replay-rec 0.17.1rc0__tar.gz → 0.18.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/PKG-INFO +12 -18
  2. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/pyproject.toml +19 -23
  3. replay_rec-0.18.0/replay/__init__.py +3 -0
  4. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/dataset.py +3 -2
  5. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/dataset_utils/dataset_label_encoder.py +1 -0
  6. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/schema.py +5 -5
  7. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/__init__.py +1 -0
  8. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/als.py +1 -1
  9. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/base_rec.py +7 -7
  10. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
  11. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
  12. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/bert4rec/model.py +5 -112
  13. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/sasrec/model.py +8 -5
  14. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/optimization/optuna_objective.py +1 -0
  15. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/converter.py +1 -1
  16. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/filters.py +19 -18
  17. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/history_based_fp.py +5 -5
  18. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/label_encoder.py +1 -0
  19. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/scenarios/__init__.py +1 -0
  20. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/last_n_splitter.py +1 -1
  21. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/time_splitter.py +1 -1
  22. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/two_stage_splitter.py +8 -6
  23. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/distributions.py +1 -0
  24. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/session_handler.py +3 -3
  25. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/spark_utils.py +2 -2
  26. replay_rec-0.17.1rc0/NOTICE +0 -41
  27. replay_rec-0.17.1rc0/replay/__init__.py +0 -2
  28. replay_rec-0.17.1rc0/replay/experimental/metrics/__init__.py +0 -61
  29. replay_rec-0.17.1rc0/replay/experimental/metrics/base_metric.py +0 -601
  30. replay_rec-0.17.1rc0/replay/experimental/metrics/coverage.py +0 -97
  31. replay_rec-0.17.1rc0/replay/experimental/metrics/experiment.py +0 -175
  32. replay_rec-0.17.1rc0/replay/experimental/metrics/hitrate.py +0 -26
  33. replay_rec-0.17.1rc0/replay/experimental/metrics/map.py +0 -30
  34. replay_rec-0.17.1rc0/replay/experimental/metrics/mrr.py +0 -18
  35. replay_rec-0.17.1rc0/replay/experimental/metrics/ncis_precision.py +0 -31
  36. replay_rec-0.17.1rc0/replay/experimental/metrics/ndcg.py +0 -49
  37. replay_rec-0.17.1rc0/replay/experimental/metrics/precision.py +0 -22
  38. replay_rec-0.17.1rc0/replay/experimental/metrics/recall.py +0 -25
  39. replay_rec-0.17.1rc0/replay/experimental/metrics/rocauc.py +0 -49
  40. replay_rec-0.17.1rc0/replay/experimental/metrics/surprisal.py +0 -90
  41. replay_rec-0.17.1rc0/replay/experimental/metrics/unexpectedness.py +0 -76
  42. replay_rec-0.17.1rc0/replay/experimental/models/__init__.py +0 -10
  43. replay_rec-0.17.1rc0/replay/experimental/models/admm_slim.py +0 -205
  44. replay_rec-0.17.1rc0/replay/experimental/models/base_neighbour_rec.py +0 -204
  45. replay_rec-0.17.1rc0/replay/experimental/models/base_rec.py +0 -1271
  46. replay_rec-0.17.1rc0/replay/experimental/models/base_torch_rec.py +0 -234
  47. replay_rec-0.17.1rc0/replay/experimental/models/cql.py +0 -452
  48. replay_rec-0.17.1rc0/replay/experimental/models/ddpg.py +0 -921
  49. replay_rec-0.17.1rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
  50. replay_rec-0.17.1rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
  51. replay_rec-0.17.1rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
  52. replay_rec-0.17.1rc0/replay/experimental/models/dt4rec/utils.py +0 -265
  53. replay_rec-0.17.1rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  54. replay_rec-0.17.1rc0/replay/experimental/models/implicit_wrap.py +0 -131
  55. replay_rec-0.17.1rc0/replay/experimental/models/lightfm_wrap.py +0 -302
  56. replay_rec-0.17.1rc0/replay/experimental/models/mult_vae.py +0 -331
  57. replay_rec-0.17.1rc0/replay/experimental/models/neuromf.py +0 -405
  58. replay_rec-0.17.1rc0/replay/experimental/models/scala_als.py +0 -296
  59. replay_rec-0.17.1rc0/replay/experimental/nn/data/__init__.py +0 -1
  60. replay_rec-0.17.1rc0/replay/experimental/nn/data/schema_builder.py +0 -55
  61. replay_rec-0.17.1rc0/replay/experimental/preprocessing/__init__.py +0 -3
  62. replay_rec-0.17.1rc0/replay/experimental/preprocessing/data_preparator.py +0 -838
  63. replay_rec-0.17.1rc0/replay/experimental/preprocessing/padder.py +0 -229
  64. replay_rec-0.17.1rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
  65. replay_rec-0.17.1rc0/replay/experimental/scenarios/__init__.py +0 -1
  66. replay_rec-0.17.1rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  67. replay_rec-0.17.1rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  68. replay_rec-0.17.1rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  69. replay_rec-0.17.1rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  70. replay_rec-0.17.1rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
  71. replay_rec-0.17.1rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  72. replay_rec-0.17.1rc0/replay/experimental/utils/logger.py +0 -24
  73. replay_rec-0.17.1rc0/replay/experimental/utils/model_handler.py +0 -181
  74. replay_rec-0.17.1rc0/replay/experimental/utils/session_handler.py +0 -44
  75. replay_rec-0.17.1rc0/replay/models/extensions/ann/__init__.py +0 -0
  76. replay_rec-0.17.1rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
  77. replay_rec-0.17.1rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  78. replay_rec-0.17.1rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  79. replay_rec-0.17.1rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  80. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/LICENSE +0 -0
  81. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/README.md +0 -0
  82. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/__init__.py +0 -0
  83. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/dataset_utils/__init__.py +0 -0
  84. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/__init__.py +0 -0
  85. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/sequence_tokenizer.py +0 -0
  86. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/sequential_dataset.py +0 -0
  87. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/torch_sequential_dataset.py +0 -0
  88. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/nn/utils.py +0 -0
  89. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/schema.py +0 -0
  90. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/data/spark_schema.py +0 -0
  91. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/base_metric.py +0 -0
  92. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/categorical_diversity.py +0 -0
  93. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/coverage.py +0 -0
  94. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/descriptors.py +0 -0
  95. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/experiment.py +0 -0
  96. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/hitrate.py +0 -0
  97. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/map.py +0 -0
  98. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/mrr.py +0 -0
  99. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/ndcg.py +0 -0
  100. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/novelty.py +0 -0
  101. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/offline_metrics.py +0 -0
  102. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/precision.py +0 -0
  103. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/recall.py +0 -0
  104. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/rocauc.py +0 -0
  105. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/surprisal.py +0 -0
  106. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/torch_metrics_builder.py +0 -0
  107. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/metrics/unexpectedness.py +0 -0
  108. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/__init__.py +0 -0
  109. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/association_rules.py +0 -0
  110. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/base_neighbour_rec.py +0 -0
  111. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/cat_pop_rec.py +0 -0
  112. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/cluster.py +0 -0
  113. {replay_rec-0.17.1rc0/replay/experimental → replay_rec-0.18.0/replay/models/extensions}/__init__.py +0 -0
  114. {replay_rec-0.17.1rc0/replay/experimental/models/dt4rec → replay_rec-0.18.0/replay/models/extensions/ann}/__init__.py +0 -0
  115. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/ann_mixin.py +0 -0
  116. {replay_rec-0.17.1rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.18.0/replay/models/extensions/ann/entities}/__init__.py +0 -0
  117. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  118. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  119. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  120. {replay_rec-0.17.1rc0/replay/experimental/scenarios/two_stages → replay_rec-0.18.0/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
  121. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  122. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  123. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  124. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  125. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  126. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  127. {replay_rec-0.17.1rc0/replay/experimental/utils → replay_rec-0.18.0/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
  128. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  129. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  130. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  131. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  132. {replay_rec-0.17.1rc0/replay/models/extensions → replay_rec-0.18.0/replay/models/extensions/ann/index_stores}/__init__.py +0 -0
  133. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  134. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  135. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  136. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  137. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  138. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/extensions/ann/utils.py +0 -0
  139. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/kl_ucb.py +0 -0
  140. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/knn.py +0 -0
  141. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/__init__.py +0 -0
  142. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  143. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  144. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/__init__.py +0 -0
  145. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  146. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
  147. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  148. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  149. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  150. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
  151. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  152. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  153. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
  154. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  155. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
  156. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  157. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/pop_rec.py +0 -0
  158. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/query_pop_rec.py +0 -0
  159. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/random_rec.py +0 -0
  160. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/slim.py +0 -0
  161. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/thompson_sampling.py +0 -0
  162. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/ucb.py +0 -0
  163. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/wilson.py +0 -0
  164. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/models/word2vec.py +0 -0
  165. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/optimization/__init__.py +0 -0
  166. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/__init__.py +0 -0
  167. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/preprocessing/sessionizer.py +0 -0
  168. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/scenarios/fallback.py +0 -0
  169. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/__init__.py +0 -0
  170. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/base_splitter.py +0 -0
  171. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/cold_user_random_splitter.py +0 -0
  172. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/k_folds.py +0 -0
  173. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/new_users_splitter.py +0 -0
  174. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/random_splitter.py +0 -0
  175. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/splitters/ratio_splitter.py +0 -0
  176. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/__init__.py +0 -0
  177. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/common.py +0 -0
  178. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/dataframe_bucketizer.py +0 -0
  179. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/model_handler.py +0 -0
  180. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/time.py +0 -0
  181. {replay_rec-0.17.1rc0 → replay_rec-0.18.0}/replay/utils/types.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.1rc0
3
+ Version: 0.18.0
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
7
7
  Author: AI Lab
8
- Requires-Python: >=3.8.1,<3.11
8
+ Requires-Python: >=3.8.1,<3.12
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Environment :: Console
11
11
  Classifier: Intended Audience :: Developers
@@ -16,32 +16,26 @@ Classifier: Operating System :: Unix
16
16
  Classifier: Programming Language :: Python :: 3
17
17
  Classifier: Programming Language :: Python :: 3.9
18
18
  Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
19
20
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
21
  Provides-Extra: all
21
22
  Provides-Extra: spark
22
23
  Provides-Extra: torch
23
- Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
24
- Requires-Dist: gym (>=0.26.0,<0.27.0)
25
- Requires-Dist: hnswlib (==0.7.0)
26
- Requires-Dist: implicit (>=0.7.0,<0.8.0)
27
- Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
28
- Requires-Dist: lightfm (==1.17)
29
- Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
30
- Requires-Dist: llvmlite (>=0.32.1)
31
- Requires-Dist: nmslib (==2.1.1)
32
- Requires-Dist: numba (>=0.50)
24
+ Requires-Dist: fixed-install-nmslib (==2.1.2)
25
+ Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
26
+ Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "all"
33
27
  Requires-Dist: numpy (>=1.20.0)
34
28
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
29
  Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
- Requires-Dist: polars (>=0.20.7,<0.21.0)
37
- Requires-Dist: psutil (>=5.9.5,<5.10.0)
30
+ Requires-Dist: polars (>=1.0.0,<1.1.0)
31
+ Requires-Dist: psutil (>=6.0.0,<6.1.0)
38
32
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
33
+ Requires-Dist: pyspark (>=3.0,<3.6) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
34
+ Requires-Dist: pyspark (>=3.4,<3.6) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
40
35
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
- Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
42
36
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
43
- Requires-Dist: scipy (>=1.8.1,<1.9.0)
44
- Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
37
+ Requires-Dist: scipy (>=1.8.1,<2.0.0)
38
+ Requires-Dist: torch (>=1.8,<=2.4.0) ; extra == "torch" or extra == "all"
45
39
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
46
40
  Description-Content-Type: text/markdown
47
41
 
@@ -7,7 +7,7 @@ build-backend = "poetry_dynamic_versioning.backend"
7
7
 
8
8
  [tool.black]
9
9
  line-length = 120
10
- target-versions = ["py38", "py39", "py310"]
10
+ target-versions = ["py38", "py39", "py310", "py311"]
11
11
 
12
12
  [tool.poetry]
13
13
  name = "replay-rec"
@@ -39,33 +39,29 @@ classifiers = [
39
39
  ]
40
40
  exclude = [
41
41
  "replay/conftest.py",
42
+ "replay/experimental",
42
43
  ]
43
- version = "0.17.1.preview"
44
+ version = "0.18.0"
44
45
 
45
46
  [tool.poetry.dependencies]
46
- python = ">=3.8.1, <3.11"
47
+ python = ">=3.8.1, <3.12"
47
48
  numpy = ">=1.20.0"
48
- pandas = ">=1.3.5,<=2.2.2"
49
- polars = "~0.20.7"
49
+ pandas = ">=1.3.5, <=2.2.2"
50
+ polars = "~1.0.0"
50
51
  optuna = "~3.2.0"
51
- scipy = "~1.8.1"
52
- psutil = "~5.9.5"
53
- pyspark = {version = ">=3.0,<3.5", optional = true}
52
+ scipy = "^1.8.1"
53
+ psutil = "~6.0.0"
54
54
  scikit-learn = "^1.0.2"
55
55
  pyarrow = ">=12.0.1"
56
- nmslib = "2.1.1"
57
- hnswlib = "0.7.0"
58
- torch = "^1.8"
59
- lightning = "^2.0.2"
60
- pytorch-ranger = "^0.1.1"
61
- lightfm = "1.17"
62
- lightautoml = "~0.3.1"
63
- numba = ">=0.50"
64
- llvmlite = ">=0.32.1"
65
- sb-obp = "^0.5.7"
66
- d3rlpy = "^2.0.4"
67
- implicit = "~0.7.0"
68
- gym = "^0.26.0"
56
+ pyspark = [
57
+ {version = ">=3.4,<3.6", python = ">=3.11,<3.12", optional = true},
58
+ {version = ">=3.0,<3.6", python = ">=3.8.1,<3.11", optional = true},
59
+ ]
60
+ torch = {version = ">=1.8, <=2.4.0", optional = true}
61
+ lightning = {version = ">=2.0.2, <=2.4.0", optional = true}
62
+ pytorch-ranger = {version = "^0.1.1", optional = true}
63
+ fixed-install-nmslib = "2.1.2"
64
+ hnswlib = "^0.7.0"
69
65
 
70
66
  [tool.poetry.extras]
71
67
  spark = ["pyspark"]
@@ -77,7 +73,7 @@ jupyter = "~1.0.0"
77
73
  jupyterlab = "^3.6.0"
78
74
  pytest = ">=7.1.0"
79
75
  pytest-cov = ">=3.0.0"
80
- statsmodels = "~0.13.5"
76
+ statsmodels = "~0.14.0"
81
77
  black = ">=23.3.0"
82
78
  ruff = ">=0.0.261"
83
79
  toml-sort = "^0.23.0"
@@ -92,7 +88,7 @@ data-science-types = "0.2.23"
92
88
 
93
89
  [tool.poetry-dynamic-versioning]
94
90
  enable = false
95
- format-jinja = """0.17.1{{ env['PACKAGE_SUFFIX'] }}"""
91
+ format-jinja = """0.18.0{{ env['PACKAGE_SUFFIX'] }}"""
96
92
  vcs = "git"
97
93
 
98
94
  [tool.ruff]
@@ -0,0 +1,3 @@
1
+ """ RecSys library """
2
+
3
+ __version__ = "0.18.0"
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ``Dataset`` universal dataset class for manipulating interactions and feed data to models.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
6
7
  import json
@@ -606,7 +607,7 @@ class Dataset:
606
607
  if self.is_pandas:
607
608
  min_id = data[column].min()
608
609
  elif self.is_spark:
609
- min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
610
+ min_id = data.agg(sf.min(column).alias("min_index")).first()[0]
610
611
  else:
611
612
  min_id = data[column].min()
612
613
  if min_id < 0:
@@ -616,7 +617,7 @@ class Dataset:
616
617
  if self.is_pandas:
617
618
  max_id = data[column].max()
618
619
  elif self.is_spark:
619
- max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
620
+ max_id = data.agg(sf.max(column).alias("max_index")).first()[0]
620
621
  else:
621
622
  max_id = data[column].max()
622
623
 
@@ -4,6 +4,7 @@ Contains classes for encoding categorical data
4
4
  ``LabelEncoderTransformWarning`` new category of warning for DatasetLabelEncoder.
5
5
  ``DatasetLabelEncoder`` to encode categorical features in `Dataset` objects.
6
6
  """
7
+
7
8
  import warnings
8
9
  from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
10
 
@@ -418,11 +418,11 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
418
418
  "feature_type": feature.feature_type.name,
419
419
  "is_seq": feature.is_seq,
420
420
  "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
- "feature_sources": [
422
- {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
- ]
424
- if feature.feature_sources
425
- else None,
421
+ "feature_sources": (
422
+ [{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources]
423
+ if feature.feature_sources
424
+ else None
425
+ ),
426
426
  "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
427
  "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
428
  "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
@@ -42,6 +42,7 @@ For each metric, a formula for its calculation is given, because this is
42
42
  important for the correct comparison of algorithms, as mentioned in our
43
43
  `article <https://arxiv.org/abs/2206.12858>`_.
44
44
  """
45
+
45
46
  from .base_metric import Metric
46
47
  from .categorical_diversity import CategoricalDiversity
47
48
  from .coverage import Coverage
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy(self.query_column)
116
116
  .agg(sf.count(self.query_column).alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_interactions if max_seen_in_interactions is not None else 0
121
121
 
@@ -401,8 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
401
401
  self.fit_items = sf.broadcast(items)
402
402
  self._num_queries = self.fit_queries.count()
403
403
  self._num_items = self.fit_items.count()
404
- self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
405
- self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
404
+ self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).first()[0] + 1
405
+ self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).first()[0] + 1
406
406
  self._fit(dataset)
407
407
 
408
408
  @abstractmethod
@@ -431,7 +431,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
431
431
  # count maximal number of items seen by queries
432
432
  max_seen = 0
433
433
  if num_seen.count() > 0:
434
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
434
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
435
435
 
436
436
  # crop recommendations to first k + max_seen items for each query
437
437
  recs = recs.withColumn(
@@ -708,7 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
708
708
  setattr(
709
709
  self,
710
710
  dim_size,
711
- fit_entities.agg({column: "max"}).collect()[0][0] + 1,
711
+ fit_entities.agg({column: "max"}).first()[0] + 1,
712
712
  )
713
713
  return getattr(self, dim_size)
714
714
 
@@ -1426,7 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1426
1426
  Calculating a fill value a the minimal rating
1427
1427
  calculated during model training multiplied by weight.
1428
1428
  """
1429
- return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
1429
+ return item_popularity.select(sf.min(rating_column)).first()[0] * weight
1430
1430
 
1431
1431
  @staticmethod
1432
1432
  def _check_rating(dataset: Dataset):
@@ -1460,7 +1460,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1460
1460
  .agg(sf.countDistinct(item_column).alias("items_count"))
1461
1461
  )
1462
1462
  .select(sf.max("items_count"))
1463
- .collect()[0][0]
1463
+ .first()[0]
1464
1464
  )
1465
1465
  # all queries have empty history
1466
1466
  if max_hist_len is None:
@@ -1495,7 +1495,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1495
1495
  queries = queries.join(query_to_num_items, on=self.query_column, how="left")
1496
1496
  queries = queries.fillna(0, "num_items")
1497
1497
  # 'selected_item_popularity' truncation by k + max_seen
1498
- max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1498
+ max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1499
1499
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1500
1500
  return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1501
1501
 
@@ -32,9 +32,9 @@ class NmslibFilterIndexInferer(IndexInferer):
32
32
  index = index_store.load_index(
33
33
  init_index=lambda: create_nmslib_index_instance(index_params),
34
34
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
35
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
36
- if index_params.ef_s
37
- else None,
35
+ configure_index=lambda index: (
36
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
37
+ ),
38
38
  )
39
39
 
40
40
  # max number of items to retrieve per batch
@@ -30,9 +30,9 @@ class NmslibIndexInferer(IndexInferer):
30
30
  index = index_store.load_index(
31
31
  init_index=lambda: create_nmslib_index_instance(index_params),
32
32
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
33
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
34
- if index_params.ef_s
35
- else None,
33
+ configure_index=lambda index: (
34
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
35
+ ),
36
36
  )
37
37
 
38
38
  user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
@@ -1,7 +1,7 @@
1
1
  import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, Optional, Tuple, Union, cast
4
+ from typing import Dict, Optional, Union
5
5
 
6
6
  import torch
7
7
 
@@ -115,13 +115,10 @@ class Bert4RecModel(torch.nn.Module):
115
115
  # (B x L x E)
116
116
  x = self.item_embedder(inputs, token_mask)
117
117
 
118
- # (B x 1 x L x L)
119
- pad_mask_for_attention = self._get_attention_mask_from_padding(pad_mask)
120
-
121
118
  # Running over multiple transformer blocks
122
119
  for transformer in self.transformer_blocks:
123
120
  for _ in range(self.num_passes_over_block):
124
- x = transformer(x, pad_mask_for_attention)
121
+ x = transformer(x, pad_mask)
125
122
 
126
123
  return x
127
124
 
@@ -147,11 +144,6 @@ class Bert4RecModel(torch.nn.Module):
147
144
  """
148
145
  return self.forward_step(inputs, pad_mask, token_mask)[:, -1, :]
149
146
 
150
- def _get_attention_mask_from_padding(self, pad_mask: torch.BoolTensor) -> torch.BoolTensor:
151
- # (B x L) -> (B x 1 x L x L)
152
- pad_mask_for_attention = pad_mask.unsqueeze(1).repeat(1, self.max_len, 1).unsqueeze(1)
153
- return cast(torch.BoolTensor, pad_mask_for_attention)
154
-
155
147
  def _init(self) -> None:
156
148
  for _, param in self.named_parameters():
157
149
  with contextlib.suppress(ValueError):
@@ -456,7 +448,7 @@ class TransformerBlock(torch.nn.Module):
456
448
  :param dropout: Dropout rate.
457
449
  """
458
450
  super().__init__()
459
- self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden_size, dropout=dropout)
451
+ self.attention = torch.nn.MultiheadAttention(hidden_size, attn_heads, dropout=dropout, batch_first=True)
460
452
  self.attention_dropout = torch.nn.Dropout(dropout)
461
453
  self.attention_norm = LayerNorm(hidden_size)
462
454
 
@@ -479,7 +471,8 @@ class TransformerBlock(torch.nn.Module):
479
471
  """
480
472
  # Attention + skip-connection
481
473
  x_norm = self.attention_norm(x)
482
- y = x + self.attention_dropout(self.attention(x_norm, x_norm, x_norm, mask))
474
+ attent_emb, _ = self.attention(x_norm, x_norm, x_norm, key_padding_mask=~mask, need_weights=False)
475
+ y = x + self.attention_dropout(attent_emb)
483
476
 
484
477
  # PFF + skip-connection
485
478
  z = y + self.pff_dropout(self.pff(self.pff_norm(y)))
@@ -487,106 +480,6 @@ class TransformerBlock(torch.nn.Module):
487
480
  return self.dropout(z)
488
481
 
489
482
 
490
- class Attention(torch.nn.Module):
491
- """
492
- Compute Scaled Dot Product Attention
493
- """
494
-
495
- def __init__(self, dropout: float) -> None:
496
- """
497
- :param dropout: Dropout rate.
498
- """
499
- super().__init__()
500
- self.dropout = torch.nn.Dropout(p=dropout)
501
-
502
- def forward(
503
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.BoolTensor
504
- ) -> Tuple[torch.Tensor, torch.Tensor]:
505
- """
506
- :param query: Query feature vector.
507
- :param key: Key feature vector.
508
- :param value: Value feature vector.
509
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
510
-
511
- :returns: Tuple of scaled dot product attention
512
- and attention logits for each element.
513
- """
514
- scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
515
-
516
- scores = scores.masked_fill(mask == 0, -1e9)
517
- p_attn = torch.nn.functional.softmax(scores, dim=-1)
518
- p_attn = self.dropout(p_attn)
519
-
520
- return torch.matmul(p_attn, value), p_attn
521
-
522
-
523
- class MultiHeadedAttention(torch.nn.Module):
524
- """
525
- Take in model size and number of heads.
526
- """
527
-
528
- def __init__(self, h: int, d_model: int, dropout: float = 0.1) -> None:
529
- """
530
- :param h: Head sizes of multi-head attention.
531
- :param d_model: Embedding dimension.
532
- :param dropout: Dropout rate.
533
- Default: ``0.1``.
534
- """
535
- super().__init__()
536
- assert d_model % h == 0
537
-
538
- # We assume d_v always equals d_k
539
- self.d_k = d_model // h
540
- self.h = h
541
-
542
- # 3 linear projections for Q, K, V
543
- self.qkv_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(3)])
544
-
545
- # 2 linear projections for P -> P_q, P_k
546
- self.pos_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(2)])
547
-
548
- self.output_linear = torch.nn.Linear(d_model, d_model)
549
-
550
- self.attention = Attention(dropout)
551
-
552
- def forward(
553
- self,
554
- query: torch.Tensor,
555
- key: torch.Tensor,
556
- value: torch.Tensor,
557
- mask: torch.BoolTensor,
558
- ) -> torch.Tensor:
559
- """
560
- :param query: Query feature vector.
561
- :param key: Key feature vector.
562
- :param value: Value feature vector.
563
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
564
-
565
- :returns: Attention outputs.
566
- """
567
- batch_size = query.size(0)
568
-
569
- # B - batch size
570
- # L - sequence length (max_len)
571
- # E - embedding size for tokens fed into transformer
572
- # K - max relative distance
573
- # H - attention head count
574
-
575
- # Do all the linear projections in batch from d_model => h x d_k
576
- # (B x L x E) -> (B x H x L x (E / H))
577
- query, key, value = [
578
- layer(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
579
- for layer, x in zip(self.qkv_linear_layers, (query, key, value))
580
- ]
581
-
582
- x, _ = self.attention(query, key, value, mask)
583
-
584
- # Concat using a view and apply a final linear.
585
- x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
586
-
587
- return self.output_linear(x)
588
-
589
-
590
483
  class LayerNorm(torch.nn.Module):
591
484
  """
592
485
  Construct a layernorm module (See citation for details).
@@ -401,7 +401,12 @@ class SasRecLayers(torch.nn.Module):
401
401
  """
402
402
  super().__init__()
403
403
  self.attention_layers = self._layers_stacker(
404
- num_blocks, torch.nn.MultiheadAttention, hidden_size, num_heads, dropout
404
+ num_blocks,
405
+ torch.nn.MultiheadAttention,
406
+ hidden_size,
407
+ num_heads,
408
+ dropout,
409
+ batch_first=True,
405
410
  )
406
411
  self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
407
412
  self.forward_layers = self._layers_stacker(num_blocks, SasRecPointWiseFeedForward, hidden_size, dropout)
@@ -422,11 +427,9 @@ class SasRecLayers(torch.nn.Module):
422
427
  """
423
428
  length = len(self.attention_layers)
424
429
  for i in range(length):
425
- seqs = torch.transpose(seqs, 0, 1)
426
430
  query = self.attention_layernorms[i](seqs)
427
- attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask)
431
+ attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask, need_weights=False)
428
432
  seqs = query + attent_emb
429
- seqs = torch.transpose(seqs, 0, 1)
430
433
 
431
434
  seqs = self.forward_layernorms[i](seqs)
432
435
  seqs = self.forward_layers[i](seqs)
@@ -492,7 +495,7 @@ class SasRecPointWiseFeedForward(torch.nn.Module):
492
495
 
493
496
  :returns: Output tensors.
494
497
  """
495
- outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
498
+ outputs = self.dropout2(self.conv2(self.dropout1(self.relu(self.conv1(inputs.transpose(-1, -2))))))
496
499
  outputs = outputs.transpose(-1, -2)
497
500
  outputs += inputs
498
501
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This class calculates loss function for optimization process
3
3
  """
4
+
4
5
  import collections
5
6
  import logging
6
7
  from functools import partial
@@ -102,6 +102,6 @@ class CSRConverter:
102
102
  row_count = self.row_count if self.row_count is not None else _get_max(rows_data) + 1
103
103
  col_count = self.column_count if self.column_count is not None else _get_max(cols_data) + 1
104
104
  return csr_matrix(
105
- (data, (rows_data, cols_data)),
105
+ (data.tolist(), (rows_data.tolist(), cols_data.tolist())),
106
106
  shape=(row_count, col_count),
107
107
  )
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Select or remove data by some criteria
3
3
  """
4
+
4
5
  from abc import ABC, abstractmethod
5
6
  from datetime import datetime, timedelta
6
7
  from typing import Callable, Optional, Tuple, Union
@@ -355,8 +356,8 @@ class NumInteractionsFilter(_BaseFilter):
355
356
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
356
357
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
357
358
  ... "rating": [1., 0.5, 3, 1, 0, 1],
358
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
359
- ... "2020-02-01", "2020-01-01 00:04:15",
359
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
360
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
360
361
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
361
362
  ... )
