replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def raw_get_offsets(lengths: torch.LongTensor) -> torch.LongTensor:
|
|
7
|
+
"""
|
|
8
|
+
Performs offset calculation, defined simply as a cumulative sum of
|
|
9
|
+
the provided lengths tensor.
|
|
10
|
+
|
|
11
|
+
:param lengths: A tensor containing lengths of each individual row in a dataset's column.
|
|
12
|
+
:return: A tensor of offsets for each row.
|
|
13
|
+
"""
|
|
14
|
+
zero = torch.zeros((1,), device=lengths.device, dtype=torch.int64)
|
|
15
|
+
cumsum = torch.cumsum(lengths, dim=-1)
|
|
16
|
+
return torch.cat([zero, cumsum])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_offsets(lengths: torch.LongTensor) -> torch.LongTensor:
|
|
20
|
+
"""
|
|
21
|
+
Sanitizes row lengths, then calculates offsets for each row.
|
|
22
|
+
The calculation itself is performed via the ``raw_get_offsets`` method.
|
|
23
|
+
|
|
24
|
+
:param lengths: A tensor containing lengths of each individual row in a dataset's column.
|
|
25
|
+
:raises ValueError: If the lengths tensor is of invalid shape or contains negative values.
|
|
26
|
+
|
|
27
|
+
:return: A tensor of offsets for each row.
|
|
28
|
+
"""
|
|
29
|
+
if lengths.ndim != 1:
|
|
30
|
+
msg = f"Lengths must be strictly 1D. Got {lengths.ndim}D."
|
|
31
|
+
raise ValueError(msg)
|
|
32
|
+
min_length = torch.min(lengths.detach()).cpu().item()
|
|
33
|
+
if min_length < 0:
|
|
34
|
+
msg = f"There is a negative length. Got {min_length}."
|
|
35
|
+
raise ValueError(msg)
|
|
36
|
+
return raw_get_offsets(lengths)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
LengthType = Union[int, torch.LongTensor]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def raw_get_mask(
|
|
43
|
+
indices: torch.LongTensor,
|
|
44
|
+
offsets: torch.LongTensor,
|
|
45
|
+
length: LengthType,
|
|
46
|
+
) -> tuple[torch.BoolTensor, torch.LongTensor]:
|
|
47
|
+
"""
|
|
48
|
+
Performs mask construction.
|
|
49
|
+
Given the data itself, its offsets and the expected sequence length, returns two tensors.
|
|
50
|
+
|
|
51
|
+
The first tensor is the padding mask, where ``False`` represents a padded value that was not present in the data,
|
|
52
|
+
and ``True`` represents a real element from the dataset.
|
|
53
|
+
|
|
54
|
+
The second tensor is the data itself, left-padded with a 0 to the desired length.
|
|
55
|
+
|
|
56
|
+
:param indices: A tensor of indices to be sampled from the dataset.
|
|
57
|
+
:param offsets: A tensor containing individual offsets for each of the column's rows.
|
|
58
|
+
:param length: THe total number of elements in a dataset's column.
|
|
59
|
+
|
|
60
|
+
:return: Constructed mask.
|
|
61
|
+
"""
|
|
62
|
+
length = torch.asarray(length, dtype=torch.int64, device=indices.device)
|
|
63
|
+
|
|
64
|
+
# For every "line", start element index matches the offset, while end is the offset of the next line
|
|
65
|
+
last = offsets[indices + 1]
|
|
66
|
+
first = offsets[indices + 0]
|
|
67
|
+
|
|
68
|
+
per_line = length - (last - first)
|
|
69
|
+
|
|
70
|
+
arange = torch.arange(length, dtype=torch.int64, device=offsets.device)
|
|
71
|
+
raw_indices = (first[:, None] - per_line[:, None]) + arange[None, :]
|
|
72
|
+
mask = (first[:, None] <= raw_indices) & (raw_indices < last[:, None])
|
|
73
|
+
|
|
74
|
+
assert torch.all(torch.sum(mask, dim=-1, dtype=torch.int64) == torch.minimum(last - first, length)).cpu().item()
|
|
75
|
+
|
|
76
|
+
output_indices = torch.where(mask, raw_indices, 0)
|
|
77
|
+
assert torch.all((torch.max(output_indices, dim=-1).values < last) | (last == first)).cpu().item()
|
|
78
|
+
return (mask, output_indices)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_mask(
|
|
82
|
+
indices: torch.LongTensor,
|
|
83
|
+
offsets: torch.LongTensor,
|
|
84
|
+
length: LengthType,
|
|
85
|
+
) -> tuple[torch.BoolTensor, torch.LongTensor]:
|
|
86
|
+
"""
|
|
87
|
+
Perform input sanity checks, then contructs a mask from inputs.
|
|
88
|
+
The mask calculation itself is performed via the ``raw_get_mask`` method.
|
|
89
|
+
|
|
90
|
+
:param indices: A tensor of indices to be sampled from the dataset.
|
|
91
|
+
:param offsets: A tensor containing individual offsets for each of the column's rows.
|
|
92
|
+
:param length: THe total number of elements in a dataset's column.
|
|
93
|
+
|
|
94
|
+
:raises ValueError: When mishaped or otherwise invalid arguments are provided.
|
|
95
|
+
:raises IndexError: When sampling indices missing from dataset or none at all.
|
|
96
|
+
:raises RuntimeError: When provided tensors are not on the same device.
|
|
97
|
+
|
|
98
|
+
:return: Constructed mask.
|
|
99
|
+
"""
|
|
100
|
+
if torch.asarray(length).cpu().item() < 1:
|
|
101
|
+
msg = f"Length must be a positive number. Got {length}"
|
|
102
|
+
raise ValueError(msg)
|
|
103
|
+
if torch.numel(indices) < 1:
|
|
104
|
+
msg = f"Indices must be non-empty. Got {torch.numel(indices)}."
|
|
105
|
+
raise IndexError(msg)
|
|
106
|
+
if indices.device != offsets.device: # pragma: no cover
|
|
107
|
+
msg = f"Devices must match. Got {indices.device} vs {offsets.device}"
|
|
108
|
+
raise RuntimeError(msg)
|
|
109
|
+
if offsets.ndim != 1:
|
|
110
|
+
msg = f"Offsets must be strictly 1D. Got {offsets.ndim}D."
|
|
111
|
+
raise ValueError(msg)
|
|
112
|
+
min_index = torch.min(indices.detach()).cpu().item()
|
|
113
|
+
if min_index < 0:
|
|
114
|
+
msg = f"Index is too small. Got {min_index}."
|
|
115
|
+
raise IndexError(msg)
|
|
116
|
+
max_index = torch.max(indices.detach()).cpu().item()
|
|
117
|
+
if torch.numel(offsets) < max_index:
|
|
118
|
+
msg = f"Index is too large. Got {max_index}."
|
|
119
|
+
raise IndexError(msg)
|
|
120
|
+
if not torch.all(offsets[:-1] <= offsets[1:]).cpu().item():
|
|
121
|
+
msg = "Offset sequence is not monothonous."
|
|
122
|
+
raise ValueError(msg)
|
|
123
|
+
return raw_get_mask(indices, offsets, length)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from replay.data.nn.parquet.collate import general_collate
|
|
4
|
+
from replay.data.nn.parquet.constants.batches import GeneralCollateFn
|
|
5
|
+
from replay.data.nn.parquet.info.replicas import ReplicasInfo, ReplicasInfoProtocol
|
|
6
|
+
|
|
7
|
+
DEFAULT_COLLATE_FN: GeneralCollateFn = general_collate
|
|
8
|
+
|
|
9
|
+
DEFAULT_MASK_POSTFIX: str = "_mask"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def default_make_mask_name(postfix: str) -> Callable[[str], str]:
|
|
13
|
+
def function(name: str) -> str:
|
|
14
|
+
return f"{name}{postfix}"
|
|
15
|
+
|
|
16
|
+
return function
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
DEFAULT_MAKE_MASK_NAME = default_make_mask_name(DEFAULT_MASK_POSTFIX)
|
|
20
|
+
DEFAULT_REPLICAS_INFO: ReplicasInfoProtocol = ReplicasInfo()
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.parquet.impl.masking import DEFAULT_MAKE_MASK_NAME
|
|
7
|
+
|
|
8
|
+
from .column_protocol import ColumnProtocol
|
|
9
|
+
|
|
10
|
+
Batch = dict[str, torch.Tensor]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def deduce_device(columns: Sequence[ColumnProtocol]) -> torch.device:
|
|
14
|
+
"""
|
|
15
|
+
Sanity check for matching devices on all of dataset's columns.
|
|
16
|
+
|
|
17
|
+
:param columns: A list of dataset's column data.
|
|
18
|
+
:raises RuntimeError: If any of the columns have mismatching devices.
|
|
19
|
+
:return: The determined columns' device.
|
|
20
|
+
"""
|
|
21
|
+
assert len(columns) > 0
|
|
22
|
+
device = columns[0].device
|
|
23
|
+
|
|
24
|
+
def is_correct_device(column: ColumnProtocol) -> bool:
|
|
25
|
+
return column.device == device
|
|
26
|
+
|
|
27
|
+
if not all(map(is_correct_device, columns)): # pragma: no cover
|
|
28
|
+
msg = "Columns must be all on the same device."
|
|
29
|
+
raise RuntimeError(msg)
|
|
30
|
+
return device
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def deduce_length(columns: Sequence[ColumnProtocol]) -> int:
|
|
34
|
+
"""
|
|
35
|
+
Sanity check for matching lengths on all of dataset's columns.
|
|
36
|
+
|
|
37
|
+
:param columns: A list of dataset's column data.
|
|
38
|
+
:raises RuntimeError: If any of the columns has less rows than others.
|
|
39
|
+
:return: The determined columns' length.
|
|
40
|
+
"""
|
|
41
|
+
assert len(columns) > 0
|
|
42
|
+
length = columns[0].length
|
|
43
|
+
|
|
44
|
+
def is_correct_length(column: ColumnProtocol) -> bool:
|
|
45
|
+
return column.length == length
|
|
46
|
+
|
|
47
|
+
if not all(map(is_correct_length, columns)):
|
|
48
|
+
msg = "Columns must have the same lengths."
|
|
49
|
+
raise RuntimeError(msg)
|
|
50
|
+
assert length > 0
|
|
51
|
+
return length
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def deduce_length_device(columns: dict[str, ColumnProtocol]) -> tuple[int, torch.device]:
|
|
55
|
+
"""A combination check for both matching devices and lengths."""
|
|
56
|
+
raw = [*columns.values()]
|
|
57
|
+
columns_length = deduce_length(raw)
|
|
58
|
+
columns_device = deduce_device(raw)
|
|
59
|
+
del raw
|
|
60
|
+
return (columns_length, columns_device)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class NamedColumns:
|
|
64
|
+
"""
|
|
65
|
+
Representation of a data batch read from the filesystem.
|
|
66
|
+
This representation contains all of the columns read into memory, as well as
|
|
67
|
+
metadata such as their length and current device.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
columns: dict[str, ColumnProtocol],
|
|
73
|
+
make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
:param columns: Column data read from the filesystem.
|
|
77
|
+
:param make_mask_name: A function generating matching mask names for each column.
|
|
78
|
+
"""
|
|
79
|
+
self.columns_length, self.columns_device = deduce_length_device(columns)
|
|
80
|
+
|
|
81
|
+
self.columns = columns
|
|
82
|
+
self.make_mask_name = make_mask_name
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def length(self) -> int:
|
|
86
|
+
return self.columns_length
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def device(self) -> torch.device:
|
|
90
|
+
return self.columns_device
|
|
91
|
+
|
|
92
|
+
def __len__(self) -> int:
|
|
93
|
+
return self.columns_length
|
|
94
|
+
|
|
95
|
+
def __getitem__(self, indices: torch.LongTensor) -> Batch:
|
|
96
|
+
indices = indices.to(device=self.device)
|
|
97
|
+
result = {}
|
|
98
|
+
for name, column in self.columns.items():
|
|
99
|
+
result[self.make_mask_name(name)], result[name] = column[indices]
|
|
100
|
+
return result
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
7
|
+
from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
|
|
8
|
+
from replay.data.nn.parquet.metadata import Metadata, get_numeric_columns
|
|
9
|
+
from replay.data.utils.typing.dtype import pyarrow_to_torch
|
|
10
|
+
|
|
11
|
+
from .column_protocol import OutputType
|
|
12
|
+
from .utils import ensure_mutable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NumericColumn:
|
|
16
|
+
"""A representation of a numeric column, containing a single number in each of its rows."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
data: torch.Tensor,
|
|
21
|
+
mask: Optional[torch.BoolTensor] = None,
|
|
22
|
+
padding: Any = DEFAULT_PADDING,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""
|
|
25
|
+
:param data: A tensor containing column data.
|
|
26
|
+
:param mask: A mask tensor to differentiate real values from paddings. Default: ``None``.
|
|
27
|
+
:param padding: Padding to use for future indexing of non-existent data. Default: value of ``DEFAULT_PADDING``.
|
|
28
|
+
"""
|
|
29
|
+
self.padding: Any = padding
|
|
30
|
+
self.data = data
|
|
31
|
+
self.mask = mask
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def length(self) -> int:
|
|
35
|
+
result = torch.numel(self.data)
|
|
36
|
+
if self.mask is not None:
|
|
37
|
+
assert result == torch.numel(self.mask)
|
|
38
|
+
return result
|
|
39
|
+
|
|
40
|
+
def __len__(self) -> int:
|
|
41
|
+
return self.length
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def device(self) -> torch.device:
|
|
45
|
+
result = self.data.device
|
|
46
|
+
if self.mask is not None:
|
|
47
|
+
assert result == self.mask.device
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
def _get_mask(self, indices: torch.LongTensor) -> torch.BoolTensor:
|
|
51
|
+
mask = torch.ones_like(indices, dtype=torch.bool) if self.mask is None else self.mask[indices]
|
|
52
|
+
return mask
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, indices: torch.LongTensor) -> OutputType:
|
|
55
|
+
indices = indices.to(device=self.device)
|
|
56
|
+
mask = self._get_mask(indices)
|
|
57
|
+
output = torch.where(mask, self.data[indices], self.padding)
|
|
58
|
+
return (mask, output)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE, padding: Any = DEFAULT_PADDING) -> OutputType:
|
|
62
|
+
"""
|
|
63
|
+
Converts a PyArrow array into a PyTorch tensor.
|
|
64
|
+
|
|
65
|
+
:param array: Original PyArrow array.
|
|
66
|
+
:param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
|
|
67
|
+
:param padding: Value to fill null values with. Default: value of to ``DEFAULT_PADDING``.
|
|
68
|
+
|
|
69
|
+
:return: A PyTorch tensor obtained from original array.
|
|
70
|
+
"""
|
|
71
|
+
dtype = pyarrow_to_torch(array.type)
|
|
72
|
+
|
|
73
|
+
mask_torch = None
|
|
74
|
+
if array.null_count > 0:
|
|
75
|
+
mask_torch = torch.asarray(
|
|
76
|
+
ensure_mutable(array.is_valid().to_numpy(zero_copy_only=False)),
|
|
77
|
+
device=device,
|
|
78
|
+
dtype=torch.bool,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
array_torch = torch.asarray(
|
|
82
|
+
ensure_mutable(array.fill_null(padding).to_numpy()),
|
|
83
|
+
device=device,
|
|
84
|
+
dtype=dtype,
|
|
85
|
+
)
|
|
86
|
+
return (mask_torch, array_torch)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def to_numeric_columns(
|
|
90
|
+
data: pa.RecordBatch,
|
|
91
|
+
metadata: Metadata,
|
|
92
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
93
|
+
padding: Any = DEFAULT_PADDING,
|
|
94
|
+
) -> dict[str, NumericColumn]:
|
|
95
|
+
"""
|
|
96
|
+
Converts a PyArrow batch of data to a set of ``NumericColumn``s.
|
|
97
|
+
This function filters only those columns matching its format from the full batch.
|
|
98
|
+
|
|
99
|
+
:param data: A PyArrow batch of column data.
|
|
100
|
+
:param metadata: Metadata containing information about columns' formats.
|
|
101
|
+
:param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
|
|
102
|
+
:param padding: Padding to use for future indexing of non-existent data. Default: value of ``DEFAULT_PADDING``.
|
|
103
|
+
|
|
104
|
+
:return: A dict of tensors containing dataset's numeric columns.
|
|
105
|
+
"""
|
|
106
|
+
result = {}
|
|
107
|
+
for column_name in get_numeric_columns(metadata):
|
|
108
|
+
mask, torch_array = to_torch(data.column(column_name), device, padding)
|
|
109
|
+
result[column_name] = NumericColumn(data=torch_array, mask=mask, padding=padding)
|
|
110
|
+
return result
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
WRITEABLE_FLAG: str = "WRITEABLE"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def ensure_mutable(array: np.array) -> np.array:
|
|
7
|
+
"""
|
|
8
|
+
Ensures the resulting NumPy array is mutable by making a copy if it's not.
|
|
9
|
+
|
|
10
|
+
:param array: Array to be checked for mutability.
|
|
11
|
+
:return: Mutable copy of `array`.
|
|
12
|
+
"""
|
|
13
|
+
if not array.flags[WRITEABLE_FLAG]:
|
|
14
|
+
result = array.copy()
|
|
15
|
+
assert result.flags[WRITEABLE_FLAG]
|
|
16
|
+
return result
|
|
17
|
+
return array
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DistributedInfo:
|
|
7
|
+
"""Wrapper class for Torch's distibuted environment metadata."""
|
|
8
|
+
|
|
9
|
+
def __iter__(self):
|
|
10
|
+
yield self.rank
|
|
11
|
+
yield self.world_size
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def is_distributed(self) -> bool:
|
|
15
|
+
if dist.is_available():
|
|
16
|
+
return dist.is_initialized()
|
|
17
|
+
return False
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def rank(self) -> int:
|
|
21
|
+
if self.is_distributed:
|
|
22
|
+
return dist.get_rank()
|
|
23
|
+
return 0
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def world_size(self) -> int:
|
|
27
|
+
if self.is_distributed:
|
|
28
|
+
return dist.get_world_size()
|
|
29
|
+
return 1
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DistributedInfoProtocol(Protocol):
|
|
33
|
+
@property
|
|
34
|
+
def rank(self) -> int: ...
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def world_size(self) -> int: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
DEFAULT_DISTRIBUTED_INFO: DistributedInfo = DistributedInfo()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from math import ceil
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def validate_length(length: int) -> int:
|
|
11
|
+
if length < 1:
|
|
12
|
+
msg = f"Length is invalid. Got {length}."
|
|
13
|
+
raise ValueError(msg)
|
|
14
|
+
return length
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def validate_num_replicas(num_replicas: int) -> int:
|
|
18
|
+
if num_replicas < 1:
|
|
19
|
+
msg = f"Num Replicas is invalid. Got {num_replicas}."
|
|
20
|
+
raise ValueError(msg)
|
|
21
|
+
return num_replicas
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def validate_curr_replica(curr_replica: int, num_replicas: int) -> int:
|
|
25
|
+
num_replicas = validate_num_replicas(num_replicas)
|
|
26
|
+
if (curr_replica < 0) or (num_replicas <= curr_replica):
|
|
27
|
+
msg = f"Curr Replicas is invalid. Got {curr_replica}."
|
|
28
|
+
raise ValueError(msg)
|
|
29
|
+
return curr_replica
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@lru_cache
|
|
33
|
+
def _partitioning_length(length: int, num_replicas: int) -> int:
|
|
34
|
+
length = validate_length(length)
|
|
35
|
+
num_replicas = validate_num_replicas(num_replicas)
|
|
36
|
+
|
|
37
|
+
result = length
|
|
38
|
+
if length % num_replicas != 0:
|
|
39
|
+
raw_per_replica = length / num_replicas
|
|
40
|
+
per_replica = ceil(raw_per_replica)
|
|
41
|
+
new_length = per_replica * num_replicas
|
|
42
|
+
assert (new_length - length) < num_replicas
|
|
43
|
+
result = new_length
|
|
44
|
+
assert result % num_replicas == 0
|
|
45
|
+
assert length <= result
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def partitioning_length(length: int, num_replicas: int) -> int:
|
|
50
|
+
return _partitioning_length(length, num_replicas)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@lru_cache
|
|
54
|
+
def _partitioning_per_replica(length: int, num_replicas: int) -> int:
|
|
55
|
+
full_length = partitioning_length(length, num_replicas)
|
|
56
|
+
result = full_length // num_replicas
|
|
57
|
+
assert result <= length
|
|
58
|
+
assert result > 0
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def partitioning_per_replica(length: int, num_replicas: int) -> int:
|
|
63
|
+
return _partitioning_per_replica(length, num_replicas)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Partitioning:
|
|
67
|
+
"""Utility class for calculating valid indices across multiple replicas."""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
curr_replica: int,
|
|
72
|
+
num_replicas: int,
|
|
73
|
+
device: Union[torch.device, str] = DEFAULT_DEVICE,
|
|
74
|
+
generator: Optional[torch.Generator] = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""
|
|
77
|
+
:param curr_replica: Id of the curreent replica.
|
|
78
|
+
:param num_replicas: Total number of active replicas.
|
|
79
|
+
:param device: Target device to send the indices tensor to.
|
|
80
|
+
Default: value of ``DEFAULT_DEVICE``.
|
|
81
|
+
:param generator: A pseudo-random number generator for index shuffling. Default: ``None``.
|
|
82
|
+
"""
|
|
83
|
+
self.device = torch.device(device)
|
|
84
|
+
self.generator = generator
|
|
85
|
+
self.num_replicas = validate_num_replicas(num_replicas)
|
|
86
|
+
self.curr_replica = validate_curr_replica(curr_replica, self.num_replicas)
|
|
87
|
+
|
|
88
|
+
def generate_raw_indices(self, length: int) -> torch.LongTensor:
|
|
89
|
+
full_length = partitioning_length(length, self.num_replicas)
|
|
90
|
+
|
|
91
|
+
if self.generator is None:
|
|
92
|
+
raw_indices = torch.arange(full_length, dtype=torch.int64, device=self.device)
|
|
93
|
+
else:
|
|
94
|
+
raw_indices = torch.randperm(full_length, dtype=torch.int64, generator=self.generator)
|
|
95
|
+
raw_indices = raw_indices.to(device=self.device)
|
|
96
|
+
|
|
97
|
+
assert torch.max(raw_indices).cpu().item() < full_length
|
|
98
|
+
assert torch.numel(raw_indices) == full_length
|
|
99
|
+
assert raw_indices.device == self.device
|
|
100
|
+
|
|
101
|
+
return raw_indices
|
|
102
|
+
|
|
103
|
+
def replica_indices(self, raw_indices: torch.LongTensor) -> torch.LongTensor:
|
|
104
|
+
full_length = torch.numel(raw_indices)
|
|
105
|
+
slc = slice(self.curr_replica, full_length, self.num_replicas)
|
|
106
|
+
replica_indices = raw_indices[slc].clone()
|
|
107
|
+
|
|
108
|
+
assert torch.max(replica_indices).cpu().item() < full_length
|
|
109
|
+
|
|
110
|
+
return replica_indices
|
|
111
|
+
|
|
112
|
+
def generate(self, length: int) -> torch.LongTensor:
|
|
113
|
+
raw_indices = self.generate_raw_indices(length)
|
|
114
|
+
full_length = partitioning_length(length, self.num_replicas)
|
|
115
|
+
|
|
116
|
+
assert torch.numel(raw_indices) == full_length
|
|
117
|
+
|
|
118
|
+
replica_indices = self.replica_indices(raw_indices)
|
|
119
|
+
per_replica = partitioning_per_replica(length, self.num_replicas)
|
|
120
|
+
|
|
121
|
+
assert torch.numel(replica_indices) == per_replica
|
|
122
|
+
|
|
123
|
+
indices = torch.remainder(replica_indices, length)
|
|
124
|
+
|
|
125
|
+
assert torch.max(indices).cpu().item() < length
|
|
126
|
+
assert torch.numel(indices) == per_replica
|
|
127
|
+
assert indices.device == self.device
|
|
128
|
+
|
|
129
|
+
return indices
|
|
130
|
+
|
|
131
|
+
def __call__(self, length: int) -> torch.LongTensor:
|
|
132
|
+
return self.generate(length)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
from .distributed_info import DEFAULT_DISTRIBUTED_INFO, DistributedInfoProtocol
|
|
4
|
+
from .worker_info import DEFAULT_WORKER_INFO, WorkerInfoProtocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def num_replicas(
|
|
8
|
+
worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
|
|
9
|
+
distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
|
|
10
|
+
) -> int:
|
|
11
|
+
return worker_info.num_workers * distributed_info.world_size
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def curr_replica(
|
|
15
|
+
worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
|
|
16
|
+
distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
|
|
17
|
+
) -> int:
|
|
18
|
+
result = worker_info.id + worker_info.num_workers * distributed_info.rank
|
|
19
|
+
assert result < num_replicas(worker_info, distributed_info)
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ReplicasInfoProtocol(Protocol):
|
|
24
|
+
@property
|
|
25
|
+
def num_replicas(self) -> int: ...
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def curr_replica(self) -> int: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ReplicasInfo:
|
|
32
|
+
"""
|
|
33
|
+
A replica metadata geneartor.
|
|
34
|
+
|
|
35
|
+
By default, assumes standard Torch DDP training/inference procedure,
|
|
36
|
+
where each replica (a distinct worker on a specific device) is expected to process
|
|
37
|
+
a separate chunk of the dataset.
|
|
38
|
+
|
|
39
|
+
This behavior can be modified by providing custom ``worker_info`` and ``distributed_info`` objects
|
|
40
|
+
able to provide infor about local worker count and world size/rank respectively.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
|
|
46
|
+
distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""
|
|
49
|
+
:param worker_info: An object adhering to the ``WorkerInfoProtocol`` and used to obtain local worker count.
|
|
50
|
+
Default: value of ``DEFAULT_WORKER_INFO`` - an implementation using ``torch.utils.data.get_worker_info()``.
|
|
51
|
+
:param distributed_info: An object adhering to the ``DistributedInfoProtocol`` and used to obtain
|
|
52
|
+
world size and rank. Default: value of ``DEFAULT_WORKER_INFO`` - an implementation using the
|
|
53
|
+
``torch.distributed`` module.
|
|
54
|
+
"""
|
|
55
|
+
self.worker_info = worker_info
|
|
56
|
+
self.distributed_info = distributed_info
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def num_replicas(self) -> int:
|
|
60
|
+
return num_replicas(worker_info=self.worker_info, distributed_info=self.distributed_info)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def curr_replica(self) -> int:
|
|
64
|
+
return curr_replica(worker_info=self.worker_info, distributed_info=self.distributed_info)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
DEFAULT_REPLICAS_INFO: ReplicasInfoProtocol = ReplicasInfo()
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Any, Iterator, Optional, Protocol
|
|
2
|
+
|
|
3
|
+
import torch.utils.data as data
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WorkerInfoProtocol(Protocol):
|
|
7
|
+
@property
|
|
8
|
+
def id(self) -> int: ...
|
|
9
|
+
|
|
10
|
+
@property
|
|
11
|
+
def num_workers(self) -> int: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WorkerInfo:
|
|
15
|
+
"""Wrapper class for Torch's worker metadata."""
|
|
16
|
+
|
|
17
|
+
def __iter__(self) -> Iterator[int]:
|
|
18
|
+
yield self.id
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def worker_info(self) -> Optional[Any]:
|
|
22
|
+
return data.get_worker_info()
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def is_parallel(self) -> bool:
|
|
26
|
+
return self.worker_info is not None
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def id(self) -> int:
|
|
30
|
+
wi: Optional[data.WorkerInfo] = self.worker_info
|
|
31
|
+
if wi is not None:
|
|
32
|
+
return wi.id
|
|
33
|
+
return 0
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def num_workers(self) -> int:
|
|
37
|
+
wi: Optional[data.WorkerInfo] = self.worker_info
|
|
38
|
+
if wi is not None:
|
|
39
|
+
return wi.num_workers
|
|
40
|
+
return 1
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
DEFAULT_WORKER_INFO: WorkerInfo = WorkerInfo()
|