deltacat 0.1.10.dev0__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. deltacat/__init__.py +41 -15
  2. deltacat/aws/clients.py +12 -31
  3. deltacat/aws/constants.py +1 -1
  4. deltacat/aws/redshift/__init__.py +7 -2
  5. deltacat/aws/redshift/model/manifest.py +54 -50
  6. deltacat/aws/s3u.py +176 -187
  7. deltacat/catalog/delegate.py +151 -185
  8. deltacat/catalog/interface.py +78 -97
  9. deltacat/catalog/model/catalog.py +21 -21
  10. deltacat/catalog/model/table_definition.py +11 -9
  11. deltacat/compute/compactor/__init__.py +12 -16
  12. deltacat/compute/compactor/compaction_session.py +237 -166
  13. deltacat/compute/compactor/model/delta_annotated.py +60 -44
  14. deltacat/compute/compactor/model/delta_file_envelope.py +5 -6
  15. deltacat/compute/compactor/model/delta_file_locator.py +10 -8
  16. deltacat/compute/compactor/model/materialize_result.py +6 -7
  17. deltacat/compute/compactor/model/primary_key_index.py +38 -34
  18. deltacat/compute/compactor/model/pyarrow_write_result.py +3 -4
  19. deltacat/compute/compactor/model/round_completion_info.py +25 -19
  20. deltacat/compute/compactor/model/sort_key.py +18 -15
  21. deltacat/compute/compactor/steps/dedupe.py +119 -94
  22. deltacat/compute/compactor/steps/hash_bucket.py +48 -47
  23. deltacat/compute/compactor/steps/materialize.py +86 -92
  24. deltacat/compute/compactor/steps/rehash/rehash_bucket.py +13 -13
  25. deltacat/compute/compactor/steps/rehash/rewrite_index.py +5 -5
  26. deltacat/compute/compactor/utils/io.py +59 -47
  27. deltacat/compute/compactor/utils/primary_key_index.py +91 -80
  28. deltacat/compute/compactor/utils/round_completion_file.py +22 -23
  29. deltacat/compute/compactor/utils/system_columns.py +33 -45
  30. deltacat/compute/metastats/meta_stats.py +235 -157
  31. deltacat/compute/metastats/model/partition_stats_dict.py +7 -10
  32. deltacat/compute/metastats/model/stats_cluster_size_estimator.py +13 -5
  33. deltacat/compute/metastats/stats.py +95 -64
  34. deltacat/compute/metastats/utils/io.py +100 -53
  35. deltacat/compute/metastats/utils/pyarrow_memory_estimation_function.py +5 -2
  36. deltacat/compute/metastats/utils/ray_utils.py +38 -33
  37. deltacat/compute/stats/basic.py +107 -69
  38. deltacat/compute/stats/models/delta_column_stats.py +11 -8
  39. deltacat/compute/stats/models/delta_stats.py +59 -32
  40. deltacat/compute/stats/models/delta_stats_cache_result.py +4 -1
  41. deltacat/compute/stats/models/manifest_entry_stats.py +12 -6
  42. deltacat/compute/stats/models/stats_result.py +24 -14
  43. deltacat/compute/stats/utils/intervals.py +16 -9
  44. deltacat/compute/stats/utils/io.py +86 -51
  45. deltacat/compute/stats/utils/manifest_stats_file.py +24 -33
  46. deltacat/constants.py +4 -13
  47. deltacat/io/__init__.py +2 -2
  48. deltacat/io/aws/redshift/redshift_datasource.py +157 -143
  49. deltacat/io/dataset.py +14 -17
  50. deltacat/io/read_api.py +36 -33
  51. deltacat/logs.py +94 -42
  52. deltacat/storage/__init__.py +18 -8
  53. deltacat/storage/interface.py +196 -213
  54. deltacat/storage/model/delta.py +45 -51
  55. deltacat/storage/model/list_result.py +12 -8
  56. deltacat/storage/model/namespace.py +4 -5
  57. deltacat/storage/model/partition.py +42 -42
  58. deltacat/storage/model/stream.py +29 -30
  59. deltacat/storage/model/table.py +14 -14
  60. deltacat/storage/model/table_version.py +32 -31
  61. deltacat/storage/model/types.py +1 -0
  62. deltacat/tests/stats/test_intervals.py +11 -24
  63. deltacat/tests/utils/__init__.py +0 -0
  64. deltacat/tests/utils/test_record_batch_tables.py +284 -0
  65. deltacat/types/media.py +3 -4
  66. deltacat/types/tables.py +31 -21
  67. deltacat/utils/common.py +5 -11
  68. deltacat/utils/numpy.py +20 -22
  69. deltacat/utils/pandas.py +73 -100
  70. deltacat/utils/performance.py +3 -9
  71. deltacat/utils/placement.py +259 -230
  72. deltacat/utils/pyarrow.py +302 -89
  73. deltacat/utils/ray_utils/collections.py +2 -1
  74. deltacat/utils/ray_utils/concurrency.py +27 -28
  75. deltacat/utils/ray_utils/dataset.py +28 -28
  76. deltacat/utils/ray_utils/performance.py +5 -9
  77. deltacat/utils/ray_utils/runtime.py +9 -10
  78. {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/METADATA +1 -1
  79. deltacat-0.1.12.dist-info/RECORD +110 -0
  80. deltacat-0.1.10.dev0.dist-info/RECORD +0 -108
  81. {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/LICENSE +0 -0
  82. {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/WHEEL +0 -0
  83. {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/top_level.txt +0 -0
deltacat/utils/pyarrow.py CHANGED
@@ -1,25 +1,31 @@
1
- import pyarrow as pa
2
- import gzip
1
+ # Allow classes to use self-referencing Type hints in Python 3.7.
2
+ from __future__ import annotations
3
+
3
4
  import bz2
5
+ import gzip
4
6
  import io
5
7
  import logging
6
-
7
8
  from functools import partial
8
- from fsspec import AbstractFileSystem
9
-
10
- from pyarrow import feather as paf, parquet as papq, csv as pacsv, \
11
- json as pajson
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional
12
10
 
11
+ import pyarrow as pa
12
+ from fsspec import AbstractFileSystem
13
+ from pyarrow import csv as pacsv
14
+ from pyarrow import feather as paf
15
+ from pyarrow import json as pajson
16
+ from pyarrow import parquet as papq
13
17
  from ray.data.datasource import BlockWritePathProvider
14
18
 
15
19
  from deltacat import logs
16
- from deltacat.types.media import ContentType, ContentEncoding, \
17
- DELIMITED_TEXT_CONTENT_TYPES, TABULAR_CONTENT_TYPES
18
- from deltacat.utils.common import ReadKwargsProvider, ContentTypeKwargsProvider
20
+ from deltacat.types.media import (
21
+ DELIMITED_TEXT_CONTENT_TYPES,
22
+ TABULAR_CONTENT_TYPES,
23
+ ContentEncoding,
24
+ ContentType,
25
+ )
26
+ from deltacat.utils.common import ContentTypeKwargsProvider, ReadKwargsProvider
19
27
  from deltacat.utils.performance import timed_invocation
20
28
 
21
- from typing import Any, Callable, Dict, List, Optional, Iterable, Union
22
-
23
29
  logger = logs.configure_deltacat_logger(logging.getLogger(__name__))
24
30
 
25
31
 
@@ -33,35 +39,28 @@ CONTENT_TYPE_TO_PA_READ_FUNC: Dict[str, Callable] = {
33
39
  # Pyarrow.orc is disabled in Pyarrow 0.15, 0.16:
34
40
  # https://issues.apache.org/jira/browse/ARROW-7811
35
41
  # ContentType.ORC.value: paorc.ContentType.ORCFile,
36
- ContentType.JSON.value: pajson.read_json
42
+ ContentType.JSON.value: pajson.read_json,
37
43
  }
38
44
 
39
45
 
40
46
  def write_feather(
41
- table: pa.Table,
42
- path: str,
43
- *,
44
- filesystem: AbstractFileSystem,
45
- **kwargs) -> None:
47
+ table: pa.Table, path: str, *, filesystem: AbstractFileSystem, **kwargs
48
+ ) -> None:
46
49
 
47
50
  with filesystem.open(path, "wb") as f:
48
51
  paf.write_feather(table, f, **kwargs)
49
52
 
50
53
 
51
54
  def write_csv(
52
- table: pa.Table,
53
- path: str,
54
- *,
55
- filesystem: AbstractFileSystem,
56
- **kwargs) -> None:
55
+ table: pa.Table, path: str, *, filesystem: AbstractFileSystem, **kwargs
56
+ ) -> None:
57
57
 
58
58
  with filesystem.open(path, "wb") as f:
59
59
  # TODO (pdames): Add support for client-specified compression types.
60
60
  with pa.CompressedOutputStream(f, ContentEncoding.GZIP.value) as out:
61
61
  if kwargs.get("write_options") is None:
62
62
  # column names are kept in table metadata, so omit header
63
- kwargs["write_options"] = pacsv.WriteOptions(
64
- include_header=False)
63
+ kwargs["write_options"] = pacsv.WriteOptions(include_header=False)
65
64
  pacsv.write_csv(table, out, **kwargs)
66
65
 
67
66
 
@@ -78,13 +77,11 @@ CONTENT_TYPE_TO_PA_WRITE_FUNC: Dict[str, Callable] = {
78
77
  def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
79
78
  if content_type == ContentType.UNESCAPED_TSV.value:
80
79
  return {
81
- "parse_options": pacsv.ParseOptions(
82
- delimiter="\t",
83
- quote_char=False),
80
+ "parse_options": pacsv.ParseOptions(delimiter="\t", quote_char=False),
84
81
  "convert_options": pacsv.ConvertOptions(
85
82
  null_values=[""], # pyarrow defaults are ["", "NULL", "null"]
86
83
  strings_can_be_null=True,
87
- )
84
+ ),
88
85
  }
89
86
  if content_type == ContentType.TSV.value:
90
87
  return {"parse_options": pacsv.ParseOptions(delimiter="\t")}
@@ -92,9 +89,11 @@ def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
92
89
  return {"parse_options": pacsv.ParseOptions(delimiter=",")}
93
90
  if content_type == ContentType.PSV.value:
94
91
  return {"parse_options": pacsv.ParseOptions(delimiter="|")}
95
- if content_type in {ContentType.PARQUET.value,
96
- ContentType.FEATHER.value,
97
- ContentType.JSON.value}:
92
+ if content_type in {
93
+ ContentType.PARQUET.value,
94
+ ContentType.FEATHER.value,
95
+ ContentType.JSON.value,
96
+ }:
98
97
  return {}
99
98
  # Pyarrow.orc is disabled in Pyarrow 0.15, 0.16:
100
99
  # https://issues.apache.org/jira/browse/ARROW-7811
@@ -105,15 +104,13 @@ def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
105
104
 
106
105
  # TODO (pdames): add deflate and snappy
107
106
  ENCODING_TO_FILE_INIT: Dict[str, Callable] = {
108
- ContentEncoding.GZIP.value: partial(gzip.GzipFile, mode='rb'),
109
- ContentEncoding.BZIP2.value: partial(bz2.BZ2File, mode='rb'),
107
+ ContentEncoding.GZIP.value: partial(gzip.GzipFile, mode="rb"),
108
+ ContentEncoding.BZIP2.value: partial(bz2.BZ2File, mode="rb"),
110
109
  ContentEncoding.IDENTITY.value: lambda fileobj: fileobj,
111
110
  }
112
111
 
113
112
 
114
- def slice_table(
115
- table: pa.Table,
116
- max_len: Optional[int]) -> List[pa.Table]:
113
+ def slice_table(table: pa.Table, max_len: Optional[int]) -> List[pa.Table]:
117
114
  """
118
115
  Iteratively create 0-copy table slices.
119
116
  """
@@ -123,10 +120,7 @@ def slice_table(
123
120
  offset = 0
124
121
  records_remaining = len(table)
125
122
  while records_remaining > 0:
126
- records_this_entry = min(
127
- max_len,
128
- records_remaining
129
- )
123
+ records_this_entry = min(max_len, records_remaining)
130
124
  tables.append(table.slice(offset, records_this_entry))
131
125
  records_remaining -= records_this_entry
132
126
  offset += records_this_entry
@@ -138,21 +132,21 @@ class ReadKwargsProviderPyArrowCsvPureUtf8(ContentTypeKwargsProvider):
138
132
  as UTF-8 strings (i.e. disables type inference). Useful for ensuring
139
133
  lossless reads of UTF-8 delimited text datasets and improving read
140
134
  performance in cases where type casting is not required."""
135
+
141
136
  def __init__(self, include_columns: Optional[Iterable[str]] = None):
142
137
  self.include_columns = include_columns
143
138
 
144
- def _get_kwargs(
145
- self,
146
- content_type: str,
147
- kwargs: Dict[str, Any]) -> Dict[str, Any]:
139
+ def _get_kwargs(self, content_type: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
148
140
  if content_type in DELIMITED_TEXT_CONTENT_TYPES:
149
- convert_options: pacsv.ConvertOptions = \
150
- kwargs.get("convert_options")
141
+ convert_options: pacsv.ConvertOptions = kwargs.get("convert_options")
151
142
  if convert_options is None:
152
143
  convert_options = pacsv.ConvertOptions()
153
144
  # read only the included columns as strings?
154
- column_names = self.include_columns \
155
- if self.include_columns else convert_options.include_columns
145
+ column_names = (
146
+ self.include_columns
147
+ if self.include_columns
148
+ else convert_options.include_columns
149
+ )
156
150
  if not column_names:
157
151
  # read all columns as strings?
158
152
  read_options: pacsv.ReadOptions = kwargs.get("read_options")
@@ -171,13 +165,26 @@ class ReadKwargsProviderPyArrowSchemaOverride(ContentTypeKwargsProvider):
171
165
  """ReadKwargsProvider impl that explicitly maps column names to column types when
172
166
  loading dataset files into a PyArrow table. Disables the default type inference
173
167
  behavior on the defined columns."""
174
- def __init__(self, schema: Optional[pa.Schema] = None):
168
+
169
+ def __init__(
170
+ self,
171
+ schema: Optional[pa.Schema] = None,
172
+ pq_coerce_int96_timestamp_unit: Optional[str] = None,
173
+ ):
174
+ """
175
+
176
+ Args:
177
+ schema: The schema to use for reading the dataset.
178
+ If unspecified, the schema will be inferred from the source.
179
+ pq_coerce_int96_timestamp_unit: When reading from parquet files, cast timestamps that are stored in INT96
180
+ format to a particular resolution (e.g. 'ms'). Setting to None is equivalent to 'ms'
181
+ and therefore INT96 timestamps will be inferred as timestamps in milliseconds.
182
+
183
+ """
175
184
  self.schema = schema
185
+ self.pq_coerce_int96_timestamp_unit = pq_coerce_int96_timestamp_unit
176
186
 
177
- def _get_kwargs(
178
- self,
179
- content_type: str,
180
- kwargs: Dict[str, Any]) -> Dict[str, Any]:
187
+ def _get_kwargs(self, content_type: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
181
188
  if content_type in DELIMITED_TEXT_CONTENT_TYPES:
182
189
  convert_options = kwargs.get("convert_options", pacsv.ConvertOptions())
183
190
  if self.schema:
@@ -188,14 +195,21 @@ class ReadKwargsProviderPyArrowSchemaOverride(ContentTypeKwargsProvider):
188
195
  # Only supported in PyArrow 8.0.0+
189
196
  if self.schema:
190
197
  kwargs["schema"] = self.schema
198
+
199
+ # Coerce deprecated int96 timestamp to millisecond if unspecified
200
+ kwargs["coerce_int96_timestamp_unit"] = (
201
+ self.pq_coerce_int96_timestamp_unit or "ms"
202
+ )
203
+
191
204
  return kwargs
192
205
 
193
206
 
194
207
  def _add_column_kwargs(
195
- content_type: str,
196
- column_names: Optional[List[str]],
197
- include_columns: Optional[List[str]],
198
- kwargs: Dict[str, Any]):
208
+ content_type: str,
209
+ column_names: Optional[List[str]],
210
+ include_columns: Optional[List[str]],
211
+ kwargs: Dict[str, Any],
212
+ ):
199
213
 
200
214
  if content_type in DELIMITED_TEXT_CONTENT_TYPES:
201
215
  read_options: pacsv.ReadOptions = kwargs.get("read_options")
@@ -219,25 +233,27 @@ def _add_column_kwargs(
219
233
  if include_columns:
220
234
  logger.warning(
221
235
  f"Ignoring request to include columns {include_columns} "
222
- f"for non-tabular content type {content_type}")
236
+ f"for non-tabular content type {content_type}"
237
+ )
223
238
 
224
239
 
225
240
  def s3_file_to_table(
226
- s3_url: str,
227
- content_type: str,
228
- content_encoding: str,
229
- column_names: Optional[List[str]] = None,
230
- include_columns: Optional[List[str]] = None,
231
- pa_read_func_kwargs_provider: Optional[ReadKwargsProvider] = None,
232
- **s3_client_kwargs) -> pa.Table:
241
+ s3_url: str,
242
+ content_type: str,
243
+ content_encoding: str,
244
+ column_names: Optional[List[str]] = None,
245
+ include_columns: Optional[List[str]] = None,
246
+ pa_read_func_kwargs_provider: Optional[ReadKwargsProvider] = None,
247
+ **s3_client_kwargs,
248
+ ) -> pa.Table:
233
249
 
234
250
  from deltacat.aws import s3u as s3_utils
235
- logger.debug(f"Reading {s3_url} to PyArrow. Content type: {content_type}. "
236
- f"Encoding: {content_encoding}")
237
- s3_obj = s3_utils.get_object_at_url(
238
- s3_url,
239
- **s3_client_kwargs
251
+
252
+ logger.debug(
253
+ f"Reading {s3_url} to PyArrow. Content type: {content_type}. "
254
+ f"Encoding: {content_encoding}"
240
255
  )
256
+ s3_obj = s3_utils.get_object_at_url(s3_url, **s3_client_kwargs)
241
257
  logger.debug(f"Read S3 object from {s3_url}: {s3_obj}")
242
258
  pa_read_func = CONTENT_TYPE_TO_PA_READ_FUNC[content_type]
243
259
  input_file_init = ENCODING_TO_FILE_INIT[content_encoding]
@@ -251,11 +267,7 @@ def s3_file_to_table(
251
267
  kwargs = pa_read_func_kwargs_provider(content_type, kwargs)
252
268
 
253
269
  logger.debug(f"Reading {s3_url} via {pa_read_func} with kwargs: {kwargs}")
254
- table, latency = timed_invocation(
255
- pa_read_func,
256
- *args,
257
- **kwargs
258
- )
270
+ table, latency = timed_invocation(pa_read_func, *args, **kwargs)
259
271
  logger.debug(f"Time to read {s3_url} into PyArrow table: {latency}s")
260
272
  return table
261
273
 
@@ -265,12 +277,13 @@ def table_size(table: pa.Table) -> int:
265
277
 
266
278
 
267
279
  def table_to_file(
268
- table: pa.Table,
269
- base_path: str,
270
- file_system: AbstractFileSystem,
271
- block_path_provider: BlockWritePathProvider,
272
- content_type: str = ContentType.PARQUET.value,
273
- **kwargs) -> None:
280
+ table: pa.Table,
281
+ base_path: str,
282
+ file_system: AbstractFileSystem,
283
+ block_path_provider: BlockWritePathProvider,
284
+ content_type: str = ContentType.PARQUET.value,
285
+ **kwargs,
286
+ ) -> None:
274
287
  """
275
288
  Writes the given Pyarrow Table to a file.
276
289
  """
@@ -279,11 +292,211 @@ def table_to_file(
279
292
  raise NotImplementedError(
280
293
  f"Pyarrow writer for content type '{content_type}' not "
281
294
  f"implemented. Known content types: "
282
- f"{CONTENT_TYPE_TO_PA_WRITE_FUNC.keys}")
295
+ f"{CONTENT_TYPE_TO_PA_WRITE_FUNC.keys}"
296
+ )
283
297
  path = block_path_provider(base_path)
284
- writer(
285
- table,
286
- path,
287
- filesystem=file_system,
288
- **kwargs
289
- )
298
+ writer(table, path, filesystem=file_system, **kwargs)
299
+
300
+
301
+ class RecordBatchTables:
302
+ def __init__(self, batch_size: int):
303
+ """
304
+ Data structure for maintaining a batched list of tables, where each batched table has
305
+ a record count of some multiple of the specified record batch size.
306
+
307
+ Remaining records are stored in a separate list of tables.
308
+
309
+ Args:
310
+ batch_size: Minimum record count per table to batch by. Batched tables are
311
+ guaranteed to have a record count multiple of the batch_size.
312
+ """
313
+ self._batched_tables: List[pa.Table] = []
314
+ self._batched_record_count: int = 0
315
+ self._remaining_tables: List[pa.Table] = []
316
+ self._remaining_record_count: int = 0
317
+ self._batch_size: int = batch_size
318
+
319
+ def append(self, table: pa.Table) -> None:
320
+ """
321
+ Appends a table for batching.
322
+
323
+ Table record counts are added to any previous remaining record count.
324
+ If the new remainder record count meets or exceeds the configured batch size record count,
325
+ the remainder will be shifted over to the list of batched tables in FIFO order via table slicing.
326
+ Batched tables will always have a record count of some multiple of the configured batch size.
327
+
328
+ Record ordering is preserved from input tables whenever tables are shifted from the remainder
329
+ over to the batched list. Records from Table A will always precede records from Table B,
330
+ if Table A was appended before Table B. Records from the batched list will always precede records
331
+ from the remainders.
332
+
333
+ Ex:
334
+ bt = RecordBatchTables(8)
335
+ col1 = pa.array([i for i in range(10)])
336
+ test_table = pa.Table.from_arrays([col1], names=["col1"])
337
+ bt.append(test_table)
338
+
339
+ print(bt.batched_records) # 8
340
+ print(bt.batched) # [0, 1, 2, 3, 4, 5, 6, 7]
341
+ print(bt.remaining_records) # 2
342
+ print(bt.remaining) # [8, 9]
343
+
344
+ Args:
345
+ table: Input table to add
346
+
347
+ """
348
+ if self._remaining_tables:
349
+ if self._remaining_record_count + len(table) < self._batch_size:
350
+ self._remaining_tables.append(table)
351
+ self._remaining_record_count += len(table)
352
+ return
353
+
354
+ records_to_fit = self._batch_size - self._remaining_record_count
355
+ fitted_table = table.slice(length=records_to_fit)
356
+ self._remaining_tables.append(fitted_table)
357
+ self._remaining_record_count += len(fitted_table)
358
+ table = table.slice(offset=records_to_fit)
359
+
360
+ record_count = len(table)
361
+ record_multiplier, records_leftover = (
362
+ record_count // self._batch_size,
363
+ record_count % self._batch_size,
364
+ )
365
+
366
+ if record_multiplier > 0:
367
+ batched_table = table.slice(length=record_multiplier * self._batch_size)
368
+ # Add to remainder tables to preserve record ordering
369
+ self._remaining_tables.append(batched_table)
370
+ self._remaining_record_count += len(batched_table)
371
+
372
+ if self._remaining_tables:
373
+ self._shift_remaining_to_new_batch()
374
+
375
+ if records_leftover > 0:
376
+ leftover_table = table.slice(offset=record_multiplier * self._batch_size)
377
+ self._remaining_tables.append(leftover_table)
378
+ self._remaining_record_count += len(leftover_table)
379
+
380
+ def _shift_remaining_to_new_batch(self) -> None:
381
+ new_batch = pa.concat_tables(self._remaining_tables)
382
+ self._batched_tables.append(new_batch)
383
+ self._batched_record_count += self._remaining_record_count
384
+ self.clear_remaining()
385
+
386
+ @staticmethod
387
+ def from_tables(tables: List[pa.Table], batch_size: int) -> RecordBatchTables:
388
+ """
389
+ Static factory for generating batched tables and remainders given a list of input tables.
390
+
391
+ Args:
392
+ tables: A list of input tables with various record counts
393
+ batch_size: Minimum record count per table to batch by. Batched tables are
394
+ guaranteed to have a record count multiple of the batch_size.
395
+
396
+ Returns: A batched tables object
397
+
398
+ """
399
+ rbt = RecordBatchTables(batch_size)
400
+ for table in tables:
401
+ rbt.append(table)
402
+ return rbt
403
+
404
+ @property
405
+ def batched(self) -> List[pa.Table]:
406
+ """
407
+ List of tables batched and ready for processing.
408
+ Each table has N records, where N records are some multiple of the configured records batch size.
409
+
410
+ For example, if the configured batch size is 5, then a list of batched tables
411
+ could have the following record counts: [60, 5, 30, 10]
412
+
413
+ Returns: a list of batched tables
414
+
415
+ """
416
+ return self._batched_tables
417
+
418
+ @property
419
+ def batched_record_count(self) -> int:
420
+ """
421
+ The number of total records from the batched list.
422
+
423
+ Returns: batched record count
424
+
425
+ """
426
+ return self._batched_record_count
427
+
428
+ @property
429
+ def remaining(self) -> List[pa.Table]:
430
+ """
431
+ List of tables carried over from table slicing during the batching operation.
432
+ The sum of all record counts in the remaining tables is guaranteed to be less than the configured batch size.
433
+
434
+ Returns: a list of remaining tables
435
+
436
+ """
437
+ return self._remaining_tables
438
+
439
+ @property
440
+ def remaining_record_count(self) -> int:
441
+ """
442
+ The number of total records from the remaining tables list.
443
+
444
+ Returns: remaining record count
445
+
446
+ """
447
+ return self._remaining_record_count
448
+
449
+ @property
450
+ def batch_size(self) -> int:
451
+ """
452
+ The configured batch size.
453
+
454
+ Returns: batch size
455
+
456
+ """
457
+ return self._batch_size
458
+
459
+ def has_batches(self) -> bool:
460
+ """
461
+ Checks if there are any currently batched tables ready for processing.
462
+
463
+ Returns: true if batched records exist, otherwise false
464
+
465
+ """
466
+ return self._batched_record_count > 0
467
+
468
+ def has_remaining(self) -> bool:
469
+ """
470
+ Checks if any remaining tables exist after batching.
471
+
472
+ Returns: true if remaining records exist, otherwise false
473
+
474
+ """
475
+ return self._remaining_record_count > 0
476
+
477
+ def evict(self) -> List[pa.Table]:
478
+ """
479
+ Evicts all batched tables from this object and returns them.
480
+
481
+ Returns: a list of batched tables
482
+
483
+ """
484
+ evicted_tables = [*self.batched]
485
+ self.clear_batches()
486
+ return evicted_tables
487
+
488
+ def clear_batches(self) -> None:
489
+ """
490
+ Removes all batched tables and resets batched records.
491
+
492
+ """
493
+ self._batched_tables.clear()
494
+ self._batched_record_count = 0
495
+
496
+ def clear_remaining(self) -> None:
497
+ """
498
+ Removes all remaining tables and resets remaining records.
499
+
500
+ """
501
+ self._remaining_tables.clear()
502
+ self._remaining_record_count = 0
@@ -1,6 +1,7 @@
1
- import ray
2
1
  from collections import Counter
3
2
 
3
+ import ray
4
+
4
5
 
5
6
  @ray.remote
6
7
  class DistributedCounter(object):
@@ -1,22 +1,23 @@
1
- import ray
1
+ import copy
2
+ import itertools
3
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2
4
 
5
+ import ray
3
6
  from ray._private.ray_constants import MIN_RESOURCE_GRANULARITY
4
7
  from ray.types import ObjectRef
5
8
 
6
9
  from deltacat.utils.ray_utils.runtime import current_node_resource_key
7
- import copy
8
10
 
9
- from typing import Any, Iterable, Callable, Dict, List, Tuple, Union, Optional
10
- import itertools
11
11
 
12
12
  def invoke_parallel(
13
- items: Iterable,
14
- ray_task: Callable,
15
- *args,
16
- max_parallelism: Optional[int] = 1000,
17
- options_provider: Callable[[int, Any], Dict[str, Any]] = None,
18
- kwargs_provider: Callable[[int, Any], Dict[str, Any]] = None,
19
- **kwargs) -> List[Union[ObjectRef, Tuple[ObjectRef, ...]]]:
13
+ items: Iterable,
14
+ ray_task: Callable,
15
+ *args,
16
+ max_parallelism: Optional[int] = 1000,
17
+ options_provider: Callable[[int, Any], Dict[str, Any]] = None,
18
+ kwargs_provider: Callable[[int, Any], Dict[str, Any]] = None,
19
+ **kwargs,
20
+ ) -> List[Union[ObjectRef, Tuple[ObjectRef, ...]]]:
20
21
  """
21
22
  Creates a limited number of parallel remote invocations of the given ray
22
23
  task. By default each task is provided an ordered item from the input
@@ -57,11 +58,11 @@ def invoke_parallel(
57
58
  ray.wait(
58
59
  list(itertools.chain(*pending_ids)),
59
60
  num_returns=int(
60
- len(pending_ids[0])*(len(pending_ids) - max_parallelism)
61
- )
61
+ len(pending_ids[0]) * (len(pending_ids) - max_parallelism)
62
+ ),
62
63
  )
63
64
  else:
64
- ray.wait(pending_ids, num_returns=len(pending_ids)-max_parallelism)
65
+ ray.wait(pending_ids, num_returns=len(pending_ids) - max_parallelism)
65
66
  opt = {}
66
67
  if options_provider:
67
68
  opt = options_provider(i, item)
@@ -79,21 +80,17 @@ def current_node_options_provider(*args, **kwargs) -> Dict[str, Any]:
79
80
  """Returns a resource dictionary that can be included with ray remote
80
81
  options to pin the task or actor on the current node via:
81
82
  `foo.options(current_node_options_provider()).remote()`"""
82
- return {
83
- "resources": {
84
- current_node_resource_key(): MIN_RESOURCE_GRANULARITY
85
- }
86
- }
83
+ return {"resources": {current_node_resource_key(): MIN_RESOURCE_GRANULARITY}}
87
84
 
88
85
 
89
86
  def round_robin_options_provider(
90
- i: int,
91
- item: Any,
92
- resource_keys: List[str],
93
- *args,
94
- resource_amount_provider: Callable[[int], int] =
95
- lambda i: MIN_RESOURCE_GRANULARITY,
96
- **kwargs) -> Dict[str, Any]:
87
+ i: int,
88
+ item: Any,
89
+ resource_keys: List[str],
90
+ *args,
91
+ resource_amount_provider: Callable[[int], int] = lambda i: MIN_RESOURCE_GRANULARITY,
92
+ **kwargs,
93
+ ) -> Dict[str, Any]:
97
94
  """Returns a resource dictionary that can be included with ray remote
98
95
  options to round robin indexed tasks or actors across a list of resource
99
96
  keys. For example, the following code round-robins 100 tasks across all
@@ -108,8 +105,10 @@ def round_robin_options_provider(
108
105
  opts = kwargs.get("pg_config")
109
106
  if opts:
110
107
  new_opts = copy.deepcopy(opts)
111
- bundle_key_index = i % len(new_opts['scheduling_strategy'].placement_group.bundle_specs)
112
- new_opts['scheduling_strategy'].placement_group_bundle_index = bundle_key_index
108
+ bundle_key_index = i % len(
109
+ new_opts["scheduling_strategy"].placement_group.bundle_specs
110
+ )
111
+ new_opts["scheduling_strategy"].placement_group_bundle_index = bundle_key_index
113
112
  return new_opts
114
113
  else:
115
114
  assert resource_keys, f"No resource keys given to round robin!"