mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2023 Huawei Technologies Co., Ltd
1
+ # Copyright 2019-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,12 +19,13 @@ After declaring the dataset object, you can further apply dataset operations
19
19
  (e.g. filter, skip, concat, map, batch) on it.
20
20
  """
21
21
  import builtins
22
- import copy
23
22
  import errno
23
+ import itertools
24
24
  import math
25
25
  import os
26
26
  import signal
27
27
  import time
28
+ from types import GeneratorType
28
29
  import multiprocessing
29
30
  from multiprocessing.util import Finalize
30
31
  import queue
@@ -46,7 +47,7 @@ from . import samplers
46
47
  from .queue import _SharedQueue
47
48
  from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
48
49
  from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \
49
- get_enable_watchdog, get_debug_mode
50
+ get_enable_watchdog, get_debug_mode, get_seed, set_seed
50
51
  from ..core.datatypes import mstypelist_to_detypelist
51
52
  from ..core.py_util_helpers import ExceptionHandler
52
53
  from ..transforms import transforms
@@ -89,7 +90,7 @@ def _generator_fn(generator, num_samples):
89
90
  yield _convert_row(val)
90
91
 
91
92
 
92
- def _cpp_sampler_fn(sample_ids, dataset):
93
+ def _cpp_sampler_fn(dataset, sample_ids):
93
94
  """
94
95
  Generator function wrapper for mappable dataset with cpp sampler.
95
96
  """
@@ -104,7 +105,7 @@ def _cpp_sampler_fn(sample_ids, dataset):
104
105
  yield _convert_row(val)
105
106
 
106
107
 
107
- def _cpp_sampler_fn_mp(sample_ids, sample_fn):
108
+ def _cpp_sampler_fn_mp(sample_fn, sample_ids):
108
109
  """
109
110
  Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
110
111
  """
@@ -116,6 +117,14 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn):
116
117
  return sample_fn.process(sample_ids)
117
118
 
118
119
 
120
+ def _generator_fn_wrapper(function, *args):
121
+ """
122
+ Generate a new function that wraps the specified generator function with partial
123
+ application of the given arguments and keywords.
124
+ """
125
+ return partial(function, *args)
126
+
127
+
119
128
  def _fill_worker_indices(workers, indices, idx_cursor, worker_to_quit):
120
129
  """
121
130
  Worker index queue filler, fill worker index queue in round robin order or QUIT flag.
@@ -178,25 +187,42 @@ def _convert_row(row):
178
187
  return tuple(value)
179
188
 
180
189
 
181
- class SamplerFn:
190
+ class SamplerFn(cde.PythonMultiprocessingRuntime):
182
191
  """
183
192
  Multiprocessing or multithread generator function wrapper master process.
184
193
  """
185
194
 
186
195
  def __init__(self, dataset, num_worker, multi_process, max_rowsize):
196
+ super(SamplerFn, self).__init__()
187
197
  self.workers = []
188
198
  self.dataset = dataset
189
199
  self.num_worker = num_worker
190
200
  self.multi_process = multi_process
191
201
  self.max_rowsize = max_rowsize
192
202
  self.need_join = False
203
+
204
+ def is_mp_enabled(self):
205
+ return self.workers is not None and self.workers
206
+
207
+ def launch(self, op_id=-1):
208
+ """launch the multiprocessing pool"""
209
+ self.op_id = op_id
210
+ logger.info("Launching new Python Multiprocessing pool for GeneratorOp:" + str(self.op_id))
211
+ if self.is_mp_enabled():
212
+ message = "Launching a new Python multiprocessing pool for GeneratorOp while a pool already exists!" + \
213
+ " The existing pool will be terminated first."
214
+ logger.warning(message)
215
+ self._stop_subprocess()
216
+ self.reset()
217
+ self.workers = []
218
+
193
219
  self.ppid = os.getpid()
194
220
  self.pids = []
195
221
  self.check_interval = get_multiprocessing_timeout_interval() # the interval of check queue's size
