deltacat 0.1.10.dev0__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deltacat/__init__.py +41 -15
- deltacat/aws/clients.py +12 -31
- deltacat/aws/constants.py +1 -1
- deltacat/aws/redshift/__init__.py +7 -2
- deltacat/aws/redshift/model/manifest.py +54 -50
- deltacat/aws/s3u.py +176 -187
- deltacat/catalog/delegate.py +151 -185
- deltacat/catalog/interface.py +78 -97
- deltacat/catalog/model/catalog.py +21 -21
- deltacat/catalog/model/table_definition.py +11 -9
- deltacat/compute/compactor/__init__.py +12 -16
- deltacat/compute/compactor/compaction_session.py +237 -166
- deltacat/compute/compactor/model/delta_annotated.py +60 -44
- deltacat/compute/compactor/model/delta_file_envelope.py +5 -6
- deltacat/compute/compactor/model/delta_file_locator.py +10 -8
- deltacat/compute/compactor/model/materialize_result.py +6 -7
- deltacat/compute/compactor/model/primary_key_index.py +38 -34
- deltacat/compute/compactor/model/pyarrow_write_result.py +3 -4
- deltacat/compute/compactor/model/round_completion_info.py +25 -19
- deltacat/compute/compactor/model/sort_key.py +18 -15
- deltacat/compute/compactor/steps/dedupe.py +119 -94
- deltacat/compute/compactor/steps/hash_bucket.py +48 -47
- deltacat/compute/compactor/steps/materialize.py +86 -92
- deltacat/compute/compactor/steps/rehash/rehash_bucket.py +13 -13
- deltacat/compute/compactor/steps/rehash/rewrite_index.py +5 -5
- deltacat/compute/compactor/utils/io.py +59 -47
- deltacat/compute/compactor/utils/primary_key_index.py +91 -80
- deltacat/compute/compactor/utils/round_completion_file.py +22 -23
- deltacat/compute/compactor/utils/system_columns.py +33 -45
- deltacat/compute/metastats/meta_stats.py +235 -157
- deltacat/compute/metastats/model/partition_stats_dict.py +7 -10
- deltacat/compute/metastats/model/stats_cluster_size_estimator.py +13 -5
- deltacat/compute/metastats/stats.py +95 -64
- deltacat/compute/metastats/utils/io.py +100 -53
- deltacat/compute/metastats/utils/pyarrow_memory_estimation_function.py +5 -2
- deltacat/compute/metastats/utils/ray_utils.py +38 -33
- deltacat/compute/stats/basic.py +107 -69
- deltacat/compute/stats/models/delta_column_stats.py +11 -8
- deltacat/compute/stats/models/delta_stats.py +59 -32
- deltacat/compute/stats/models/delta_stats_cache_result.py +4 -1
- deltacat/compute/stats/models/manifest_entry_stats.py +12 -6
- deltacat/compute/stats/models/stats_result.py +24 -14
- deltacat/compute/stats/utils/intervals.py +16 -9
- deltacat/compute/stats/utils/io.py +86 -51
- deltacat/compute/stats/utils/manifest_stats_file.py +24 -33
- deltacat/constants.py +4 -13
- deltacat/io/__init__.py +2 -2
- deltacat/io/aws/redshift/redshift_datasource.py +157 -143
- deltacat/io/dataset.py +14 -17
- deltacat/io/read_api.py +36 -33
- deltacat/logs.py +94 -42
- deltacat/storage/__init__.py +18 -8
- deltacat/storage/interface.py +196 -213
- deltacat/storage/model/delta.py +45 -51
- deltacat/storage/model/list_result.py +12 -8
- deltacat/storage/model/namespace.py +4 -5
- deltacat/storage/model/partition.py +42 -42
- deltacat/storage/model/stream.py +29 -30
- deltacat/storage/model/table.py +14 -14
- deltacat/storage/model/table_version.py +32 -31
- deltacat/storage/model/types.py +1 -0
- deltacat/tests/stats/test_intervals.py +11 -24
- deltacat/tests/utils/__init__.py +0 -0
- deltacat/tests/utils/test_record_batch_tables.py +284 -0
- deltacat/types/media.py +3 -4
- deltacat/types/tables.py +31 -21
- deltacat/utils/common.py +5 -11
- deltacat/utils/numpy.py +20 -22
- deltacat/utils/pandas.py +73 -100
- deltacat/utils/performance.py +3 -9
- deltacat/utils/placement.py +259 -230
- deltacat/utils/pyarrow.py +302 -89
- deltacat/utils/ray_utils/collections.py +2 -1
- deltacat/utils/ray_utils/concurrency.py +27 -28
- deltacat/utils/ray_utils/dataset.py +28 -28
- deltacat/utils/ray_utils/performance.py +5 -9
- deltacat/utils/ray_utils/runtime.py +9 -10
- {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/METADATA +1 -1
- deltacat-0.1.12.dist-info/RECORD +110 -0
- deltacat-0.1.10.dev0.dist-info/RECORD +0 -108
- {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/LICENSE +0 -0
- {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/WHEEL +0 -0
- {deltacat-0.1.10.dev0.dist-info → deltacat-0.1.12.dist-info}/top_level.txt +0 -0
deltacat/utils/pyarrow.py
CHANGED
@@ -1,25 +1,31 @@
|
|
1
|
-
|
2
|
-
import
|
1
|
+
# Allow classes to use self-referencing Type hints in Python 3.7.
|
2
|
+
from __future__ import annotations
|
3
|
+
|
3
4
|
import bz2
|
5
|
+
import gzip
|
4
6
|
import io
|
5
7
|
import logging
|
6
|
-
|
7
8
|
from functools import partial
|
8
|
-
from
|
9
|
-
|
10
|
-
from pyarrow import feather as paf, parquet as papq, csv as pacsv, \
|
11
|
-
json as pajson
|
9
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional
|
12
10
|
|
11
|
+
import pyarrow as pa
|
12
|
+
from fsspec import AbstractFileSystem
|
13
|
+
from pyarrow import csv as pacsv
|
14
|
+
from pyarrow import feather as paf
|
15
|
+
from pyarrow import json as pajson
|
16
|
+
from pyarrow import parquet as papq
|
13
17
|
from ray.data.datasource import BlockWritePathProvider
|
14
18
|
|
15
19
|
from deltacat import logs
|
16
|
-
from deltacat.types.media import
|
17
|
-
DELIMITED_TEXT_CONTENT_TYPES,
|
18
|
-
|
20
|
+
from deltacat.types.media import (
|
21
|
+
DELIMITED_TEXT_CONTENT_TYPES,
|
22
|
+
TABULAR_CONTENT_TYPES,
|
23
|
+
ContentEncoding,
|
24
|
+
ContentType,
|
25
|
+
)
|
26
|
+
from deltacat.utils.common import ContentTypeKwargsProvider, ReadKwargsProvider
|
19
27
|
from deltacat.utils.performance import timed_invocation
|
20
28
|
|
21
|
-
from typing import Any, Callable, Dict, List, Optional, Iterable, Union
|
22
|
-
|
23
29
|
logger = logs.configure_deltacat_logger(logging.getLogger(__name__))
|
24
30
|
|
25
31
|
|
@@ -33,35 +39,28 @@ CONTENT_TYPE_TO_PA_READ_FUNC: Dict[str, Callable] = {
|
|
33
39
|
# Pyarrow.orc is disabled in Pyarrow 0.15, 0.16:
|
34
40
|
# https://issues.apache.org/jira/browse/ARROW-7811
|
35
41
|
# ContentType.ORC.value: paorc.ContentType.ORCFile,
|
36
|
-
ContentType.JSON.value: pajson.read_json
|
42
|
+
ContentType.JSON.value: pajson.read_json,
|
37
43
|
}
|
38
44
|
|
39
45
|
|
40
46
|
def write_feather(
|
41
|
-
|
42
|
-
|
43
|
-
*,
|
44
|
-
filesystem: AbstractFileSystem,
|
45
|
-
**kwargs) -> None:
|
47
|
+
table: pa.Table, path: str, *, filesystem: AbstractFileSystem, **kwargs
|
48
|
+
) -> None:
|
46
49
|
|
47
50
|
with filesystem.open(path, "wb") as f:
|
48
51
|
paf.write_feather(table, f, **kwargs)
|
49
52
|
|
50
53
|
|
51
54
|
def write_csv(
|
52
|
-
|
53
|
-
|
54
|
-
*,
|
55
|
-
filesystem: AbstractFileSystem,
|
56
|
-
**kwargs) -> None:
|
55
|
+
table: pa.Table, path: str, *, filesystem: AbstractFileSystem, **kwargs
|
56
|
+
) -> None:
|
57
57
|
|
58
58
|
with filesystem.open(path, "wb") as f:
|
59
59
|
# TODO (pdames): Add support for client-specified compression types.
|
60
60
|
with pa.CompressedOutputStream(f, ContentEncoding.GZIP.value) as out:
|
61
61
|
if kwargs.get("write_options") is None:
|
62
62
|
# column names are kept in table metadata, so omit header
|
63
|
-
kwargs["write_options"] = pacsv.WriteOptions(
|
64
|
-
include_header=False)
|
63
|
+
kwargs["write_options"] = pacsv.WriteOptions(include_header=False)
|
65
64
|
pacsv.write_csv(table, out, **kwargs)
|
66
65
|
|
67
66
|
|
@@ -78,13 +77,11 @@ CONTENT_TYPE_TO_PA_WRITE_FUNC: Dict[str, Callable] = {
|
|
78
77
|
def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
|
79
78
|
if content_type == ContentType.UNESCAPED_TSV.value:
|
80
79
|
return {
|
81
|
-
"parse_options": pacsv.ParseOptions(
|
82
|
-
delimiter="\t",
|
83
|
-
quote_char=False),
|
80
|
+
"parse_options": pacsv.ParseOptions(delimiter="\t", quote_char=False),
|
84
81
|
"convert_options": pacsv.ConvertOptions(
|
85
82
|
null_values=[""], # pyarrow defaults are ["", "NULL", "null"]
|
86
83
|
strings_can_be_null=True,
|
87
|
-
)
|
84
|
+
),
|
88
85
|
}
|
89
86
|
if content_type == ContentType.TSV.value:
|
90
87
|
return {"parse_options": pacsv.ParseOptions(delimiter="\t")}
|
@@ -92,9 +89,11 @@ def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
|
|
92
89
|
return {"parse_options": pacsv.ParseOptions(delimiter=",")}
|
93
90
|
if content_type == ContentType.PSV.value:
|
94
91
|
return {"parse_options": pacsv.ParseOptions(delimiter="|")}
|
95
|
-
if content_type in {
|
96
|
-
|
97
|
-
|
92
|
+
if content_type in {
|
93
|
+
ContentType.PARQUET.value,
|
94
|
+
ContentType.FEATHER.value,
|
95
|
+
ContentType.JSON.value,
|
96
|
+
}:
|
98
97
|
return {}
|
99
98
|
# Pyarrow.orc is disabled in Pyarrow 0.15, 0.16:
|
100
99
|
# https://issues.apache.org/jira/browse/ARROW-7811
|
@@ -105,15 +104,13 @@ def content_type_to_reader_kwargs(content_type: str) -> Dict[str, Any]:
|
|
105
104
|
|
106
105
|
# TODO (pdames): add deflate and snappy
|
107
106
|
ENCODING_TO_FILE_INIT: Dict[str, Callable] = {
|
108
|
-
ContentEncoding.GZIP.value: partial(gzip.GzipFile, mode=
|
109
|
-
ContentEncoding.BZIP2.value: partial(bz2.BZ2File, mode=
|
107
|
+
ContentEncoding.GZIP.value: partial(gzip.GzipFile, mode="rb"),
|
108
|
+
ContentEncoding.BZIP2.value: partial(bz2.BZ2File, mode="rb"),
|
110
109
|
ContentEncoding.IDENTITY.value: lambda fileobj: fileobj,
|
111
110
|
}
|
112
111
|
|
113
112
|
|
114
|
-
def slice_table(
|
115
|
-
table: pa.Table,
|
116
|
-
max_len: Optional[int]) -> List[pa.Table]:
|
113
|
+
def slice_table(table: pa.Table, max_len: Optional[int]) -> List[pa.Table]:
|
117
114
|
"""
|
118
115
|
Iteratively create 0-copy table slices.
|
119
116
|
"""
|
@@ -123,10 +120,7 @@ def slice_table(
|
|
123
120
|
offset = 0
|
124
121
|
records_remaining = len(table)
|
125
122
|
while records_remaining > 0:
|
126
|
-
records_this_entry = min(
|
127
|
-
max_len,
|
128
|
-
records_remaining
|
129
|
-
)
|
123
|
+
records_this_entry = min(max_len, records_remaining)
|
130
124
|
tables.append(table.slice(offset, records_this_entry))
|
131
125
|
records_remaining -= records_this_entry
|
132
126
|
offset += records_this_entry
|
@@ -138,21 +132,21 @@ class ReadKwargsProviderPyArrowCsvPureUtf8(ContentTypeKwargsProvider):
|
|
138
132
|
as UTF-8 strings (i.e. disables type inference). Useful for ensuring
|
139
133
|
lossless reads of UTF-8 delimited text datasets and improving read
|
140
134
|
performance in cases where type casting is not required."""
|
135
|
+
|
141
136
|
def __init__(self, include_columns: Optional[Iterable[str]] = None):
|
142
137
|
self.include_columns = include_columns
|
143
138
|
|
144
|
-
def _get_kwargs(
|
145
|
-
self,
|
146
|
-
content_type: str,
|
147
|
-
kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
139
|
+
def _get_kwargs(self, content_type: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
148
140
|
if content_type in DELIMITED_TEXT_CONTENT_TYPES:
|
149
|
-
convert_options: pacsv.ConvertOptions =
|
150
|
-
kwargs.get("convert_options")
|
141
|
+
convert_options: pacsv.ConvertOptions = kwargs.get("convert_options")
|
151
142
|
if convert_options is None:
|
152
143
|
convert_options = pacsv.ConvertOptions()
|
153
144
|
# read only the included columns as strings?
|
154
|
-
column_names =
|
155
|
-
|
145
|
+
column_names = (
|
146
|
+
self.include_columns
|
147
|
+
if self.include_columns
|
148
|
+
else convert_options.include_columns
|
149
|
+
)
|
156
150
|
if not column_names:
|
157
151
|
# read all columns as strings?
|
158
152
|
read_options: pacsv.ReadOptions = kwargs.get("read_options")
|
@@ -171,13 +165,26 @@ class ReadKwargsProviderPyArrowSchemaOverride(ContentTypeKwargsProvider):
|
|
171
165
|
"""ReadKwargsProvider impl that explicitly maps column names to column types when
|
172
166
|
loading dataset files into a PyArrow table. Disables the default type inference
|
173
167
|
behavior on the defined columns."""
|
174
|
-
|
168
|
+
|
169
|
+
def __init__(
|
170
|
+
self,
|
171
|
+
schema: Optional[pa.Schema] = None,
|
172
|
+
pq_coerce_int96_timestamp_unit: Optional[str] = None,
|
173
|
+
):
|
174
|
+
"""
|
175
|
+
|
176
|
+
Args:
|
177
|
+
schema: The schema to use for reading the dataset.
|
178
|
+
If unspecified, the schema will be inferred from the source.
|
179
|
+
pq_coerce_int96_timestamp_unit: When reading from parquet files, cast timestamps that are stored in INT96
|
180
|
+
format to a particular resolution (e.g. 'ms'). Setting to None is equivalent to 'ms'
|
181
|
+
and therefore INT96 timestamps will be inferred as timestamps in milliseconds.
|
182
|
+
|
183
|
+
"""
|
175
184
|
self.schema = schema
|
185
|
+
self.pq_coerce_int96_timestamp_unit = pq_coerce_int96_timestamp_unit
|
176
186
|
|
177
|
-
def _get_kwargs(
|
178
|
-
self,
|
179
|
-
content_type: str,
|
180
|
-
kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
187
|
+
def _get_kwargs(self, content_type: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
181
188
|
if content_type in DELIMITED_TEXT_CONTENT_TYPES:
|
182
189
|
convert_options = kwargs.get("convert_options", pacsv.ConvertOptions())
|
183
190
|
if self.schema:
|
@@ -188,14 +195,21 @@ class ReadKwargsProviderPyArrowSchemaOverride(ContentTypeKwargsProvider):
|
|
188
195
|
# Only supported in PyArrow 8.0.0+
|
189
196
|
if self.schema:
|
190
197
|
kwargs["schema"] = self.schema
|
198
|
+
|
199
|
+
# Coerce deprecated int96 timestamp to millisecond if unspecified
|
200
|
+
kwargs["coerce_int96_timestamp_unit"] = (
|
201
|
+
self.pq_coerce_int96_timestamp_unit or "ms"
|
202
|
+
)
|
203
|
+
|
191
204
|
return kwargs
|
192
205
|
|
193
206
|
|
194
207
|
def _add_column_kwargs(
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
208
|
+
content_type: str,
|
209
|
+
column_names: Optional[List[str]],
|
210
|
+
include_columns: Optional[List[str]],
|
211
|
+
kwargs: Dict[str, Any],
|
212
|
+
):
|
199
213
|
|
200
214
|
if content_type in DELIMITED_TEXT_CONTENT_TYPES:
|
201
215
|
read_options: pacsv.ReadOptions = kwargs.get("read_options")
|
@@ -219,25 +233,27 @@ def _add_column_kwargs(
|
|
219
233
|
if include_columns:
|
220
234
|
logger.warning(
|
221
235
|
f"Ignoring request to include columns {include_columns} "
|
222
|
-
f"for non-tabular content type {content_type}"
|
236
|
+
f"for non-tabular content type {content_type}"
|
237
|
+
)
|
223
238
|
|
224
239
|
|
225
240
|
def s3_file_to_table(
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
241
|
+
s3_url: str,
|
242
|
+
content_type: str,
|
243
|
+
content_encoding: str,
|
244
|
+
column_names: Optional[List[str]] = None,
|
245
|
+
include_columns: Optional[List[str]] = None,
|
246
|
+
pa_read_func_kwargs_provider: Optional[ReadKwargsProvider] = None,
|
247
|
+
**s3_client_kwargs,
|
248
|
+
) -> pa.Table:
|
233
249
|
|
234
250
|
from deltacat.aws import s3u as s3_utils
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
**s3_client_kwargs
|
251
|
+
|
252
|
+
logger.debug(
|
253
|
+
f"Reading {s3_url} to PyArrow. Content type: {content_type}. "
|
254
|
+
f"Encoding: {content_encoding}"
|
240
255
|
)
|
256
|
+
s3_obj = s3_utils.get_object_at_url(s3_url, **s3_client_kwargs)
|
241
257
|
logger.debug(f"Read S3 object from {s3_url}: {s3_obj}")
|
242
258
|
pa_read_func = CONTENT_TYPE_TO_PA_READ_FUNC[content_type]
|
243
259
|
input_file_init = ENCODING_TO_FILE_INIT[content_encoding]
|
@@ -251,11 +267,7 @@ def s3_file_to_table(
|
|
251
267
|
kwargs = pa_read_func_kwargs_provider(content_type, kwargs)
|
252
268
|
|
253
269
|
logger.debug(f"Reading {s3_url} via {pa_read_func} with kwargs: {kwargs}")
|
254
|
-
table, latency = timed_invocation(
|
255
|
-
pa_read_func,
|
256
|
-
*args,
|
257
|
-
**kwargs
|
258
|
-
)
|
270
|
+
table, latency = timed_invocation(pa_read_func, *args, **kwargs)
|
259
271
|
logger.debug(f"Time to read {s3_url} into PyArrow table: {latency}s")
|
260
272
|
return table
|
261
273
|
|
@@ -265,12 +277,13 @@ def table_size(table: pa.Table) -> int:
|
|
265
277
|
|
266
278
|
|
267
279
|
def table_to_file(
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
280
|
+
table: pa.Table,
|
281
|
+
base_path: str,
|
282
|
+
file_system: AbstractFileSystem,
|
283
|
+
block_path_provider: BlockWritePathProvider,
|
284
|
+
content_type: str = ContentType.PARQUET.value,
|
285
|
+
**kwargs,
|
286
|
+
) -> None:
|
274
287
|
"""
|
275
288
|
Writes the given Pyarrow Table to a file.
|
276
289
|
"""
|
@@ -279,11 +292,211 @@ def table_to_file(
|
|
279
292
|
raise NotImplementedError(
|
280
293
|
f"Pyarrow writer for content type '{content_type}' not "
|
281
294
|
f"implemented. Known content types: "
|
282
|
-
f"{CONTENT_TYPE_TO_PA_WRITE_FUNC.keys}"
|
295
|
+
f"{CONTENT_TYPE_TO_PA_WRITE_FUNC.keys}"
|
296
|
+
)
|
283
297
|
path = block_path_provider(base_path)
|
284
|
-
writer(
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
298
|
+
writer(table, path, filesystem=file_system, **kwargs)
|
299
|
+
|
300
|
+
|
301
|
+
class RecordBatchTables:
|
302
|
+
def __init__(self, batch_size: int):
|
303
|
+
"""
|
304
|
+
Data structure for maintaining a batched list of tables, where each batched table has
|
305
|
+
a record count of some multiple of the specified record batch size.
|
306
|
+
|
307
|
+
Remaining records are stored in a separate list of tables.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
batch_size: Minimum record count per table to batch by. Batched tables are
|
311
|
+
guaranteed to have a record count multiple of the batch_size.
|
312
|
+
"""
|
313
|
+
self._batched_tables: List[pa.Table] = []
|
314
|
+
self._batched_record_count: int = 0
|
315
|
+
self._remaining_tables: List[pa.Table] = []
|
316
|
+
self._remaining_record_count: int = 0
|
317
|
+
self._batch_size: int = batch_size
|
318
|
+
|
319
|
+
def append(self, table: pa.Table) -> None:
|
320
|
+
"""
|
321
|
+
Appends a table for batching.
|
322
|
+
|
323
|
+
Table record counts are added to any previous remaining record count.
|
324
|
+
If the new remainder record count meets or exceeds the configured batch size record count,
|
325
|
+
the remainder will be shifted over to the list of batched tables in FIFO order via table slicing.
|
326
|
+
Batched tables will always have a record count of some multiple of the configured batch size.
|
327
|
+
|
328
|
+
Record ordering is preserved from input tables whenever tables are shifted from the remainder
|
329
|
+
over to the batched list. Records from Table A will always precede records from Table B,
|
330
|
+
if Table A was appended before Table B. Records from the batched list will always precede records
|
331
|
+
from the remainders.
|
332
|
+
|
333
|
+
Ex:
|
334
|
+
bt = RecordBatchTables(8)
|
335
|
+
col1 = pa.array([i for i in range(10)])
|
336
|
+
test_table = pa.Table.from_arrays([col1], names=["col1"])
|
337
|
+
bt.append(test_table)
|
338
|
+
|
339
|
+
print(bt.batched_records) # 8
|
340
|
+
print(bt.batched) # [0, 1, 2, 3, 4, 5, 6, 7]
|
341
|
+
print(bt.remaining_records) # 2
|
342
|
+
print(bt.remaining) # [8, 9]
|
343
|
+
|
344
|
+
Args:
|
345
|
+
table: Input table to add
|
346
|
+
|
347
|
+
"""
|
348
|
+
if self._remaining_tables:
|
349
|
+
if self._remaining_record_count + len(table) < self._batch_size:
|
350
|
+
self._remaining_tables.append(table)
|
351
|
+
self._remaining_record_count += len(table)
|
352
|
+
return
|
353
|
+
|
354
|
+
records_to_fit = self._batch_size - self._remaining_record_count
|
355
|
+
fitted_table = table.slice(length=records_to_fit)
|
356
|
+
self._remaining_tables.append(fitted_table)
|
357
|
+
self._remaining_record_count += len(fitted_table)
|
358
|
+
table = table.slice(offset=records_to_fit)
|
359
|
+
|
360
|
+
record_count = len(table)
|
361
|
+
record_multiplier, records_leftover = (
|
362
|
+
record_count // self._batch_size,
|
363
|
+
record_count % self._batch_size,
|
364
|
+
)
|
365
|
+
|
366
|
+
if record_multiplier > 0:
|
367
|
+
batched_table = table.slice(length=record_multiplier * self._batch_size)
|
368
|
+
# Add to remainder tables to preserve record ordering
|
369
|
+
self._remaining_tables.append(batched_table)
|
370
|
+
self._remaining_record_count += len(batched_table)
|
371
|
+
|
372
|
+
if self._remaining_tables:
|
373
|
+
self._shift_remaining_to_new_batch()
|
374
|
+
|
375
|
+
if records_leftover > 0:
|
376
|
+
leftover_table = table.slice(offset=record_multiplier * self._batch_size)
|
377
|
+
self._remaining_tables.append(leftover_table)
|
378
|
+
self._remaining_record_count += len(leftover_table)
|
379
|
+
|
380
|
+
def _shift_remaining_to_new_batch(self) -> None:
|
381
|
+
new_batch = pa.concat_tables(self._remaining_tables)
|
382
|
+
self._batched_tables.append(new_batch)
|
383
|
+
self._batched_record_count += self._remaining_record_count
|
384
|
+
self.clear_remaining()
|
385
|
+
|
386
|
+
@staticmethod
|
387
|
+
def from_tables(tables: List[pa.Table], batch_size: int) -> RecordBatchTables:
|
388
|
+
"""
|
389
|
+
Static factory for generating batched tables and remainders given a list of input tables.
|
390
|
+
|
391
|
+
Args:
|
392
|
+
tables: A list of input tables with various record counts
|
393
|
+
batch_size: Minimum record count per table to batch by. Batched tables are
|
394
|
+
guaranteed to have a record count multiple of the batch_size.
|
395
|
+
|
396
|
+
Returns: A batched tables object
|
397
|
+
|
398
|
+
"""
|
399
|
+
rbt = RecordBatchTables(batch_size)
|
400
|
+
for table in tables:
|
401
|
+
rbt.append(table)
|
402
|
+
return rbt
|
403
|
+
|
404
|
+
@property
|
405
|
+
def batched(self) -> List[pa.Table]:
|
406
|
+
"""
|
407
|
+
List of tables batched and ready for processing.
|
408
|
+
Each table has N records, where N records are some multiple of the configured records batch size.
|
409
|
+
|
410
|
+
For example, if the configured batch size is 5, then a list of batched tables
|
411
|
+
could have the following record counts: [60, 5, 30, 10]
|
412
|
+
|
413
|
+
Returns: a list of batched tables
|
414
|
+
|
415
|
+
"""
|
416
|
+
return self._batched_tables
|
417
|
+
|
418
|
+
@property
|
419
|
+
def batched_record_count(self) -> int:
|
420
|
+
"""
|
421
|
+
The number of total records from the batched list.
|
422
|
+
|
423
|
+
Returns: batched record count
|
424
|
+
|
425
|
+
"""
|
426
|
+
return self._batched_record_count
|
427
|
+
|
428
|
+
@property
|
429
|
+
def remaining(self) -> List[pa.Table]:
|
430
|
+
"""
|
431
|
+
List of tables carried over from table slicing during the batching operation.
|
432
|
+
The sum of all record counts in the remaining tables is guaranteed to be less than the configured batch size.
|
433
|
+
|
434
|
+
Returns: a list of remaining tables
|
435
|
+
|
436
|
+
"""
|
437
|
+
return self._remaining_tables
|
438
|
+
|
439
|
+
@property
|
440
|
+
def remaining_record_count(self) -> int:
|
441
|
+
"""
|
442
|
+
The number of total records from the remaining tables list.
|
443
|
+
|
444
|
+
Returns: remaining record count
|
445
|
+
|
446
|
+
"""
|
447
|
+
return self._remaining_record_count
|
448
|
+
|
449
|
+
@property
|
450
|
+
def batch_size(self) -> int:
|
451
|
+
"""
|
452
|
+
The configured batch size.
|
453
|
+
|
454
|
+
Returns: batch size
|
455
|
+
|
456
|
+
"""
|
457
|
+
return self._batch_size
|
458
|
+
|
459
|
+
def has_batches(self) -> bool:
|
460
|
+
"""
|
461
|
+
Checks if there are any currently batched tables ready for processing.
|
462
|
+
|
463
|
+
Returns: true if batched records exist, otherwise false
|
464
|
+
|
465
|
+
"""
|
466
|
+
return self._batched_record_count > 0
|
467
|
+
|
468
|
+
def has_remaining(self) -> bool:
|
469
|
+
"""
|
470
|
+
Checks if any remaining tables exist after batching.
|
471
|
+
|
472
|
+
Returns: true if remaining records exist, otherwise false
|
473
|
+
|
474
|
+
"""
|
475
|
+
return self._remaining_record_count > 0
|
476
|
+
|
477
|
+
def evict(self) -> List[pa.Table]:
|
478
|
+
"""
|
479
|
+
Evicts all batched tables from this object and returns them.
|
480
|
+
|
481
|
+
Returns: a list of batched tables
|
482
|
+
|
483
|
+
"""
|
484
|
+
evicted_tables = [*self.batched]
|
485
|
+
self.clear_batches()
|
486
|
+
return evicted_tables
|
487
|
+
|
488
|
+
def clear_batches(self) -> None:
|
489
|
+
"""
|
490
|
+
Removes all batched tables and resets batched records.
|
491
|
+
|
492
|
+
"""
|
493
|
+
self._batched_tables.clear()
|
494
|
+
self._batched_record_count = 0
|
495
|
+
|
496
|
+
def clear_remaining(self) -> None:
|
497
|
+
"""
|
498
|
+
Removes all remaining tables and resets remaining records.
|
499
|
+
|
500
|
+
"""
|
501
|
+
self._remaining_tables.clear()
|
502
|
+
self._remaining_record_count = 0
|
@@ -1,22 +1,23 @@
|
|
1
|
-
import
|
1
|
+
import copy
|
2
|
+
import itertools
|
3
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
2
4
|
|
5
|
+
import ray
|
3
6
|
from ray._private.ray_constants import MIN_RESOURCE_GRANULARITY
|
4
7
|
from ray.types import ObjectRef
|
5
8
|
|
6
9
|
from deltacat.utils.ray_utils.runtime import current_node_resource_key
|
7
|
-
import copy
|
8
10
|
|
9
|
-
from typing import Any, Iterable, Callable, Dict, List, Tuple, Union, Optional
|
10
|
-
import itertools
|
11
11
|
|
12
12
|
def invoke_parallel(
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
13
|
+
items: Iterable,
|
14
|
+
ray_task: Callable,
|
15
|
+
*args,
|
16
|
+
max_parallelism: Optional[int] = 1000,
|
17
|
+
options_provider: Callable[[int, Any], Dict[str, Any]] = None,
|
18
|
+
kwargs_provider: Callable[[int, Any], Dict[str, Any]] = None,
|
19
|
+
**kwargs,
|
20
|
+
) -> List[Union[ObjectRef, Tuple[ObjectRef, ...]]]:
|
20
21
|
"""
|
21
22
|
Creates a limited number of parallel remote invocations of the given ray
|
22
23
|
task. By default each task is provided an ordered item from the input
|
@@ -57,11 +58,11 @@ def invoke_parallel(
|
|
57
58
|
ray.wait(
|
58
59
|
list(itertools.chain(*pending_ids)),
|
59
60
|
num_returns=int(
|
60
|
-
len(pending_ids[0])*(len(pending_ids) - max_parallelism)
|
61
|
-
)
|
61
|
+
len(pending_ids[0]) * (len(pending_ids) - max_parallelism)
|
62
|
+
),
|
62
63
|
)
|
63
64
|
else:
|
64
|
-
ray.wait(pending_ids, num_returns=len(pending_ids)-max_parallelism)
|
65
|
+
ray.wait(pending_ids, num_returns=len(pending_ids) - max_parallelism)
|
65
66
|
opt = {}
|
66
67
|
if options_provider:
|
67
68
|
opt = options_provider(i, item)
|
@@ -79,21 +80,17 @@ def current_node_options_provider(*args, **kwargs) -> Dict[str, Any]:
|
|
79
80
|
"""Returns a resource dictionary that can be included with ray remote
|
80
81
|
options to pin the task or actor on the current node via:
|
81
82
|
`foo.options(current_node_options_provider()).remote()`"""
|
82
|
-
return {
|
83
|
-
"resources": {
|
84
|
-
current_node_resource_key(): MIN_RESOURCE_GRANULARITY
|
85
|
-
}
|
86
|
-
}
|
83
|
+
return {"resources": {current_node_resource_key(): MIN_RESOURCE_GRANULARITY}}
|
87
84
|
|
88
85
|
|
89
86
|
def round_robin_options_provider(
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
87
|
+
i: int,
|
88
|
+
item: Any,
|
89
|
+
resource_keys: List[str],
|
90
|
+
*args,
|
91
|
+
resource_amount_provider: Callable[[int], int] = lambda i: MIN_RESOURCE_GRANULARITY,
|
92
|
+
**kwargs,
|
93
|
+
) -> Dict[str, Any]:
|
97
94
|
"""Returns a resource dictionary that can be included with ray remote
|
98
95
|
options to round robin indexed tasks or actors across a list of resource
|
99
96
|
keys. For example, the following code round-robins 100 tasks across all
|
@@ -108,8 +105,10 @@ def round_robin_options_provider(
|
|
108
105
|
opts = kwargs.get("pg_config")
|
109
106
|
if opts:
|
110
107
|
new_opts = copy.deepcopy(opts)
|
111
|
-
bundle_key_index = i % len(
|
112
|
-
|
108
|
+
bundle_key_index = i % len(
|
109
|
+
new_opts["scheduling_strategy"].placement_group.bundle_specs
|
110
|
+
)
|
111
|
+
new_opts["scheduling_strategy"].placement_group_bundle_index = bundle_key_index
|
113
112
|
return new_opts
|
114
113
|
else:
|
115
114
|
assert resource_keys, f"No resource keys given to round robin!"
|