replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -3,13 +3,12 @@ from typing import Optional
3
3
  import numpy as np
4
4
  from scipy.sparse import csr_matrix
5
5
 
6
- from replay.utils import DataFrameLike, SparkDataFrame, PYSPARK_AVAILABLE
6
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, SparkDataFrame
7
7
 
8
8
  if PYSPARK_AVAILABLE:
9
9
  from replay.utils.spark_utils import spark_to_pandas
10
10
 
11
11
 
12
- # pylint: disable=too-few-public-methods
13
12
  class CSRConverter:
14
13
  """
15
14
  Convert input data to csr sparse matrix.
@@ -47,7 +46,6 @@ class CSRConverter:
47
46
  <BLANKLINE>
48
47
  """
49
48
 
50
- # pylint: disable=too-many-arguments
51
49
  def __init__(
52
50
  self,
53
51
  first_dim_column: str,
@@ -96,10 +94,7 @@ class CSRConverter:
96
94
 
97
95
  rows_data = data[self.first_dim_column].values
98
96
  cols_data = data[self.second_dim_column].values
99
- if self.data_column is not None:
100
- data = data[self.data_column].values
101
- else:
102
- data = np.ones(data.shape[0])
97
+ data = data[self.data_column].values if self.data_column is not None else np.ones(data.shape[0])
103
98
 
104
99
  def _get_max(data: np.ndarray) -> int:
105
100
  return np.max(data) if data.shape[0] > 0 else 0
@@ -1,22 +1,23 @@
1
1
  """
2
2
  Select or remove data by some criteria
3
3
  """
4
- import polars as pl
5
4
  from abc import ABC, abstractmethod
6
5
  from datetime import datetime, timedelta
7
- from typing import Callable, Optional, Union, Tuple
6
+ from typing import Callable, Optional, Tuple, Union
8
7
 
9
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, SparkDataFrame, PolarsDataFrame
8
+ import polars as pl
10
9
 
10
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
11
11
 
12
12
  if PYSPARK_AVAILABLE:
13
- from pyspark.sql import Window
14
- from pyspark.sql import functions as sf
13
+ from pyspark.sql import (
14
+ Window,
15
+ functions as sf,
16
+ )
15
17
  from pyspark.sql.functions import col
16
18
  from pyspark.sql.types import TimestampType
17
19
 
18
20
 
19
- # pylint: disable=too-few-public-methods
20
21
  class _BaseFilter(ABC):
21
22
  def transform(self, interactions: DataFrameLike) -> DataFrameLike:
22
23
  r"""Filter interactions.
@@ -32,7 +33,8 @@ class _BaseFilter(ABC):
32
33
  elif isinstance(interactions, PolarsDataFrame):
33
34
  return self._filter_polars(interactions)
34
35
  else:
35
- raise NotImplementedError(f"{self.__class__.__name__} is not implemented for {type(interactions)}")
36
+ msg = f"{self.__class__.__name__} is not implemented for {type(interactions)}"
37
+ raise NotImplementedError(msg)
36
38
 
37
39
  @abstractmethod
38
40
  def _filter_spark(self, interactions: SparkDataFrame): # pragma: no cover
@@ -47,7 +49,6 @@ class _BaseFilter(ABC):
47
49
  pass
48
50
 
49
51
 
50
- # pylint: disable=too-few-public-methods, too-many-instance-attributes
51
52
  class InteractionEntriesFilter(_BaseFilter):
52
53
  """
53
54
  Remove interactions less than minimum constraint value and greater
@@ -81,7 +82,6 @@ class InteractionEntriesFilter(_BaseFilter):
81
82
  <BLANKLINE>
82
83
  """
83
84
 
84
- # pylint: disable=too-many-arguments
85
85
  def __init__(
86
86
  self,
87
87
  query_column: str = "user_id",
@@ -166,11 +166,10 @@ class InteractionEntriesFilter(_BaseFilter):
166
166
  interactions, interaction_count, agg_column, non_agg_column, min_inter, max_inter
167
167
  )
168
168
  is_dropped_user_item[current_index] = bool(dropped_interact)
169
- current_index = (current_index + 1) % 2 # current_index only in (0, 1)
169
+ current_index = (current_index + 1) % 2 # current_index only in (0, 1)
170
170
 
171
171
  return interactions
172
172
 
173
- # pylint: disable=no-self-use
174
173
  def _filter_column_pandas(
175
174
  self,
176
175
  interactions: PandasDataFrame,
@@ -196,7 +195,6 @@ class InteractionEntriesFilter(_BaseFilter):
196
195
 
197
196
  return filtered_interactions, different_len, end_len_dataframe
198
197
 
199
- # pylint: disable=no-self-use
200
198
  def _filter_column_spark(
201
199
  self,
202
200
  interactions: SparkDataFrame,
@@ -223,7 +221,6 @@ class InteractionEntriesFilter(_BaseFilter):
223
221
 
224
222
  return filtered_interactions, different_len, end_len_dataframe
225
223
 
226
- # pylint: disable=no-self-use
227
224
  def _filter_column_polars(
228
225
  self,
229
226
  interactions: PolarsDataFrame,
@@ -234,8 +231,7 @@ class InteractionEntriesFilter(_BaseFilter):
234
231
  max_inter: Optional[int] = None,
235
232
  ) -> Tuple[PolarsDataFrame, int, int]:
236
233
  filtered_interactions = interactions.with_columns(
237
- pl.col(non_agg_column).count().over(pl.col(agg_column))
238
- .alias("count")
234
+ pl.col(non_agg_column).count().over(pl.col(agg_column)).alias("count")
239
235
  )
240
236
  if min_inter:
241
237
  filtered_interactions = filtered_interactions.filter(pl.col("count") >= min_inter)
@@ -294,32 +290,20 @@ class MinCountFilter(_BaseFilter):
294
290
 
295
291
  def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
296
292
  filtered_interactions = interactions.copy(deep=True)
297
- filtered_interactions["count"] = (
298
- filtered_interactions
299
- .groupby(self.groupby_column)[self.groupby_column]
300
- .transform(len)
301
- )
302
- return (
303
- filtered_interactions[filtered_interactions["count"] >= self.num_entries]
304
- .drop(columns=["count"])
305
- )
293
+ filtered_interactions["count"] = filtered_interactions.groupby(self.groupby_column)[
294
+ self.groupby_column
295
+ ].transform(len)
296
+ return filtered_interactions[filtered_interactions["count"] >= self.num_entries].drop(columns=["count"])
306
297
 
307
298
  def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
308
299
  filtered_interactions = interactions.clone()
309
300
  count_by_group = (
310
- filtered_interactions
311
- .group_by(self.groupby_column)
312
- .agg(
313
- pl.col(self.groupby_column).count().alias(f"{self.groupby_column}_temp_count")
314
- )
315
- .filter(
316
- pl.col(f"{self.groupby_column}_temp_count") >= self.num_entries
317
- )
301
+ filtered_interactions.group_by(self.groupby_column)
302
+ .agg(pl.col(self.groupby_column).count().alias(f"{self.groupby_column}_temp_count"))
303
+ .filter(pl.col(f"{self.groupby_column}_temp_count") >= self.num_entries)
318
304
  )
319
- return (
320
- filtered_interactions
321
- .join(count_by_group, on=self.groupby_column)
322
- .drop(f"{self.groupby_column}_temp_count")
305
+ return filtered_interactions.join(count_by_group, on=self.groupby_column).drop(
306
+ f"{self.groupby_column}_temp_count"
323
307
  )
324
308
 
325
309
 
@@ -423,7 +407,6 @@ class NumInteractionsFilter(_BaseFilter):
423
407
  <BLANKLINE>
424
408
  """
425
409
 
426
- # pylint: disable=too-many-arguments
427
410
  def __init__(
428
411
  self,
429
412
  num_interactions: int = 10,
@@ -480,14 +463,12 @@ class NumInteractionsFilter(_BaseFilter):
480
463
  ascending = [self.first] * len(sorting_columns)
481
464
 
482
465
  filtered_interactions["temp_rank"] = (
483
- filtered_interactions
484
- .sort_values(sorting_columns, ascending=ascending)
466
+ filtered_interactions.sort_values(sorting_columns, ascending=ascending)
485
467
  .groupby(self.query_column)
486
468
  .cumcount()
487
469
  )
488
- return (
489
- filtered_interactions[filtered_interactions["temp_rank"] < self.num_interactions]
490
- .drop(columns=["temp_rank"])
470
+ return filtered_interactions[filtered_interactions["temp_rank"] < self.num_interactions].drop(
471
+ columns=["temp_rank"]
491
472
  )
492
473
 
493
474
  def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
@@ -498,15 +479,10 @@ class NumInteractionsFilter(_BaseFilter):
498
479
  descending = not self.first
499
480
 
500
481
  return (
501
- interactions
502
- .sort(sorting_columns, descending=descending)
503
- .with_columns(
504
- pl.col(self.query_column)
505
- .cumcount()
506
- .over(self.query_column)
507
- .alias("temp_rank")
508
- )
509
- .filter(pl.col("temp_rank") <= self.num_interactions).drop("temp_rank")
482
+ interactions.sort(sorting_columns, descending=descending)
483
+ .with_columns(pl.col(self.query_column).cumcount().over(self.query_column).alias("temp_rank"))
484
+ .filter(pl.col("temp_rank") <= self.num_interactions)
485
+ .drop("temp_rank")
510
486
  )
511
487
 
512
488
 
@@ -598,18 +574,13 @@ class EntityDaysFilter(_BaseFilter):
598
574
  if self.first:
599
575
  filtered_interactions = (
600
576
  interactions.withColumn("min_date", sf.min(col(self.timestamp_column)).over(window))
601
- .filter(
602
- col(self.timestamp_column)
603
- < col("min_date") + sf.expr(f"INTERVAL {self.days} days")
604
- )
577
+ .filter(col(self.timestamp_column) < col("min_date") + sf.expr(f"INTERVAL {self.days} days"))
605
578
  .drop("min_date")
606
579
  )
607
580
  else:
608
581
  filtered_interactions = (
609
582
  interactions.withColumn("max_date", sf.max(col(self.timestamp_column)).over(window))
610
- .filter(
611
- col(self.timestamp_column) > col("max_date") - sf.expr(f"INTERVAL {self.days} days")
612
- )
583
+ .filter(col(self.timestamp_column) > col("max_date") - sf.expr(f"INTERVAL {self.days} days"))
613
584
  .drop("max_date")
614
585
  )
615
586
  return filtered_interactions
@@ -618,57 +589,36 @@ class EntityDaysFilter(_BaseFilter):
618
589
  filtered_interactions = interactions.copy(deep=True)
619
590
 
620
591
  if self.first:
621
- filtered_interactions["min_date"] = (
622
- filtered_interactions
623
- .groupby(self.entity_column)[self.timestamp_column]
624
- .transform(min)
625
- )
626
- return (
627
- filtered_interactions[
628
- (
629
- filtered_interactions[self.timestamp_column]
630
- - filtered_interactions["min_date"]
631
- ).dt.days < self.days
632
- ]
633
- .drop(columns=["min_date"])
634
- )
635
- filtered_interactions["max_date"] = (
636
- filtered_interactions
637
- .groupby(self.entity_column)[self.timestamp_column]
638
- .transform(max)
639
- )
640
- return (
641
- filtered_interactions[
642
- (
643
- filtered_interactions["max_date"]
644
- - filtered_interactions[self.timestamp_column]
645
- ).dt.days < self.days
646
- ]
647
- .drop(columns=["max_date"])
648
- )
592
+ filtered_interactions["min_date"] = filtered_interactions.groupby(self.entity_column)[
593
+ self.timestamp_column
594
+ ].transform(min)
595
+ return filtered_interactions[
596
+ (filtered_interactions[self.timestamp_column] - filtered_interactions["min_date"]).dt.days < self.days
597
+ ].drop(columns=["min_date"])
598
+ filtered_interactions["max_date"] = filtered_interactions.groupby(self.entity_column)[
599
+ self.timestamp_column
600
+ ].transform(max)
601
+ return filtered_interactions[
602
+ (filtered_interactions["max_date"] - filtered_interactions[self.timestamp_column]).dt.days < self.days
603
+ ].drop(columns=["max_date"])
649
604
 
650
605
  def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
651
606
  if self.first:
652
607
  return (
653
- interactions
654
- .with_columns(
608
+ interactions.with_columns(
655
609
  (
656
- pl.col(self.timestamp_column)
657
- .min().over(pl.col(self.entity_column)) + pl.duration(days=self.days)
658
- )
659
- .alias("min_date")
610
+ pl.col(self.timestamp_column).min().over(pl.col(self.entity_column))
611
+ + pl.duration(days=self.days)
612
+ ).alias("min_date")
660
613
  )
661
614
  .filter(pl.col(self.timestamp_column) < pl.col("min_date"))
662
615
  .drop("min_date")
663
616
  )
664
617
  return (
665
- interactions
666
- .with_columns(
618
+ interactions.with_columns(
667
619
  (
668
- pl.col(self.timestamp_column)
669
- .max().over(pl.col(self.entity_column)) - pl.duration(days=self.days)
670
- )
671
- .alias("max_date")
620
+ pl.col(self.timestamp_column).max().over(pl.col(self.entity_column)) - pl.duration(days=self.days)
621
+ ).alias("max_date")
672
622
  )
673
623
  .filter(pl.col(self.timestamp_column) > pl.col("max_date"))
674
624
  .drop("max_date")
@@ -749,15 +699,11 @@ class GlobalDaysFilter(_BaseFilter):
749
699
  def _filter_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
750
700
  if self.first:
751
701
  start_date = interactions.agg(sf.min(self.timestamp_column)).first()[0]
752
- end_date = sf.lit(start_date).cast(TimestampType()) + sf.expr(
753
- f"INTERVAL {self.days} days"
754
- )
702
+ end_date = sf.lit(start_date).cast(TimestampType()) + sf.expr(f"INTERVAL {self.days} days")
755
703
  return interactions.filter(col(self.timestamp_column) < end_date)
756
704
 
757
705
  end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0]
758
- start_date = sf.lit(end_date).cast(TimestampType()) - sf.expr(
759
- f"INTERVAL {self.days} days"
760
- )
706
+ start_date = sf.lit(end_date).cast(TimestampType()) - sf.expr(f"INTERVAL {self.days} days")
761
707
  return interactions.filter(col(self.timestamp_column) > start_date)
762
708
 
763
709
  def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
@@ -765,33 +711,19 @@ class GlobalDaysFilter(_BaseFilter):
765
711
 
766
712
  if self.first:
767
713
  start_date = filtered_interactions[self.timestamp_column].min()
768
- return (
769
- filtered_interactions[
770
- (filtered_interactions[self.timestamp_column] - start_date).dt.days < self.days
771
- ]
772
- )
773
- end_date = filtered_interactions[self.timestamp_column].max()
774
- return (
775
- filtered_interactions[
776
- (end_date - filtered_interactions[self.timestamp_column]).dt.days < self.days
714
+ return filtered_interactions[
715
+ (filtered_interactions[self.timestamp_column] - start_date).dt.days < self.days
777
716
  ]
778
- )
717
+ end_date = filtered_interactions[self.timestamp_column].max()
718
+ return filtered_interactions[(end_date - filtered_interactions[self.timestamp_column]).dt.days < self.days]
779
719
 
780
720
  def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
781
721
  if self.first:
782
- return (
783
- interactions
784
- .filter(
785
- pl.col(self.timestamp_column)
786
- < (pl.col(self.timestamp_column).min() + pl.duration(days=self.days))
787
- )
788
- )
789
- return (
790
- interactions
791
- .filter(
792
- pl.col(self.timestamp_column)
793
- > (pl.col(self.timestamp_column).max() - pl.duration(days=self.days))
722
+ return interactions.filter(
723
+ pl.col(self.timestamp_column) < (pl.col(self.timestamp_column).min() + pl.duration(days=self.days))
794
724
  )
725
+ return interactions.filter(
726
+ pl.col(self.timestamp_column) > (pl.col(self.timestamp_column).max() - pl.duration(days=self.days))
795
727
  )
796
728
 
797
729
 
@@ -823,7 +755,10 @@ class TimePeriodFilter(_BaseFilter):
823
755
  +-------+-------+------+-------------------+
824
756
  <BLANKLINE>
825
757
 
826
- >>> TimePeriodFilter(start_date="2020-01-01 14:00:00", end_date=datetime(2020, 1, 3, 0, 0, 0)).transform(log_sp).show()
758
+ >>> TimePeriodFilter(
759
+ ... start_date="2020-01-01 14:00:00",
760
+ ... end_date=datetime(2020, 1, 3, 0, 0, 0)
761
+ ... ).transform(log_sp).show()
827
762
  +-------+-------+------+-------------------+
828
763
  |user_id|item_id|rating| timestamp|
829
764
  +-------+-------+------+-------------------+
@@ -861,9 +796,7 @@ class TimePeriodFilter(_BaseFilter):
861
796
  if self.start_date is None:
862
797
  self.start_date = interactions.agg(sf.min(self.timestamp_column)).first()[0]
863
798
  if self.end_date is None:
864
- self.end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0] + timedelta(
865
- seconds=1
866
- )
799
+ self.end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0] + timedelta(seconds=1)
867
800
 
868
801
  return interactions.filter(
869
802
  (col(self.timestamp_column) >= sf.lit(self.start_date))
@@ -874,9 +807,7 @@ class TimePeriodFilter(_BaseFilter):
874
807
  if self.start_date is None:
875
808
  self.start_date = interactions[self.timestamp_column].min()
876
809
  if self.end_date is None:
877
- self.end_date = interactions[self.timestamp_column].max() + timedelta(
878
- seconds=1
879
- )
810
+ self.end_date = interactions[self.timestamp_column].max() + timedelta(seconds=1)
880
811
 
881
812
  return interactions[
882
813
  (interactions[self.timestamp_column] >= self.start_date)
@@ -887,14 +818,8 @@ class TimePeriodFilter(_BaseFilter):
887
818
  if self.start_date is None:
888
819
  self.start_date = interactions.select(self.timestamp_column).min()[0, 0]
889
820
  if self.end_date is None:
890
- self.end_date = interactions.select(self.timestamp_column).max()[0, 0] + pl.duration(
891
- seconds=1
892
- )
821
+ self.end_date = interactions.select(self.timestamp_column).max()[0, 0] + pl.duration(seconds=1)
893
822
 
894
- return (
895
- interactions
896
- .filter(
897
- pl.col(self.timestamp_column)
898
- .is_between(self.start_date, self.end_date, closed="left")
899
- )
823
+ return interactions.filter(
824
+ pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
900
825
  )