196
222
  self._final_join = True
197
223
 
198
224
  # Event for end of epoch
199
- if multi_process is True:
225
+ if self.multi_process is True:
200
226
  try:
201
227
  self.eof = multiprocessing.Event()
202
228
  except Exception:
@@ -206,22 +232,22 @@ class SamplerFn:
206
232
  self.eof = threading.Event()
207
233
  # Create workers
208
234
 
209
- # get default queue size and adjust queuesize per worker if there are large # workers
235
+ # get default queue size and adjust queue size per worker if there are large # workers
210
236
  queue_size = get_prefetch_size()
211
- queue_size = min(queue_size, queue_size * 4 // num_worker)
237
+ queue_size = min(queue_size, queue_size * 4 // self.num_worker)
212
238
  queue_size = max(2, queue_size)
213
239
 
214
- if multi_process and get_enable_shared_mem():
240
+ if self.multi_process and get_enable_shared_mem():
215
241
  # generator dataset use idx_queue and res_queue to transfer data between main and subprocess
216
242
  # idx_queue is used multiprocess.Queue which is not shared memory, so it's size is 0.
217
- # res_queue is used shared memory, so it' size is max_rowsize which is defined by user.
218
- _check_shm_usage(num_worker, queue_size, 0, max_rowsize)
243
+ # res_queue is used shared memory, so its size is max_rowsize which is defined by user.
244
+ _check_shm_usage(self.num_worker, queue_size, 0, self.max_rowsize)
219
245
  self.count = multiprocessing.Value('i', 0)
220
- for worker_id in range(num_worker):
221
- if multi_process is True:
246
+ for worker_id in range(self.num_worker):
247
+ if self.multi_process is True:
222
248
  try:
223
- worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid, self.count,
224
- worker_id)
249
+ worker = _GeneratorWorkerMp(self.dataset, self.eof, self.max_rowsize, queue_size, self.ppid,
250
+ self.count, worker_id)
225
251
  worker.daemon = True
226
252
  # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
227
253
  # which may cause deadlock. Therefore, the subprocess startup is performed in the initialization
@@ -240,10 +266,12 @@ class SamplerFn:
240
266
  self.pids.append(worker.pid)
241
267
  self.need_join = True
242
268
  else:
243
- worker = _GeneratorWorkerMt(dataset, self.eof, worker_id)
269
+ worker = _GeneratorWorkerMt(self.dataset, self.eof, worker_id)
244
270
  worker.daemon = True
271
+ self.need_join = True
245
272
  self.workers.append(worker)
246
- self._launch_cleanup_worker(multi_process=multi_process)
273
+ if self.multi_process and platform.system().lower() != 'windows':
274
+ self._launch_cleanup_worker()
247
275
 
248
276
  def _interval_log(self, i, start_time, wait_count):
249
277
  cost_time = int(time.time()) - start_time
@@ -252,11 +280,10 @@ class SamplerFn:
252
280
  self._log_stuck_warning(self.workers[i % self.num_worker], cost_time)
253
281
  return wait_count
254
282
 
255
- def process(self, indices):
256
- """
257
- The main process, start the child process or child thread, and fill the index queue.
258
- Get the result and return.
259
- """
283
+ def _check_and_start_process(self):
284
+ """Check the idx_queue and start the process"""
285
+ if self.workers is None:
286
+ raise RuntimeError("The GeneratorDataset subprocess worker may be killed or exit abnormally.")
260
287
  for w in self.workers:
261
288
  # Check whether the queue of the subprocess is empty.
262
289
  if not w.queue_empty():
@@ -270,7 +297,20 @@ class SamplerFn:
270
297
  continue
271
298
  # Start all workers
272
299
  if not w.is_alive():
273
- w.start()
300
+ try:
301
+ w.start()
302
+ except RuntimeError as e:
303
+ # the worker may be being started.
304
+ if w._started.is_set(): # pylint: disable=W0212
305
+ continue
306
+ raise e
307
+
308
+ def process(self, indices):
309
+ """
310
+ The main process, start the child process or child thread, and fill the index queue.
311
+ Get the result and return.
312
+ """
313
+ self._check_and_start_process()
274
314
 
275
315
  # Fill initial index queues
276
316
  idx_cursor = 0
@@ -300,14 +340,6 @@ class SamplerFn:
300
340
  time.sleep(0.1)
301
341
  wait_count = self._interval_log(i, start_time, wait_count)
302
342
  result = self.workers[i % self.num_worker].get()
303
- # Because there is no need to copy when creating Tensors in the C++layer, it reduces the time
304
- # from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes,
305
- # the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline,
306
- # and the shared memory will also be written by the current node, causing dirty data to be accessed
307
- # by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of
308
- # shared memory being contaminated.
309
- if self.multi_process is True and get_enable_shared_mem():
310
- result = copy.deepcopy(result)
311
343
  if isinstance(result, ExceptionHandler):
312
344
  result.reraise()
313
345
  except queue.Empty:
@@ -360,44 +392,74 @@ class SamplerFn:
360
392
  "the `mindspore.dataset.config.set_multiprocessing_timeout_interval` interface."
361
393
  logger.warning(warning_message)
362
394
 
363
- def _launch_cleanup_worker(self, multi_process):
395
+ def _launch_cleanup_worker(self):
364
396
  """
365
397
  We need a extra thread and process if main process or subprocess was killed.
366
-
367
- Args:
368
- multi_process: Whether use multiprocess.
369
398
  """
370
- if multi_process is True and platform.system().lower() != 'windows':
371
- _clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212
372
- self.cleaning_process = multiprocessing.Process(target=_clean_worker_func,
373
- name="GeneratorCleanProcess",
374
- args=(self.ppid, self.workers, self.eof))
375
- self.cleaning_process.daemon = True
376
- self.cleaning_process.start()
377
-
378
- if get_enable_watchdog():
379
- self.eot = threading.Event()
380
- self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
381
- name="GeneratorWatchDog",
382
- args=(self.eot, self.workers + [self.cleaning_process]))
383
- self.watch_dog.daemon = True
384
- self.watch_dog.start()
385
-
386
- if self._final_join is True:
387
- self._jointhread = Finalize(
388
- self.watch_dog, self._finalize_join,
389
- args=(weakref.ref(self.watch_dog), self.eot),
390
- exitpriority=-5
391
- )
399
+ _clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212
400
+ self.cleaning_process = multiprocessing.Process(target=_clean_worker_func,
401
+ name="GeneratorCleanProcess",
402
+ args=(self.ppid, self.workers, self.eof))
403
+ self.cleaning_process.daemon = True
404
+ self.cleaning_process.start()
405
+
406
+ if get_enable_watchdog():
407
+ self.eot = threading.Event()
408
+ self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
409
+ name="GeneratorWatchDog",
410
+ args=(self.eot, self.workers + [self.cleaning_process]))
411
+ self.watch_dog.daemon = True
412
+ self.watch_dog.start()
413
+
414
+ if self._final_join is True:
415
+ self._jointhread = Finalize(
416
+ self.watch_dog, self._finalize_join,
417
+ args=(weakref.ref(self.watch_dog), self.eot),
418
+ exitpriority=-5
419
+ )
420
+
421
+ def _release_fd(self):
422
+ """Release the file descriptor by subprocess"""
423
+ # release the file descriptor handle
424
+ check_interval = get_multiprocessing_timeout_interval()
425
+ for w in self.workers:
426
+ try:
427
+ subprocess_file_descriptor = w.sentinel
428
+ st = time.time()
429
+ while _PythonMultiprocessing.is_process_alive(w.pid):
430
+ time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
431
+ if time.time() - st > check_interval:
432
+ logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
433
+ st += check_interval
434
+ except ValueError as e:
435
+ if "process object is closed" in str(e):
436
+ continue
437
+ raise e
438
+ try:
439
+ if w.is_alive():
440
+ os.close(subprocess_file_descriptor)
441
+ except OSError as e:
442
+ # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
443
+ if "Bad file descriptor" not in str(e):
444
+ raise e
445
+ except AttributeError: # maybe occur "'NoneType' object has no attribute 'maxsize'"
446
+ pass
392
447
 
