replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterable, Iterator
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.utils.data as data
6
+
7
+ from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
8
+
9
+ from .impl.named_columns import NamedColumns
10
+ from .info.replicas import ReplicasInfoProtocol
11
+ from .iterable_dataset import IterableDataset
12
+
13
+ Batch = dict[str, torch.Tensor]
14
+
15
+
16
+ class PartitionedIterableDataset(data.IterableDataset):
17
+ """
18
+ A dataset that implements iteration over partitioned data.
19
+
20
+ This implementation allows large amounts of data to be processed in batch-wise mode,
21
+ which is especially useful when used in distributed training.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ iterable: Iterable[NamedColumns],
27
+ batch_size: int,
28
+ generator: Optional[torch.Generator] = None,
29
+ replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
30
+ ) -> None:
31
+ """
32
+ :param iterable: An iterable object that returns data partitions.
33
+ :param batch_size: Batch size.
34
+ :param generator: Random number generator for batch shuffling.
35
+ If ``None``, shuffling will be disabled. Default: ``None``.
36
+ :param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
37
+ Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
38
+ """
39
+ super().__init__()
40
+
41
+ self.iterable = iterable
42
+
43
+ self.batch_size = batch_size
44
+ self.generator = generator
45
+ self.replicas_info = replicas_info
46
+
47
+ def __iter__(self) -> Iterator[Batch]:
48
+ for partition in iter(self.iterable):
49
+ iterable = IterableDataset(
50
+ named_columns=partition,
51
+ generator=self.generator,
52
+ batch_size=self.batch_size,
53
+ replicas_info=self.replicas_info,
54
+ )
55
+
56
+ yield from iter(iterable)
File without changes
@@ -0,0 +1,66 @@
1
+ import warnings
2
+ from collections.abc import Iterable
3
+ from typing import Protocol
4
+
5
+ from replay.data.nn.parquet.info.partitioning import partitioning_per_replica
6
+ from replay.data.nn.parquet.iterator import BatchesIterator
7
+
8
+
9
+ class HasLengthProtocol(Protocol):
10
+ def __len__(self) -> int: ...
11
+
12
+
13
+ def compute_fixed_size_generic_length_from_sizes(
14
+ partition_sizes: Iterable[int], batch_size: int, num_replicas: int
15
+ ) -> int:
16
+ residue = 0
17
+ batch_counter = 0
18
+ for partition_size in partition_sizes:
19
+ per_replica = partitioning_per_replica(partition_size, num_replicas)
20
+ batch_count = per_replica // batch_size
21
+ residue += per_replica % batch_size
22
+ if batch_size < residue:
23
+ batch_count += residue // batch_size
24
+ residue = residue % batch_size
25
+ batch_counter += batch_count
26
+ batch_counter += residue > 0
27
+ return batch_counter
28
+
29
+
30
+ def compute_fixed_size_batches_length(iterable: BatchesIterator, batch_size: int, num_replicas: int) -> int:
31
+ assert isinstance(iterable, BatchesIterator)
32
+
33
+ partition_size = iterable.batch_size
34
+
35
+ def default_partitions(fragment_size: int) -> list[int]:
36
+ full_partitions_count = fragment_size // partition_size
37
+ result = [partition_size] * full_partitions_count
38
+ if (residue := (fragment_size % partition_size)) > 0:
39
+ result.append(residue)
40
+ return result
41
+
42
+ partition_sizes = []
43
+ for fragment in iterable.dataset.get_fragments():
44
+ fragment_size = fragment.count_rows()
45
+ partitions = default_partitions(fragment_size)
46
+ partition_sizes.extend(partitions)
47
+
48
+ result = compute_fixed_size_generic_length_from_sizes(
49
+ partition_sizes=partition_sizes,
50
+ num_replicas=num_replicas,
51
+ batch_size=batch_size,
52
+ )
53
+
54
+ return result
55
+
56
+
57
+ def compute_fixed_size_generic_length(iterable: Iterable[HasLengthProtocol], batch_size: int, num_replicas: int) -> int:
58
+ warnings.warn("Generic length computation. This may cause performance issues.", UserWarning, stacklevel=2)
59
+ return compute_fixed_size_generic_length_from_sizes(map(len, iterable), batch_size, num_replicas)
60
+
61
+
62
+ def compute_fixed_size_length(iterable: Iterable[HasLengthProtocol], batch_size: int, num_replicas: int) -> int:
63
+ if isinstance(iterable, BatchesIterator):
64
+ return compute_fixed_size_batches_length(iterable, batch_size, num_replicas)
65
+ else:
66
+ return compute_fixed_size_generic_length(iterable, batch_size, num_replicas)
replay/data/nn/schema.py CHANGED
@@ -86,12 +86,14 @@ class TensorFeatureInfo:
86
86
  default: ``None``.
87
87
  :param feature_sources: columns names and DataFrames feature came from,
88
88
  default: ``None``.
89
- :param cardinality: cardinality of categorical feature, required for ids columns,
90
- optional for others,
91
- default: ``None``.
92
- :param padding_value: value to pad sequences to desired length
93
- :param embedding_dim: embedding dimensions of categorical feature,
94
- default: ``None``.
89
+ :param cardinality: cardinality of categorical feature.
90
+ number of unique items in vocabulary (catalog).
91
+ The specified cardinality value must not take into account the padding value.
92
+ Default: ``None``.
93
+ :param padding_value: value to pad sequences to desired length.
94
+ It is recommended to set the padding value for categorical features in the `cardinality` value.
95
+ :param embedding_dim: embedding dimensions of the feature.
96
+ Default: ``None`` - it means will be used value of ``DEFAULT_EMBEDDING_DIM``.
95
97
  :param tensor_dim: tensor dimensions of numerical feature,
96
98
  default: ``None``.
97
99
  """
