replay-rec 0.17.1rc0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. replay/__init__.py +2 -1
  2. replay/data/dataset.py +3 -2
  3. replay/data/dataset_utils/dataset_label_encoder.py +1 -0
  4. replay/data/nn/schema.py +5 -5
  5. replay/metrics/__init__.py +1 -0
  6. replay/models/als.py +1 -1
  7. replay/models/base_rec.py +7 -7
  8. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
  9. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
  10. replay/models/nn/sequential/bert4rec/model.py +5 -112
  11. replay/models/nn/sequential/sasrec/model.py +8 -5
  12. replay/optimization/optuna_objective.py +1 -0
  13. replay/preprocessing/converter.py +1 -1
  14. replay/preprocessing/filters.py +19 -18
  15. replay/preprocessing/history_based_fp.py +5 -5
  16. replay/preprocessing/label_encoder.py +1 -0
  17. replay/scenarios/__init__.py +1 -0
  18. replay/splitters/last_n_splitter.py +1 -1
  19. replay/splitters/time_splitter.py +1 -1
  20. replay/splitters/two_stage_splitter.py +8 -6
  21. replay/utils/distributions.py +1 -0
  22. replay/utils/session_handler.py +3 -3
  23. replay/utils/spark_utils.py +2 -2
  24. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/METADATA +12 -18
  25. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/RECORD +27 -80
  26. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/WHEEL +1 -1
  27. replay/experimental/__init__.py +0 -0
  28. replay/experimental/metrics/__init__.py +0 -61
  29. replay/experimental/metrics/base_metric.py +0 -601
  30. replay/experimental/metrics/coverage.py +0 -97
  31. replay/experimental/metrics/experiment.py +0 -175
  32. replay/experimental/metrics/hitrate.py +0 -26
  33. replay/experimental/metrics/map.py +0 -30
  34. replay/experimental/metrics/mrr.py +0 -18
  35. replay/experimental/metrics/ncis_precision.py +0 -31
  36. replay/experimental/metrics/ndcg.py +0 -49
  37. replay/experimental/metrics/precision.py +0 -22
  38. replay/experimental/metrics/recall.py +0 -25
  39. replay/experimental/metrics/rocauc.py +0 -49
  40. replay/experimental/metrics/surprisal.py +0 -90
  41. replay/experimental/metrics/unexpectedness.py +0 -76
  42. replay/experimental/models/__init__.py +0 -10
  43. replay/experimental/models/admm_slim.py +0 -205
  44. replay/experimental/models/base_neighbour_rec.py +0 -204
  45. replay/experimental/models/base_rec.py +0 -1271
  46. replay/experimental/models/base_torch_rec.py +0 -234
  47. replay/experimental/models/cql.py +0 -452
  48. replay/experimental/models/ddpg.py +0 -921
  49. replay/experimental/models/dt4rec/__init__.py +0 -0
  50. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  51. replay/experimental/models/dt4rec/gpt1.py +0 -401
  52. replay/experimental/models/dt4rec/trainer.py +0 -127
  53. replay/experimental/models/dt4rec/utils.py +0 -265
  54. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  55. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  56. replay/experimental/models/implicit_wrap.py +0 -131
  57. replay/experimental/models/lightfm_wrap.py +0 -302
  58. replay/experimental/models/mult_vae.py +0 -331
  59. replay/experimental/models/neuromf.py +0 -405
  60. replay/experimental/models/scala_als.py +0 -296
  61. replay/experimental/nn/data/__init__.py +0 -1
  62. replay/experimental/nn/data/schema_builder.py +0 -55
  63. replay/experimental/preprocessing/__init__.py +0 -3
  64. replay/experimental/preprocessing/data_preparator.py +0 -838
  65. replay/experimental/preprocessing/padder.py +0 -229
  66. replay/experimental/preprocessing/sequence_generator.py +0 -208
  67. replay/experimental/scenarios/__init__.py +0 -1
  68. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  69. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  70. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  71. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  72. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  73. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  74. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  75. replay/experimental/utils/__init__.py +0 -0
  76. replay/experimental/utils/logger.py +0 -24
  77. replay/experimental/utils/model_handler.py +0 -181
  78. replay/experimental/utils/session_handler.py +0 -44
  79. replay_rec-0.17.1rc0.dist-info/NOTICE +0 -41
  80. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/LICENSE +0 -0