393
448
  def _stop_subprocess(self):
394
- """Only the main process can call join."""
449
+ """Only the main process can call join. All the sub-process / sub-thread will be stopped."""
395
450
  if self.need_join is True and self.ppid == os.getpid():
451
+ # the sub-process / sub-thread will stop by self.eof.set()
396
452
  if hasattr(self, 'eof') and self.eof is not None:
397
- self.eof.set()
453
+ try:
454
+ self.eof.set()
455
+ except AttributeError: # maybe occur "'NoneType' object has no attribute 'maxsize'"
456
+ pass
457
+
398
458
  # close the watch dog first
399
459
  self._abort_watchdog()
400
460
  self.need_join = False
461
+
462
+ # waiting for the sub-process stop
401
463
  for w in self.workers:
402
464
  if self.multi_process is True and hasattr(w, '_closed') and w._closed is False: # pylint: disable=W0212
403
465
  try:
@@ -415,28 +477,8 @@ class SamplerFn:
415
477
  # Block all errors when join
416
478
  continue
417
479
 
418
- # release the file descriptor handle
419
- check_interval = get_multiprocessing_timeout_interval()
420
- for w in self.workers:
421
- try:
422
- subprocess_file_descriptor = w.sentinel
423
- st = time.time()
424
- while _PythonMultiprocessing.is_process_alive(w.pid):
425
- time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
426
- if time.time() - st > check_interval:
427
- logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
428
- st += check_interval
429
- except ValueError as e:
430
- if "process object is closed" in str(e):
431
- continue
432
- raise e
433
- try:
434
- if w.is_alive():
435
- os.close(subprocess_file_descriptor)
436
- except OSError as e:
437
- # Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
438
- if "Bad file descriptor" not in str(e):
439
- raise e
480
+ if self.multi_process is True:
481
+ self._release_fd()
440
482
 