@@ -106,8 +108,8 @@ class TensorFeatureInfo:
106
108
  raise ValueError(msg)
107
109
  self._feature_type = feature_type
108
110
 
109
- if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST] and (cardinality or embedding_dim):
110
- msg = "Cardinality and embedding dimensions are needed only with categorical feature type."
111
+ if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST] and cardinality is not None:
112
+ msg = "Cardinality is needed only with categorical feature type."
111
113
  raise ValueError(msg)
112
114
  self._cardinality = cardinality
113
115
 
@@ -115,9 +117,8 @@ class TensorFeatureInfo:
115
117
  msg = "Tensor dimensions is needed only with numerical feature type."
116
118
  raise ValueError(msg)
117
119
 
118
- if feature_type in [FeatureType.CATEGORICAL, FeatureType.CATEGORICAL_LIST]:
119
- self._embedding_dim = embedding_dim or self.DEFAULT_EMBEDDING_DIM
120
- else:
120
+ self._embedding_dim = embedding_dim or self.DEFAULT_EMBEDDING_DIM
121
+ if feature_type in [FeatureType.NUMERICAL, FeatureType.NUMERICAL_LIST]:
121
122
  self._tensor_dim = tensor_dim
122
123
 
123
124
  @property
@@ -236,9 +237,6 @@ class TensorFeatureInfo:
236
237
  """
237
238
  :returns: Embedding dimensions of the feature.
238
239
  """
239
- if not self.is_cat:
240
- msg = f"Can not get embedding dimensions because feature type of {self.name} feature is not categorical."
241
- raise RuntimeError(msg)
242
240
  return self._embedding_dim
243
241
 
244
242
  def _set_embedding_dim(self, embedding_dim: int) -> None:
@@ -10,6 +10,7 @@ import numpy as np
10
10
  import polars as pl
11
11
  from pandas import DataFrame as PandasDataFrame
12
12
  from polars import DataFrame as PolarsDataFrame
13
+ from typing_extensions import deprecated
13
14
 
14
15
  from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
15
16
  from replay.data.dataset_utils import DatasetLabelEncoder
@@ -24,6 +25,7 @@ SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
25
  _T = TypeVar("_T")
25
26
 
26
27
 
28
+ @deprecated("`SequenceTokenizer` class is deprecated.")
27
29
  class SequenceTokenizer:
28
30
  """
29
31
  Data tokenizer for transformers;
@@ -507,6 +509,7 @@ class SequenceTokenizer:
507
509
  pickle.dump(self, file)
508
510
 
509
511
 
512
+ @deprecated("`_BaseSequenceProcessor` class is deprecated.", stacklevel=2)
510
513
  class _BaseSequenceProcessor(Generic[_T]):
511
514
  """
