replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Iterable, Iterator
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.data as data
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
|
|
8
|
+
|
|
9
|
+
from .impl.named_columns import NamedColumns
|
|
10
|
+
from .info.replicas import ReplicasInfoProtocol
|
|
11
|
+
from .iterable_dataset import IterableDataset
|
|
12
|
+
|
|
13
|
+
Batch = dict[str, torch.Tensor]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PartitionedIterableDataset(data.IterableDataset):
|
|
17
|
+
"""
|
|
18
|
+
A dataset that implements iteration over partitioned data.
|
|
19
|
+
|
|
20
|
+
This implementation allows large amounts of data to be processed in batch-wise mode,
|
|
21
|
+
which is especially useful when used in distributed training.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
iterable: Iterable[NamedColumns],
|
|
27
|
+
batch_size: int,
|
|
28
|
+
generator: Optional[torch.Generator] = None,
|
|
29
|
+
replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
:param iterable: An iterable object that returns data partitions.
|
|
33
|
+
:param batch_size: Batch size.
|
|
34
|
+
:param generator: Random number generator for batch shuffling.
|
|
35
|
+
If ``None``, shuffling will be disabled. Default: ``None``.
|
|
36
|
+
:param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
|
|
37
|
+
Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.iterable = iterable
|
|
42
|
+
|
|
43
|
+
self.batch_size = batch_size
|
|
44
|
+
self.generator = generator
|
|
45
|
+
self.replicas_info = replicas_info
|
|
46
|
+
|
|
47
|
+
def __iter__(self) -> Iterator[Batch]:
|
|
48
|
+
for partition in iter(self.iterable):
|
|
49
|
+
iterable = IterableDataset(
|
|
50
|
+
named_columns=partition,
|
|
51
|
+
generator=self.generator,
|
|
52
|
+
batch_size=self.batch_size,
|
|
53
|
+
replicas_info=self.replicas_info,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
yield from iter(iterable)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from replay.data.nn.parquet.info.partitioning import partitioning_per_replica
|
|
6
|
+
from replay.data.nn.parquet.iterator import BatchesIterator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HasLengthProtocol(Protocol):
|
|
10
|
+
def __len__(self) -> int: ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compute_fixed_size_generic_length_from_sizes(
|
|
14
|
+
partition_sizes: Iterable[int], batch_size: int, num_replicas: int
|
|
15
|
+
) -> int:
|
|
16
|
+
residue = 0
|
|
17
|
+
batch_counter = 0
|
|
18
|
+
for partition_size in partition_sizes:
|
|
19
|
+
per_replica = partitioning_per_replica(partition_size, num_replicas)
|
|
20
|
+
batch_count = per_replica // batch_size
|
|
21
|
+
residue += per_replica % batch_size
|
|
22
|
+
if batch_size < residue:
|
|
23
|
+
batch_count += residue // batch_size
|
|
24
|
+
residue = residue % batch_size
|
|
25
|
+
batch_counter += batch_count
|
|
26
|
+
batch_counter += residue > 0
|
|
27
|
+
return batch_counter
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_fixed_size_batches_length(iterable: BatchesIterator, batch_size: int, num_replicas: int) -> int:
|
|
31
|
+
assert isinstance(iterable, BatchesIterator)
|
|
32
|
+
|
|
33
|
+
partition_size = iterable.batch_size
|
|
34
|
+
|
|
35
|
+
def default_partitions(fragment_size: int) -> list[int]:
|
|
36
|
+
full_partitions_count = fragment_size // partition_size
|
|
37
|
+
result = [partition_size] * full_partitions_count
|
|
38
|
+
if (residue := (fragment_size % partition_size)) > 0:
|
|
39
|
+
result.append(residue)
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
partition_sizes = []
|
|
43
|
+
for fragment in iterable.dataset.get_fragments():
|
|
44
|
+
fragment_size = fragment.count_rows()
|
|
45
|
+
partitions = default_partitions(fragment_size)
|
|
46
|
+
partition_sizes.extend(partitions)
|
|
47
|
+
|
|
48
|
+
result = compute_fixed_size_generic_length_from_sizes(
|
|
49
|
+
partition_sizes=partition_sizes,
|
|
50
|
+
num_replicas=num_replicas,
|
|
51
|
+
batch_size=batch_size,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def compute_fixed_size_generic_length(iterable: Iterable[HasLengthProtocol], batch_size: int, num_replicas: int) -> int:
|
|
58
|
+
warnings.warn("Generic length computation. This may cause performance issues.", UserWarning, stacklevel=2)
|
|
59
|
+
return compute_fixed_size_generic_length_from_sizes(map(len, iterable), batch_size, num_replicas)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def compute_fixed_size_length(iterable: Iterable[HasLengthProtocol], batch_size: int, num_replicas: int) -> int:
|
|
63
|
+
if isinstance(iterable, BatchesIterator):
|
|
64
|
+
return compute_fixed_size_batches_length(iterable, batch_size, num_replicas)
|
|
65
|
+
else:
|
|
66
|
+
return compute_fixed_size_generic_length(iterable, batch_size, num_replicas)
|
replay/data/nn/schema.py
CHANGED
|
@@ -86,12 +86,14 @@ class TensorFeatureInfo:
|
|
|
86
86
|
default: ``None``.
|
|
87
87
|
:param feature_sources: columns names and DataFrames feature came from,
|
|
88
88
|
default: ``None``.
|
|
89
|
-
:param cardinality: cardinality of categorical feature
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
:param
|
|
94
|
-
|
|
89
|
+
:param cardinality: cardinality of categorical feature.
|
|
90
|
+
number of unique items in vocabulary (catalog).
|
|
91
|
+
The specified cardinality value must not take into account the padding value.
|
|
92
|
+
Default: ``None``.
|
|
93
|
+
:param padding_value: value to pad sequences to desired length.
|
|
94
|
+
It is recommended to set the padding value for categorical features in the `cardinality` value.
|
|
95
|
+
:param embedding_dim: embedding dimensions of the feature.
|
|
96
|
+
Default: ``None`` - it means will be used value of ``DEFAULT_EMBEDDING_DIM``.
|
|
95
97
|
:param tensor_dim: tensor dimensions of numerical feature,
|
|
96
98
|
default: ``None``.
|
|
97
99
|
"""
|
|
@@ -106,8 +108,8 @@ class TensorFeatureInfo:
|
|
|
106
108
|
raise ValueError(msg)
|
|
107
109
|
self._feature_type = feature_type
|
|
108
110
|
|
|
109
|
-
if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST] and
|
|
110
|
-
msg = "Cardinality
|
|
111
|
+
if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST] and cardinality is not None:
|
|
112
|
+
msg = "Cardinality is needed only with categorical feature type."
|
|
111
113
|
raise ValueError(msg)
|
|
112
114
|
self._cardinality = cardinality
|
|
113
115
|
|
|
@@ -115,9 +117,8 @@ class TensorFeatureInfo:
|
|
|
115
117
|
msg = "Tensor dimensions is needed only with numerical feature type."
|
|
116
118
|
raise ValueError(msg)
|
|
117
119
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
else:
|
|
120
|
+
self._embedding_dim = embedding_dim or self.DEFAULT_EMBEDDING_DIM
|
|
121
|
+
if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST]:
|
|
121
122
|
self._tensor_dim = tensor_dim
|
|
122
123
|
|
|
123
124
|
@property
|
|
@@ -236,9 +237,6 @@ class TensorFeatureInfo:
|
|
|
236
237
|
"""
|
|
237
238
|
:returns: Embedding dimensions of the feature.
|
|
238
239
|
"""
|
|
239
|
-
if not self.is_cat:
|
|
240
|
-
msg = f"Can not get embedding dimensions because feature type of {self.name} feature is not categorical."
|
|
241
|
-
raise RuntimeError(msg)
|
|
242
240
|
return self._embedding_dim
|
|
243
241
|
|
|
244
242
|
def _set_embedding_dim(self, embedding_dim: int) -> None:
|
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
|
10
10
|
import polars as pl
|
|
11
11
|
from pandas import DataFrame as PandasDataFrame
|
|
12
12
|
from polars import DataFrame as PolarsDataFrame
|
|
13
|
+
from typing_extensions import deprecated
|
|
13
14
|
|
|
14
15
|
from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
|
|
15
16
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
@@ -24,6 +25,7 @@ SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
|
|
|
24
25
|
_T = TypeVar("_T")
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
@deprecated("`SequenceTokenizer` class is deprecated.")
|
|
27
29
|
class SequenceTokenizer:
|
|
28
30
|
"""
|
|
29
31
|
Data tokenizer for transformers;
|
|
@@ -507,6 +509,7 @@ class SequenceTokenizer:
|
|
|
507
509
|
pickle.dump(self, file)
|
|
508
510
|
|
|
509
511
|
|
|
512
|
+
@deprecated("`_BaseSequenceProcessor` class is deprecated.", stacklevel=2)
|
|
510
513
|
class _BaseSequenceProcessor(Generic[_T]):
|
|
511
514
|
"""
|
|
512
515
|
Base class for sequence processing
|
|
@@ -600,6 +603,7 @@ class _BaseSequenceProcessor(Generic[_T]):
|
|
|
600
603
|
pass
|
|
601
604
|
|
|
602
605
|
|
|
606
|
+
@deprecated("`_PandasSequenceProcessor` class is deprecated.", stacklevel=2)
|
|
603
607
|
class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
604
608
|
"""
|
|
605
609
|
Class to process sequences of different categorical and numerical features.
|
|
@@ -780,6 +784,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
780
784
|
return values
|
|
781
785
|
|
|
782
786
|
|
|
787
|
+
@deprecated("`_PolarsSequenceProcessor` class is deprecated.", stacklevel=2)
|
|
783
788
|
class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
784
789
|
"""
|
|
785
790
|
Class to process sequences of different categorical and numerical features.
|
|
@@ -8,11 +8,13 @@ import pandas as pd
|
|
|
8
8
|
import polars as pl
|
|
9
9
|
from pandas import DataFrame as PandasDataFrame
|
|
10
10
|
from polars import DataFrame as PolarsDataFrame
|
|
11
|
+
from typing_extensions import deprecated
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from .schema import TensorSchema
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
@deprecated("`SequentialDataset` class is deprecated.", stacklevel=2)
|
|
16
18
|
class SequentialDataset(abc.ABC):
|
|
17
19
|
"""
|
|
18
20
|
Abstract base class for sequential dataset
|
|
@@ -138,6 +140,7 @@ class SequentialDataset(abc.ABC):
|
|
|
138
140
|
return df_converted
|
|
139
141
|
|
|
140
142
|
|
|
143
|
+
@deprecated("`PandasSequentialDataset` class is deprecated.")
|
|
141
144
|
class PandasSequentialDataset(SequentialDataset):
|
|
142
145
|
"""
|
|
143
146
|
Sequential dataset that stores sequences in PandasDataFrame format.
|
|
@@ -234,6 +237,7 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
234
237
|
return dataset
|
|
235
238
|
|
|
236
239
|
|
|
240
|
+
@deprecated("`PolarsSequentialDataset` class is deprecated.")
|
|
237
241
|
class PolarsSequentialDataset(PandasSequentialDataset):
|
|
238
242
|
"""
|
|
239
243
|
Sequential dataset that stores sequences in PolarsDataFrame format.
|
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
7
|
from torch.utils.data import Dataset as TorchDataset
|
|
8
|
+
from typing_extensions import deprecated
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
10
11
|
from .schema import TensorFeatureInfo, TensorMap, TensorSchema
|
|
@@ -13,6 +14,7 @@ if TYPE_CHECKING:
|
|
|
13
14
|
|
|
14
15
|
# We do not use dataclasses as PyTorch default collate
|
|
15
16
|
# function in dataloader supports only namedtuple
|
|
17
|
+
@deprecated("`TorchSequentialBatch` class is deprecated.", stacklevel=2)
|
|
16
18
|
class TorchSequentialBatch(NamedTuple):
|
|
17
19
|
"""
|
|
18
20
|
Batch of TorchSequentialDataset
|
|
@@ -23,6 +25,7 @@ class TorchSequentialBatch(NamedTuple):
|
|
|
23
25
|
features: "TensorMap"
|
|
24
26
|
|
|
25
27
|
|
|
28
|
+
@deprecated("`TorchSequentialDataset` class is deprecated.")
|
|
26
29
|
class TorchSequentialDataset(TorchDataset):
|
|
27
30
|
"""
|
|
28
31
|
Torch dataset for sequential recommender models
|
|
@@ -160,6 +163,7 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
160
163
|
yield (i, offset_from_seq_beginning)
|
|
161
164
|
|
|
162
165
|
|
|
166
|
+
@deprecated("`TorchSequentialValidationBatch` class is deprecated.", stacklevel=2)
|
|
163
167
|
class TorchSequentialValidationBatch(NamedTuple):
|
|
164
168
|
"""
|
|
165
169
|
Batch of TorchSequentialValidationDataset
|
|
@@ -176,6 +180,7 @@ DEFAULT_GROUND_TRUTH_PADDING_VALUE = -1
|
|
|
176
180
|
DEFAULT_TRAIN_PADDING_VALUE = -2
|
|
177
181
|
|
|
178
182
|
|
|
183
|
+
@deprecated("`TorchSequentialValidationDataset` class is deprecated.")
|
|
179
184
|
class TorchSequentialValidationDataset(TorchDataset):
|
|
180
185
|
"""
|
|
181
186
|
Torch dataset for sequential recommender models that additionally stores ground truth
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Iterator, Tuple
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def validate_length(length: int) -> int:
|
|
6
|
+
if length < 1:
|
|
7
|
+
msg: str = f"Length is invalid. Got {length}."
|
|
8
|
+
raise ValueError(msg)
|
|
9
|
+
return length
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_batch_size(batch_size: int) -> int:
|
|
13
|
+
if batch_size < 1:
|
|
14
|
+
msg: str = f"Batch Size is invalid. Got {batch_size}."
|
|
15
|
+
raise ValueError(msg)
|
|
16
|
+
return batch_size
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def validate_input(length: int, batch_size: int) -> Tuple[int, int]:
|
|
20
|
+
length = validate_length(length)
|
|
21
|
+
batch_size = validate_batch_size(batch_size)
|
|
22
|
+
return (length, batch_size)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def uniform_batch_count(length: int, batch_size: int) -> int:
|
|
26
|
+
@lru_cache
|
|
27
|
+
def _uniform_batch_count(length: int, batch_size: int) -> int:
|
|
28
|
+
length, batch_size = validate_input(length, batch_size)
|
|
29
|
+
batch_count: int = length // batch_size
|
|
30
|
+
batch_count = batch_count + bool(length % batch_size)
|
|
31
|
+
assert batch_count >= 1
|
|
32
|
+
assert length <= batch_count * batch_size
|
|
33
|
+
assert (batch_count - 1) * batch_size < length
|
|
34
|
+
return batch_count
|
|
35
|
+
|
|
36
|
+
return _uniform_batch_count(length, batch_size)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class UniformBatching:
|
|
40
|
+
def __init__(self, length: int, batch_size: int) -> None:
|
|
41
|
+
length, batch_size = validate_input(length, batch_size)
|
|
42
|
+
|
|
43
|
+
self.length: int = length
|
|
44
|
+
self.batch_size: int = batch_size
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def batch_count(self) -> int:
|
|
48
|
+
return uniform_batch_count(self.length, self.batch_size)
|
|
49
|
+
|
|
50
|
+
def __len__(self) -> int:
|
|
51
|
+
return self.batch_count
|
|
52
|
+
|
|
53
|
+
def get_limits(self, index: int) -> Tuple[int, int]:
|
|
54
|
+
if (index < 0) or (self.batch_count <= index):
|
|
55
|
+
msg: str = f"Batching Index is invalid. Got {index}."
|
|
56
|
+
raise IndexError(msg)
|
|
57
|
+
first: int = index * self.batch_size
|
|
58
|
+
last: int = min(self.length, first + self.batch_size)
|
|
59
|
+
assert (first >= 0) and (first < self.length)
|
|
60
|
+
assert (first < last) and (last <= self.length)
|
|
61
|
+
return (first, last)
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, index: int) -> Tuple[int, int]:
|
|
64
|
+
return self.get_limits(index)
|
|
65
|
+
|
|
66
|
+
def __iter__(self) -> Iterator[Tuple[int, int]]:
|
|
67
|
+
index: int
|
|
68
|
+
for index in range(self.batch_count):
|
|
69
|
+
yield self.get_limits(index)
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@lru_cache
|
|
9
|
+
def _torch_to_numpy(dtype: torch.dtype) -> np.dtype:
|
|
10
|
+
exemplar: torch.Tensor = torch.asarray([0], dtype=dtype)
|
|
11
|
+
return exemplar.numpy().dtype
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def torch_to_numpy(dtype: torch.dtype) -> np.dtype:
|
|
15
|
+
return _torch_to_numpy(dtype)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def _numpy_to_torch(dtype: np.dtype) -> torch.dtype:
|
|
20
|
+
exemplar: np.ndarray = np.asarray([0], dtype=dtype)
|
|
21
|
+
return torch.from_numpy(exemplar).dtype
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def numpy_to_torch(dtype: np.dtype) -> torch.dtype:
|
|
25
|
+
return _numpy_to_torch(dtype)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@lru_cache
|
|
29
|
+
def _pyarrow_to_numpy(dtype: pa.DataType) -> np.dtype:
|
|
30
|
+
exemplar: pa.Array = pa.array([0], type=dtype)
|
|
31
|
+
return exemplar.to_numpy().dtype
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def pyarrow_to_numpy(dtype: pa.DataType) -> np.dtype:
|
|
35
|
+
return _pyarrow_to_numpy(dtype)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@lru_cache
|
|
39
|
+
def _numpy_to_pyarrow(dtype: np.dtype) -> pa.DataType:
|
|
40
|
+
exemplar: np.ndarray = np.asarray([0], dtype=dtype)
|
|
41
|
+
return pa.array(exemplar).type
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def numpy_to_pyarrow(dtype: np.dtype) -> pa.DataType:
|
|
45
|
+
return _numpy_to_pyarrow(dtype)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@lru_cache
|
|
49
|
+
def _torch_to_pyarrow(dtype: torch.dtype) -> pa.DataType:
|
|
50
|
+
np_dtype: np.dtype = torch_to_numpy(dtype)
|
|
51
|
+
return numpy_to_pyarrow(np_dtype)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def torch_to_pyarrow(dtype: torch.dtype) -> pa.DataType:
|
|
55
|
+
return _torch_to_pyarrow(dtype)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@lru_cache
|
|
59
|
+
def _pyarrow_to_torch(dtype: pa.DataType) -> torch.dtype:
|
|
60
|
+
np_dtype: np.dtype = pyarrow_to_numpy(dtype)
|
|
61
|
+
return numpy_to_torch(np_dtype)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def pyarrow_to_torch(dtype: pa.DataType) -> torch.dtype:
|
|
65
|
+
return _pyarrow_to_torch(dtype)
|
|
@@ -139,7 +139,9 @@ class _CoverageHelper:
|
|
|
139
139
|
"""
|
|
140
140
|
self._ensure_hists_on_device(train.device)
|
|
141
141
|
flatten_train = train.flatten()
|
|
142
|
-
filtered_train = torch.masked_select(
|
|
142
|
+
filtered_train = torch.masked_select(
|
|
143
|
+
flatten_train, ((flatten_train >= 0) & (flatten_train <= self.item_count - 1))
|
|
144
|
+
)
|
|
143
145
|
self._train_hist += torch.histc(filtered_train.float(), bins=self.item_count, min=0, max=self.item_count - 1)
|
|
144
146
|
|
|
145
147
|
def get_metrics(self) -> Mapping[str, float]:
|
|
@@ -193,7 +195,7 @@ class _MetricBuilder(abc.ABC):
|
|
|
193
195
|
|
|
194
196
|
class TorchMetricsBuilder(_MetricBuilder):
|
|
195
197
|
"""
|
|
196
|
-
Computes specified metrics over multiple batches
|
|
198
|
+
Computes specified metrics over multiple batches.
|
|
197
199
|
"""
|
|
198
200
|
|
|
199
201
|
def __init__(
|
|
@@ -203,12 +205,12 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
203
205
|
item_count: Optional[int] = None,
|
|
204
206
|
) -> None:
|
|
205
207
|
"""
|
|
206
|
-
:param metrics:
|
|
207
|
-
Default:
|
|
208
|
-
:param top_k:
|
|
209
|
-
Default:
|
|
210
|
-
:param item_count:
|
|
211
|
-
You can omit this parameter if you don't need to calculate the Coverage metric.
|
|
208
|
+
:param metrics: Names of metrics to calculate.
|
|
209
|
+
Default: ``["map", "ndcg", "recall"]``.
|
|
210
|
+
:param top_k: Consider the highest k scores in the ranking.
|
|
211
|
+
Default: ``[1, 5, 10, 20]``.
|
|
212
|
+
:param item_count: the total number of items in the dataset.
|
|
213
|
+
You can omit this parameter if you don't need to calculate the ``Coverage`` metric.
|
|
212
214
|
"""
|
|
213
215
|
self._mr = _MetricRequirements.from_metrics(
|
|
214
216
|
set(metrics),
|
|
@@ -272,12 +274,16 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
272
274
|
"""
|
|
273
275
|
Add a batch with predictions, ground truth and train set to calculate the metrics.
|
|
274
276
|
|
|
275
|
-
:param predictions:
|
|
276
|
-
:param ground_truth:
|
|
277
|
-
If users have a test set of different sizes then you need to do
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
277
|
+
:param predictions: A batch with the same number of recommendations for each user.
|
|
278
|
+
:param ground_truth: A batch corresponding to the test set for each user.
|
|
279
|
+
If users have a test set of different sizes then you need to do
|
|
280
|
+
the padding using a value that is not found in the item ID's.
|
|
281
|
+
For example, these can be negative values.
|
|
282
|
+
:param train: A batch corresponding to the train set for each user.
|
|
283
|
+
If users have a train set of different sizes then you need to do
|
|
284
|
+
the padding using a value that is not found in the item ID's and ``ground_truth``.
|
|
285
|
+
For example, these can be negative values.
|
|
286
|
+
You can omit this parameter if you don't need to calculate the ``coverage`` or ``novelty`` metrics.
|
|
281
287
|
"""
|
|
282
288
|
self._ensure_constants_on_device(predictions.device)
|
|
283
289
|
metrics_sum = np.array(self._compute_metrics_sum(predictions, ground_truth, train), dtype=np.float64)
|
replay/models/nn/loss/sce.py
CHANGED
|
@@ -6,9 +6,9 @@ import torch
|
|
|
6
6
|
|
|
7
7
|
@dataclass(frozen=True)
|
|
8
8
|
class SCEParams:
|
|
9
|
-
"""
|
|
9
|
+
"""
|
|
10
|
+
Set of parameters for ScalableCrossEntropyLoss.
|
|
10
11
|
|
|
11
|
-
Constructor arguments:
|
|
12
12
|
:param n_buckets: Number of buckets into which samples will be distributed.
|
|
13
13
|
:param bucket_size_x: Number of item hidden representations that will be in each bucket.
|
|
14
14
|
:param bucket_size_y: Number of item embeddings that will be in each bucket.
|
|
@@ -33,11 +33,6 @@ class ScalableCrossEntropyLoss:
|
|
|
33
33
|
|
|
34
34
|
:param SCEParams: Dataclass with ScalableCrossEntropyLoss parameters.
|
|
35
35
|
Dataclass contains following values:
|
|
36
|
-
:param n_buckets: Number of buckets into which samples will be distributed.
|
|
37
|
-
:param bucket_size_x: Number of item hidden representations that will be in each bucket.
|
|
38
|
-
:param bucket_size_y: Number of item embeddings that will be in each bucket.
|
|
39
|
-
:param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
|
|
40
|
-
Default: ``False``.
|
|
41
36
|
"""
|
|
42
37
|
assert all(
|
|
43
38
|
param is not None for param in sce_params._get_not_none_params()
|
|
@@ -1,4 +1,9 @@
|
|
|
1
1
|
from replay.utils import TORCH_AVAILABLE
|
|
2
2
|
|
|
3
3
|
if TORCH_AVAILABLE:
|
|
4
|
-
from .optimizer_factory import
|
|
4
|
+
from .optimizer_factory import (
|
|
5
|
+
FatLRSchedulerFactory,
|
|
6
|
+
FatOptimizerFactory,
|
|
7
|
+
LRSchedulerFactory,
|
|
8
|
+
OptimizerFactory,
|
|
9
|
+
)
|
|
@@ -2,8 +2,13 @@ import abc
|
|
|
2
2
|
from collections.abc import Iterator
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from typing_extensions import deprecated
|
|
5
6
|
|
|
6
7
|
|
|
8
|
+
@deprecated(
|
|
9
|
+
"`OptimizerFactory` class is deprecated. Use `replay.nn.lightning.optimizer.BaseOptimizerFactory` instead.",
|
|
10
|
+
stacklevel=2,
|
|
11
|
+
)
|
|
7
12
|
class OptimizerFactory(abc.ABC):
|
|
8
13
|
"""
|
|
9
14
|
Interface for optimizer factory
|
|
@@ -20,6 +25,10 @@ class OptimizerFactory(abc.ABC):
|
|
|
20
25
|
"""
|
|
21
26
|
|
|
22
27
|
|
|
28
|
+
@deprecated(
|
|
29
|
+
"`LRSchedulerFactory` class is deprecated. Use `replay.nn.lightning.scheduler.BaseLRSchedulerFactory` instead.",
|
|
30
|
+
stacklevel=2,
|
|
31
|
+
)
|
|
23
32
|
class LRSchedulerFactory(abc.ABC):
|
|
24
33
|
"""
|
|
25
34
|
Interface for learning rate scheduler factory
|
|
@@ -36,6 +45,9 @@ class LRSchedulerFactory(abc.ABC):
|
|
|
36
45
|
"""
|
|
37
46
|
|
|
38
47
|
|
|
48
|
+
@deprecated(
|
|
49
|
+
"`FatOptimizerFactory` class is deprecated. Use `replay.nn.lightning.optimizer.OptimizerFactory` instead.",
|
|
50
|
+
)
|
|
39
51
|
class FatOptimizerFactory(OptimizerFactory):
|
|
40
52
|
"""
|
|
41
53
|
Factory that creates optimizer depending on passed parameters
|
|
@@ -75,6 +87,9 @@ class FatOptimizerFactory(OptimizerFactory):
|
|
|
75
87
|
raise ValueError(msg)
|
|
76
88
|
|
|
77
89
|
|
|
90
|
+
@deprecated(
|
|
91
|
+
"`FatLRSchedulerFactory` class is deprecated. Use `replay.nn.lightning.scheduler.LRSchedulerFactory` instead.",
|
|
92
|
+
)
|
|
78
93
|
class FatLRSchedulerFactory(LRSchedulerFactory):
|
|
79
94
|
"""
|
|
80
95
|
Factory that creates learning rate schedule depending on passed parameters
|