replay-rec 0.18.1rc0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/nn/schema.py +3 -1
  3. replay/metrics/surprisal.py +4 -2
  4. replay/models/lin_ucb.py +2 -3
  5. replay/models/nn/loss/__init__.py +1 -0
  6. replay/models/nn/loss/sce.py +131 -0
  7. replay/models/nn/sequential/bert4rec/lightning.py +36 -4
  8. replay/models/nn/sequential/bert4rec/model.py +5 -46
  9. replay/models/nn/sequential/sasrec/lightning.py +27 -3
  10. replay/models/nn/sequential/sasrec/model.py +1 -1
  11. replay/preprocessing/filters.py +102 -1
  12. replay/preprocessing/label_encoder.py +8 -4
  13. {replay_rec-0.18.1rc0.dist-info → replay_rec-0.19.0.dist-info}/METADATA +5 -12
  14. {replay_rec-0.18.1rc0.dist-info → replay_rec-0.19.0.dist-info}/RECORD +16 -70
  15. {replay_rec-0.18.1rc0.dist-info → replay_rec-0.19.0.dist-info}/WHEEL +1 -1
  16. replay/experimental/__init__.py +0 -0
  17. replay/experimental/metrics/__init__.py +0 -62
  18. replay/experimental/metrics/base_metric.py +0 -602
  19. replay/experimental/metrics/coverage.py +0 -97
  20. replay/experimental/metrics/experiment.py +0 -175
  21. replay/experimental/metrics/hitrate.py +0 -26
  22. replay/experimental/metrics/map.py +0 -30
  23. replay/experimental/metrics/mrr.py +0 -18
  24. replay/experimental/metrics/ncis_precision.py +0 -31
  25. replay/experimental/metrics/ndcg.py +0 -49
  26. replay/experimental/metrics/precision.py +0 -22
  27. replay/experimental/metrics/recall.py +0 -25
  28. replay/experimental/metrics/rocauc.py +0 -49
  29. replay/experimental/metrics/surprisal.py +0 -90
  30. replay/experimental/metrics/unexpectedness.py +0 -76
  31. replay/experimental/models/__init__.py +0 -13
  32. replay/experimental/models/admm_slim.py +0 -205
  33. replay/experimental/models/base_neighbour_rec.py +0 -204
  34. replay/experimental/models/base_rec.py +0 -1340
  35. replay/experimental/models/base_torch_rec.py +0 -234
  36. replay/experimental/models/cql.py +0 -454
  37. replay/experimental/models/ddpg.py +0 -923
  38. replay/experimental/models/dt4rec/__init__.py +0 -0
  39. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  40. replay/experimental/models/dt4rec/gpt1.py +0 -401
  41. replay/experimental/models/dt4rec/trainer.py +0 -127
  42. replay/experimental/models/dt4rec/utils.py +0 -265
  43. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  44. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  45. replay/experimental/models/hierarchical_recommender.py +0 -331
  46. replay/experimental/models/implicit_wrap.py +0 -131
  47. replay/experimental/models/lightfm_wrap.py +0 -302
  48. replay/experimental/models/mult_vae.py +0 -332
  49. replay/experimental/models/neural_ts.py +0 -986
  50. replay/experimental/models/neuromf.py +0 -406
  51. replay/experimental/models/scala_als.py +0 -296
  52. replay/experimental/models/u_lin_ucb.py +0 -115
  53. replay/experimental/nn/data/__init__.py +0 -1
  54. replay/experimental/nn/data/schema_builder.py +0 -102
  55. replay/experimental/preprocessing/__init__.py +0 -3
  56. replay/experimental/preprocessing/data_preparator.py +0 -839
  57. replay/experimental/preprocessing/padder.py +0 -229
  58. replay/experimental/preprocessing/sequence_generator.py +0 -208
  59. replay/experimental/scenarios/__init__.py +0 -1
  60. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  61. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  62. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  63. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  64. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  65. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  66. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  67. replay/experimental/utils/__init__.py +0 -0
  68. replay/experimental/utils/logger.py +0 -24
  69. replay/experimental/utils/model_handler.py +0 -186
  70. replay/experimental/utils/session_handler.py +0 -44
  71. replay_rec-0.18.1rc0.dist-info/NOTICE +0 -41
  72. {replay_rec-0.18.1rc0.dist-info → replay_rec-0.19.0.dist-info}/LICENSE +0 -0