362
363
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -367,7 +368,7 @@ class NumInteractionsFilter(_BaseFilter):
367
368
  +-------+-------+------+-------------------+
368
369
  | u1| i1| 1.0|2020-01-01 23:59:59|
369
370
  | u2| i2| 0.5|2020-02-01 00:00:00|
370
- | u2| i3| 3.0|2020-02-01 00:00:00|
371
+ | u2| i3| 3.0|2020-02-01 00:00:01|
371
372
  | u3| i1| 1.0|2020-01-01 00:04:15|
372
373
  | u3| i2| 0.0|2020-01-02 00:04:14|
373
374
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -393,7 +394,7 @@ class NumInteractionsFilter(_BaseFilter):
393
394
  |user_id|item_id|rating| timestamp|
394
395
  +-------+-------+------+-------------------+
395
396
  | u1| i1| 1.0|2020-01-01 23:59:59|
396
- | u2| i2| 0.5|2020-02-01 00:00:00|
397
+ | u2| i3| 3.0|2020-02-01 00:00:01|
397
398
  | u3| i3| 1.0|2020-01-05 23:59:59|
398
399
  +-------+-------+------+-------------------+
399
400
  <BLANKLINE>
@@ -403,7 +404,7 @@ class NumInteractionsFilter(_BaseFilter):
403
404
  |user_id|item_id|rating| timestamp|
