opengris-parfun 7.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. opengris_parfun-7.3.0.dist-info/METADATA +165 -0
  2. opengris_parfun-7.3.0.dist-info/RECORD +43 -0
  3. opengris_parfun-7.3.0.dist-info/WHEEL +5 -0
  4. opengris_parfun-7.3.0.dist-info/licenses/LICENSE +201 -0
  5. opengris_parfun-7.3.0.dist-info/licenses/LICENSE.spdx +7 -0
  6. opengris_parfun-7.3.0.dist-info/licenses/NOTICE +7 -0
  7. opengris_parfun-7.3.0.dist-info/top_level.txt +1 -0
  8. parfun/__init__.py +26 -0
  9. parfun/about.py +1 -0
  10. parfun/backend/__init__.py +0 -0
  11. parfun/backend/dask.py +151 -0
  12. parfun/backend/local_multiprocessing.py +92 -0
  13. parfun/backend/local_single_process.py +47 -0
  14. parfun/backend/mixins.py +68 -0
  15. parfun/backend/profiled_future.py +50 -0
  16. parfun/backend/scaler.py +226 -0
  17. parfun/backend/utility.py +7 -0
  18. parfun/combine/__init__.py +0 -0
  19. parfun/combine/collection.py +13 -0
  20. parfun/combine/dataframe.py +13 -0
  21. parfun/dataframe.py +175 -0
  22. parfun/decorators.py +135 -0
  23. parfun/entry_point.py +180 -0
  24. parfun/functions.py +71 -0
  25. parfun/kernel/__init__.py +0 -0
  26. parfun/kernel/function_signature.py +197 -0
  27. parfun/kernel/parallel_function.py +262 -0
  28. parfun/object.py +7 -0
  29. parfun/partition/__init__.py +0 -0
  30. parfun/partition/api.py +136 -0
  31. parfun/partition/collection.py +13 -0
  32. parfun/partition/dataframe.py +16 -0
  33. parfun/partition/object.py +50 -0
  34. parfun/partition/primitives.py +317 -0
  35. parfun/partition/utility.py +54 -0
  36. parfun/partition_size_estimator/__init__.py +0 -0
  37. parfun/partition_size_estimator/linear_regression_estimator.py +189 -0
  38. parfun/partition_size_estimator/mixins.py +22 -0
  39. parfun/partition_size_estimator/object.py +19 -0
  40. parfun/profiler/__init__.py +0 -0
  41. parfun/profiler/functions.py +261 -0
  42. parfun/profiler/object.py +68 -0
  43. parfun/py_list.py +56 -0
@@ -0,0 +1,92 @@
1
+ import multiprocessing
2
+ from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor
3
+ from threading import BoundedSemaphore
4
+
5
+ import attrs
6
+ import psutil
7
+ from attrs.validators import instance_of
8
+
9
+ from parfun.backend.mixins import BackendEngine, BackendSession
10
+ from parfun.backend.profiled_future import ProfiledFuture
11
+ from parfun.profiler.functions import profile, timed_function
12
+
13
+
14
+ class LocalMultiprocessingSession(BackendSession):
15
+ # Additional constant scheduling overhead that cannot be accounted for when measuring the task execution duration.
16
+ CONSTANT_SCHEDULING_OVERHEAD = 1_500_000 # 1.5ms
17
+
18
+ def __init__(self, underlying_executor: Executor):
19
+ self._underlying_executor = underlying_executor
20
+ self._concurrent_task_guard = BoundedSemaphore(underlying_executor._max_workers) # type: ignore[attr-defined]
21
+
22
+ def __enter__(self) -> "LocalMultiprocessingSession":
23
+ return self
24
+
25
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
26
+ return None
27
+
28
+ def submit(self, fn, *args, **kwargs) -> ProfiledFuture:
29
+ with profile() as submit_duration:
30
+ future = ProfiledFuture()
31
+
32
+ self._concurrent_task_guard.acquire()
33
+
34
+ underlying_future = self._underlying_executor.submit(timed_function, fn, *args, **kwargs)
35
+
36
+ def on_done_callback(underlying_future: Future):
37
+ assert submit_duration.value is not None
38
+
39
+ if underlying_future.cancelled():
40
+ future.cancel()
41
+ return
42
+
43
+ with profile() as release_duration:
44
+ exception = underlying_future.exception()
45
+
46
+ if exception is None:
47
+ result, function_duration = underlying_future.result()
48
+ else:
49
+ function_duration = 0
50
+ result = None
51
+
52
+ self._concurrent_task_guard.release()
53
+
54
+ task_duration = (
55
+ self.CONSTANT_SCHEDULING_OVERHEAD + submit_duration.value + function_duration + release_duration.value
56
+ )
57
+
58
+ if exception is None:
59
+ future.set_result(result, duration=task_duration)
60
+ else:
61
+ future.set_exception(exception, duration=task_duration)
62
+
63
+ underlying_future.add_done_callback(on_done_callback)
64
+
65
+ return future
66
+
67
+
68
+ @attrs.define(init=False)
69
+ class LocalMultiprocessingBackend(BackendEngine):
70
+ """
71
+ Concurrent engine that uses Python builtin :mod:`multiprocessing` module.
72
+ """
73
+
74
+ _underlying_executor: Executor = attrs.field(validator=instance_of(Executor), init=False)
75
+ _concurrent_task_guard: BoundedSemaphore = attrs.field(validator=instance_of(BoundedSemaphore), init=False)
76
+
77
+ def __init__(self, max_workers: int = psutil.cpu_count(logical=False) - 1, is_process: bool = True, **kwargs):
78
+ if is_process:
79
+ self._underlying_executor = ProcessPoolExecutor(
80
+ max_workers=max_workers, mp_context=multiprocessing.get_context("spawn"), **kwargs
81
+ )
82
+ else:
83
+ self._underlying_executor = ThreadPoolExecutor(max_workers=max_workers, **kwargs)
84
+
85
+ def session(self) -> LocalMultiprocessingSession:
86
+ return LocalMultiprocessingSession(self._underlying_executor)
87
+
88
+ def shutdown(self, wait=True):
89
+ self._underlying_executor.shutdown(wait=wait)
90
+
91
+ def allows_nested_tasks(self) -> bool:
92
+ return False
@@ -0,0 +1,47 @@
1
+ from typing import Callable
2
+
3
+ from parfun.backend.mixins import BackendEngine, BackendSession
4
+ from parfun.backend.profiled_future import ProfiledFuture
5
+ from parfun.profiler.functions import profile
6
+
7
+
8
+ class LocalSingleProcessSession(BackendSession):
9
+ # Additional constant scheduling overhead that cannot be accounted for when measuring the task execution duration.
10
+ CONSTANT_SCHEDULING_OVERHEAD = 5_000 # 5 us
11
+
12
+ def __enter__(self) -> "LocalSingleProcessSession":
13
+ return self
14
+
15
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
16
+ return None
17
+
18
+ def submit(self, fn: Callable, *args, **kwargs) -> ProfiledFuture:
19
+ with profile() as function_duration:
20
+ future = ProfiledFuture()
21
+
22
+ try:
23
+ result = fn(*args, **kwargs)
24
+ exception = None
25
+ except Exception as e:
26
+ exception = e
27
+ result = None
28
+
29
+ task_duration = self.CONSTANT_SCHEDULING_OVERHEAD + function_duration.value
30
+
31
+ if exception is None:
32
+ future.set_result(result, duration=task_duration)
33
+ else:
34
+ future.set_exception(exception, duration=task_duration)
35
+
36
+ return future
37
+
38
+
39
+ class LocalSingleProcessBackend(BackendEngine):
40
+ def session(self) -> BackendSession:
41
+ return LocalSingleProcessSession()
42
+
43
+ def shutdown(self):
44
+ pass
45
+
46
+ def allows_nested_tasks(self) -> bool:
47
+ return False
@@ -0,0 +1,68 @@
1
+ import abc
2
+ from contextlib import AbstractContextManager
3
+ from typing import Any, Callable
4
+
5
+ from parfun.backend.profiled_future import ProfiledFuture
6
+
7
+
8
+ class BackendSession(AbstractContextManager, metaclass=abc.ABCMeta):
9
+ """
10
+ A task submitting session to a backend engine that manages the lifecycle of the task objects (preloaded values,
11
+ argument values and future objects).
12
+ """
13
+
14
+ def preload_value(self, value: Any) -> Any:
15
+ """
16
+ Preloads a value to the backend engine.
17
+
18
+ The returned value will be used when calling ``submit()`` instead of the original value.
19
+ """
20
+ # By default, does not do anything
21
+ return value
22
+
23
+ @abc.abstractmethod
24
+ def submit(self, fn: Callable, *args, **kwargs) -> ProfiledFuture:
25
+ """
26
+ Executes an asynchronous computation.
27
+
28
+ **Blocking if no computing resource is available**.
29
+ """
30
+
31
+ raise NotImplementedError()
32
+
33
+
34
+ class BackendEngine(metaclass=abc.ABCMeta):
35
+ """
36
+ Asynchronous task manager interface.
37
+ """
38
+
39
+ @abc.abstractmethod
40
+ def session(self) -> BackendSession:
41
+ """
42
+ Returns a new managed session for submitting tasks.
43
+
44
+ .. code:: python
45
+
46
+ with backend.session() as session:
47
+ arg_ref = session.preload_value(arg)
48
+
49
+ future = session.submit(fn, arg_ref)
50
+
51
+ print(future.result())
52
+
53
+ """
54
+ raise NotImplementedError()
55
+
56
+ @abc.abstractmethod
57
+ def shutdown(self):
58
+ """
59
+ Shutdowns all resources required by the backend engine.
60
+ """
61
+ raise NotImplementedError()
62
+
63
+ @abc.abstractmethod
64
+ def allows_nested_tasks(self) -> bool:
65
+ """
66
+ Indicates if Parfun can submit new tasks from other tasks.
67
+ """
68
+ raise NotImplementedError()
@@ -0,0 +1,50 @@
1
+ from concurrent.futures import Future
2
+ from typing import Any, Optional, Tuple
3
+
4
+ from parfun.profiler.object import TraceTime
5
+
6
+
7
+ class ProfiledFuture(Future):
8
+ """Future that provides an additional duration metric used to profile the task's duration."""
9
+
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+
13
+ self._duration = None
14
+
15
+ def set_result(self, result: Any, duration: Optional[TraceTime] = None):
16
+ # Sets the task duration before the result, as set_result() triggers all completion callbacks.
17
+ self._duration = duration
18
+ return super().set_result(result)
19
+
20
+ def set_exception(self, exception: Optional[BaseException], duration: Optional[TraceTime] = None) -> None:
21
+ # Sets the task duration before the exception, as set_exception() triggers all completion callbacks.
22
+ self._duration = duration
23
+ return super().set_exception(exception)
24
+
25
+ def duration(self, timeout: Optional[float] = None) -> Optional[TraceTime]:
26
+ """
27
+ The total CPU time (i.e. user + system times) required to run the task, or `None` if the task didn't provide
28
+ task profiling.
29
+
30
+ This **should** include the overhead time required to schedule the task.
31
+ """
32
+
33
+ self.exception(timeout) # Waits until the task finishes.
34
+
35
+ return self._duration
36
+
37
+ def result_and_duration(self, timeout: Optional[float] = None) -> Tuple[Any, Optional[TraceTime]]:
38
+ """
39
+ Combines the calls to `result() and duration()`:
40
+
41
+ .. code:: python
42
+
43
+ result, duration = future.result(), future.duration()
44
+ # is equivalent to
45
+ result, duration = future.result_and_duration()
46
+
47
+ """
48
+
49
+ result = self.result(timeout)
50
+ return result, self._duration
@@ -0,0 +1,226 @@
1
+ import inspect
2
+ import itertools
3
+ import threading
4
+ from collections import deque
5
+ from threading import BoundedSemaphore
6
+ from typing import Any, Deque, Dict, Optional, Set, Tuple
7
+
8
+ try:
9
+ from scaler import Client, SchedulerClusterCombo
10
+ from scaler.client.future import ScalerFuture
11
+ from scaler.client.object_reference import ObjectReference
12
+ except ImportError:
13
+ raise ImportError("Scaler dependency missing. Use `pip install 'opengris-parfun[scaler]'` to install Scaler.")
14
+
15
+ import psutil
16
+
17
+ from parfun.backend.mixins import BackendEngine, BackendSession
18
+ from parfun.backend.profiled_future import ProfiledFuture
19
+ from parfun.profiler.functions import profile
20
+
21
+
22
+ class ScalerSession(BackendSession):
23
+ # Additional constant scheduling overhead that cannot be accounted for when measuring the task execution duration.
24
+ CONSTANT_SCHEDULING_OVERHEAD = 8_000_000 # 8ms
25
+
26
+ def __init__(self, scheduler_address: str, n_workers: int, client_kwargs: Dict):
27
+ self._concurrent_task_guard = BoundedSemaphore(n_workers)
28
+ self._client = Client(address=scheduler_address, profiling=True, **client_kwargs)
29
+
30
+ def __enter__(self) -> "ScalerSession":
31
+ return self
32
+
33
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
34
+ self._client.disconnect()
35
+
36
+ def preload_value(self, value: Any) -> ObjectReference:
37
+ return self._client.send_object(value)
38
+
39
+ def submit(self, fn, *args, **kwargs) -> Optional[ProfiledFuture]:
40
+ with profile() as submit_duration:
41
+ future = ProfiledFuture()
42
+
43
+ acquired = self._concurrent_task_guard.acquire()
44
+ if not acquired:
45
+ return None
46
+
47
+ underlying_future = self._client.submit(fn, *args, **kwargs)
48
+
49
+ def on_done_callback(underlying_future: ScalerFuture):
50
+ assert submit_duration.value is not None
51
+
52
+ if underlying_future.cancelled():
53
+ future.cancel()
54
+ return
55
+
56
+ with profile() as release_duration:
57
+ exception = underlying_future.exception()
58
+
59
+ if exception is None:
60
+ result = underlying_future.result()
61
+ function_duration = int(underlying_future.profiling_info().cpu_time_s * 1_000_000_000)
62
+ else:
63
+ function_duration = 0
64
+ result = None
65
+
66
+ self._concurrent_task_guard.release()
67
+
68
+ task_duration = (
69
+ self.CONSTANT_SCHEDULING_OVERHEAD + submit_duration.value + function_duration + release_duration.value
70
+ )
71
+
72
+ if exception is None:
73
+ future.set_result(result, duration=task_duration)
74
+ else:
75
+ future.set_exception(exception, duration=task_duration)
76
+
77
+ underlying_future.add_done_callback(on_done_callback)
78
+
79
+ return future
80
+
81
+
82
+ class ScalerClientPool:
83
+ def __init__(self, scheduler_address: str, client_kwargs: Dict, max_unused_clients: int = 1):
84
+ self._scheduler_address = scheduler_address
85
+ self._client_kwargs = client_kwargs
86
+ self._max_unused_clients = max_unused_clients
87
+
88
+ self._lock = threading.Lock()
89
+ self._assigned_clients: Dict[bytes, Client] = {}
90
+ self._unassigned_clients: Deque[Client] = deque() # a FIFO poll of up to `max_unused_clients`.
91
+
92
+ def acquire(self) -> Client:
93
+ with self._lock:
94
+ if len(self._unassigned_clients) > 0:
95
+ client = self._unassigned_clients.popleft()
96
+ else:
97
+ client = Client(address=self._scheduler_address, profiling=True, **self._client_kwargs)
98
+
99
+ self._assigned_clients[client.identity] = client
100
+
101
+ return client
102
+
103
+ def release(self, client: Client) -> None:
104
+ with self._lock:
105
+ self._assigned_clients.pop(client.identity)
106
+
107
+ if len(self._unassigned_clients) < self._max_unused_clients:
108
+ self._unassigned_clients.append(client)
109
+ else:
110
+ client.disconnect()
111
+
112
+ def disconnect_all(self) -> None:
113
+ with self._lock:
114
+ for client in itertools.chain(self._unassigned_clients, self._assigned_clients.values()):
115
+ client.disconnect()
116
+
117
+ self._unassigned_clients.clear()
118
+ self._assigned_clients.clear()
119
+
120
+
121
+ class ScalerRemoteBackend(BackendEngine):
122
+ """Connects to a previously instantiated Scaler instance as a backend engine."""
123
+
124
+ def __init__(
125
+ self,
126
+ scheduler_address: str,
127
+ n_workers: int = psutil.cpu_count(logical=False) - 1,
128
+ allows_nested_tasks: bool = True,
129
+ **client_kwargs,
130
+ ):
131
+ self._scheduler_address = scheduler_address
132
+ self._n_workers = n_workers
133
+ self._allows_nested_tasks = allows_nested_tasks
134
+ self._client_kwargs = client_kwargs
135
+
136
+ def __getstate__(self) -> dict:
137
+ return {
138
+ "scheduler_address": self._scheduler_address,
139
+ "n_workers": self._n_workers,
140
+ "allows_nested_tasks": self._allows_nested_tasks,
141
+ "client_kwargs": self._client_kwargs,
142
+ }
143
+
144
+ def __setstate__(self, state: dict) -> None:
145
+ self._scheduler_address = state["scheduler_address"]
146
+ self._n_workers = state["n_workers"]
147
+ self._allows_nested_tasks = state["allows_nested_tasks"]
148
+ self._client_kwargs = state["client_kwargs"]
149
+
150
+ def session(self) -> ScalerSession:
151
+ return ScalerSession(self._scheduler_address, self._n_workers, self._client_kwargs)
152
+
153
+ def get_scheduler_address(self) -> str:
154
+ return self._scheduler_address
155
+
156
+ def disconnect(self):
157
+ pass
158
+
159
+ def shutdown(self):
160
+ pass
161
+
162
+ def allows_nested_tasks(self) -> bool:
163
+ return self._allows_nested_tasks
164
+
165
+
166
+ class ScalerLocalBackend(ScalerRemoteBackend):
167
+ """Creates a Scaler cluster on the local machine and uses it as a backend engine."""
168
+
169
+ def __init__(
170
+ self,
171
+ scheduler_address: Optional[str] = None,
172
+ n_workers: int = psutil.cpu_count(logical=False) - 1,
173
+ per_worker_task_queue_size: int = 1000,
174
+ allows_nested_tasks: bool = True,
175
+ logging_paths: Tuple[str, ...] = ("/dev/stdout",),
176
+ logging_level: str = "INFO",
177
+ logging_config_file: Optional[str] = None,
178
+ **kwargs,
179
+ ):
180
+ """
181
+ :param scheduler_address the ``tcp://host:port`` tuple to use as a cluster address. If ``None``, listen to the
182
+ local host on an available TCP port.
183
+ """
184
+
185
+ client_kwargs = self.__get_constructor_arg_names(Client)
186
+ scheduler_cluster_combo_kwargs = self.__get_constructor_arg_names(SchedulerClusterCombo)
187
+
188
+ self._cluster = SchedulerClusterCombo(
189
+ address=scheduler_address,
190
+ n_workers=n_workers,
191
+ logging_paths=logging_paths,
192
+ logging_level=logging_level,
193
+ logging_config_file=logging_config_file,
194
+ per_worker_task_queue_size=per_worker_task_queue_size,
195
+ **{kwarg: value for kwarg, value in kwargs.items() if kwarg in scheduler_cluster_combo_kwargs},
196
+ )
197
+ scheduler_address = self._cluster.get_address()
198
+
199
+ super().__init__(
200
+ scheduler_address=scheduler_address,
201
+ allows_nested_tasks=allows_nested_tasks,
202
+ n_workers=n_workers,
203
+ **{kwarg: value for kwarg, value in kwargs.items() if kwarg in client_kwargs},
204
+ )
205
+
206
+ def __setstate__(self, state: dict) -> None:
207
+ super().__setstate__(state)
208
+ self._cluster = None # Unserialized instances have no cluster reference.
209
+
210
+ @property
211
+ def cluster(self) -> SchedulerClusterCombo:
212
+ if self._cluster is None:
213
+ raise AttributeError("cluster is undefined for serialized instances.")
214
+
215
+ return self._cluster
216
+
217
+ def shutdown(self):
218
+ super().shutdown()
219
+
220
+ if self._cluster is not None:
221
+ self._cluster.shutdown()
222
+ self._cluster = None
223
+
224
+ @staticmethod
225
+ def __get_constructor_arg_names(class_: type) -> Set:
226
+ return set(inspect.signature(class_).parameters.keys())
@@ -0,0 +1,7 @@
1
+ import socket
2
+
3
+
4
+ def get_available_tcp_port(hostname: str = "127.0.0.1") -> int:
5
+ with socket.socket(socket.AddressFamily.AF_INET, socket.SocketKind.SOCK_STREAM) as sock:
6
+ sock.bind((hostname, 0))
7
+ return sock.getsockname()[1]
File without changes
@@ -0,0 +1,13 @@
1
+ import warnings
2
+
3
+ from parfun.py_list import concat
4
+
5
+
6
+ warnings.warn(
7
+ "parfun.combine.collection is deprecated and will be removed in a future version, use parfun.py_list.",
8
+ DeprecationWarning
9
+ )
10
+
11
+ list_concat = concat
12
+
13
+ __all__ = ["list_concat"]
@@ -0,0 +1,13 @@
1
+ import warnings
2
+
3
+ from parfun.dataframe import concat
4
+
5
+
6
+ warnings.warn(
7
+ "parfun.combine.dataframe is deprecated and will be removed in a future version, use parfun.dataframe.",
8
+ DeprecationWarning
9
+ )
10
+
11
+ df_concat = concat
12
+
13
+ __all__ = ["df_concat"]
parfun/dataframe.py ADDED
@@ -0,0 +1,175 @@
1
+ """
2
+ A collection of pre-define APIs to partition and combine dataframes.
3
+ """
4
+
5
+ from typing import Iterable, List, Tuple
6
+
7
+ try:
8
+ import pandas as pd
9
+ except ImportError:
10
+ raise ImportError("Pandas dependency missing. Use `pip install 'opengris-parfun[pandas]'` to install Pandas.")
11
+
12
+ from parfun.partition.object import PartitionFunction, PartitionGenerator
13
+
14
+
15
+ def concat(dfs: Iterable[pd.DataFrame]) -> pd.DataFrame:
16
+ """
17
+ Similar to :py:func:`pandas.concat`.
18
+
19
+ .. code:: python
20
+
21
+ df_1 = pd.DataFrame([1,2,3])
22
+ df_2 = pd.DataFrame([4,5,6])
23
+
24
+ print(concat([df_1, df_2]))
25
+ # 0
26
+ # 0 1
27
+ # 1 2
28
+ # 2 3
29
+ # 3 4
30
+ # 4 5
31
+ # 5 6
32
+
33
+ """
34
+
35
+ return pd.concat(dfs, ignore_index=True)
36
+
37
+
38
+ def by_row(*dfs: pd.DataFrame) -> PartitionGenerator[Tuple[pd.DataFrame, ...]]:
39
+ """
40
+ Partitions one or multiple Pandas dataframes by rows.
41
+
42
+ If multiple dataframes are given, these returned partitions will be of identical number of rows.
43
+
44
+ .. code:: python
45
+
46
+ df_1 = pd.DataFrame(range(0, 5))
47
+ df_2 = df_1 ** 2
48
+
49
+ with_partition_size(by_row(df_1, df_2), partition_size=2)
50
+
51
+ # ( 0
52
+ # 1 0
53
+ # 2 1,
54
+ # 0
55
+ # 1 0
56
+ # 2 1),
57
+ # ( 0
58
+ # 3 2
59
+ # 4 3,
60
+ # 0
61
+ # 3 4
62
+ # 4 9),
63
+ # ( 0
64
+ # 5 4,
65
+ # 0
66
+ # 5 16)]
67
+
68
+ """
69
+
70
+ __validate_dfs_parameter(*dfs)
71
+
72
+ chunk_size = yield None
73
+
74
+ def dfs_chunk(rng_start: int, rng_end: int) -> Tuple[pd.DataFrame, ...]:
75
+ return tuple(df.iloc[rng_start:rng_end] for df in dfs)
76
+
77
+ total_size = dfs[0].shape[0]
78
+ range_start = 0
79
+ range_end = chunk_size
80
+ while range_end < total_size:
81
+ chunk_size = yield chunk_size, dfs_chunk(range_start, range_end)
82
+
83
+ range_start = range_end
84
+ range_end += chunk_size
85
+
86
+ if range_start < total_size:
87
+ yield total_size - range_start, dfs_chunk(range_start, total_size)
88
+
89
+
90
+ def by_group(*args, **kwargs) -> PartitionFunction[pd.DataFrame]:
91
+ """
92
+ Partitions one or multiple Pandas dataframes by groups of identical numbers of rows, similar to
93
+ :py:func:`pandas.DataFrame.groupby`.
94
+
95
+ See :py:func:`pandas.DataFrame.groupby` for function parameters.
96
+
97
+ .. code:: python
98
+
99
+ df_1 = pd.DataFrame({"country": ["USA", "China", "Belgium"], "capital": ["Washington", "Beijing", "Brussels"]})
100
+ df_2 = pd.DataFrame({"country": ["USA", "China", "Belgium"], "iso_code": ["US", "CN", "BE"]})
101
+
102
+ with_partition_size(df_by_group(by="country")(df_1, df_2), partition_size=1)
103
+
104
+ # [( country capital
105
+ # 2 Belgium Brussels,
106
+ # country iso_code
107
+ # 2 Belgium BE),
108
+ # ( country capital
109
+ # 1 China Beijing,
110
+ # country iso_code
111
+ # 1 China CN),
112
+ # ( country capital
113
+ # 0 USA Washington,
114
+ # country iso_code
115
+ # 0 USA US)]
116
+
117
+ """
118
+
119
+ def generator(*dfs: pd.DataFrame) -> PartitionGenerator[Tuple[pd.DataFrame, ...]]:
120
+ __validate_dfs_parameter(*dfs)
121
+
122
+ groups: Iterable[Tuple[pd.DataFrame, ...]] = zip(
123
+ *((group for _name, group in df.groupby(*args, **kwargs)) for df in dfs)
124
+ )
125
+
126
+ it = iter(groups)
127
+
128
+ chunked_group: Tuple[List[pd.DataFrame], ...] = tuple([] for _ in range(0, len(dfs)))
129
+ chunked_group_size: int = 0
130
+
131
+ target_chunk_size = yield None
132
+
133
+ def concat_chunked_group_dfs(chunked_group: Tuple[List[pd.DataFrame], ...]):
134
+ return tuple(pd.concat(chunked_dfs) for chunked_dfs in chunked_group)
135
+
136
+ while True:
137
+ try:
138
+ group = next(it)
139
+ assert isinstance(group, tuple)
140
+ assert isinstance(group[0], pd.DataFrame)
141
+
142
+ group_size = group[0].shape[0]
143
+
144
+ if any(group_df.shape[0] != group_size for group_df in group[1:]):
145
+ raise ValueError("all dataframe group sizes should be identical.")
146
+
147
+ chunked_group_size += group_size
148
+
149
+ for i, group_df in enumerate(group):
150
+ chunked_group[i].append(group_df)
151
+
152
+ if chunked_group_size >= target_chunk_size:
153
+ target_chunk_size = yield chunked_group_size, concat_chunked_group_dfs(chunked_group)
154
+
155
+ chunked_group = tuple([] for _ in range(0, len(dfs)))
156
+ chunked_group_size = 0
157
+ except StopIteration:
158
+ if chunked_group_size > 0:
159
+ yield chunked_group_size, concat_chunked_group_dfs(chunked_group)
160
+
161
+ return
162
+
163
+ return generator
164
+
165
+
166
+ def __validate_dfs_parameter(*dfs: pd.DataFrame) -> None:
167
+ if len(dfs) < 1:
168
+ raise ValueError("missing `dfs` parameter.")
169
+
170
+ if any(not isinstance(df, pd.DataFrame) for df in dfs):
171
+ raise ValueError("all `dfs` values should be DataFrame instances.")
172
+
173
+ total_size = dfs[0].shape[0]
174
+ if any(df.shape[0] != total_size for df in dfs[1:]):
175
+ raise ValueError("all DataFrames should have the same number of rows.")