replay/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """ RecSys library """
2
2
 
3
- __version__ = "0.18.1.preview"
3
+ __version__ = "0.19.0"
replay/data/nn/schema.py CHANGED
@@ -7,6 +7,7 @@ from typing import (
7
7
  List,
8
8
  Mapping,
9
9
  Optional,
10
+ OrderedDict,
10
11
  Sequence,
11
12
  Set,
12
13
  Union,
@@ -262,6 +263,8 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
262
263
  """
263
264
  :param features_list: list of tensor feature infos.
264
265
  """
266
+ if isinstance(features_list, OrderedDict):
267
+ features_list = list(features_list.values())
265
268
  features_list = [features_list] if not isinstance(features_list, Sequence) else features_list
266
269
  self._tensor_schema = {feature.name: feature for feature in features_list}
267
270
 
@@ -501,7 +504,6 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
501
504
  filtered_features,
502
505
  )
503
506
  )
504
-
505
507
  return TensorSchema(filtered_features)
506
508
 
507
509
  @staticmethod
@@ -129,7 +129,9 @@ class Surprisal(Metric):
129
129
  item_weights = train.group_by(self.item_column).agg(
130
130
  (np.log2(n_users / pl.col(self.query_column).n_unique()) / np.log2(n_users)).alias("weight")
131
131
  )
132
- recommendations = recommendations.join(item_weights, on=self.item_column, how="left").fill_nan(1.0)
132
+ recommendations = recommendations.join(item_weights, on=self.item_column, how="left").with_columns(
133
+ pl.col("weight").fill_null(1.0)
134
+ )
133
135
 
134
136
  sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
135
137
  return self._rearrange_columns(sorted_by_score_recommendations)
@@ -175,7 +177,7 @@ class Surprisal(Metric):
175
177
 
176
178
  weights = self._get_recommendation_weights(recommendations, train)
177
179
  return self._dict_call(
178
- list(train),
180
+ list(recommendations),
179
181
  pred_item_id=recommendations,
180
182
  pred_weight=weights,
181
183
  )
replay/models/lin_ucb.py CHANGED
@@ -98,9 +98,8 @@ class LinUCB(HybridRecommender):
98
98
  The model assumes a linear relationship between user context, item features and action rewards,
99
99
  making it efficient for high-dimensional contexts.
100
100
 
101
- Note:
102
- It's recommended to scale features to a similar range (e.g., using StandardScaler or MinMaxScaler)
103
- to ensure proper convergence and prevent numerical instability (since relationships to learn are linear).
101
+ Note: It's recommended to scale features to a similar range (e.g., using StandardScaler or MinMaxScaler)
102
+ to ensure proper convergence and prevent numerical instability (since relationships to learn are linear).
104
103
 
105
104
  >>> import pandas as pd