replay/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  """ RecSys library """
2
- __version__ = "0.17.1.preview"
2
+
3
+ __version__ = "0.18.0"
replay/data/dataset.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ``Dataset`` universal dataset class for manipulating interactions and feed data to models.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
6
7
  import json
@@ -606,7 +607,7 @@ class Dataset:
606
607
  if self.is_pandas:
607
608
  min_id = data[column].min()
608
609
  elif self.is_spark:
609
- min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
610
+ min_id = data.agg(sf.min(column).alias("min_index")).first()[0]
610
611
  else:
611
612
  min_id = data[column].min()
612
613
  if min_id < 0:
@@ -616,7 +617,7 @@ class Dataset:
616
617
  if self.is_pandas:
617
618
  max_id = data[column].max()
618
619
  elif self.is_spark:
619
- max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
620
+ max_id = data.agg(sf.max(column).alias("max_index")).first()[0]
620
621
  else:
621
622
  max_id = data[column].max()
622
623
 
@@ -4,6 +4,7 @@ Contains classes for encoding categorical data
4
4
  ``LabelEncoderTransformWarning`` new category of warning for DatasetLabelEncoder.
5
5
  ``DatasetLabelEncoder`` to encode categorical features in `Dataset` objects.
6
6
  """
7
+
7
8
  import warnings
8
9
  from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
10
 
replay/data/nn/schema.py CHANGED
@@ -418,11 +418,11 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
418
418
  "feature_type": feature.feature_type.name,
419
419
  "is_seq": feature.is_seq,
420
420
  "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
- "feature_sources": [
422
- {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
- ]
424
- if feature.feature_sources
425
- else None,
421
+ "feature_sources": (
422
+ [{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources]
423
+ if feature.feature_sources
424
+ else None
425
+ ),
426
426
  "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
427
  "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
428
  "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
@@ -42,6 +42,7 @@ For each metric, a formula for its calculation is given, because this is
42
42
  important for the correct comparison of algorithms, as mentioned in our
43
43
  `article <https://arxiv.org/abs/2206.12858>`_.
44
44
  """
45
+
45
46
  from .base_metric import Metric
46
47
  from .categorical_diversity import CategoricalDiversity
47
48
  from .coverage import Coverage
replay/models/als.py CHANGED
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy(self.query_column)
116
116
  .agg(sf.count(self.query_column).alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_interactions if max_seen_in_interactions is not None else 0
121
121
 
replay/models/base_rec.py CHANGED
@@ -401,8 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
401
401
  self.fit_items = sf.broadcast(items)
402
402
  self._num_queries = self.fit_queries.count()
403
403
  self._num_items = self.fit_items.count()
404
- self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
405
- self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
404
+ self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).first()[0] + 1
405
+ self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).first()[0] + 1
406
406
  self._fit(dataset)
407
407
 
408
408
  @abstractmethod
@@ -431,7 +431,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
431
431
  # count maximal number of items seen by queries
432
432
  max_seen = 0
433
433
  if num_seen.count() > 0:
434
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
434
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
435
435
 
436
436
  # crop recommendations to first k + max_seen items for each query
