mindspore 2.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1406) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +53 -0
  18. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1419 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +299 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1448 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4475 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +433 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +86 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_pijit_context.py +190 -0
  89. mindspore/common/_register_for_adapter.py +74 -0
  90. mindspore/common/_register_for_recompute.py +48 -0
  91. mindspore/common/_register_for_tensor.py +46 -0
  92. mindspore/common/_stub_tensor.py +210 -0
  93. mindspore/common/_tensor_overload.py +139 -0
  94. mindspore/common/_utils.py +122 -0
  95. mindspore/common/api.py +2064 -0
  96. mindspore/common/auto_dynamic_shape.py +507 -0
  97. mindspore/common/dtype.py +422 -0
  98. mindspore/common/dump.py +130 -0
  99. mindspore/common/file_system.py +48 -0
  100. mindspore/common/generator.py +254 -0
  101. mindspore/common/hook_handle.py +143 -0
  102. mindspore/common/initializer.py +880 -0
  103. mindspore/common/jit_config.py +98 -0
  104. mindspore/common/lazy_inline.py +240 -0
  105. mindspore/common/mindir_util.py +111 -0
  106. mindspore/common/mutable.py +234 -0
  107. mindspore/common/no_inline.py +54 -0
  108. mindspore/common/np_dtype.py +25 -0
  109. mindspore/common/parameter.py +1081 -0
  110. mindspore/common/recompute.py +292 -0
  111. mindspore/common/seed.py +260 -0
  112. mindspore/common/sparse_tensor.py +1175 -0
  113. mindspore/common/symbol.py +122 -0
  114. mindspore/common/tensor.py +5039 -0
  115. mindspore/communication/__init__.py +37 -0
  116. mindspore/communication/_comm_helper.py +501 -0
  117. mindspore/communication/_hccl_management.py +297 -0
  118. mindspore/communication/comm_func.py +1395 -0
  119. mindspore/communication/management.py +673 -0
  120. mindspore/config/op_info.config +533 -0
  121. mindspore/context.py +2077 -0
  122. mindspore/d3dcompiler_47.dll +0 -0
  123. mindspore/dataset/__init__.py +90 -0
  124. mindspore/dataset/audio/__init__.py +61 -0
  125. mindspore/dataset/audio/transforms.py +3690 -0
  126. mindspore/dataset/audio/utils.py +386 -0
  127. mindspore/dataset/audio/validators.py +1172 -0
  128. mindspore/dataset/callback/__init__.py +20 -0
  129. mindspore/dataset/callback/ds_callback.py +368 -0
  130. mindspore/dataset/callback/validators.py +32 -0
  131. mindspore/dataset/core/__init__.py +13 -0
  132. mindspore/dataset/core/config.py +1095 -0
  133. mindspore/dataset/core/datatypes.py +101 -0
  134. mindspore/dataset/core/py_util_helpers.py +65 -0
  135. mindspore/dataset/core/validator_helpers.py +781 -0
  136. mindspore/dataset/debug/__init__.py +21 -0
  137. mindspore/dataset/debug/debug_hook.py +97 -0
  138. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  139. mindspore/dataset/engine/__init__.py +124 -0
  140. mindspore/dataset/engine/cache_admin.py +47 -0
  141. mindspore/dataset/engine/cache_client.py +129 -0
  142. mindspore/dataset/engine/datasets.py +4582 -0
  143. mindspore/dataset/engine/datasets_audio.py +911 -0
  144. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  145. mindspore/dataset/engine/datasets_text.py +2161 -0
  146. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  147. mindspore/dataset/engine/datasets_vision.py +4816 -0
  148. mindspore/dataset/engine/iterators.py +371 -0
  149. mindspore/dataset/engine/obs/__init__.py +23 -0
  150. mindspore/dataset/engine/obs/config_loader.py +68 -0
  151. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  152. mindspore/dataset/engine/obs/util.py +482 -0
  153. mindspore/dataset/engine/offload.py +596 -0
  154. mindspore/dataset/engine/queue.py +304 -0
  155. mindspore/dataset/engine/samplers.py +895 -0
  156. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  157. mindspore/dataset/engine/validators.py +2895 -0
  158. mindspore/dataset/text/__init__.py +51 -0
  159. mindspore/dataset/text/transforms.py +1703 -0
  160. mindspore/dataset/text/utils.py +715 -0
  161. mindspore/dataset/text/validators.py +642 -0
  162. mindspore/dataset/transforms/__init__.py +45 -0
  163. mindspore/dataset/transforms/c_transforms.py +638 -0
  164. mindspore/dataset/transforms/py_transforms.py +393 -0
  165. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  166. mindspore/dataset/transforms/transforms.py +1260 -0
  167. mindspore/dataset/transforms/validators.py +410 -0
  168. mindspore/dataset/utils/__init__.py +19 -0
  169. mindspore/dataset/utils/browse_dataset.py +190 -0
  170. mindspore/dataset/utils/line_reader.py +126 -0
  171. mindspore/dataset/vision/__init__.py +65 -0
  172. mindspore/dataset/vision/c_transforms.py +2641 -0
  173. mindspore/dataset/vision/py_transforms.py +2120 -0
  174. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  175. mindspore/dataset/vision/transforms.py +7295 -0
  176. mindspore/dataset/vision/utils.py +863 -0
  177. mindspore/dataset/vision/validators.py +1483 -0
  178. mindspore/default_config.py +2 -0
  179. mindspore/dnnl.dll +0 -0
  180. mindspore/dpcmi.dll +0 -0
  181. mindspore/experimental/__init__.py +20 -0
  182. mindspore/experimental/es/__init__.py +22 -0
  183. mindspore/experimental/es/embedding_service.py +883 -0
  184. mindspore/experimental/es/embedding_service_layer.py +581 -0
  185. mindspore/experimental/llm_boost/__init__.py +21 -0
  186. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  187. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  188. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  189. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  190. mindspore/experimental/llm_boost/register.py +129 -0
  191. mindspore/experimental/llm_boost/utils.py +31 -0
  192. mindspore/experimental/map_parameter.py +309 -0
  193. mindspore/experimental/optim/__init__.py +40 -0
  194. mindspore/experimental/optim/adadelta.py +161 -0
  195. mindspore/experimental/optim/adagrad.py +168 -0
  196. mindspore/experimental/optim/adam.py +193 -0
  197. mindspore/experimental/optim/adamax.py +170 -0
  198. mindspore/experimental/optim/adamw.py +290 -0
  199. mindspore/experimental/optim/asgd.py +153 -0
  200. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  201. mindspore/experimental/optim/nadam.py +157 -0
  202. mindspore/experimental/optim/optimizer.py +262 -0
  203. mindspore/experimental/optim/radam.py +194 -0
  204. mindspore/experimental/optim/rmsprop.py +154 -0
  205. mindspore/experimental/optim/rprop.py +164 -0
  206. mindspore/experimental/optim/sgd.py +156 -0
  207. mindspore/hal/__init__.py +40 -0
  208. mindspore/hal/_ascend.py +57 -0
  209. mindspore/hal/_base.py +57 -0
  210. mindspore/hal/_cpu.py +56 -0
  211. mindspore/hal/_gpu.py +57 -0
  212. mindspore/hal/contiguous_tensors_handle.py +175 -0
  213. mindspore/hal/device.py +356 -0
  214. mindspore/hal/event.py +179 -0
  215. mindspore/hal/memory.py +326 -0
  216. mindspore/hal/stream.py +357 -0
  217. mindspore/include/OWNERS +7 -0
  218. mindspore/include/api/allocator.h +97 -0
  219. mindspore/include/api/callback/callback.h +93 -0
  220. mindspore/include/api/callback/ckpt_saver.h +41 -0
  221. mindspore/include/api/callback/loss_monitor.h +33 -0
  222. mindspore/include/api/callback/lr_scheduler.h +51 -0
  223. mindspore/include/api/callback/time_monitor.h +34 -0
  224. mindspore/include/api/callback/train_accuracy.h +37 -0
  225. mindspore/include/api/cell.h +90 -0
  226. mindspore/include/api/cfg.h +82 -0
  227. mindspore/include/api/context.h +602 -0
  228. mindspore/include/api/data_type.h +47 -0
  229. mindspore/include/api/delegate.h +178 -0
  230. mindspore/include/api/delegate_api.h +75 -0
  231. mindspore/include/api/dual_abi_helper.h +208 -0
  232. mindspore/include/api/format.h +28 -0
  233. mindspore/include/api/graph.h +46 -0
  234. mindspore/include/api/kernel.h +58 -0
  235. mindspore/include/api/kernel_api.h +168 -0
  236. mindspore/include/api/metrics/accuracy.h +36 -0
  237. mindspore/include/api/metrics/metrics.h +41 -0
  238. mindspore/include/api/model.h +438 -0
  239. mindspore/include/api/model_group.h +91 -0
  240. mindspore/include/api/model_parallel_runner.h +168 -0
  241. mindspore/include/api/serialization.h +185 -0
  242. mindspore/include/api/status.h +192 -0
  243. mindspore/include/api/types.h +431 -0
  244. mindspore/include/api/visible.h +41 -0
  245. mindspore/include/c_api/context_c.h +179 -0
  246. mindspore/include/c_api/data_type_c.h +52 -0
  247. mindspore/include/c_api/format_c.h +46 -0
  248. mindspore/include/c_api/model_c.h +347 -0
  249. mindspore/include/c_api/status_c.h +79 -0
  250. mindspore/include/c_api/tensor_c.h +146 -0
  251. mindspore/include/c_api/types_c.h +67 -0
  252. mindspore/include/dataset/config.h +163 -0
  253. mindspore/include/dataset/constants.h +363 -0
  254. mindspore/include/dataset/execute.h +196 -0
  255. mindspore/include/dataset/text.h +1092 -0
  256. mindspore/include/dataset/transforms.h +638 -0
  257. mindspore/include/dataset/vision.h +2129 -0
  258. mindspore/include/dataset/vision_ascend.h +206 -0
  259. mindspore/include/dataset/vision_lite.h +625 -0
  260. mindspore/jpeg62.dll +0 -0
  261. mindspore/log.py +633 -0
  262. mindspore/mindrecord/__init__.py +43 -0
  263. mindspore/mindrecord/common/__init__.py +17 -0
  264. mindspore/mindrecord/common/constant.py +20 -0
  265. mindspore/mindrecord/common/enums.py +44 -0
  266. mindspore/mindrecord/common/exceptions.py +311 -0
  267. mindspore/mindrecord/config.py +809 -0
  268. mindspore/mindrecord/filereader.py +174 -0
  269. mindspore/mindrecord/filewriter.py +722 -0
  270. mindspore/mindrecord/mindpage.py +210 -0
  271. mindspore/mindrecord/shardheader.py +141 -0
  272. mindspore/mindrecord/shardindexgenerator.py +74 -0
  273. mindspore/mindrecord/shardreader.py +117 -0
  274. mindspore/mindrecord/shardsegment.py +128 -0
  275. mindspore/mindrecord/shardutils.py +185 -0
  276. mindspore/mindrecord/shardwriter.py +237 -0
  277. mindspore/mindrecord/tools/__init__.py +17 -0
  278. mindspore/mindrecord/tools/cifar10.py +140 -0
  279. mindspore/mindrecord/tools/cifar100.py +153 -0
  280. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  281. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  282. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  283. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  284. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  285. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  286. mindspore/mindspore_backend.dll +0 -0
  287. mindspore/mindspore_common.dll +0 -0
  288. mindspore/mindspore_core.dll +0 -0
  289. mindspore/mindspore_glog.dll +0 -0
  290. mindspore/mindspore_np_dtype.dll +0 -0
  291. mindspore/mindspore_ops.dll +0 -0
  292. mindspore/mint/__init__.py +1586 -0
  293. mindspore/mint/distributed/__init__.py +31 -0
  294. mindspore/mint/distributed/distributed.py +254 -0
  295. mindspore/mint/linalg/__init__.py +22 -0
  296. mindspore/mint/nn/__init__.py +757 -0
  297. mindspore/mint/nn/functional.py +679 -0
  298. mindspore/mint/nn/layer/__init__.py +39 -0
  299. mindspore/mint/nn/layer/activation.py +133 -0
  300. mindspore/mint/nn/layer/normalization.py +477 -0
  301. mindspore/mint/nn/layer/pooling.py +110 -0
  302. mindspore/mint/optim/__init__.py +24 -0
  303. mindspore/mint/optim/adamw.py +206 -0
  304. mindspore/mint/special/__init__.py +63 -0
  305. mindspore/msobj140.dll +0 -0
  306. mindspore/mspdb140.dll +0 -0
  307. mindspore/mspdbcore.dll +0 -0
  308. mindspore/mspdbst.dll +0 -0
  309. mindspore/mspft140.dll +0 -0
  310. mindspore/msvcdis140.dll +0 -0
  311. mindspore/msvcp140.dll +0 -0
  312. mindspore/msvcp140_1.dll +0 -0
  313. mindspore/msvcp140_2.dll +0 -0
  314. mindspore/msvcp140_atomic_wait.dll +0 -0
  315. mindspore/msvcp140_codecvt_ids.dll +0 -0
  316. mindspore/multiprocessing/__init__.py +73 -0
  317. mindspore/nn/__init__.py +47 -0
  318. mindspore/nn/cell.py +2787 -0
  319. mindspore/nn/dynamic_lr.py +482 -0
  320. mindspore/nn/grad/__init__.py +21 -0
  321. mindspore/nn/grad/cell_grad.py +196 -0
  322. mindspore/nn/layer/__init__.py +63 -0
  323. mindspore/nn/layer/activation.py +1822 -0
  324. mindspore/nn/layer/basic.py +1629 -0
  325. mindspore/nn/layer/channel_shuffle.py +90 -0
  326. mindspore/nn/layer/combined.py +248 -0
  327. mindspore/nn/layer/container.py +734 -0
  328. mindspore/nn/layer/conv.py +1505 -0
  329. mindspore/nn/layer/dense.py +204 -0
  330. mindspore/nn/layer/embedding.py +869 -0
  331. mindspore/nn/layer/image.py +661 -0
  332. mindspore/nn/layer/math.py +1069 -0
  333. mindspore/nn/layer/normalization.py +1273 -0
  334. mindspore/nn/layer/padding.py +880 -0
  335. mindspore/nn/layer/pooling.py +2302 -0
  336. mindspore/nn/layer/rnn_cells.py +388 -0
  337. mindspore/nn/layer/rnns.py +849 -0
  338. mindspore/nn/layer/thor_layer.py +963 -0
  339. mindspore/nn/layer/timedistributed.py +155 -0
  340. mindspore/nn/layer/transformer.py +823 -0
  341. mindspore/nn/learning_rate_schedule.py +512 -0
  342. mindspore/nn/loss/__init__.py +36 -0
  343. mindspore/nn/loss/loss.py +2924 -0
  344. mindspore/nn/metrics.py +53 -0
  345. mindspore/nn/optim/__init__.py +45 -0
  346. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  347. mindspore/nn/optim/ada_grad.py +217 -0
  348. mindspore/nn/optim/adadelta.py +206 -0
  349. mindspore/nn/optim/adafactor.py +448 -0
  350. mindspore/nn/optim/adam.py +1297 -0
  351. mindspore/nn/optim/adamax.py +220 -0
  352. mindspore/nn/optim/adasum.py +548 -0
  353. mindspore/nn/optim/asgd.py +216 -0
  354. mindspore/nn/optim/ftrl.py +401 -0
  355. mindspore/nn/optim/lamb.py +296 -0
  356. mindspore/nn/optim/lars.py +202 -0
  357. mindspore/nn/optim/lazyadam.py +533 -0
  358. mindspore/nn/optim/momentum.py +239 -0
  359. mindspore/nn/optim/optimizer.py +1034 -0
  360. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  361. mindspore/nn/optim/rmsprop.py +264 -0
  362. mindspore/nn/optim/rprop.py +251 -0
  363. mindspore/nn/optim/sgd.py +237 -0
  364. mindspore/nn/optim/tft_wrapper.py +127 -0
  365. mindspore/nn/optim/thor.py +1310 -0
  366. mindspore/nn/probability/__init__.py +22 -0
  367. mindspore/nn/probability/bijector/__init__.py +35 -0
  368. mindspore/nn/probability/bijector/bijector.py +337 -0
  369. mindspore/nn/probability/bijector/exp.py +65 -0
  370. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  371. mindspore/nn/probability/bijector/invert.py +126 -0
  372. mindspore/nn/probability/bijector/power_transform.py +196 -0
  373. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  374. mindspore/nn/probability/bijector/softplus.py +189 -0
  375. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  376. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  377. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  378. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  379. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  380. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  381. mindspore/nn/probability/distribution/__init__.py +56 -0
  382. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  383. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  384. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  385. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  386. mindspore/nn/probability/distribution/beta.py +391 -0
  387. mindspore/nn/probability/distribution/categorical.py +435 -0
  388. mindspore/nn/probability/distribution/cauchy.py +383 -0
  389. mindspore/nn/probability/distribution/distribution.py +827 -0
  390. mindspore/nn/probability/distribution/exponential.py +350 -0
  391. mindspore/nn/probability/distribution/gamma.py +391 -0
  392. mindspore/nn/probability/distribution/geometric.py +335 -0
  393. mindspore/nn/probability/distribution/gumbel.py +257 -0
  394. mindspore/nn/probability/distribution/half_normal.py +133 -0
  395. mindspore/nn/probability/distribution/laplace.py +128 -0
  396. mindspore/nn/probability/distribution/log_normal.py +272 -0
  397. mindspore/nn/probability/distribution/logistic.py +379 -0
  398. mindspore/nn/probability/distribution/normal.py +336 -0
  399. mindspore/nn/probability/distribution/poisson.py +288 -0
  400. mindspore/nn/probability/distribution/student_t.py +149 -0
  401. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  402. mindspore/nn/probability/distribution/uniform.py +375 -0
  403. mindspore/nn/reinforcement/__init__.py +24 -0
  404. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  405. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  406. mindspore/nn/reinforcement/tensor_array.py +145 -0
  407. mindspore/nn/sparse/__init__.py +23 -0
  408. mindspore/nn/sparse/sparse.py +147 -0
  409. mindspore/nn/wrap/__init__.py +49 -0
  410. mindspore/nn/wrap/cell_wrapper.py +968 -0
  411. mindspore/nn/wrap/grad_reducer.py +608 -0
  412. mindspore/nn/wrap/loss_scale.py +694 -0
  413. mindspore/numpy/__init__.py +121 -0
  414. mindspore/numpy/array_creations.py +2731 -0
  415. mindspore/numpy/array_ops.py +2629 -0
  416. mindspore/numpy/dtypes.py +185 -0
  417. mindspore/numpy/fft.py +966 -0
  418. mindspore/numpy/logic_ops.py +936 -0
  419. mindspore/numpy/math_ops.py +5911 -0
  420. mindspore/numpy/utils.py +214 -0
  421. mindspore/numpy/utils_const.py +565 -0
  422. mindspore/opencv_core452.dll +0 -0
  423. mindspore/opencv_imgcodecs452.dll +0 -0
  424. mindspore/opencv_imgproc452.dll +0 -0
  425. mindspore/ops/__init__.py +56 -0
  426. mindspore/ops/_constants.py +30 -0
  427. mindspore/ops/_grad_experimental/__init__.py +31 -0
  428. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  429. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  430. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  431. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  432. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  433. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  434. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  435. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  436. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  437. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  438. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  439. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  440. mindspore/ops/_op_impl/__init__.py +23 -0
  441. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  442. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  443. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  447. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  448. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  449. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  450. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  451. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  452. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  453. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  462. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  463. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  464. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  465. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  471. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  472. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  473. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  474. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  475. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  476. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  477. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  478. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  479. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  480. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  481. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  490. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  491. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  494. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  495. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  498. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  499. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  500. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  501. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  502. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  503. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  504. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  505. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  506. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  507. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  508. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  510. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  511. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  513. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  514. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  515. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  516. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  517. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  518. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  519. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  520. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  521. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  522. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  523. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  524. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  526. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  527. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  529. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  531. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  532. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  533. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  534. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  535. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  537. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  538. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  539. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  540. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  541. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  542. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  544. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  545. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  546. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  547. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  548. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  550. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  551. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  552. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  553. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  554. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  556. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  557. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  558. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  559. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  560. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  561. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  562. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  563. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  564. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  565. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  567. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  568. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  569. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  570. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  571. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  572. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  574. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  575. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  576. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  577. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  578. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  579. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  580. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  581. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  582. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  583. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  584. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  585. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  586. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  587. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  588. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  589. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  590. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  591. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  592. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  593. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  594. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  595. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  596. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  597. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  598. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  599. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  600. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  601. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  603. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  610. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  611. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  612. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  615. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  616. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  617. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  618. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  619. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  620. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  621. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  622. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  623. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  625. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  626. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  627. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  628. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  629. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  630. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  634. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  635. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  636. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  637. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  638. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  639. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  640. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  641. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  642. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  643. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  644. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  646. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  647. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  648. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  650. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  651. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  652. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  653. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  654. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  655. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  656. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  657. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  658. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  659. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  660. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  661. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  662. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  663. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  664. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  665. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  667. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  668. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  669. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  670. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  673. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  675. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  676. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  677. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  679. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  680. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  681. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  682. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  683. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  684. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  685. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  686. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  687. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  688. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  699. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  700. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  701. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  702. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  706. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  707. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  708. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  709. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  710. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  711. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  712. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  713. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  714. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  715. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  716. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  717. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  718. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  720. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  721. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  722. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  723. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  724. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  725. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  726. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  727. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  728. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  729. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  730. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  731. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  732. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  733. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  734. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  735. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  736. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  737. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  738. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  739. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  740. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  741. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  742. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  743. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  744. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  745. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  746. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  747. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  748. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  749. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  750. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  751. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  752. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  753. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  754. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  755. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  757. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  758. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  759. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  760. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  761. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  762. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  763. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  764. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  765. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  766. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  767. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  768. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  769. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  770. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  774. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  775. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  776. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  777. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  778. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  779. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  782. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  783. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  784. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  785. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  786. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  787. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  788. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  789. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  790. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  791. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  792. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  793. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  794. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  795. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  796. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  797. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  798. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  799. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  804. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  805. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  806. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  810. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  811. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  812. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  813. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  817. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  818. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  819. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  820. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  821. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  823. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  825. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  826. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  827. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  828. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  829. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  862. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  863. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  864. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  865. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  866. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  868. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  869. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  870. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  871. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  872. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  873. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  874. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  876. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1318 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/function/__init__.py +782 -0
  1098. mindspore/ops/function/array_func.py +7226 -0
  1099. mindspore/ops/function/clip_func.py +384 -0
  1100. mindspore/ops/function/debug_func.py +181 -0
  1101. mindspore/ops/function/fft_func.py +44 -0
  1102. mindspore/ops/function/grad/__init__.py +34 -0
  1103. mindspore/ops/function/grad/grad_func.py +1425 -0
  1104. mindspore/ops/function/image_func.py +292 -0
  1105. mindspore/ops/function/linalg_func.py +416 -0
  1106. mindspore/ops/function/math_func.py +12228 -0
  1107. mindspore/ops/function/nn_func.py +8609 -0
  1108. mindspore/ops/function/other_func.py +115 -0
  1109. mindspore/ops/function/parameter_func.py +134 -0
  1110. mindspore/ops/function/random_func.py +1715 -0
  1111. mindspore/ops/function/reshard_func.py +104 -0
  1112. mindspore/ops/function/sparse_func.py +884 -0
  1113. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1114. mindspore/ops/function/spectral_func.py +150 -0
  1115. mindspore/ops/function/vmap_func.py +117 -0
  1116. mindspore/ops/functional.py +464 -0
  1117. mindspore/ops/op_info_register.py +1572 -0
  1118. mindspore/ops/operations/__init__.py +722 -0
  1119. mindspore/ops/operations/_csr_ops.py +403 -0
  1120. mindspore/ops/operations/_custom_grad.py +181 -0
  1121. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1122. mindspore/ops/operations/_grad_ops.py +2978 -0
  1123. mindspore/ops/operations/_infer_ops.py +19 -0
  1124. mindspore/ops/operations/_inner_ops.py +2544 -0
  1125. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1126. mindspore/ops/operations/_ms_kernel.py +601 -0
  1127. mindspore/ops/operations/_ocr_ops.py +379 -0
  1128. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1129. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1130. mindspore/ops/operations/_quant_ops.py +1844 -0
  1131. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1132. mindspore/ops/operations/_scalar_ops.py +106 -0
  1133. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1134. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1135. mindspore/ops/operations/_tensor_array.py +359 -0
  1136. mindspore/ops/operations/_thor_ops.py +807 -0
  1137. mindspore/ops/operations/array_ops.py +6124 -0
  1138. mindspore/ops/operations/comm_ops.py +1985 -0
  1139. mindspore/ops/operations/control_ops.py +127 -0
  1140. mindspore/ops/operations/custom_ops.py +1129 -0
  1141. mindspore/ops/operations/debug_ops.py +678 -0
  1142. mindspore/ops/operations/image_ops.py +1041 -0
  1143. mindspore/ops/operations/inner_ops.py +697 -0
  1144. mindspore/ops/operations/linalg_ops.py +95 -0
  1145. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1146. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1147. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1148. mindspore/ops/operations/math_ops.py +5095 -0
  1149. mindspore/ops/operations/nn_ops.py +9575 -0
  1150. mindspore/ops/operations/other_ops.py +874 -0
  1151. mindspore/ops/operations/random_ops.py +1288 -0
  1152. mindspore/ops/operations/reshard_ops.py +53 -0
  1153. mindspore/ops/operations/rl_ops.py +288 -0
  1154. mindspore/ops/operations/sparse_ops.py +2753 -0
  1155. mindspore/ops/operations/spectral_ops.py +111 -0
  1156. mindspore/ops/primitive.py +1046 -0
  1157. mindspore/ops/signature.py +54 -0
  1158. mindspore/ops/vm_impl_registry.py +91 -0
  1159. mindspore/ops_generate/__init__.py +27 -0
  1160. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1161. mindspore/ops_generate/arg_handler.py +197 -0
  1162. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1163. mindspore/ops_generate/gen_constants.py +36 -0
  1164. mindspore/ops_generate/gen_ops.py +1099 -0
  1165. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1166. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1167. mindspore/ops_generate/gen_utils.py +209 -0
  1168. mindspore/ops_generate/op_proto.py +145 -0
  1169. mindspore/ops_generate/pyboost_utils.py +367 -0
  1170. mindspore/ops_generate/template.py +261 -0
  1171. mindspore/parallel/__init__.py +30 -0
  1172. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1173. mindspore/parallel/_cell_wrapper.py +174 -0
  1174. mindspore/parallel/_cost_model_context.py +700 -0
  1175. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1176. mindspore/parallel/_offload_context.py +275 -0
  1177. mindspore/parallel/_parallel_serialization.py +561 -0
  1178. mindspore/parallel/_ps_context.py +242 -0
  1179. mindspore/parallel/_recovery_context.py +110 -0
  1180. mindspore/parallel/_tensor.py +730 -0
  1181. mindspore/parallel/_transformer/__init__.py +35 -0
  1182. mindspore/parallel/_transformer/layers.py +765 -0
  1183. mindspore/parallel/_transformer/loss.py +251 -0
  1184. mindspore/parallel/_transformer/moe.py +693 -0
  1185. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1186. mindspore/parallel/_transformer/transformer.py +3119 -0
  1187. mindspore/parallel/_utils.py +612 -0
  1188. mindspore/parallel/algo_parameter_config.py +400 -0
  1189. mindspore/parallel/checkpoint_transform.py +650 -0
  1190. mindspore/parallel/cluster/__init__.py +15 -0
  1191. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1192. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1193. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1194. mindspore/parallel/cluster/run.py +136 -0
  1195. mindspore/parallel/mpi/__init__.py +14 -0
  1196. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1197. mindspore/parallel/parameter_broadcast.py +151 -0
  1198. mindspore/parallel/shard.py +481 -0
  1199. mindspore/parallel/transform_safetensors.py +993 -0
  1200. mindspore/perf_msvcbuildinsights.dll +0 -0
  1201. mindspore/pgodb140.dll +0 -0
  1202. mindspore/pgort140.dll +0 -0
  1203. mindspore/profiler/__init__.py +28 -0
  1204. mindspore/profiler/common/__init__.py +14 -0
  1205. mindspore/profiler/common/constant.py +29 -0
  1206. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1207. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1208. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1209. mindspore/profiler/common/process_pool.py +41 -0
  1210. mindspore/profiler/common/registry.py +47 -0
  1211. mindspore/profiler/common/singleton.py +28 -0
  1212. mindspore/profiler/common/struct_type.py +118 -0
  1213. mindspore/profiler/common/util.py +472 -0
  1214. mindspore/profiler/common/validator/__init__.py +14 -0
  1215. mindspore/profiler/common/validator/validate_path.py +84 -0
  1216. mindspore/profiler/dynamic_profiler.py +694 -0
  1217. mindspore/profiler/envprofiling.py +254 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1229. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1230. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1231. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1232. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1233. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1234. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1235. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1236. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1237. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1238. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1239. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1240. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1241. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1242. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1243. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1244. mindspore/profiler/parser/container.py +229 -0
  1245. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1246. mindspore/profiler/parser/flops_parser.py +531 -0
  1247. mindspore/profiler/parser/framework_enum.py +111 -0
  1248. mindspore/profiler/parser/framework_parser.py +464 -0
  1249. mindspore/profiler/parser/framework_struct.py +61 -0
  1250. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1251. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1252. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1253. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1254. mindspore/profiler/parser/hccl_parser.py +573 -0
  1255. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1256. mindspore/profiler/parser/integrator.py +526 -0
  1257. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1258. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1259. mindspore/profiler/parser/minddata_parser.py +186 -0
  1260. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1261. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1262. mindspore/profiler/parser/optime_parser.py +250 -0
  1263. mindspore/profiler/parser/profiler_info.py +213 -0
  1264. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1265. mindspore/profiler/profiler.py +153 -0
  1266. mindspore/profiler/profiling.py +1922 -0
  1267. mindspore/rewrite/__init__.py +28 -0
  1268. mindspore/rewrite/api/__init__.py +17 -0
  1269. mindspore/rewrite/api/node.py +519 -0
  1270. mindspore/rewrite/api/node_type.py +53 -0
  1271. mindspore/rewrite/api/pattern_engine.py +490 -0
  1272. mindspore/rewrite/api/scoped_value.py +181 -0
  1273. mindspore/rewrite/api/symbol_tree.py +497 -0
  1274. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1275. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1276. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1277. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1278. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1279. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1280. mindspore/rewrite/common/__init__.py +19 -0
  1281. mindspore/rewrite/common/config.py +24 -0
  1282. mindspore/rewrite/common/error_log.py +39 -0
  1283. mindspore/rewrite/common/event.py +28 -0
  1284. mindspore/rewrite/common/namer.py +271 -0
  1285. mindspore/rewrite/common/namespace.py +118 -0
  1286. mindspore/rewrite/common/observable.py +44 -0
  1287. mindspore/rewrite/common/observer.py +54 -0
  1288. mindspore/rewrite/node/__init__.py +22 -0
  1289. mindspore/rewrite/node/call_function.py +95 -0
  1290. mindspore/rewrite/node/cell_container.py +139 -0
  1291. mindspore/rewrite/node/control_flow.py +113 -0
  1292. mindspore/rewrite/node/node.py +1428 -0
  1293. mindspore/rewrite/node/node_manager.py +283 -0
  1294. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1295. mindspore/rewrite/parsers/__init__.py +29 -0
  1296. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1297. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1298. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1299. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1300. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1301. mindspore/rewrite/parsers/container_parser.py +88 -0
  1302. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1303. mindspore/rewrite/parsers/for_parser.py +61 -0
  1304. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1305. mindspore/rewrite/parsers/if_parser.py +85 -0
  1306. mindspore/rewrite/parsers/module_parser.py +117 -0
  1307. mindspore/rewrite/parsers/parser.py +43 -0
  1308. mindspore/rewrite/parsers/parser_register.py +86 -0
  1309. mindspore/rewrite/parsers/return_parser.py +37 -0
  1310. mindspore/rewrite/parsers/while_parser.py +59 -0
  1311. mindspore/rewrite/sparsify/__init__.py +0 -0
  1312. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1313. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1314. mindspore/rewrite/sparsify/utils.py +179 -0
  1315. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1316. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1317. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1318. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1319. mindspore/run_check/__init__.py +20 -0
  1320. mindspore/run_check/_check_version.py +507 -0
  1321. mindspore/run_check/run_check.py +66 -0
  1322. mindspore/safeguard/__init__.py +18 -0
  1323. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1324. mindspore/swresample-4.dll +0 -0
  1325. mindspore/swscale-6.dll +0 -0
  1326. mindspore/tbbmalloc.dll +0 -0
  1327. mindspore/tinyxml2.dll +0 -0
  1328. mindspore/train/__init__.py +48 -0
  1329. mindspore/train/_utils.py +465 -0
  1330. mindspore/train/amp.py +935 -0
  1331. mindspore/train/anf_ir_pb2.py +1517 -0
  1332. mindspore/train/callback/__init__.py +44 -0
  1333. mindspore/train/callback/_backup_and_restore.py +117 -0
  1334. mindspore/train/callback/_callback.py +613 -0
  1335. mindspore/train/callback/_checkpoint.py +814 -0
  1336. mindspore/train/callback/_cluster_monitor.py +201 -0
  1337. mindspore/train/callback/_dataset_graph.py +150 -0
  1338. mindspore/train/callback/_early_stop.py +239 -0
  1339. mindspore/train/callback/_flops_collector.py +239 -0
  1340. mindspore/train/callback/_history.py +92 -0
  1341. mindspore/train/callback/_lambda_callback.py +80 -0
  1342. mindspore/train/callback/_landscape.py +1049 -0
  1343. mindspore/train/callback/_loss_monitor.py +107 -0
  1344. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1345. mindspore/train/callback/_on_request_exit.py +298 -0
  1346. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1347. mindspore/train/callback/_summary_collector.py +1184 -0
  1348. mindspore/train/callback/_tft_register.py +352 -0
  1349. mindspore/train/callback/_time_monitor.py +141 -0
  1350. mindspore/train/checkpoint_pb2.py +233 -0
  1351. mindspore/train/data_sink.py +219 -0
  1352. mindspore/train/dataset_helper.py +692 -0
  1353. mindspore/train/lineage_pb2.py +1260 -0
  1354. mindspore/train/loss_scale_manager.py +213 -0
  1355. mindspore/train/memory_profiling_pb2.py +298 -0
  1356. mindspore/train/metrics/__init__.py +175 -0
  1357. mindspore/train/metrics/accuracy.py +133 -0
  1358. mindspore/train/metrics/auc.py +129 -0
  1359. mindspore/train/metrics/bleu_score.py +170 -0
  1360. mindspore/train/metrics/confusion_matrix.py +700 -0
  1361. mindspore/train/metrics/cosine_similarity.py +109 -0
  1362. mindspore/train/metrics/dice.py +116 -0
  1363. mindspore/train/metrics/error.py +175 -0
  1364. mindspore/train/metrics/fbeta.py +167 -0
  1365. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1366. mindspore/train/metrics/loss.py +97 -0
  1367. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1368. mindspore/train/metrics/metric.py +373 -0
  1369. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1370. mindspore/train/metrics/perplexity.py +133 -0
  1371. mindspore/train/metrics/precision.py +160 -0
  1372. mindspore/train/metrics/recall.py +159 -0
  1373. mindspore/train/metrics/roc.py +223 -0
  1374. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1375. mindspore/train/metrics/topk.py +167 -0
  1376. mindspore/train/mind_ir_pb2.py +1908 -0
  1377. mindspore/train/model.py +2252 -0
  1378. mindspore/train/node_strategy_pb2.py +653 -0
  1379. mindspore/train/print_pb2.py +184 -0
  1380. mindspore/train/profiling_parallel_pb2.py +151 -0
  1381. mindspore/train/serialization.py +3325 -0
  1382. mindspore/train/summary/__init__.py +23 -0
  1383. mindspore/train/summary/_lineage_adapter.py +41 -0
  1384. mindspore/train/summary/_summary_adapter.py +496 -0
  1385. mindspore/train/summary/_writer_pool.py +207 -0
  1386. mindspore/train/summary/enums.py +56 -0
  1387. mindspore/train/summary/summary_record.py +581 -0
  1388. mindspore/train/summary/writer.py +167 -0
  1389. mindspore/train/summary_pb2.py +1165 -0
  1390. mindspore/train/train_thor/__init__.py +20 -0
  1391. mindspore/train/train_thor/convert_utils.py +268 -0
  1392. mindspore/train/train_thor/dataset_helper.py +192 -0
  1393. mindspore/train/train_thor/model_thor.py +257 -0
  1394. mindspore/turbojpeg.dll +0 -0
  1395. mindspore/utils/__init__.py +21 -0
  1396. mindspore/utils/utils.py +60 -0
  1397. mindspore/vcmeta.dll +0 -0
  1398. mindspore/vcomp140.dll +0 -0
  1399. mindspore/vcruntime140.dll +0 -0
  1400. mindspore/vcruntime140_1.dll +0 -0
  1401. mindspore/version.py +1 -0
  1402. mindspore-2.4.0.dist-info/METADATA +352 -0
  1403. mindspore-2.4.0.dist-info/RECORD +1406 -0
  1404. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1405. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1406. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4582 @@