106
105
  >>> from replay.data.dataset import (
@@ -0,0 +1 @@
1
+ from .sce import ScalableCrossEntropyLoss, SCEParams
@@ -0,0 +1,131 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class SCEParams:
9
+ """Set of parameters for ScalableCrossEntropyLoss.
10
+
11
+ Constructor arguments:
12
+ :param n_buckets: Number of buckets into which samples will be distributed.
13
+ :param bucket_size_x: Number of item hidden representations that will be in each bucket.
14
+ :param bucket_size_y: Number of item embeddings that will be in each bucket.
15
+ :param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
16
+ Default: ``False``.
17
+ """
18
+
19
+ n_buckets: int
20
+ bucket_size_x: int
21
+ bucket_size_y: int
22
+ mix_x: bool = False
23
+
24
+ def _get_not_none_params(self):
25
+ return [self.n_buckets, self.bucket_size_x, self.bucket_size_y]
26
+
27
+
28
+ class ScalableCrossEntropyLoss:
29
+ def __init__(self, sce_params: SCEParams):
30
+ """
31
+ ScalableCrossEntropyLoss for Sequential Recommendations with Large Item Catalogs.
32
+ Reference article may be found at https://arxiv.org/pdf/2409.18721.
33
+
34
+ :param SCEParams: Dataclass with ScalableCrossEntropyLoss parameters.
35
+ Dataclass contains following values:
36
+ :param n_buckets: Number of buckets into which samples will be distributed.
37
+ :param bucket_size_x: Number of item hidden representations that will be in each bucket.
38
+ :param bucket_size_y: Number of item embeddings that will be in each bucket.
39
+ :param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
40
+ Default: ``False``.
41
+ """
42
+ assert all(
43
+ param is not None for param in sce_params._get_not_none_params()
44
+ ), "You should define ``n_buckets``, ``bucket_size_x``, ``bucket_size_y`` when using SCE loss function."
45
+ self._n_buckets = sce_params.n_buckets
46
+ self._bucket_size_x = sce_params.bucket_size_x
47
+ self._bucket_size_y = sce_params.bucket_size_y
48
+ self._mix_x = sce_params.mix_x
49
+
50
+ def __call__(
51
+ self,
52
+ embeddings: torch.Tensor,
53
+ positive_labels: torch.LongTensor,
54
+ all_embeddings: torch.Tensor,
55
+ padding_mask: torch.BoolTensor,
56
+ tokens_mask: Optional[torch.BoolTensor] = None,
57
+ ) -> torch.Tensor:
58
+ """
59
+ ScalableCrossEntropyLoss computation.
60
+
61
+ :param embeddings: Matrix of the last transformer block outputs.
62
+ :param positive_labels: Positive labels.
63
+ :param all_embeddings: Matrix of all item embeddings.
64
+ :param padding_mask: Padding mask.
65
+ :param tokens_mask: Tokens mask (need only for Bert4Rec).
66
+ Default: ``None``.
67
+ """
68
+ masked_tokens = padding_mask if tokens_mask is None else ~(~padding_mask + tokens_mask)
69
+
70
+ hd = torch.tensor(embeddings.shape[-1])
71
+ x = embeddings.view(-1, hd)
72
+ y = positive_labels.view(-1)
73
+ w = all_embeddings
74
+
75
+ correct_class_logits_ = (x * torch.index_select(w, dim=0, index=y)).sum(dim=1) # (bs,)
76
+
77
+ with torch.no_grad():
78
+ if self._mix_x:
79
+ omega = 1 / torch.sqrt(torch.sqrt(hd)) * torch.randn(x.shape[0], self._n_buckets, device=x.device)
80
+ buckets = omega.T @ x
81
+ del omega
82
+ else:
83
+ buckets = (
84
+ 1 / torch.sqrt(torch.sqrt(hd)) * torch.randn(self._n_buckets, hd, device=x.device)
85
+ ) # (n_b, hd)
86
+
87
+ with torch.no_grad():
88
+ x_bucket = buckets @ x.T # (n_b, hd) x (hd, b) -> (n_b, b)
89
+ x_bucket[:, ~padding_mask.view(-1)] = float("-inf")
90
+ _, top_x_bucket = torch.topk(x_bucket, dim=1, k=self._bucket_size_x) # (n_b, bs_x)
91
+ del x_bucket
92
+
93
+ y_bucket = buckets @ w.T # (n_b, hd) x (hd, n_cl) -> (n_b, n_cl)
94
+
95
+ _, top_y_bucket = torch.topk(y_bucket, dim=1, k=self._bucket_size_y) # (n_b, bs_y)
96
+ del y_bucket
97
+
98
+ x_bucket = torch.gather(x, 0, top_x_bucket.view(-1, 1).expand(-1, hd)).view(
99
+ self._n_buckets, self._bucket_size_x, hd
100
+ ) # (n_b, bs_x, hd)
101
+ y_bucket = torch.gather(w, 0, top_y_bucket.view(-1, 1).expand(-1, hd)).view(
102
+ self._n_buckets, self._bucket_size_y, hd
103
+ ) # (n_b, bs_y, hd)
104
+
105
+ wrong_class_logits = x_bucket @ y_bucket.transpose(-1, -2) # (n_b, bs_x, bs_y)
106
+ mask = (
107
+ torch.index_select(y, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[
108
+ :, :, None
109
+ ]
110
+ == top_y_bucket[:, None, :]
111
+ ) # (n_b, bs_x, bs_y)
112
+ wrong_class_logits = wrong_class_logits.masked_fill(mask, float("-inf")) # (n_b, bs_x, bs_y)
113
+ correct_class_logits = torch.index_select(correct_class_logits_, dim=0, index=top_x_bucket.view(-1)).view(
114
+ self._n_buckets, self._bucket_size_x
115
+ )[
116
+ :, :, None
117
+ ] # (n_b, bs_x, 1)
118
+ logits = torch.cat((wrong_class_logits, correct_class_logits), dim=2) # (n_b, bs_x, bs_y + 1)
119
+
120
+ loss_ = torch.nn.functional.cross_entropy(
121
+ logits.view(-1, logits.shape[-1]),
122
+ (logits.shape[-1] - 1)
123
+ * torch.ones(logits.shape[0] * logits.shape[1], dtype=torch.int64, device=logits.device),
124
+ reduction="none",
125
+ ) # (n_b * bs_x,)
126
+ loss = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
127
+ loss.scatter_reduce_(0, top_x_bucket.view(-1), loss_, reduce="amax", include_self=False)
128
+ loss = loss[(loss != 0) & (masked_tokens).view(-1)]
129
+ loss = torch.mean(loss)
130
+
131
+ return loss
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Any, Dict, Optional, Tuple, Union, cast
2
+ from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -27,7 +27,7 @@ class Bert4Rec(lightning.LightningModule):
27
27
  pass_per_transformer_block_count: int = 1,
28
28
  enable_positional_embedding: bool = True,
29
29
  enable_embedding_tying: bool = False,
30
- loss_type: str = "CE",
30
+ loss_type: Literal["BCE", "CE", "CE_restricted"] = "CE",
31
31
  loss_sample_count: Optional[int] = None,
32
32
  negative_sampling_strategy: str = "global_uniform",
33
33
  negatives_sharing: bool = False,
@@ -54,7 +54,7 @@ class Bert4Rec(lightning.LightningModule):
54
54
  If `True` - result scores are calculated by dot product of input and output embeddings,
55
55
  if `False` - default linear layer is applied to calculate logits for each item.
56
56
  Default: ``False``.
57
- :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``.
57
+ :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"CE_restricted"``.
58
58
  Default: ``CE``.
59
59
  :param loss_sample_count (Optional[int]): Sample count to calculate loss.
60
60
  Default: ``None``.
@@ -197,6 +197,8 @@ class Bert4Rec(lightning.LightningModule):
197
197
  loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
198
198
  elif self._loss_type == "CE":
199
199
  loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
200
+ elif self._loss_type == "CE_restricted":
201
+ loss_func = self._compute_loss_ce_restricted
200
202
  else:
201
203
  msg = f"Not supported loss type: {self._loss_type}"
202
204
  raise ValueError(msg)
@@ -316,6 +318,20 @@ class Bert4Rec(lightning.LightningModule):
316
318
  loss = self._loss(logits, labels_flat)
317
319
  return loss
318
320
 
321
+ def _compute_loss_ce_restricted(
322
+ self,
323
+ feature_tensors: TensorMap,
324
+ positive_labels: torch.LongTensor,
325
+ padding_mask: torch.BoolTensor,
326
+ tokens_mask: torch.BoolTensor,
327
+ ) -> torch.Tensor:
328
+ (logits, labels) = self._get_restricted_logits_for_ce_loss(
329
+ feature_tensors, positive_labels, padding_mask, tokens_mask
330
+ )
331
+
332
+ loss = self._loss(logits, labels)
333
+ return loss
334
+
319
335
  def _get_sampled_logits(
320
336
  self,
321
337
  feature_tensors: TensorMap,
@@ -398,11 +414,27 @@ class Bert4Rec(lightning.LightningModule):
398
414
  vocab_size,
399
415
  )
400
416
 
417
+ def _get_restricted_logits_for_ce_loss(
418
+ self,
419
+ feature_tensors: TensorMap,
420
+ positive_labels: torch.LongTensor,
421
+ padding_mask: torch.BoolTensor,
422
+ tokens_mask: torch.BoolTensor,
423
+ ):
424
+ labels_mask = (~padding_mask) + tokens_mask
425
+ masked_tokens = ~labels_mask
426
+ positive_labels = cast(
427
+ torch.LongTensor, torch.masked_select(positive_labels, masked_tokens)
428
+ ) # (masked_batch_seq_size,)
429
+ output_emb = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)[masked_tokens]
430
+ logits = self._model.get_logits(output_emb)
431
+ return (logits, positive_labels)
432
+
401
433
  def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntropyLoss]:
402
434
  if self._loss_type == "BCE":
403
435
  return torch.nn.BCEWithLogitsLoss(reduction="sum")
404
436
 
405
- if self._loss_type == "CE":
437
+ if self._loss_type == "CE" or self._loss_type == "CE_restricted":
406
438
  return torch.nn.CrossEntropyLoss()
407
439
 
408
440
  msg = "Not supported loss_type"
@@ -1,9 +1,9 @@
1
1
  import contextlib
2
- import math
3
2
  from abc import ABC, abstractmethod
4
3
  from typing import Dict, Optional, Union
5
4
 
6
5
  import torch
6
+ import torch.nn as nn
7
7
 
8
8
  from replay.data.nn import TensorFeatureInfo, TensorMap, TensorSchema
9
9
 
@@ -379,7 +379,7 @@ class BaseHead(ABC, torch.nn.Module):
379
379
  item_embeddings = item_embeddings[item_ids]
380
380
  bias = bias[item_ids]
381
381
 
382
- logits = item_embeddings.matmul(out_embeddings.unsqueeze(-1)).squeeze(-1) + bias
382
+ logits = torch.nn.functional.linear(out_embeddings, item_embeddings, bias)
383
383
  return logits
384
384
 
385
385
  @abstractmethod
@@ -471,11 +471,11 @@ class TransformerBlock(torch.nn.Module):
471
471
  super().__init__()