437
437
  recs = recs.withColumn(
@@ -708,7 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
708
708
  setattr(
709
709
  self,
710
710
  dim_size,
711
- fit_entities.agg({column: "max"}).collect()[0][0] + 1,
711
+ fit_entities.agg({column: "max"}).first()[0] + 1,
712
712
  )
713
713
  return getattr(self, dim_size)
714
714
 
@@ -1426,7 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1426
1426
  Calculating a fill value a the minimal rating
1427
1427
  calculated during model training multiplied by weight.
1428
1428
  """
1429
- return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
1429
+ return item_popularity.select(sf.min(rating_column)).first()[0] * weight
1430
1430
 
1431
1431
  @staticmethod
1432
1432
  def _check_rating(dataset: Dataset):
@@ -1460,7 +1460,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1460
1460
  .agg(sf.countDistinct(item_column).alias("items_count"))
1461
1461
  )
1462
1462
  .select(sf.max("items_count"))
1463
- .collect()[0][0]
1463
+ .first()[0]
1464
1464
  )
1465
1465
  # all queries have empty history
1466
1466
  if max_hist_len is None:
@@ -1495,7 +1495,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1495
1495
  queries = queries.join(query_to_num_items, on=self.query_column, how="left")
1496
1496
  queries = queries.fillna(0, "num_items")
1497
1497
  # 'selected_item_popularity' truncation by k + max_seen
1498
- max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1498
+ max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1499
1499
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1500
1500
  return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1501
1501
 
@@ -32,9 +32,9 @@ class NmslibFilterIndexInferer(IndexInferer):
32
32
  index = index_store.load_index(
33
33
  init_index=lambda: create_nmslib_index_instance(index_params),
34
34
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
35
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
36
- if index_params.ef_s
37
- else None,
35
+ configure_index=lambda index: (
36
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
37
+ ),
38
38
  )
39
39
 
40
40
  # max number of items to retrieve per batch
@@ -30,9 +30,9 @@ class NmslibIndexInferer(IndexInferer):
30
30
  index = index_store.load_index(
31
31
  init_index=lambda: create_nmslib_index_instance(index_params),
32
32
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
33
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
34
- if index_params.ef_s
35
- else None,
33
+ configure_index=lambda index: (
34
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
35
+ ),
36
36
  )
37
37
 
38
38
  user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
@@ -1,7 +1,7 @@
1
1
  import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, Optional, Tuple, Union, cast
4
+ from typing import Dict, Optional, Union
5
5
 
6
6
  import torch
7
7
 
@@ -115,13 +115,10 @@ class Bert4RecModel(torch.nn.Module):
115
115
  # (B x L x E)
116
116
  x = self.item_embedder(inputs, token_mask)
117
117
 
118
- # (B x 1 x L x L)
119
- pad_mask_for_attention = self._get_attention_mask_from_padding(pad_mask)
120
-
121
118
  # Running over multiple transformer blocks
122
119
  for transformer in self.transformer_blocks:
123
120
  for _ in range(self.num_passes_over_block):
124
- x = transformer(x, pad_mask_for_attention)
121
+ x = transformer(x, pad_mask)
125
122
 
126
123
  return x
127
124
 
@@ -147,11 +144,6 @@ class Bert4RecModel(torch.nn.Module):
147
144
  """
148
145
  return self.forward_step(inputs, pad_mask, token_mask)[:, -1, :]
149
146
 
150
- def _get_attention_mask_from_padding(self, pad_mask: torch.BoolTensor) -> torch.BoolTensor:
151
- # (B x L) -> (B x 1 x L x L)
152
- pad_mask_for_attention = pad_mask.unsqueeze(1).repeat(1, self.max_len, 1).unsqueeze(1)
153
- return cast(torch.BoolTensor, pad_mask_for_attention)
154
-
155
147
  def _init(self) -> None:
156
148
  for _, param in self.named_parameters():
157
149
  with contextlib.suppress(ValueError):
@@ -456,7 +448,7 @@ class TransformerBlock(torch.nn.Module):
456
448
  :param dropout: Dropout rate.
457
449
  """
458
450
  super().__init__()
459
- self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden_size, dropout=dropout)
451
+ self.attention = torch.nn.MultiheadAttention(hidden_size, attn_heads, dropout=dropout, batch_first=True)
460
452
  self.attention_dropout = torch.nn.Dropout(dropout)
461
453
  self.attention_norm = LayerNorm(hidden_size)
462
454
 
@@ -479,7 +471,8 @@ class TransformerBlock(torch.nn.Module):
479
471
  """
480
472
  # Attention + skip-connection
481
473
  x_norm = self.attention_norm(x)
482
- y = x + self.attention_dropout(self.attention(x_norm, x_norm, x_norm, mask))
474
+ attent_emb, _ = self.attention(x_norm, x_norm, x_norm, key_padding_mask=~mask, need_weights=False)
475
+ y = x + self.attention_dropout(attent_emb)
483
476
 
484
477
  # PFF + skip-connection
485
478
  z = y + self.pff_dropout(self.pff(self.pff_norm(y)))
@@ -487,106 +480,6 @@ class TransformerBlock(torch.nn.Module):
487
480
  return self.dropout(z)
488
481
 
489
482
 
490
- class Attention(torch.nn.Module):
491
- """
492
- Compute Scaled Dot Product Attention
493
- """
494
-
495
- def __init__(self, dropout: float) -> None:
496
- """
497
- :param dropout: Dropout rate.
498
- """
499
- super().__init__()
500
- self.dropout = torch.nn.Dropout(p=dropout)
501
-
502
- def forward(
503
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.BoolTensor
504
- ) -> Tuple[torch.Tensor, torch.Tensor]:
505
- """
506
- :param query: Query feature vector.
507
- :param key: Key feature vector.
508
- :param value: Value feature vector.
509
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
510
-
511
- :returns: Tuple of scaled dot product attention
512
- and attention logits for each element.
513
- """
514
- scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
515
-
516
- scores = scores.masked_fill(mask == 0, -1e9)
517
- p_attn = torch.nn.functional.softmax(scores, dim=-1)
518
- p_attn = self.dropout(p_attn)
519
-
520
- return torch.matmul(p_attn, value), p_attn
521
-
522
-
523
- class MultiHeadedAttention(torch.nn.Module):
524
- """
525
- Take in model size and number of heads.
526
- """
527
-
528
- def __init__(self, h: int, d_model: int, dropout: float = 0.1) -> None:
529
- """
530
- :param h: Head sizes of multi-head attention.
531
- :param d_model: Embedding dimension.
532
- :param dropout: Dropout rate.
533
- Default: ``0.1``.
534
- """
535
- super().__init__()
536
- assert d_model % h == 0
537
-
538
- # We assume d_v always equals d_k
539
- self.d_k = d_model // h
540
- self.h = h
541
-
542
- # 3 linear projections for Q, K, V
543
- self.qkv_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(3)])
544
-
545
- # 2 linear projections for P -> P_q, P_k
546
- self.pos_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(2)])
547
-
548
- self.output_linear = torch.nn.Linear(d_model, d_model)
549
-
550
- self.attention = Attention(dropout)
551
-
552
- def forward(
553
- self,
554
- query: torch.Tensor,
555
- key: torch.Tensor,
556
- value: torch.Tensor,
557
- mask: torch.BoolTensor,
558
- ) -> torch.Tensor:
559
- """
560
- :param query: Query feature vector.
561
- :param key: Key feature vector.
562
- :param value: Value feature vector.
563
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
564
-
565
- :returns: Attention outputs.
566
- """
567
- batch_size = query.size(0)
568
-
569
- # B - batch size
570
- # L - sequence length (max_len)
571
- # E - embedding size for tokens fed into transformer
572
- # K - max relative distance
573
- # H - attention head count
574
-
575
- # Do all the linear projections in batch from d_model => h x d_k
576
- # (B x L x E) -> (B x H x L x (E / H))
577
- query, key, value = [
578
- layer(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
579
- for layer, x in zip(self.qkv_linear_layers, (query, key, value))
580
- ]
581
-
582
- x, _ = self.attention(query, key, value, mask)
583
-
584
- # Concat using a view and apply a final linear.
585
- x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
586
-
587
- return self.output_linear(x)
588
-
589
-
590
483
  class LayerNorm(torch.nn.Module):
591
484
  """
592
485
  Construct a layernorm module (See citation for details).
@@ -401,7 +401,12 @@ class SasRecLayers(torch.nn.Module):
401
401
  """
402
402
  super().__init__()
403
403
  self.attention_layers = self._layers_stacker(
404
- num_blocks, torch.nn.MultiheadAttention, hidden_size, num_heads, dropout
404
+ num_blocks,
405
+ torch.nn.MultiheadAttention,
406
+ hidden_size,
407
+ num_heads,
408
+ dropout,
409
+ batch_first=True,
405
410
  )
406
411
  self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
407
412
  self.forward_layers = self._layers_stacker(num_blocks, SasRecPointWiseFeedForward, hidden_size, dropout)
@@ -422,11 +427,9 @@ class SasRecLayers(torch.nn.Module):
422
427
  """
423
428
  length = len(self.attention_layers)
424
429
  for i in range(length):
425
- seqs = torch.transpose(seqs, 0, 1)
426
430
  query = self.attention_layernorms[i](seqs)
427
- attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask)
431
+ attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask, need_weights=False)
428
432
  seqs = query + attent_emb
429
- seqs = torch.transpose(seqs, 0, 1)
430
433
 
431
434
  seqs = self.forward_layernorms[i](seqs)
432
435
  seqs = self.forward_layers[i](seqs)
@@ -492,7 +495,7 @@ class SasRecPointWiseFeedForward(torch.nn.Module):
492
495
 
493
496
  :returns: Output tensors.
494
497
  """
495
- outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
498
+ outputs = self.dropout2(self.conv2(self.dropout1(self.relu(self.conv1(inputs.transpose(-1, -2))))))
496
499
  outputs = outputs.transpose(-1, -2)
497
500
  outputs += inputs
498
501
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This class calculates loss function for optimization process
3
3
  """
4
+
4
5
  import collections
5
6
  import logging
6
7
  from functools import partial
@@ -102,6 +102,6 @@ class CSRConverter:
102
102
  row_count = self.row_count if self.row_count is not None else _get_max(rows_data) + 1
103
103
  col_count = self.column_count if self.column_count is not None else _get_max(cols_data) + 1
104
104
  return csr_matrix(
105
- (data, (rows_data, cols_data)),
105
+ (data.tolist(), (rows_data.tolist(), cols_data.tolist())),
106
106
  shape=(row_count, col_count),
107
107
  )
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Select or remove data by some criteria
3
3
  """
4
+
4
5
  from abc import ABC, abstractmethod
5
6
  from datetime import datetime, timedelta
6
7
  from typing import Callable, Optional, Tuple, Union
@@ -355,8 +356,8 @@ class NumInteractionsFilter(_BaseFilter):
355
356
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
356
357
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
357
358
  ... "rating": [1., 0.5, 3, 1, 0, 1],
358
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
359
- ... "2020-02-01", "2020-01-01 00:04:15",
359
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
360
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
360
361
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
361
362
  ... )