1
+ # Copyright 2022-2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """
16
+ 1. This file is an abstraction of the dataset loading class. It contains
17
+ some basic dataset operations(skip, filter, map, batch, ...).
18
+ 2. Specific dataset loading classes can be found in datasets_vision.py, datasets_text.py,
19
+ datasets_audio.py, datasets_standard_format.py and datasets_user_defined.py files.
20
+ datasets_vision.py: contains vision dataset loading classes.
21
+ datasets_text.py: contains text dataset loading classes.
22
+ datasets_audio.py: contains audio dataset loading classes.
23
+ datasets_standard_format.py: contains standard format loading classes which
24
+ any other kinds of datasets can be converted to.
25
+ datasets_user_defined.py: contains basic classes that help users to define
26
+ flexible ways to load dataset.
27
+ """
28
+ import atexit
29
+ import glob
30
+ import json
31
+ import os
32
+ import queue
33
+ import signal
34
+ import stat
35
+ import subprocess
36
+ import warnings
37
+
38
+ import gc
39
+ import time
40
+ import uuid
41
+ import multiprocessing
42
+ from enum import Enum
43
+ from importlib import import_module
44
+ import sys
45
+ import threading
46
+
47
+ import copy
48
+ import weakref
49
+ import platform
50
+ import psutil
51
+
52
+ import mindspore._c_dataengine as cde
53
+ from mindspore._c_expression import typing
54
+
55
+ from mindspore import log as logger
56
+ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context,\
57
+ _enable_distributed_mindrt
58
+ from mindspore.dataset.engine.offload import GetOffloadModel
59
+
60
+ import mindspore.dataset.transforms.c_transforms as c_transforms
61
+ import mindspore.dataset.transforms.py_transforms as py_transforms
62
+ import mindspore.dataset.transforms as transforms
63
+ from mindspore.dataset.text.utils import SentencePieceModel, DE_C_INTER_SENTENCEPIECE_MODE
64
+ from mindspore.parallel._utils import _get_device_num
65
+ from mindspore.dataset.debug import DebugHook
66
+
67
+ from mindspore.dataset.engine import samplers
68
+ from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
69
+ ITERATORS_LIST, _unset_iterator_cleanup, _cleanup_the_iterators_if_created
70
+ from .queue import _SharedQueue, _Queue
71
+ from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
72
+ check_rename, check_device_send, check_take, check_output_shape, check_project, \
73
+ check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
74
+ check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch, \
75
+ check_total_batch, check_sync_update
76
+ from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
77
+ get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval, _get_debug_hook_list
78
+ from ..core.datatypes import mstype_to_detype
79
+ from ..core.validator_helpers import replace_none
80
+ from ..core.py_util_helpers import ExceptionHandler
81
+ from ..transforms.py_transforms_util import FuncWrapper, Implementation
82
+ from ..vision.transforms import ToNumpy
83
+ from ...mindrecord.config import _get_enc_key, _get_enc_mode, _get_hash_mode, encrypt, append_hash_to_file
84
+
85
+ try:
86
+ context = import_module("mindspore.context")
87
+ except ModuleNotFoundError:
88
+ context = None
89
+
90
+ if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork":
91
+ multiprocessing.set_start_method("fork", True)
92
+
93
+ OffloadToManualOffloadMode = {
94
+ None: cde.ManualOffloadMode.UNSPECIFIED,
95
+ False: cde.ManualOffloadMode.DISABLED,
96
+ True: cde.ManualOffloadMode.ENABLED
97
+ }
98
+
99
+ _train_dataset = None
100
+
101
+
102
+ def _set_training_dataset(dataset):
103
+ """
104
+ Set the dataset to be used when training recovery has occurred.
105
+
106
+ Args:
107
+ dataset: the training dataset or iterator
108
+ """
109
+ global _train_dataset
110
+ _train_dataset = dataset
111
+
112
+
113
+ def _get_training_dataset():
114
+ """
115
+ Get the dataset to be used when training recovery has occurred.
116
+
117
+ Returns:
118
+ training dataset/iterator
119
+ """
120
+ return _train_dataset
121
+
122
+
123
+ def _reset_training_dataset(global_step, dataset_size):
124
+ """
125
+ Reset the training dataset to the given global step.
126
+
127
+ Args:
128
+ global_step (int): Number of global steps that have completed training.
129
+ Dataset will provide data from its next step after reset.
130
+ dataset_size (int): Number of steps per epoch.
131
+ """
132
+ dataset = _get_training_dataset()
133
+ if dataset is not None:
134
+ dataset._reset(global_step, dataset_size) # pylint: disable=protected-access
135
+ else:
136
+ raise RuntimeError("Training dataset is not set.")
137
+
138
+
139
+ class Shuffle(str, Enum):
140
+ """Specify the shuffle mode.
141
+
142
+ - ``Shuffle.GLOBAL`` : Shuffle both the files and samples.
143
+ - ``Shuffle.FILES`` : Shuffle files only.
144
+ - ``Shuffle.INFILE`` : Shuffle data within each file.
145
+ """
146
+ GLOBAL: str = "global"
147
+ FILES: str = "files"
148
+ INFILE: str = "infile"
149
+
150
+
151
+ ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
152
+ Shuffle.GLOBAL: cde.ShuffleMode.GLOBAL,
153
+ Shuffle.INFILE: cde.ShuffleMode.INFILE}
154
+
155
+
156
+ def shuffle_to_shuffle_mode(shuffle):
157
+ """
158
+ Shuffle Enum to Shuffle Mode
159
+
160
+ Args:
161
+ shuffle (Shuffle): shuffle flag to shuffle mode in C layer
162
+
163
+ Returns:
164
+ ShuffleMode, shuffle mode
165
+ """
166
+ shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle
167
+ if not isinstance(shuffle, Shuffle):
168
+ if shuffle is None or shuffle:
169
+ shuffle_mode = cde.ShuffleMode.GLOBAL # Global shuffle
170
+ else:
171
+ shuffle_mode = cde.ShuffleMode.FALSE # No shuffle
172
+ else:
173
+ shuffle_mode = ShuffleToShuffleMode[shuffle]
174
+ return shuffle_mode
175
+
176
+
177
+ def shuffle_to_bool(shuffle):
178
+ """
179
+ Shuffle Enum to bool
180
+
181
+ Args:
182
+ shuffle (Shuffle): shuffle flag to bool
183
+
184
+ Returns:
185
+ bool, True / False
186
+ """
187
+ if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
188
+ raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
189
+ "'Shuffle.FILES' or 'Shuffle.INFILE'.")
190
+
191
+ shuffle_bool = True
192
+ if not isinstance(shuffle, Shuffle):
193
+ if shuffle is None:
194
+ shuffle_bool = None
195
+ elif shuffle:
196
+ shuffle_bool = True
197
+ else:
198
+ shuffle_bool = False
199
+ else:
200
+ shuffle_bool = True
201
+ return shuffle_bool
202
+
203
+
204
+ @check_zip
205
+ def zip(datasets):
206
+ """
207
+ Zip the datasets in the input tuple of datasets.
208
+
209
+ Args:
210
+ datasets (tuple[Dataset]): A tuple of datasets to be zipped together.
211
+ The number of datasets must be more than 1.
212
+
213
+ Returns:
214
+ Dataset, a new dataset with the above operation applied.
215
+
216
+ Raises:
217
+ ValueError: If the number of datasets is 1.
218
+ TypeError: If datasets is not a tuple.
219
+
220
+ Examples:
221
+ >>> # Create a dataset which is the combination of dataset_1 and dataset_2
222
+ >>> import mindspore.dataset as ds
223
+ >>>
224
+ >>> dataset_1 = ds.GeneratorDataset([1], "column1")
225
+ >>> dataset_2 = ds.GeneratorDataset([2], "column2")
226
+ >>> dataset = ds.zip((dataset_1, dataset_2))
227
+ """
228
+ if len(datasets) <= 1:
229
+ raise ValueError(
230
+ "Can't zip empty or just one dataset!")
231
+ for dataset in datasets:
232
+ if not isinstance(dataset, Dataset):
233
+ raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
234
+ return ZipDataset(datasets)
235
+
236
+
237
+ def _get_operator_process():
238
+ """
239
+ Inner implemented method, mainly for passing sub-process id in C layer
240
+
241
+ Returns:
242
+ dict, mapping dict of operation id and corresponding process id.
243
+ """
244
+ global _OP_PROCESS
245
+ process_info = _OP_PROCESS
246
+ op_process = dict()
247
+ keys = process_info.keys()
248
+ fetched_all = True
249
+ for key in keys:
250
+ try:
251
+ op_process[key] = list(process_info[key][1])
252
+ item_full = (len(process_info[key][1]) == process_info[key][0])
253
+ except KeyError as err:
254
+ raise err
255
+ fetched_all = fetched_all and item_full
256
+ return op_process, fetched_all
257
+
258
+
259
+ def _set_dataset_permissions(file_name, num_files):
260
+ """
261
+ set saved dataset files' permissions to 600
262
+ the rule of dataset filenames should be the same as those in C++.
263
+ """
264
+ num_digits = len(str(num_files - 1))
265
+ if num_files == 1:
266
+ paths = [file_name]
267
+ else:
268
+ paths = ["{}{}".format(file_name, str(x).rjust(num_digits, '0')) for x in range(num_files)]
269
+
270
+ for item in paths:
271
+ if os.path.exists(item):
272
+ os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
273
+ index_file = item + ".db"
274
+ if os.path.exists(index_file):
275
+ os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
276
+
277
+
278
+ class Dataset:
279
+ """
280
+ Abstract class to represent a dataset in DataEngine's data pipeline.
281
+
282
+ This class is the base class of SourceDataset and Dataset, and represents
283
+ a node in the data flow graph.
284
+ Dataset
285
+ -----------------------------------------------------------
286
+ | | | |
287
+ VisionBaseDataset TextBaseDataset AudioBaseDataset |
288
+ - - - |
289
+ | | | |
290
+ ---------------------------------------- |
291
+ UnionBaseDataset |
292
+ |
293
+ SourceDataset
294
+ -
295
+ |
296
+ MappableDataset
297
+
298
+ DatasetOperation: MapDataset(UnionBaseDataset)
299
+ BatchDataset(UnionBaseDataset)
300
+ PaddedBatchDataset(UnionBaseDataset)
301
+ BucketBatchByLengthDataset(UnionBaseDataset)
302
+ ShuffleDataset(UnionBaseDataset)
303
+ FilterDataset(UnionBaseDataset)
304
+ RepeatDataset(UnionBaseDataset)
305
+ SkipDataset(UnionBaseDataset)
306
+ TakeDataset(UnionBaseDataset)
307
+ ZipDataset(UnionBaseDataset)
308
+ ConcatDataset(UnionBaseDataset)
309
+ RenameDataset(UnionBaseDataset)
310
+ ProjectDataset(UnionBaseDataset)
311
+ SyncWaitDataset(UnionBaseDataset)
312
+
313
+ Impl Dataset - vision: ImageFolderDataset(MappableDataset, VisionBaseDataset)
314
+ USPSDataset(SourceDataset, VisionBaseDataset)
315
+ Impl Dataset - text: TextFileDataset(SourceDataset, TextBaseDataset)
316
+ YahooAnswersDataset(SourceDataset, TextBaseDataset)
317
+ Impl Dataset - audio: LJSpeechDataset(MappableDataset, AudioBaseDataset)
318
+ TedliumDataset(MappableDataset, AudioBaseDataset)
319
+ Impl Dataset - standard: MindDataset(MappableDataset, UnionBaseDataset)
320
+ TFRecordDataset(SourceDataset, UnionBaseDataset)
321
+ Impl Dataset - user defined: GeneratorDataset(MappableDataset, UnionBaseDataset)
322
+ NumpySlicesDataset(GeneratorDataset)
323
+
324
+ Args:
325
+ num_parallel_workers (int, optional): Number of workers to process the dataset in parallel.
326
+ Default: ``None``.
327
+ """
328
+
329
+ def __init__(self, children=None, num_parallel_workers=None, cache=None):
330
+ # Note: children and parent are internal variables, not recommended for external using.
331
+ self.children = replace_none(children, [])
332
+ if isinstance(self.children, tuple):
333
+ self.children = list(self.children)
334
+ if not isinstance(self.children, list):
335
+ self.children = [self.children]
336
+
337
+ self.parent = []
338
+ for child in self.children:
339
+ child.parent.append(weakref.ref(self))
340
+ self.num_parallel_workers = num_parallel_workers
341
+ self.cache = cache
342
+
343
+ self._device_iter = 0
344
+ self._input_indexs = ()
345
+ self.saved_output_types = None
346
+ self.saved_output_shapes = None
347
+ self.estimated_output_shapes = None
348
+ self.runtime_context = None
349
+ self._col_names = None
350
+ self.dataset_size = None
351
+ self._batch_size = None
352
+ self._num_classes = None
353
+ self._repeat_count = None
354
+ self._class_indexing = None
355
+ self._sync = False
356
+ self._global_step = None
357
+
358
+ @staticmethod
359
+ def _get_operator_id(dataset):
360
+ """
361
+ Internal method to iterate the tree and obtain op_id of each operation.
362
+
363
+ Returns:
364
+ Dataset, the root dataset of the tree.
365
+ """
366
+ op_name = dict()
367
+ generator_process = dict()
368
+ op_name[str(dataset)] = 0
369
+ op_id = 1
370
+
371
+ def process_name(datasets, operator_id):
372
+ if not datasets:
373
+ return 0
374
+ temp = []
375
+ for item in datasets:
376
+ for d in item.children:
377
+ temp.append(d)
378
+ op_name[str(d)] = operator_id
379
+
380
+ from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset
381
+ if isinstance(d, GeneratorDataset) and d.sample_fn and d.sample_fn.pids:
382
+ generator_process[operator_id] = [d.num_parallel_workers, set(d.sample_fn.pids)]
383
+
384
+ operator_id = operator_id + 1
385
+ return process_name(temp, operator_id)
386
+
387
+ process_name([dataset], op_id)
388
+ if generator_process:
389
+ global _OP_PROCESS
390
+ _OP_PROCESS.update(generator_process)
391
+ return op_name
392
+
393
+ def create_ir_tree(self, getter_mode=False):
394
+ """
395
+ Internal method to build an IR tree.
396
+
397
+ Args:
398
+ getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``.
399
+
400
+ Returns:
401
+ Union[DatasetNode, Dataset], the root node of the IR tree and the root dataset of the IR tree.
402
+ """
403
+ parent = self.parent
404
+ self.parent = []
405
+ dataset = copy.deepcopy(self)
406
+ global _OP_NAME
407
+ _OP_NAME = Dataset._get_operator_id(dataset)
408
+ ir_tree = dataset.parse_tree(getter_mode)
409
+ self.parent = parent
410
+ _init_device_info()
411
+ return ir_tree, dataset
412
+
413
+ def parse_tree(self, getter_mode=False):
414
+ """
415
+ Internal method to parse the API tree into an IR tree.
416
+
417
+ Args:
418
+ getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``False``.
419
+
420
+ Returns:
421
+ DatasetNode, the root node of the IR tree.
422
+ """
423
+ if len(self.parent) > 1:
424
+ raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
425
+ ir_children = [d.parse_tree(getter_mode) for d in self.children]
426
+ # Bootstrap can only be performed on a copy of the original dataset node.
427
+ # Bootstrap on original dataset node will make all iterators share the same process pool
428
+ self.pre_parse(getter_mode)
429
+ self.iterator_bootstrap()
430
+ ir_node = self.parse(ir_children)
431
+ ir_node = self.post_parse(ir_node)
432
+ return ir_node
433
+
434
+ def __safe_deepcopy__(self, memodict, exclude=()):
435
+ if id(self) in memodict:
436
+ return memodict[id(self)]
437
+ cls = self.__class__
438
+ new_op = cls.__new__(cls)
439
+ memodict[id(self)] = new_op
440
+ for arg, value in self.__dict__.items():
441
+ if arg in exclude:
442
+ setattr(new_op, arg, value)
443
+ else:
444
+ try:
445
+ setattr(new_op, arg, copy.deepcopy(value, memodict))
446
+ except TypeError:
447
+ setattr(new_op, arg, value)
448
+ return new_op
449
+
450
+ @staticmethod
451
+ def _noop_mode():
452
+ if _is_role_sched():
453
+ return True
454
+ return False
455
+
456
+ def iterator_bootstrap(self):
457
+ pass
458
+
459
+ def __add__(self, datasets):
460
+ return self.concat(datasets)
461
+
462
+ def to_json(self, filename=""):
463
+ """
464
+ Serialize a pipeline into JSON string and dump into file if filename is provided.
465
+
466
+ Args:
467
+ filename (str): filename of JSON file to be saved as. Default: ``""``.
468
+
469
+ Returns:
470
+ str, JSON string of the pipeline.
471
+
472
+ Examples:
473
+ >>> import mindspore.dataset as ds
474
+ >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
475
+ >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)
476
+ >>> dataset_json = dataset.to_json("/path/to/mnist_dataset_pipeline.json")
477
+ """
478
+ ir_tree, _ = self.create_ir_tree()
479
+ return json.loads(ir_tree.to_json(filename))
480
+
481
+ @check_bucket_batch_by_length
482
+ def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None,
483
+ pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False):
484
+ """
485
+ Bucket elements according to their lengths. Each bucket will be padded and batched when
486
+ they are full.
487
+
488
+ A length function is called on each row in the dataset. The row is then
489
+ bucketed based on its length and bucket boundaries. When a bucket reaches its
490
+ corresponding size specified in bucket_batch_sizes, the entire bucket will be
491
+ padded according to pad_info, and then form a batch.
492
+
493
+ Refer to the following figure for the execution process:
494
+
495
+ .. image:: bucket_batch_by_length_en.png
496
+
497
+ Note:
498
+ - When using `Data Sinking <https://www.mindspore.cn/docs/en/master/model_train/train_process/optimize/
499
+ sink_mode.html#data-sinking>`_ in Graph mode, the input shape of the network should keep consistent.
500
+ You should set `drop_remainder` to "True" to discard the last incomplete batch of data,
501
+ or supplement/remove samples to ensure the dataset size is divisible by `batch_size`.
502
+
503
+ Args:
504
+ column_names (list[str]): Columns passed to element_length_function.
505
+ bucket_boundaries (list[int]): A list consisting of the upper boundaries
506
+ of the buckets. Must be strictly increasing. If there are n boundaries,
507
+ n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
508
+ bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
509
+ 0<i<n-1, and the last bucket for [bucket_boundaries[n-1], inf).
510
+ bucket_batch_sizes (list[int]): A list consisting of the batch sizes for
511
+ each bucket. Must contain len(bucket_boundaries)+1 elements.
512
+ element_length_function (Callable, optional): A function that takes in
513
+ M arguments where M = len(column_names) and returns an integer. If no value
514
+ provided, parameter M the len(column_names) must be 1, and the size of the first
515
+ dimension of that column will be taken as the length. Default: ``None``.
516
+ pad_info (dict, optional): The information about how to batch each column. The key
517
+ corresponds to the column name, and the value must be a tuple of 2 elements.
518
+ The first element corresponds to the shape to pad to, and the second
519
+ element corresponds to the value to pad with. If a column is not
520
+ specified, then that column will be padded to the longest in the current
521
+ batch, and 0 will be used as the padding value. Any None dimensions will
522
+ be padded to the longest in the current batch, unless if
523
+ `pad_to_bucket_boundary` is ``True``. If no padding is wanted, set `pad_info`
524
+ to ``None``. Default: ``None``.
525
+ pad_to_bucket_boundary (bool, optional): If ``True``, will pad each None
526
+ dimension in `pad_info` to the bucket_boundary minus 1. If there are any
527
+ elements that fall into the last bucket, an error will occur.
528
+ Default: ``False``.
529
+ drop_remainder (bool, optional): If ``True``, will drop the last batch for each
530
+ bucket if it is not a full batch. Default: ``False``.
531
+
532
+ Returns:
533
+ Dataset, a new dataset with the above operation applied.
534
+
535
+ Examples:
536
+ >>> # Create a dataset where certain counts rows are combined into a batch
537
+ >>> # and drops the last incomplete batch if there is one.
538
+ >>> import mindspore.dataset as ds
539
+ >>> import numpy as np
540
+ >>> def generate_2_columns(n):
541
+ ... for i in range(n):
542
+ ... yield (np.array([i]), np.array([j for j in range(i + 1)]))
543
+ >>>
544
+ >>> column_names = ["col1", "col2"]
545
+ >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
546
+ >>> bucket_boundaries = [5, 10]
547
+ >>> bucket_batch_sizes = [2, 1, 1]
548
+ >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
549
+ >>> # Will pad col2 to shape [bucket_boundaries[i]] where i is the
550
+ >>> # index of the bucket that is currently being batched.
551
+ >>> pad_info = {"col2": ([None], -1)}
552
+ >>> pad_to_bucket_boundary = True
553
+ >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
554
+ ... bucket_batch_sizes,
555
+ ... element_length_function, pad_info,
556
+ ... pad_to_bucket_boundary)
557
+ """
558
+ return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
559
+ element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder)
560
+
561
+ @check_batch
562
+ def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, **kwargs):
563
+ """
564
+ Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first.
565
+
566
+ For any column, all the elements within that column must have the same shape.
567
+
568
+ Refer to the following figure for the execution process:
569
+
570
+ .. image:: batch_en.png
571
+
572
+ Note:
573
+ - The order of using repeat and batch reflects the number of batches and per_batch_map.
574
+ It is recommended that the repeat operation applied after the batch operation finished.
575
+ - When using `Data Sinking <https://www.mindspore.cn/docs/en/master/model_train/train_process/optimize/
576
+ sink_mode.html#data-sinking>`_ in Graph mode, the input shape of the network should keep consistent.
577
+ You should set `drop_remainder` to "True" to discard the last incomplete batch of data,
578
+ or supplement/remove samples to ensure the dataset size is divisible by `batch_size`.
579
+
580
+ Args:
581
+ batch_size (Union[int, Callable]): The number of rows each batch is created with. An
582
+ int or callable object which takes exactly 1 parameter, BatchInfo.
583
+ drop_remainder (bool, optional): Determines whether or not to drop the last block
584
+ whose data row number is less than batch size. Default: ``False`` . If ``True`` ,
585
+ and if there are less than `batch_size` rows available to make the last batch,
586
+ then those rows will be dropped and not propagated to the child node.
587
+ num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel.
588
+ Default: ``None`` .
589
+ **kwargs:
590
+
591
+ - per_batch_map (Callable[[List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo], \
592
+ (List[numpy.ndarray], ..., List[numpy.ndarray])], optional): Per batch map callable.
593
+ Default: ``None``.
594
+ A callable which takes (List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo) as input parameters.
595
+ Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists
596
+ should match with the number of entries in input_columns. The last parameter of the callable should
597
+ always be a BatchInfo object. Per_batch_map should return
598
+ (list[numpy.ndarray], list[numpy.ndarray], ...). The length of each list in output should be the same
599
+ as the input. output_columns is required if the number of output lists is different from input.
600
+
601
+ - input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of
602
+ the list should match with signature of `per_batch_map` callable. Default: ``None`` .
603
+
604
+ - output_columns (Union[str, list[str]], optional): List of names assigned to the columns
605
+ outputted by the last operation. This parameter is mandatory if len(input_columns) !=
606
+ len(output_columns). The size of this list must match the number of output
607
+ columns of the last operation. Default: ``None`` , output columns will have the same
608
+ name as the input columns, i.e., the columns will be replaced.
609
+
610
+ - python_multiprocessing (bool, optional): Parallelize Python function `per_batch_map` with
611
+ multiprocessing or multithreading mode, ``True`` means multiprocessing,
612
+ ``False`` means multithreading If `per_batch_map` is a I/O bound task, use
613
+ multithreading mode. If `per_batch_map` is a CPU bound task, it is recommended to use
614
+ multiprocessing mode. Default: ``False`` , use python multithreading mode.
615
+
616
+ - max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
617
+ allocation to copy data between processes, the total occupied shared memory will increase as
618
+ ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set
619
+ to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if
620
+ ``python_multiprocessing`` is set to True. If it is an int value, it represents
621
+ ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
622
+ If it is a list, the first element represents the ``input_columns`` use this value as the unit to
623
+ create shared memory, and the second element represents ``output_columns`` use this value as the unit
624
+ to create shared memory. Default: ``None`` , allocate shared memory dynamically.
625
+
626
+ Returns:
627
+ Dataset, a new dataset with the above operation applied.
628
+
629
+ Examples:
630
+ >>> # 1) Create a dataset where every 5 rows are combined into a batch
631
+ >>> # and drops the last incomplete batch if there is one.
632
+ >>> import mindspore.dataset as ds
633
+ >>> from PIL import Image
634
+ >>>
635
+ >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
636
+ >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=10)
637
+ >>> dataset = dataset.batch(5, True)
638
+ >>>
639
+ >>> # 2) resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
640
+ >>> def np_resize(col, BatchInfo):
641
+ ... output = col.copy()
642
+ ... s = (BatchInfo.get_batch_num() + 1) ** 2
643
+ ... index = 0
644
+ ... for c in col:
645
+ ... img = Image.fromarray(c.astype('uint8')).convert('RGB')
646
+ ... img = img.resize((s, s))
647
+ ... output[index] = np.array(img)
648
+ ... index += 1
649
+ ... return (output,)
650
+ >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
651
+ >>>
652
+ >>> # 3) Create a dataset where its batch size is dynamic
653
+ >>> # Define a callable batch size function and let batch size increase 1 each time.
654
+ >>> def add_one(BatchInfo):
655
+ ... return BatchInfo.get_batch_num() + 1
656
+ >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True)
657
+ """
658
+ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, **kwargs)
659
+
660
+ @check_padded_batch
661
+ def padded_batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None):
662
+ """
663
+ Combine batch_size number of consecutive rows into batch which apply pad_info to the samples first.
664
+
665
+ Refer to the following figure for the execution process:
666
+
667
+ .. image:: padded_batch_en.png
668
+
669
+ Note:
670
+ - The order of using repeat and padded_batch reflects the number of batches.
671
+ It is recommended that the repeat operation applied after the padded_batch operation finished.
672
+ - When using `Data Sinking <https://www.mindspore.cn/docs/en/master/model_train/train_process/optimize/
673
+ sink_mode.html#data-sinking>`_ in Graph mode, the input shape of the network should keep consistent.
674
+ You should set `drop_remainder` to "True" to discard the last incomplete batch of data,
675
+ or supplement/remove samples to ensure the dataset size is divisible by `batch_size`.
676
+
677
+ Args:
678
+ batch_size (Union[int, Callable]): The number of rows each batch is created with. An
679
+ int or callable object which takes exactly 1 parameter, BatchInfo.
680
+ drop_remainder (bool, optional): Determines whether or not to drop the last block
681
+ whose data row number is less than batch size. Default: ``False``. If ``True``, and if there
682
+ are less than batch_size rows available to make the last batch, then those rows will
683
+ be dropped and not propagated to the child node.
684
+ num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel.
685
+ Default: ``None``.
686
+ pad_info (dict, optional): The pad information about how to batch each column. The key
687
+ corresponds to the column name, and the value must be a tuple of 2 elements.
688
+ The first element corresponds to the shape to pad to, and the second
689
+ element corresponds to the value to pad with. If a column is not
690
+ specified, then that column will be padded to the longest in the current
691
+ batch, and 0 will be used as the padding value. If ``pad_info={"col1": ([224, 224], 0)}``,
692
+ expand the data column named ``col1`` to shape (224, 224), and fill in the missing values with 0.
693
+ If ``pad_info={}``, all samples in the batch will be filled to the shape with the largest sample
694
+ in the current batch. If ``pad_info={"col1": (None, 100)}``, all samples in the batch will be filled
695
+ to the shape with the largest sample in the current batch, and fill in the missing values with 100.
696
+ If no padding is wanted, set `pad_info` to ``None``. Default: ``None``.
697
+
698
+ Returns:
699
+ Dataset, a new dataset with the above operation applied.
700
+
701
+ Examples:
702
+ >>> # 1) Pad every sample to the largest sample's shape and batch the samples
703
+ >>> import mindspore.dataset as ds
704
+ >>> dataset = ds.NumpySlicesDataset([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], "column1")
705
+ >>> dataset = dataset.padded_batch(2, True, pad_info={})
706
+ >>>
707
+ >>> # 2) Create a dataset where every 3 rows are combined into a batch
708
+ >>> # and drops the last incomplete batch if there is one.
709
+ >>> dataset = ds.NumpySlicesDataset([i for i in range(10)], "column1")
710
+ >>> dataset = dataset.padded_batch(3, True)
711
+ >>>
712
+ >>> # 3) Create a dataset where its batch size is dynamic
713
+ >>> # Define a callable batch size function and let batch size increase 1 each time.
714
+ >>> def add_one(BatchInfo):
715
+ ... return BatchInfo.get_batch_num() + 1
716
+ >>> dataset = dataset.padded_batch(batch_size=add_one, drop_remainder=True)
717
+ """
718
+ return PaddedBatchDataset(self, batch_size, drop_remainder, num_parallel_workers, pad_info)
719
+
720
+ @check_sync_wait
721
+ def sync_wait(self, condition_name, num_batch=1, callback=None):
722
+ """
723
+ Add a blocking condition to the input Dataset and a synchronize action will be applied.
724
+
725
+ Args:
726
+ condition_name (str): The condition name that is used to toggle sending next row.
727
+ num_batch (int): the number of batches without blocking at the start of each epoch.
728
+ Default: ``1``.
729
+ callback (function): The callback function that will be invoked when sync_update is called.
730
+ Default: ``None``.
731
+
732
+ Returns:
733
+ Dataset, a new dataset with the above operation applied.
734
+
735
+ Raises:
736
+ RuntimeError: If condition name already exists.
737
+
738
+ Examples:
739
+ >>> import mindspore.dataset as ds
740
+ >>> import numpy as np
741
+ >>> def gen():
742
+ ... for i in range(100):
743
+ ... yield (np.array(i),)
744
+ >>>
745
+ >>> class Augment:
746
+ ... def __init__(self, loss):
747
+ ... self.loss = loss
748
+ ...
749
+ ... def preprocess(self, input_):
750
+ ... return input_
751
+ ...
752
+ ... def update(self, data):
753
+ ... self.loss = data["loss"]
754
+ >>>
755
+ >>> batch_size = 4
756
+ >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
757
+ >>>
758
+ >>> aug = Augment(0)
759
+ >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
760
+ >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
761
+ >>> dataset = dataset.batch(batch_size)
762
+ >>> count = 0
763
+ >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
764
+ ... assert data["input"][0] == count
765
+ ... count += batch_size
766
+ ... data = {"loss": count}
767
+ ... dataset.sync_update(condition_name="policy", data=data)
768
+ """
769
+ return SyncWaitDataset(self, condition_name, num_batch, callback)
770
+
771
+ @check_shuffle
772
+ def shuffle(self, buffer_size):
773
+ """
774
+ Shuffle the dataset by creating a cache with the size of `buffer_size` .
775
+
776
+ 1. Make a shuffle buffer that contains the first `buffer_size` rows.
777
+ 2. Randomly select an element from the shuffle buffer to be the next row
778
+ propagated to the child node.
779
+ 3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
780
+ 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
781
+
782
+ A random seed can be provided to be used on the first epoch via `dataset.config.set_seed` . In every subsequent
783
+ epoch, the seed is changed to a new one, randomly generated value.
784
+
785
+ Args:
786
+ buffer_size (int): The size of the buffer (must be larger than 1) for
787
+ shuffling. Setting `buffer_size` equal to the number of rows in the entire
788
+ dataset will result in a global shuffle.
789
+
790
+ Returns:
791
+ Dataset, a new dataset with the above operation applied.
792
+
793
+ Raises:
794
+ RuntimeError: If exist sync operations before shuffle.
795
+
796
+ Examples:
797
+ >>> import mindspore.dataset as ds
798
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
799
+ >>>
800
+ >>> # Optionally set the seed for fixed randomness
801
+ >>> ds.config.set_seed(58)
802
+ >>>
803
+ >>> # Create a shuffled dataset using a shuffle buffer of size 4
804
+ >>> dataset = dataset.shuffle(4)
805
+ """
806
+ return ShuffleDataset(self, buffer_size)
807
+
808
+ def flat_map(self, func):
809
+ """
810
+ Map `func` to each row in dataset and flatten the result.
811
+
812
+ Args:
813
+ func (function): A function that must take one `numpy.ndarray` as an argument and
814
+ return a `Dataset` .
815
+
816
+ Returns:
817
+ Dataset, a new dataset with the above operation applied.
818
+
819
+ Examples:
820
+ >>> import mindspore.dataset as ds
821
+ >>> # 1) flat_map on one column dataset
822
+ >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
823
+ >>>
824
+ >>> def repeat(array):
825
+ ... # create a NumpySlicesDataset with the array
826
+ ... data = ds.NumpySlicesDataset(array, shuffle=False)
827
+ ... # repeat the dataset twice
828
+ ... data = data.repeat(2)
829
+ ... return data
830
+ >>>
831
+ >>> dataset = dataset.flat_map(repeat)
832
+ >>> # [0, 1, 0, 1, 2, 3, 2, 3]
833
+ >>>
834
+ >>> # 2) flat_map on multi column dataset
835
+ >>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
836
+ >>>
837
+ >>> def plus_and_minus(col1, col2):
838
+ ... # apply different methods on columns
839
+ ... data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
840
+ ... return data
841
+ >>>
842
+ >>> dataset = dataset.flat_map(plus_and_minus)
843
+ >>> # ([1, 2, 3, 4], [-1, -2, -3, -4])
844
+
845
+ Raises:
846
+ TypeError: If `func` is not a function.
847
+ TypeError: If `func` doesn't return a Dataset.
848
+ """
849
+ dataset = None
850
+ if not hasattr(func, '__call__'):
851
+ logger.critical("func must be a function.")
852
+ raise TypeError("func must be a function.")
853
+
854
+ for row_data in self.create_tuple_iterator(num_epochs=1, output_numpy=True):
855
+ if dataset is None:
856
+ dataset = func(*row_data)
857
+ else:
858
+ dataset += func(*row_data)
859
+
860
+ if not isinstance(dataset, Dataset):
861
+ logger.critical("flat_map must return a Dataset object.")
862
+ raise TypeError("flat_map must return a Dataset object.")
863
+ return dataset
864
+
865
+ @check_map
866
+ def map(self, operations, input_columns=None, output_columns=None, column_order=None,
867
+ num_parallel_workers=None, **kwargs):
868
+ """
869
+ Apply each operation in operations to this dataset.
870
+
871
+ Each operation will be passed one or more columns from the dataset as input, and one or
872
+ more columns will be outputted. The first operation will be passed the columns specified
873
+ in input_columns as input. If there is more than one operation in operations, the outputted
874
+ columns of the previous operation are used as the input columns for the next operation.
875
+
876
+ The columns outputted by the very last operation will be assigned names specified by
877
+ `output_columns` , and if not specified, the column name of output column is same as that of `input_columns` .
878
+
879
+ - If you use transformations (
880
+ `vision transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
881
+ dataset.transforms.html#module-mindspore.dataset.vision>`_ ,
882
+ `nlp transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
883
+ dataset.transforms.html#module-mindspore.dataset.text>`_ ,
884
+ `audio transform <https://mindspore.cn/docs/en/master/api_python/mindspore.\
885
+ dataset.transforms.html#module-mindspore.dataset.audio>`_ )
886
+ provided by mindspore dataset, please use the following parameters:
887
+
888
+ .. image:: map_parameter_en.png
889
+
890
+ - If you use user-defined transform as PyFunc (Python Func), please use the following parameters:
891
+
892
+ .. image:: map_parameter_pyfunc_en.png
893
+
894
+ Args:
895
+ operations (Union[list[TensorOperation], list[functions]]): List of operations to be
896
+ applied on the dataset. Operations are applied in the order they appear in this list.
897
+ input_columns (Union[str, list[str]], optional): List of the names of the columns that will be passed to
898
+ the first operation as input. The size of this list must match the number of
899
+ input columns expected by the first operation. Default: ``None``, the first
900
+ operation will be passed however many columns that are required, starting from
901
+ the first column.
902
+ output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
903
+ the last operation. This parameter is mandatory if len(input_columns) !=
904
+ len(output_columns). The size of this list must match the number of output
905
+ columns of the last operation. Default: ``None``, output columns will have the same
906
+ name as the input columns, i.e., the columns will be replaced.
907
+ num_parallel_workers (int, optional): Number of threads used to process the dataset in
908
+ parallel. Default: ``None``, the value from the configuration will be used.
909
+ **kwargs:
910
+
911
+ - python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes.
912
+ This option could be beneficial if the Python operation is computational heavy. Default: ``False``.
913
+
914
+ - max_rowsize (Union[int, list[int]], optional): Maximum size of row in MB that is used for shared
915
+ memory allocation to copy data between processes, the total occupied shared memory will increase as
916
+ ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set
917
+ to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if
918
+ ``python_multiprocessing`` is set to True. If it is an int value, it represents
919
+ ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
920
+ If it is a list, the first element represents the ``input_columns`` use this value as the unit to
921
+ create shared memory, and the second element represents ``output_columns`` use this value as the unit
922
+ to create shared memory. Default: ``None`` , allocate shared memory dynamically.
923
+
924
+ - cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
925
+ Default: ``None``, which means no cache is used.
926
+
927
+ - callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called.
928
+ Default: ``None``.
929
+
930
+ - offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``.
931
+
932
+ Note:
933
+ - Input `operations` accepts TensorOperations defined in mindspore.dataset part, plus user-defined
934
+ Python functions (PyFuncs).
935
+ - Do not add network computing operators from mindspore.nn and mindspore.ops or others into this
936
+ `operations` .
937
+
938
+ Returns:
939
+ Dataset, a new dataset with the above operation applied.
940
+
941
+ Examples:
942
+ >>> import mindspore.dataset as ds
943
+ >>> import mindspore.dataset.vision as vision
944
+ >>> # dataset is an instance of Dataset which has 2 columns, "image" and "label".
945
+ >>> # image is of type bytes type which can be decoded to RGB
946
+ >>> # label is of type int32
947
+ >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
948
+ >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir)
949
+ >>>
950
+ >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column.
951
+ >>> decode_op = vision.Decode(to_pil=False)
952
+ >>> random_jitter_op = vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
953
+ ... saturation=(1, 1), hue=(0, 0))
954
+ >>>
955
+ >>> # 1) Simple map example.
956
+ >>>
957
+ >>> # Apply decode_op on column "image".
958
+ >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
959
+ >>>
960
+ >>> # Decode and rename column "image" to "decoded_image".
961
+ >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
962
+ >>>
963
+ >>> # A simple example for user defined python function transform.
964
+ >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
965
+ >>> dataset = dataset.map(operations=[(lambda x: x - 1)], input_columns=["data"])
966
+ >>>
967
+ >>> # 2) Map example with more than one operation.
968
+ >>>
969
+ >>> # Create a dataset where the images are decoded, then randomly color jittered.
970
+ >>> # decode_op takes column "image" as input and outputs one column. The column
971
+ >>> # outputted by decode_op is passed as input to random_jitter_op.
972
+ >>> # random_jitter_op will output one column. Column "image" will be replaced by
973
+ >>> # the column outputted by random_jitter_op (the very last operation). All other
974
+ >>> # columns are unchanged.
975
+ >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
976
+ >>>
977
+ >>> # Rename the column outputted by random_jitter_op to "image_mapped".
978
+ >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
979
+ ... output_columns=["image_mapped"])
980
+ >>>
981
+ >>> # Map with multiple operations using pyfunc and rename column's name
982
+ >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
983
+ >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
984
+ ... output_columns=["data_mapped"])
985
+ >>>
986
+ >>> # 3) Example where number of input columns is not equal to number of output columns.
987
+ >>>
988
+ >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
989
+ >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
990
+ >>> # operations[2] is a lambda that takes 1 column as input and outputs 4 columns.
991
+ >>> #
992
+ >>> # Note: The number of output columns of operation[i] must equal the number of
993
+ >>> # input columns of operation[i+1]. Otherwise, this map call will also result
994
+ >>> # in an error.
995
+ >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
996
+ ... (lambda x, y, z: x * y * z),
997
+ ... (lambda x: (x % 2, x % 3, x % 5, x % 7))]
998
+ >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
999
+ >>> dataset = dataset.map(operations, input_columns=["x", "y"],
1000
+ ... output_columns=["mod2", "mod3", "mod5", "mod7"])
1001
+ """
1002
+ if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
1003
+ num_parallel_workers = 1
1004
+ logger.warning(
1005
+ "Input 'operations' of 'map' includes network computing operators like in mindspore.nn, mindspore.ops, "
1006
+ "mindspore.numpy module and etc, which do not support multithreading compiling, recommend to replace "
1007
+ "it with python implemented operator like numpy etc. Here decrease 'num_parallel_workers' into 1.")
1008
+
1009
+ return MapDataset(self, operations, input_columns, output_columns, num_parallel_workers, **kwargs)
1010
+
1011
+ @check_filter
1012
+ def filter(self, predicate, input_columns=None, num_parallel_workers=None):
1013
+ """
1014
+ Filter dataset by prediction.
1015
+
1016
+ Args:
1017
+ predicate (callable): Python callable which returns a boolean value. If False then filter the element.
1018
+ input_columns (Union[str, list[str]], optional): List of names of the input columns. If not provided
1019
+ or provided with ``None``, the predicate will be applied on all columns in the dataset.
1020
+ Default: ``None``.
1021
+ num_parallel_workers (int, optional): Number of workers to process the dataset
1022
+ in parallel. Default: ``None``.
1023
+
1024
+ Returns:
1025
+ Dataset, a new dataset with the above operation applied.
1026
+
1027
+ Examples:
1028
+ >>> # generator data(0 ~ 19)
1029
+ >>> # filter the data that greater than or equal to 11
1030
+ >>> import mindspore.dataset as ds
1031
+ >>> dataset = ds.GeneratorDataset([i for i in range(20)], "data")
1032
+ >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
1033
+ """
1034
+ return FilterDataset(self, predicate, input_columns, num_parallel_workers)
1035
+
1036
+ @check_repeat
1037
+ def repeat(self, count=None):
1038
+ """
1039
+ Repeat this dataset `count` times. Repeat infinitely if the `count` is ``None`` or ``-1``.
1040
+
1041
+ Note:
1042
+ The order of using repeat and batch reflects the number of batches. It is recommended that
1043
+ the repeat operation is used after the batch operation.
1044
+
1045
+ Args:
1046
+ count (int): Number of times the dataset is going to be repeated. Default: ``None``.
1047
+
1048
+ Returns:
1049
+ Dataset, a new dataset with the above operation applied.
1050
+
1051
+ Examples:
1052
+ >>> import mindspore.dataset as ds
1053
+ >>>
1054
+ >>> # Create a dataset with 10 elements
1055
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1056
+ >>> ori_size = dataset.get_dataset_size()
1057
+ >>>
1058
+ >>> # Repeat the dataset 50 times.
1059
+ >>> dataset = dataset.repeat(50)
1060
+ >>> repeated_size = dataset.get_dataset_size()
1061
+ >>> print("ori_size", ori_size, ", repeated_size", repeated_size)
1062
+ ori_size 10 , repeated_size 500
1063
+ >>>
1064
+ >>> # Since the original dataset size is less than batch_size, thus no data is returned
1065
+ >>> dataset1 = ds.GeneratorDataset([i for i in range(10)], "column1")
1066
+ >>> dataset1 = dataset1.batch(batch_size=20, drop_remainder=True)
1067
+ >>> dataset1 = dataset1.repeat(6)
1068
+ >>>
1069
+ >>> # Repeat the original dataset to 60 elements, thus 3 batches are returned
1070
+ >>> dataset2 = ds.GeneratorDataset([i for i in range(10)], "column1")
1071
+ >>> dataset2 = dataset2.repeat(6)
1072
+ >>> dataset2 = dataset2.batch(batch_size=20, drop_remainder=True)
1073
+ >>> print("dataset1 size", dataset1.get_dataset_size(), ", dataset2 size", dataset2.get_dataset_size())
1074
+ dataset1 size 0 , dataset2 size 3
1075
+ """
1076
+ return RepeatDataset(self, count)
1077
+
1078
+ @check_skip
1079
+ def skip(self, count):
1080
+ """
1081
+ Skip the first N elements of this dataset.
1082
+
1083
+ Args:
1084
+ count (int): Number of elements in the dataset to be skipped.
1085
+
1086
+ Returns:
1087
+ Dataset, a new dataset with the above operation applied.
1088
+
1089
+ Examples:
1090
+ >>> import mindspore.dataset as ds
1091
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1092
+ >>> # Skip first 3 elements of dataset and retain 7 elements.
1093
+ >>> dataset = dataset.skip(3)
1094
+ """
1095
+ return SkipDataset(self, count)
1096
+
1097
+ @check_take
1098
+ def take(self, count=-1):
1099
+ """
1100
+ Take the first specified number of samples from the dataset.
1101
+
1102
+ Args:
1103
+ count (int, optional): The desired number of samples to take. If the value exceeds
1104
+ the total number of samples in the dataset, all data will be returned.
1105
+ Default: ``-1`` , will return all data.
1106
+
1107
+ Note:
1108
+ When there are operations that will change the number of samples of the dataset in
1109
+ the data pipeline, the location of the `take` operation can change its effect.
1110
+ For example, `batch` operation will combine the successive samples of the specified
1111
+ `batch_size` into 1 sample, so `.batch(batch_size).take(1)` will be equivalent to
1112
+ `.take(batch_size).batch(batch_size)`.
1113
+
1114
+ Returns:
1115
+ Dataset, a new dataset with the above operation applied.
1116
+
1117
+ Examples:
1118
+ >>> import mindspore.dataset as ds
1119
+ >>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
1120
+ >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir)
1121
+ >>> # Take 50 samples from MNIST dataset.
1122
+ >>> dataset = dataset.take(50)
1123
+ """
1124
+ return TakeDataset(self, count)
1125
+
1126
+ def _get_absolute_split_sizes(self, sizes):
1127
+ """
1128
+ Internal method called by split to calculate absolute split sizes and to
1129
+ do some error checking after calculating absolute split sizes.
1130
+
1131
+ Returns:
1132
+ int, absolute split sizes of the dataset.
1133
+ """
1134
+ # Call get_dataset_size here and check input here because
1135
+ # don't want to call this once in check_split and another time in
1136
+ # here again
1137
+ dataset_size = self.get_dataset_size()
1138
+
1139
+ if dataset_size is None or dataset_size <= 0:
1140
+ raise RuntimeError("dataset_size is unknown, unable to split.")
1141
+
1142
+ if not isinstance(sizes, list):
1143
+ raise RuntimeError("sizes must be a list.")
1144
+
1145
+ all_int = all(isinstance(item, int) for item in sizes)
1146
+ if all_int:
1147
+ sizes_sum = sum(sizes)
1148
+ if sizes_sum != dataset_size:
1149
+ raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
1150
+ .format(sizes_sum, dataset_size))
1151
+ return sizes
1152
+
1153
+ absolute_sizes = []
1154
+ for item in sizes:
1155
+ absolute_size = int(round(item * dataset_size))
1156
+ if absolute_size == 0:
1157
+ raise RuntimeError("Split percentage {} is too small.".format(item))
1158
+ absolute_sizes.append(absolute_size)
1159
+
1160
+ absolute_sizes_sum = sum(absolute_sizes)
1161
+
1162
+ # if we still need more rows, give them to the first split.
1163
+ # if we have too many rows, remove the extras from the first split that has
1164
+ # enough rows.
1165
+ size_difference = int(dataset_size - absolute_sizes_sum)
1166
+ if size_difference > 0:
1167
+ absolute_sizes[0] += size_difference
1168
+ else:
1169
+ for i, _ in enumerate(absolute_sizes):
1170
+ if absolute_sizes[i] + size_difference > 0:
1171
+ absolute_sizes[i] += size_difference
1172
+ break
1173
+
1174
+ if sum(absolute_sizes) != dataset_size:
1175
+ raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
1176
+ .format(absolute_sizes_sum, dataset_size))
1177
+
1178
+ return absolute_sizes
1179
+
1180
+ @check_split
1181
+ def split(self, sizes, randomize=True):
1182
+ """
1183
+ Split the dataset into smaller, non-overlapping datasets.
1184
+
1185
+ Args:
1186
+ sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
1187
+ provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
1188
+ respectively. If the sum of all input sizes does not equal the original dataset size, an
1189
+ error will throw.
1190
+ If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
1191
+ and must sum to 1, otherwise an error will throw. The dataset will be split into n
1192
+ Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
1193
+ original dataset.
1194
+ If after rounding:
1195
+
1196
+ - Any size equals 0, an error will occur.
1197
+ - The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first
1198
+ split.
1199
+ - The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first
1200
+ large enough split such that it will have at least 1 row after removing the difference.
1201
+
1202
+ randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``.
1203
+ If True, the data will be randomly split. Otherwise, each split will be created with
1204
+ consecutive rows from the dataset.
1205
+
1206
+ Note:
1207
+ 1. Dataset cannot be sharded if split is going to be called.
1208
+ 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
1209
+ Shuffling the dataset may not be deterministic, which means the data in each split
1210
+ will be different in each epoch.
1211
+
1212
+ Returns:
1213
+ Tuple[Dataset], a tuple of new datasets split from the original one.
1214
+
1215
+ Raises:
1216
+ RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
1217
+ RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
1218
+ equal the dataset size.
1219
+ RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
1220
+ RuntimeError: If the dataset is sharded prior to calling split.
1221
+ ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
1222
+ floats don't sum to 1.
1223
+
1224
+ Examples:
1225
+ >>> # Split the data into train part and test part.
1226
+ >>> import mindspore.dataset as ds
1227
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1228
+ >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
1229
+ """
1230
+ if self.is_shuffled():
1231
+ logger.warning("Dataset is shuffled before split.")
1232
+
1233
+ if self.is_sharded():
1234
+ raise RuntimeError("Dataset should not be sharded before split.")
1235
+
1236
+ absolute_sizes = self._get_absolute_split_sizes(sizes)
1237
+ splits = []
1238
+ rows_to_skip = 0
1239
+ for size in absolute_sizes:
1240
+ ds = copy.deepcopy(self)
1241
+ if randomize:
1242
+ # want to shuffle the same way every epoch before split
1243
+ # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
1244
+ ds = ds.shuffle(10000)
1245
+ ds.reshuffle_each_epoch = False
1246
+
1247
+ if rows_to_skip > 0:
1248
+ ds = ds.skip(rows_to_skip)
1249
+
1250
+ ds = ds.take(size)
1251
+ splits.append(ds)
1252
+
1253
+ rows_to_skip += size
1254
+
1255
+ return tuple(splits)
1256
+
1257
+ @check_zip_dataset
1258
+ def zip(self, datasets):
1259
+ """
1260
+ Zip the datasets in the sense of input tuple of datasets. Columns in the input datasets must have different
1261
+ name.
1262
+
1263
+ Args:
1264
+ datasets (Union[Dataset, tuple[Dataset]]): A tuple of datasets or a single class Dataset
1265
+ to be zipped together with this dataset.
1266
+
1267
+ Returns:
1268
+ Dataset, a new dataset with the above operation applied.
1269
+
1270
+ Raises:
1271
+ TypeError: The parameter is not dataset object or tuple of dataset objects.
1272
+
1273
+ Examples:
1274
+ >>> # Create a dataset which is the combination of dataset_1 and dataset_2
1275
+ >>> import mindspore.dataset as ds
1276
+ >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1")
1277
+ >>> dataset_2 = ds.GeneratorDataset([1, 2, 3], "column2")
1278
+ >>> dataset = dataset_1.zip(dataset_2)
1279
+ """
1280
+ if isinstance(datasets, tuple):
1281
+ datasets = (self, *datasets)
1282
+ elif isinstance(datasets, Dataset):
1283
+ datasets = (self, datasets)
1284
+ else:
1285
+ raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets)
1286
+ return ZipDataset(datasets)
1287
+
1288
+ @check_concat
1289
+ def concat(self, datasets):
1290
+ """
1291
+ Concatenate the dataset objects in the input list.
1292
+ Performing "+" operation on dataset objects can achieve the same effect.
1293
+
1294
+ For a dataset concatenated by many other dataset objects, it returns the data in the order of
1295
+ datasets passed in. If you want to change the data order(such as random selection from each dataset
1296
+ instead of in sequence), apply `use_sampler` method on the concatenated dataset object.
1297
+ Currently `use_sampler` supports `dataset.DistributedSampler` for sharding selection from each dataset
1298
+ or `dataset.RandomSampler` for random selection from each dataset, see examples below.
1299
+
1300
+ Note:
1301
+ The column name, and rank and type of the column data must be the same in the input datasets.
1302
+
1303
+ Args:
1304
+ datasets (Union[list, Dataset]): A list of datasets or a single class Dataset
1305
+ to be concatenated together with this dataset.
1306
+
1307
+ Returns:
1308
+ Dataset, a new dataset with the above operation applied.
1309
+
1310
+ Examples:
1311
+ >>> import mindspore.dataset as ds
1312
+ >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1313
+ >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1314
+ >>>
1315
+ >>> # Create a dataset by concatenating dataset_1 and dataset_2 with "+" operator
1316
+ >>> dataset = dataset_1 + dataset_2
1317
+ >>> # Create a dataset by concatenating dataset_1 and dataset_2 with concat operation
1318
+ >>> dataset = dataset_1.concat(dataset_2)
1319
+ >>>
1320
+ >>> # Check the data order of dataset
1321
+ >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1322
+ >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1323
+ >>> dataset = dataset_1 + dataset_2
1324
+ >>> result = list(dataset)
1325
+ >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 2)],
1326
+ >>> # [Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)],
1327
+ >>> # [Tensor(shape=[], dtype=Int64, value= 5)], [Tensor(shape=[], dtype=Int64, value= 6)]]
1328
+ >>>
1329
+ >>> # Change the data order of concatenated dataset with sharding selection
1330
+ >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1331
+ >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1332
+ >>> dataset = dataset_1.concat(dataset_2)
1333
+ >>> dataset.use_sampler(ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False))
1334
+ >>> result = list(dataset)
1335
+ >>> # [[Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 4)],
1336
+ >>> # [Tensor(shape=[], dtype=Int64, value= 6)]]
1337
+ >>>
1338
+ >>> # Change the data order of concatenated dataset with random selection
1339
+ >>> dataset_1 = ds.GeneratorDataset([1, 2, 3], "column1", shuffle=False)
1340
+ >>> dataset_2 = ds.GeneratorDataset([4, 5, 6], "column1", shuffle=False)
1341
+ >>> dataset = dataset_1.concat(dataset_2)
1342
+ >>> dataset.use_sampler(ds.RandomSampler())
1343
+ >>> result = list(dataset)
1344
+ >>> # [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 4)],
1345
+ >>> # [Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 5)],
1346
+ >>> # [Tensor(shape=[], dtype=Int64, value= 6)], [Tensor(shape=[], dtype=Int64, value= 3)]]
1347
+ """
1348
+ if isinstance(datasets, Dataset):
1349
+ datasets = [self] + [datasets]
1350
+ elif isinstance(datasets, list):
1351
+ datasets = [self] + datasets
1352
+ else:
1353
+ raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets)
1354
+ return ConcatDataset(datasets)
1355
+
1356
+ @check_rename
1357
+ def rename(self, input_columns, output_columns):
1358
+ """
1359
+ Rename the columns in input datasets.
1360
+
1361
+ Args:
1362
+ input_columns (Union[str, list[str]]): List of names of the input columns.
1363
+ output_columns (Union[str, list[str]]): List of names of the output columns.
1364
+
1365
+ Returns:
1366
+ Dataset, a new dataset with the above operation applied.
1367
+
1368
+ Examples:
1369
+ >>> import mindspore.dataset as ds
1370
+ >>> input_columns = ["input_col1", "input_col2", "input_col3"]
1371
+ >>> output_columns = ["output_col1", "output_col2", "output_col3"]
1372
+ >>>
1373
+ >>> # Create a dataset with 3 columns
1374
+ >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns)
1375
+ >>>
1376
+ >>> # Rename "input_col1" to "output_col1", "input_col2" to "output_col2", "input_col3" to "output_col3"
1377
+ >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
1378
+ """
1379
+
1380
+ return RenameDataset(self, input_columns, output_columns)
1381
+
1382
+ @check_project
1383
+ def project(self, columns):
1384
+ """
1385
+ The specified columns will be selected from the dataset and passed into
1386
+ the pipeline with the order specified. The other columns are discarded.
1387
+
1388
+ Args:
1389
+ columns(Union[str, list[str]]): List of names of the columns to project.
1390
+
1391
+ Returns:
1392
+ Dataset, a new dataset with the above operation applied.
1393
+
1394
+ Examples:
1395
+ >>> import mindspore.dataset as ds
1396
+ >>> # Create a dataset with 3 columns
1397
+ >>> input_columns = ["column1", "column2", "column3"]
1398
+ >>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns)
1399
+ >>>
1400
+ >>> columns_to_project = ["column3", "column1", "column2"]
1401
+ >>> # in that order, regardless of the original order of columns.
1402
+ >>> dataset = dataset.project(columns=columns_to_project)
1403
+ """
1404
+
1405
+ return ProjectDataset(self, columns)
1406
+
1407
+ def apply(self, apply_func):
1408
+ """
1409
+ Apply a function in this dataset.
1410
+
1411
+ Args:
1412
+ apply_func (function): A function that must take one `Dataset` as an argument and
1413
+ return a preprocessed `Dataset` .
1414
+
1415
+ Returns:
1416
+ Dataset, a new dataset with the above operation applied.
1417
+
1418
+ Examples:
1419
+ >>> import mindspore.dataset as ds
1420
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1421
+ >>>
1422
+ >>> # Declare an apply_func function which returns a Dataset object
1423
+ >>> def apply_func(data):
1424
+ ... data = data.batch(2)
1425
+ ... return data
1426
+ >>>
1427
+ >>> # Use apply to call apply_func
1428
+ >>> dataset = dataset.apply(apply_func)
1429
+
1430
+ Raises:
1431
+ TypeError: If apply_func is not a function.
1432
+ TypeError: If apply_func doesn't return a Dataset.
1433
+ """
1434
+
1435
+ if not hasattr(apply_func, '__call__'):
1436
+ raise TypeError("apply_func must be a function.")
1437
+
1438
+ dataset = apply_func(self)
1439
+ if not isinstance(dataset, Dataset):
1440
+ raise TypeError("apply_func must return a dataset.")
1441
+ return dataset
1442
+
1443
+ @check_device_send
1444
+ def device_que(self, send_epoch_end=True, create_data_info_queue=False, queue_name=""):
1445
+ """
1446
+ Return a transferred Dataset that transfers data through a device.
1447
+
1448
+ Args:
1449
+ send_epoch_end (bool, optional): Whether to send end of sequence to device or not.
1450
+ Default: ``True``.
1451
+ create_data_info_queue (bool, optional): Whether to create queue which stores
1452
+ types and shapes of data or not. Default: ``False``.
1453
+ queue_name (str, optional): Name of queue which connects dataset processing and model
1454
+ computing. Default: ``""``.
1455
+
1456
+ Note:
1457
+ If device is Ascend, features of data will be transferred one by one. The limitation
1458
+ of data transmission per time is 256M.
1459
+
1460
+ Returns:
1461
+ Dataset, a new dataset with the above operation applied.
1462
+
1463
+ Examples:
1464
+ >>> import mindspore.dataset as ds
1465
+ >>> import time
1466
+ >>>
1467
+ >>> data = ds.TFRecordDataset('/path/to/TF_FILES', '/path/to/TF_SCHEMA_FILE', shuffle=ds.Shuffle.FILES)
1468
+ >>> data = data.device_que()
1469
+ >>> data.send()
1470
+ >>> time.sleep(0.1)
1471
+ >>> data.stop_send()
1472
+ """
1473
+ return TransferDataset(self, send_epoch_end, create_data_info_queue, queue_name)
1474
+
1475
+ @check_save
1476
+ def save(self, file_name, num_files=1, file_type='mindrecord'):
1477
+ """
1478
+ Save the dynamic data processed by the dataset pipeline in common dataset format.
1479
+ Supported dataset formats: ``'mindrecord'`` only. And you can use
1480
+ :class:`mindspore.dataset.MindDataset` API to read the saved file(s).
1481
+
1482
+ Implicit type casting exists when saving data as ``'mindrecord'`` . The transform table shows how to do
1483
+ type casting.
1484
+
1485
+ .. list-table:: Implicit Type Casting when Saving as `mindrecord`
1486
+ :widths: 25 25 50
1487
+ :header-rows: 1
1488
+
1489
+ * - Type in `dataset`
1490
+ - Type in `mindrecord`
1491
+ - Details
1492
+ * - bool
1493
+ - int32
1494
+ - transform to int32
1495
+ * - int8
1496
+ - int32
1497
+ -
1498
+ * - uint8
1499
+ - int32
1500
+ -
1501
+ * - int16
1502
+ - int32
1503
+ -
1504
+ * - uint16
1505
+ - int32
1506
+ -
1507
+ * - int32
1508
+ - int32
1509
+ -
1510
+ * - uint32
1511
+ - int64
1512
+ -
1513
+ * - int64
1514
+ - int64
1515
+ -
1516
+ * - uint64
1517
+ - int64
1518
+ - Maybe reverse
1519
+ * - float16
1520
+ - float32
1521
+ -
1522
+ * - float32
1523
+ - float32
1524
+ -
1525
+ * - float64
1526
+ - float64
1527
+ -
1528
+ * - string
1529
+ - string
1530
+ - Multi-dimensional string not supported
1531
+ * - bytes
1532
+ - bytes
1533
+ - Multi-dimensional bytes not supported
1534
+
1535
+ Note:
1536
+ 1. To save the samples in order, set dataset's `shuffle` to ``False`` and `num_files` to ``1``.
1537
+ 2. Before calling the function, do not use batch operation, repeat operation or data augmentation operations
1538
+ with random attribute in map operation.
1539
+ 3. When array dimension is variable, one-dimensional arrays or
1540
+ multidimensional arrays with variable dimension 0 are supported.
1541
+ 4. MindRecord does not support multidimensional string or multidimensional bytes.
1542
+
1543
+ Args:
1544
+ file_name (str): Path to dataset file.
1545
+ num_files (int, optional): Number of dataset files. Default: ``1`` .
1546
+ file_type (str, optional): Dataset format. Default: ``'mindrecord'`` .
1547
+
1548
+ Examples:
1549
+ >>> import mindspore.dataset as ds
1550
+ >>> import numpy as np
1551
+ >>>
1552
+ >>> def generator_1d():
1553
+ ... for i in range(10):
1554
+ ... yield (np.array([i]),)
1555
+ >>>
1556
+ >>> # apply dataset operations
1557
+ >>> d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
1558
+ >>> d1.save('/path/to/save_file')
1559
+ """
1560
+ if (_get_enc_key() is not None or _get_hash_mode() is not None) and num_files > 1:
1561
+ raise RuntimeError("When encode mode or hash check is enabled, " +
1562
+ "the automatic sharding function is unavailable.")
1563
+
1564
+ ir_tree, api_tree = self.create_ir_tree()
1565
+
1566
+ runtime_context = cde.PythonRuntimeContext()
1567
+ runtime_context.Init()
1568
+ consumer = cde.PythonSaveToDisk(file_name, num_files, file_type)
1569
+ consumer.Init(ir_tree)
1570
+ runtime_context.AssignConsumer(consumer)
1571
+
1572
+ consumer.Save()
1573
+
1574
+ if _get_hash_mode() is not None:
1575
+ append_hash_to_file(file_name)
1576
+ append_hash_to_file(file_name + ".db")
1577
+
1578
+ if _get_enc_key() is not None:
1579
+ encrypt(file_name, _get_enc_key(), _get_enc_mode())
1580
+ encrypt(file_name + ".db", _get_enc_key(), _get_enc_mode())
1581
+
1582
+ _set_dataset_permissions(file_name, num_files)
1583
+ del api_tree
1584
+
1585
+ @check_tuple_iterator
1586
+ def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
1587
+ """
1588
+ Create an iterator over the dataset that yields samples of type list, whose elements are
1589
+ the data for each column.
1590
+
1591
+ Args:
1592
+ columns (list[str], optional): Specify the output columns and the order.
1593
+ Default: ``None``, keep all the output columns and their original order.
1594
+ num_epochs (int, optional): The number of epochs to iterate over the entire dataset.
1595
+ Default: ``-1`` , the dataset can be iterated indefinitely.
1596
+ output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or
1597
+ convert it to Tensor. Default: ``False`` .
1598
+ do_copy (bool, optional): Whether to copy the data when converting output to Tensor,
1599
+ or reuse the buffer for better performance, only works when `output_numpy` is ``False`` .
1600
+ Default: ``True`` .
1601
+
1602
+ Returns:
1603
+ Iterator, a dataset iterator that yields samples of type list.
1604
+
1605
+ Examples:
1606
+ >>> import mindspore.dataset as ds
1607
+ >>>
1608
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data")
1609
+ >>> num_epochs = 3
1610
+ >>> iterator = dataset.create_tuple_iterator(num_epochs=num_epochs)
1611
+ >>> for epoch in range(num_epochs):
1612
+ ... for item in iterator:
1613
+ ... # output is of type tuple
1614
+ ... print(type(item))
1615
+ ... break
1616
+ ... break
1617
+ <class 'list'>
1618
+ """
1619
+ if output_numpy is None:
1620
+ output_numpy = False
1621
+
1622
+ if Dataset._noop_mode():
1623
+ return DummyIterator(self, 'tuple', output_numpy)
1624
+ return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
1625
+
1626
+ @check_dict_iterator
1627
+ def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True):
1628
+ """
1629
+ Create an iterator over the dataset that yields samples of type dict,
1630
+ while the key is the column name and the value is the data.
1631
+
1632
+ Args:
1633
+ num_epochs (int, optional): The number of epochs to iterate over the entire dataset.
1634
+ Default: ``-1`` , the dataset can be iterated indefinitely.
1635
+ output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or
1636
+ convert it to Tensor. Default: ``False`` .
1637
+ do_copy (bool, optional): Whether to copy the data when converting output to Tensor,
1638
+ or reuse the buffer for better performance, only works when `output_numpy` is ``False`` .
1639
+ Default: ``True`` .
1640
+
1641
+ Returns:
1642
+ Iterator, a dataset iterator that yields samples of type dict.
1643
+
1644
+ Examples:
1645
+ >>> import mindspore.dataset as ds
1646
+ >>>
1647
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "data")
1648
+ >>> num_epochs = 3
1649
+ >>> iterator = dataset.create_dict_iterator(num_epochs=num_epochs)
1650
+ >>> for epoch in range(num_epochs):
1651
+ ... for item in iterator:
1652
+ ... # output is of type dict
1653
+ ... print(type(item))
1654
+ ... break
1655
+ ... break
1656
+ <class 'dict'>
1657
+ """
1658
+ if output_numpy is None:
1659
+ output_numpy = False
1660
+
1661
+ if Dataset._noop_mode():
1662
+ return DummyIterator(self, 'dict', output_numpy)
1663
+ return DictIterator(self, num_epochs, output_numpy, do_copy)
1664
+
1665
+ def __iter__(self):
1666
+ """Create an iterator over the dataset."""
1667
+ return self.create_tuple_iterator(num_epochs=1)
1668
+
1669
+ @property
1670
+ def input_indexs(self):
1671
+ """
1672
+ Get the column index, which represents the corresponding relationship between the data column order
1673
+ and the network when using the sink mode.
1674
+
1675
+ Returns:
1676
+ int, tuple of the input index information.
1677
+
1678
+ Examples:
1679
+ >>> import mindspore.dataset as ds
1680
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1681
+ >>> # set input_indexs
1682
+ >>> dataset.input_indexs = 10
1683
+ >>> print(dataset.input_indexs)
1684
+ 10
1685
+ """
1686
+ if self._input_indexs != ():
1687
+ return self._input_indexs
1688
+
1689
+ # find input_indexes of children
1690
+ children_input_index = [child.input_indexs for child in self.children]
1691
+
1692
+ # in case of more than one child, return the first input_indexes
1693
+ for cix in children_input_index:
1694
+ if cix != ():
1695
+ return cix
1696
+
1697
+ # if all children's input_indexes are () or the node is a leaf
1698
+ return self._input_indexs
1699
+
1700
+ @input_indexs.setter
1701
+ def input_indexs(self, value):
1702
+ self._input_indexs = value
1703
+
1704
+ def copy_batch_size(self, value):
1705
+ self._batch_size = value
1706
+
1707
+ def _init_tree_getters(self, getter_mode=True):
1708
+ """
1709
+ Get pipeline information.
1710
+
1711
+ Args:
1712
+ getter_mode (bool, optional): Whether to build IR tree in pull mode. Default: ``True``.
1713
+ """
1714
+ ir_tree, api_tree = self.create_ir_tree(getter_mode)
1715
+
1716
+ runtime_context = cde.PythonRuntimeContext()
1717
+ runtime_context.Init()
1718
+ getter = cde.TreeGetters()
1719
+ getter.Init(ir_tree)
1720
+ runtime_context.AssignConsumer(getter)
1721
+ return getter, runtime_context, api_tree
1722
+
1723
+ def __init_size_getter(self):
1724
+ """
1725
+ Get pipeline information.
1726
+ """
1727
+ ir_tree, api_tree = self.create_ir_tree()
1728
+
1729
+ runtime_context = cde.PythonRuntimeContext()
1730
+ runtime_context.Init()
1731
+ getter = cde.DatasetSizeGetters()
1732
+ getter.Init(ir_tree)
1733
+ runtime_context.AssignConsumer(getter)
1734
+ return getter, runtime_context, api_tree
1735
+
1736
+ def get_col_names(self):
1737
+ """
1738
+ Return the names of the columns in dataset.
1739
+
1740
+ Returns:
1741
+ list, list of column names in the dataset.
1742
+
1743
+ Examples:
1744
+ >>> import mindspore.dataset as ds
1745
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1746
+ >>> col_names = dataset.get_col_names()
1747
+ >>> print(col_names)
1748
+ ['column1']
1749
+
1750
+ """
1751
+ if self._col_names is None:
1752
+ runtime_getter = self._init_tree_getters()
1753
+ self._col_names = runtime_getter[0].GetColumnNames()
1754
+
1755
+ return self._col_names
1756
+
1757
+ @check_output_shape
1758
+ @_cleanup_the_iterators_if_created
1759
+ def output_shapes(self, estimate=False):
1760
+ """
1761
+ Get the shapes of output data.
1762
+
1763
+ Args:
1764
+ estimate (bool): If `estimate` is ``False`` , will return the shapes of first data row.
1765
+ Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
1766
+ where dynamic shape is marked as None (used in dynamic data shapes scenario).
1767
+ Default: ``False`` .
1768
+
1769
+ Returns:
1770
+ list, list of shapes of each column.
1771
+
1772
+ Examples:
1773
+ >>> import mindspore.dataset as ds
1774
+ >>> import numpy as np
1775
+ >>>
1776
+ >>> def generator1():
1777
+ ... for i in range(1, 100):
1778
+ ... yield np.ones((16, 83, 83)), np.array([i])
1779
+ >>>
1780
+ >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1781
+ >>> output_shapes = dataset.output_shapes()
1782
+ >>> print(output_shapes)
1783
+ [[16, 83, 83], [1]]
1784
+ """
1785
+ # cache single shape
1786
+ if not estimate and self.saved_output_shapes is not None:
1787
+ return self.saved_output_shapes
1788
+ # cache estimate shape
1789
+ if estimate and self.estimated_output_shapes is not None:
1790
+ return self.estimated_output_shapes
1791
+
1792
+ # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle
1793
+ # of runtime_context. We found this hang problem only occur on output_types and output_shapes.
1794
+ runtime_getter = self._init_tree_getters()
1795
+ self.runtime_context = runtime_getter[1]
1796
+ api_tree = runtime_getter[2]
1797
+ output_shapes = runtime_getter[0].GetOutputShapes(estimate)
1798
+ del api_tree
1799
+ # Need to terminate the runtime context to avoid the occasional hang problem for
1800
+ # Python (with multiprocessing enabled) in sink mode.
1801
+ self.runtime_context.Terminate()
1802
+ del self.runtime_context
1803
+
1804
+ if estimate:
1805
+ self.estimated_output_shapes = output_shapes
1806
+ else:
1807
+ self.saved_output_shapes = output_shapes
1808
+ return output_shapes
1809
+
1810
+ @_cleanup_the_iterators_if_created
1811
+ def output_types(self):
1812
+ """
1813
+ Get the types of output data.
1814
+
1815
+ Returns:
1816
+ list, list of data types.
1817
+
1818
+ Examples:
1819
+ >>> import mindspore.dataset as ds
1820
+ >>> import numpy as np
1821
+ >>>
1822
+ >>> def generator1():
1823
+ ... for i in range(1, 100):
1824
+ ... yield np.ones((16, 83, 83)).astype(np.float32), np.array([i]).astype(np.int32)
1825
+ >>>
1826
+ >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1827
+ >>> output_types = dataset.output_types()
1828
+ >>> print(output_types)
1829
+ [dtype('float32'), dtype('int32')]
1830
+ """
1831
+ if self.saved_output_types is None:
1832
+ runtime_getter = self._init_tree_getters()
1833
+ # We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle
1834
+ # of runtime_context. We found this hang problem only occur on output_types and output_shapes.
1835
+ self.runtime_context = runtime_getter[1]
1836
+ api_tree = runtime_getter[2]
1837
+ self.saved_output_types = runtime_getter[0].GetOutputTypes()
1838
+ del api_tree
1839
+ # Need to terminate the runtime context to avoid the occasional hang problem for
1840
+ # Python (with multiprocessing enabled) in sink mode.
1841
+ self.runtime_context.Terminate()
1842
+ del self.runtime_context
1843
+ return self.saved_output_types
1844
+
1845
+ @_cleanup_the_iterators_if_created
1846
+ def get_dataset_size(self):
1847
+ """
1848
+ Return the number of batches in an epoch.
1849
+
1850
+ Returns:
1851
+ int, number of batches.
1852
+
1853
+ Examples:
1854
+ >>> import mindspore.dataset as ds
1855
+ >>> import numpy as np
1856
+ >>>
1857
+ >>> # A generator return 66 samples
1858
+ >>> def generator1():
1859
+ ... for i in range(66):
1860
+ ... yield np.ones((16, 83, 83)), np.array([i])
1861
+ >>>
1862
+ >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
1863
+ >>> dataset_size = dataset.get_dataset_size()
1864
+ >>> print(dataset_size)
1865
+ 66
1866
+ """
1867
+ if self.dataset_size is None:
1868
+ runtime_getter = self.__init_size_getter()
1869
+ self.dataset_size = runtime_getter[0].GetDatasetSize(False)
1870
+ if self.dataset_size == 0:
1871
+ logger.warning("Got 0 sample from dataset pipeline, check if drop all data or load dataset fail.")
1872
+
1873
+ return self.dataset_size
1874
+
1875
+ def num_classes(self):
1876
+ """
1877
+ Get the number of classes in a dataset.
1878
+
1879
+ Returns:
1880
+ int, number of classes.
1881
+
1882
+ Examples:
1883
+ >>> import mindspore.dataset as ds
1884
+ >>> # Read image files
1885
+ >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
1886
+ >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir)
1887
+ >>> # Check how many classes exist in image folder
1888
+ >>> num_classes = dataset.num_classes()
1889
+ """
1890
+ if self._num_classes is None:
1891
+ runtime_getter = self._init_tree_getters()
1892
+ self._num_classes = runtime_getter[0].GetNumClasses()
1893
+
1894
+ if self._num_classes == -1:
1895
+ return None
1896
+ return self._num_classes
1897
+
1898
+ def get_sync_notifiers(self):
1899
+ if self.children:
1900
+ return self.children[0].get_sync_notifiers()
1901
+ return {}
1902
+
1903
+ def disable_sync(self):
1904
+ if self.children:
1905
+ return self.children[0].disable_sync()
1906
+ return {}
1907
+
1908
+ def is_sync(self):
1909
+ if self.children:
1910
+ return self.children[0].is_sync()
1911
+ return False
1912
+
1913
+ @check_sync_update
1914
+ def sync_update(self, condition_name, num_batch=None, data=None):
1915
+ """
1916
+ Release a blocking condition and trigger callback with given data.
1917
+
1918
+ Args:
1919
+ condition_name (str): The condition name that is used to toggle sending next row.
1920
+ num_batch (Union[int, None]): The number of batches (rows) that are released.
1921
+ When `num_batch` is ``None``, it will default to the number specified by the
1922
+ `sync_wait` operation. Default: ``None``.
1923
+ data (Any): The data passed to the callback, user defined. Default: ``None``.
1924
+
1925
+ Examples:
1926
+ >>> import numpy as np
1927
+ >>> import mindspore.dataset as ds
1928
+ >>>
1929
+ >>> def gen():
1930
+ ... for i in range(100):
1931
+ ... yield (np.array(i),)
1932
+ >>>
1933
+ >>> class Augment:
1934
+ ... def __init__(self, loss):
1935
+ ... self.loss = loss
1936
+ ...
1937
+ ... def preprocess(self, input_):
1938
+ ... return input_
1939
+ ...
1940
+ ... def update(self, data):
1941
+ ... self.loss = data["loss"]
1942
+ >>>
1943
+ >>> batch_size = 10
1944
+ >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
1945
+ >>> aug = Augment(0)
1946
+ >>> dataset = dataset.sync_wait(condition_name='', num_batch=1)
1947
+ >>> dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
1948
+ >>> dataset = dataset.batch(batch_size)
1949
+ >>>
1950
+ >>> count = 0
1951
+ >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
1952
+ ... count += 1
1953
+ ... data = {"loss": count}
1954
+ ... dataset.sync_update(condition_name="", data=data)
1955
+ """
1956
+ if (not isinstance(num_batch, int) and num_batch is not None) or \
1957
+ (isinstance(num_batch, int) and num_batch <= 0):
1958
+ # throwing exception, disable all sync_wait in pipeline
1959
+ self.disable_sync()
1960
+ raise RuntimeError("Sync_update batch size can only be positive integer, got : {}.".format(num_batch))
1961
+ notifiers_dict = self.get_sync_notifiers()
1962
+ if not isinstance(condition_name, str):
1963
+ raise TypeError("Argument condition_name with value {} is not of type str, but got {}."
1964
+ .format(condition_name, type(condition_name)))
1965
+ if condition_name not in notifiers_dict:
1966
+ # throwing exception, disable all sync_wait in pipeline
1967
+ self.disable_sync()
1968
+ raise RuntimeError("Condition name not found.")
1969
+ if num_batch is not None:
1970
+ num_batch *= self.get_batch_size()
1971
+ notifiers_dict[condition_name](num_batch, data)
1972
+
1973
+ def get_batch_size(self):
1974
+ """
1975
+ Return the size of batch.
1976
+
1977
+ Returns:
1978
+ int, the batch size of data.
1979
+
1980
+ Examples:
1981
+ >>> import mindspore.dataset as ds
1982
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
1983
+ >>> dataset = dataset.batch(2)
1984
+ >>> batch_size = dataset.get_batch_size()
1985
+ >>> print(batch_size)
1986
+ 2
1987
+ """
1988
+ if self._batch_size is None:
1989
+ runtime_getter = self._init_tree_getters()
1990
+ self._batch_size = runtime_getter[0].GetBatchSize()
1991
+ if self._batch_size is None:
1992
+ self._batch_size = 1
1993
+ return self._batch_size
1994
+
1995
+ def get_repeat_count(self):
1996
+ """
1997
+ Get the replication times in RepeatDataset. Default: ``1`` .
1998
+
1999
+ Returns:
2000
+ int, the count of repeat.
2001
+
2002
+ Examples:
2003
+ >>> import mindspore.dataset as ds
2004
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
2005
+ >>> dataset = dataset.repeat(5)
2006
+ >>> repeat_count = dataset.get_repeat_count()
2007
+ >>> print(repeat_count)
2008
+ 5
2009
+ """
2010
+ if self._repeat_count is None:
2011
+ runtime_getter = self._init_tree_getters()
2012
+ self._repeat_count = runtime_getter[0].GetRepeatCount()
2013
+ if self._repeat_count is None:
2014
+ self._repeat_count = 1
2015
+ return self._repeat_count
2016
+
2017
+ def get_class_indexing(self):
2018
+ """
2019
+ Get the mapping dictionary from category names to category indexes.
2020
+
2021
+ This dictionary can be used to look up which category name corresponds to a particular category index.
2022
+
2023
+ Returns:
2024
+ Dict[str, int], the mappings from category names to category indexes.
2025
+
2026
+ Examples:
2027
+ >>> import mindspore.dataset as ds
2028
+ >>> # Read image files
2029
+ >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
2030
+ >>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir)
2031
+ >>> # Check how many classes exist in image folder
2032
+ >>> class_indexing = dataset.get_class_indexing()
2033
+ """
2034
+ if self.children:
2035
+ return self.children[0].get_class_indexing()
2036
+ return {}
2037
+
2038
+ def reset(self):
2039
+ """
2040
+ Reset the dataset for next epoch.
2041
+
2042
+ Examples:
2043
+ >>> import mindspore.dataset as ds
2044
+ >>> mind_dataset_dir = ["/path/to/mind_dataset_file"]
2045
+ >>> dataset = ds.MindDataset(dataset_files=mind_dataset_dir)
2046
+ >>> for _ in range(5):
2047
+ ... num_iter = 0
2048
+ ... for data in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
2049
+ ... num_iter += 1
2050
+ ... dataset.reset()
2051
+ """
2052
+
2053
+ def is_shuffled(self):
2054
+ """Returns True if the dataset or its children is shuffled."""
2055
+ for input_dataset in self.children:
2056
+ if input_dataset.is_shuffled():
2057
+ return True
2058
+
2059
+ return False
2060
+
2061
+ def is_sharded(self):
2062
+ """Returns True if the dataset or its children is sharded."""
2063
+ for input_dataset in self.children:
2064
+ if input_dataset.is_sharded():
2065
+ return True
2066
+
2067
+ return False
2068
+
2069
+ def parse(self, children=None):
2070
+ raise NotImplementedError("Dataset has to implement parse method.")
2071
+
2072
+ def __len__(self):
2073
+ """
2074
+ Get the length of dataset.
2075
+
2076
+ Returns:
2077
+ int, the length of dataset.
2078
+ """
2079
+ return self.get_dataset_size()
2080
+
2081
+ @staticmethod
2082
+ def _update_data_shard(num_shards, shard_id):
2083
+ """
2084
+ Update the shard number and shard id if necessary.
2085
+ This is normally used in distributed training mode like Parameter Server training.
2086
+ """
2087
+ # If this is in distributed execution mode,
2088
+ # the shard number and shard id might need to be updated according to the process's rank or role.
2089
+ worker_num = _get_ps_context("worker_num")
2090
+ server_num = _get_ps_context("server_num")
2091
+ if _is_role_pserver() and _enable_distributed_mindrt() and (worker_num != server_num):
2092
+ num_shards = worker_num
2093
+ shard_id = 0
2094
+ return num_shards, shard_id
2095
+
2096
+ def pre_parse(self, getter_mode):
2097
+ if getter_mode:
2098
+ if hasattr(self, "python_multiprocessing"):
2099
+ self.python_multiprocessing = False
2100
+ if hasattr(self, "num_parallel_workers"):
2101
+ self.num_parallel_workers = 1
2102
+
2103
+ def post_parse(self, ir_node):
2104
+ if self.cache:
2105
+ ir_node = ir_node.set_cache_client(self.cache.cache_client)
2106
+ if self.num_parallel_workers:
2107
+ ir_node = ir_node.set_num_workers(self.num_parallel_workers)
2108
+
2109
+ return ir_node
2110
+
2111
+ def set_init_step(self, init_step):
2112
+ self._global_step = init_step
2113
+
2114
+ def get_init_step(self):
2115
+ if self._global_step is not None:
2116
+ return self._global_step
2117
+ if len(self.children) == 1:
2118
+ return self.children[0].get_init_step()
2119
+ # When there are multiple children, we cannot tell from which child to get the initial step,
2120
+ # so we initialize from the beginning
2121
+ return 0
2122
+
2123
+
2124
+ class VisionBaseDataset(Dataset):
2125
+ """
2126
+ Abstract class to represent a vision source dataset which produces content to the data pipeline.
2127
+ """
2128
+
2129
+ def __init__(self, children=None, num_parallel_workers=None, cache=None):
2130
+ super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2131
+
2132
+ def parse(self, children=None):
2133
+ raise NotImplementedError("Dataset has to implement parse method.")
2134
+
2135
+
2136
+ class TextBaseDataset(Dataset):
2137
+ """
2138
+ Abstract class to represent a text source dataset which produces content to the data pipeline.
2139
+ """
2140
+
2141
+ def __init__(self, children=None, num_parallel_workers=None, cache=None):
2142
+ super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2143
+
2144
+ def parse(self, children=None):
2145
+ raise NotImplementedError("Dataset has to implement parse method.")
2146
+
2147
+ def build_vocab(self, columns, freq_range, top_k, special_tokens, special_first):
2148
+ """
2149
+ Function to create a Vocab from source dataset.
2150
+ Desired source dataset is a text type dataset.
2151
+
2152
+ Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
2153
+ which contains top_k most frequent words (if top_k is specified).
2154
+
2155
+ Note:
2156
+ mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0
2157
+ and will be removed in a future version. Use mindspore.dataset.text.Vocab.from_dataset instead.
2158
+
2159
+ Args:
2160
+ columns(Union[str, list[str]]): Column names to get words from.
2161
+ freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency
2162
+ range will be stored.
2163
+ Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
2164
+ can be set to default, which corresponds to 0/total_words separately.
2165
+ top_k(int): Number of words to be built into vocab. top_k most frequent words are
2166
+ taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
2167
+ special_tokens(list[str]): A list of strings, each one is a special token.
2168
+ special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens
2169
+ is specified and special_first is set to default, special_tokens will be prepended.
2170
+
2171
+ Returns:
2172
+ Vocab, vocab built from the dataset.
2173
+ """
2174
+ warnings.warn("mindspore.dataset.Dataset.build_vocab is deprecated from version 2.0 "
2175
+ "and will be removed in a future version. "
2176
+ "Use mindspore.dataset.text.Vocab.from_dataset instead.", DeprecationWarning)
2177
+
2178
+ def build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params):
2179
+ """
2180
+ Function to create a SentencePieceVocab from source dataset.
2181
+ Desired source dataset is a text type dataset.
2182
+
2183
+ Note:
2184
+ mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0
2185
+ and will be removed in a future version. Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead.
2186
+
2187
+ Args:
2188
+ columns(list[str]): Column names to get words from.
2189
+ vocab_size(int): Vocabulary size.
2190
+ character_coverage(float): Percentage of characters covered by the model, must be between
2191
+ 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
2192
+ Japanese or Chinese character sets, and 1.0 for other languages with small character sets
2193
+ like English or Latin.
2194
+ model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
2195
+ The input sentence must be pre-tokenized when using word type.
2196
+ params(dict): Any extra optional parameters of sentencepiece library according to your raw data
2197
+
2198
+ Returns:
2199
+ SentencePieceVocab, vocab built from the dataset.
2200
+ """
2201
+ warnings.warn("mindspore.dataset.Dataset.build_sentencepiece_vocab is deprecated from version 2.0 "
2202
+ "and will be removed in a future version. "
2203
+ "Use mindspore.dataset.text.SentencePieceVocab.from_dataset instead.", DeprecationWarning)
2204
+
2205
+ def _build_vocab(self, columns, freq_range, top_k, special_tokens, special_first):
2206
+ """
2207
+ Function to create a Vocab from source dataset.
2208
+ Desired source dataset is a text type dataset.
2209
+
2210
+ Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
2211
+ which contains top_k most frequent words (if top_k is specified).
2212
+
2213
+ Args:
2214
+ columns(Union[str, list[str]]): Column names to get words from.
2215
+ freq_range(tuple[int]): A tuple of integers (min_frequency, max_frequency). Words within the frequency
2216
+ range will be stored.
2217
+ Naturally 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
2218
+ can be set to default, which corresponds to 0/total_words separately.
2219
+ top_k(int): Number of words to be built into vocab. top_k most frequent words are
2220
+ taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
2221
+ special_tokens(list[str]): A list of strings, each one is a special token.
2222
+ special_first(bool): Whether special_tokens will be prepended/appended to vocab, If special_tokens
2223
+ is specified and special_first is set to default, special_tokens will be prepended.
2224
+
2225
+ Returns:
2226
+ Vocab, vocab built from the dataset.
2227
+ """
2228
+ vocab = cde.Vocab()
2229
+ columns = replace_none(columns, [])
2230
+ if not isinstance(columns, list):
2231
+ columns = [columns]
2232
+
2233
+ freq_range = replace_none(freq_range, (0, 9223372036854775807))
2234
+ if freq_range[0] is None:
2235
+ freq_range = (0, freq_range[1])
2236
+ if freq_range[1] is None:
2237
+ freq_range = (freq_range[0], 9223372036854775807)
2238
+ special_tokens = replace_none(special_tokens, [])
2239
+ top_k = replace_none(top_k, 9223372036854775807)
2240
+
2241
+ ir_tree, api_tree = self.create_ir_tree()
2242
+
2243
+ # vocab node
2244
+ vocab_node = cde.BuildVocabNode(ir_tree, vocab, columns, freq_range, top_k, special_tokens, special_first)
2245
+
2246
+ runtime_context = cde.PythonRuntimeContext()
2247
+ runtime_context.Init()
2248
+
2249
+ # build vocab
2250
+ consumer = cde.PythonBuildVocabConsumer()
2251
+ consumer.Init(vocab_node)
2252
+ runtime_context.AssignConsumer(consumer)
2253
+
2254
+ consumer.Start()
2255
+ del api_tree
2256
+
2257
+ return vocab
2258
+
2259
+ def _build_sentencepiece_vocab(self, columns, vocab_size, character_coverage, model_type, params):
2260
+ """
2261
+ Function to create a SentencePieceVocab from source dataset.
2262
+ Desired source dataset is a text type dataset.
2263
+
2264
+ Args:
2265
+ columns(list[str]): Column names to get words from.
2266
+ vocab_size(int): Vocabulary size.
2267
+ character_coverage(float): Percentage of characters covered by the model, must be between
2268
+ 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
2269
+ Japanese or Chinese character sets, and 1.0 for other languages with small character sets
2270
+ like English or Latin.
2271
+ model_type(SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
2272
+ The input sentence must be pre-tokenized when using word type.
2273
+ params(dict): Any extra optional parameters of sentencepiece library according to your raw data
2274
+
2275
+ Returns:
2276
+ SentencePieceVocab, vocab built from the dataset.
2277
+ """
2278
+ if not isinstance(model_type, SentencePieceModel):
2279
+ raise TypeError("Argument model_type with value {0} is not of type SentencePieceModel, but got {1}." \
2280
+ .format(model_type, type(model_type)))
2281
+ model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type]
2282
+ vocab = cde.SentencePieceVocab()
2283
+
2284
+ ir_tree, api_tree = self.create_ir_tree()
2285
+
2286
+ # vocab node
2287
+ vocab_node = cde.BuildSentenceVocabNode(ir_tree, vocab, columns, vocab_size, character_coverage, model_type,
2288
+ params)
2289
+
2290
+ runtime_context = cde.PythonRuntimeContext()
2291
+ runtime_context.Init()
2292
+
2293
+ # build vocab
2294
+ consumer = cde.PythonBuildVocabConsumer()
2295
+ consumer.Init(vocab_node)
2296
+ runtime_context.AssignConsumer(consumer)
2297
+
2298
+ consumer.Start()
2299
+ del api_tree
2300
+
2301
+ return vocab
2302
+
2303
+
2304
+ class AudioBaseDataset(Dataset):
2305
+ """
2306
+ Abstract class to represent a audio source dataset which produces content to the data pipeline.
2307
+ """
2308
+
2309
+ def __init__(self, children=None, num_parallel_workers=None, cache=None):
2310
+ super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2311
+
2312
+ def parse(self, children=None):
2313
+ raise NotImplementedError("Dataset has to implement parse method.")
2314
+
2315
+
2316
+ class UnionBaseDataset(VisionBaseDataset, TextBaseDataset, AudioBaseDataset):
2317
+ """
2318
+ Abstract class to represent a union source dataset which produces content to the data pipeline.
2319
+ """
2320
+
2321
+ def __init__(self, children=None, num_parallel_workers=None, cache=None):
2322
+ super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
2323
+
2324
+ def parse(self, children=None):
2325
+ raise NotImplementedError("Dataset has to implement parse method.")
2326
+
2327
+
2328
+ class SourceDataset(Dataset):
2329
+ """
2330
+ Abstract class to represent a source dataset which produces content to the data pipeline.
2331
+ """
2332
+
2333
+ def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None,
2334
+ cache=None):
2335
+ super().__init__(num_parallel_workers=num_parallel_workers, cache=cache)
2336
+ self.num_samples = replace_none(num_samples, 0)
2337
+ self.num_shards = replace_none(num_shards, 1)
2338
+ self.shard_id = replace_none(shard_id, 0)
2339
+
2340
+ if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
2341
+ raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
2342
+ "'Shuffle.FILES' or 'Shuffle.INFILE'.")
2343
+
2344
+ self.shuffle_flag = 2 # Global shuffle
2345
+ if not isinstance(shuffle, Shuffle):
2346
+ if shuffle is None or shuffle:
2347
+ self.shuffle_flag = 2 # Global shuffle
2348
+ else:
2349
+ self.shuffle_flag = 0 # No shuffle
2350
+ else:
2351
+ if shuffle == Shuffle.GLOBAL:
2352
+ self.shuffle_flag = 2 # Global shuffle
2353
+ elif shuffle == Shuffle.FILES:
2354
+ self.shuffle_flag = 1 # Files shuffle
2355
+ elif shuffle == Shuffle.INFILE:
2356
+ self.shuffle_flag = 3 # Infile shuffle
2357
+
2358
+ def parse(self, children=None):
2359
+ raise NotImplementedError("Dataset has to implement parse method.")
2360
+
2361
+ @staticmethod
2362
+ def _find_files(patterns):
2363
+ """
2364
+ Utility function to search for files with the given glob patterns.
2365
+
2366
+ Args:
2367
+ patterns (Union[str, list[str]]): String or list of patterns to be searched.
2368
+
2369
+ Returns:
2370
+ list, list of files.
2371
+ """
2372
+
2373
+ if not isinstance(patterns, list):
2374
+ patterns = [patterns]
2375
+
2376
+ file_list = []
2377
+ unmatched_patterns = []
2378
+ for pattern in patterns:
2379
+ matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
2380
+
2381
+ if matches:
2382
+ file_list.extend(matches)
2383
+ else:
2384
+ unmatched_patterns.append(pattern)
2385
+
2386
+ if unmatched_patterns:
2387
+ raise ValueError("The following patterns did not match any files: {}.".format(unmatched_patterns))
2388
+
2389
+ if file_list: # not empty
2390
+ return file_list
2391
+ raise ValueError("The list of path names matching the patterns is empty.")
2392
+
2393
+ def is_shuffled(self):
2394
+ return self.shuffle_flag > 0
2395
+
2396
+ def is_sharded(self):
2397
+ if self.num_shards is not None:
2398
+ return self.num_shards > 1
2399
+ return False
2400
+
2401
+
2402
+ class MappableDataset(SourceDataset):
2403
+ """
2404
+ Abstract class to represent a source dataset which supports use of samplers.
2405
+ """
2406
+
2407
+ def parse(self, children=None):
2408
+ raise NotImplementedError("Dataset has to implement parse method.")
2409
+
2410
+ def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
2411
+ shard_id=None, cache=None):
2412
+ num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
2413
+ super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
2414
+ num_shards=num_shards, shard_id=shard_id, cache=cache)
2415
+ self.shuffle_flag = replace_none(shuffle, True)
2416
+ self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
2417
+
2418
+ def add_sampler(self, new_sampler):
2419
+ """
2420
+ Add a child sampler for the current dataset.
2421
+
2422
+ Args:
2423
+ new_sampler (Sampler): The child sampler to be added.
2424
+
2425
+ Examples:
2426
+ >>> import mindspore.dataset as ds
2427
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
2428
+ >>>
2429
+ >>> new_sampler = ds.DistributedSampler(10, 2)
2430
+ >>> dataset.add_sampler(new_sampler)
2431
+ """
2432
+ # Note: By adding a sampler, the sampled IDs will flow to the new_sampler
2433
+ # after first passing through the current samplers attached to this dataset.
2434
+ self.dataset_size = None
2435
+ new_sampler.add_child(self.sampler)
2436
+ self.sampler = new_sampler
2437
+
2438
+ def use_sampler(self, new_sampler):
2439
+ """
2440
+ Replace the last child sampler of the current dataset, remaining the parent sampler unchanged.
2441
+
2442
+ Args:
2443
+ new_sampler (Sampler): The new sampler to replace with.
2444
+
2445
+ Examples:
2446
+ >>> import mindspore.dataset as ds
2447
+ >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1")
2448
+ >>>
2449
+ >>> # use a DistributedSampler instead
2450
+ >>> new_sampler = ds.DistributedSampler(10, 2)
2451
+ >>> dataset.use_sampler(new_sampler)
2452
+ """
2453
+ if new_sampler is None:
2454
+ raise TypeError("Input sampler can not be None.")
2455
+ if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
2456
+ raise TypeError("Input sampler is not an instance of a sampler.")
2457
+ self.dataset_size = None
2458
+
2459
+ self.sampler = self.sampler.child_sampler
2460
+ self.add_sampler(new_sampler)
2461
+
2462
+ def is_shuffled(self):
2463
+ return self.sampler.is_shuffled()
2464
+
2465
+ def is_sharded(self):
2466
+ return self.sampler.is_sharded()
2467
+
2468
+ @check_split
2469
+ def split(self, sizes, randomize=True):
2470
+ """
2471
+ Split the dataset into smaller, non-overlapping datasets.
2472
+
2473
+ Args:
2474
+ sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
2475
+ provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
2476
+ respectively. If the sum of all sizes does not equal the original dataset size, an
2477
+ error will occur.
2478
+ If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
2479
+ and must sum to 1, otherwise an error will occur. The dataset will be split into n
2480
+ Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
2481
+ original dataset.
2482
+ If after rounding:
2483
+
2484
+ - Any size equals 0, an error will occur.
2485
+ - The sum of split sizes < K, the difference will be added to the first split.
2486
+ - The sum of split sizes > K, the difference will be removed from the first large
2487
+ enough split such that it will have at least 1 row after removing the difference.
2488
+
2489
+ randomize (bool, optional): Determines whether or not to split the data randomly. Default: ``True``.
2490
+ If ``True``, the data will be randomly split. Otherwise, each split will be created with
2491
+ consecutive rows from the dataset.
2492
+
2493
+ Note:
2494
+ 1. There is an optimized split function, which will be called automatically when the dataset
2495
+ that calls this function is a MappableDataset.
2496
+ 2. Dataset should not be sharded if split is going to be called. Instead, create a
2497
+ :class:`mindspore.dataset.DistributedSampler` and specify a split to shard after splitting.
2498
+ If the dataset is sharded after a split, it is strongly recommended setting the same
2499
+ seed in each instance of execution, otherwise each shard may not be part of the same
2500
+ split (see Examples).
2501
+ 3. It is strongly recommended to not shuffle the dataset, but set `randomize` to ``True`` instead.
2502
+ Shuffling the dataset may not be deterministic, which means the data in each split
2503
+ will be different in each epoch. Furthermore, if sharding occurs after split, each
2504
+ shard may not be part of the same split.
2505
+
2506
+ Returns:
2507
+ Tuple[Dataset], a tuple of new datasets split from the original one.
2508
+
2509
+ Raises:
2510
+ RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
2511
+ RuntimeError: If `sizes` is list of integers and sum of all elements in sizes does not
2512
+ equal the dataset size.
2513
+ RuntimeError: If `sizes` is list of float and there is a split with size 0 after calculations.
2514
+ RuntimeError: If the dataset is sharded prior to calling split.
2515
+ ValueError: If `sizes` is list of float and not all floats are between 0 and 1, or if the
2516
+ floats don't sum to 1.
2517
+
2518
+ Examples:
2519
+ >>> import mindspore.dataset as ds
2520
+ >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
2521
+ >>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
2522
+ >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, shuffle=False)
2523
+ >>>
2524
+ >>> # Set the seed, and tell split to use this seed when randomizing.
2525
+ >>> # This is needed because sharding will be done later
2526
+ >>> ds.config.set_seed(58)
2527
+ >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
2528
+ >>>
2529
+ >>> # To shard the train dataset, use a DistributedSampler
2530
+ >>> train_sampler = ds.DistributedSampler(10, 2)
2531
+ >>> train_dataset.use_sampler(train_sampler)
2532
+ """
2533
+ if self.is_shuffled():
2534
+ logger.warning("Dataset is shuffled before split.")
2535
+
2536
+ if self.is_sharded():
2537
+ raise RuntimeError("Dataset should not be sharded before split.")
2538
+
2539
+ absolute_sizes = self._get_absolute_split_sizes(sizes)
2540
+ splits = []
2541
+ current_split_start_index = 0
2542
+ for size in absolute_sizes:
2543
+ ds = copy.deepcopy(self)
2544
+ ds.dataset_size = None
2545
+ if randomize:
2546
+ # want to shuffle the same way every epoch before split, we are assuming
2547
+ # that the user will call set_seed
2548
+ random_sampler = samplers.RandomSampler()
2549
+ random_sampler.reshuffle_each_epoch = False
2550
+ ds.add_sampler(random_sampler)
2551
+
2552
+ subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
2553
+ ds.add_sampler(subset_sampler)
2554
+
2555
+ # add sequential sampler, so that if user calls use_sampler, we will
2556
+ # get rid of the sequential sampler instead of something we need
2557
+ ds.add_sampler(samplers.SequentialSampler())
2558
+
2559
+ splits.append(ds)
2560
+
2561
+ current_split_start_index += size
2562
+
2563
+ return tuple(splits)
2564
+
2565
+
2566
+ class BucketBatchByLengthDataset(UnionBaseDataset):
2567
+ """
2568
+ The result of applying BucketBatchByLength operation to the input dataset.
2569
+ """
2570
+
2571
+ def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function,
2572
+ pad_info, pad_to_bucket_boundary, drop_remainder):
2573
+ super().__init__(children=input_dataset)
2574
+
2575
+ self.column_names = to_list(column_names)
2576
+ self.bucket_boundaries = replace_none(bucket_boundaries, [])
2577
+ self.bucket_batch_sizes = replace_none(bucket_batch_sizes, [])
2578
+ self.element_length_function = element_length_function
2579
+ self.pad_info = replace_none(pad_info, {})
2580
+ self.pad_to_bucket_boundary = replace_none(pad_to_bucket_boundary, False)
2581
+ self.drop_remainder = replace_none(drop_remainder, False)
2582
+
2583
+ def parse(self, children=None):
2584
+ return cde.BucketBatchByLengthNode(children[0], self.column_names, self.bucket_boundaries,
2585
+ self.bucket_batch_sizes, self.element_length_function, self.pad_info,
2586
+ self.pad_to_bucket_boundary, self.drop_remainder)
2587
+
2588
+
2589
+ def _check_shm_usage(num_worker, queue_size, in_rowsize, out_rowsize):
2590
+ """
2591
+ Check sufficient shared memory is available for shared memory queues
2592
+ when training in parallel mode.
2593
+ """
2594
+ threshold_ratio = 0.8
2595
+ # Verify available size only when using static shared memory on Linux
2596
+ if platform.system().lower() not in {"windows", "darwin"} and in_rowsize != -1 and out_rowsize != -1:
2597
+ device_num = _get_device_num()
2598
+ # In the cluster, _get_device_num indicates the number of the entire cluster. The maximum number of cards
2599
+ # on the ascend server is 8.
2600
+ if device_num > 1:
2601
+ device_num = min(device_num, 8)
2602
+ shm_estimate_usage = device_num * num_worker * \
2603
+ (queue_size + 2) * (in_rowsize + out_rowsize) * 1024 * 1024
2604
+ try:
2605
+ shm_available = psutil.disk_usage('/dev/shm').free
2606
+ if shm_estimate_usage >= threshold_ratio * shm_available:
2607
+ raise RuntimeError(
2608
+ "Insufficient shared memory available. Required: {}, Available: {}. "
2609
+ "The required memory can't exceed 80% of the available shared memory, "
2610
+ "it's recommended to reduce memory usage by following methods:\n"
2611
+ "1. reduce value of parameter max_rowsize or num_parallel_workers.\n"
2612
+ "2. reduce prefetch size by set_prefetch_size().\n"
2613
+ "3. disable shared memory by set_enable_shared_mem().".format(shm_estimate_usage, shm_available))
2614
+ except FileNotFoundError:
2615
+ raise RuntimeError("Expected /dev/shm to exist.")
2616
+
2617
+
2618
+ class BatchDataset(UnionBaseDataset):
2619
+ """
2620
+ The result of applying Batch operation to the input dataset.
2621
+
2622
+ Args:
2623
+ input_dataset (Dataset): Input Dataset to be batched.
2624
+ batch_size (Union[int, function]): The number of rows each batch is created with. An
2625
+ int or callable which takes exactly 1 parameter, BatchInfo.
2626
+ drop_remainder (bool, optional): Determines whether or not to drop the last
2627
+ possibly incomplete batch. Default: ``False``. If True, and if there are less
2628
+ than batch_size rows available to make the last batch, then those rows will
2629
+ be dropped and not propagated to the child node.
2630
+ num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``.
2631
+ per_batch_map (callable, optional): Per batch map callable. A callable which takes
2632
+ (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of
2633
+ Tensors on a given column. The number of lists should match with number of entries in input_columns. The
2634
+ last parameter of the callable must always be a BatchInfo object.
2635
+ input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must
2636
+ match with signature of per_batch_map callable.
2637
+ output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by
2638
+ the last operation. This parameter is mandatory if len(input_columns) !=
2639
+ len(output_columns). The size of this list must match the number of output
2640
+ columns of the last operation. Default: ``None``, output columns will have the same
2641
+ name as the input columns, i.e., the columns will be replaced.
2642
+ max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
2643
+ allocation to copy data between processes, the total occupied shared memory will increase as
2644
+ ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
2645
+ shared memory will be dynamically allocated with the actual size of data. This is only used if
2646
+ ``python_multiprocessing`` is set to True. If it is an int value, it represents
2647
+ ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory.
2648
+ If it is a list, the first element represents the ``input_columns`` use this value as the unit to
2649
+ create shared memory, and the second element represents ``output_columns`` use this value as the unit
2650
+ to create shared memory. Default: ``None`` , allocate shared memory dynamically.
2651
+
2652
+ """
2653
+
2654
+ def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
2655
+ input_columns=None, output_columns=None, python_multiprocessing=False, max_rowsize=None):
2656
+ super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2657
+
2658
+ if BatchDataset._is_ancestor_of_repeat(input_dataset):
2659
+ logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
2660
+
2661
+ BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2662
+
2663
+ # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function
2664
+ self.batch_size = batch_size if not callable(batch_size) else 1
2665
+ self.batch_size_func = None if not callable(batch_size) else batch_size
2666
+
2667
+ self.drop_remainder = replace_none(drop_remainder, False)
2668
+
2669
+ self.per_batch_map = per_batch_map
2670
+
2671
+ self.input_columns = to_list(input_columns)
2672
+ self.output_columns = to_list(output_columns)
2673
+
2674
+ self.python_multiprocessing = python_multiprocessing
2675
+ self.process_pool = None
2676
+ if max_rowsize is None:
2677
+ self.max_rowsize = [-1, -1]
2678
+ elif isinstance(max_rowsize, int):
2679
+ self.max_rowsize = [max_rowsize * self.batch_size] * 2 if max_rowsize != -1 else [max_rowsize, max_rowsize]
2680
+ else:
2681
+ self.max_rowsize = [max_rowsize[0] * self.batch_size, max_rowsize[1] * self.batch_size]
2682
+
2683
+ def __del__(self):
2684
+ if hasattr(self, "process_pool") and self.process_pool is not None:
2685
+ self.process_pool.terminate()
2686
+ del self.process_pool
2687
+
2688
+ def parse(self, children=None):
2689
+ return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, False, self.input_columns,
2690
+ self.output_columns, self.batch_size_func, self.per_batch_map, {},
2691
+ self.process_pool)
2692
+
2693
+ @staticmethod
2694
+ def _is_ancestor_of_repeat(dataset):
2695
+ """
2696
+ Utility function to find the case where repeat is used before batch.
2697
+
2698
+ Args:
2699
+ dataset (Dataset): Dataset to be checked.
2700
+
2701
+ Returns:
2702
+ bool, whether repeat is used before batch.
2703
+ """
2704
+ if isinstance(dataset, RepeatDataset):
2705
+ return True
2706
+ flag = False
2707
+ for input_dataset in dataset.children:
2708
+ flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
2709
+ return flag
2710
+
2711
+ @staticmethod
2712
+ def _update_batch_size_for_syncwait(dataset, batch_size):
2713
+ """
2714
+ Utility function to notify batch size to sync_wait.
2715
+
2716
+ Args:
2717
+ dataset (Dataset): Dataset to be checked.
2718
+ batch_size (int): batch size to notify.
2719
+ """
2720
+ if isinstance(dataset, SyncWaitDataset):
2721
+ dataset.update_sync_batch_size(batch_size)
2722
+ for input_dataset in dataset.children:
2723
+ BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2724
+
2725
+ def __deepcopy__(self, memodict):
2726
+ return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__"))
2727
+
2728
+ # Iterator bootstrap will be called on iterator construction.
2729
+ # A deep copy of Dataset object is created prior of iterator_bootstrap.
2730
+ # This method will create per iterator process pool and bind pyfunc execution to the pool.
2731
+ def iterator_bootstrap(self):
2732
+ """
2733
+ Per iterator bootstrap callback.
2734
+ """
2735
+ if self.python_multiprocessing and platform.system().lower() == 'windows':
2736
+ logger.warning("Python multiprocessing is not supported on Windows platform.")
2737
+ if self.python_multiprocessing and get_debug_mode():
2738
+ logger.warning("Python multiprocessing is not supported in debug mode."
2739
+ " Ignoring Python multiprocessing for batch operation.")
2740
+ self.python_multiprocessing = False
2741
+ if self.python_multiprocessing and platform.system().lower() != 'windows':
2742
+ if self.per_batch_map is None:
2743
+ logger.warning("per_batch_map is None so python_multiprocessing is ignored for batch.")
2744
+ return
2745
+
2746
+ # If user didn't specify num_parallel_workers, set it to default
2747
+ if self.num_parallel_workers is None:
2748
+ self.num_parallel_workers = get_num_parallel_workers()
2749
+
2750
+ self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, [self.per_batch_map],
2751
+ self.max_rowsize)
2752
+ # Wrap per_batch_map into _PythonCallable
2753
+ self.per_batch_map = _PythonCallable(self.per_batch_map, 0, self.process_pool)
2754
+ else:
2755
+ if self.per_batch_map is not None:
2756
+ self.per_batch_map = FuncWrapper(self.per_batch_map)
2757
+
2758
+
2759
+ class BatchInfo(cde.CBatchInfo):
2760
+ """
2761
+ This class helps to get dataset information dynamically when the input of `batch_size` or `per_batch_map`
2762
+ in `batch` operation is a callable object.
2763
+ """
2764
+
2765
+ def get_batch_num(self):
2766
+ """
2767
+ Return the batch number being processed in current epoch, start from 0.
2768
+
2769
+ Examples:
2770
+ >>> # Create a dataset where its batch size is dynamic
2771
+ >>> # Define a callable batch size function and let batch size increase 1 each time.
2772
+ >>> import mindspore.dataset as ds
2773
+ >>> from mindspore.dataset import BatchInfo
2774
+ >>>
2775
+ >>> dataset = ds.GeneratorDataset([i for i in range(3)], "column1", shuffle=False)
2776
+ >>> def add_one(BatchInfo):
2777
+ ... return BatchInfo.get_batch_num() + 1
2778
+ >>> dataset = dataset.batch(batch_size=add_one)
2779
+ >>> print(list(dataset))
2780
+ [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[2], dtype=Int64, value= [1, 2])]]
2781
+ """
2782
+ return
2783
+
2784
+ def get_epoch_num(self):
2785
+ """
2786
+ Return the epoch number, start from 0.
2787
+
2788
+ Examples:
2789
+ >>> # Create a dataset where its batch size is dynamic
2790
+ >>> # Define a callable batch size function and let batch size increase 1 each epoch.
2791
+ >>> import mindspore.dataset as ds
2792
+ >>> from mindspore.dataset import BatchInfo
2793
+ >>>
2794
+ >>> dataset = ds.GeneratorDataset([i for i in range(4)], "column1", shuffle=False)
2795
+ >>> def add_one_by_epoch(BatchInfo):
2796
+ ... return BatchInfo.get_epoch_num() + 1
2797
+ >>> dataset = dataset.batch(batch_size=add_one_by_epoch)
2798
+ >>>
2799
+ >>> result = []
2800
+ >>> epoch = 2
2801
+ >>> iterator = dataset.create_tuple_iterator(num_epochs=epoch)
2802
+ >>> for i in range(epoch):
2803
+ ... result.extend(list(iterator))
2804
+ >>> # result:
2805
+ >>> # [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[1], dtype=Int64, value= [1])],
2806
+ >>> # [Tensor(shape=[1], dtype=Int64, value= [2])], [Tensor(shape=[1], dtype=Int64, value= [3])],
2807
+ >>> # [Tensor(shape=[2], dtype=Int64, value= [0, 1])], [Tensor(shape=[2], dtype=Int64, value= [2, 3])]]
2808
+ """
2809
+ return
2810
+
2811
+
2812
+ class BlockReleasePair:
2813
+ """
2814
+ The blocking condition class used by SyncWaitDataset.
2815
+
2816
+ Args:
2817
+ init_release_rows (int): Number of lines to allow through the pipeline.
2818
+ callback (function): The callback function that will be called when release is called. Default: ``None``.
2819
+ """
2820
+
2821
+ def __init__(self, init_release_rows, callback=None):
2822
+ if isinstance(init_release_rows, int) and init_release_rows <= 0:
2823
+ raise ValueError("release_rows need to be greater than 0.")
2824
+ self.row_count = -init_release_rows
2825
+ self.cv = threading.Condition()
2826
+ self.callback = callback
2827
+ self.default_rows = init_release_rows
2828
+ self.disable = False
2829
+
2830
+ def __deepcopy__(self, memodict):
2831
+ return self
2832
+
2833
+ def reset(self):
2834
+ with self.cv:
2835
+ self.row_count = -self.default_rows
2836
+ self.cv.notify_all()
2837
+
2838
+ def update_batched_size(self, batch_size):
2839
+ # sanity check
2840
+ if isinstance(batch_size, int) and batch_size <= 0:
2841
+ raise ValueError("batch_size need to be greater than 0.")
2842
+
2843
+ # should only use before the pipeline creates
2844
+ self.row_count *= batch_size
2845
+ self.default_rows *= batch_size
2846
+
2847
+ def block_func(self):
2848
+ """
2849
+ Function for handing blocking condition.
2850
+
2851
+ Returns:
2852
+ bool, True.
2853
+ """
2854
+ with self.cv:
2855
+ # if disable is true, the always evaluate to true
2856
+ not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable),
2857
+ timeout=get_callback_timeout())
2858
+ # time_out will be False if time out occurs
2859
+ if not not_time_out:
2860
+ logger.warning("Timeout happened in sync_wait, maybe dataset.sync_update(condition=...) "
2861
+ "is not added after dataset.create_dict_iterator(...), now disabling lock.")
2862
+ self.disable = True
2863
+ self.row_count += 1
2864
+ return True
2865
+
2866
+ def release_func(self, pass_rows=None, data=None):
2867
+ with self.cv:
2868
+ if pass_rows is None:
2869
+ pass_rows = self.default_rows
2870
+ self.row_count -= pass_rows
2871
+ if self.callback is not None:
2872
+ self.callback(data)
2873
+ self.cv.notify_all()
2874
+
2875
+ def disable_lock(self):
2876
+ with self.cv:
2877
+ self.disable = True
2878
+ self.cv.notify_all()
2879
+
2880
+
2881
+ class PaddedBatchDataset(UnionBaseDataset):
2882
+ """
2883
+ The result of applying Batch operation to the input dataset.
2884
+
2885
+ Args:
2886
+ input_dataset (Dataset): Input Dataset to be batched.
2887
+ batch_size (Union[int, function]): The number of rows each batch is created with. An
2888
+ int or callable which takes exactly 1 parameter, BatchInfo.
2889
+ drop_remainder (bool, optional): Determines whether or not to drop the last
2890
+ possibly incomplete batch. Default: ``False``. If True, and if there are less
2891
+ than batch_size rows available to make the last batch, then those rows will
2892
+ be dropped and not propagated to the child node.
2893
+ num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``.
2894
+ pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
2895
+ will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
2896
+ """
2897
+
2898
+ def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None):
2899
+ super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
2900
+
2901
+ if PaddedBatchDataset._is_ancestor_of_repeat(input_dataset):
2902
+ logger.warning("Repeat is located before padded_batch, data from two epochs can be batched together.")
2903
+
2904
+ PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2905
+
2906
+ # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function
2907
+ self.batch_size = batch_size if not callable(batch_size) else 1
2908
+ self.batch_size_func = None if not callable(batch_size) else batch_size
2909
+
2910
+ self.drop_remainder = replace_none(drop_remainder, False)
2911
+
2912
+ self.pad = bool(pad_info is not None)
2913
+ self.pad_info = replace_none(pad_info, dict())
2914
+
2915
+ def parse(self, children=None):
2916
+ return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, [],
2917
+ [], self.batch_size_func, None, self.pad_info, None)
2918
+
2919
+ @staticmethod
2920
+ def _is_ancestor_of_repeat(dataset):
2921
+ """
2922
+ Utility function to find the case where repeat is used before batch.
2923
+
2924
+ Args:
2925
+ dataset (Dataset): Dataset to be checked.
2926
+
2927
+ Returns:
2928
+ bool, whether repeat is used before batch.
2929
+ """
2930
+ if isinstance(dataset, RepeatDataset):
2931
+ return True
2932
+ flag = False
2933
+ for input_dataset in dataset.children:
2934
+ flag = flag | PaddedBatchDataset._is_ancestor_of_repeat(input_dataset)
2935
+ return flag
2936
+
2937
+ @staticmethod
2938
+ def _update_batch_size_for_syncwait(dataset, batch_size):
2939
+ """
2940
+ Utility function to notify batch size to sync_wait.
2941
+
2942
+ Args:
2943
+ dataset (Dataset): Dataset to be checked.
2944
+ batch_size (int): batch size to notify.
2945
+ """
2946
+ if isinstance(dataset, SyncWaitDataset):
2947
+ dataset.update_sync_batch_size(batch_size)
2948
+ for input_dataset in dataset.children:
2949
+ PaddedBatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
2950
+
2951
+ def __deepcopy__(self, memodict):
2952
+ return self.__safe_deepcopy__(memodict, exclude=("batch_size_func", "__transfer_dataset__"))
2953
+
2954
+
2955
+ class SyncWaitDataset(UnionBaseDataset):
2956
+ """
2957
+ The result of adding a blocking condition to the input Dataset.
2958
+
2959
+ Args:
2960
+ input_dataset (Dataset): Input dataset to apply flow control.
2961
+ num_batch (int): Number of batches without blocking at the start of each epoch.
2962
+ condition_name (str): Condition name that is used to toggle sending next row.
2963
+ callback (function): Callback function that will be invoked when sync_update is called. Default: ``None``.
2964
+
2965
+ Raises:
2966
+ RuntimeError: If condition name already exists.
2967
+ """
2968
+
2969
+ def __init__(self, input_dataset, condition_name, num_batch, callback=None):
2970
+ super().__init__(children=input_dataset)
2971
+
2972
+ # set to the default value, waiting for the batch to update it
2973
+ self._condition_name = condition_name
2974
+ if isinstance(num_batch, int) and num_batch <= 0:
2975
+ raise ValueError("num_batch need to be greater than 0.")
2976
+
2977
+ self._pair = BlockReleasePair(num_batch, callback)
2978
+ if self._condition_name in self.children[0].get_sync_notifiers():
2979
+ raise RuntimeError("Condition name is already in use.")
2980
+ logger.info("Please remember to add dataset.sync_update(condition=%s), otherwise hanging will result. "
2981
+ "If dataset.sync_update(condition=%s) has already been added, you can ignore the info.",
2982
+ condition_name, condition_name)
2983
+
2984
+ def parse(self, children=None):
2985
+ return cde.SyncWaitNode(children[0], self._condition_name, self._pair.block_func)
2986
+
2987
+ def get_sync_notifiers(self):
2988
+ return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
2989
+
2990
+ def is_sync(self):
2991
+ return True
2992
+
2993
+ def update_sync_batch_size(self, batch_size):
2994
+ if isinstance(batch_size, int) and batch_size <= 0:
2995
+ raise ValueError("num_batch need to be greater than 0.")
2996
+ self._pair.update_batched_size(batch_size)
2997
+
2998
+ def disable_sync(self):
2999
+ logger.info("Disabling Sync")
3000
+ self._pair.disable_lock()
3001
+
3002
+ @staticmethod
3003
+ def _is_ancestor_of_batch(dataset):
3004
+ """
3005
+ Utility function to find the case where sync_wait is used before batch.
3006
+
3007
+ Args:
3008
+ dataset (Dataset): Dataset to be checked.
3009
+
3010
+ Returns:
3011
+ bool, whether sync_wait is used before batch.
3012
+ """
3013
+ if isinstance(dataset, (BatchDataset, PaddedBatchDataset)):
3014
+ return True
3015
+ flag = False
3016
+ for input_dataset in dataset.children:
3017
+ flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
3018
+ return flag
3019
+
3020
+ def iterator_bootstrap(self):
3021
+ self._pair.reset()
3022
+
3023
+
3024
+ class ShuffleDataset(UnionBaseDataset):
3025
+ """
3026
+ The result of applying Shuffle operation to the input Dataset.
3027
+
3028
+ Args:
3029
+ input_dataset (Dataset): Input Dataset to be shuffled.
3030
+ buffer_size (int): Size of the buffer.
3031
+
3032
+ Raises:
3033
+ RuntimeError: If exist sync operations before shuffle.
3034
+ """
3035
+
3036
+ def __init__(self, input_dataset, buffer_size):
3037
+ super().__init__(children=input_dataset)
3038
+ self.buffer_size = buffer_size
3039
+ self.reshuffle_each_epoch = True
3040
+
3041
+ if self.is_sync():
3042
+ raise RuntimeError("No shuffle after sync operators.")
3043
+
3044
+ def parse(self, children=None):
3045
+ return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch)
3046
+
3047
+ def is_shuffled(self):
3048
+ return True
3049
+
3050
+
3051
+ # Pyfunc collection for multiprocess pyfunc
3052
+ # This global variable will only be used within subprocesses
3053
+ _OP_NAME = dict()
3054
+ _OP_PROCESS = dict()
3055
+
3056
+
3057
+ # PythonCallable wrapper for multiprocess pyfunc
3058
+ class _PythonCallable:
3059
+ """
3060
+ Internal Python function wrapper for multiprocessing pyfunc.
3061
+ """
3062
+
3063
+ def __init__(self, py_callable, idx, pool=None):
3064
+ # Original Python callable from user.
3065
+ self.py_callable = py_callable
3066
+ # Process pool created for current iterator.
3067
+ self.pool = pool
3068
+ # Python callable index
3069
+ self.idx = idx
3070
+
3071
+ def __call__(self, *args):
3072
+ result = None
3073
+ get_data_from_worker_process = False
3074
+ while get_data_from_worker_process is False:
3075
+ if self.pool.is_running() and check_iterator_cleanup() is False:
3076
+ try:
3077
+ result = self.pool.execute(self.idx, *args)
3078
+ except multiprocessing.TimeoutError:
3079
+ continue
3080
+ get_data_from_worker_process = True
3081
+ else:
3082
+ # worker process is stopped
3083
+ logger.info("The worker process of map operation is stopped. "
3084
+ "So return None to main thread and break the main thread.")
3085
+ return None
3086
+ # got value from worker process
3087
+ if not isinstance(result, tuple) and get_data_from_worker_process is True:
3088
+ result = (result,)
3089
+ return result
3090
+
3091
+ def to_json(self):
3092
+ return self.py_callable.to_json()
3093
+
3094
+
3095
+ # used when python_multiprocessing=True in map
3096
+ class Pipe:
3097
+ """
3098
+ Class to handle communication between the master process and the worker processes.
3099
+ """
3100
+
3101
+ def __init__(self, warning_ctl, shared_memory=False, max_rowsize=(-1, -1)):
3102
+ self.shared_memory = shared_memory
3103
+ self.eof = multiprocessing.Event()
3104
+ if self.shared_memory:
3105
+ self.in_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[0])
3106
+ self.res_queue = _SharedQueue(1, warning_ctl, max_rowsize=max_rowsize[1])
3107
+ else:
3108
+ self.in_queue = _Queue(1)
3109
+ self.res_queue = _Queue(1)
3110
+ self.in_queue.cancel_join_thread() # Ensure that the process does not hung when exiting
3111
+
3112
+ def master_send(self, func_index, data):
3113
+ self.in_queue.put_nowait((func_index, *data))
3114
+
3115
+ def master_receive(self):
3116
+ if self.eof is None:
3117
+ raise RuntimeError("EOF is none when get data from worker.")
3118
+ if self.eof.is_set():
3119
+ return None
3120
+ return self.res_queue.get(timeout=1)
3121
+
3122
+ def master_close(self):
3123
+ self.eof.set()
3124
+ self.send_finish_signal_to_worker()
3125
+ self.send_finish_signal()
3126
+
3127
+ def send_finish_signal(self):
3128
+ self.worker_send(None)
3129
+
3130
+ def send_finish_signal_to_worker(self):
3131
+ self.master_send(0, "QUIT")
3132
+
3133
+ def worker_send(self, data):
3134
+ self.res_queue.put_until(data, timeout=1, exit_signal=self.eof)
3135
+
3136
+ def worker_receive(self):
3137
+ result = self.in_queue.get_until(timeout=1, exit_signal=self.eof)
3138
+ if result is None:
3139
+ return result
3140
+ if len(result) == 1:
3141
+ raise RuntimeError(f"Corrupted data. Worker received {len(result)} elements, it should be more than 1.")
3142
+ func_index, *data = result
3143
+ return func_index, tuple(data)
3144
+
3145
+
3146
+ def _main_process_already_exit():
3147
+ """
3148
+ Judge whether main process already exit.
3149
+ """
3150
+ ppid = os.getppid()
3151
+
3152
+ if (platform.system().lower() != 'windows' and
3153
+ not _PythonMultiprocessing.is_process_alive(ppid)):
3154
+ return True
3155
+ return False
3156
+
3157
+
3158
+ def _worker_loop(operations, pipe, worker_id):
3159
+ """
3160
+ Multiprocess worker process loop.
3161
+ """
3162
+ # Initialize C++ side signal handlers
3163
+ cde.register_worker_handlers()
3164
+
3165
+ # Ensure that the process does not hang when exiting
3166
+ pipe.res_queue.cancel_join_thread()
3167
+
3168
+ def _ignore_sigint():
3169
+ """
3170
+ We need to ignore sigint signal here so subprocesses can exit normally and clear.
3171
+ """
3172
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
3173
+
3174
+ # If the default random seed has not been changed, there is no need to fix the randomness.
3175
+ # Otherwise, set the random seed for each child process to "base_seed + worker_id" to ensure
3176
+ # that the random results of each process are different.
3177
+ if get_seed() != 5489:
3178
+ set_seed(get_seed() + worker_id)
3179
+
3180
+ while not _main_process_already_exit():
3181
+ _ignore_sigint()
3182
+
3183
+ result = pipe.worker_receive()
3184
+ if result is None:
3185
+ return
3186
+ (idx, input_tensors) = result
3187
+ if input_tensors == "QUIT":
3188
+ break
3189
+ try:
3190
+ output_tensors = operations[idx](*input_tensors)
3191
+
3192
+ pipe.worker_send(output_tensors)
3193
+ except Exception:
3194
+ pipe.worker_send(ExceptionHandler(where="in map(or batch) worker and execute Python function"))
3195
+ # Do not return
3196
+
3197
+ # release the queue when stop the worker by master
3198
+ del pipe.in_queue
3199
+ del pipe.res_queue
3200
+
3201
+
3202
+ def worker_target(operations, worker_id):
3203
+ return lambda pipe: _worker_loop(operations, pipe, worker_id)
3204
+
3205
+
3206
+ class _MPWorker(multiprocessing.Process):
3207
+ """
3208
+ Worker process for multiprocessing.
3209
+ """
3210
+
3211
+ def __init__(self, operations, warning_ctl, max_rowsize=(-1, -1), worker_id=0):
3212
+ shared_memory = get_enable_shared_mem()
3213
+ self.pipe = Pipe(warning_ctl, shared_memory=shared_memory, max_rowsize=max_rowsize)
3214
+ self.check_interval = get_multiprocessing_timeout_interval()
3215
+ super().__init__(target=worker_target(operations, worker_id), name="MapWorker" + str(worker_id),
3216
+ args=(self.pipe,), daemon=True)
3217
+
3218
+ def execute(self, idx, *args):
3219
+ """Acquiring data from a worker in an infinite loop"""
3220
+ self.pipe.master_send(idx, args)
3221
+ time_s = time.time()
3222
+ wait_count = 1
3223
+ while True:
3224
+ cost_time = time.time() - time_s
3225
+ if cost_time / self.check_interval >= wait_count:
3226
+ wait_count += 1
3227
+ logger.warning("It has been waiting for " + "%.3f" % cost_time + "s because the sub-process "
3228
+ "worker of the map operation is hanging. "
3229
+ "Check whether the user defined data transform is too slow or the "
3230
+ "output data is too large. You can also set the timeout interval by "
3231
+ "ds.config.set_multiprocessing_timeout_interval to adjust the output frequency "
3232
+ "of this log.")
3233
+ pid = self.pid
3234
+ logger.warning("Map worker subprocess ID {} is stuck.".format(pid))
3235
+ install_status, _ = subprocess.getstatusoutput("py-spy --version")
3236
+ if install_status == 0:
3237
+ stack = subprocess.getoutput("py-spy dump -p {} -l".format(pid))
3238
+ logger.warning("Map worker subprocess stack:\n{}".format(stack))
3239
+ else:
3240
+ logger.warning("Please `pip install py-spy` to get the stacks of the stuck process.")
3241
+ try:
3242
+ res = self.pipe.master_receive()
3243
+ except queue.Empty:
3244
+ continue
3245
+ if res is None:
3246
+ # receive finish signal
3247
+ return None
3248
+ if isinstance(res, ExceptionHandler):
3249
+ res.reraise()
3250
+ return res
3251
+
3252
+ def close(self):
3253
+ try:
3254
+ if self.is_alive():
3255
+ # release the eager executor which is used by current process
3256
+ transforms.transforms.clean_unused_executors()
3257
+
3258
+ logger.info(f"Closing worker with PID: {self.pid}")
3259
+ self.pipe.master_close()
3260
+ # del the handle which hold by master
3261
+ del self.pipe.in_queue
3262
+ del self.pipe.res_queue
3263
+ super().terminate()
3264
+ super().join()
3265
+ super().close()
3266
+
3267
+ except ValueError:
3268
+ # Process has been closed already
3269
+ return
3270
+ return
3271
+
3272
+ def is_alive(self):
3273
+ try:
3274
+ return super().is_alive()
3275
+ except ValueError:
3276
+ return False
3277
+
3278
+
3279
+ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
3280
+ """
3281
+ A wrapper to multiprocessing.pool that performs cleanup and ensure proper termination of forked processes.
3282
+ """
3283
+
3284
+ class _ExceptHookHandler:
3285
+ """
3286
+ Internal class ExceptionHandler
3287
+ """
3288
+
3289
+ def __init__(self):
3290
+ self.origin_hook = sys.excepthook
3291
+ sys.excepthook = self.__handler_exception
3292
+
3293
+ @staticmethod
3294
+ def mp_pool_exit_preprocess():
3295
+ if check_iterator_cleanup() is False:
3296
+ # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
3297
+ # applied to the multiprocessing task to prevent multiprocessing from hang when exiting
3298
+ _set_iterator_cleanup()
3299
+ time.sleep(3)
3300
+
3301
+ def __handler_exception(self, ex_type, value, tb):
3302
+ self.origin_hook(ex_type, value, tb)
3303
+ self.mp_pool_exit_preprocess()
3304
+
3305
+ def __init__(self, op_name, num_parallel_workers, operations, max_rowsize=(-1, -1)):
3306
+ super(_PythonMultiprocessing, self).__init__()
3307
+ self.op_name = op_name
3308
+ self.num_parallel_workers = num_parallel_workers
3309
+ self.operations = operations
3310
+ self.max_rowsize = max_rowsize
3311
+
3312
+ self.workers = None
3313
+ self.pids = None
3314
+ self.op_id = -1
3315
+
3316
+ self.queues_map = {}
3317
+ self.next_queue = 0
3318
+
3319
+ self.eot = None
3320
+ self.watch_dog = None
3321
+ self.ppid = None
3322
+ self.hook = None
3323
+ self.warning_ctl = None
3324
+ # cache thread (get_ident()) to worker_id mapping in Python layer
3325
+ self.python_threads_to_workers = {}
3326
+ self.eof = None
3327
+
3328
+ def __del__(self):
3329
+ try:
3330
+ self.terminate()
3331
+ except TypeError:
3332
+ pass
3333
+
3334
+ # This wait function is for cleaning zombie subprocesses
3335
+ @staticmethod
3336
+ def wait_pid():
3337
+ """
3338
+ This function is used by the main process to release subprocess resources.
3339
+ """
3340
+ try:
3341
+ while True:
3342
+ child_pid, _ = os.waitpid(-1, os.WNOHANG)
3343
+ if child_pid == 0:
3344
+ break
3345
+ except OSError:
3346
+ # waitpid may fail for some reason, so we ignore this error
3347
+ pass
3348
+
3349
+ # Dataset need watch_dog thread to monitoring fork multiprocessing,
3350
+ # and thread can't be a member function otherwise python won't collect and release resources.
3351
+ @staticmethod
3352
+ def _watch_dog(eot, workers):
3353
+ """
3354
+ This thread is for monitoring subprocesses forked by GeneratorDataset/map/batch
3355
+ """
3356
+ if not isinstance(workers, list):
3357
+ raise TypeError("[Internal Error] The 2nd parameter of watch dog thread should be list of process, "
3358
+ "but got {}.".format(type(workers)))
3359
+
3360
+ while not eot.is_set():
3361
+ # Monitoring and count how many subprocesses already exit
3362
+ clear_subprocess_timeout = _PythonMultiprocessing._monitor_subprocess_exit(workers)
3363
+ # If find subprocess exit, we will wait for 30s and do some waitpid operations
3364
+ if clear_subprocess_timeout > 0:
3365
+ start = time.time()
3366
+ while time.time() - start < clear_subprocess_timeout:
3367
+ # We need to distinguishing get_dataset_size or train finished normally and hang scenario.
3368
+ # If get_dataset_size or train finished normally, _stop_subprocess can be execute and
3369
+ # self.need_abort can be set to True. If main process is hang in get(), self.need_abort
3370
+ # will never set to True, then we wait for 30s and kill main process
3371
+ if eot.is_set():
3372
+ return
3373
+ # Sometimes subprocess may be zombie, so in 30s we can wait and do some useful tasks(waitpid).
3374
+ _PythonMultiprocessing.wait_pid()
3375
+ # multiprocessing.queue may hang in .get() forever when put() process was killed.
3376
+ # We have to exit main process otherwise main process will hang.
3377
+ _PythonMultiprocessing._terminate_processes(workers)
3378
+ logger.critical("The subprocess of dataset may exit unexpected or be killed, "
3379
+ "main process will exit. If this is not an artificial operation, you can use "
3380
+ "ds.config.set_enable_watchdog(False) to block this error.")
3381
+ os.kill(os.getpid(), signal.SIGTERM)
3382
+ # sleep to release GIL
3383
+ time.sleep(1)
3384
+
3385
+ # release the workers
3386
+ del workers
3387
+
3388
+ @staticmethod
3389
+ def _terminate_processes(processes):
3390
+ """Terminate subprocesses"""
3391
+
3392
+ for p in processes:
3393
+ try:
3394
+ if p.exitcode is None:
3395
+ p.terminate()
3396
+ except Exception: # pylint: disable=broad-except
3397
+ # process has been closed already
3398
+ pass
3399
+ for p in processes:
3400
+ if p._closed is False: # pylint: disable=W0212
3401
+ # We don't use w.join because join can only used in main process or join will raise an error.
3402
+ p._popen.wait() # pylint: disable=W0212
3403
+
3404
+ # Monitor the exit number of subprocesses
3405
+ @staticmethod
3406
+ def _monitor_subprocess_exit(workers):
3407
+ """
3408
+ To monitor whether process is exit.
3409
+
3410
+ Args:
3411
+ workers (list of multiprocessing.Process): multiprocessing.Process.
3412
+
3413
+ Returns:
3414
+ int, the timeout(in seconds) when process exit.
3415
+ """
3416
+ for w in workers:
3417
+ try:
3418
+ exit_code = w.exitcode
3419
+ if exit_code is not None:
3420
+ # For kill -9, we can exit quickly
3421
+ if exit_code == -9:
3422
+ return 1
3423
+ # For kill -15, we still exit after 30s
3424
+ if exit_code == -15:
3425
+ return 30
3426
+ # In some cases the subprocess has been killed but the exitcode is still None.
3427
+ # So we use os.kill(pid, 0) to check if it is alive.
3428
+ subprocess_alive = _PythonMultiprocessing.is_process_alive(w.pid)
3429
+ if not subprocess_alive:
3430
+ # Like kill -15, we wait 30s before exit
3431
+ return 30
3432
+ except ValueError:
3433
+ # process has been closed already
3434
+ return 0
3435
+ return 0
3436
+
3437
+ @staticmethod
3438
+ def is_process_alive(pid):
3439
+ """
3440
+ Check if the process is alive or not.
3441
+ Note: We hit a deadlock when we use psutil or w.exitcode to check whether a process is alive.
3442
+ Instead we use os.kill(ppid, 0).
3443
+
3444
+ Args:
3445
+ pid: pid of the process to be checked
3446
+
3447
+ Returns:
3448
+ True if the process is alive
3449
+ """
3450
+
3451
+ try:
3452
+ os.kill(pid, 0)
3453
+ except OSError:
3454
+ return False
3455
+ return True
3456
+
3457
+ # When main process exit, subprocesses will be terminate
3458
+ @staticmethod
3459
+ def _clean_process(ppid, workers, quit_signal):
3460
+ """
3461
+ This is the execute function of clean process, if we found main process exited, we will clean subprocesses.
3462
+
3463
+ Args:
3464
+ ppid: The process id of main process.
3465
+ workers: The list of subprocesses.
3466
+ quit_signal: The flag of quit.
3467
+ """
3468
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
3469
+ while _PythonMultiprocessing.is_process_alive(ppid):
3470
+ if quit_signal.is_set():
3471
+ return
3472
+
3473
+ # independent dataset mode, the subprocess of GeneratorDataset / map / batch should exit when
3474
+ # independent dataset process have exit
3475
+ if os.getppid() != ppid:
3476
+ break
3477
+
3478
+ time.sleep(0.1)
3479
+
3480
+ _PythonMultiprocessing._terminate_processes(workers)
3481
+ del workers
3482
+ os.kill(os.getpid(), signal.SIGTERM)
3483
+
3484
+ def launch(self, op_id=-1):
3485
+ """
3486
+ Launch Python multiprocessing pool.
3487
+
3488
+ Args:
3489
+ op_id: ID for operation to have Python multiprocessing pool launched
3490
+
3491
+ Returns:
3492
+ Python multiprocessing pool is launched.
3493
+ """
3494
+ self.python_threads_to_workers = {}
3495
+ self.op_id = op_id
3496
+ logger.info("Launching new Python Multiprocessing pool for Op:" + str(self.op_id))
3497
+ if self.is_mp_enabled():
3498
+ message = "Launching a new Python multiprocessing pool while a pool already exists!" + \
3499
+ " The existing pool will be terminated first."
3500
+ logger.warning(message)
3501
+ self.terminate()
3502
+ self.reset()
3503
+ self.ppid = os.getpid()
3504
+ self.create_pool()
3505
+
3506
+ def create_pool(self):
3507
+ """
3508
+
3509
+ Returns:
3510
+
3511
+ """
3512
+ if get_enable_shared_mem():
3513
+ _check_shm_usage(self.num_parallel_workers, 1, self.max_rowsize[0], self.max_rowsize[1])
3514
+
3515
+ if self.workers is not None:
3516
+ raise Exception("Pool was already created, close it first.")
3517
+
3518
+ # Let gc collect unreferenced memory to avoid child processes in the pool to do it
3519
+ gc.collect()
3520
+
3521
+ # Construct python worker processes
3522
+ self.workers = []
3523
+ self.warning_ctl = multiprocessing.Value('i', 0)
3524
+ for worker_id in range(self.num_parallel_workers):
3525
+ worker = _MPWorker(self.operations, self.warning_ctl, self.max_rowsize, worker_id)
3526
+ worker.start()
3527
+ self.workers.append(worker)
3528
+
3529
+ logger.info("Op: " + str(self.op_id) + " Python multiprocessing pool workers' PIDs: " + str(self.get_pids()))
3530
+
3531
+ self.hook = _PythonMultiprocessing._ExceptHookHandler()
3532
+
3533
+ # The op (Map, Batch, etc) multiprocessing will launch a watch dog thread for monitoring sub processes
3534
+ self._launch_watch_dog()
3535
+
3536
+ atexit.register(self.terminate)
3537
+
3538
+ def terminate(self):
3539
+ # close watch dog first and then close all the workers
3540
+ self.abort_watchdog()
3541
+ self.close_all_workers()
3542
+ if hasattr(self, "warning_ctl"):
3543
+ del self.warning_ctl
3544
+
3545
+ def get_pids(self):
3546
+ """
3547
+ Get list of worker's PIDs
3548
+
3549
+ Returns:
3550
+ list of strings
3551
+ """
3552
+ if not self.is_mp_enabled():
3553
+ return []
3554
+ if not self.pids:
3555
+ self.pids = []
3556
+ if self.workers:
3557
+ for w in self.workers:
3558
+ try:
3559
+ self.pids.append(w.pid)
3560
+ except ValueError:
3561
+ continue
3562
+ return self.pids
3563
+
3564
+ def add_new_workers(self, num_new_workers):
3565
+ logger.info(
3566
+ "Increasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
3567
+ ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str(
3568
+ self.num_parallel_workers +
3569
+ num_new_workers) + ".")
3570
+ self.terminate()
3571
+ self.num_parallel_workers += num_new_workers
3572
+ self.launch(self.op_id)
3573
+
3574
+ def remove_workers(self, num_removed_workers):
3575
+ logger.info(
3576
+ "Decreasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
3577
+ ", old num_workers=" + str(self.num_parallel_workers) + " new num_workers=" + str(
3578
+ self.num_parallel_workers -
3579
+ num_removed_workers) + ".")
3580
+ self.terminate()
3581
+ self.num_parallel_workers -= num_removed_workers
3582
+ self.launch(self.op_id)
3583
+
3584
+ def is_mp_enabled(self):
3585
+ return self.workers is not None
3586
+
3587
+ def execute(self, idx, *args):
3588
+ """
3589
+ Execute
3590
+ """
3591
+ t_id = threading.get_ident()
3592
+ # get the worker_id from Python layer cache first, get from Cpp layer if not found.
3593
+ worker_id = self.python_threads_to_workers.setdefault(t_id, self.get_thread_to_worker())
3594
+ if worker_id >= len(self.workers):
3595
+ raise RuntimeError("[Internal] worker_id value is greater than number of available workers!")
3596
+
3597
+ # todo check_iterator_cleanup
3598
+ if self.is_running() and check_iterator_cleanup() is False:
3599
+ return self.workers[worker_id].execute(idx, *args)
3600
+
3601
+ return None
3602
+
3603
+ def _launch_watch_dog(self):
3604
+ """
3605
+ We will launch a watchdog thread and a clean process to cleaning subprocess when there is process was killed.
3606
+ The watchdog thread will cleanup subprocesses and main process when one of the subprocesses was killed.
3607
+ The cleaning subprocess will cleanup subprocesses when main process was killed.
3608
+ """
3609
+ if platform.system().lower() != 'windows':
3610
+ self.eof = multiprocessing.Event()
3611
+ self.cleaning_process = multiprocessing.Process(target=self._clean_process,
3612
+ name="MapCleanProcess",
3613
+ args=(self.ppid, self.workers, self.eof),
3614
+ daemon=True)
3615
+ self.cleaning_process.start()
3616
+
3617
+ if get_enable_watchdog():
3618
+ self.eot = threading.Event()
3619
+ self.watch_dog = threading.Thread(target=self._watch_dog,
3620
+ name="MapWatchDog",
3621
+ args=(self.eot, self.workers + [self.cleaning_process]),
3622
+ daemon=True)
3623
+ self.watch_dog.start()
3624
+
3625
+ def _abort_watchdog(self):
3626
+ if not self.eot.is_set():
3627
+ self.eot.set()
3628
+
3629
+ def abort_watchdog(self):
3630
+ if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
3631
+ self._abort_watchdog()
3632
+ if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
3633
+ if hasattr(self, 'eof') and self.eof is not None and not self.eof.is_set():
3634
+ self.eof.set()
3635
+ _PythonMultiprocessing._terminate_processes([self.cleaning_process])
3636
+ del self.cleaning_process
3637
+
3638
+ def is_running(self):
3639
+ if hasattr(self, 'workers') and self.workers is not None:
3640
+ return all([w.is_alive() for w in self.workers])
3641
+ return False
3642
+
3643
+ def close_all_workers(self):
3644
+ """Close all the subprocess workers"""
3645
+ if hasattr(self, 'workers') and self.workers is not None:
3646
+ for w in self.workers:
3647
+ w.close()
3648
+ check_interval = get_multiprocessing_timeout_interval()
3649
+ for w in self.workers:
3650
+ try:
3651
+ subprocess_file_descriptor = w.sentinel
3652
+ st = time.time()
3653
+ while _PythonMultiprocessing.is_process_alive(w.pid):
3654
+ time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
3655
+ if time.time() - st > check_interval:
3656
+ logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
3657
+ st += check_interval
3658
+ except ValueError as e:
3659
+ if "process object is closed" in str(e):
3660
+ continue
3661
+ raise e
3662
+ try:
3663
+ if w.is_alive():
3664
+ os.close(subprocess_file_descriptor)
3665
+ except OSError as e:
3666
+ # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
3667
+ if "Bad file descriptor" not in str(e):
3668
+ raise e
3669
+
3670
+ # use clear to release the handle which is better than self.workers = None
3671
+ self.workers.clear()
3672
+ self.workers = None
3673
+ self.pids = None
3674
+
3675
+
3676
+ class MapDataset(UnionBaseDataset):
3677
+ """
3678
+ The result of applying the Map operation to the input Dataset.
3679
+
3680
+ Args:
3681
+ input_dataset (Dataset): Input Dataset to be mapped.
3682
+ operations (Union[list[TensorOperation], list[functions]]): A function mapping a nested structure of tensors
3683
+ to another nested structure of tensor. Default: ``None``.
3684
+ input_columns (Union[str, list[str]]): List of names of the input columns.
3685
+ Default: ``None``, the operations will be applied on the first columns in the dataset.
3686
+ The size of the list should match the number of inputs of the first operation.
3687
+ output_columns (Union[str, list[str]], optional): List of names of the output columns.
3688
+ The size of the list should match the number of outputs of the last operation.
3689
+ Default: ``None``, output columns will be the input columns, i.e., the columns will
3690
+ be replaced.
3691
+ num_parallel_workers (int, optional): Number of workers to process the dataset
3692
+ in parallel. Default: ``None``.
3693
+ python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
3694
+ option could be beneficial if the Python operation is computational heavy. Default: ``False``.
3695
+ cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
3696
+ Default: ``None``, which means no cache is used.
3697
+ callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called. Default: ``None``.
3698
+ max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory
3699
+ allocation to copy data between processes, the total occupied shared memory will increase as
3700
+ ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
3701
+ shared memory will be dynamically allocated with the actual size of data. This is only used if
3702
+ ``python_multiprocessing`` is set to True. If it is an int value, it represents ``input_columns`` and
3703
+ ``output_columns`` use this value as the unit to create shared memory. If it is a list, the first element
3704
+ represents the ``input_columns`` use this value as the unit to create shared memory, and the second element
3705
+ represents ``output_columns`` use this value as the unit to create shared memory. Default: ``None`` ,
3706
+ allocate shared memory dynamically.
3707
+ offload (bool, optional): Flag to indicate whether offload is used. Default: ``None``.
3708
+ """
3709
+
3710
+ def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None,
3711
+ num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=None,
3712
+ offload=None):
3713
+ super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
3714
+ self.operations = to_list(operations)
3715
+ for op in self.operations:
3716
+ # user define c_vision.HWC2CHW without parentheses is error
3717
+ if type(op) == type: # pylint: disable=unidiomatic-typecheck
3718
+ raise ValueError("Parameter operations's element of method map should be a dataset processing "
3719
+ "operation instance, but got: {}. It may be missing parentheses for "
3720
+ "instantiation.".format(op))
3721
+ if not isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation)) \
3722
+ and not callable(op):
3723
+ raise ValueError("Parameter operations's element of method map should be a python function or "
3724
+ "class method which should be callable, but got: {}. It doesn't need parentheses "
3725
+ "for python function or class method.".format(op))
3726
+
3727
+ self.input_columns = to_list(input_columns)
3728
+ self.output_columns = to_list(output_columns)
3729
+
3730
+ # If output_columns were not provided then use input_columns
3731
+ self.output_columns = self.input_columns if not self.output_columns else self.output_columns
3732
+
3733
+ self.python_multiprocessing = python_multiprocessing
3734
+ self.process_pool = None
3735
+
3736
+ self.callbacks = to_list(callbacks)
3737
+ if max_rowsize is None:
3738
+ self.max_rowsize = [-1, -1]
3739
+ elif isinstance(max_rowsize, int):
3740
+ self.max_rowsize = [max_rowsize] * 2
3741
+ else:
3742
+ self.max_rowsize = max_rowsize
3743
+ self.offload = offload
3744
+
3745
+ def parse(self, children=None):
3746
+ operations = self.__decompose_callable_operations()
3747
+
3748
+ count_old_transforms, count_new_transforms, count_non_data_vision_transforms = \
3749
+ self.__count_transforms(operations)
3750
+ count_pyfunc = self.__count_pyfuncs(operations)
3751
+ if count_new_transforms + count_pyfunc == len(operations):
3752
+ prev_op = None
3753
+ for op in operations:
3754
+ # skip user added DebugHook to avoid changing to Py-implementation.
3755
+ if self.__is_debug_hook_op(op):
3756
+ if prev_op:
3757
+ # manually set previous_op_name
3758
+ prev_op_name = self.__parse_op_name(prev_op)
3759
+ op.set_previous_op_name(prev_op_name)
3760
+ continue
3761
+ if op.implementation is None:
3762
+ if prev_op and prev_op.implementation == Implementation.PY:
3763
+ op.implementation = Implementation.PY
3764
+ else:
3765
+ op.implementation = Implementation.C
3766
+ prev_op = op
3767
+ operations = self.__insert_debug_wrapper(operations)
3768
+ operations = transforms.transforms.Compose.reduce(operations)
3769
+ elif count_old_transforms + count_pyfunc + count_non_data_vision_transforms == len(operations):
3770
+ operations = self.__insert_debug_wrapper(operations)
3771
+ operations = transforms.py_transforms.Compose.reduce(operations)
3772
+ else:
3773
+ raise RuntimeError("Mixing old legacy c/py_transforms and new unified transforms is not allowed.")
3774
+
3775
+ self.operations = self.__process_final_operations(operations)
3776
+ self.prepare_multiprocessing()
3777
+
3778
+ callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
3779
+ return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns,
3780
+ callbacks, OffloadToManualOffloadMode.get(self.offload), self.process_pool)
3781
+
3782
+ def __deepcopy__(self, memodict):
3783
+ return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))
3784
+
3785
+ def __del__(self):
3786
+ if hasattr(self, "process_pool") and self.process_pool is not None:
3787
+ self.process_pool.terminate()
3788
+ del self.process_pool
3789
+
3790
+ @staticmethod
3791
+ def __parse_op_name(op):
3792
+ """
3793
+ Utility method to get operation name.
3794
+ """
3795
+ op_name = ""
3796
+ if isinstance(op, transforms.py_transforms_util.FuncWrapper):
3797
+ try:
3798
+ op_name = op.transform.__name__
3799
+ except (AttributeError,):
3800
+ op_name = op.transform.__class__.__name__
3801
+ else:
3802
+ op_name = op.__class__.__name__
3803
+ return op_name
3804
+
3805
+ @staticmethod
3806
+ def __construct_debug_hook(previous_op_name=None, is_first_op=False):
3807
+ """
3808
+ Wrap debug hook into FuncWrapper.
3809
+ """
3810
+ inserted_functions = []
3811
+ debug_hook_list = _get_debug_hook_list()
3812
+ if debug_hook_list:
3813
+ for fn in debug_hook_list:
3814
+ # making deep copy to allow each debug hook instance hold unique variables
3815
+ new_fn = copy.deepcopy(fn)
3816
+ new_fn.set_previous_op_name(previous_op_name)
3817
+ new_fn.set_is_first(is_first_op)
3818
+ inserted_func = transforms.py_transforms_util.FuncWrapper(new_fn)
3819
+ inserted_func.implementation = Implementation.PY
3820
+ inserted_functions.append(inserted_func)
3821
+ return inserted_functions
3822
+
3823
+ @staticmethod
3824
+ def __is_debug_hook_op(op):
3825
+ """
3826
+ Check if the op is user added DebugHook and skip it to avoid changing transforms implementation.
3827
+ """
3828
+ if isinstance(op, DebugHook):
3829
+ if not get_debug_mode():
3830
+ raise ValueError("It is not allowed to inject DebugHook object in non-debug mode.")
3831
+ return True
3832
+ return False
3833
+
3834
+ @staticmethod
3835
+ def __count_pyfuncs(operations):
3836
+ """
3837
+ Count the number of pyfuncs operations
3838
+ """
3839
+ return sum([1 if isinstance(op, FuncWrapper) else 0 for op in operations])
3840
+
3841
+ @staticmethod
3842
+ def __count_transforms(operations):
3843
+ """
3844
+ Count the various flavors of transforms operations
3845
+ """
3846
+ # Count the number of old legacy data and vision c_transforms and py_transforms
3847
+ count_old_transforms = sum(
3848
+ [1 if "c_transforms" in str(op)
3849
+ or isinstance(op, (c_transforms.TensorOperation, py_transforms.PyTensorOperation))
3850
+ or ("py_transforms" in str(op) and not isinstance(op, FuncWrapper))
3851
+ else 0 for op in operations])
3852
+ # Count the number of new unified data and vision transforms
3853
+ count_new_transforms = sum([1 if hasattr(op, "implementation") and not isinstance(op, FuncWrapper)
3854
+ else 0 for op in operations])
3855
+ # Count the number of non-data transforms and non-vision transforms
3856
+ count_non_data_vision_transforms = sum(
3857
+ [1 if "text.transforms" in str(op) or "audio.transforms" in str(op) else 0 for op in operations])
3858
+ return count_old_transforms, count_new_transforms, count_non_data_vision_transforms
3859
+
3860
+ @staticmethod
3861
+ def __operation_valid_for_multiprocessing(op):
3862
+ if callable(op) and str(op).find("c_transform") < 0:
3863
+ return True
3864
+ return False
3865
+
3866
+ @staticmethod
3867
+ def __process_final_operations(operations):
3868
+ """
3869
+ Build final list of operations
3870
+ """
3871
+ operations_fin = []
3872
+ for op in operations:
3873
+ if hasattr(op, "implementation"):
3874
+ if op.implementation == Implementation.C and not isinstance(op, (FuncWrapper, ToNumpy)):
3875
+ operations_fin.append(op.parse())
3876
+ elif op.implementation == Implementation.PY:
3877
+ operations_fin.append(op)
3878
+ elif isinstance(op, (FuncWrapper, ToNumpy)):
3879
+ operations_fin.append(op)
3880
+ else:
3881
+ raise RuntimeError("Wrong implementation")
3882
+ else:
3883
+ if op and getattr(op, 'parse', None):
3884
+ operations_fin.append(op.parse())
3885
+ else:
3886
+ operations_fin.append(op)
3887
+ return operations_fin
3888
+
3889
+ # Iterator bootstrap will be called on iterator construction.
3890
+ # A deep copy of Dataset object is created prior of iterator_bootstrap.
3891
+ # This method will create per iterator process pool and bind pyfunc execution to the pool.
3892
+ def prepare_multiprocessing(self):
3893
+ """
3894
+ Per iterator bootstrap callback.
3895
+ """
3896
+ if self.python_multiprocessing and platform.system().lower() == 'windows':
3897
+ logger.warning("Python multiprocessing is not supported on Windows platform.")
3898
+ return
3899
+ if self.python_multiprocessing and get_debug_mode():
3900
+ logger.warning("Python multiprocessing is not supported in debug mode."
3901
+ " Ignoring Python multiprocessing for map operation.")
3902
+ return
3903
+ if self.python_multiprocessing:
3904
+ iter_specific_operations = []
3905
+ callable_list = []
3906
+
3907
+ # If user didn't specify num_parallel_workers, set it to default
3908
+ if self.num_parallel_workers is None:
3909
+ self.num_parallel_workers = get_num_parallel_workers()
3910
+
3911
+ # Pass #1, look for Python callables and build list
3912
+ for op in self.operations:
3913
+ # our c transforms is now callable and should not be run in Python multithreading
3914
+ if MapDataset.__operation_valid_for_multiprocessing(op):
3915
+ callable_list.append(op)
3916
+
3917
+ if callable_list:
3918
+ self.process_pool = _PythonMultiprocessing(str(self), self.num_parallel_workers, callable_list,
3919
+ self.max_rowsize)
3920
+ # Pass #2
3921
+ idx = 0
3922
+ for op in self.operations:
3923
+ # our c transforms is now callable and should not be run in Python multithreading
3924
+ if MapDataset.__operation_valid_for_multiprocessing(op):
3925
+ # Wrap Python callable into _PythonCallable
3926
+ iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
3927
+ idx += 1
3928
+ else:
3929
+ # CPP ops remain the same
3930
+ iter_specific_operations.append(op)
3931
+ self.operations = iter_specific_operations
3932
+
3933
+ def __insert_debug_wrapper(self, operations):
3934
+ """
3935
+ Insert DebuggerWrapper before and after each op if debug mode is on.
3936
+ """
3937
+ if not get_debug_mode():
3938
+ return operations
3939
+ first_op_name = self.__parse_op_name(operations[0])
3940
+ inserted_operations = self.__construct_debug_hook(first_op_name, is_first_op=True)
3941
+ for op in operations:
3942
+ inserted_operations.append(op)
3943
+ op_name = self.__parse_op_name(op)
3944
+ inserted_operations.extend(self.__construct_debug_hook(op_name))
3945
+ return inserted_operations
3946
+
3947
+ def __decompose_callable_operations(self):
3948
+ """
3949
+ Decompose operations and build list of old legacy ops which are callable
3950
+ """
3951
+ decomposed_operations = transforms.transforms.Compose.decompose(self.operations)
3952
+ operations = []
3953
+ for op in decomposed_operations:
3954
+ if callable(op) and not hasattr(op, "implementation") and str(op).find(
3955
+ "c_transform") < 0 and not isinstance(op, c_transforms.TensorOperation) and \
3956
+ not isinstance(op, py_transforms.PyTensorOperation):
3957
+ op = transforms.py_transforms_util.FuncWrapper(op)
3958
+ operations.append(op)
3959
+ return operations
3960
+
3961
+
3962
+ class FilterDataset(UnionBaseDataset):
3963
+ """
3964
+ The result of applying filter predicate to the input Dataset.
3965
+
3966
+ Args:
3967
+ input_dataset (Dataset): Input Dataset to be mapped.
3968
+ predicate (callable): Python callable which returns a boolean value. If False then filter the element.
3969
+ input_columns (Union[str, list[str]], optional): List of names of the input columns.
3970
+ Default: ``None``, the predicate will be applied to all columns in the dataset.
3971
+ num_parallel_workers (int, optional): Number of workers to process the dataset
3972
+ in parallel. Default: ``None``.
3973
+ """
3974
+
3975
+ def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
3976
+ super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
3977
+ self.predicate = lambda *args: bool(predicate(*args))
3978
+ self.input_columns = to_list(input_columns)
3979
+
3980
+ def parse(self, children=None):
3981
+ return cde.FilterNode(children[0], self.predicate, self.input_columns)
3982
+
3983
+
3984
+ class RepeatDataset(UnionBaseDataset):
3985
+ """
3986
+ The result of applying Repeat operation to the input Dataset.
3987
+
3988
+ Args:
3989
+ input_dataset (Dataset): Input Dataset to be repeated.
3990
+ count (int): Number of times the dataset will be repeated. Default: -1, repeat indefinitely.
3991
+ """
3992
+
3993
+ def __init__(self, input_dataset, count):
3994
+ super().__init__(children=input_dataset)
3995
+ self.count = replace_none(count, -1)
3996
+
3997
+ def parse(self, children=None):
3998
+ return cde.RepeatNode(children[0], self.count)
3999
+
4000
+
4001
+ class SkipDataset(UnionBaseDataset):
4002
+ """
4003
+ The result of applying Skip operation to the input Dataset.
4004
+
4005
+ Args:
4006
+ input_dataset (Dataset): Input dataset to have elements skipped.
4007
+ count (int): Number of elements to be skipped in the dataset.
4008
+ """
4009
+
4010
+ def __init__(self, input_dataset, count):
4011
+ super().__init__(input_dataset)
4012
+ self.count = count
4013
+
4014
+ def parse(self, children=None):
4015
+ return cde.SkipNode(children[0], self.count)
4016
+
4017
+
4018
+ class TakeDataset(UnionBaseDataset):
4019
+ """
4020
+ The result of applying Take operation to the input Dataset.
4021
+
4022
+ Args:
4023
+ input_dataset (Dataset): Input Dataset to have elements taken from.
4024
+ count (int): Number of elements to be taken from the dataset.
4025
+ """
4026
+
4027
+ def __init__(self, input_dataset, count):
4028
+ super().__init__(children=input_dataset)
4029
+ self.count = count
4030
+
4031
+ def parse(self, children=None):
4032
+ return cde.TakeNode(children[0], self.count)
4033
+
4034
+
4035
+ class ZipDataset(UnionBaseDataset):
4036
+ """
4037
+ The result of applying Zip operation to the input Dataset.
4038
+
4039
+ Args:
4040
+ datasets (tuple): A tuple of datasets to be zipped together.
4041
+
4042
+ Raises:
4043
+ TypeError: If dataset is not an instance of Dataset.
4044
+ """
4045
+
4046
+ def __init__(self, datasets):
4047
+ super().__init__(children=datasets)
4048
+
4049
+ def parse(self, children=None):
4050
+ return cde.ZipNode(children)
4051
+
4052
+ def is_sync(self):
4053
+ return any([c.is_sync() for c in self.children])
4054
+
4055
+
4056
+ class ConcatDataset(UnionBaseDataset):
4057
+ """
4058
+ The result of applying Concat operation to the input Dataset.
4059
+
4060
+ Args:
4061
+ datasets (list): A list of datasets to be concatenated together.
4062
+
4063
+ Raises:
4064
+ TypeError: If dataset is not an instance of Dataset.
4065
+ ValueError: If there is no samples in the one of the datasets.
4066
+ """
4067
+
4068
+ def __init__(self, datasets):
4069
+ super().__init__(children=datasets)
4070
+ for dataset in datasets:
4071
+ if not isinstance(dataset, Dataset):
4072
+ raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
4073
+ self.datasets = datasets
4074
+ self._sampler = samplers.SequentialSampler(num_samples=None)
4075
+
4076
+ self.children_sizes_ = [c.get_dataset_size() for c in self.children]
4077
+ child_index = 0
4078
+ for item in self.children_sizes_:
4079
+ if item == 0:
4080
+ raise ValueError("There are no samples in the dataset number %d. Please make sure there are "
4081
+ "valid samples in the dataset." % child_index)
4082
+ child_index += 1
4083
+
4084
+ self._children_sizes = self.children_sizes_.copy()
4085
+
4086
+ # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
4087
+ # whether the dataset is mappable. The second element of pair is length of the dataset
4088
+ self._children_flag_and_nums = []
4089
+
4090
+ # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
4091
+ # the valid position of the dataset corresponding to the subscript when sampling
4092
+ self._children_start_end_index_ = []
4093
+ for index, child in enumerate(self.children):
4094
+ tem_list = [-1, -1]
4095
+ self._children_start_end_index_.append(tem_list)
4096
+ dataset_len = self.children_sizes_[index]
4097
+
4098
+ from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset
4099
+ if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
4100
+ dataset_len = 0
4101
+ self.children_sizes_[index] = 0
4102
+
4103
+ if isinstance(child, MappableDataset):
4104
+ self._children_flag_and_nums.append((0, dataset_len))
4105
+ else:
4106
+ self._children_flag_and_nums.append((1, dataset_len))
4107
+
4108
+ def parse(self, children=None):
4109
+ return cde.ConcatNode(children, self._sampler, self._children_flag_and_nums, self._children_start_end_index_,
4110
+ self._children_sizes)
4111
+
4112
+ def use_sampler(self, sampler):
4113
+ """
4114
+ Set the distributedSampler to concat dataset
4115
+
4116
+ Args:
4117
+ sampler (Sampler): The sampler to use for the current dataset.
4118
+ Currently supported: DistributedSampler.
4119
+
4120
+ Raises:
4121
+ TypeError: If the sampler is not an instance of DistributedSampler
4122
+ ValueError: If the parameter shuffle of sampler is True
4123
+ ValueError: If the parameter NumSamples of sampler is not None.
4124
+ ValueError: If num_shards <=0.
4125
+ """
4126
+ if not isinstance(sampler, (samplers.DistributedSampler, samplers.RandomSampler)):
4127
+ raise TypeError("The parameter %s of concat must be DistributedSampler or RandomSampler!" % sampler)
4128
+
4129
+ if isinstance(sampler, samplers.RandomSampler):
4130
+ if sampler.replacement:
4131
+ raise ValueError("The parameter replacement of RandomSampler must be False!")
4132
+
4133
+ if sampler.get_num_samples() is not None:
4134
+ raise ValueError("The parameter num_samples of RandomSampler is not support to be set!")
4135
+
4136
+ self._sampler = sampler
4137
+ self._children_sizes = [c.get_dataset_size() for c in self.children]
4138
+
4139
+ # Recursive access to other child concat nodes
4140
+ def set_child(node):
4141
+ for c in node.children:
4142
+ if isinstance(c, ConcatDataset):
4143
+ c.use_sampler(sampler)
4144
+ set_child(c)
4145
+ set_child(self)
4146
+
4147
+ return
4148
+
4149
+ if sampler.is_shuffled():
4150
+ raise ValueError("The parameter shuffle of DistributedSampler must be False!")
4151
+
4152
+ if sampler.num_shards <= 0:
4153
+ raise ValueError("The parameter num_shards of DistributedSampler must be positive int!")
4154
+
4155
+ if sampler.get_num_samples() is not None:
4156
+ raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!")
4157
+
4158
+ self.dataset_size = None
4159
+
4160
+ self._sampler = sampler
4161
+ cumulative_samples_nums = 0
4162
+ for index, child in enumerate(self.children):
4163
+ if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
4164
+ raise ValueError("The parameter NumSamples of %s is not support to be set!" % child)
4165
+
4166
+ if isinstance(child, (BatchDataset, PaddedBatchDataset)):
4167
+ raise TypeError("The parameter %s of concat must not be BatchDataset or PaddedBatchDataset!" % child)
4168
+
4169
+ # if child is mappable and the length is greater than 0
4170
+ if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
4171
+
4172
+ tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
4173
+
4174
+ if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
4175
+ if tem_value < sampler.num_shards:
4176
+ self._children_start_end_index_[index][0] = cumulative_samples_nums
4177
+ self._children_start_end_index_[index][1] = tem_value
4178
+ else:
4179
+ self._children_start_end_index_[index][0] = cumulative_samples_nums
4180
+ self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
4181
+
4182
+ tem_sampler = copy.deepcopy(sampler)
4183
+ tem_sampler.set_offset(cumulative_samples_nums)
4184
+ child.use_sampler(tem_sampler)
4185
+
4186
+ cumulative_samples_nums += self.children_sizes_[index]
4187
+ cumulative_samples_nums %= sampler.num_shards
4188
+
4189
+
4190
+ class RenameDataset(UnionBaseDataset):
4191
+ """
4192
+ The result of applying Rename operation to the input Dataset.
4193
+
4194
+ Args:
4195
+ input_dataset (Dataset): Input Dataset to be Renamed.
4196
+ input_columns (Union[str, list[str]]): List of names of the input columns.
4197
+ output_columns (Union[str, list[str]]): List of names of the output columns.
4198
+ """
4199
+
4200
+ def __init__(self, input_dataset, input_columns, output_columns):
4201
+ super().__init__(children=input_dataset)
4202
+ self.input_column_names = to_list(input_columns)
4203
+ self.output_column_names = to_list(output_columns)
4204
+
4205
+ def parse(self, children=None):
4206
+ return cde.RenameNode(children[0], self.input_column_names, self.output_column_names)
4207
+
4208
+
4209
+ def to_list(items):
4210
+ if items is None:
4211
+ return []
4212
+ if isinstance(items, tuple):
4213
+ return list(items)
4214
+ if not isinstance(items, list):
4215
+ return [items]
4216
+ return items
4217
+
4218
+
4219
+ class ProjectDataset(UnionBaseDataset):
4220
+ """
4221
+ The result of applying Project operation to the input Dataset.
4222
+
4223
+ Args:
4224
+ input_dataset (Dataset): Input Dataset to be Projected.
4225
+ columns (Union[str, list[str]]): List of names of the columns to project.
4226
+ """
4227
+
4228
+ def __init__(self, input_dataset, columns):
4229
+ super().__init__(children=input_dataset)
4230
+ self.columns = to_list(columns)
4231
+
4232
+ def parse(self, children=None):
4233
+ return cde.ProjectNode(children[0], self.columns)
4234
+
4235
+
4236
+ class _ToDevice:
4237
+ """
4238
+ Internal class to handle sending data to device.
4239
+ """
4240
+
4241
+ def __init__(self, dataset, num_epochs):
4242
+ if get_debug_mode():
4243
+ logger.error("MindData debugger cannot be used in dataset sink mode. Please manually turn off "
4244
+ "sink mode and try debugger again.")
4245
+ ir_tree, self.api_tree = dataset.create_ir_tree()
4246
+
4247
+ self._runtime_context = cde.PythonRuntimeContext()
4248
+ self._runtime_context.Init()
4249
+ self._to_device = cde.ToDevice(num_epochs)
4250
+ if dataset.get_init_step() != 0:
4251
+ init_step = dataset.get_init_step()
4252
+ dataset_size = dataset.get_dataset_size()
4253
+ self._to_device.Init(ir_tree, init_step, dataset_size)
4254
+ else:
4255
+ self._to_device.Init(ir_tree, 0, -1)
4256
+ self._runtime_context.AssignConsumer(self._to_device)
4257
+
4258
+ ITERATORS_LIST.append(weakref.ref(self))
4259
+ _unset_iterator_cleanup()
4260
+
4261
+ def send(self):
4262
+ self._to_device.Send()
4263
+
4264
+ def stop_send(self):
4265
+ """
4266
+ send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
4267
+ """
4268
+ self._to_device.StopSend()
4269
+
4270
+ def continue_send(self):
4271
+ """
4272
+ send continue send signal to pipeline, it is used when end of sequence is sent at the epoch end.
4273
+ """
4274
+ self._to_device.ContinueSend()
4275
+
4276
+ def get_data_info(self):
4277
+ """
4278
+ Get type and shape of current batch.
4279
+ """
4280
+ return self._to_device.GetDataInfo()
4281
+
4282
+ def get_mbuf_queue_size(self):
4283
+ """
4284
+ Get element numbers inside mbuf.
4285
+ """
4286
+ return self._to_device.GetMbufQueueSize()
4287
+
4288
+ def get_send_info(self):
4289
+ """
4290
+ In sink mode, it returns the send information of dataset at this moment.
4291
+ Send information includes number of send batches, time summary of fetching data on host
4292
+ and time summary of sending data.
4293
+ """
4294
+ return self._to_device.GetSendInfo()
4295
+
4296
+ def release(self):
4297
+ """
4298
+ Manually terminate Device Queue instead of relying on out of scope destruction.
4299
+ """
4300
+ if hasattr(self, '_runtime_context') and self._runtime_context:
4301
+ if hasattr(self, '_to_device') and self._to_device:
4302
+ self._runtime_context.Terminate()
4303
+ del self._to_device
4304
+ del self._runtime_context
4305
+
4306
+ def __deepcopy__(self, memodict):
4307
+ return self
4308
+
4309
+ def get_offload_model(self, col_names):
4310
+ """
4311
+ Get offload model containing removed offload ops from pipeline.
4312
+ """
4313
+ offload_model = GetOffloadModel(self._to_device, col_names)
4314
+ return offload_model
4315
+
4316
+ def _reset(self, step, dataset_size):
4317
+ self._to_device.Reset(step, dataset_size)
4318
+
4319
+
4320
+ class TransferDataset(Dataset):
4321
+ """
4322
+ The result of applying TDT operation to the input Dataset.
4323
+
4324
+ Args:
4325
+ input_dataset (Dataset): Input Dataset to be transferred.
4326
+ send_epoch_end (bool, optional): Whether to send end of sequence to device or not. Default: ``True``.
4327
+ create_data_info_queue (bool, optional): Whether to create queue which stores
4328
+ types and shapes of data or not. Default: ``False``.
4329
+
4330
+ Raises:
4331
+ TypeError: If device_type is empty.
4332
+ ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
4333
+ RuntimeError: If dataset is unknown.
4334
+ """
4335
+
4336
+ def __init__(self, input_dataset, send_epoch_end=True, create_data_info_queue=False, queue_name=""):
4337
+ super().__init__(children=input_dataset)
4338
+ if queue_name == "":
4339
+ self.queue_name = str(uuid.uuid1())
4340
+ logger.info(f"queue_name is newly generated. value is {self.queue_name}")
4341
+ else:
4342
+ self.queue_name = queue_name
4343
+ logger.info(f"queue_name is read from compile cache. value is {self.queue_name}")
4344
+ self.device_type = context.get_context("device_target") if context else "CPU"
4345
+ self.device_id = context.get_context("device_id") if context else 0
4346
+
4347
+ self._send_epoch_end = replace_none(send_epoch_end, True)
4348
+ self._create_data_info_queue = create_data_info_queue
4349
+ self._to_device = None
4350
+ self.column_name = input_dataset.get_col_names()
4351
+
4352
+ def parse(self, children=None):
4353
+ total_batch = 0
4354
+ if hasattr(self.children[0], "__total_batch__"):
4355
+ total_batch = self.children[0].__total_batch__
4356
+ check_total_batch(total_batch)
4357
+ return cde.DataQueueNode(children[0], self.queue_name, self.device_type, self.device_id, self._send_epoch_end,
4358
+ total_batch, self._create_data_info_queue)
4359
+
4360
+ def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
4361
+ raise RuntimeError("TransferDataset is not iterable.")
4362
+
4363
+ def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
4364
+ raise RuntimeError("TransferDataset is not iterable.")
4365
+
4366
+ def __iter__(self):
4367
+ raise RuntimeError("TransferDataset is not iterable.")
4368
+
4369
+ def output_shapes(self):
4370
+ raise RuntimeError("TransferDataset does not support obtaining output_shapes.")
4371
+
4372
+ def output_types(self):
4373
+ raise RuntimeError("TransferDataset does not support obtaining output_types.")
4374
+
4375
+ @check_to_device_send
4376
+ def send(self, num_epochs=-1):
4377
+ """
4378
+ Send to device
4379
+ """
4380
+ if Dataset._noop_mode():
4381
+ return
4382
+ if self._to_device is not None:
4383
+ del self._to_device
4384
+ self._to_device = _ToDevice(self, num_epochs)
4385
+ self._to_device.send()
4386
+
4387
+ def stop_send(self):
4388
+ if self._to_device is not None:
4389
+ self._to_device.stop_send()
4390
+
4391
+ def continue_send(self):
4392
+ if self._to_device is not None:
4393
+ self._to_device.continue_send()
4394
+
4395
+ def get_data_info(self):
4396
+ """
4397
+ Get type and shape of current batch
4398
+ """
4399
+ if self._to_device is not None:
4400
+ return self._to_device.get_data_info()
4401
+ raise RuntimeError("Calling get_data_info with bad state.")
4402
+
4403
+ def get_mbuf_queue_size(self):
4404
+ """
4405
+ Get element numbers inside mbuf.
4406
+ """
4407
+ if self._to_device is not None:
4408
+ return self._to_device.get_mbuf_queue_size()
4409
+ raise RuntimeError("Device queue is not init, call get_mbuf_queue_size failed.")
4410
+
4411
+ def get_send_info(self):
4412
+ """
4413
+ In sink mode, it returns the send information of dataset at this moment.
4414
+ Send information includes number of send batches, time summary of fetching data on host
4415
+ and time summary of sending data.
4416
+ """
4417
+ if self._to_device is not None:
4418
+ return self._to_device.get_send_info()
4419
+ raise RuntimeError("Calling get_send_info with bad state, data queue is not initialized.")
4420
+
4421
+ def get_offload_model(self):
4422
+ if self._to_device is not None:
4423
+ return self._to_device.get_offload_model(self.column_name)
4424
+
4425
+ raise RuntimeError("get_offload_model, _to_device is None")
4426
+
4427
+ def release(self):
4428
+ """
4429
+ Manually terminate Device Queue instead of relying on out of scope destruction.
4430
+ """
4431
+ if self._to_device is not None:
4432
+ self._to_device.release()
4433
+
4434
+ def _reset(self, step, dataset_size):
4435
+ if self._to_device is not None:
4436
+ logger.info("Reset the dataset pipeline to step: " + str(step) + ", epoch: " + str(step // dataset_size))
4437
+ self._to_device._reset(step, dataset_size) # pylint: disable=protected-access
4438
+
4439
+
4440
+ class Schema:
4441
+ """
4442
+ Class to represent a schema of a dataset.
4443
+
4444
+ Args:
4445
+ schema_file (str): Path of the schema file. Default: ``None``.
4446
+
4447
+ Raises:
4448
+ RuntimeError: If schema file failed to load.
4449
+
4450
+ Examples:
4451
+ >>> import mindspore.dataset as ds
4452
+ >>> from mindspore import dtype as mstype
4453
+ >>>
4454
+ >>> # Create schema; specify column name, mindspore.dtype and shape of the column
4455
+ >>> schema = ds.Schema()
4456
+ >>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
4457
+ """
4458
+
4459
+ @check_schema
4460
+ def __init__(self, schema_file=None):
4461
+ self.schema_file = replace_none(schema_file, "")
4462
+ self.cpp_schema = cde.SchemaObj(self.schema_file)
4463
+
4464
+ @check_add_column
4465
+ def add_column(self, name, de_type, shape=None):
4466
+ """
4467
+ Add new column to the schema.
4468
+
4469
+ Args:
4470
+ name (str): The new name of the column.
4471
+ de_type (str): Data type of the column.
4472
+ shape (list[int], optional): Shape of the column.
4473
+ Default: ``None``, [-1] which is an unknown shape of rank 1.
4474
+
4475
+ Raises:
4476
+ ValueError: If column type is unknown.
4477
+
4478
+ Examples:
4479
+ >>> import mindspore.dataset as ds
4480
+ >>> from mindspore import dtype as mstype
4481
+ >>>
4482
+ >>> schema = ds.Schema()
4483
+ >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
4484
+ """
4485
+ if isinstance(de_type, typing.Type):
4486
+ de_type = mstype_to_detype(de_type)
4487
+ col_type = str(de_type)
4488
+ else:
4489
+ col_type = str(cde.DataType(de_type))
4490
+ if shape is None:
4491
+ self.cpp_schema.add_column(name, col_type)
4492
+ else:
4493
+ self.cpp_schema.add_column(name, col_type, shape)
4494
+
4495
+ def parse_columns(self, columns):
4496
+ """
4497
+ Parse the columns and add it to self.
4498
+
4499
+ Args:
4500
+ columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file.
4501
+
4502
+ - list[dict], `name` and `type` must be in keys, `shape` optional.
4503
+
4504
+ - dict, columns.keys() as name, columns.values() is dict, and `type` inside, `shape` optional.
4505
+
4506
+ Raises:
4507
+ RuntimeError: If failed to parse columns.
4508
+ RuntimeError: If column's name field is missing.
4509
+ RuntimeError: If column's type field is missing.
4510
+
4511
+ Examples:
4512
+ >>> from mindspore.dataset import Schema
4513
+ >>> schema = Schema()
4514
+ >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
4515
+ ... {'name': 'label', 'type': 'int8', 'shape': [1]}]
4516
+ >>> schema.parse_columns(columns1)
4517
+ >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
4518
+ >>> schema.parse_columns(columns2)
4519
+ """
4520
+ self.cpp_schema.parse_columns(json.dumps(columns, indent=2))
4521
+
4522
+ def to_json(self):
4523
+ """
4524
+ Get a JSON string of the schema.
4525
+
4526
+ Returns:
4527
+ str, JSON string of the schema.
4528
+
4529
+ Examples:
4530
+ >>> from mindspore.dataset import Schema
4531
+ >>> from mindspore import dtype as mstype
4532
+ >>>
4533
+ >>> schema = Schema()
4534
+ >>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
4535
+ >>> json = schema.to_json()
4536
+ """
4537
+ return self.cpp_schema.to_json()
4538
+
4539
+ def from_json(self, json_obj):
4540
+ """
4541
+ Get schema file from JSON object.
4542
+
4543
+ Args:
4544
+ json_obj(dictionary): Object of JSON parsed.
4545
+
4546
+ Raises:
4547
+ RuntimeError: if there is unknown item in the object.
4548
+ RuntimeError: if dataset type is missing in the object.
4549
+ RuntimeError: if columns are missing in the object.
4550
+
4551
+ Examples:
4552
+ >>> import json
4553
+ >>> from mindspore.dataset import Schema
4554
+ >>>
4555
+ >>> with open("/path/to/schema_file", "r") as file:
4556
+ ... json_obj = json.load(file)
4557
+ ... schema = Schema()
4558
+ ... schema.from_json(json_obj)
4559
+ """
4560
+ self.cpp_schema.from_string(json.dumps(json_obj, indent=2))
4561
+
4562
+ def __str__(self):
4563
+ return self.to_json()
4564
+
4565
+ @staticmethod
4566
+ def get_num_rows(schema):
4567
+ schema_obj = schema
4568
+ if not isinstance(schema_obj, Schema):
4569
+ schema_obj = Schema(schema_obj)
4570
+ return schema_obj.cpp_schema.get_num_rows()
4571
+
4572
+
4573
+ class DeserializedDataset(Dataset):
4574
+ def __init__(self, input_obj):
4575
+ super().__init__()
4576
+ self.input_obj = input_obj
4577
+
4578
+ def parse(self, children=None):
4579
+ if isinstance(self.input_obj, dict):
4580
+ json_str = json.dumps(self.input_obj)
4581
+ return cde.Dataset.from_json_string(json_str)
4582
+ return cde.Dataset.from_json_file(self.input_obj)