404
405
  +-------+-------+------+-------------------+
405
406
  | u1| i1| 1.0|2020-01-01 23:59:59|
406
- | u2| i3| 3.0|2020-02-01 00:00:00|
407
+ | u2| i3| 3.0|2020-02-01 00:00:01|
407
408
  | u3| i3| 1.0|2020-01-05 23:59:59|
408
409
  +-------+-------+------+-------------------+
409
410
  <BLANKLINE>
@@ -482,7 +483,7 @@ class NumInteractionsFilter(_BaseFilter):
482
483
 
483
484
  return (
484
485
  interactions.sort(sorting_columns, descending=descending)
485
- .with_columns(pl.col(self.query_column).cumcount().over(self.query_column).alias("temp_rank"))
486
+ .with_columns(pl.col(self.query_column).cum_count().over(self.query_column).alias("temp_rank"))
486
487
  .filter(pl.col("temp_rank") <= self.num_interactions)
487
488
  .drop("temp_rank")
488
489
  )
@@ -497,8 +498,8 @@ class EntityDaysFilter(_BaseFilter):
497
498
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
498
499
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
499
500
  ... "rating": [1., 0.5, 3, 1, 0, 1],
500
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
501
- ... "2020-02-01", "2020-01-01 00:04:15",
501
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
502
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
502
503
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
503
504
  ... )