472
472
  self.attention = torch.nn.MultiheadAttention(hidden_size, attn_heads, dropout=dropout, batch_first=True)
473
473
  self.attention_dropout = torch.nn.Dropout(dropout)
474
- self.attention_norm = LayerNorm(hidden_size)
474
+ self.attention_norm = torch.nn.LayerNorm(hidden_size)
475
475
 
476
476
  self.pff = PositionwiseFeedForward(d_model=hidden_size, d_ff=feed_forward_hidden, dropout=dropout)
477
477
  self.pff_dropout = torch.nn.Dropout(dropout)
478
- self.pff_norm = LayerNorm(hidden_size)
478
+ self.pff_norm = torch.nn.LayerNorm(hidden_size)
479
479
 
480
480
  self.dropout = torch.nn.Dropout(p=dropout)
481
481
 
@@ -501,33 +501,6 @@ class TransformerBlock(torch.nn.Module):
501
501
  return self.dropout(z)
502
502
 
503
503
 
504
- class LayerNorm(torch.nn.Module):
505
- """
506
- Construct a layernorm module (See citation for details).
507
- """
508
-
509
- def __init__(self, features: int, eps: float = 1e-6):
510
- """
511
- :param features: Number of features.
512
- :param eps: A value added to the denominator for numerical stability.
513
- Default: ``1e-6``.
514
- """
515
- super().__init__()
516
- self.a_2 = torch.nn.Parameter(torch.ones(features))
517
- self.b_2 = torch.nn.Parameter(torch.zeros(features))
518
- self.eps = eps
519
-
520
- def forward(self, x: torch.Tensor) -> torch.Tensor:
521
- """
522
- :param x: Input tensor.
523
-
524
- :returns: Normalized input tensor.
525
- """
526
- mean = x.mean(-1, keepdim=True)
527
- std = x.std(-1, keepdim=True)
528
- return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
529
-
530
-
531
504
  class PositionwiseFeedForward(torch.nn.Module):
532
505
  """
533
506
  Implements FFN equation.
@@ -544,7 +517,7 @@ class PositionwiseFeedForward(torch.nn.Module):
544
517
  self.w_1 = torch.nn.Linear(d_model, d_ff)
545
518
  self.w_2 = torch.nn.Linear(d_ff, d_model)
546
519
  self.dropout = torch.nn.Dropout(dropout)
547
- self.activation = GELU()
520
+ self.activation = nn.GELU()
548
521
 
549
522
  def forward(self, x: torch.Tensor) -> torch.Tensor:
550
523
  """
@@ -553,17 +526,3 @@ class PositionwiseFeedForward(torch.nn.Module):
553
526
  :returns: Position wised output.
