replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/experimental/models/extensions/spark_custom_models/als_extension.py +1 -1
  44. replay/metrics/torch_metrics_builder.py +20 -14
  45. replay/models/nn/loss/sce.py +2 -7
  46. replay/models/nn/optimizer_utils/__init__.py +6 -1
  47. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  48. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  49. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  50. replay/models/nn/sequential/bert4rec/model.py +11 -11
  51. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  52. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  53. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  54. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  55. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  56. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  57. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  58. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  59. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  60. replay/models/nn/sequential/sasrec/model.py +14 -9
  61. replay/nn/__init__.py +8 -0
  62. replay/nn/agg.py +109 -0
  63. replay/nn/attention.py +158 -0
  64. replay/nn/embedding.py +283 -0
  65. replay/nn/ffn.py +135 -0
  66. replay/nn/head.py +49 -0
  67. replay/nn/lightning/__init__.py +1 -0
  68. replay/nn/lightning/callback/__init__.py +9 -0
  69. replay/nn/lightning/callback/metrics_callback.py +183 -0
  70. replay/nn/lightning/callback/predictions_callback.py +314 -0
  71. replay/nn/lightning/module.py +123 -0
  72. replay/nn/lightning/optimizer.py +60 -0
  73. replay/nn/lightning/postprocessor/__init__.py +2 -0
  74. replay/nn/lightning/postprocessor/_base.py +51 -0
  75. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  76. replay/nn/lightning/scheduler.py +91 -0
  77. replay/nn/loss/__init__.py +22 -0
  78. replay/nn/loss/base.py +197 -0
  79. replay/nn/loss/bce.py +216 -0
  80. replay/nn/loss/ce.py +317 -0
  81. replay/nn/loss/login_ce.py +373 -0
  82. replay/nn/loss/logout_ce.py +230 -0
  83. replay/nn/mask.py +87 -0
  84. replay/nn/normalization.py +9 -0
  85. replay/nn/output.py +37 -0
  86. replay/nn/sequential/__init__.py +9 -0
  87. replay/nn/sequential/sasrec/__init__.py +7 -0
  88. replay/nn/sequential/sasrec/agg.py +53 -0
  89. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  90. replay/nn/sequential/sasrec/model.py +377 -0
  91. replay/nn/sequential/sasrec/transformer.py +107 -0
  92. replay/nn/sequential/twotower/__init__.py +2 -0
  93. replay/nn/sequential/twotower/model.py +674 -0
  94. replay/nn/sequential/twotower/reader.py +89 -0
  95. replay/nn/transform/__init__.py +22 -0
  96. replay/nn/transform/copy.py +38 -0
  97. replay/nn/transform/grouping.py +39 -0
  98. replay/nn/transform/negative_sampling.py +182 -0
  99. replay/nn/transform/next_token.py +100 -0
  100. replay/nn/transform/rename.py +33 -0
  101. replay/nn/transform/reshape.py +41 -0
  102. replay/nn/transform/sequence_roll.py +48 -0
  103. replay/nn/transform/template/__init__.py +2 -0
  104. replay/nn/transform/template/sasrec.py +53 -0
  105. replay/nn/transform/template/twotower.py +22 -0
  106. replay/nn/transform/token_mask.py +69 -0
  107. replay/nn/transform/trim.py +51 -0
  108. replay/nn/utils.py +28 -0
  109. replay/preprocessing/filters.py +128 -0
  110. replay/preprocessing/label_encoder.py +36 -33
  111. replay/preprocessing/utils.py +209 -0
  112. replay/splitters/__init__.py +1 -0
  113. replay/splitters/random_next_n_splitter.py +224 -0
  114. replay/utils/common.py +10 -4
  115. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/METADATA +3 -3
  116. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/RECORD +119 -34
  117. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0rc0.dist-info}/licenses/NOTICE +0 -0