362
363
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -367,7 +368,7 @@ class NumInteractionsFilter(_BaseFilter):
367
368
  +-------+-------+------+-------------------+
368
369
  | u1| i1| 1.0|2020-01-01 23:59:59|
369
370
  | u2| i2| 0.5|2020-02-01 00:00:00|
370
- | u2| i3| 3.0|2020-02-01 00:00:00|
371
+ | u2| i3| 3.0|2020-02-01 00:00:01|
371
372
  | u3| i1| 1.0|2020-01-01 00:04:15|
372
373
  | u3| i2| 0.0|2020-01-02 00:04:14|
373
374
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -393,7 +394,7 @@ class NumInteractionsFilter(_BaseFilter):
393
394
  |user_id|item_id|rating| timestamp|
394
395
  +-------+-------+------+-------------------+
395
396
  | u1| i1| 1.0|2020-01-01 23:59:59|
396
- | u2| i2| 0.5|2020-02-01 00:00:00|
397
+ | u2| i3| 3.0|2020-02-01 00:00:01|
397
398
  | u3| i3| 1.0|2020-01-05 23:59:59|
398
399
  +-------+-------+------+-------------------+
399
400
  <BLANKLINE>
@@ -403,7 +404,7 @@ class NumInteractionsFilter(_BaseFilter):
403
404
  |user_id|item_id|rating| timestamp|