554
527
  """
555
528
  return self.w_2(self.dropout(self.activation(self.w_1(x))))
556
-
557
-
558
- class GELU(torch.nn.Module):
559
- """
560
- Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
561
- """
562
-
563
- def forward(self, x: torch.Tensor) -> torch.Tensor:
564
- """
565
- :param x: Input tensor.
566
-
567
- :returns: Activated input tensor.
568
- """
569
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
@@ -1,10 +1,11 @@
1
1
  import math
2
- from typing import Any, Dict, Optional, Tuple, Union, cast
2
+ from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
3
3
 
4
4
  import lightning
5
5
  import torch
6
6
 
7
7
  from replay.data.nn import TensorMap, TensorSchema
8
+ from replay.models.nn.loss import ScalableCrossEntropyLoss, SCEParams
8
9
  from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
9
10
 
10
11
  from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch
@@ -29,12 +30,13 @@ class SasRec(lightning.LightningModule):
29
30
  dropout_rate: float = 0.2,
30
31
  ti_modification: bool = False,
31
32
  time_span: int = 256,
32
- loss_type: str = "CE",
33
+ loss_type: Literal["BCE", "CE", "SCE"] = "CE",
33
34
  loss_sample_count: Optional[int] = None,
34
35
  negative_sampling_strategy: str = "global_uniform",
35
36
  negatives_sharing: bool = False,
36
37
  optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
37
38
  lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
39
+ sce_params: Optional[SCEParams] = None,
38
40
  ):
39
41
  """
40
42
  :param tensor_schema: Tensor schema of features.
@@ -52,9 +54,10 @@ class SasRec(lightning.LightningModule):
52
54
  Default: ``False``.
53
55
  :param time_span: Time span value.
54
56
  Default: ``256``.
55
- :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``.
57
+ :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"SCE"``.
56
58
  Default: ``CE``.
57
59
  :param loss_sample_count (Optional[int]): Sample count to calculate loss.
60
+ Suitable for ``"CE"`` and ``"BCE"`` loss functions.
58
61
  Default: ``None``.
59
62
  :param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
60
63
  Is used when large count of items in dataset.
@@ -66,6 +69,8 @@ class SasRec(lightning.LightningModule):
66
69
  Default: ``FatOptimizerFactory``.
67
70
  :param lr_scheduler_factory: Learning rate schedule factory.
68
71
  Default: ``None``.