441
483
  self.workers.clear()
442
484
  self.workers = None
@@ -498,13 +540,21 @@ def _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue,
498
540
  return False
499
541
 
500
542
 
501
- def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1):
543
+ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, worker_id, ppid=-1):
502
544
  """
503
545
  Multithread or multiprocess generator worker process loop.
504
546
  """
547
+ # Initialize C++ side signal handlers
548
+ cde.register_worker_handlers()
549
+
505
550
  if is_multiprocessing:
506
551
  result_queue.cancel_join_thread() # Ensure that the process does not hung when exiting
507
552
  signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
553
+
554
+ # init the random seed and np.random seed for the subprocess
555
+ if get_seed() != 5489:
556
+ set_seed(get_seed() + worker_id)
557
+
508
558
  while not eof.is_set():
509
559
  _ignore_sigint(is_multiprocessing=is_multiprocessing)
510
560
 
@@ -562,7 +612,8 @@ class _GeneratorWorkerMt(threading.Thread):
562
612
  def __init__(self, dataset, eof, worker_id):
563
613
  self.idx_queue = queue.Queue(16)
564
614
  self.res_queue = queue.Queue(16)
565
- super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False),
615
+ super().__init__(target=_generator_worker_loop,
616
+ args=(dataset, self.idx_queue, self.res_queue, eof, False, worker_id),
566
617
  name="GeneratorWorkerThread" + str(worker_id))
567
618
 
568
619
  def put(self, item):
@@ -598,8 +649,9 @@ class _GeneratorWorkerMp(multiprocessing.Process):
598
649
  self.res_queue = _SharedQueue(queue_size, count, max_rowsize=max_rowsize)
599
650
  else:
600
651
  self.res_queue = multiprocessing.Queue(queue_size)