404
405
  +-------+-------+------+-------------------+
405
406
  | u1| i1| 1.0|2020-01-01 23:59:59|
406
- | u2| i3| 3.0|2020-02-01 00:00:00|
407
+ | u2| i3| 3.0|2020-02-01 00:00:01|
407
408
  | u3| i3| 1.0|2020-01-05 23:59:59|
408
409
  +-------+-------+------+-------------------+
409
410
  <BLANKLINE>
@@ -482,7 +483,7 @@ class NumInteractionsFilter(_BaseFilter):
482
483
 
483
484
  return (
484
485
  interactions.sort(sorting_columns, descending=descending)
485
- .with_columns(pl.col(self.query_column).cumcount().over(self.query_column).alias("temp_rank"))
486
+ .with_columns(pl.col(self.query_column).cum_count().over(self.query_column).alias("temp_rank"))
486
487
  .filter(pl.col("temp_rank") <= self.num_interactions)
487
488
  .drop("temp_rank")
488
489
  )
@@ -497,8 +498,8 @@ class EntityDaysFilter(_BaseFilter):
497
498
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
498
499
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
499
500
  ... "rating": [1., 0.5, 3, 1, 0, 1],
500
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
501
- ... "2020-02-01", "2020-01-01 00:04:15",
501
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
502
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
502
503
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
503
504
  ... )
