mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-
|
|
1
|
+
# Copyright 2019-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,12 +19,13 @@ After declaring the dataset object, you can further apply dataset operations
|
|
|
19
19
|
(e.g. filter, skip, concat, map, batch) on it.
|
|
20
20
|
"""
|
|
21
21
|
import builtins
|
|
22
|
-
import copy
|
|
23
22
|
import errno
|
|
23
|
+
import itertools
|
|
24
24
|
import math
|
|
25
25
|
import os
|
|
26
26
|
import signal
|
|
27
27
|
import time
|
|
28
|
+
from types import GeneratorType
|
|
28
29
|
import multiprocessing
|
|
29
30
|
from multiprocessing.util import Finalize
|
|
30
31
|
import queue
|
|
@@ -46,7 +47,7 @@ from . import samplers
|
|
|
46
47
|
from .queue import _SharedQueue
|
|
47
48
|
from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
|
|
48
49
|
from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \
|
|
49
|
-
get_enable_watchdog, get_debug_mode
|
|
50
|
+
get_enable_watchdog, get_debug_mode, get_seed, set_seed
|
|
50
51
|
from ..core.datatypes import mstypelist_to_detypelist
|
|
51
52
|
from ..core.py_util_helpers import ExceptionHandler
|
|
52
53
|
from ..transforms import transforms
|
|
@@ -89,7 +90,7 @@ def _generator_fn(generator, num_samples):
|
|
|
89
90
|
yield _convert_row(val)
|
|
90
91
|
|
|
91
92
|
|
|
92
|
-
def _cpp_sampler_fn(
|
|
93
|
+
def _cpp_sampler_fn(dataset, sample_ids):
|
|
93
94
|
"""
|
|
94
95
|
Generator function wrapper for mappable dataset with cpp sampler.
|
|
95
96
|
"""
|
|
@@ -104,7 +105,7 @@ def _cpp_sampler_fn(sample_ids, dataset):
|
|
|
104
105
|
yield _convert_row(val)
|
|
105
106
|
|
|
106
107
|
|
|
107
|
-
def _cpp_sampler_fn_mp(
|
|
108
|
+
def _cpp_sampler_fn_mp(sample_fn, sample_ids):
|
|
108
109
|
"""
|
|
109
110
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
|
|
110
111
|
"""
|
|
@@ -116,6 +117,14 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn):
|
|
|
116
117
|
return sample_fn.process(sample_ids)
|
|
117
118
|
|
|
118
119
|
|
|
120
|
+
def _generator_fn_wrapper(function, *args):
|
|
121
|
+
"""
|
|
122
|
+
Generate a new function that wraps the specified generator function with partial
|
|
123
|
+
application of the given arguments and keywords.
|
|
124
|
+
"""
|
|
125
|
+
return partial(function, *args)
|
|
126
|
+
|
|
127
|
+
|
|
119
128
|
def _fill_worker_indices(workers, indices, idx_cursor, worker_to_quit):
|
|
120
129
|
"""
|
|
121
130
|
Worker index queue filler, fill worker index queue in round robin order or QUIT flag.
|
|
@@ -178,25 +187,42 @@ def _convert_row(row):
|
|
|
178
187
|
return tuple(value)
|
|
179
188
|
|
|
180
189
|
|
|
181
|
-
class SamplerFn:
|
|
190
|
+
class SamplerFn(cde.PythonMultiprocessingRuntime):
|
|
182
191
|
"""
|
|
183
192
|
Multiprocessing or multithread generator function wrapper master process.
|
|
184
193
|
"""
|
|
185
194
|
|
|
186
195
|
def __init__(self, dataset, num_worker, multi_process, max_rowsize):
|
|
196
|
+
super(SamplerFn, self).__init__()
|
|
187
197
|
self.workers = []
|
|
188
198
|
self.dataset = dataset
|
|
189
199
|
self.num_worker = num_worker
|
|
190
200
|
self.multi_process = multi_process
|
|
191
201
|
self.max_rowsize = max_rowsize
|
|
192
202
|
self.need_join = False
|
|
203
|
+
|
|
204
|
+
def is_mp_enabled(self):
|
|
205
|
+
return self.workers is not None and self.workers
|
|
206
|
+
|
|
207
|
+
def launch(self, op_id=-1):
|
|
208
|
+
"""launch the multiprocessing pool"""
|
|
209
|
+
self.op_id = op_id
|
|
210
|
+
logger.info("Launching new Python Multiprocessing pool for GeneratorOp:" + str(self.op_id))
|
|
211
|
+
if self.is_mp_enabled():
|
|
212
|
+
message = "Launching a new Python multiprocessing pool for GeneratorOp while a pool already exists!" + \
|
|
213
|
+
" The existing pool will be terminated first."
|
|
214
|
+
logger.warning(message)
|
|
215
|
+
self._stop_subprocess()
|
|
216
|
+
self.reset()
|
|
217
|
+
self.workers = []
|
|
218
|
+
|
|
193
219
|
self.ppid = os.getpid()
|
|
194
220
|
self.pids = []
|
|
195
221
|
self.check_interval = get_multiprocessing_timeout_interval() # the interval of check queue's size
|
|
196
222
|
self._final_join = True
|
|
197
223
|
|
|
198
224
|
# Event for end of epoch
|
|
199
|
-
if multi_process is True:
|
|
225
|
+
if self.multi_process is True:
|
|
200
226
|
try:
|
|
201
227
|
self.eof = multiprocessing.Event()
|
|
202
228
|
except Exception:
|
|
@@ -206,22 +232,22 @@ class SamplerFn:
|
|
|
206
232
|
self.eof = threading.Event()
|
|
207
233
|
# Create workers
|
|
208
234
|
|
|
209
|
-
# get default queue size and adjust
|
|
235
|
+
# get default queue size and adjust queue size per worker if there are large # workers
|
|
210
236
|
queue_size = get_prefetch_size()
|
|
211
|
-
queue_size = min(queue_size, queue_size * 4 // num_worker)
|
|
237
|
+
queue_size = min(queue_size, queue_size * 4 // self.num_worker)
|
|
212
238
|
queue_size = max(2, queue_size)
|
|
213
239
|
|
|
214
|
-
if multi_process and get_enable_shared_mem():
|
|
240
|
+
if self.multi_process and get_enable_shared_mem():
|
|
215
241
|
# generator dataset use idx_queue and res_queue to transfer data between main and subprocess
|
|
216
242
|
# idx_queue is used multiprocess.Queue which is not shared memory, so it's size is 0.
|
|
217
|
-
# res_queue is used shared memory, so
|
|
218
|
-
_check_shm_usage(num_worker, queue_size, 0, max_rowsize)
|
|
243
|
+
# res_queue is used shared memory, so its size is max_rowsize which is defined by user.
|
|
244
|
+
_check_shm_usage(self.num_worker, queue_size, 0, self.max_rowsize)
|
|
219
245
|
self.count = multiprocessing.Value('i', 0)
|
|
220
|
-
for worker_id in range(num_worker):
|
|
221
|
-
if multi_process is True:
|
|
246
|
+
for worker_id in range(self.num_worker):
|
|
247
|
+
if self.multi_process is True:
|
|
222
248
|
try:
|
|
223
|
-
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid,
|
|
224
|
-
worker_id)
|
|
249
|
+
worker = _GeneratorWorkerMp(self.dataset, self.eof, self.max_rowsize, queue_size, self.ppid,
|
|
250
|
+
self.count, worker_id)
|
|
225
251
|
worker.daemon = True
|
|
226
252
|
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
|
|
227
253
|
# which may cause deadlock. Therefore, the subprocess startup is performed in the initialization
|
|
@@ -240,10 +266,12 @@ class SamplerFn:
|
|
|
240
266
|
self.pids.append(worker.pid)
|
|
241
267
|
self.need_join = True
|
|
242
268
|
else:
|
|
243
|
-
worker = _GeneratorWorkerMt(dataset, self.eof, worker_id)
|
|
269
|
+
worker = _GeneratorWorkerMt(self.dataset, self.eof, worker_id)
|
|
244
270
|
worker.daemon = True
|
|
271
|
+
self.need_join = True
|
|
245
272
|
self.workers.append(worker)
|
|
246
|
-
self.
|
|
273
|
+
if self.multi_process and platform.system().lower() != 'windows':
|
|
274
|
+
self._launch_cleanup_worker()
|
|
247
275
|
|
|
248
276
|
def _interval_log(self, i, start_time, wait_count):
|
|
249
277
|
cost_time = int(time.time()) - start_time
|
|
@@ -252,11 +280,10 @@ class SamplerFn:
|
|
|
252
280
|
self._log_stuck_warning(self.workers[i % self.num_worker], cost_time)
|
|
253
281
|
return wait_count
|
|
254
282
|
|
|
255
|
-
def
|
|
256
|
-
"""
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
"""
|
|
283
|
+
def _check_and_start_process(self):
|
|
284
|
+
"""Check the idx_queue and start the process"""
|
|
285
|
+
if self.workers is None:
|
|
286
|
+
raise RuntimeError("The GeneratorDataset subprocess worker may be killed or exit abnormally.")
|
|
260
287
|
for w in self.workers:
|
|
261
288
|
# Check whether the queue of the subprocess is empty.
|
|
262
289
|
if not w.queue_empty():
|
|
@@ -270,7 +297,20 @@ class SamplerFn:
|
|
|
270
297
|
continue
|
|
271
298
|
# Start all workers
|
|
272
299
|
if not w.is_alive():
|
|
273
|
-
|
|
300
|
+
try:
|
|
301
|
+
w.start()
|
|
302
|
+
except RuntimeError as e:
|
|
303
|
+
# the worker may be being started.
|
|
304
|
+
if w._started.is_set(): # pylint: disable=W0212
|
|
305
|
+
continue
|
|
306
|
+
raise e
|
|
307
|
+
|
|
308
|
+
def process(self, indices):
|
|
309
|
+
"""
|
|
310
|
+
The main process, start the child process or child thread, and fill the index queue.
|
|
311
|
+
Get the result and return.
|
|
312
|
+
"""
|
|
313
|
+
self._check_and_start_process()
|
|
274
314
|
|
|
275
315
|
# Fill initial index queues
|
|
276
316
|
idx_cursor = 0
|
|
@@ -300,14 +340,6 @@ class SamplerFn:
|
|
|
300
340
|
time.sleep(0.1)
|
|
301
341
|
wait_count = self._interval_log(i, start_time, wait_count)
|
|
302
342
|
result = self.workers[i % self.num_worker].get()
|
|
303
|
-
# Because there is no need to copy when creating Tensors in the C++layer, it reduces the time
|
|
304
|
-
# from np.ndarray to C++Tensor creation. However, when using shared memory in multiple processes,
|
|
305
|
-
# the address of the shared memory will always be passed to subsequent nodes in the dataset pipeline,
|
|
306
|
-
# and the shared memory will also be written by the current node, causing dirty data to be accessed
|
|
307
|
-
# by subsequent nodes in the pipeline. So make a memory copy here to solve the problem of
|
|
308
|
-
# shared memory being contaminated.
|
|
309
|
-
if self.multi_process is True and get_enable_shared_mem():
|
|
310
|
-
result = copy.deepcopy(result)
|
|
311
343
|
if isinstance(result, ExceptionHandler):
|
|
312
344
|
result.reraise()
|
|
313
345
|
except queue.Empty:
|
|
@@ -360,44 +392,74 @@ class SamplerFn:
|
|
|
360
392
|
"the `mindspore.dataset.config.set_multiprocessing_timeout_interval` interface."
|
|
361
393
|
logger.warning(warning_message)
|
|
362
394
|
|
|
363
|
-
def _launch_cleanup_worker(self
|
|
395
|
+
def _launch_cleanup_worker(self):
|
|
364
396
|
"""
|
|
365
397
|
We need a extra thread and process if main process or subprocess was killed.
|
|
366
|
-
|
|
367
|
-
Args:
|
|
368
|
-
multi_process: Whether use multiprocess.
|
|
369
398
|
"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
self.
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
399
|
+
_clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212
|
|
400
|
+
self.cleaning_process = multiprocessing.Process(target=_clean_worker_func,
|
|
401
|
+
name="GeneratorCleanProcess",
|
|
402
|
+
args=(self.ppid, self.workers, self.eof))
|
|
403
|
+
self.cleaning_process.daemon = True
|
|
404
|
+
self.cleaning_process.start()
|
|
405
|
+
|
|
406
|
+
if get_enable_watchdog():
|
|
407
|
+
self.eot = threading.Event()
|
|
408
|
+
self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
|
|
409
|
+
name="GeneratorWatchDog",
|
|
410
|
+
args=(self.eot, self.workers + [self.cleaning_process]))
|
|
411
|
+
self.watch_dog.daemon = True
|
|
412
|
+
self.watch_dog.start()
|
|
413
|
+
|
|
414
|
+
if self._final_join is True:
|
|
415
|
+
self._jointhread = Finalize(
|
|
416
|
+
self.watch_dog, self._finalize_join,
|
|
417
|
+
args=(weakref.ref(self.watch_dog), self.eot),
|
|
418
|
+
exitpriority=-5
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
def _release_fd(self):
|
|
422
|
+
"""Release the file descriptor by subprocess"""
|
|
423
|
+
# release the file descriptor handle
|
|
424
|
+
check_interval = get_multiprocessing_timeout_interval()
|
|
425
|
+
for w in self.workers:
|
|
426
|
+
try:
|
|
427
|
+
subprocess_file_descriptor = w.sentinel
|
|
428
|
+
st = time.time()
|
|
429
|
+
while _PythonMultiprocessing.is_process_alive(w.pid):
|
|
430
|
+
time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
|
|
431
|
+
if time.time() - st > check_interval:
|
|
432
|
+
logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
|
|
433
|
+
st += check_interval
|
|
434
|
+
except ValueError as e:
|
|
435
|
+
if "process object is closed" in str(e):
|
|
436
|
+
continue
|
|
437
|
+
raise e
|
|
438
|
+
try:
|
|
439
|
+
if w.is_alive():
|
|
440
|
+
os.close(subprocess_file_descriptor)
|
|
441
|
+
except OSError as e:
|
|
442
|
+
# Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
|
|
443
|
+
if "Bad file descriptor" not in str(e):
|
|
444
|
+
raise e
|
|
445
|
+
except AttributeError: # maybe occur "'NoneType' object has no attribute 'maxsize'"
|
|
446
|
+
pass
|
|
392
447
|
|
|
393
448
|
def _stop_subprocess(self):
|
|
394
|
-
"""Only the main process can call join."""
|
|
449
|
+
"""Only the main process can call join. All the sub-process / sub-thread will be stopped."""
|
|
395
450
|
if self.need_join is True and self.ppid == os.getpid():
|
|
451
|
+
# the sub-process / sub-thread will stop by self.eof.set()
|
|
396
452
|
if hasattr(self, 'eof') and self.eof is not None:
|
|
397
|
-
|
|
453
|
+
try:
|
|
454
|
+
self.eof.set()
|
|
455
|
+
except AttributeError: # maybe occur "'NoneType' object has no attribute 'maxsize'"
|
|
456
|
+
pass
|
|
457
|
+
|
|
398
458
|
# close the watch dog first
|
|
399
459
|
self._abort_watchdog()
|
|
400
460
|
self.need_join = False
|
|
461
|
+
|
|
462
|
+
# waiting for the sub-process stop
|
|
401
463
|
for w in self.workers:
|
|
402
464
|
if self.multi_process is True and hasattr(w, '_closed') and w._closed is False: # pylint: disable=W0212
|
|
403
465
|
try:
|
|
@@ -415,28 +477,8 @@ class SamplerFn:
|
|
|
415
477
|
# Block all errors when join
|
|
416
478
|
continue
|
|
417
479
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
for w in self.workers:
|
|
421
|
-
try:
|
|
422
|
-
subprocess_file_descriptor = w.sentinel
|
|
423
|
-
st = time.time()
|
|
424
|
-
while _PythonMultiprocessing.is_process_alive(w.pid):
|
|
425
|
-
time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
|
|
426
|
-
if time.time() - st > check_interval:
|
|
427
|
-
logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
|
|
428
|
-
st += check_interval
|
|
429
|
-
except ValueError as e:
|
|
430
|
-
if "process object is closed" in str(e):
|
|
431
|
-
continue
|
|
432
|
-
raise e
|
|
433
|
-
try:
|
|
434
|
-
if w.is_alive():
|
|
435
|
-
os.close(subprocess_file_descriptor)
|
|
436
|
-
except OSError as e:
|
|
437
|
-
# Maybe the file descriptor had been released, so ignore the 'Bad file descriptor'
|
|
438
|
-
if "Bad file descriptor" not in str(e):
|
|
439
|
-
raise e
|
|
480
|
+
if self.multi_process is True:
|
|
481
|
+
self._release_fd()
|
|
440
482
|
|
|
441
483
|
self.workers.clear()
|
|
442
484
|
self.workers = None
|
|
@@ -498,13 +540,21 @@ def _main_process_already_exit(eof, is_multiprocessing, idx_queue, result_queue,
|
|
|
498
540
|
return False
|
|
499
541
|
|
|
500
542
|
|
|
501
|
-
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1):
|
|
543
|
+
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, worker_id, ppid=-1):
|
|
502
544
|
"""
|
|
503
545
|
Multithread or multiprocess generator worker process loop.
|
|
504
546
|
"""
|
|
547
|
+
# Initialize C++ side signal handlers
|
|
548
|
+
cde.register_worker_handlers()
|
|
549
|
+
|
|
505
550
|
if is_multiprocessing:
|
|
506
551
|
result_queue.cancel_join_thread() # Ensure that the process does not hung when exiting
|
|
507
552
|
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
|
|
553
|
+
|
|
554
|
+
# init the random seed and np.random seed for the subprocess
|
|
555
|
+
if get_seed() != 5489:
|
|
556
|
+
set_seed(get_seed() + worker_id)
|
|
557
|
+
|
|
508
558
|
while not eof.is_set():
|
|
509
559
|
_ignore_sigint(is_multiprocessing=is_multiprocessing)
|
|
510
560
|
|
|
@@ -562,7 +612,8 @@ class _GeneratorWorkerMt(threading.Thread):
|
|
|
562
612
|
def __init__(self, dataset, eof, worker_id):
|
|
563
613
|
self.idx_queue = queue.Queue(16)
|
|
564
614
|
self.res_queue = queue.Queue(16)
|
|
565
|
-
super().__init__(target=_generator_worker_loop,
|
|
615
|
+
super().__init__(target=_generator_worker_loop,
|
|
616
|
+
args=(dataset, self.idx_queue, self.res_queue, eof, False, worker_id),
|
|
566
617
|
name="GeneratorWorkerThread" + str(worker_id))
|
|
567
618
|
|
|
568
619
|
def put(self, item):
|
|
@@ -598,8 +649,9 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
598
649
|
self.res_queue = _SharedQueue(queue_size, count, max_rowsize=max_rowsize)
|
|
599
650
|
else:
|
|
600
651
|
self.res_queue = multiprocessing.Queue(queue_size)
|
|
601
|
-
self.idx_queue.cancel_join_thread() # Ensure that the process does not
|
|
602
|
-
super().__init__(target=_generator_worker_loop,
|
|
652
|
+
self.idx_queue.cancel_join_thread() # Ensure that the process does not hang when exiting
|
|
653
|
+
super().__init__(target=_generator_worker_loop,
|
|
654
|
+
args=(dataset, self.idx_queue, self.res_queue, eof, True, worker_id, ppid),
|
|
603
655
|
name="GeneratorWorkerProcess" + str(worker_id))
|
|
604
656
|
|
|
605
657
|
def put(self, item):
|
|
@@ -634,6 +686,20 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
634
686
|
del self.res_queue
|
|
635
687
|
|
|
636
688
|
|
|
689
|
+
class _GeneratorWrapper:
|
|
690
|
+
"""Wrapper the generator so that it can be iterated multiple times in GeneratorDataset."""
|
|
691
|
+
def __init__(self, generator):
|
|
692
|
+
self.generator = generator
|
|
693
|
+
self.generator_new, self.generator = itertools.tee(self.generator)
|
|
694
|
+
|
|
695
|
+
def __iter__(self):
|
|
696
|
+
self.generator_new, self.generator = itertools.tee(self.generator)
|
|
697
|
+
return self
|
|
698
|
+
|
|
699
|
+
def __next__(self):
|
|
700
|
+
return next(self.generator_new)
|
|
701
|
+
|
|
702
|
+
|
|
637
703
|
class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
638
704
|
"""
|
|
639
705
|
A source dataset that generates data from Python by invoking Python data source each epoch.
|
|
@@ -671,11 +737,11 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
671
737
|
Random accessible input is required.
|
|
672
738
|
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
|
|
673
739
|
option could be beneficial if the Python operation is computational heavy. Default: ``True``.
|
|
674
|
-
max_rowsize(int, optional): Maximum size of
|
|
740
|
+
max_rowsize(int, optional): Maximum size of data (in MB) that is used for shared memory
|
|
675
741
|
allocation to copy data between processes, the total occupied shared memory will increase as
|
|
676
742
|
``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1,
|
|
677
743
|
shared memory will be dynamically allocated with the actual size of data. This is only used if
|
|
678
|
-
``python_multiprocessing`` is set to True. Default:
|
|
744
|
+
``python_multiprocessing`` is set to True. Default: ``None`` , allocate shared memory dynamically.
|
|
679
745
|
|
|
680
746
|
Raises:
|
|
681
747
|
RuntimeError: If source raises an exception during execution.
|
|
@@ -693,16 +759,16 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
693
759
|
|
|
694
760
|
Note:
|
|
695
761
|
- If you configure `python_multiprocessing=True` (Default: ``True`` ) and `num_parallel_workers>1`
|
|
696
|
-
(default: ``1`` ) indicates that the
|
|
762
|
+
(default: ``1`` ) indicates that the multiprocessing mode is started for data load acceleration.
|
|
697
763
|
At this time, as the datasetiterates, the memory consumption of the subprocess will gradually increase,
|
|
698
764
|
mainly because the subprocess of the user-defined dataset obtains the member variables from the main
|
|
699
765
|
process in the Copy On Write way.
|
|
700
766
|
Example: If you define a dataset with `__ init__` function which contains a large number of member variable
|
|
701
767
|
data (for example, a very large file name list is loaded during the dataset construction) and uses the
|
|
702
|
-
|
|
768
|
+
multiprocessing mode, which may cause the problem of OOM (the estimated total memory usage is:
|
|
703
769
|
`(num_parallel_workers+1) * size of the parent process` ). The simplest solution is to replace Python objects
|
|
704
770
|
(such as list/dict/int/float/string) with non referenced data types
|
|
705
|
-
(such as Pandas, Numpy or PyArrow objects) for member variables, or load less
|
|
771
|
+
(such as Pandas, Numpy or PyArrow objects) for member variables, or load less metadata in member variables,
|
|
706
772
|
or configure `python_multiprocessing=False` to use multi-threading mode.
|
|
707
773
|
|
|
708
774
|
There are several classes/functions that can help you reduce the size of member variables, and you can choose
|
|
@@ -782,7 +848,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
782
848
|
@check_generatordataset
|
|
783
849
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
|
784
850
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
|
|
785
|
-
python_multiprocessing=True, max_rowsize=
|
|
851
|
+
python_multiprocessing=True, max_rowsize=None):
|
|
786
852
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
|
787
853
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
|
788
854
|
if isinstance(source, builtins.zip):
|
|
@@ -790,6 +856,11 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
790
856
|
self.source = [item for item in source]
|
|
791
857
|
else:
|
|
792
858
|
self.source = source
|
|
859
|
+
|
|
860
|
+
# wrapper the generator so that it can be iterated multiple times
|
|
861
|
+
if isinstance(self.source, GeneratorType):
|
|
862
|
+
self.source = _GeneratorWrapper(self.source)
|
|
863
|
+
|
|
793
864
|
self.prepared_source = None # source to be sent to C++
|
|
794
865
|
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
|
|
795
866
|
self.num_parallel_workers = 1
|
|
@@ -805,7 +876,6 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
805
876
|
if self.python_multiprocessing and get_debug_mode():
|
|
806
877
|
logger.warning("Python multiprocessing is not supported in debug mode."
|
|
807
878
|
" Ignoring Python multiprocessing for GeneratorDataset.")
|
|
808
|
-
self.python_multiprocessing = False
|
|
809
879
|
|
|
810
880
|
self.column_names = to_list(column_names)
|
|
811
881
|
|
|
@@ -829,7 +899,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
829
899
|
if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
|
|
830
900
|
self.source_len = len(list(sampler))
|
|
831
901
|
|
|
832
|
-
self.max_rowsize = max_rowsize
|
|
902
|
+
self.max_rowsize = max_rowsize if max_rowsize is not None else -1
|
|
833
903
|
self.sample_fn = None
|
|
834
904
|
|
|
835
905
|
def __deepcopy__(self, memodict):
|
|
@@ -863,14 +933,14 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
863
933
|
if self.source_len == -1:
|
|
864
934
|
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
|
|
865
935
|
|
|
866
|
-
if self.num_parallel_workers > 1:
|
|
936
|
+
if self.num_parallel_workers > 1 and not get_debug_mode():
|
|
867
937
|
self.__validate_memory_usage()
|
|
868
938
|
|
|
869
939
|
sample_fn = SamplerFn(self.source, self.num_parallel_workers, self.python_multiprocessing,
|
|
870
940
|
self.max_rowsize)
|
|
871
|
-
self.prepared_source = (
|
|
941
|
+
self.prepared_source = _generator_fn_wrapper(_cpp_sampler_fn_mp, sample_fn)
|
|
872
942
|
else:
|
|
873
|
-
self.prepared_source = (
|
|
943
|
+
self.prepared_source = _generator_fn_wrapper(_cpp_sampler_fn, self.source)
|
|
874
944
|
self.sample_fn = sample_fn
|
|
875
945
|
else:
|
|
876
946
|
self.sampler = None
|
|
@@ -878,30 +948,30 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
878
948
|
self.source_len = min(self.source_len, self.num_samples) if self.num_samples != 0 else self.source_len
|
|
879
949
|
if not hasattr(self.source, "__iter__"):
|
|
880
950
|
# Use generator function if input callable
|
|
881
|
-
self.prepared_source = (
|
|
951
|
+
self.prepared_source = _generator_fn_wrapper(_generator_fn, self.source, self.num_samples)
|
|
882
952
|
else:
|
|
883
953
|
# Use iterator function if input is iterable
|
|
884
954
|
# Random accessible input is also iterable
|
|
885
|
-
self.prepared_source = (
|
|
955
|
+
self.prepared_source = _generator_fn_wrapper(_iter_fn, self.source, self.num_samples)
|
|
886
956
|
|
|
887
957
|
def parse(self, children=None):
|
|
888
958
|
self.prepare_multiprocessing()
|
|
889
959
|
if self.schema is None:
|
|
890
960
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
|
891
|
-
self.sampler, self.num_parallel_workers)
|
|
961
|
+
self.sampler, self.num_parallel_workers, self.sample_fn)
|
|
892
962
|
schema = self.schema
|
|
893
963
|
if isinstance(schema, Schema):
|
|
894
964
|
schema = self.schema.cpp_schema
|
|
895
965
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
|
|
896
|
-
self.num_parallel_workers)
|
|
966
|
+
self.num_parallel_workers, self.sample_fn)
|
|
897
967
|
|
|
898
968
|
def __validate_memory_usage(self):
|
|
899
969
|
"""
|
|
900
|
-
Check memory usage when
|
|
970
|
+
Check memory usage when multiprocessing mode, when 85% prompt warning and 100% raise error.
|
|
901
971
|
"""
|
|
902
972
|
if self.python_multiprocessing:
|
|
903
|
-
#
|
|
904
|
-
#
|
|
973
|
+
# setting num_parallel_workers too large when using python multiprocessing may cause
|
|
974
|
+
# out of memory for getting num_shards
|
|
905
975
|
valid_num_shards = 1
|
|
906
976
|
if isinstance(self.sampler, samplers.DistributedSampler):
|
|
907
977
|
valid_num_shards = self.sampler.num_shards
|