601
- self.idx_queue.cancel_join_thread() # Ensure that the process does not hung when exiting
602
- super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid),
652
+ self.idx_queue.cancel_join_thread() # Ensure that the process does not hang when exiting
653
+ super().__init__(target=_generator_worker_loop,
654
+ args=(dataset, self.idx_queue, self.res_queue, eof, True, worker_id, ppid),
603
655
  name="GeneratorWorkerProcess" + str(worker_id))
604
656
 
605
657
  def put(self, item):
@@ -634,6 +686,20 @@ class _GeneratorWorkerMp(multiprocessing.Process):
634
686
  del self.res_queue
635
687
 
636
688
 
689
+ class _GeneratorWrapper:
690
+ """Wrapper the generator so that it can be iterated multiple times in GeneratorDataset."""
691
+ def __init__(self, generator):
692
+ self.generator = generator
693
+ self.generator_new, self.generator = itertools.tee(self.generator)
694
+
695
+ def __iter__(self):
696
+ self.generator_new, self.generator = itertools.tee(self.generator)
697
+ return self
698
+
699
+ def __next__(self):
700
+ return next(self.generator_new)
701
+
702
+
637
703
  class GeneratorDataset(MappableDataset, UnionBaseDataset):
638
704
  """
639
705
  A source dataset that generates data from Python by invoking Python data source each epoch.
@@ -671,11 +737,11 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
671
737
  Random accessible input is required.
672
738
  python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
673
739
  option could be beneficial if the Python operation is computational heavy. Default: ``True``.
674
- max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory
740
+ max_rowsize(int, optional): Maximum size of data (in MB) that is used for shared memory
675
741
  allocation to copy data between processes, the total occupied shared memory will increase as
676
742
  ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
677
743
  shared memory will be dynamically allocated with the actual size of data. This is only used if
678
- ``python_multiprocessing`` is set to True. Default: 16.
744
+ ``python_multiprocessing`` is set to True. Default: ``None`` , allocate shared memory dynamically.
679
745
 
680
746
  Raises:
681
747
  RuntimeError: If source raises an exception during execution.
@@ -693,16 +759,16 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
693
759
 
694
760
  Note:
695
761
  - If you configure `python_multiprocessing=True` (Default: ``True`` ) and `num_parallel_workers>1`
696
- (default: ``1`` ) indicates that the multi-process mode is started for data load acceleration.
762
+ (default: ``1`` ) indicates that the multiprocessing mode is started for data load acceleration.
697
763
  At this time, as the datasetiterates, the memory consumption of the subprocess will gradually increase,
698
764
  mainly because the subprocess of the user-defined dataset obtains the member variables from the main
699
765
  process in the Copy On Write way.
700
766
  Example: If you define a dataset with `__ init__` function which contains a large number of member variable
701
767
  data (for example, a very large file name list is loaded during the dataset construction) and uses the
702
- multi-process mode, which may cause the problem of OOM (the estimated total memory usage is:
768
+ multiprocessing mode, which may cause the problem of OOM (the estimated total memory usage is:
703
769
  `(num_parallel_workers+1) * size of the parent process` ). The simplest solution is to replace Python objects
704
770
  (such as list/dict/int/float/string) with non referenced data types
705
- (such as Pandas, Numpy or PyArrow objects) for member variables, or load less meta data in member variables,
771
+ (such as Pandas, Numpy or PyArrow objects) for member variables, or load less metadata in member variables,
706
772
  or configure `python_multiprocessing=False` to use multi-threading mode.
707
773
 
708
774
  There are several classes/functions that can help you reduce the size of member variables, and you can choose
@@ -782,7 +848,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
782
848
  @check_generatordataset
783
849
  def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
784
850
  num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
785
- python_multiprocessing=True, max_rowsize=6):
851
+ python_multiprocessing=True, max_rowsize=None):
786
852
  super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
787
853
  shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
788
854
  if isinstance(source, builtins.zip):
@@ -790,6 +856,11 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
790
856
  self.source = [item for item in source]
791
857
  else:
792
858
  self.source = source
859
+
860
+ # wrapper the generator so that it can be iterated multiple times
861
+ if isinstance(self.source, GeneratorType):
862
+ self.source = _GeneratorWrapper(self.source)
863
+
793
864
  self.prepared_source = None # source to be sent to C++
794
865
  if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
795
866
  self.num_parallel_workers = 1
@@ -805,7 +876,6 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
805
876
  if self.python_multiprocessing and get_debug_mode():
806
877
  logger.warning("Python multiprocessing is not supported in debug mode."
807
878
  " Ignoring Python multiprocessing for GeneratorDataset.")
808
- self.python_multiprocessing = False
809
879
 
810
880
  self.column_names = to_list(column_names)
811
881
 
@@ -829,7 +899,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
829
899
  if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
830
900
  self.source_len = len(list(sampler))
831
901
 
832
- self.max_rowsize = max_rowsize
902
+ self.max_rowsize = max_rowsize if max_rowsize is not None else -1
833
903
  self.sample_fn = None
834
904
 
835
905
  def __deepcopy__(self, memodict):
@@ -863,14 +933,14 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
863
933
  if self.source_len == -1:
864
934
  raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
865
935
 
866
- if self.num_parallel_workers > 1:
936
+ if self.num_parallel_workers > 1 and not get_debug_mode():
867
937
  self.__validate_memory_usage()
868
938
 
869
939
  sample_fn = SamplerFn(self.source, self.num_parallel_workers, self.python_multiprocessing,
870
940
  self.max_rowsize)
871
- self.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
941
+ self.prepared_source = _generator_fn_wrapper(_cpp_sampler_fn_mp, sample_fn)
872
942
  else:
873
- self.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
943
+ self.prepared_source = _generator_fn_wrapper(_cpp_sampler_fn, self.source)
874
944
  self.sample_fn = sample_fn
875
945
  else:
876
946
  self.sampler = None
@@ -878,30 +948,30 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
878
948
  self.source_len = min(self.source_len, self.num_samples) if self.num_samples != 0 else self.source_len
879
949
  if not hasattr(self.source, "__iter__"):
880
950
  # Use generator function if input callable
881
- self.prepared_source = (lambda: _generator_fn(self.source, self.num_samples))
951
+ self.prepared_source = _generator_fn_wrapper(_generator_fn, self.source, self.num_samples)
882
952
  else:
883
953
  # Use iterator function if input is iterable
884
954
  # Random accessible input is also iterable
885
- self.prepared_source = (lambda: _iter_fn(self.source, self.num_samples))
955
+ self.prepared_source = _generator_fn_wrapper(_iter_fn, self.source, self.num_samples)
886
956
 
887
957
  def parse(self, children=None):
888
958
  self.prepare_multiprocessing()
889
959
  if self.schema is None:
890
960
  return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
891
- self.sampler, self.num_parallel_workers)
961
+ self.sampler, self.num_parallel_workers, self.sample_fn)
892
962
  schema = self.schema
893
963
  if isinstance(schema, Schema):
894
964
  schema = self.schema.cpp_schema
895
965
  return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
896
- self.num_parallel_workers)
966
+ self.num_parallel_workers, self.sample_fn)
897
967
 
898
968
  def __validate_memory_usage(self):
899
969
  """
900
- Check memory usage when mulit-processing mode, when 85% prompt warning and 100% raise error.
970
+ Check memory usage when multiprocessing mode, when 85% prompt warning and 100% raise error.
901
971
  """
902
972
  if self.python_multiprocessing:
903
- # if use num_parallel_workers is to large when python_multiprocessing=True which would cause
904
- # OOM error get the num_shards
973
+ # setting num_parallel_workers too large when using python multiprocessing may cause
974
+ # out of memory for getting num_shards
905
975
  valid_num_shards = 1
906
976
  if isinstance(self.sampler, samplers.DistributedSampler):
907
977
  valid_num_shards = self.sampler.num_shards