504
505
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -509,7 +510,7 @@ class EntityDaysFilter(_BaseFilter):
509
510
  +-------+-------+------+-------------------+
510
511
  | u1| i1| 1.0|2020-01-01 23:59:59|
511
512
  | u2| i2| 0.5|2020-02-01 00:00:00|
512
- | u2| i3| 3.0|2020-02-01 00:00:00|
513
+ | u2| i3| 3.0|2020-02-01 00:00:01|
513
514
  | u3| i1| 1.0|2020-01-01 00:04:15|
514
515
  | u3| i2| 0.0|2020-01-02 00:04:14|
515
516
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -524,7 +525,7 @@ class EntityDaysFilter(_BaseFilter):
524
525
  +-------+-------+------+-------------------+
525
526
  | u1| i1| 1.0|2020-01-01 23:59:59|
526
527
  | u2| i2| 0.5|2020-02-01 00:00:00|
527
- | u2| i3| 3.0|2020-02-01 00:00:00|
528
+ | u2| i3| 3.0|2020-02-01 00:00:01|
528
529
  | u3| i1| 1.0|2020-01-01 00:04:15|
529
530
  | u3| i2| 0.0|2020-01-02 00:04:14|
530
531
  +-------+-------+------+-------------------+
@@ -539,7 +540,7 @@ class EntityDaysFilter(_BaseFilter):
539
540
  | u1| i1| 1.0|2020-01-01 23:59:59|
540
541
  | u3| i1| 1.0|2020-01-01 00:04:15|
541
542
  | u2| i2| 0.5|2020-02-01 00:00:00|
542
- | u2| i3| 3.0|2020-02-01 00:00:00|
543
+ | u2| i3| 3.0|2020-02-01 00:00:01|
543
544
  +-------+-------+------+-------------------+
544
545
  <BLANKLINE>
