replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TokenMaskTransform(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
For the feature tensor specified by ``token_field``, randomly masks items
|
|
9
|
+
in the sequence based on a uniform distribution with specified probability of masking.
|
|
10
|
+
In fact, this transform creates mask for the Masked Language Modeling (MLM) task analog in the recommendations.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
|
|
14
|
+
.. code-block:: python
|
|
15
|
+
|
|
16
|
+
>>> _ = torch.manual_seed(0)
|
|
17
|
+
>>> input_tensor = {"padding_id": torch.BoolTensor([0, 1, 1])}
|
|
18
|
+
>>> transform = TokenMaskTransform("padding_id")
|
|
19
|
+
>>> output_tensor = transform(input_tensor)
|
|
20
|
+
>>> output_tensor
|
|
21
|
+
{'padding_id': tensor([False, True, True]),
|
|
22
|
+
'token_mask': tensor([False, True, False])}
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
token_field: str,
|
|
29
|
+
out_feature_name: str = "token_mask",
|
|
30
|
+
mask_prob: float = 0.15,
|
|
31
|
+
generator: Optional[torch.Generator] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
:param token_field: Name of the column containing the unmasked tokes.
|
|
35
|
+
:param out_feature_name: Name of the resulting mask column. Default: ``token_mask``.
|
|
36
|
+
:param mask_prob: Probability of masking the item, i.e. setting it to ``0``. Default: ``0.15``.
|
|
37
|
+
:param generator: Random number generator to be used for generating
|
|
38
|
+
the uniform distribution. Default: ``None``.
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.token_field = token_field
|
|
42
|
+
self.out_feature_name = out_feature_name
|
|
43
|
+
self.mask_prob = mask_prob
|
|
44
|
+
self.generator = generator
|
|
45
|
+
|
|
46
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
47
|
+
output_batch = dict(batch.items())
|
|
48
|
+
|
|
49
|
+
paddings = batch[self.token_field]
|
|
50
|
+
|
|
51
|
+
assert paddings.dtype == torch.bool, "Source tensor for token mask should be boolean."
|
|
52
|
+
|
|
53
|
+
mask_prob = torch.rand(paddings.size(-1), dtype=torch.float32, generator=self.generator).to(
|
|
54
|
+
device=paddings.device
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# mask[i], 0 ~ mask_prob, 1 ~ (1 - mask_prob)
|
|
58
|
+
mask = (mask_prob * paddings) >= self.mask_prob
|
|
59
|
+
|
|
60
|
+
# Fix corner cases in mask
|
|
61
|
+
# 1. If all token are not masked, add mask to the end
|
|
62
|
+
if mask.all() or mask[paddings].all():
|
|
63
|
+
mask[-1] = 0
|
|
64
|
+
# 2. If all token are masked, add non-masked before the last
|
|
65
|
+
elif (not mask.any()) and (len(mask) > 1):
|
|
66
|
+
mask[-2] = 1
|
|
67
|
+
|
|
68
|
+
output_batch[self.out_feature_name] = mask
|
|
69
|
+
return output_batch
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TrimTransform(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Trims sequences of specified names `feature_names` keeping the specified sequence length `seq_len` on the right.
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
|
|
12
|
+
.. code-block:: python
|
|
13
|
+
|
|
14
|
+
>>> input_batch = {
|
|
15
|
+
... "user_id": torch.LongTensor([111]),
|
|
16
|
+
... "item_id": torch.LongTensor([[5, 4, 0, 7, 4]]),
|
|
17
|
+
... "seen_ids": torch.LongTensor([[5, 4, 0, 7, 4]]),
|
|
18
|
+
... }
|
|
19
|
+
>>> transform = TrimTransform(seq_len=3, feature_names="item_id")
|
|
20
|
+
>>> output_batch = transform(input_batch)
|
|
21
|
+
>>> output_batch
|
|
22
|
+
{'user_id': tensor([111]),
|
|
23
|
+
'item_id': tensor([[0, 7, 4]]),
|
|
24
|
+
'seen_ids': tensor([[5, 4, 0, 7, 4]])}
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
seq_len: int,
|
|
31
|
+
feature_names: Union[List[str], str],
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
:param seq_len: max sequence length used in model. Must be positive.
|
|
35
|
+
:param feature_name: name of feature in batch to be trimmed.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
assert seq_len > 0
|
|
39
|
+
self.seq_len = seq_len
|
|
40
|
+
self.feature_names = [feature_names] if isinstance(feature_names, str) else feature_names
|
|
41
|
+
|
|
42
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
43
|
+
output_batch = dict(batch.items())
|
|
44
|
+
|
|
45
|
+
for name in self.feature_names:
|
|
46
|
+
assert output_batch[name].shape[1] >= self.seq_len
|
|
47
|
+
|
|
48
|
+
trimmed_seq = output_batch[name][:, -self.seq_len :, ...].clone()
|
|
49
|
+
output_batch[name] = trimmed_seq
|
|
50
|
+
|
|
51
|
+
return output_batch
|
replay/nn/utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Callable, Literal, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def warning_is_not_none(msg: str) -> Callable:
|
|
8
|
+
def checker(value: Tuple[torch.Tensor, str]) -> bool:
|
|
9
|
+
if value[0] is not None:
|
|
10
|
+
warnings.warn(msg.format(value[1]), RuntimeWarning, stacklevel=2)
|
|
11
|
+
return False
|
|
12
|
+
return True
|
|
13
|
+
|
|
14
|
+
return checker
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_activation(
|
|
18
|
+
activation: Literal["relu", "gelu", "sigmoid"],
|
|
19
|
+
) -> torch.nn.Module:
|
|
20
|
+
"""The function of creating an activation function based on its name"""
|
|
21
|
+
if activation == "relu":
|
|
22
|
+
return torch.nn.ReLU()
|
|
23
|
+
if activation == "gelu":
|
|
24
|
+
return torch.nn.GELU()
|
|
25
|
+
if activation == "sigmoid":
|
|
26
|
+
return torch.nn.Sigmoid()
|
|
27
|
+
msg = "Expected to get activation relu/gelu/sigmoid"
|
|
28
|
+
raise ValueError(msg)
|
replay/preprocessing/filters.py
CHANGED
|
@@ -1090,3 +1090,131 @@ class ConsecutiveDuplicatesFilter(_BaseFilter):
|
|
|
1090
1090
|
.where((sf.col(self.item_column) != sf.col(self.temporary_column)) | sf.col(self.temporary_column).isNull())
|
|
1091
1091
|
.drop(self.temporary_column)
|
|
1092
1092
|
)
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
def _check_col_present(
|
|
1096
|
+
target: DataFrameLike,
|
|
1097
|
+
reference: DataFrameLike,
|
|
1098
|
+
columns_to_process: list[str],
|
|
1099
|
+
) -> None:
|
|
1100
|
+
target_columns = set(target.columns)
|
|
1101
|
+
reference_columns = set(reference.columns)
|
|
1102
|
+
for column in columns_to_process:
|
|
1103
|
+
if column not in target_columns or column not in reference_columns:
|
|
1104
|
+
msg = f"Column '{column}' must be in both dataframes"
|
|
1105
|
+
raise KeyError(msg)
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _filter_cold_pandas(
|
|
1109
|
+
target: PandasDataFrame,
|
|
1110
|
+
reference: PandasDataFrame,
|
|
1111
|
+
columns_to_process: list[str],
|
|
1112
|
+
) -> PandasDataFrame:
|
|
1113
|
+
for column in columns_to_process:
|
|
1114
|
+
allowed_values = reference[column].unique()
|
|
1115
|
+
target = target[target[column].isin(allowed_values)]
|
|
1116
|
+
return target
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
def _filter_cold_polars(
|
|
1120
|
+
target: PolarsDataFrame,
|
|
1121
|
+
reference: PolarsDataFrame,
|
|
1122
|
+
columns_to_process: list[str],
|
|
1123
|
+
) -> PolarsDataFrame:
|
|
1124
|
+
for column in columns_to_process:
|
|
1125
|
+
allowed_values = reference.select(column).unique()
|
|
1126
|
+
target = target.join(allowed_values, on=column, how="semi")
|
|
1127
|
+
return target
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
def _filter_cold_spark(
|
|
1131
|
+
target: SparkDataFrame,
|
|
1132
|
+
reference: SparkDataFrame,
|
|
1133
|
+
columns_to_process: list[str],
|
|
1134
|
+
) -> SparkDataFrame:
|
|
1135
|
+
for column in columns_to_process:
|
|
1136
|
+
allowed_values = reference.select(column).distinct()
|
|
1137
|
+
target = target.join(allowed_values, on=column, how="left_semi")
|
|
1138
|
+
return target
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def filter_cold(
|
|
1142
|
+
target: DataFrameLike,
|
|
1143
|
+
reference: DataFrameLike,
|
|
1144
|
+
mode: Literal["items", "users", "both"] = "items",
|
|
1145
|
+
query_column: str = "query_id",
|
|
1146
|
+
item_column: str = "item_id",
|
|
1147
|
+
) -> DataFrameLike:
|
|
1148
|
+
"""
|
|
1149
|
+
Filter rows in ``target`` keeping only users/items that exist in ``reference``.
|
|
1150
|
+
|
|
1151
|
+
This function works with pandas, Polars and Spark DataFrames. ``target`` and
|
|
1152
|
+
``reference`` must be of the same backend type. Depending on ``mode``, it
|
|
1153
|
+
removes rows whose ``item_column`` and/or ``query_column`` values are not
|
|
1154
|
+
present in the corresponding columns of ``reference``.
|
|
1155
|
+
|
|
1156
|
+
Parameters
|
|
1157
|
+
----------
|
|
1158
|
+
target : DataFrameLike
|
|
1159
|
+
Dataset to be filtered (pandas/Polars/Spark).
|
|
1160
|
+
reference : DataFrameLike
|
|
1161
|
+
Dataset that defines the allowed universe of users/items.
|
|
1162
|
+
mode : {"items", "users", "both"}, default "items"
|
|
1163
|
+
What to filter: only items, only users, or both.
|
|
1164
|
+
query_column : str, default "query_id"
|
|
1165
|
+
Name of the user (query) column.
|
|
1166
|
+
item_column : str, default "item_id"
|
|
1167
|
+
Name of the item column.
|
|
1168
|
+
|
|
1169
|
+
Returns
|
|
1170
|
+
-------
|
|
1171
|
+
DataFrameLike
|
|
1172
|
+
Filtered ``target`` of the same backend type as the input.
|
|
1173
|
+
|
|
1174
|
+
Raises
|
|
1175
|
+
------
|
|
1176
|
+
ValueError
|
|
1177
|
+
If ``mode`` is not one of {"items", "users", "both"}.
|
|
1178
|
+
TypeError
|
|
1179
|
+
If ``target`` and ``reference`` are of different backend types.
|
|
1180
|
+
KeyError
|
|
1181
|
+
If required columns are missing in either dataset.
|
|
1182
|
+
NotImplementedError
|
|
1183
|
+
If the input dataframe type is not supported.
|
|
1184
|
+
"""
|
|
1185
|
+
if mode not in {"items", "users", "both"}:
|
|
1186
|
+
msg = "mode must be 'items' | 'users' | 'both'"
|
|
1187
|
+
raise ValueError(msg)
|
|
1188
|
+
if not isinstance(target, type(reference)):
|
|
1189
|
+
msg = "Target and reference must be of the same type"
|
|
1190
|
+
raise TypeError(msg)
|
|
1191
|
+
|
|
1192
|
+
if mode == "both":
|
|
1193
|
+
columns_to_process = [query_column, item_column]
|
|
1194
|
+
elif mode == "items":
|
|
1195
|
+
columns_to_process = [item_column]
|
|
1196
|
+
elif mode == "users":
|
|
1197
|
+
columns_to_process = [query_column]
|
|
1198
|
+
|
|
1199
|
+
_check_col_present(target, reference, columns_to_process)
|
|
1200
|
+
|
|
1201
|
+
if isinstance(target, PandasDataFrame):
|
|
1202
|
+
return _filter_cold_pandas(
|
|
1203
|
+
target,
|
|
1204
|
+
reference,
|
|
1205
|
+
columns_to_process,
|
|
1206
|
+
)
|
|
1207
|
+
if isinstance(target, PolarsDataFrame):
|
|
1208
|
+
return _filter_cold_polars(
|
|
1209
|
+
target,
|
|
1210
|
+
reference,
|
|
1211
|
+
columns_to_process,
|
|
1212
|
+
)
|
|
1213
|
+
if isinstance(target, SparkDataFrame):
|
|
1214
|
+
return _filter_cold_spark(
|
|
1215
|
+
target,
|
|
1216
|
+
reference,
|
|
1217
|
+
columns_to_process,
|
|
1218
|
+
)
|
|
1219
|
+
msg = f"Unsupported data frame type: {type(target)}"
|
|
1220
|
+
raise NotImplementedError(msg)
|
|
@@ -26,7 +26,7 @@ from replay.utils import (
|
|
|
26
26
|
|
|
27
27
|
if PYSPARK_AVAILABLE:
|
|
28
28
|
from pyspark.sql import Window, functions as sf # noqa: I001
|
|
29
|
-
from pyspark.sql.types import LongType
|
|
29
|
+
from pyspark.sql.types import LongType
|
|
30
30
|
from replay.utils.session_handler import get_spark_session
|
|
31
31
|
|
|
32
32
|
HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
|
|
@@ -185,11 +185,11 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
185
185
|
self._mapping = mapping_on_spark.rdd.collectAsMap()
|
|
186
186
|
|
|
187
187
|
def _fit_pandas(self, df: PandasDataFrame) -> None:
|
|
188
|
-
unique_col_values = df[self._col].drop_duplicates().reset_index(drop=True)
|
|
188
|
+
unique_col_values = df[self._col].sort_values().drop_duplicates().reset_index(drop=True)
|
|
189
189
|
self._mapping = {val: key for key, val in unique_col_values.to_dict().items()}
|
|
190
190
|
|
|
191
191
|
def _fit_polars(self, df: PolarsDataFrame) -> None:
|
|
192
|
-
unique_col_values = df.select(self._col).unique()
|
|
192
|
+
unique_col_values = df.sort(self._col).select(self._col).unique()
|
|
193
193
|
self._mapping = {key: val for val, key in enumerate(unique_col_values.to_series().to_list())}
|
|
194
194
|
|
|
195
195
|
def fit(self, df: DataFrameLike) -> "LabelEncodingRule":
|
|
@@ -630,37 +630,40 @@ class SequenceEncodingRule(LabelEncodingRule):
|
|
|
630
630
|
return self
|
|
631
631
|
|
|
632
632
|
def _transform_spark(self, df: SparkDataFrame, default_value: Optional[int]) -> SparkDataFrame:
|
|
633
|
-
|
|
634
|
-
return [mapping.get(value) for value in x] # pragma: no cover
|
|
633
|
+
other_columns = [col for col in df.columns if col != self._col]
|
|
635
634
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
635
|
+
mapping_on_spark = get_spark_session().createDataFrame(
|
|
636
|
+
data=list(self.get_mapping().items()), schema=[self._col, self._target_col]
|
|
637
|
+
)
|
|
638
|
+
encoded_df = (
|
|
639
|
+
df.select(*other_columns, sf.posexplode(self._col))
|
|
640
|
+
.withColumnRenamed("col", self._col)
|
|
641
|
+
.join(mapping_on_spark, on=self._col, how="left")
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if self._handle_unknown == "error":
|
|
645
|
+
if encoded_df.filter(sf.col(self._target_col).isNull()).count() > 0:
|
|
646
|
+
msg = f"Found unknown labels in column {self._col} during transform"
|
|
647
|
+
raise ValueError(msg)
|
|
648
|
+
else:
|
|
649
|
+
if default_value is not None:
|
|
650
|
+
encoded_df = encoded_df.fillna(default_value, subset=[self._target_col])
|
|
639
651
|
|
|
652
|
+
result = encoded_df.groupBy(other_columns).agg(
|
|
653
|
+
sf.sort_array(sf.collect_list(sf.struct("pos", self._target_col)))
|
|
654
|
+
.getItem(self._target_col)
|
|
655
|
+
.alias(self._col)
|
|
656
|
+
)
|
|
640
657
|
if self._handle_unknown == "drop":
|
|
641
|
-
|
|
642
|
-
if
|
|
658
|
+
result = result.withColumn(self._col, sf.filter(self._col, lambda x: x.isNotNull()))
|
|
659
|
+
if result.select(sf.max(sf.size(self._col))).first()[0] == 0:
|
|
643
660
|
warnings.warn(
|
|
644
661
|
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
645
662
|
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
646
663
|
LabelEncoderTransformWarning,
|
|
647
664
|
)
|
|
648
|
-
elif self._handle_unknown == "error":
|
|
649
|
-
if (
|
|
650
|
-
encoded_df.select(sf.sum(sf.array_contains(self._target_col, -1).isNull().cast("integer"))).first()[0]
|
|
651
|
-
!= 0
|
|
652
|
-
):
|
|
653
|
-
msg = f"Found unknown labels in column {self._col} during transform"
|
|
654
|
-
raise ValueError(msg)
|
|
655
|
-
else:
|
|
656
|
-
if default_value:
|
|
657
|
-
encoded_df = encoded_df.withColumn(
|
|
658
|
-
self._target_col,
|
|
659
|
-
sf.transform(self._target_col, lambda x: sf.when(x.isNull(), default_value).otherwise(x)),
|
|
660
|
-
)
|
|
661
665
|
|
|
662
|
-
|
|
663
|
-
return result_df
|
|
666
|
+
return result
|
|
664
667
|
|
|
665
668
|
def _transform_pandas(self, df: PandasDataFrame, default_value: Optional[int]) -> PandasDataFrame:
|
|
666
669
|
mapping = self.get_mapping()
|
|
@@ -771,7 +774,7 @@ class SequenceEncodingRule(LabelEncodingRule):
|
|
|
771
774
|
def _inverse_transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
772
775
|
array_expr = sf.array([sf.lit(x) for x in self._inverse_mapping_list])
|
|
773
776
|
decoded_df = df.withColumn(
|
|
774
|
-
self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x + 1))
|
|
777
|
+
self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x.cast("int") + 1))
|
|
775
778
|
)
|
|
776
779
|
return decoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
|
|
777
780
|
|
|
@@ -800,19 +803,19 @@ class LabelEncoder:
|
|
|
800
803
|
>>> mapped_interactions = encoder.fit_transform(user_interactions)
|
|
801
804
|
>>> mapped_interactions
|
|
802
805
|
user_id item_1 item_2 list
|
|
803
|
-
0 0 0 0 [
|
|
804
|
-
1 1 1 1 [
|
|
805
|
-
2 2 2 2 [
|
|
806
|
+
0 0 0 0 [2, 3, 4]
|
|
807
|
+
1 1 1 1 [4, 5, 6]
|
|
808
|
+
2 2 2 2 [1, 0, 5]
|
|
806
809
|
>>> encoder.mapping
|
|
807
810
|
{'user_id': {'u1': 0, 'u2': 1, 'u3': 2},
|
|
808
811
|
'item_1': {'item_1': 0, 'item_2': 1, 'item_3': 2},
|
|
809
812
|
'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2},
|
|
810
|
-
'list': {
|
|
813
|
+
'list': {-2: 0, -1: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6}}
|
|
811
814
|
>>> encoder.inverse_mapping
|
|
812
815
|
{'user_id': {0: 'u1', 1: 'u2', 2: 'u3'},
|
|
813
816
|
'item_1': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
|
|
814
817
|
'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
|
|
815
|
-
'list': {0:
|
|
818
|
+
'list': {0: -2, 1: -1, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5}}
|
|
816
819
|
>>> new_encoder = LabelEncoder([
|
|
817
820
|
... LabelEncodingRule("user_id", encoder.mapping["user_id"]),
|
|
818
821
|
... LabelEncodingRule("item_1", encoder.mapping["item_1"]),
|
|
@@ -834,14 +837,14 @@ class LabelEncoder:
|
|
|
834
837
|
self.rules = rules
|
|
835
838
|
|
|
836
839
|
@property
|
|
837
|
-
def mapping(self) ->
|
|
840
|
+
def mapping(self) -> dict[str, Mapping]:
|
|
838
841
|
"""
|
|
839
842
|
Returns mapping of each column in given rules.
|
|
840
843
|
"""
|
|
841
844
|
return {r.column: r.get_mapping() for r in self.rules}
|
|
842
845
|
|
|
843
846
|
@property
|
|
844
|
-
def inverse_mapping(self) ->
|
|
847
|
+
def inverse_mapping(self) -> dict[str, Mapping]:
|
|
845
848
|
"""
|
|
846
849
|
Returns inverse mapping of each column in given rules.
|
|
847
850
|
"""
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import (
|
|
3
|
+
Literal,
|
|
4
|
+
Optional,
|
|
5
|
+
Sequence,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import polars as pl
|
|
10
|
+
|
|
11
|
+
from replay.utils import (
|
|
12
|
+
PYSPARK_AVAILABLE,
|
|
13
|
+
DataFrameLike,
|
|
14
|
+
PandasDataFrame,
|
|
15
|
+
PolarsDataFrame,
|
|
16
|
+
SparkDataFrame,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if PYSPARK_AVAILABLE:
|
|
20
|
+
import pyspark.sql.functions as sf
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _ensure_columns_match(df, ref_cols, index: int, check_columns: bool) -> None:
|
|
24
|
+
if check_columns and set(df.columns) != set(ref_cols):
|
|
25
|
+
msg = f"Columns mismatch in dataframe #{index}: {sorted(df.columns)} != {sorted(ref_cols)}"
|
|
26
|
+
raise ValueError(msg)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _merge_subsets_pandas(
|
|
30
|
+
dfs: Sequence[PandasDataFrame],
|
|
31
|
+
columns: Optional[Sequence[str]],
|
|
32
|
+
check_columns: bool,
|
|
33
|
+
subset_for_duplicates: Optional[Sequence[str]],
|
|
34
|
+
on_duplicate: Literal["error", "drop", "ignore"],
|
|
35
|
+
) -> PandasDataFrame:
|
|
36
|
+
ref_cols = list(dfs[0].columns) if columns is None else list(columns)
|
|
37
|
+
|
|
38
|
+
aligned: list[PandasDataFrame] = []
|
|
39
|
+
for i, df in enumerate(dfs):
|
|
40
|
+
_ensure_columns_match(df, ref_cols, i, check_columns)
|
|
41
|
+
aligned.append(df[ref_cols])
|
|
42
|
+
|
|
43
|
+
merged = pd.concat(aligned, axis=0, ignore_index=True)
|
|
44
|
+
|
|
45
|
+
if on_duplicate == "ignore":
|
|
46
|
+
return merged
|
|
47
|
+
|
|
48
|
+
dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
|
|
49
|
+
dup_mask = merged.duplicated(subset=dup_subset, keep="first")
|
|
50
|
+
dup_count = int(dup_mask.sum())
|
|
51
|
+
|
|
52
|
+
if dup_count > 0:
|
|
53
|
+
if on_duplicate == "error":
|
|
54
|
+
msg = f"Found {dup_count} duplicate rows on subset {dup_subset}"
|
|
55
|
+
raise ValueError(msg)
|
|
56
|
+
if on_duplicate == "drop":
|
|
57
|
+
merged = merged.drop_duplicates(subset=dup_subset, keep="first").reset_index(drop=True)
|
|
58
|
+
logging.getLogger("replay").warning(
|
|
59
|
+
f"Found {dup_count} duplicate rows on subset {dup_subset} and dropped them"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return merged
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _merge_subsets_polars(
|
|
66
|
+
dfs: Sequence[PolarsDataFrame],
|
|
67
|
+
columns: Optional[Sequence[str]],
|
|
68
|
+
check_columns: bool,
|
|
69
|
+
subset_for_duplicates: Optional[Sequence[str]],
|
|
70
|
+
on_duplicate: Literal["error", "drop", "ignore"],
|
|
71
|
+
) -> PolarsDataFrame:
|
|
72
|
+
ref_cols = list(dfs[0].columns) if columns is None else list(columns)
|
|
73
|
+
|
|
74
|
+
aligned: list[PolarsDataFrame] = []
|
|
75
|
+
for i, df in enumerate(dfs):
|
|
76
|
+
_ensure_columns_match(df, ref_cols, i, check_columns)
|
|
77
|
+
aligned.append(df.select(ref_cols))
|
|
78
|
+
|
|
79
|
+
merged = pl.concat(aligned, how="vertical")
|
|
80
|
+
|
|
81
|
+
if on_duplicate == "ignore":
|
|
82
|
+
return merged
|
|
83
|
+
|
|
84
|
+
dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
|
|
85
|
+
dup_mask = merged.select(dup_subset).is_duplicated()
|
|
86
|
+
dup_count = int(dup_mask.sum())
|
|
87
|
+
|
|
88
|
+
if dup_count > 0:
|
|
89
|
+
if on_duplicate == "error":
|
|
90
|
+
msg = f"Found {dup_count} duplicate rows on subset {dup_subset}"
|
|
91
|
+
raise ValueError(msg)
|
|
92
|
+
if on_duplicate == "drop":
|
|
93
|
+
merged = merged.unique(subset=dup_subset, keep="first", maintain_order=True)
|
|
94
|
+
logging.getLogger("replay").warning(
|
|
95
|
+
f"Found {dup_count} duplicate rows on subset {dup_subset} and dropped them"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return merged
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _merge_subsets_spark(
|
|
102
|
+
dfs: Sequence[SparkDataFrame],
|
|
103
|
+
columns: Optional[Sequence[str]],
|
|
104
|
+
check_columns: bool,
|
|
105
|
+
subset_for_duplicates: Optional[Sequence[str]],
|
|
106
|
+
on_duplicate: Literal["error", "drop", "ignore"],
|
|
107
|
+
) -> SparkDataFrame:
|
|
108
|
+
ref_cols = list(dfs[0].columns) if columns is None else list(columns)
|
|
109
|
+
|
|
110
|
+
merged = None
|
|
111
|
+
for i, df in enumerate(dfs):
|
|
112
|
+
_ensure_columns_match(df, ref_cols, i, check_columns)
|
|
113
|
+
part = df.select(*ref_cols)
|
|
114
|
+
merged = part if merged is None else merged.unionByName(part)
|
|
115
|
+
|
|
116
|
+
if on_duplicate == "ignore":
|
|
117
|
+
return merged
|
|
118
|
+
|
|
119
|
+
dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
|
|
120
|
+
if on_duplicate == "error" and merged.groupBy(*dup_subset).count().filter(sf.col("count") > 1).limit(1).count() > 0:
|
|
121
|
+
msg = f"Found duplicate rows on subset {dup_subset}"
|
|
122
|
+
raise ValueError(msg)
|
|
123
|
+
if on_duplicate == "drop":
|
|
124
|
+
unique = merged.dropDuplicates(dup_subset)
|
|
125
|
+
logging.getLogger("replay").warning(
|
|
126
|
+
f"Found {merged.count() - unique.count()} duplicate rows on subset {dup_subset} and dropped them"
|
|
127
|
+
)
|
|
128
|
+
merged = unique
|
|
129
|
+
|
|
130
|
+
return merged
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def merge_subsets(
|
|
134
|
+
dfs: Sequence[DataFrameLike],
|
|
135
|
+
columns: Optional[Sequence[str]] = None,
|
|
136
|
+
check_columns: bool = True,
|
|
137
|
+
subset_for_duplicates: Optional[Sequence[str]] = None,
|
|
138
|
+
on_duplicate: Literal["error", "drop", "ignore"] = "error",
|
|
139
|
+
) -> DataFrameLike:
|
|
140
|
+
"""Merge multiple dataframes of the same backend into a single one.
|
|
141
|
+
|
|
142
|
+
All inputs must be of the same dataframe type (pandas/Polars/Spark). Before
|
|
143
|
+
concatenation, each dataframe is aligned to a common set of columns: either
|
|
144
|
+
the provided ``columns`` or the columns of the first dataframe. Duplicate
|
|
145
|
+
rows are handled according to ``on_duplicate``.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
dfs : Sequence[DataFrameLike]
|
|
150
|
+
Dataframes to merge.
|
|
151
|
+
columns : Optional[Sequence[str]]
|
|
152
|
+
Columns to align to. If ``None``, columns of the first dataframe are used.
|
|
153
|
+
check_columns : bool
|
|
154
|
+
Whether to validate that all inputs have the same column set.
|
|
155
|
+
subset_for_duplicates : Optional[Sequence[str]]
|
|
156
|
+
Columns subset used to detect duplicates. If ``None``, all aligned columns
|
|
157
|
+
are used.
|
|
158
|
+
on_duplicate : {"error", "drop", "ignore"}
|
|
159
|
+
How to handle duplicates: raise an error, drop them, or ignore.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
DataFrameLike
|
|
164
|
+
Merged dataframe of the same backend as the inputs.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ValueError
|
|
169
|
+
If ``dfs`` is empty, if duplicates are found with ``on_duplicate='error'``,
|
|
170
|
+
or if column sets differ when validation is enabled.
|
|
171
|
+
TypeError
|
|
172
|
+
If inputs are of different dataframe types.
|
|
173
|
+
"""
|
|
174
|
+
if not dfs:
|
|
175
|
+
msg = "At least one dataframe is required"
|
|
176
|
+
raise ValueError(msg)
|
|
177
|
+
|
|
178
|
+
first = dfs[0]
|
|
179
|
+
if any(not isinstance(df, type(first)) for df in dfs):
|
|
180
|
+
msg = "All input dataframes must be of the same type"
|
|
181
|
+
raise TypeError(msg)
|
|
182
|
+
|
|
183
|
+
if isinstance(first, PandasDataFrame):
|
|
184
|
+
return _merge_subsets_pandas(
|
|
185
|
+
dfs,
|
|
186
|
+
columns=columns,
|
|
187
|
+
check_columns=check_columns,
|
|
188
|
+
subset_for_duplicates=subset_for_duplicates,
|
|
189
|
+
on_duplicate=on_duplicate,
|
|
190
|
+
)
|
|
191
|
+
if isinstance(first, PolarsDataFrame):
|
|
192
|
+
return _merge_subsets_polars(
|
|
193
|
+
dfs,
|
|
194
|
+
columns=columns,
|
|
195
|
+
check_columns=check_columns,
|
|
196
|
+
subset_for_duplicates=subset_for_duplicates,
|
|
197
|
+
on_duplicate=on_duplicate,
|
|
198
|
+
)
|
|
199
|
+
if isinstance(first, SparkDataFrame):
|
|
200
|
+
return _merge_subsets_spark(
|
|
201
|
+
dfs,
|
|
202
|
+
columns=columns,
|
|
203
|
+
check_columns=check_columns,
|
|
204
|
+
subset_for_duplicates=subset_for_duplicates,
|
|
205
|
+
on_duplicate=on_duplicate,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
msg = f"Unsupported data frame type: {type(first)}"
|
|
209
|
+
raise NotImplementedError(msg)
|
replay/splitters/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from .cold_user_random_splitter import ColdUserRandomSplitter
|
|
|
7
7
|
from .k_folds import KFolds
|
|
8
8
|
from .last_n_splitter import LastNSplitter
|
|
9
9
|
from .new_users_splitter import NewUsersSplitter
|
|
10
|
+
from .random_next_n_splitter import RandomNextNSplitter
|
|
10
11
|
from .random_splitter import RandomSplitter
|
|
11
12
|
from .ratio_splitter import RatioSplitter
|
|
12
13
|
from .time_splitter import TimeSplitter
|