replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +1 -1
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/METADATA +3 -3
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/RECORD +119 -34
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/licenses/NOTICE +0 -0
replay/__init__.py
CHANGED
replay/data/dataset.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
import warnings
|
|
8
9
|
from collections.abc import Iterable, Sequence
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import Callable, Optional, Union
|
|
@@ -45,6 +46,7 @@ class Dataset:
|
|
|
45
46
|
):
|
|
46
47
|
"""
|
|
47
48
|
:param feature_schema: mapping of columns names and feature infos.
|
|
49
|
+
All features not specified in the schema will be assumed numerical by default.
|
|
48
50
|
:param interactions: dataframe with interactions.
|
|
49
51
|
:param query_features: dataframe with query features,
|
|
50
52
|
defaults: ```None```.
|
|
@@ -498,6 +500,15 @@ class Dataset:
|
|
|
498
500
|
source=FeatureSource.QUERY_FEATURES,
|
|
499
501
|
feature_schema=updated_feature_schema,
|
|
500
502
|
)
|
|
503
|
+
|
|
504
|
+
if filled_features:
|
|
505
|
+
msg = (
|
|
506
|
+
"The following features are present in the dataset but have not been specified "
|
|
507
|
+
f"by the feature schema: {[(info.column, info.feature_source.value) for info in filled_features]}. "
|
|
508
|
+
"These features will be interpreted as NUMERICAL."
|
|
509
|
+
)
|
|
510
|
+
warnings.warn(msg, stacklevel=2)
|
|
511
|
+
|
|
501
512
|
return FeatureSchema(features_list=features_list + filled_features)
|
|
502
513
|
|
|
503
514
|
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
replay/data/nn/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from replay.utils import TORCH_AVAILABLE
|
|
2
2
|
|
|
3
3
|
if TORCH_AVAILABLE:
|
|
4
|
+
from .parquet import ParquetDataset, ParquetModule
|
|
4
5
|
from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
|
|
5
6
|
from .sequence_tokenizer import SequenceTokenizer
|
|
6
7
|
from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
|
|
@@ -18,6 +19,8 @@ if TORCH_AVAILABLE:
|
|
|
18
19
|
"DEFAULT_TRAIN_PADDING_VALUE",
|
|
19
20
|
"MutableTensorMap",
|
|
20
21
|
"PandasSequentialDataset",
|
|
22
|
+
"ParquetDataset",
|
|
23
|
+
"ParquetModule",
|
|
21
24
|
"PolarsSequentialDataset",
|
|
22
25
|
"SequenceTokenizer",
|
|
23
26
|
"SequentialDataset",
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementation of the ``ParquetDataset`` and its internals.
|
|
3
|
+
|
|
4
|
+
``ParquetDataset`` is combination of PyTorch-compatible dataset and sampler which enables
|
|
5
|
+
training and inference of models on datasets of any arbitrary size by leveraging PyArrow
|
|
6
|
+
Datasets to perform batch-wise reading and processing of data from disk.
|
|
7
|
+
|
|
8
|
+
``ParquetDataset`` includes support for Pytorch's distributed training framework as well as
|
|
9
|
+
access to remotely stored data via PyArrow's filesystem configs.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .info.replicas import DEFAULT_REPLICAS_INFO, ReplicasInfo, ReplicasInfoProtocol
|
|
13
|
+
from .parquet_dataset import ParquetDataset
|
|
14
|
+
from .parquet_module import ParquetModule
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DEFAULT_REPLICAS_INFO",
|
|
18
|
+
"ParquetDataset",
|
|
19
|
+
"ParquetModule",
|
|
20
|
+
"ReplicasInfo",
|
|
21
|
+
"ReplicasInfoProtocol",
|
|
22
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralValue
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def dict_collate(batch: Sequence[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
|
9
|
+
"""Simple collate function that converts a dict of values into a tensor dict."""
|
|
10
|
+
return {k: torch.cat([d[k] for d in batch], dim=0) for k in batch[0]}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def general_collate(batch: Sequence[GeneralBatch]) -> GeneralBatch:
|
|
14
|
+
"""General collate function that converts a nested dict of values into a tensor dict."""
|
|
15
|
+
result = {}
|
|
16
|
+
test_sample = batch[0]
|
|
17
|
+
|
|
18
|
+
if len(batch) == 1:
|
|
19
|
+
return test_sample
|
|
20
|
+
|
|
21
|
+
for key, test_value in test_sample.items():
|
|
22
|
+
values: Sequence[GeneralValue] = [sample[key] for sample in batch]
|
|
23
|
+
if torch.is_tensor(test_value):
|
|
24
|
+
result[key] = torch.cat(values, dim=0)
|
|
25
|
+
else:
|
|
26
|
+
assert isinstance(test_value, dict)
|
|
27
|
+
result[key] = general_collate(values)
|
|
28
|
+
|
|
29
|
+
return result
|
|
File without changes
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
GeneralValue: TypeAlias = Union[torch.Tensor, "GeneralBatch"]
|
|
7
|
+
GeneralBatch: TypeAlias = dict[str, GeneralValue]
|
|
8
|
+
GeneralCollateFn: TypeAlias = Callable[[GeneralBatch], GeneralBatch]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Callable, Optional, Protocol, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import IterableDataset
|
|
7
|
+
|
|
8
|
+
from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
|
|
9
|
+
from replay.data.nn.parquet.impl.masking import DEFAULT_COLLATE_FN
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_batch_size(batch: GeneralBatch, strict: bool = False) -> int:
|
|
13
|
+
"""
|
|
14
|
+
Retrieves the size of the ``batch`` object.
|
|
15
|
+
|
|
16
|
+
:param batch: Batch object.
|
|
17
|
+
:param strict: If ``True``, performs additional validation. Default: ``False``.
|
|
18
|
+
|
|
19
|
+
:raises ValueError: If size mismatch is found in the batch during a strict check.
|
|
20
|
+
|
|
21
|
+
:return: Batch size.
|
|
22
|
+
"""
|
|
23
|
+
batch_size: Optional[int] = None
|
|
24
|
+
|
|
25
|
+
for key, value in batch.items():
|
|
26
|
+
new_batch_size: int
|
|
27
|
+
|
|
28
|
+
if torch.is_tensor(value):
|
|
29
|
+
new_batch_size = value.size(0)
|
|
30
|
+
else:
|
|
31
|
+
assert isinstance(value, dict)
|
|
32
|
+
new_batch_size = get_batch_size(value, strict)
|
|
33
|
+
|
|
34
|
+
if batch_size is None:
|
|
35
|
+
batch_size = new_batch_size
|
|
36
|
+
|
|
37
|
+
if strict:
|
|
38
|
+
if batch_size != new_batch_size:
|
|
39
|
+
msg = f"Batch size mismatch {key}: {batch_size} != {new_batch_size}"
|
|
40
|
+
raise ValueError(msg)
|
|
41
|
+
else:
|
|
42
|
+
break
|
|
43
|
+
assert batch_size is not None
|
|
44
|
+
return cast(int, batch_size)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def split_batches(batch: GeneralBatch, split: int) -> tuple[GeneralBatch, GeneralBatch]:
|
|
48
|
+
left: GeneralBatch = {}
|
|
49
|
+
right: GeneralBatch = {}
|
|
50
|
+
|
|
51
|
+
for key, value in batch.items():
|
|
52
|
+
if torch.is_tensor(value):
|
|
53
|
+
sub_left = value[:split, ...]
|
|
54
|
+
sub_right = value[split:, ...]
|
|
55
|
+
else:
|
|
56
|
+
sub_left, sub_right = split_batches(value, split)
|
|
57
|
+
left[key], right[key] = sub_left, sub_right
|
|
58
|
+
|
|
59
|
+
return (left, right)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DatasetProtocol(Protocol):
|
|
63
|
+
def __iter__(self) -> Iterator[GeneralBatch]: ...
|
|
64
|
+
@property
|
|
65
|
+
def batch_size(self) -> int: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class FixedBatchSizeDataset(IterableDataset):
|
|
69
|
+
"""
|
|
70
|
+
Wrapper for arbitrary datasets that fetches batches of fixed size.
|
|
71
|
+
Concatenates batches from the wrapped dataset until it reaches the specified size.
|
|
72
|
+
The last batch may be smaller than the specified size.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
dataset: DatasetProtocol,
|
|
78
|
+
batch_size: Optional[int] = None,
|
|
79
|
+
collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
|
|
80
|
+
strict_checks: bool = False,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
:param dataset: An iterable object that returns batches.
|
|
84
|
+
Generally a subclass of ``torch.utils.data.IterableDataset``.
|
|
85
|
+
:param batch_size: Desired batch size. If ``None``, will search for batch size in ``dataset.batch_size``.
|
|
86
|
+
Default: ``None``.
|
|
87
|
+
:param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
|
|
88
|
+
:param strict_checks: If ``True``, additional batch size checks will be performed.
|
|
89
|
+
May affect performance. Default: ``False``.
|
|
90
|
+
|
|
91
|
+
:raises ValueError: If an invalid batch size was provided.
|
|
92
|
+
"""
|
|
93
|
+
super().__init__()
|
|
94
|
+
|
|
95
|
+
self.dataset: DatasetProtocol = dataset
|
|
96
|
+
|
|
97
|
+
if batch_size is None:
|
|
98
|
+
assert hasattr(dataset, "batch_size")
|
|
99
|
+
batch_size = self.dataset.batch_size
|
|
100
|
+
|
|
101
|
+
assert isinstance(batch_size, int)
|
|
102
|
+
int_batch_size: int = cast(int, batch_size)
|
|
103
|
+
|
|
104
|
+
if int_batch_size < 1:
|
|
105
|
+
msg = f"Insufficient batch size. Got {int_batch_size=}"
|
|
106
|
+
raise ValueError(msg)
|
|
107
|
+
|
|
108
|
+
if int_batch_size < 2:
|
|
109
|
+
warnings.warn(f"Low batch size. Got {int_batch_size=}. This may cause performance issues.", stacklevel=2)
|
|
110
|
+
|
|
111
|
+
self.collate_fn: Callable = collate_fn
|
|
112
|
+
self.batch_size: int = int_batch_size
|
|
113
|
+
self.strict_checks: bool = strict_checks
|
|
114
|
+
|
|
115
|
+
def get_batch_size(self, batch: GeneralBatch) -> int:
|
|
116
|
+
return get_batch_size(batch, strict=self.strict_checks)
|
|
117
|
+
|
|
118
|
+
def __iter__(self) -> Iterator[GeneralBatch]:
|
|
119
|
+
iterator: Iterator[GeneralBatch] = iter(self.dataset)
|
|
120
|
+
|
|
121
|
+
buffer: list[GeneralBatch] = []
|
|
122
|
+
buffer_size: int = 0
|
|
123
|
+
|
|
124
|
+
while True:
|
|
125
|
+
while buffer_size < self.batch_size:
|
|
126
|
+
try:
|
|
127
|
+
batch: GeneralBatch = next(iterator)
|
|
128
|
+
size: int = self.get_batch_size(batch)
|
|
129
|
+
|
|
130
|
+
buffer.append(batch)
|
|
131
|
+
buffer_size += size
|
|
132
|
+
except StopIteration:
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
if buffer_size == 0:
|
|
136
|
+
break
|
|
137
|
+
|
|
138
|
+
joined: GeneralBatch = self.collate_fn(buffer)
|
|
139
|
+
assert buffer_size == self.get_batch_size(joined)
|
|
140
|
+
|
|
141
|
+
if self.batch_size < buffer_size:
|
|
142
|
+
left, right = split_batches(joined, self.batch_size)
|
|
143
|
+
residue: int = buffer_size - self.batch_size
|
|
144
|
+
assert residue == self.get_batch_size(right)
|
|
145
|
+
|
|
146
|
+
buffer_size = residue
|
|
147
|
+
buffer = [right]
|
|
148
|
+
|
|
149
|
+
yield left
|
|
150
|
+
else:
|
|
151
|
+
buffer_size = 0
|
|
152
|
+
buffer = []
|
|
153
|
+
|
|
154
|
+
yield joined
|
|
155
|
+
|
|
156
|
+
assert buffer_size == 0
|
|
157
|
+
assert len(buffer) == 0
|
|
File without changes
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import pyarrow.compute as pc
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
8
|
+
from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
|
|
9
|
+
from replay.data.nn.parquet.metadata import (
|
|
10
|
+
Metadata,
|
|
11
|
+
get_1d_array_columns,
|
|
12
|
+
get_padding,
|
|
13
|
+
get_shape,
|
|
14
|
+
)
|
|
15
|
+
from replay.data.utils.typing.dtype import pyarrow_to_torch
|
|
16
|
+
|
|
17
|
+
from .column_protocol import OutputType
|
|
18
|
+
from .indexing import get_mask, get_offsets
|
|
19
|
+
from .utils import ensure_mutable
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Array1DColumn:
|
|
23
|
+
"""
|
|
24
|
+
Representation of a 1D array column, containing a
|
|
25
|
+
list of numbers of varying length in each of its rows.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
data: torch.Tensor,
|
|
31
|
+
lengths: torch.LongTensor,
|
|
32
|
+
shape: Union[int, list[int]],
|
|
33
|
+
padding: Any = DEFAULT_PADDING,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
:param data: A tensor containing column data.
|
|
37
|
+
:param lengths: A tensor containing lengths of each individual row array.
|
|
38
|
+
:param shape: An integer or list of integers representing the target array shapes.
|
|
39
|
+
:param padding: Padding value to use to fill null values and match target shape.
|
|
40
|
+
Default: value of ``DEFAULT_PADDING``
|
|
41
|
+
|
|
42
|
+
:raises ValueError: If the shape provided is not one-dimensional.
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(shape, list) and len(shape) > 1:
|
|
45
|
+
msg = f"Array1DColumn accepts a shape of size (1,) only. Got {shape=}"
|
|
46
|
+
raise ValueError(msg)
|
|
47
|
+
|
|
48
|
+
self.padding = padding
|
|
49
|
+
self.data = data
|
|
50
|
+
self.offsets = get_offsets(lengths)
|
|
51
|
+
self.shape = shape[0] if isinstance(shape, list) else shape
|
|
52
|
+
assert self.length == torch.numel(lengths)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def length(self) -> int:
|
|
56
|
+
return torch.numel(self.offsets) - 1
|
|
57
|
+
|
|
58
|
+
def __len__(self) -> int:
|
|
59
|
+
return self.length
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def device(self) -> torch.device:
|
|
63
|
+
assert self.data.device == self.offsets.device
|
|
64
|
+
return self.offsets.device
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def dtype(self) -> torch.dtype:
|
|
68
|
+
return self.data.dtype
|
|
69
|
+
|
|
70
|
+
def __getitem__(self, indices: torch.LongTensor) -> OutputType:
|
|
71
|
+
indices = indices.to(device=self.device)
|
|
72
|
+
mask, output = get_mask(indices, self.offsets, self.shape)
|
|
73
|
+
|
|
74
|
+
# TODO: Test this for both 1d and 2d arrays. Add same check in 2d arrays
|
|
75
|
+
if self.data.numel() == 0:
|
|
76
|
+
mask = torch.zeros((indices.size(0), self.shape), dtype=torch.bool, device=self.device)
|
|
77
|
+
output = torch.ones((indices.size(0), self.shape), dtype=torch.bool, device=self.device) * self.padding
|
|
78
|
+
return mask, output
|
|
79
|
+
|
|
80
|
+
unmasked_values = torch.take(self.data, output)
|
|
81
|
+
masked_values = torch.where(mask, unmasked_values, self.padding)
|
|
82
|
+
assert masked_values.device == self.device
|
|
83
|
+
assert masked_values.dtype == self.dtype
|
|
84
|
+
return (mask, masked_values)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE) -> tuple[torch.Tensor, torch.Tensor]:
|
|
88
|
+
"""
|
|
89
|
+
Converts a PyArrow array into a PyTorch tensor.
|
|
90
|
+
|
|
91
|
+
:param array: Original PyArrow array.
|
|
92
|
+
:param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
|
|
93
|
+
|
|
94
|
+
:return: A PyTorch tensor obtained from original array.
|
|
95
|
+
"""
|
|
96
|
+
flatten = pc.list_flatten(array)
|
|
97
|
+
lengths = pc.list_value_length(array).cast(pa.int64())
|
|
98
|
+
|
|
99
|
+
# Copying to be mutable
|
|
100
|
+
flatten_torch = torch.asarray(
|
|
101
|
+
ensure_mutable(flatten.to_numpy()),
|
|
102
|
+
device=device,
|
|
103
|
+
dtype=pyarrow_to_torch(flatten.type),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Copying to be mutable
|
|
107
|
+
lengths_torch = torch.asarray(
|
|
108
|
+
ensure_mutable(lengths.to_numpy()),
|
|
109
|
+
device=device,
|
|
110
|
+
dtype=torch.int64,
|
|
111
|
+
)
|
|
112
|
+
return (lengths_torch, flatten_torch)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def to_array_1d_columns(
|
|
116
|
+
data: pa.RecordBatch,
|
|
117
|
+
metadata: Metadata,
|
|
118
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
119
|
+
) -> dict[str, Array1DColumn]:
|
|
120
|
+
"""
|
|
121
|
+
Converts a PyArrow batch of data to a set of ``Array1DColums``s.
|
|
122
|
+
This function filters only those columns matching its format from the full batch.
|
|
123
|
+
|
|
124
|
+
:param data: A PyArrow batch of column data.
|
|
125
|
+
:param metadata: Metadata containing information about columns' formats.
|
|
126
|
+
:param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
|
|
127
|
+
|
|
128
|
+
:return: A dict of tensors containing dataset's numeric columns.
|
|
129
|
+
"""
|
|
130
|
+
result: dict[str, Array1DColumn] = {}
|
|
131
|
+
|
|
132
|
+
for column_name in get_1d_array_columns(metadata):
|
|
133
|
+
lengths, torch_array = to_torch(data.column(column_name), device=device)
|
|
134
|
+
result[column_name] = Array1DColumn(
|
|
135
|
+
data=torch_array,
|
|
136
|
+
lengths=lengths,
|
|
137
|
+
padding=get_padding(metadata, column_name),
|
|
138
|
+
shape=get_shape(metadata, column_name),
|
|
139
|
+
)
|
|
140
|
+
return result
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import pyarrow.compute as pc
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
8
|
+
from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
|
|
9
|
+
from replay.data.nn.parquet.metadata import (
|
|
10
|
+
Metadata,
|
|
11
|
+
get_2d_array_columns,
|
|
12
|
+
get_padding,
|
|
13
|
+
get_shape,
|
|
14
|
+
)
|
|
15
|
+
from replay.data.utils.typing.dtype import pyarrow_to_torch
|
|
16
|
+
|
|
17
|
+
from .column_protocol import OutputType
|
|
18
|
+
from .indexing import get_mask, get_offsets
|
|
19
|
+
from .utils import ensure_mutable
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Array2DColumn:
|
|
23
|
+
"""
|
|
24
|
+
Representation of a 2D array column, containing nested
|
|
25
|
+
lists of numbers of varying length in each of its rows.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
data: torch.Tensor,
|
|
31
|
+
outer_lengths: torch.LongTensor,
|
|
32
|
+
inner_lengths: torch.LongTensor,
|
|
33
|
+
shape: list[int],
|
|
34
|
+
padding: Any = DEFAULT_PADDING,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""
|
|
37
|
+
:param data: A tensor containing column data.
|
|
38
|
+
:param outer_lengths: A tensor containing outer lengths (first dim) of each individual row array.
|
|
39
|
+
:param inner_lengths: A tensor containing inner lengths (second dim) of each individual row array.
|
|
40
|
+
:param shape: An integer or list of integers representing the target array shapes.
|
|
41
|
+
:param padding: Padding value to use to fill null values and match target shape.
|
|
42
|
+
Default: value of ``DEFAULT_PADDING``
|
|
43
|
+
|
|
44
|
+
:raises ValueError: If the shape provided is not two-dimensional.
|
|
45
|
+
"""
|
|
46
|
+
self.padding = padding
|
|
47
|
+
self.data = data
|
|
48
|
+
self.inner_offsets = get_offsets(inner_lengths)
|
|
49
|
+
self.outer_offsets = get_offsets(outer_lengths)
|
|
50
|
+
if len(shape) != 2:
|
|
51
|
+
msg = f"Array2DColumn accepts a shape of size (2,) only. Got {shape=}"
|
|
52
|
+
raise ValueError(msg)
|
|
53
|
+
self.shape: list[int] = shape
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def length(self) -> int:
|
|
57
|
+
return torch.numel(self.outer_offsets) - 1
|
|
58
|
+
|
|
59
|
+
def __len__(self) -> int:
|
|
60
|
+
return self.length
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def device(self) -> torch.device:
|
|
64
|
+
assert self.data.device == self.inner_offsets.device
|
|
65
|
+
assert self.data.device == self.outer_offsets.device
|
|
66
|
+
return self.inner_offsets.device
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def dtype(self) -> torch.dtype:
|
|
70
|
+
return self.data.dtype
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, indices: torch.LongTensor) -> OutputType:
|
|
73
|
+
indices = indices.to(device=self.device)
|
|
74
|
+
outer_mask, outer_output = get_mask(indices, self.outer_offsets, self.shape[0])
|
|
75
|
+
left_bound = outer_output.min().item()
|
|
76
|
+
right_bound = outer_output.max().item()
|
|
77
|
+
outer_output -= left_bound
|
|
78
|
+
|
|
79
|
+
inner_indices = torch.arange(left_bound, right_bound + 1, device=indices.device)
|
|
80
|
+
inner_mask, output = get_mask(inner_indices, self.inner_offsets, self.shape[1])
|
|
81
|
+
|
|
82
|
+
final_indices = output[outer_output]
|
|
83
|
+
inner_final_mask = inner_mask[outer_output]
|
|
84
|
+
|
|
85
|
+
unmasked_values = torch.take(self.data, final_indices)
|
|
86
|
+
outer_final_mask = outer_mask.unsqueeze(-1).repeat(1, 1, unmasked_values.size(-1))
|
|
87
|
+
mask = inner_final_mask * outer_final_mask
|
|
88
|
+
|
|
89
|
+
masked_values = torch.where(mask, unmasked_values, self.padding)
|
|
90
|
+
assert masked_values.device == self.device
|
|
91
|
+
assert masked_values.dtype == self.dtype
|
|
92
|
+
return (mask, masked_values)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def to_torch(
|
|
96
|
+
array: pa.Array,
|
|
97
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
98
|
+
) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
|
|
99
|
+
"""
|
|
100
|
+
Converts a PyArrow array into a PyTorch tensor.
|
|
101
|
+
|
|
102
|
+
:param array: Original PyArrow array.
|
|
103
|
+
:param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
|
|
104
|
+
|
|
105
|
+
:return: A PyTorch tensor obtained from original array.
|
|
106
|
+
"""
|
|
107
|
+
flatten_dim0 = pc.list_flatten(array)
|
|
108
|
+
flatten = pc.list_flatten(flatten_dim0)
|
|
109
|
+
|
|
110
|
+
outer_lengths = pc.list_value_length(array).cast(pa.int64())
|
|
111
|
+
inner_lengths = pc.list_value_length(flatten_dim0).cast(pa.int64())
|
|
112
|
+
|
|
113
|
+
# Copying to be mutable
|
|
114
|
+
flatten_torch = torch.asarray(
|
|
115
|
+
ensure_mutable(flatten.to_numpy()),
|
|
116
|
+
device=device,
|
|
117
|
+
dtype=pyarrow_to_torch(flatten.type),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Copying to be mutable
|
|
121
|
+
outer_lengths_torch = torch.asarray(
|
|
122
|
+
ensure_mutable(outer_lengths.to_numpy()),
|
|
123
|
+
device=device,
|
|
124
|
+
dtype=torch.int64,
|
|
125
|
+
)
|
|
126
|
+
inner_lengths_torch = torch.asarray(
|
|
127
|
+
ensure_mutable(inner_lengths.to_numpy()),
|
|
128
|
+
device=device,
|
|
129
|
+
dtype=torch.int64,
|
|
130
|
+
)
|
|
131
|
+
return (outer_lengths_torch, inner_lengths_torch, flatten_torch)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def to_array_2d_columns(
|
|
135
|
+
data: pa.RecordBatch,
|
|
136
|
+
metadata: Metadata,
|
|
137
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
138
|
+
) -> dict[str, Array2DColumn]:
|
|
139
|
+
"""
|
|
140
|
+
Converts a PyArrow batch of data to a set of ``Array2DColums``s.
|
|
141
|
+
This function filters only those columns matching its format from the full batch.
|
|
142
|
+
|
|
143
|
+
:param data: A PyArrow batch of column data.
|
|
144
|
+
:param metadata: Metadata containing information about columns' formats.
|
|
145
|
+
:param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
|
|
146
|
+
|
|
147
|
+
:return: A dict of tensors containing dataset's numeric columns.
|
|
148
|
+
"""
|
|
149
|
+
result = {}
|
|
150
|
+
|
|
151
|
+
for column_name in get_2d_array_columns(metadata):
|
|
152
|
+
outer_lengths, inner_lengths, torch_array = to_torch(data.column(column_name), device=device)
|
|
153
|
+
result[column_name] = Array2DColumn(
|
|
154
|
+
data=torch_array,
|
|
155
|
+
outer_lengths=outer_lengths,
|
|
156
|
+
inner_lengths=inner_lengths,
|
|
157
|
+
padding=get_padding(metadata, column_name),
|
|
158
|
+
shape=get_shape(metadata, column_name),
|
|
159
|
+
)
|
|
160
|
+
return result
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
OutputType = tuple[torch.BoolTensor, torch.Tensor]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ColumnProtocol(Protocol):
|
|
9
|
+
def __len__(self) -> int: ...
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def length(self) -> int: ...
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def device(self) -> torch.device: ...
|
|
16
|
+
|
|
17
|
+
def __getitem__(self, indices: torch.LongTensor) -> OutputType: ...
|