512
515
  Base class for sequence processing
@@ -600,6 +603,7 @@ class _BaseSequenceProcessor(Generic[_T]):
600
603
  pass
601
604
 
602
605
 
606
+ @deprecated("`_PandasSequenceProcessor` class is deprecated.", stacklevel=2)
603
607
  class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
604
608
  """
605
609
  Class to process sequences of different categorical and numerical features.
@@ -780,6 +784,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
780
784
  return values
781
785
 
782
786
 
787
+ @deprecated("`_PolarsSequenceProcessor` class is deprecated.", stacklevel=2)
783
788
  class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
784
789
  """
785
790
  Class to process sequences of different categorical and numerical features.
@@ -8,11 +8,13 @@ import pandas as pd
8
8
  import polars as pl
9
9
  from pandas import DataFrame as PandasDataFrame
10
10
  from polars import DataFrame as PolarsDataFrame
11
+ from typing_extensions import deprecated
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .schema import TensorSchema
14
15
 
15
16
 
17
+ @deprecated("`SequentialDataset` class is deprecated.", stacklevel=2)
16
18
  class SequentialDataset(abc.ABC):
17
19
  """
18
20
  Abstract base class for sequential dataset
@@ -138,6 +140,7 @@ class SequentialDataset(abc.ABC):
138
140
  return df_converted
139
141
 
140
142
 
143
+ @deprecated("`PandasSequentialDataset` class is deprecated.")
141
144
  class PandasSequentialDataset(SequentialDataset):
142
145
  """
143
146
  Sequential dataset that stores sequences in PandasDataFrame format.
@@ -234,6 +237,7 @@ class PandasSequentialDataset(SequentialDataset):
234
237
  return dataset
235
238
 
236
239
 
240
+ @deprecated("`PolarsSequentialDataset` class is deprecated.")
237
241
  class PolarsSequentialDataset(PandasSequentialDataset):
238
242
  """
239
243
  Sequential dataset that stores sequences in PolarsDataFrame format.
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
5
5
  import numpy as np
6
6
  import torch
7
7
  from torch.utils.data import Dataset as TorchDataset
8
+ from typing_extensions import deprecated
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  from .schema import TensorFeatureInfo, TensorMap, TensorSchema
@@ -13,6 +14,7 @@ if TYPE_CHECKING:
13
14
 
14
15
  # We do not use dataclasses as PyTorch default collate
15
16
  # function in dataloader supports only namedtuple
17
+ @deprecated("`TorchSequentialBatch` class is deprecated.", stacklevel=2)
16
18
  class TorchSequentialBatch(NamedTuple):
17
19
  """
18
20
  Batch of TorchSequentialDataset
@@ -23,6 +25,7 @@ class TorchSequentialBatch(NamedTuple):
23
25
  features: "TensorMap"
24
26
 
25
27
 
28
+ @deprecated("`TorchSequentialDataset` class is deprecated.")
26
29
  class TorchSequentialDataset(TorchDataset):
27
30
  """
28
31
  Torch dataset for sequential recommender models
@@ -160,6 +163,7 @@ class TorchSequentialDataset(TorchDataset):
160
163
  yield (i, offset_from_seq_beginning)
161
164
 
162
165
 
166
+ @deprecated("`TorchSequentialValidationBatch` class is deprecated.", stacklevel=2)
163
167
  class TorchSequentialValidationBatch(NamedTuple):
164
168
  """
165
169
  Batch of TorchSequentialValidationDataset
@@ -176,6 +180,7 @@ DEFAULT_GROUND_TRUTH_PADDING_VALUE = -1
176
180
  DEFAULT_TRAIN_PADDING_VALUE = -2
177
181
 
178
182
 
183
+ @deprecated("`TorchSequentialValidationDataset` class is deprecated.")
179
184
  class TorchSequentialValidationDataset(TorchDataset):