504
505
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -509,7 +510,7 @@ class EntityDaysFilter(_BaseFilter):
509
510
  +-------+-------+------+-------------------+
510
511
  | u1| i1| 1.0|2020-01-01 23:59:59|
511
512
  | u2| i2| 0.5|2020-02-01 00:00:00|
512
- | u2| i3| 3.0|2020-02-01 00:00:00|
513
+ | u2| i3| 3.0|2020-02-01 00:00:01|
513
514
  | u3| i1| 1.0|2020-01-01 00:04:15|
514
515
  | u3| i2| 0.0|2020-01-02 00:04:14|
515
516
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -524,7 +525,7 @@ class EntityDaysFilter(_BaseFilter):
524
525
  +-------+-------+------+-------------------+
525
526
  | u1| i1| 1.0|2020-01-01 23:59:59|
526
527
  | u2| i2| 0.5|2020-02-01 00:00:00|
527
- | u2| i3| 3.0|2020-02-01 00:00:00|
528
+ | u2| i3| 3.0|2020-02-01 00:00:01|
528
529
  | u3| i1| 1.0|2020-01-01 00:04:15|
529
530
  | u3| i2| 0.0|2020-01-02 00:04:14|
530
531
  +-------+-------+------+-------------------+
@@ -539,7 +540,7 @@ class EntityDaysFilter(_BaseFilter):
539
540
  | u1| i1| 1.0|2020-01-01 23:59:59|