72
+ :param sce_params: Dataclass with SCE parameters. Need to be defined if ``loss_type`` is ``SCE``.
73
+ Default: ``None``.
69
74
  """
70
75
  super().__init__()
71
76
  self.save_hyperparameters()
@@ -85,9 +90,12 @@ class SasRec(lightning.LightningModule):
85
90
  self._negatives_sharing = negatives_sharing
86
91
  self._optimizer_factory = optimizer_factory
87
92
  self._lr_scheduler_factory = lr_scheduler_factory
93
+ self._sce_params = sce_params
88
94
  self._loss = self._create_loss()
89
95
  self._schema = tensor_schema
90
96
  assert negative_sampling_strategy in {"global_uniform", "inbatch"}
97
+ if self._loss_type == "SCE":
98
+ assert sce_params is not None, "You should define ``sce_params`` when using SCE loss function."
91
99
 
92
100
  item_count = tensor_schema.item_id_features.item().cardinality
93
101
  assert item_count
@@ -197,6 +205,8 @@ class SasRec(lightning.LightningModule):
197
205
  loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
198
206
  elif self._loss_type == "CE":
199
207
  loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
208
+ elif self._loss_type == "SCE":
209
+ loss_func = self._compute_loss_scalable_ce
200
210
  else:
201
211
  msg = f"Not supported loss type: {self._loss_type}"
202
212
  raise ValueError(msg)
@@ -314,6 +324,17 @@ class SasRec(lightning.LightningModule):
314
324
  loss = self._loss(logits, labels_flat)
315
325
  return loss
316
326
 
327
+ def _compute_loss_scalable_ce(
328
+ self,
329
+ feature_tensors: TensorMap,
330
+ positive_labels: torch.LongTensor,
331
+ padding_mask: torch.BoolTensor,
332
+ tokens_mask: torch.BoolTensor, # noqa: ARG002
333
+ ) -> torch.Tensor:
334
+ emb = self._model.forward_step(feature_tensors, padding_mask)
335
+ all_embeddings = self.get_all_embeddings()["item_embedding"]
336
+ return self._loss(emb, positive_labels, all_embeddings, padding_mask)
337
+
317
338
  def _get_sampled_logits(
318
339
  self,
319
340
  feature_tensors: TensorMap,
@@ -401,6 +422,9 @@ class SasRec(lightning.LightningModule):
401
422
  if self._loss_type == "CE":
402
423
  return torch.nn.CrossEntropyLoss()
403
424
 
425
+ if self._loss_type == "SCE":
426
+ return ScalableCrossEntropyLoss(self._sce_params)
427
+
404
428
  msg = "Not supported loss_type"
405
429
  raise NotImplementedError(msg)
406
430
 
@@ -298,7 +298,7 @@ class EmbeddingTyingHead(torch.nn.Module):
298
298
  if len(item_embeddings.shape) > 2: # global_uniform, negative sharing=False, train only
299
299
  logits = (item_embeddings * out_embeddings.unsqueeze(-2)).sum(dim=-1)
300
300
  else:
301
- logits = item_embeddings.matmul(out_embeddings.unsqueeze(-1)).squeeze(-1)
301
+ logits = torch.matmul(out_embeddings, item_embeddings.t())
302
302
  return logits
303
303
 
304
304
 
@@ -4,7 +4,8 @@ Select or remove data by some criteria
4
4
 
5
5
  from abc import ABC, abstractmethod
6
6
  from datetime import datetime, timedelta
7
- from typing import Callable, Optional, Tuple, Union
7
+ from typing import Callable, Literal, Optional, Tuple, Union
8
+ from uuid import uuid4
8
9
 
9
10
  import numpy as np
10
11
  import pandas as pd
@@ -989,3 +990,103 @@ class QuantileItemsFilter(_BaseFilter):
989
990
  )
990
991
  short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
991
992
  return long_tail.select(df.columns).union(short_tail.select(df.columns))
993
+
994
+
995
+ class ConsecutiveDuplicatesFilter(_BaseFilter):
996
+ """Removes consecutive duplicate items from sequential dataset.
997
+
998
+ >>> import datetime as dt
999
+ >>> import pandas as pd
1000
+ >>> from replay.utils.spark_utils import convert2spark
1001
+ >>> interactions = pd.DataFrame({
1002
+ ... "user_id": ["u0", "u1", "u1", "u0", "u0", "u0", "u1", "u0"],
1003
+ ... "item_id": ["i0", "i1", "i1", "i2", "i0", "i1", "i2", "i1"],
1004
+ ... "timestamp": [dt.datetime(2024, 1, 1) + dt.timedelta(days=i) for i in range(8)]
1005
+ ... })
1006
+ >>> interactions = convert2spark(interactions)
1007
+ >>> interactions.show()
1008
+ +-------+-------+-------------------+
1009
+ |user_id|item_id| timestamp|
1010
+ +-------+-------+-------------------+
1011
+ | u0| i0|2024-01-01 00:00:00|
1012
+ | u1| i1|2024-01-02 00:00:00|
1013
+ | u1| i1|2024-01-03 00:00:00|
1014
+ | u0| i2|2024-01-04 00:00:00|
1015
+ | u0| i0|2024-01-05 00:00:00|
1016
+ | u0| i1|2024-01-06 00:00:00|
1017
+ | u1| i2|2024-01-07 00:00:00|
1018
+ | u0| i1|2024-01-08 00:00:00|
1019
+ +-------+-------+-------------------+
1020
+ <BLANKLINE>
1021
+
1022
+ >>> ConsecutiveDuplicatesFilter(query_column="user_id").transform(interactions).show()
1023
+ +-------+-------+-------------------+
1024
+ |user_id|item_id| timestamp|
1025
+ +-------+-------+-------------------+
1026
+ | u0| i0|2024-01-01 00:00:00|
1027
+ | u0| i2|2024-01-04 00:00:00|
1028
+ | u0| i0|2024-01-05 00:00:00|
1029
+ | u0| i1|2024-01-06 00:00:00|
1030
+ | u1| i1|2024-01-02 00:00:00|
1031
+ | u1| i2|2024-01-07 00:00:00|
1032
+ +-------+-------+-------------------+
1033
+ <BLANKLINE>
1034
+ """
1035
+
1036
+ def __init__(
1037
+ self,
1038
+ keep: Literal["first", "last"] = "first",
1039
+ query_column: str = "query_id",
1040
+ item_column: str = "item_id",
1041
+ timestamp_column: str = "timestamp",
1042
+ ) -> None:
1043
+ """
1044
+ :param keep: whether to keep first or last occurrence,
1045
+ Default: ``first``.
1046
+ :param query_column: query column,
1047
+ Default: ``query_id``.
1048
+ :param item_column: item column,
1049
+ Default: ``item_id``.
1050
+ :param timestamp_column: timestamp column,
1051
+ Default: ``timestamp``.
1052
+ """
1053
+ super().__init__()
1054
+ self.query_column = query_column
1055
+ self.item_column = item_column
1056
+ self.timestamp_column = timestamp_column
1057
+
1058
+ if keep not in ("first", "last"):
1059
+ msg = "`keep` must be either 'first' or 'last'"
1060
+ raise ValueError(msg)
1061
+
1062
+ self.bias = 1 if keep == "first" else -1
1063
+ self.temporary_column = f"__shifted_{uuid4().hex[:8]}"
1064
+
1065
+ def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
1066
+ interactions = interactions.sort_values(self.timestamp_column)
1067
+ interactions[self.temporary_column] = interactions.groupby(self.query_column)[self.item_column].shift(
1068
+ periods=self.bias
1069
+ )
1070
+ return (
1071
+ interactions[interactions[self.item_column] != interactions[self.temporary_column]]
1072
+ .drop(self.temporary_column, axis=1)
1073
+ .reset_index(drop=True)
1074
+ )
1075
+
1076
+ def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
1077
+ return (
1078
+ interactions.sort(self.timestamp_column)
1079
+ .with_columns(
1080
+ pl.col(self.item_column).shift(n=self.bias).over(self.query_column).alias(self.temporary_column)
1081
+ )
1082
+ .filter((pl.col(self.item_column) != pl.col(self.temporary_column)).fill_null(True))
1083
+ .drop(self.temporary_column)
1084
+ )
1085
+
1086
+ def _filter_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
1087
+ window = Window.partitionBy(self.query_column).orderBy(self.timestamp_column)
1088
+ return (
1089
+ interactions.withColumn(self.temporary_column, sf.lag(self.item_column, offset=self.bias).over(window))
1090
+ .where((sf.col(self.item_column) != sf.col(self.temporary_column)) | sf.col(self.temporary_column).isNull())
1091
+ .drop(self.temporary_column)
1092
+ )
@@ -10,7 +10,6 @@ import abc
10
10
  import json
11
11
  import os
12
12
  import warnings
13
- from itertools import chain
14
13
  from pathlib import Path
15
14
  from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
16
15
 
@@ -27,7 +26,7 @@ from replay.utils import (
27
26
 
28
27
  if PYSPARK_AVAILABLE:
29
28
  from pyspark.sql import Window, functions as sf # noqa: I001
30
- from pyspark.sql.types import LongType
29
+ from pyspark.sql.types import LongType, IntegerType, ArrayType
31
30
 
32
31
  HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
33
32
 
@@ -336,6 +335,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
336
335
  "with `handle_unknown_strategy=drop` leads to empty dataframe",
337
336
  LabelEncoderTransformWarning,
338
337
  )
338
+ joined_df[self._target_col] = joined_df[self._target_col].astype("int")
339
339
  elif self._handle_unknown == "error":
340
340
  unknown_unique_labels = joined_df[self._col][unknown_mask].unique().tolist()
341
341
  msg = f"Found unknown labels {unknown_unique_labels} in column {self._col} during transform"
@@ -629,8 +629,12 @@ class SequenceEncodingRule(LabelEncodingRule):
629
629
  return self
630
630
 
631
631
  def _transform_spark(self, df: SparkDataFrame, default_value: Optional[int]) -> SparkDataFrame:
632
- map_expr = sf.create_map([sf.lit(x) for x in chain(*self.get_mapping().items())])
633
- encoded_df = df.withColumn(self._target_col, sf.transform(self.column, lambda x: map_expr.getItem(x)))
632
+ def mapper_udf(x):
633
+ return [mapping.get(value) for value in x] # pragma: no cover
634
+
635
+ mapping = self.get_mapping()
636
+ call_mapper_udf = sf.udf(mapper_udf, ArrayType(IntegerType()))
637
+ encoded_df = df.withColumn(self._target_col, call_mapper_udf(sf.col(self.column)))
634
638
 
635
639
  if self._handle_unknown == "drop":
636
640
  encoded_df = encoded_df.withColumn(self._target_col, sf.filter(self._target_col, lambda x: x.isNotNull()))