180
185
  """
181
186
  Torch dataset for sequential recommender models that additionally stores ground truth
File without changes
@@ -0,0 +1,69 @@
1
+ from functools import lru_cache
2
+ from typing import Iterator, Tuple
3
+
4
+
5
+ def validate_length(length: int) -> int:
6
+ if length < 1:
7
+ msg: str = f"Length is invalid. Got {length}."
8
+ raise ValueError(msg)
9
+ return length
10
+
11
+
12
+ def validate_batch_size(batch_size: int) -> int:
13
+ if batch_size < 1:
14
+ msg: str = f"Batch Size is invalid. Got {batch_size}."
15
+ raise ValueError(msg)
16
+ return batch_size
17
+
18
+
19
+ def validate_input(length: int, batch_size: int) -> Tuple[int, int]:
20
+ length = validate_length(length)
21
+ batch_size = validate_batch_size(batch_size)
22
+ return (length, batch_size)
23
+
24
+
25
+ def uniform_batch_count(length: int, batch_size: int) -> int:
26
+ @lru_cache
27
+ def _uniform_batch_count(length: int, batch_size: int) -> int:
28
+ length, batch_size = validate_input(length, batch_size)
29
+ batch_count: int = length // batch_size
30
+ batch_count = batch_count + bool(length % batch_size)
31
+ assert batch_count >= 1
32
+ assert length <= batch_count * batch_size
33
+ assert (batch_count - 1) * batch_size < length
34
+ return batch_count
35
+
36
+ return _uniform_batch_count(length, batch_size)
37
+
38
+
39
+ class UniformBatching:
40
+ def __init__(self, length: int, batch_size: int) -> None:
41
+ length, batch_size = validate_input(length, batch_size)
42
+
43
+ self.length: int = length
44
+ self.batch_size: int = batch_size
45
+
46
+ @property
47
+ def batch_count(self) -> int:
48
+ return uniform_batch_count(self.length, self.batch_size)
49
+
50
+ def __len__(self) -> int:
51
+ return self.batch_count
52
+
53
+ def get_limits(self, index: int) -> Tuple[int, int]:
54
+ if (index < 0) or (self.batch_count <= index):
55
+ msg: str = f"Batching Index is invalid. Got {index}."
56
+ raise IndexError(msg)
57
+ first: int = index * self.batch_size
58
+ last: int = min(self.length, first + self.batch_size)
59
+ assert (first >= 0) and (first < self.length)
60
+ assert (first < last) and (last <= self.length)
61
+ return (first, last)
62
+
63
+ def __getitem__(self, index: int) -> Tuple[int, int]:
64
+ return self.get_limits(index)
65
+
66
+ def __iter__(self) -> Iterator[Tuple[int, int]]:
67
+ index: int
68
+ for index in range(self.batch_count):
69
+ yield self.get_limits(index)
File without changes
@@ -0,0 +1,65 @@
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import pyarrow as pa
5
+ import torch
6
+
7
+
8
+ @lru_cache
9
+ def _torch_to_numpy(dtype: torch.dtype) -> np.dtype:
10
+ exemplar: torch.Tensor = torch.asarray([0], dtype=dtype)
11
+ return exemplar.numpy().dtype
12
+
13
+
14
+ def torch_to_numpy(dtype: torch.dtype) -> np.dtype:
15
+ return _torch_to_numpy(dtype)
16
+
17
+
18
+ @lru_cache
19
+ def _numpy_to_torch(dtype: np.dtype) -> torch.dtype:
20
+ exemplar: np.ndarray = np.asarray([0], dtype=dtype)
21
+ return torch.from_numpy(exemplar).dtype
22
+
23
+
24
+ def numpy_to_torch(dtype: np.dtype) -> torch.dtype:
25
+ return _numpy_to_torch(dtype)
26
+
27
+
28
+ @lru_cache
29
+ def _pyarrow_to_numpy(dtype: pa.DataType) -> np.dtype:
30
+ exemplar: pa.Array = pa.array([0], type=dtype)
31
+ return exemplar.to_numpy().dtype
32
+
33
+
34
+ def pyarrow_to_numpy(dtype: pa.DataType) -> np.dtype:
35
+ return _pyarrow_to_numpy(dtype)
36
+
37
+
38
+ @lru_cache
39
+ def _numpy_to_pyarrow(dtype: np.dtype) -> pa.DataType:
40
+ exemplar: np.ndarray = np.asarray([0], dtype=dtype)
41
+ return pa.array(exemplar).type
42
+
43
+
44
+ def numpy_to_pyarrow(dtype: np.dtype) -> pa.DataType:
45
+ return _numpy_to_pyarrow(dtype)
46
+
47
+
48
+ @lru_cache
49
+ def _torch_to_pyarrow(dtype: torch.dtype) -> pa.DataType:
50
+ np_dtype: np.dtype = torch_to_numpy(dtype)
51
+ return numpy_to_pyarrow(np_dtype)
52
+
53
+
54
+ def torch_to_pyarrow(dtype: torch.dtype) -> pa.DataType:
55
+ return _torch_to_pyarrow(dtype)
56
+
57
+
58
+ @lru_cache
59
+ def _pyarrow_to_torch(dtype: pa.DataType) -> torch.dtype:
60
+ np_dtype: np.dtype = pyarrow_to_numpy(dtype)
61
+ return numpy_to_torch(np_dtype)
62
+
63
+
64
+ def pyarrow_to_torch(dtype: pa.DataType) -> torch.dtype:
65
+ return _pyarrow_to_torch(dtype)
@@ -139,7 +139,9 @@ class _CoverageHelper:
139
139
  """