545
546
  """
@@ -636,8 +637,8 @@ class GlobalDaysFilter(_BaseFilter):
636
637
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
637
638
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
638
639
  ... "rating": [1., 0.5, 3, 1, 0, 1],
639
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
640
- ... "2020-02-01", "2020-01-01 00:04:15",
640
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
641
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
641
642
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
642
643
  ... )
643
644
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -648,7 +649,7 @@ class GlobalDaysFilter(_BaseFilter):
648
649
  +-------+-------+------+-------------------+
649
650
  | u1| i1| 1.0|2020-01-01 23:59:59|
650
651
  | u2| i2| 0.5|2020-02-01 00:00:00|
651
- | u2| i3| 3.0|2020-02-01 00:00:00|
652
+ | u2| i3| 3.0|2020-02-01 00:00:01|
652
653
  | u3| i1| 1.0|2020-01-01 00:04:15|
653
654
  | u3| i2| 0.0|2020-01-02 00:04:14|
654
655
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -670,7 +671,7 @@ class GlobalDaysFilter(_BaseFilter):
670
671
  |user_id|item_id|rating| timestamp|
671
672
  +-------+-------+------+-------------------+
672
673
  | u2| i2| 0.5|2020-02-01 00:00:00|
673
- | u2| i3| 3.0|2020-02-01 00:00:00|
674
+ | u2| i3| 3.0|2020-02-01 00:00:01|
674
675
  +-------+-------+------+-------------------+
675
676
  <BLANKLINE>
676
677
  """
@@ -738,8 +739,8 @@ class TimePeriodFilter(_BaseFilter):
738
739
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
739
740
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
740
741
  ... "rating": [1., 0.5, 3, 1, 0, 1],
741
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
742
- ... "2020-02-01", "2020-01-01 00:04:15",
742
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
743
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
743
744
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
744
745
  ... )
745
746
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -750,7 +751,7 @@ class TimePeriodFilter(_BaseFilter):
750
751
  +-------+-------+------+-------------------+
751
752
  | u1| i1| 1.0|2020-01-01 23:59:59|
752
753
  | u2| i2| 0.5|2020-02-01 00:00:00|
753
- | u2| i3| 3.0|2020-02-01 00:00:00|
754
+ | u2| i3| 3.0|2020-02-01 00:00:01|
754
755
  | u3| i1| 1.0|2020-01-01 00:04:15|
755
756
  | u3| i2| 0.0|2020-01-02 00:04:14|
756
757
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -179,8 +179,8 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
179
179
  abnormality_aggs = [sf.mean(sf.col("abnormality")).alias("abnormality")]
180
180
 
181
181
  # Abnormality CR:
182
- max_std = item_features.select(sf.max("i_std")).collect()[0][0]
183
- min_std = item_features.select(sf.min("i_std")).collect()[0][0]
182
+ max_std = item_features.select(sf.max("i_std")).first()[0]
183
+ min_std = item_features.select(sf.min("i_std")).first()[0]
184
184
  if max_std - min_std != 0:
185
185
  abnormality_df = abnormality_df.withColumn(
186
186
  "controversy",
@@ -201,15 +201,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
201
201
  :param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
202
202
  """
203
203
  self.calc_timestamp_based = (isinstance(log.schema["timestamp"].dataType, TimestampType)) & (
204
- log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0] > 1
204
+ log.select(sf.countDistinct(sf.col("timestamp"))).first()[0] > 1
205
205
  )
206
- self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0] > 1
206
+ self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).first()[0] > 1
207
207
 
208
208
  user_log_features = log.groupBy("user_idx").agg(*self._create_log_aggregates(agg_col="user_idx"))
209
209
  item_log_features = log.groupBy("item_idx").agg(*self._create_log_aggregates(agg_col="item_idx"))
210
210
 
211
211
  if self.calc_timestamp_based:
212
- last_date = log.select(sf.max("timestamp")).collect()[0][0]
212
+ last_date = log.select(sf.max("timestamp")).first()[0]
213
213
  user_log_features = self._add_ts_based(features=user_log_features, max_log_date=last_date, prefix="u")
214
214
 
215
215
  item_log_features = self._add_ts_based(features=item_log_features, max_log_date=last_date, prefix="i")
@@ -5,6 +5,7 @@ Contains classes for encoding categorical data
5
5
  Recommended to use together with the LabelEncoder.
6
6
  ``LabelEncoder`` to apply multiple LabelEncodingRule to dataframe.
7
7
  """
8
+
8
9
  import abc
9
10
  import warnings
10
11
  from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
@@ -1,4 +1,5 @@
1
1
  """
2
2
  Scenarios are a series of actions for recommendations
3
3
  """
4
+
4
5
  from .fallback import Fallback
@@ -193,7 +193,7 @@ class LastNSplitter(Splitter):
193
193
 
194
194
  def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
195
195
  res = interactions.sort(self.timestamp_column).with_columns(
196
- pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column)).alias("row_num")
196
+ pl.col(self.divide_column).cum_count().over(pl.col(self.divide_column)).alias("row_num")
197
197
  )