540
541
  | u3| i1| 1.0|2020-01-01 00:04:15|
541
542
  | u2| i2| 0.5|2020-02-01 00:00:00|
542
- | u2| i3| 3.0|2020-02-01 00:00:00|
543
+ | u2| i3| 3.0|2020-02-01 00:00:01|
543
544
  +-------+-------+------+-------------------+
544
545
  <BLANKLINE>
545
546
  """
@@ -636,8 +637,8 @@ class GlobalDaysFilter(_BaseFilter):
636
637
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
637
638
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
638
639
  ... "rating": [1., 0.5, 3, 1, 0, 1],
639
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
640
- ... "2020-02-01", "2020-01-01 00:04:15",
640
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
641
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
641
642
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
642
643
  ... )
643
644
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -648,7 +649,7 @@ class GlobalDaysFilter(_BaseFilter):
648
649
  +-------+-------+------+-------------------+
649
650
  | u1| i1| 1.0|2020-01-01 23:59:59|
650
651
  | u2| i2| 0.5|2020-02-01 00:00:00|
651
- | u2| i3| 3.0|2020-02-01 00:00:00|
652
+ | u2| i3| 3.0|2020-02-01 00:00:01|
652
653
  | u3| i1| 1.0|2020-01-01 00:04:15|
653
654
  | u3| i2| 0.0|2020-01-02 00:04:14|
654
655
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -670,7 +671,7 @@ class GlobalDaysFilter(_BaseFilter):
670
671
  |user_id|item_id|rating| timestamp|
671
672
  +-------+-------+------+-------------------+
672
673
  | u2| i2| 0.5|2020-02-01 00:00:00|
673
- | u2| i3| 3.0|2020-02-01 00:00:00|
674
+ | u2| i3| 3.0|2020-02-01 00:00:01|
674
675
  +-------+-------+------+-------------------+
675
676
  <BLANKLINE>
676
677
  """
@@ -738,8 +739,8 @@ class TimePeriodFilter(_BaseFilter):
738
739
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
739
740
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
740
741
  ... "rating": [1., 0.5, 3, 1, 0, 1],
741
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
742
- ... "2020-02-01", "2020-01-01 00:04:15",
742
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
743
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
743
744
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
744
745
  ... )
745
746
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -750,7 +751,7 @@ class TimePeriodFilter(_BaseFilter):
750
751
  +-------+-------+------+-------------------+
751
752
  | u1| i1| 1.0|2020-01-01 23:59:59|
752
753
  | u2| i2| 0.5|2020-02-01 00:00:00|
753
- | u2| i3| 3.0|2020-02-01 00:00:00|
754
+ | u2| i3| 3.0|2020-02-01 00:00:01|
754
755
  | u3| i1| 1.0|2020-01-01 00:04:15|
755
756
  | u3| i2| 0.0|2020-01-02 00:04:14|
756
757
  | u3| i3| 1.0|2020-01-05 23:59:59|