140
140
  self._ensure_hists_on_device(train.device)
141
141
  flatten_train = train.flatten()
142
- filtered_train = torch.masked_select(flatten_train, flatten_train != -2)
142
+ filtered_train = torch.masked_select(
143
+ flatten_train, ((flatten_train >= 0) & (flatten_train <= self.item_count - 1))
144
+ )
143
145
  self._train_hist += torch.histc(filtered_train.float(), bins=self.item_count, min=0, max=self.item_count - 1)
144
146
 
145
147
  def get_metrics(self) -> Mapping[str, float]:
@@ -193,7 +195,7 @@ class _MetricBuilder(abc.ABC):
193
195
 
194
196
  class TorchMetricsBuilder(_MetricBuilder):
195
197
  """
196
- Computes specified metrics over multiple batches
198
+ Computes specified metrics over multiple batches.
197
199
  """
198
200
 
199
201
  def __init__(
@@ -203,12 +205,12 @@ class TorchMetricsBuilder(_MetricBuilder):
203
205
  item_count: Optional[int] = None,
204
206
  ) -> None:
205
207
  """
206
- :param metrics: (list[MetricName]): Names of metrics to calculate.
207
- Default: `["map", "ndcg", "recall"]`.
208
- :param top_k: (list): Consider the highest k scores in the ranking.
209
- Default: `[1, 5, 10, 20]`.
210
- :param item_count: (optional, int): the total number of items in the dataset.
211
- You can omit this parameter if you don't need to calculate the Coverage metric.
208
+ :param metrics: Names of metrics to calculate.
209
+ Default: ``["map", "ndcg", "recall"]``.
210
+ :param top_k: Consider the highest k scores in the ranking.
211
+ Default: ``[1, 5, 10, 20]``.
212
+ :param item_count: the total number of items in the dataset.
213
+ You can omit this parameter if you don't need to calculate the ``Coverage`` metric.
212
214
  """