198
198
 
199
199
  return res
@@ -193,7 +193,7 @@ class TimeSplitter(Splitter):
193
193
  )
194
194
  test_start = int(dates.count() * (1 - threshold)) + 1
195
195
  test_start = (
196
- dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).collect()[0][0]
196
+ dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).first()[0]
197
197
  )
198
198
  res = interactions.withColumn("is_test", sf.col(self.timestamp_column) >= test_start)
199
199
  else:
@@ -1,8 +1,10 @@
1
1
  """
2
2
  This splitter split data by two columns.
3
3
  """
4
+
4
5
  from typing import Optional, Tuple
5
6
 
7
+ import numpy as np
6
8
  import polars as pl
7
9
 
8
10
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
@@ -124,15 +126,15 @@ class TwoStageSplitter(Splitter):
124
126
  :return: DataFrame with single column `first_divide_column`
125
127
  """
126
128
  if isinstance(interactions, SparkDataFrame):
127
- all_values = interactions.select(self.first_divide_column).distinct()
129
+ all_values = interactions.select(self.first_divide_column).distinct().sort(self.first_divide_column)
128
130
  user_count = all_values.count()
129
131
  elif isinstance(interactions, PandasDataFrame):
130
132
  all_values = PandasDataFrame(
131
- interactions[self.first_divide_column].unique(), columns=[self.first_divide_column]
133
+ np.sort(interactions[self.first_divide_column].unique()), columns=[self.first_divide_column]
132
134
  )
133
135
  user_count = len(all_values)
134
136
  else:
135
- all_values = interactions.select(self.first_divide_column).unique()
137
+ all_values = interactions.select(self.first_divide_column).unique().sort(self.first_divide_column)
136
138
  user_count = len(all_values)
137
139
 
138
140
  value_error = False
@@ -152,7 +154,7 @@ class TwoStageSplitter(Splitter):
152
154
  if isinstance(interactions, SparkDataFrame):
153
155
  test_users = (
154
156
  all_values.withColumn("_rand", sf.rand(self.seed))
155
- .withColumn("_row_num", sf.row_number().over(Window.orderBy("_rand")))
157
+ .withColumn("_row_num", sf.row_number().over(Window.partitionBy(sf.lit(0)).orderBy("_rand")))
156
158
  .filter(f"_row_num <= {test_user_count}")
157
159
  .drop("_rand", "_row_num")
158
160
  )
@@ -240,10 +242,10 @@ class TwoStageSplitter(Splitter):
240
242
  res = res.fill_null(False)
241
243
 
242
244
  train = res.filter((pl.col("_frac") > self.second_divide_size) | (~pl.col("is_test"))).drop(
243
- "_rand", "_row_num", "count", "_frac", "is_test"
245
+ "_row_num", "count", "_frac", "is_test"
244
246
  )
245
247
  test = res.filter((pl.col("_frac") <= self.second_divide_size) & pl.col("is_test")).drop(
246
- "_rand", "_row_num", "count", "_frac", "is_test"
248
+ "_row_num", "count", "_frac", "is_test"
247
249
  )
248
250
 
249
251
  return train, test
@@ -1,4 +1,5 @@
1
1
  """Distribution calculations"""
2
+
2
3
  from .types import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame
3
4
 
4
5
  if PYSPARK_AVAILABLE: