ssb-sgis 1.1.17__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,11 +4,12 @@ import numpy as np
4
4
  import pandas as pd
5
5
  from geopandas import GeoDataFrame
6
6
  from geopandas import GeoSeries
7
- from geopandas import __version__ as geopandas_version
8
7
  from shapely import Geometry
9
- from shapely import STRtree
10
8
 
9
+ from ..conf import _get_instance
10
+ from ..conf import config
11
11
  from .conversion import to_gdf
12
+ from .runners import RTreeQueryRunner
12
13
 
13
14
  gdf_type_error_message = "'gdf' should be of type GeoDataFrame or GeoSeries."
14
15
 
@@ -18,6 +19,8 @@ def sfilter(
18
19
  other: GeoDataFrame | GeoSeries | Geometry,
19
20
  predicate: str = "intersects",
20
21
  distance: int | float | None = None,
22
+ n_jobs: int | None = None,
23
+ rtree_runner: RTreeQueryRunner | None = None,
21
24
  ) -> GeoDataFrame:
22
25
  """Filter a GeoDataFrame or GeoSeries by spatial predicate.
23
26
 
@@ -33,6 +36,9 @@ def sfilter(
33
36
  other: The geometry object to filter 'gdf' by.
34
37
  predicate: Spatial predicate to use. Defaults to 'intersects'.
35
38
  distance: Max distance to allow if predicate=="dwithin".
39
+ n_jobs: Number of workers.
40
+ rtree_runner: Optionally debug/manipulate the spatial indexing operations.
41
+ See the 'runners' module for example implementations.
36
42
 
37
43
  Returns:
38
44
  A copy of 'gdf' with only the rows matching the
@@ -80,7 +86,9 @@ def sfilter(
80
86
 
81
87
  other = _sfilter_checks(other, crs=gdf.crs)
82
88
 
83
- indices = _get_sfilter_indices(gdf, other, predicate, distance)
89
+ indices = _get_sfilter_indices(
90
+ gdf, other, predicate, distance, n_jobs, rtree_runner
91
+ )
84
92
 
85
93
  return gdf.iloc[indices]
86
94
 
@@ -90,6 +98,8 @@ def sfilter_split(
90
98
  other: GeoDataFrame | GeoSeries | Geometry,
91
99
  predicate: str = "intersects",
92
100
  distance: int | float | None = None,
101
+ n_jobs: int = 1,
102
+ rtree_runner: RTreeQueryRunner | None = None,
93
103
  ) -> tuple[GeoDataFrame, GeoDataFrame]:
94
104
  """Split a GeoDataFrame or GeoSeries by spatial predicate.
95
105
 
@@ -101,6 +111,9 @@ def sfilter_split(
101
111
  other: The geometry object to filter 'gdf' by.
102
112
  predicate: Spatial predicate to use. Defaults to 'intersects'.
103
113
  distance: Max distance to allow if predicate=="dwithin".
114
+ n_jobs: Number of workers.
115
+ rtree_runner: Optionally debug/manipulate the spatial indexing operations.
116
+ See the 'runners' module for example implementations.
104
117
 
105
118
  Returns:
106
119
  A tuple of GeoDataFrames, one with the rows that match the spatial predicate
@@ -151,7 +164,9 @@ def sfilter_split(
151
164
 
152
165
  other = _sfilter_checks(other, crs=gdf.crs)
153
166
 
154
- indices = _get_sfilter_indices(gdf, other, predicate, distance)
167
+ indices = _get_sfilter_indices(
168
+ gdf, other, predicate, distance, n_jobs, rtree_runner
169
+ )
155
170
 
156
171
  return (
157
172
  gdf.iloc[indices],
@@ -164,6 +179,8 @@ def sfilter_inverse(
164
179
  other: GeoDataFrame | GeoSeries | Geometry,
165
180
  predicate: str = "intersects",
166
181
  distance: int | float | None = None,
182
+ n_jobs: int = 1,
183
+ rtree_runner: RTreeQueryRunner | None = None,
167
184
  ) -> GeoDataFrame | GeoSeries:
168
185
  """Filter a GeoDataFrame or GeoSeries by inverse spatial predicate.
169
186
 
@@ -174,6 +191,9 @@ def sfilter_inverse(
174
191
  other: The geometry object to filter 'gdf' by.
175
192
  predicate: Spatial predicate to use. Defaults to 'intersects'.
176
193
  distance: Max distance to allow if predicate=="dwithin".
194
+ n_jobs: Number of workers.
195
+ rtree_runner: Optionally debug/manipulate the spatial indexing operations.
196
+ See the 'runners' module for example implementations.
177
197
 
178
198
  Returns:
179
199
  A copy of 'gdf' with only the rows that do not match the
@@ -215,11 +235,10 @@ def sfilter_inverse(
215
235
  """
216
236
  if not isinstance(gdf, (GeoDataFrame | GeoSeries)):
217
237
  raise TypeError(gdf_type_error_message)
218
-
219
238
  other = _sfilter_checks(other, crs=gdf.crs)
220
-
221
- indices = _get_sfilter_indices(gdf, other, predicate, distance)
222
-
239
+ indices = _get_sfilter_indices(
240
+ gdf, other, predicate, distance, n_jobs, rtree_runner
241
+ )
223
242
  return gdf.iloc[pd.Index(range(len(gdf))).difference(pd.Index(indices))]
224
243
 
225
244
 
@@ -252,6 +271,8 @@ def _get_sfilter_indices(
252
271
  right: GeoDataFrame | GeoSeries | Geometry,
253
272
  predicate: str,
254
273
  distance: int | float | None,
274
+ n_jobs: int,
275
+ rtree_runner: RTreeQueryRunner | None,
255
276
  ) -> np.ndarray:
256
277
  """Compute geometric comparisons and get matching indices.
257
278
 
@@ -264,6 +285,9 @@ def _get_sfilter_indices(
264
285
  right : GeoDataFrame
265
286
  predicate : string
266
287
  Binary predicate to query.
288
+ n_jobs: Number of workers.
289
+ rtree_runner: Optionally debug/manipulate the spatial indexing operations.
290
+ See the 'runners' module for example implementations.
267
291
 
268
292
  Returns:
269
293
  -------
@@ -273,6 +297,9 @@ def _get_sfilter_indices(
273
297
  """
274
298
  original_predicate = predicate
275
299
 
300
+ if rtree_runner is None:
301
+ rtree_runner = _get_instance(config, "rtree_runner", n_jobs=n_jobs)
302
+
276
303
  with warnings.catch_warnings():
277
304
  # We don't need to show our own warning here
278
305
  # TODO remove this once the deprecation has been enforced
@@ -285,25 +312,16 @@ def _get_sfilter_indices(
285
312
  # contains is a faster predicate
286
313
  # see discussion at https://github.com/geopandas/geopandas/pull/1421
287
314
  predicate = "contains"
288
- sindex, kwargs = _get_spatial_tree(left)
289
- input_geoms = right.geometry if isinstance(right, GeoDataFrame) else right
315
+ arr1 = right.geometry.values
316
+ arr2 = left.geometry.values
290
317
  else:
291
318
  # all other predicates are symmetric
292
319
  # keep them the same
293
- sindex, kwargs = _get_spatial_tree(right)
294
- input_geoms = left.geometry if isinstance(left, GeoDataFrame) else left
320
+ arr1 = left.geometry.values
321
+ arr2 = right.geometry.values
295
322
 
296
- l_idx, r_idx = sindex.query(
297
- input_geoms, predicate=predicate, distance=distance, **kwargs
298
- )
323
+ left, right = rtree_runner.run(arr1, arr2, predicate=predicate, distance=distance)
299
324
 
300
325
  if original_predicate == "within":
301
- return np.sort(np.unique(r_idx))
302
-
303
- return np.sort(np.unique(l_idx))
304
-
305
-
306
- def _get_spatial_tree(df):
307
- if int(geopandas_version[0]) >= 1:
308
- return df.sindex, {"sort": False}
309
- return STRtree(df.geometry.values), {}
326
+ return np.sort(np.unique(right))
327
+ return np.sort(np.unique(left))
@@ -0,0 +1,37 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from geopandas import GeoSeries
4
+ from shapely import make_valid
5
+ from shapely import union_all
6
+
7
+ from .geometry_types import to_single_geom_type
8
+
9
+
10
+ def _unary_union_for_notna(geoms, **kwargs):
11
+ try:
12
+ return make_valid(union_all(geoms, **kwargs))
13
+ except TypeError:
14
+ return make_valid(union_all([geom for geom in geoms.dropna().values], **kwargs))
15
+
16
+
17
+ def make_valid_and_keep_geom_type(geoms: np.ndarray, geom_type: str) -> GeoSeries:
18
+ """Make GeometryCollections into (Multi)Polygons, (Multi)LineStrings or (Multi)Points.
19
+
20
+ Because GeometryCollections might appear after dissolving (union_all).
21
+ And this makes shapely difference/intersection fail.
22
+
23
+ Args:
24
+ geoms: Array of geometries.
25
+ geom_type: geometry type to be kept.
26
+ """
27
+ geoms = GeoSeries(geoms)
28
+ geoms.index = range(len(geoms))
29
+ geoms.loc[:] = make_valid(geoms.to_numpy())
30
+ geoms_with_correct_type = geoms.explode(index_parts=False).pipe(
31
+ to_single_geom_type, geom_type
32
+ )
33
+ only_one = geoms_with_correct_type.groupby(level=0).transform("size") == 1
34
+ one_hit = geoms_with_correct_type[only_one]
35
+ many_hits = geoms_with_correct_type[~only_one].groupby(level=0).agg(union_all)
36
+ geoms_with_wrong_type = geoms.loc[~geoms.index.isin(geoms_with_correct_type.index)]
37
+ return pd.concat([one_hit, many_hits, geoms_with_wrong_type]).sort_index()
sgis/helpers.py CHANGED
@@ -198,7 +198,7 @@ def get_all_files(root: str, recursive: bool = True) -> list[str]:
198
198
 
199
199
 
200
200
  def return_two_vals(
201
- vals: tuple[str, str] | list[str] | str | int | float
201
+ vals: tuple[str, str] | list[str] | str | int | float,
202
202
  ) -> tuple[str | int | float, str | int | float]:
203
203
  """Return a two-length tuple from a str/int/float or list/tuple of length 1 or 2.
204
204
 
@@ -128,11 +128,11 @@ def read_geopandas(
128
128
  return gpd.GeoDataFrame(
129
129
  _read_partitioned_parquet(
130
130
  gcs_path,
131
- read_func=pq.read_table,
132
131
  file_system=file_system,
133
132
  mask=mask,
134
133
  filters=filters,
135
134
  child_paths=child_paths,
135
+ use_threads=use_threads,
136
136
  **kwargs,
137
137
  )
138
138
  )
@@ -145,7 +145,7 @@ def read_geopandas(
145
145
  read_func = gpd.read_file
146
146
 
147
147
  with file_system.open(gcs_path, mode="rb") as file:
148
- return _read_geopandas(
148
+ return _read_geopandas_single_path(
149
149
  file,
150
150
  read_func=read_func,
151
151
  file_format=file_format,
@@ -163,18 +163,10 @@ def _read_geopandas_from_iterable(
163
163
  paths = list(paths.index)
164
164
  elif mask is None:
165
165
  paths = list(paths)
166
- else:
167
- if not isinstance(paths, GeoSeries):
168
- bounds_series: GeoSeries = get_bounds_series(
169
- paths,
170
- file_system,
171
- use_threads=use_threads,
172
- pandas_fallback=pandas_fallback,
173
- )
174
- else:
175
- bounds_series = paths
176
- new_bounds_series = sfilter(bounds_series, mask)
177
- if not len(new_bounds_series):
166
+ elif isinstance(paths, GeoSeries):
167
+ bounds_series = sfilter(paths, mask)
168
+ if not len(bounds_series):
169
+ # return GeoDataFrame with correct columns
178
170
  if isinstance(kwargs.get("columns"), Iterable):
179
171
  cols = {col: [] for col in kwargs["columns"]}
180
172
  else:
@@ -186,29 +178,14 @@ def _read_geopandas_from_iterable(
186
178
  if file_system.isfile(path):
187
179
  raise ArrowInvalid(e, path) from e
188
180
  return GeoDataFrame(cols | {"geometry": []})
189
- paths = list(new_bounds_series.index)
181
+ paths = list(bounds_series.index)
190
182
 
191
- # recursive read with threads
192
- threads = (
193
- min(len(paths), int(multiprocessing.cpu_count())) or 1 if use_threads else 1
183
+ results: list[pyarrow.Table] = _read_pyarrow_with_treads(
184
+ paths, file_system=file_system, mask=mask, use_threads=use_threads, **kwargs
194
185
  )
195
- with joblib.Parallel(n_jobs=threads, backend="threading") as parallel:
196
- dfs: list[GeoDataFrame] = parallel(
197
- joblib.delayed(read_geopandas)(
198
- x,
199
- file_system=file_system,
200
- pandas_fallback=pandas_fallback,
201
- mask=mask,
202
- use_threads=use_threads,
203
- **kwargs,
204
- )
205
- for x in paths
206
- )
207
-
208
- if dfs:
209
- df = pd.concat(dfs, ignore_index=True)
186
+ if results:
210
187
  try:
211
- df = GeoDataFrame(df)
188
+ return _concat_pyarrow_to_geopandas(results, paths, file_system)
212
189
  except Exception as e:
213
190
  if not pandas_fallback:
214
191
  print(e)
@@ -219,6 +196,49 @@ def _read_geopandas_from_iterable(
219
196
  return df
220
197
 
221
198
 
199
+ def _read_pyarrow_with_treads(
200
+ paths: list[str | Path | os.PathLike], file_system, use_threads, mask, **kwargs
201
+ ) -> list[pyarrow.Table]:
202
+ read_partial = functools.partial(
203
+ _read_pyarrow, mask=mask, file_system=file_system, **kwargs
204
+ )
205
+ if not use_threads:
206
+ return [x for x in map(read_partial, paths) if x is not None]
207
+ with ThreadPoolExecutor() as executor:
208
+ return [x for x in executor.map(read_partial, paths) if x is not None]
209
+
210
+
211
+ def intersects(file, mask, file_system) -> bool:
212
+ bbox, _ = _get_bounds_parquet_from_open_file(file, file_system)
213
+ return shapely.box(*bbox).intersects(to_shapely(mask))
214
+
215
+
216
+ def _read_pyarrow(path: str, file_system, mask=None, **kwargs) -> pyarrow.Table | None:
217
+ try:
218
+ with file_system.open(path, "rb") as file:
219
+ if mask is not None and not intersects(file, mask, file_system):
220
+ return
221
+
222
+ # 'get' instead of 'pop' because dict is mutable
223
+ schema = kwargs.get("schema", pq.read_schema(file))
224
+ new_kwargs = {
225
+ key: value for key, value in kwargs.items() if key != "schema"
226
+ }
227
+
228
+ return pq.read_table(file, schema=schema, **new_kwargs)
229
+ except ArrowInvalid as e:
230
+ glob_func = _get_glob_func(file_system)
231
+ if not len(
232
+ {
233
+ x
234
+ for x in glob_func(str(Path(path) / "**"))
235
+ if not paths_are_equal(path, x)
236
+ }
237
+ ):
238
+ raise e
239
+ # allow not being able to read empty directories that are hard to delete in gcs
240
+
241
+
222
242
  def _get_bounds_parquet(
223
243
  path: str | Path, file_system: GCSFileSystem, pandas_fallback: bool = False
224
244
  ) -> tuple[list[float], dict] | tuple[None, None]:
@@ -662,10 +682,10 @@ def expression_match_path(expression: ds.Expression, path: str) -> bool:
662
682
  return bool(len(table))
663
683
 
664
684
 
665
- def _read_geopandas(
685
+ def _read_geopandas_single_path(
666
686
  file,
667
- read_func: Callable = gpd.read_parquet,
668
- file_format: str = "parquet",
687
+ read_func: Callable,
688
+ file_format: str,
669
689
  **kwargs,
670
690
  ):
671
691
  try:
@@ -681,32 +701,29 @@ def _read_geopandas(
681
701
  raise e.__class__(f"{e.__class__.__name__}: {e} for {file}.") from e
682
702
 
683
703
 
684
- def _read_pandas(gcs_path: str, **kwargs):
704
+ def _read_pandas(gcs_path: str, use_threads: bool = True, **kwargs):
685
705
  file_system = _get_file_system(None, kwargs)
686
706
 
687
707
  if not isinstance(gcs_path, (str | Path | os.PathLike)):
688
- # recursive read with threads
689
- threads = (
690
- min(len(gcs_path), int(multiprocessing.cpu_count())) or 1
691
- if kwargs.get("use_threads")
692
- else 1
708
+ results: list[pyarrow.Table] = _read_pyarrow_with_treads(
709
+ gcs_path,
710
+ file_system=file_system,
711
+ mask=None,
712
+ use_threads=use_threads,
713
+ **kwargs,
693
714
  )
694
- with joblib.Parallel(n_jobs=threads, backend="threading") as parallel:
695
- return pd.concat(
696
- parallel(
697
- joblib.delayed(_read_pandas)(x, file_system=file_system, **kwargs)
698
- for x in gcs_path
699
- )
700
- )
715
+ results = pyarrow.concat_tables(results, promote_options="permissive")
716
+ return results.to_pandas()
701
717
 
702
718
  child_paths = get_child_paths(gcs_path, file_system)
703
719
  if child_paths:
704
720
  return _read_partitioned_parquet(
705
721
  gcs_path,
706
- read_func=pd.read_parquet,
707
722
  file_system=file_system,
708
723
  mask=None,
709
724
  child_paths=child_paths,
725
+ use_threads=use_threads,
726
+ to_geopandas=False,
710
727
  **kwargs,
711
728
  )
712
729
 
@@ -716,11 +733,12 @@ def _read_pandas(gcs_path: str, **kwargs):
716
733
 
717
734
  def _read_partitioned_parquet(
718
735
  path: str,
719
- read_func: Callable,
720
736
  filters=None,
721
737
  file_system=None,
722
738
  mask=None,
723
739
  child_paths: list[str] | None = None,
740
+ use_threads: bool = True,
741
+ to_geopandas: bool = True,
724
742
  **kwargs,
725
743
  ):
726
744
  file_system = _get_file_system(file_system, kwargs)
@@ -731,62 +749,22 @@ def _read_partitioned_parquet(
731
749
 
732
750
  filters = _filters_to_expression(filters)
733
751
 
734
- def intersects(file, mask) -> bool:
735
- bbox, _ = _get_bounds_parquet_from_open_file(file, file_system)
736
- return shapely.box(*bbox).intersects(to_shapely(mask))
737
-
738
- def read(child_path: str) -> pyarrow.Table | None:
739
- try:
740
- with file_system.open(child_path, "rb") as file:
741
- if mask is not None and not intersects(file, mask):
742
- return
743
-
744
- # 'get' instead of 'pop' because dict is mutable
745
- schema = kwargs.get("schema", pq.read_schema(file))
746
- new_kwargs = {
747
- key: value for key, value in kwargs.items() if key != "schema"
748
- }
749
-
750
- return read_func(file, schema=schema, filters=filters, **new_kwargs)
751
- except ArrowInvalid as e:
752
- if not len(
753
- {
754
- x
755
- for x in glob_func(str(Path(child_path) / "**"))
756
- if not paths_are_equal(child_path, x)
757
- }
758
- ):
759
- raise e
760
- # allow not being able to read hard-to-delete empty directories
752
+ results: list[pyarrow.Table] = _read_pyarrow_with_treads(
753
+ (
754
+ path
755
+ for path in child_paths
756
+ if filters is None or expression_match_path(filters, path)
757
+ ),
758
+ file_system=file_system,
759
+ mask=mask,
760
+ use_threads=use_threads,
761
+ **kwargs,
762
+ )
761
763
 
762
- with ThreadPoolExecutor() as executor:
763
- results = [
764
- df
765
- for df in (
766
- executor.map(
767
- read,
768
- (
769
- path
770
- for path in child_paths
771
- if filters is None or expression_match_path(filters, path)
772
- ),
773
- )
774
- )
775
- if df is not None
776
- ]
777
-
778
- if results:
779
- if all(isinstance(x, DataFrame) for x in results):
780
- return pd.concat(results)
781
- else:
782
- geo_metadata = _get_geo_metadata(next(iter(child_paths)), file_system)
783
- return _arrow_to_geopandas(
784
- pyarrow.concat_tables(
785
- results,
786
- promote_options="permissive",
787
- ),
788
- geo_metadata,
789
- )
764
+ if results and to_geopandas:
765
+ return _concat_pyarrow_to_geopandas(results, child_paths, file_system)
766
+ elif results:
767
+ return pyarrow.concat_tables(results, promote_options="permissive").to_pandas()
790
768
 
791
769
  # add columns to empty DataFrame
792
770
  first_path = next(iter(child_paths + [path]))
@@ -796,6 +774,17 @@ def _read_partitioned_parquet(
796
774
  return df
797
775
 
798
776
 
777
+ def _concat_pyarrow_to_geopandas(
778
+ results: list[pyarrow.Table], paths: list[str], file_system: Any
779
+ ):
780
+ results = pyarrow.concat_tables(
781
+ results,
782
+ promote_options="permissive",
783
+ )
784
+ geo_metadata = _get_geo_metadata(next(iter(paths)), file_system)
785
+ return _arrow_to_geopandas(results, geo_metadata)
786
+
787
+
799
788
  def paths_are_equal(path1: Path | str, path2: Path | str) -> bool:
800
789
  return Path(path1).parts == Path(path2).parts
801
790
 
sgis/maps/map.py CHANGED
@@ -307,7 +307,9 @@ class Map:
307
307
  notna = array[array.notna()]
308
308
  isna = array[array.isna()]
309
309
 
310
- unique_multiplied = (notna * self._multiplier).astype(np.int64)
310
+ unique_multiplied = (notna.astype(np.float64) * self._multiplier).astype(
311
+ np.int64
312
+ )
311
313
 
312
314
  return pd.concat([unique_multiplied, isna]).sort_index()
313
315
 
sgis/parallel/parallel.py CHANGED
@@ -75,13 +75,15 @@ def parallel_overlay(
75
75
  Returns:
76
76
  A GeoDataFrame containing the result of the overlay operation.
77
77
  """
78
+ if how != "intersection":
79
+ raise ValueError("parallel_overlay only supports how='intersection'.")
78
80
  return pd.concat(
79
81
  chunkwise(
80
82
  _clean_overlay_with_print,
81
83
  df1,
82
84
  kwargs={
83
85
  "df2": df2,
84
- # "to_print": to_print,
86
+ "to_print": to_print,
85
87
  "how": how,
86
88
  }
87
89
  | kwargs,
@@ -672,7 +674,7 @@ class Parallel:
672
674
  def chunkwise(
673
675
  self,
674
676
  func: Callable,
675
- iterable: Collection[Iterable[Any]],
677
+ *iterables: Collection[Iterable[Any]],
676
678
  args: tuple | None = None,
677
679
  kwargs: dict | None = None,
678
680
  max_rows_per_chunk: int | None = None,
@@ -682,8 +684,8 @@ class Parallel:
682
684
  Args:
683
685
  func: Function to run chunkwise. It should take
684
686
  (a chunk of) the iterable as first argument.
685
- iterable: Iterable to split in chunks and passed
686
- as first argument to 'func'.
687
+ iterables: Iterable(s) to split in chunks and passed
688
+ as first argument(s) to 'func'. Iterables must have same length.
687
689
  args: Positional arguments in 'func' after the DataFrame.
688
690
  kwargs: Additional keyword arguments in 'func'.
689
691
  max_rows_per_chunk: Alternatively decide number of chunks
@@ -691,7 +693,7 @@ class Parallel:
691
693
  """
692
694
  return chunkwise(
693
695
  func,
694
- iterable,
696
+ *iterables,
695
697
  args=args,
696
698
  kwargs=kwargs,
697
699
  processes=self.processes,
@@ -1067,7 +1069,7 @@ def _fix_missing_muni_numbers(
1067
1069
 
1068
1070
  def chunkwise(
1069
1071
  func: Callable,
1070
- iterable: Collection[Iterable[Any]],
1072
+ *iterables: Collection[Iterable[Any]],
1071
1073
  args: tuple | None = None,
1072
1074
  kwargs: dict | None = None,
1073
1075
  processes: int = 1,
@@ -1082,7 +1084,7 @@ def chunkwise(
1082
1084
  Args:
1083
1085
  func: The function to apply to each chunk. This function must accept a DataFrame as
1084
1086
  its first argument and return a DataFrame.
1085
- iterable: Iterable to be chunked and processed.
1087
+ iterables: Iterable(s) to be chunked and processed. Must have same length.
1086
1088
  args: Additional positional arguments to pass to 'func'.
1087
1089
  kwargs: Keyword arguments to pass to 'func'.
1088
1090
  processes: The number of parallel jobs to run. Defaults to 1 (no parallel execution).
@@ -1096,30 +1098,36 @@ def chunkwise(
1096
1098
  args = args or ()
1097
1099
  kwargs = kwargs or {}
1098
1100
 
1101
+ if len({len(x) for x in iterables}) not in [0, 1]:
1102
+ raise ValueError(
1103
+ f"iterables must have same length. Got {', '.join([len(x) for x in iterables])}"
1104
+ )
1105
+
1099
1106
  if max_rows_per_chunk is None:
1100
1107
  n_chunks: int = processes
1101
1108
  else:
1102
- n_chunks: int = len(iterable) // max_rows_per_chunk
1103
-
1109
+ n_chunks: int = len(next(iter(iterables))) // max_rows_per_chunk
1104
1110
  if n_chunks <= 1:
1105
- return [func(iterable, *args, **kwargs)]
1111
+ return [func(*iterables, *args, **kwargs)]
1106
1112
 
1107
- chunks = np.array_split(np.arange(len(iterable)), n_chunks)
1113
+ chunks = np.array_split(np.arange(len(next(iter(iterables)))), n_chunks)
1108
1114
 
1109
- if hasattr(iterable, "iloc"):
1110
- iterable_chunked: list[pd.DataFrame | pd.Series] = [
1111
- iterable.iloc[chunk] for chunk in chunks
1112
- ]
1113
- elif is_array_like(iterable):
1114
- iterable_chunked: list[np.ndarray] = [iterable[chunk] for chunk in chunks]
1115
- else:
1116
- to_type: type = iterable.__class__
1117
- iterable_chunked: list[Iterable] = [
1118
- to_type(chunk) for chunk in np.array_split(list(iterable), n_chunks)
1119
- ]
1120
- return Parallel(processes, backend=backend).map(
1115
+ def get_chunk(iterable, chunk):
1116
+ if hasattr(iterable, "iloc"):
1117
+ return iterable.iloc[chunk]
1118
+ elif is_array_like(iterable):
1119
+ return iterable[chunk]
1120
+ else:
1121
+ to_type: type = iterable.__class__
1122
+ return to_type([x for i, x in enumerate(iterable) if i in chunk])
1123
+
1124
+ iterables_chunked: list[list[Iterable[Any]]] = [
1125
+ [get_chunk(iterable, chunk) for iterable in iterables] for chunk in chunks
1126
+ ]
1127
+
1128
+ return Parallel(processes, backend=backend).starmap(
1121
1129
  func,
1122
- iterable_chunked,
1130
+ iterables_chunked,
1123
1131
  args=args,
1124
1132
  kwargs=kwargs,
1125
1133
  )