replay/__init__.py CHANGED
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.20.3.preview"
7
+ __version__ = "0.21.0.preview"
replay/data/dataset.py CHANGED
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import json
8
+ import warnings
8
9
  from collections.abc import Iterable, Sequence
9
10
  from pathlib import Path
10
11
  from typing import Callable, Optional, Union
@@ -45,6 +46,7 @@ class Dataset:
45
46
  ):
46
47
  """
47
48
  :param feature_schema: mapping of columns names and feature infos.
49
+ All features not specified in the schema will be assumed numerical by default.
48
50
  :param interactions: dataframe with interactions.
49
51
  :param query_features: dataframe with query features,
50
52
  defaults: ```None```.
@@ -498,6 +500,15 @@ class Dataset:
498
500
  source=FeatureSource.QUERY_FEATURES,
499
501
  feature_schema=updated_feature_schema,
500
502
  )
503
+
504
+ if filled_features:
505
+ msg = (
506
+ "The following features are present in the dataset but have not been specified "
507
+ f"by the feature schema: {[(info.column, info.feature_source.value) for info in filled_features]}. "
508
+ "These features will be interpreted as NUMERICAL."
509
+ )
510
+ warnings.warn(msg, stacklevel=2)
511
+
501
512
  return FeatureSchema(features_list=features_list + filled_features)
502
513
 
503
514
  def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
@@ -1,6 +1,7 @@
1
1
  from replay.utils import TORCH_AVAILABLE
2
2
 
3
3
  if TORCH_AVAILABLE:
4
+ from .parquet import ParquetDataset, ParquetModule
4
5
  from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
5
6
  from .sequence_tokenizer import SequenceTokenizer
6
7
  from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
@@ -18,6 +19,8 @@ if TORCH_AVAILABLE:
18
19
  "DEFAULT_TRAIN_PADDING_VALUE",
19
20
  "MutableTensorMap",
20
21
  "PandasSequentialDataset",
22
+ "ParquetDataset",
23
+ "ParquetModule",
21
24
  "PolarsSequentialDataset",
22
25
  "SequenceTokenizer",
23
26
  "SequentialDataset",
@@ -0,0 +1,22 @@
1
+ """
2
+ Implementation of the ``ParquetDataset`` and its internals.
3
+
4
+ ``ParquetDataset`` is combination of PyTorch-compatible dataset and sampler which enables
5
+ training and inference of models on datasets of any arbitrary size by leveraging PyArrow
6
+ Datasets to perform batch-wise reading and processing of data from disk.
7
+
8
+ ``ParquetDataset`` includes support for Pytorch's distributed training framework as well as
9
+ access to remotely stored data via PyArrow's filesystem configs.
10
+ """
11
+
12
+ from .info.replicas import DEFAULT_REPLICAS_INFO, ReplicasInfo, ReplicasInfoProtocol
13
+ from .parquet_dataset import ParquetDataset
14
+ from .parquet_module import ParquetModule
15
+
16
+ __all__ = [
17
+ "DEFAULT_REPLICAS_INFO",
18
+ "ParquetDataset",
19
+ "ParquetModule",
20
+ "ReplicasInfo",
21
+ "ReplicasInfoProtocol",
22
+ ]
@@ -0,0 +1,29 @@
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+
5
+ from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralValue
6
+
7
+
8
+ def dict_collate(batch: Sequence[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
9
+ """Simple collate function that converts a dict of values into a tensor dict."""
10
+ return {k: torch.cat([d[k] for d in batch], dim=0) for k in batch[0]}
11
+
12
+
13
+ def general_collate(batch: Sequence[GeneralBatch]) -> GeneralBatch:
14
+ """General collate function that converts a nested dict of values into a tensor dict."""
15
+ result = {}
16
+ test_sample = batch[0]
17
+
18
+ if len(batch) == 1:
19
+ return test_sample
20
+
21
+ for key, test_value in test_sample.items():
22
+ values: Sequence[GeneralValue] = [sample[key] for sample in batch]
23
+ if torch.is_tensor(test_value):
24
+ result[key] = torch.cat(values, dim=0)
25
+ else:
26
+ assert isinstance(test_value, dict)
27
+ result[key] = general_collate(values)
28
+
29
+ return result
File without changes
@@ -0,0 +1,8 @@
1
+ from typing import Callable, Union
2
+
3
+ import torch
4
+ from typing_extensions import TypeAlias
5
+
6
+ GeneralValue: TypeAlias = Union[torch.Tensor, "GeneralBatch"]
7
+ GeneralBatch: TypeAlias = dict[str, GeneralValue]
8
+ GeneralCollateFn: TypeAlias = Callable[[GeneralBatch], GeneralBatch]
@@ -0,0 +1,3 @@
1
+ import torch
2
+
3
+ DEFAULT_DEVICE = torch.device("cpu")
@@ -0,0 +1,3 @@
1
+ import pyarrow.fs as fs
2
+
3
+ DEFAULT_FILESYSTEM = fs.LocalFileSystem()
@@ -0,0 +1,5 @@
1
+ SHAPE_FLAG = "shape"
2
+ PADDING_FLAG = "padding"
3
+ DEFAULT_PADDING = -1
4
+ SEQUENCE_LENGTH_FLAG = "sequence_length"
5
+ PADDING_FLAG = "padding"
@@ -0,0 +1,157 @@
1
+ import warnings
2
+ from collections.abc import Iterator
3
+ from typing import Callable, Optional, Protocol, cast
4
+
5
+ import torch
6
+ from torch.utils.data import IterableDataset
7
+
8
+ from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
9
+ from replay.data.nn.parquet.impl.masking import DEFAULT_COLLATE_FN
10
+
11
+
12
+ def get_batch_size(batch: GeneralBatch, strict: bool = False) -> int:
13
+ """
14
+ Retrieves the size of the ``batch`` object.
15
+
16
+ :param batch: Batch object.
17
+ :param strict: If ``True``, performs additional validation. Default: ``False``.
18
+
19
+ :raises ValueError: If size mismatch is found in the batch during a strict check.
20
+
21
+ :return: Batch size.
22
+ """
23
+ batch_size: Optional[int] = None
24
+
25
+ for key, value in batch.items():
26
+ new_batch_size: int
27
+
28
+ if torch.is_tensor(value):
29
+ new_batch_size = value.size(0)
30
+ else:
31
+ assert isinstance(value, dict)
32
+ new_batch_size = get_batch_size(value, strict)
33
+
34
+ if batch_size is None:
35
+ batch_size = new_batch_size
36
+
37
+ if strict:
38
+ if batch_size != new_batch_size:
39
+ msg = f"Batch size mismatch {key}: {batch_size} != {new_batch_size}"
40
+ raise ValueError(msg)
41
+ else:
42
+ break
43
+ assert batch_size is not None
44
+ return cast(int, batch_size)
45
+
46
+
47
+ def split_batches(batch: GeneralBatch, split: int) -> tuple[GeneralBatch, GeneralBatch]:
48
+ left: GeneralBatch = {}
49
+ right: GeneralBatch = {}
50
+
51
+ for key, value in batch.items():
52
+ if torch.is_tensor(value):
53
+ sub_left = value[:split, ...]
54
+ sub_right = value[split:, ...]
55
+ else:
56
+ sub_left, sub_right = split_batches(value, split)
57
+ left[key], right[key] = sub_left, sub_right
58
+
59
+ return (left, right)
60
+
61
+
62
+ class DatasetProtocol(Protocol):
63
+ def __iter__(self) -> Iterator[GeneralBatch]: ...
64
+ @property
65
+ def batch_size(self) -> int: ...
66
+
67
+
68
+ class FixedBatchSizeDataset(IterableDataset):
69
+ """
70
+ Wrapper for arbitrary datasets that fetches batches of fixed size.
71
+ Concatenates batches from the wrapped dataset until it reaches the specified size.
72
+ The last batch may be smaller than the specified size.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dataset: DatasetProtocol,
78
+ batch_size: Optional[int] = None,
79
+ collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
80
+ strict_checks: bool = False,
81
+ ) -> None:
82
+ """
83
+ :param dataset: An iterable object that returns batches.
84
+ Generally a subclass of ``torch.utils.data.IterableDataset``.
85
+ :param batch_size: Desired batch size. If ``None``, will search for batch size in ``dataset.batch_size``.
86
+ Default: ``None``.
87
+ :param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
88
+ :param strict_checks: If ``True``, additional batch size checks will be performed.
89
+ May affect performance. Default: ``False``.
90
+
91
+ :raises ValueError: If an invalid batch size was provided.
92
+ """
93
+ super().__init__()
94
+
95
+ self.dataset: DatasetProtocol = dataset
96
+
97
+ if batch_size is None:
98
+ assert hasattr(dataset, "batch_size")
99
+ batch_size = self.dataset.batch_size
100
+
101
+ assert isinstance(batch_size, int)
102
+ int_batch_size: int = cast(int, batch_size)
103
+
104
+ if int_batch_size < 1:
105
+ msg = f"Insufficient batch size. Got {int_batch_size=}"
106
+ raise ValueError(msg)
107
+
108
+ if int_batch_size < 2:
109
+ warnings.warn(f"Low batch size. Got {int_batch_size=}. This may cause performance issues.", stacklevel=2)
110
+
111
+ self.collate_fn: Callable = collate_fn
112
+ self.batch_size: int = int_batch_size
113
+ self.strict_checks: bool = strict_checks
114
+
115
+ def get_batch_size(self, batch: GeneralBatch) -> int:
116
+ return get_batch_size(batch, strict=self.strict_checks)
117
+
118
+ def __iter__(self) -> Iterator[GeneralBatch]:
119
+ iterator: Iterator[GeneralBatch] = iter(self.dataset)
120
+
121
+ buffer: list[GeneralBatch] = []
122
+ buffer_size: int = 0
123
+
124
+ while True:
125
+ while buffer_size < self.batch_size:
126
+ try:
127
+ batch: GeneralBatch = next(iterator)
128
+ size: int = self.get_batch_size(batch)
129
+
130
+ buffer.append(batch)
131
+ buffer_size += size
132
+ except StopIteration:
133
+ break
134
+
135
+ if buffer_size == 0:
136
+ break
137
+
138
+ joined: GeneralBatch = self.collate_fn(buffer)
139
+ assert buffer_size == self.get_batch_size(joined)
140
+
141
+ if self.batch_size < buffer_size:
142
+ left, right = split_batches(joined, self.batch_size)
143
+ residue: int = buffer_size - self.batch_size
144
+ assert residue == self.get_batch_size(right)
145
+
146
+ buffer_size = residue
147
+ buffer = [right]
148
+
149
+ yield left
150
+ else:
151
+ buffer_size = 0
152
+ buffer = []
153
+
154
+ yield joined
155
+
156
+ assert buffer_size == 0
157
+ assert len(buffer) == 0
File without changes
@@ -0,0 +1,140 @@
1
+ from typing import Any, Union
2
+
3
+ import pyarrow as pa
4
+ import pyarrow.compute as pc
5
+ import torch
6
+
7
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
8
+ from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
9
+ from replay.data.nn.parquet.metadata import (
10
+ Metadata,
11
+ get_1d_array_columns,
12
+ get_padding,
13
+ get_shape,
14
+ )
15
+ from replay.data.utils.typing.dtype import pyarrow_to_torch
16
+
17
+ from .column_protocol import OutputType
18
+ from .indexing import get_mask, get_offsets
19
+ from .utils import ensure_mutable
20
+
21
+
22
+ class Array1DColumn:
23
+ """
24
+ Representation of a 1D array column, containing a
25
+ list of numbers of varying length in each of its rows.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data: torch.Tensor,
31
+ lengths: torch.LongTensor,
32
+ shape: Union[int, list[int]],
33
+ padding: Any = DEFAULT_PADDING,
34
+ ) -> None:
35
+ """
36
+ :param data: A tensor containing column data.
37
+ :param lengths: A tensor containing lengths of each individual row array.
38
+ :param shape: An integer or list of integers representing the target array shapes.
39
+ :param padding: Padding value to use to fill null values and match target shape.
40
+ Default: value of ``DEFAULT_PADDING``
41
+
42
+ :raises ValueError: If the shape provided is not one-dimensional.
43
+ """
44
+ if isinstance(shape, list) and len(shape) > 1:
45
+ msg = f"Array1DColumn accepts a shape of size (1,) only. Got {shape=}"
46
+ raise ValueError(msg)
47
+
48
+ self.padding = padding
49
+ self.data = data
50
+ self.offsets = get_offsets(lengths)
51
+ self.shape = shape[0] if isinstance(shape, list) else shape
52
+ assert self.length == torch.numel(lengths)
53
+
54
+ @property
55
+ def length(self) -> int:
56
+ return torch.numel(self.offsets) - 1
57
+
58
+ def __len__(self) -> int:
59
+ return self.length
60
+
61
+ @property
62
+ def device(self) -> torch.device:
63
+ assert self.data.device == self.offsets.device
64
+ return self.offsets.device
65
+
66
+ @property
67
+ def dtype(self) -> torch.dtype:
68
+ return self.data.dtype
69
+
70
+ def __getitem__(self, indices: torch.LongTensor) -> OutputType:
71
+ indices = indices.to(device=self.device)
72
+ mask, output = get_mask(indices, self.offsets, self.shape)
73
+
74
+ # TODO: Test this for both 1d and 2d arrays. Add same check in 2d arrays
75
+ if self.data.numel() == 0:
76
+ mask = torch.zeros((indices.size(0), self.shape), dtype=torch.bool, device=self.device)
77
+ output = torch.ones((indices.size(0), self.shape), dtype=torch.bool, device=self.device) * self.padding
78
+ return mask, output
79
+
80
+ unmasked_values = torch.take(self.data, output)
81
+ masked_values = torch.where(mask, unmasked_values, self.padding)
82
+ assert masked_values.device == self.device
83
+ assert masked_values.dtype == self.dtype
84
+ return (mask, masked_values)
85
+
86
+
87
+ def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE) -> tuple[torch.Tensor, torch.Tensor]:
88
+ """
89
+ Converts a PyArrow array into a PyTorch tensor.
90
+
91
+ :param array: Original PyArrow array.
92
+ :param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
93
+
94
+ :return: A PyTorch tensor obtained from original array.
95
+ """
96
+ flatten = pc.list_flatten(array)
97
+ lengths = pc.list_value_length(array).cast(pa.int64())
98
+
99
+ # Copying to be mutable
100
+ flatten_torch = torch.asarray(
101
+ ensure_mutable(flatten.to_numpy()),
102
+ device=device,
103
+ dtype=pyarrow_to_torch(flatten.type),
104
+ )
105
+
106
+ # Copying to be mutable
107
+ lengths_torch = torch.asarray(
108
+ ensure_mutable(lengths.to_numpy()),
109
+ device=device,
110
+ dtype=torch.int64,
111
+ )
112
+ return (lengths_torch, flatten_torch)
113
+
114
+
115
+ def to_array_1d_columns(
116
+ data: pa.RecordBatch,
117
+ metadata: Metadata,
118
+ device: torch.device = DEFAULT_DEVICE,
119
+ ) -> dict[str, Array1DColumn]:
120
+ """
121
+ Converts a PyArrow batch of data to a set of ``Array1DColums``s.
122
+ This function filters only those columns matching its format from the full batch.
123
+
124
+ :param data: A PyArrow batch of column data.
125
+ :param metadata: Metadata containing information about columns' formats.
126
+ :param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
127
+
128
+ :return: A dict of tensors containing dataset's numeric columns.
129
+ """
130
+ result: dict[str, Array1DColumn] = {}
131
+
132
+ for column_name in get_1d_array_columns(metadata):
133
+ lengths, torch_array = to_torch(data.column(column_name), device=device)
134
+ result[column_name] = Array1DColumn(
135
+ data=torch_array,
136
+ lengths=lengths,
137
+ padding=get_padding(metadata, column_name),
138
+ shape=get_shape(metadata, column_name),
139
+ )
140
+ return result
@@ -0,0 +1,160 @@
1
+ from typing import Any
2
+
3
+ import pyarrow as pa
4
+ import pyarrow.compute as pc
5
+ import torch
6
+
7
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
8
+ from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
9
+ from replay.data.nn.parquet.metadata import (
10
+ Metadata,
11
+ get_2d_array_columns,
12
+ get_padding,
13
+ get_shape,
14
+ )
15
+ from replay.data.utils.typing.dtype import pyarrow_to_torch
16
+
17
+ from .column_protocol import OutputType
18
+ from .indexing import get_mask, get_offsets
19
+ from .utils import ensure_mutable
20
+
21
+
22
+ class Array2DColumn:
23
+ """
24
+ Representation of a 2D array column, containing nested
25
+ lists of numbers of varying length in each of its rows.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data: torch.Tensor,
31
+ outer_lengths: torch.LongTensor,
32
+ inner_lengths: torch.LongTensor,
33
+ shape: list[int],
34
+ padding: Any = DEFAULT_PADDING,
35
+ ) -> None:
36
+ """
37
+ :param data: A tensor containing column data.
38
+ :param outer_lengths: A tensor containing outer lengths (first dim) of each individual row array.
39
+ :param inner_lengths: A tensor containing inner lengths (second dim) of each individual row array.
40
+ :param shape: An integer or list of integers representing the target array shapes.
41
+ :param padding: Padding value to use to fill null values and match target shape.
42
+ Default: value of ``DEFAULT_PADDING``
43
+
44
+ :raises ValueError: If the shape provided is not two-dimensional.
45
+ """
46
+ self.padding = padding
47
+ self.data = data
48
+ self.inner_offsets = get_offsets(inner_lengths)
49
+ self.outer_offsets = get_offsets(outer_lengths)
50
+ if len(shape) != 2:
51
+ msg = f"Array2DColumn accepts a shape of size (2,) only. Got {shape=}"
52
+ raise ValueError(msg)
53
+ self.shape: list[int] = shape
54
+
55
+ @property
56
+ def length(self) -> int:
57
+ return torch.numel(self.outer_offsets) - 1
58
+
59
+ def __len__(self) -> int:
60
+ return self.length
61
+
62
+ @property
63
+ def device(self) -> torch.device:
64
+ assert self.data.device == self.inner_offsets.device
65
+ assert self.data.device == self.outer_offsets.device
66
+ return self.inner_offsets.device
67
+
68
+ @property
69
+ def dtype(self) -> torch.dtype:
70
+ return self.data.dtype
71
+
72
+ def __getitem__(self, indices: torch.LongTensor) -> OutputType:
73
+ indices = indices.to(device=self.device)
74
+ outer_mask, outer_output = get_mask(indices, self.outer_offsets, self.shape[0])
75
+ left_bound = outer_output.min().item()
76
+ right_bound = outer_output.max().item()
77
+ outer_output -= left_bound
78
+
79
+ inner_indices = torch.arange(left_bound, right_bound + 1, device=indices.device)
80
+ inner_mask, output = get_mask(inner_indices, self.inner_offsets, self.shape[1])
81
+
82
+ final_indices = output[outer_output]
83
+ inner_final_mask = inner_mask[outer_output]
84
+
85
+ unmasked_values = torch.take(self.data, final_indices)
86
+ outer_final_mask = outer_mask.unsqueeze(-1).repeat(1, 1, unmasked_values.size(-1))
87
+ mask = inner_final_mask * outer_final_mask
88
+
89
+ masked_values = torch.where(mask, unmasked_values, self.padding)
90
+ assert masked_values.device == self.device
91
+ assert masked_values.dtype == self.dtype
92
+ return (mask, masked_values)
93
+
94
+
95
+ def to_torch(
96
+ array: pa.Array,
97
+ device: torch.device = DEFAULT_DEVICE,
98
+ ) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
99
+ """
100
+ Converts a PyArrow array into a PyTorch tensor.
101
+
102
+ :param array: Original PyArrow array.
103
+ :param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
104
+
105
+ :return: A PyTorch tensor obtained from original array.
106
+ """
107
+ flatten_dim0 = pc.list_flatten(array)
108
+ flatten = pc.list_flatten(flatten_dim0)
109
+
110
+ outer_lengths = pc.list_value_length(array).cast(pa.int64())
111
+ inner_lengths = pc.list_value_length(flatten_dim0).cast(pa.int64())
112
+
113
+ # Copying to be mutable
114
+ flatten_torch = torch.asarray(
115
+ ensure_mutable(flatten.to_numpy()),
116
+ device=device,
117
+ dtype=pyarrow_to_torch(flatten.type),
118
+ )
119
+
120
+ # Copying to be mutable
121
+ outer_lengths_torch = torch.asarray(
122
+ ensure_mutable(outer_lengths.to_numpy()),
123
+ device=device,
124
+ dtype=torch.int64,
125
+ )
126
+ inner_lengths_torch = torch.asarray(
127
+ ensure_mutable(inner_lengths.to_numpy()),
128
+ device=device,
129
+ dtype=torch.int64,
130
+ )
131
+ return (outer_lengths_torch, inner_lengths_torch, flatten_torch)
132
+
133
+
134
+ def to_array_2d_columns(
135
+ data: pa.RecordBatch,
136
+ metadata: Metadata,
137
+ device: torch.device = DEFAULT_DEVICE,
138
+ ) -> dict[str, Array2DColumn]:
139
+ """
140
+ Converts a PyArrow batch of data to a set of ``Array2DColums``s.
141
+ This function filters only those columns matching its format from the full batch.
142
+
143
+ :param data: A PyArrow batch of column data.
144
+ :param metadata: Metadata containing information about columns' formats.
145
+ :param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
146
+
147
+ :return: A dict of tensors containing dataset's numeric columns.
148
+ """
149
+ result = {}
150
+
151
+ for column_name in get_2d_array_columns(metadata):
152
+ outer_lengths, inner_lengths, torch_array = to_torch(data.column(column_name), device=device)
153
+ result[column_name] = Array2DColumn(
154
+ data=torch_array,
155
+ outer_lengths=outer_lengths,
156
+ inner_lengths=inner_lengths,
157
+ padding=get_padding(metadata, column_name),
158
+ shape=get_shape(metadata, column_name),
159
+ )
160
+ return result
@@ -0,0 +1,17 @@
1
+ from typing import Protocol
2
+
3
+ import torch
4
+
5
+ OutputType = tuple[torch.BoolTensor, torch.Tensor]
6
+
7
+
8
+ class ColumnProtocol(Protocol):
9
+ def __len__(self) -> int: ...
10
+
11
+ @property
12
+ def length(self) -> int: ...
13
+
14
+ @property
15
+ def device(self) -> torch.device: ...
16
+
17
+ def __getitem__(self, indices: torch.LongTensor) -> OutputType: ...