213
215
  self._mr = _MetricRequirements.from_metrics(
214
216
  set(metrics),
@@ -272,12 +274,16 @@ class TorchMetricsBuilder(_MetricBuilder):
272
274
  """
273
275
  Add a batch with predictions, ground truth and train set to calculate the metrics.
274
276
 
275
- :param predictions: (torch.LongTensor): A batch with the same number of recommendations for each user.
276
- :param ground_truth: (torch.LongTensor): A batch corresponding to the test set for each user.
277
- If users have a test set of different sizes then you need to do the padding using -1.
278
- :param train: (optional, int): A batch corresponding to the train set for each user.
279
- If users have a train set of different sizes then you need to do the padding using -2.
280
- You can omit this parameter if you don't need to calculate the coverage or novelty metrics.
277
+ :param predictions: A batch with the same number of recommendations for each user.
278
+ :param ground_truth: A batch corresponding to the test set for each user.
279
+ If users have a test set of different sizes then you need to do
280
+ the padding using a value that is not found in the item ID's.
281
+ For example, these can be negative values.
282
+ :param train: A batch corresponding to the train set for each user.
283
+ If users have a train set of different sizes then you need to do
284
+ the padding using a value that is not found in the item ID's and ``ground_truth``.
285
+ For example, these can be negative values.
286
+ You can omit this parameter if you don't need to calculate the ``coverage`` or ``novelty`` metrics.
281
287
  """
282
288
  self._ensure_constants_on_device(predictions.device)
283
289
  metrics_sum = np.array(self._compute_metrics_sum(predictions, ground_truth, train), dtype=np.float64)
@@ -6,9 +6,9 @@ import torch
6
6
 
7
7
  @dataclass(frozen=True)
8
8
  class SCEParams:
9
- """Set of parameters for ScalableCrossEntropyLoss.
9
+ """
10
+ Set of parameters for ScalableCrossEntropyLoss.
10
11
 
11
- Constructor arguments:
12
12
  :param n_buckets: Number of buckets into which samples will be distributed.
13
13
  :param bucket_size_x: Number of item hidden representations that will be in each bucket.
14
14
  :param bucket_size_y: Number of item embeddings that will be in each bucket.
@@ -33,11 +33,6 @@ class ScalableCrossEntropyLoss:
33
33
 
34
34
  :param SCEParams: Dataclass with ScalableCrossEntropyLoss parameters.
35
35
  Dataclass contains following values:
36
- :param n_buckets: Number of buckets into which samples will be distributed.
37
- :param bucket_size_x: Number of item hidden representations that will be in each bucket.
38
- :param bucket_size_y: Number of item embeddings that will be in each bucket.
39
- :param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
40
- Default: ``False``.
41
36
  """
42
37
  assert all(
43
38
  param is not None for param in sce_params._get_not_none_params()
@@ -1,4 +1,9 @@
1
1
  from replay.utils import TORCH_AVAILABLE
2
2
 
3
3
  if TORCH_AVAILABLE:
4
- from .optimizer_factory import FatLRSchedulerFactory, FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
4
+ from .optimizer_factory import (
5
+ FatLRSchedulerFactory,
6
+ FatOptimizerFactory,
7
+ LRSchedulerFactory,
8
+ OptimizerFactory,
9
+ )
@@ -2,8 +2,13 @@ import abc
2
2
  from collections.abc import Iterator
3
3
 
4
4
  import torch
5
+ from typing_extensions import deprecated
5
6
 
6
7
 
8
+ @deprecated(
9
+ "`OptimizerFactory` class is deprecated. Use `replay.nn.lightning.optimizer.BaseOptimizerFactory` instead.",
10
+ stacklevel=2,
11
+ )
7
12
  class OptimizerFactory(abc.ABC):
8
13
  """
9
14
  Interface for optimizer factory
@@ -20,6 +25,10 @@ class OptimizerFactory(abc.ABC):
20
25
  """
21
26
 
22
27
 
28
+ @deprecated(
29
+ "`LRSchedulerFactory` class is deprecated. Use `replay.nn.lightning.scheduler.BaseLRSchedulerFactory` instead.",
30
+ stacklevel=2,
31
+ )
23
32
  class LRSchedulerFactory(abc.ABC):
24
33
  """
25
34
  Interface for learning rate scheduler factory
@@ -36,6 +45,9 @@ class LRSchedulerFactory(abc.ABC):
36
45
  """
37
46
 
38
47
 
48
+ @deprecated(
49
+ "`FatOptimizerFactory` class is deprecated. Use `replay.nn.lightning.optimizer.OptimizerFactory` instead.",
50
+ )
39
51
  class FatOptimizerFactory(OptimizerFactory):
40
52
  """
41
53
  Factory that creates optimizer depending on passed parameters
@@ -75,6 +87,9 @@ class FatOptimizerFactory(OptimizerFactory):
75
87
  raise ValueError(msg)
76
88
 
77
89
 
90
+ @deprecated(
91
+ "`FatLRSchedulerFactory` class is deprecated. Use `replay.nn.lightning.scheduler.LRSchedulerFactory` instead.",
92
+ )
78
93
  class FatLRSchedulerFactory(LRSchedulerFactory):
79
94
  """
80
95
  Factory that creates learning rate schedule depending on passed parameters