replay-rec 0.20.0rc0__tar.gz → 0.20.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/PKG-INFO +17 -11
  2. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/pyproject.toml +24 -12
  3. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/__init__.py +1 -1
  4. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/sequence_tokenizer.py +10 -3
  5. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/sequential_dataset.py +18 -14
  6. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/torch_sequential_dataset.py +12 -12
  7. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/lin_ucb.py +55 -9
  8. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/dataset.py +3 -16
  9. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  10. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/dataset.py +3 -16
  11. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/__init__.py +0 -1
  12. replay_rec-0.20.0rc0/replay/experimental/metrics/__init__.py +0 -62
  13. replay_rec-0.20.0rc0/replay/experimental/metrics/base_metric.py +0 -603
  14. replay_rec-0.20.0rc0/replay/experimental/metrics/coverage.py +0 -97
  15. replay_rec-0.20.0rc0/replay/experimental/metrics/experiment.py +0 -175
  16. replay_rec-0.20.0rc0/replay/experimental/metrics/hitrate.py +0 -26
  17. replay_rec-0.20.0rc0/replay/experimental/metrics/map.py +0 -30
  18. replay_rec-0.20.0rc0/replay/experimental/metrics/mrr.py +0 -18
  19. replay_rec-0.20.0rc0/replay/experimental/metrics/ncis_precision.py +0 -31
  20. replay_rec-0.20.0rc0/replay/experimental/metrics/ndcg.py +0 -49
  21. replay_rec-0.20.0rc0/replay/experimental/metrics/precision.py +0 -22
  22. replay_rec-0.20.0rc0/replay/experimental/metrics/recall.py +0 -25
  23. replay_rec-0.20.0rc0/replay/experimental/metrics/rocauc.py +0 -49
  24. replay_rec-0.20.0rc0/replay/experimental/metrics/surprisal.py +0 -90
  25. replay_rec-0.20.0rc0/replay/experimental/metrics/unexpectedness.py +0 -76
  26. replay_rec-0.20.0rc0/replay/experimental/models/__init__.py +0 -50
  27. replay_rec-0.20.0rc0/replay/experimental/models/admm_slim.py +0 -257
  28. replay_rec-0.20.0rc0/replay/experimental/models/base_neighbour_rec.py +0 -200
  29. replay_rec-0.20.0rc0/replay/experimental/models/base_rec.py +0 -1386
  30. replay_rec-0.20.0rc0/replay/experimental/models/base_torch_rec.py +0 -234
  31. replay_rec-0.20.0rc0/replay/experimental/models/cql.py +0 -454
  32. replay_rec-0.20.0rc0/replay/experimental/models/ddpg.py +0 -932
  33. replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
  34. replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
  35. replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
  36. replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/utils.py +0 -264
  37. replay_rec-0.20.0rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  38. replay_rec-0.20.0rc0/replay/experimental/models/hierarchical_recommender.py +0 -331
  39. replay_rec-0.20.0rc0/replay/experimental/models/implicit_wrap.py +0 -131
  40. replay_rec-0.20.0rc0/replay/experimental/models/lightfm_wrap.py +0 -303
  41. replay_rec-0.20.0rc0/replay/experimental/models/mult_vae.py +0 -332
  42. replay_rec-0.20.0rc0/replay/experimental/models/neural_ts.py +0 -986
  43. replay_rec-0.20.0rc0/replay/experimental/models/neuromf.py +0 -406
  44. replay_rec-0.20.0rc0/replay/experimental/models/scala_als.py +0 -293
  45. replay_rec-0.20.0rc0/replay/experimental/models/u_lin_ucb.py +0 -115
  46. replay_rec-0.20.0rc0/replay/experimental/nn/data/__init__.py +0 -1
  47. replay_rec-0.20.0rc0/replay/experimental/nn/data/schema_builder.py +0 -102
  48. replay_rec-0.20.0rc0/replay/experimental/preprocessing/__init__.py +0 -3
  49. replay_rec-0.20.0rc0/replay/experimental/preprocessing/data_preparator.py +0 -839
  50. replay_rec-0.20.0rc0/replay/experimental/preprocessing/padder.py +0 -229
  51. replay_rec-0.20.0rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
  52. replay_rec-0.20.0rc0/replay/experimental/scenarios/__init__.py +0 -1
  53. replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  54. replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  55. replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  56. replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  57. replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
  58. replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  59. replay_rec-0.20.0rc0/replay/experimental/utils/logger.py +0 -24
  60. replay_rec-0.20.0rc0/replay/experimental/utils/model_handler.py +0 -186
  61. replay_rec-0.20.0rc0/replay/experimental/utils/session_handler.py +0 -44
  62. replay_rec-0.20.0rc0/replay/models/extensions/ann/__init__.py +0 -0
  63. replay_rec-0.20.0rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
  64. replay_rec-0.20.0rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  65. replay_rec-0.20.0rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  66. replay_rec-0.20.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  67. replay_rec-0.20.0rc0/replay/utils/warnings.py +0 -26
  68. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/LICENSE +0 -0
  69. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/NOTICE +0 -0
  70. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/README.md +0 -0
  71. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/__init__.py +0 -0
  72. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset.py +0 -0
  73. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset_utils/__init__.py +0 -0
  74. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
  75. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/__init__.py +0 -0
  76. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/schema.py +0 -0
  77. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/utils.py +0 -0
  78. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/schema.py +0 -0
  79. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/spark_schema.py +0 -0
  80. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/__init__.py +0 -0
  81. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/base_metric.py +0 -0
  82. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/categorical_diversity.py +0 -0
  83. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/coverage.py +0 -0
  84. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/descriptors.py +0 -0
  85. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/experiment.py +0 -0
  86. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/hitrate.py +0 -0
  87. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/map.py +0 -0
  88. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/mrr.py +0 -0
  89. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/ndcg.py +0 -0
  90. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/novelty.py +0 -0
  91. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/offline_metrics.py +0 -0
  92. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/precision.py +0 -0
  93. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/recall.py +0 -0
  94. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/rocauc.py +0 -0
  95. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/surprisal.py +0 -0
  96. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/torch_metrics_builder.py +0 -0
  97. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/unexpectedness.py +0 -0
  98. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/__init__.py +0 -0
  99. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/als.py +0 -0
  100. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/association_rules.py +0 -0
  101. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/base_neighbour_rec.py +0 -0
  102. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/base_rec.py +0 -0
  103. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/cat_pop_rec.py +0 -0
  104. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/cluster.py +0 -0
  105. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/common.py +0 -0
  106. {replay_rec-0.20.0rc0/replay/experimental → replay_rec-0.20.1/replay/models/extensions}/__init__.py +0 -0
  107. {replay_rec-0.20.0rc0/replay/experimental/models/dt4rec → replay_rec-0.20.1/replay/models/extensions/ann}/__init__.py +0 -0
  108. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/ann_mixin.py +0 -0
  109. {replay_rec-0.20.0rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.20.1/replay/models/extensions/ann/entities}/__init__.py +0 -0
  110. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  111. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  112. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  113. {replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages → replay_rec-0.20.1/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
  114. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  115. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  116. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  117. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  118. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  119. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  120. {replay_rec-0.20.0rc0/replay/experimental/utils → replay_rec-0.20.1/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
  121. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  122. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  123. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  124. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  125. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  126. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  127. {replay_rec-0.20.0rc0/replay/models/extensions → replay_rec-0.20.1/replay/models/extensions/ann/index_stores}/__init__.py +0 -0
  128. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  129. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  130. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  131. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  132. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  133. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/utils.py +0 -0
  134. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/kl_ucb.py +0 -0
  135. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/knn.py +0 -0
  136. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/__init__.py +0 -0
  137. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/loss/__init__.py +0 -0
  138. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/loss/sce.py +0 -0
  139. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  140. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  141. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/__init__.py +0 -0
  142. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  143. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  144. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/model.py +0 -0
  145. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  146. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  147. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
  148. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/__init__.py +0 -0
  149. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/base_compiled_model.py +0 -0
  150. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +0 -0
  151. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/sasrec_compiled.py +0 -0
  152. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  153. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  154. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  155. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  156. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/model.py +0 -0
  157. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/__init__.py +0 -0
  158. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/optuna_mixin.py +0 -0
  159. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/optuna_objective.py +0 -0
  160. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/pop_rec.py +0 -0
  161. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/query_pop_rec.py +0 -0
  162. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/random_rec.py +0 -0
  163. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/slim.py +0 -0
  164. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/thompson_sampling.py +0 -0
  165. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/ucb.py +0 -0
  166. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/wilson.py +0 -0
  167. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/word2vec.py +0 -0
  168. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/__init__.py +0 -0
  169. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/converter.py +0 -0
  170. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/discretizer.py +0 -0
  171. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/filters.py +0 -0
  172. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/history_based_fp.py +0 -0
  173. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/label_encoder.py +0 -0
  174. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/sessionizer.py +0 -0
  175. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/scenarios/__init__.py +0 -0
  176. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/scenarios/fallback.py +0 -0
  177. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/__init__.py +0 -0
  178. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/base_splitter.py +0 -0
  179. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/cold_user_random_splitter.py +0 -0
  180. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/k_folds.py +0 -0
  181. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/last_n_splitter.py +0 -0
  182. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/new_users_splitter.py +0 -0
  183. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/random_splitter.py +0 -0
  184. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/ratio_splitter.py +0 -0
  185. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/time_splitter.py +0 -0
  186. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/two_stage_splitter.py +0 -0
  187. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/common.py +0 -0
  188. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/dataframe_bucketizer.py +0 -0
  189. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/distributions.py +0 -0
  190. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/model_handler.py +0 -0
  191. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/session_handler.py +0 -0
  192. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/spark_utils.py +0 -0
  193. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/time.py +0 -0
  194. {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.20.0rc0
3
+ Version: 0.20.1
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Natural Language :: English
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
- Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
- Requires-Dist: implicit (>=0.7.2,<0.8)
19
- Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
- Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
- Requires-Dist: numba (>=0.50,<1)
17
+ Provides-Extra: spark
18
+ Provides-Extra: torch
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
22
23
  Requires-Dist: numpy (>=1.20.0,<2)
23
24
  Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
25
  Requires-Dist: polars (<2.0)
25
- Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
26
28
  Requires-Dist: pyarrow (<22.0)
27
- Requires-Dist: pyspark (>=3.0,<3.5)
28
- Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
- Requires-Dist: sb-obp (>=0.5.10,<0.6)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
30
34
  Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
35
  Requires-Dist: scipy (>=1.13.1,<1.14)
32
36
  Requires-Dist: setuptools
33
- Requires-Dist: torch (>=1.8,<3.0.0)
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
34
40
  Requires-Dist: tqdm (>=4.67,<5)
35
41
  Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
36
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
@@ -40,19 +40,19 @@ dependencies = [
40
40
  "scikit-learn (>=1.6.1,<1.7.0)",
41
41
  "pyarrow (<22.0)",
42
42
  "tqdm (>=4.67,<5)",
43
- "torch (>=1.8,<3.0.0)",
44
- "lightning (>=2.0.2,<=2.4.0)",
45
- "pytorch-optimizer (>=3.8.0,<4)",
46
- "lightautoml (>=0.4.1,<0.5)",
47
- "numba (>=0.50,<1)",
48
- "sb-obp (>=0.5.10,<0.6)",
49
- "d3rlpy (>=2.8.1,<2.9)",
50
- "implicit (>=0.7.2,<0.8)",
51
- "pyspark (>=3.0,<3.5)",
52
- "psutil (<=7.0.0)",
43
+ "pyspark (>=3.0,<3.5); extra == 'spark'",
44
+ "psutil (<=7.0.0); extra == 'spark'",
45
+ "torch (>=1.8, <3.0.0); extra == 'torch' or extra == 'torch-cpu'",
46
+ "pytorch-optimizer (>=3.8.0,<3.9.0); extra == 'torch' or extra == 'torch-cpu'",
47
+ "lightning (<2.6.0); extra == 'torch' or extra == 'torch-cpu'",
53
48
  ]
54
49
  dynamic = ["dependencies"]
55
- version = "0.20.0.preview"
50
+ version = "0.20.1"
51
+
52
+ [project.optional-dependencies]
53
+ spark = ["pyspark", "psutil"]
54
+ torch = ["torch", "pytorch-optimizer", "lightning"]
55
+ torch-cpu = ["torch", "pytorch-optimizer", "lightning"]
56
56
 
57
57
  [project.urls]
58
58
  homepage = "https://sb-ai-lab.github.io/RePlay/"
@@ -66,6 +66,13 @@ target-version = ["py39", "py310", "py311", "py312"]
66
66
  packages = [{include = "replay"}]
67
67
  exclude = [
68
68
  "replay/conftest.py",
69
+ "replay/experimental",
70
+ ]
71
+
72
+ [tool.poetry.dependencies]
73
+ torch = [
74
+ {markers = "extra == 'torch-cpu' and extra !='torch'", source = "torch-cpu-mirror"},
75
+ {markers = "extra == 'torch' and extra !='torch-cpu'", source = "PyPI"},
69
76
  ]
70
77
 
71
78
  [tool.poetry.group.dev.dependencies]
@@ -88,9 +95,14 @@ docutils = "0.16"
88
95
  data-science-types = "0.2.23"
89
96
  filelock = "~3.14.0"
90
97
 
98
+ [[tool.poetry.source]]
99
+ name = "torch-cpu-mirror"
100
+ url = "https://download.pytorch.org/whl/cpu"
101
+ priority = "explicit"
102
+
91
103
  [tool.poetry-dynamic-versioning]
92
104
  enable = false
93
- format-jinja = """0.20.0{{ env['PACKAGE_SUFFIX'] }}"""
105
+ format-jinja = """0.20.1{{ env['PACKAGE_SUFFIX'] }}"""
94
106
  vcs = "git"
95
107
 
96
108
  [tool.ruff]
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.20.0.preview"
7
+ __version__ = "0.20.1"
@@ -15,7 +15,6 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
15
15
  from replay.data.dataset_utils import DatasetLabelEncoder
16
16
  from replay.preprocessing import LabelEncoder, LabelEncodingRule
17
17
  from replay.preprocessing.label_encoder import HandleUnknownStrategies
18
- from replay.utils import deprecation_warning
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
@@ -406,7 +405,6 @@ class SequenceTokenizer:
406
405
  tensor_feature._set_cardinality(dataset_feature.cardinality)
407
406
 
408
407
  @classmethod
409
- @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
410
408
  def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
411
409
  """
412
410
  Load tokenizer object from the given path.
@@ -450,12 +448,16 @@ class SequenceTokenizer:
450
448
  tokenizer._encoder._features_columns = encoder_features_columns
451
449
  tokenizer._encoder._encoding_rules = tokenizer_dict["encoder"]["encoding_rules"]
452
450
  else:
451
+ warnings.warn(
452
+ "with `use_pickle` equals to `True` will be deprecated in future versions",
453
+ DeprecationWarning,
454
+ stacklevel=2,
455
+ )
453
456
  with open(path, "rb") as file:
454
457
  tokenizer = pickle.load(file)
455
458
 
456
459
  return tokenizer
457
460
 
458
- @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
459
461
  def save(self, path: str, use_pickle: bool = False) -> None:
460
462
  """
461
463
  Save the tokenizer to the given path.
@@ -496,6 +498,11 @@ class SequenceTokenizer:
496
498
  with open(base_path / "init_args.json", "w+") as file:
497
499
  json.dump(tokenizer_dict, file)
498
500
  else:
501
+ warnings.warn(
502
+ "with `use_pickle` equals to `True` will be deprecated in future versions",
503
+ DeprecationWarning,
504
+ stacklevel=2,
505
+ )
499
506
  with open(path, "wb") as file:
500
507
  pickle.dump(self, file)
501
508
 
@@ -110,17 +110,27 @@ class SequentialDataset(abc.ABC):
110
110
 
111
111
  sequential_dict = {}
112
112
  sequential_dict["_class_name"] = self.__class__.__name__
113
- self._sequences.reset_index().to_json(base_path / "sequences.json")
113
+
114
+ df = SequentialDataset._convert_array_to_list(self._sequences)
115
+ df.reset_index().to_parquet(base_path / "sequences.parquet")
114
116
  sequential_dict["init_args"] = {
115
117
  "tensor_schema": self._tensor_schema._get_object_args(),
116
118
  "query_id_column": self._query_id_column,
117
119
  "item_id_column": self._item_id_column,
118
- "sequences_path": "sequences.json",
120
+ "sequences_path": "sequences.parquet",
119
121
  }
120
122
 
121
123
  with open(base_path / "init_args.json", "w+") as file:
122
124
  json.dump(sequential_dict, file)
123
125
 
126
+ @staticmethod
127
+ def _convert_array_to_list(df):
128
+ return df.map(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
129
+
130
+ @staticmethod
131
+ def _convert_list_to_array(df):
132
+ return df.map(lambda x: np.array(x) if isinstance(x, list) else x)
133
+
124
134
 
125
135
  class PandasSequentialDataset(SequentialDataset):
126
136
  """
@@ -149,7 +159,7 @@ class PandasSequentialDataset(SequentialDataset):
149
159
  if sequences.index.name != query_id_column:
150
160
  sequences = sequences.set_index(query_id_column)
151
161
 
152
- self._sequences = sequences
162
+ self._sequences = SequentialDataset._convert_list_to_array(sequences)
153
163
 
154
164
  def __len__(self) -> int:
155
165
  return len(self._sequences)
@@ -206,7 +216,8 @@ class PandasSequentialDataset(SequentialDataset):
206
216
  with open(base_path / "init_args.json") as file:
207
217
  sequential_dict = json.loads(file.read())
208
218
 
209
- sequences = pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"])
219
+ sequences = pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"])
220
+ sequences = cls._convert_array_to_list(sequences)
210
221
  dataset = cls(
211
222
  tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
212
223
  query_id_column=sequential_dict["init_args"]["query_id_column"],
@@ -258,18 +269,11 @@ class PolarsSequentialDataset(PandasSequentialDataset):
258
269
 
259
270
  def _convert_polars_to_pandas(self, df: PolarsDataFrame) -> PandasDataFrame:
260
271
  pandas_df = PandasDataFrame(df.to_dict(as_series=False))
261
-
262
- for column in pandas_df.select_dtypes(include="object").columns:
263
- if isinstance(pandas_df[column].iloc[0], list):
264
- pandas_df[column] = pandas_df[column].apply(lambda x: np.array(x))
265
-
272
+ pandas_df = SequentialDataset._convert_list_to_array(pandas_df)
266
273
  return pandas_df
267
274
 
268
275
  def _convert_pandas_to_polars(self, df: PandasDataFrame) -> PolarsDataFrame:
269
- for column in df.select_dtypes(include="object").columns:
270
- if isinstance(df[column].iloc[0], np.ndarray):
271
- df[column] = df[column].apply(lambda x: x.tolist())
272
-
276
+ df = SequentialDataset._convert_array_to_list(df)
273
277
  return pl.from_dict(df.to_dict("list"))
274
278
 
275
279
  @classmethod
@@ -290,7 +294,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
290
294
  with open(base_path / "init_args.json") as file:
291
295
  sequential_dict = json.loads(file.read())
292
296
 
293
- sequences = pl.DataFrame(pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"]))
297
+ sequences = pl.from_pandas(pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"]))
294
298
  dataset = cls(
295
299
  tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
296
300
  query_id_column=sequential_dict["init_args"]["query_id_column"],
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  from collections.abc import Generator, Sequence
2
3
  from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
3
4
 
@@ -5,8 +6,6 @@ import numpy as np
5
6
  import torch
6
7
  from torch.utils.data import Dataset as TorchDataset
7
8
 
8
- from replay.utils import deprecation_warning
9
-
10
9
  if TYPE_CHECKING:
11
10
  from .schema import TensorFeatureInfo, TensorMap, TensorSchema
12
11
  from .sequential_dataset import SequentialDataset
@@ -29,16 +28,12 @@ class TorchSequentialDataset(TorchDataset):
29
28
  Torch dataset for sequential recommender models
30
29
  """
31
30
 
32
- @deprecation_warning(
33
- "`padding_value` parameter will be removed in future versions. "
34
- "Instead, you should specify `padding_value` for each column in TensorSchema"
35
- )
36
31
  def __init__(
37
32
  self,
38
33
  sequential: "SequentialDataset",
39
34
  max_sequence_length: int,
40
35
  sliding_window_step: Optional[int] = None,
41
- padding_value: int = 0,
36
+ padding_value: Optional[int] = None,
42
37
  ) -> None:
43
38
  """
44
39
  :param sequential: sequential dataset
@@ -53,6 +48,15 @@ class TorchSequentialDataset(TorchDataset):
53
48
  self._sequential = sequential
54
49
  self._max_sequence_length = max_sequence_length
55
50
  self._sliding_window_step = sliding_window_step
51
+ if padding_value is not None:
52
+ warnings.warn(
53
+ "`padding_value` parameter will be removed in future versions. "
54
+ "Instead, you should specify `padding_value` for each column in TensorSchema",
55
+ DeprecationWarning,
56
+ stacklevel=2,
57
+ )
58
+ else:
59
+ padding_value = 0
56
60
  self._padding_value = padding_value
57
61
  self._index2sequence_map = self._build_index2sequence_map()
58
62
 
@@ -177,17 +181,13 @@ class TorchSequentialValidationDataset(TorchDataset):
177
181
  Torch dataset for sequential recommender models that additionally stores ground truth
178
182
  """
179
183
 
180
- @deprecation_warning(
181
- "`padding_value` parameter will be removed in future versions. "
182
- "Instead, you should specify `padding_value` for each column in TensorSchema"
183
- )
184
184
  def __init__(
185
185
  self,
186
186
  sequential: "SequentialDataset",
187
187
  ground_truth: "SequentialDataset",
188
188
  train: "SequentialDataset",
189
189
  max_sequence_length: int,
190
- padding_value: int = 0,
190
+ padding_value: Optional[int] = None,
191
191
  sliding_window_step: Optional[int] = None,
192
192
  label_feature_name: Optional[str] = None,
193
193
  ):
@@ -1,5 +1,6 @@
1
1
  import warnings
2
- from typing import Union
2
+ from os.path import join
3
+ from typing import Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
@@ -8,7 +9,11 @@ from tqdm import tqdm
8
9
 
9
10
  from replay.data.dataset import Dataset
10
11
  from replay.utils import SparkDataFrame
11
- from replay.utils.spark_utils import convert2spark
12
+ from replay.utils.spark_utils import (
13
+ convert2spark,
14
+ load_pickled_from_parquet,
15
+ save_picklable_to_parquet,
16
+ )
12
17
 
13
18
  from .base_rec import HybridRecommender
14
19
 
@@ -177,6 +182,7 @@ class LinUCB(HybridRecommender):
177
182
  _study = None # field required for proper optuna's optimization
178
183
  linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
179
184
  rel_matrix: np.array # matrix with relevance scores from predict method
185
+ _num_items: int # number of items/arms
180
186
 
181
187
  def __init__(
182
188
  self,
@@ -195,7 +201,7 @@ class LinUCB(HybridRecommender):
195
201
 
196
202
  @property
197
203
  def _init_args(self):
198
- return {"is_hybrid": self.is_hybrid}
204
+ return {"is_hybrid": self.is_hybrid, "eps": self.eps, "alpha": self.alpha}
199
205
 
200
206
  def _verify_features(self, dataset: Dataset):
201
207
  if dataset.query_features is None:
@@ -230,6 +236,7 @@ class LinUCB(HybridRecommender):
230
236
  self._num_items = item_features.shape[0]
231
237
  self._user_dim_size = user_features.shape[1] - 1
232
238
  self._item_dim_size = item_features.shape[1] - 1
239
+ self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
233
240
 
234
241
  # now initialize an arm object for each potential arm instance
235
242
  if self.is_hybrid:
@@ -248,11 +255,14 @@ class LinUCB(HybridRecommender):
248
255
  ]
249
256
 
250
257
  for i in tqdm(range(self._num_items)):
251
- B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
252
- idxs_list = B[feature_schema.query_id_column].values
253
- rel_list = B[feature_schema.interactions_rating_column].values
258
+ B = log.loc[ # noqa: N806
259
+ (log[feature_schema.item_id_column] == i)
260
+ & (log[feature_schema.query_id_column].isin(self._user_idxs_list))
261
+ ]
254
262
  if not B.empty:
255
263
  # if we have at least one user interacting with the hand i
264
+ idxs_list = B[feature_schema.query_id_column].values
265
+ rel_list = B[feature_schema.interactions_rating_column].values
256
266
  cur_usrs = scs.csr_matrix(
257
267
  user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
258
268
  .drop(columns=[feature_schema.query_id_column])
@@ -284,11 +294,14 @@ class LinUCB(HybridRecommender):
284
294
  ]
285
295
 
286
296
  for i in range(self._num_items):
287
- B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
288
- idxs_list = B[feature_schema.query_id_column].values # noqa: F841
289
- rel_list = B[feature_schema.interactions_rating_column].values
297
+ B = log.loc[ # noqa: N806
298
+ (log[feature_schema.item_id_column] == i)
299
+ & (log[feature_schema.query_id_column].isin(self._user_idxs_list))
300
+ ]
290
301
  if not B.empty:
291
302
  # if we have at least one user interacting with the hand i
303
+ idxs_list = B[feature_schema.query_id_column].values # noqa: F841
304
+ rel_list = B[feature_schema.interactions_rating_column].values
292
305
  cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
293
306
  columns=[feature_schema.query_id_column]
294
307
  )
@@ -318,8 +331,10 @@ class LinUCB(HybridRecommender):
318
331
  user_features = dataset.query_features
319
332
  item_features = dataset.item_features
320
333
  big_k = min(oversample * k, item_features.shape[0])
334
+ self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
321
335
 
322
336
  users = users.toPandas()
337
+ users = users[users[feature_schema.query_id_column].isin(self._user_idxs_list)]
323
338
  num_user_pred = users.shape[0]
324
339
  rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
325
340
 
@@ -404,3 +419,34 @@ class LinUCB(HybridRecommender):
404
419
  warnings.warn(warn_msg)
405
420
  dataset.to_spark()
406
421
  return convert2spark(res_df)
422
+
423
+ def _save_model(self, path: str, additional_params: Optional[dict] = None):
424
+ super()._save_model(path, additional_params)
425
+
426
+ save_picklable_to_parquet(self.linucb_arms, join(path, "linucb_arms.dump"))
427
+
428
+ if self.is_hybrid:
429
+ linucb_hybrid_shared_params = {
430
+ "A_0": self.A_0,
431
+ "A_0_inv": self.A_0_inv,
432
+ "b_0": self.b_0,
433
+ "beta": self.beta,
434
+ }
435
+ save_picklable_to_parquet(
436
+ linucb_hybrid_shared_params,
437
+ join(path, "linucb_hybrid_shared_params.dump"),
438
+ )
439
+
440
+ def _load_model(self, path: str):
441
+ super()._load_model(path)
442
+
443
+ loaded_linucb_arms = load_pickled_from_parquet(join(path, "linucb_arms.dump"))
444
+ self.linucb_arms = loaded_linucb_arms
445
+ self._num_items = len(loaded_linucb_arms)
446
+
447
+ if self.is_hybrid:
448
+ loaded_linucb_hybrid_shared_params = load_pickled_from_parquet(
449
+ join(path, "linucb_hybrid_shared_params.dump")
450
+ )
451
+ for param, value in loaded_linucb_hybrid_shared_params.items():
452
+ setattr(self, param, value)
@@ -12,7 +12,6 @@ from replay.data.nn import (
12
12
  TorchSequentialDataset,
13
13
  TorchSequentialValidationDataset,
14
14
  )
15
- from replay.utils import deprecation_warning
16
15
 
17
16
 
18
17
  class Bert4RecTrainingBatch(NamedTuple):
@@ -89,10 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
89
88
  Dataset that generates samples to train BERT-like model
90
89
  """
91
90
 
92
- @deprecation_warning(
93
- "`padding_value` parameter will be removed in future versions. "
94
- "Instead, you should specify `padding_value` for each column in TensorSchema"
95
- )
96
91
  def __init__(
97
92
  self,
98
93
  sequential: SequentialDataset,
@@ -101,7 +96,7 @@ class Bert4RecTrainingDataset(TorchDataset):
101
96
  sliding_window_step: Optional[int] = None,
102
97
  label_feature_name: Optional[str] = None,
103
98
  custom_masker: Optional[Bert4RecMasker] = None,
104
- padding_value: int = 0,
99
+ padding_value: Optional[int] = None,
105
100
  ) -> None:
106
101
  """
107
102
  :param sequential: Sequential dataset with training data.
@@ -181,15 +176,11 @@ class Bert4RecPredictionDataset(TorchDataset):
181
176
  Dataset that generates samples to infer BERT-like model
182
177
  """
183
178
 
184
- @deprecation_warning(
185
- "`padding_value` parameter will be removed in future versions. "
186
- "Instead, you should specify `padding_value` for each column in TensorSchema"
187
- )
188
179
  def __init__(
189
180
  self,
190
181
  sequential: SequentialDataset,
191
182
  max_sequence_length: int,
192
- padding_value: int = 0,
183
+ padding_value: Optional[int] = None,
193
184
  ) -> None:
194
185
  """
195
186
  :param sequential: Sequential dataset with data to make predictions at.
@@ -239,17 +230,13 @@ class Bert4RecValidationDataset(TorchDataset):
239
230
  Dataset that generates samples to infer and validate BERT-like model
240
231
  """
241
232
 
242
- @deprecation_warning(
243
- "`padding_value` parameter will be removed in future versions. "
244
- "Instead, you should specify `padding_value` for each column in TensorSchema"
245
- )
246
233
  def __init__(
247
234
  self,
248
235
  sequential: SequentialDataset,
249
236
  ground_truth: SequentialDataset,
250
237
  train: SequentialDataset,
251
238
  max_sequence_length: int,
252
- padding_value: int = 0,
239
+ padding_value: Optional[int] = None,
253
240
  label_feature_name: Optional[str] = None,
254
241
  ):
255
242
  """
@@ -51,7 +51,7 @@ class RemoveSeenItems(BasePostProcessor):
51
51
 
52
52
  def _compute_scores(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> torch.Tensor:
53
53
  flat_seen_item_ids = self._get_flat_seen_item_ids(query_ids)
54
- return self._fill_item_ids(scores, flat_seen_item_ids, -np.inf)
54
+ return self._fill_item_ids(scores.clone(), flat_seen_item_ids, -np.inf)
55
55
 
56
56
  def _fill_item_ids(
57
57
  self,
@@ -10,7 +10,6 @@ from replay.data.nn import (
10
10
  TorchSequentialDataset,
11
11
  TorchSequentialValidationDataset,
12
12
  )
13
- from replay.utils import deprecation_warning
14
13
 
15
14
 
16
15
  class SasRecTrainingBatch(NamedTuple):
@@ -31,17 +30,13 @@ class SasRecTrainingDataset(TorchDataset):
31
30
  Dataset that generates samples to train SasRec-like model
32
31
  """
33
32
 
34
- @deprecation_warning(
35
- "`padding_value` parameter will be removed in future versions. "
36
- "Instead, you should specify `padding_value` for each column in TensorSchema"
37
- )
38
33
  def __init__(
39
34
  self,
40
35
  sequential: SequentialDataset,
41
36
  max_sequence_length: int,
42
37
  sequence_shift: int = 1,
43
38
  sliding_window_step: Optional[None] = None,
44
- padding_value: int = 0,
39
+ padding_value: Optional[int] = None,
45
40
  label_feature_name: Optional[str] = None,
46
41
  ) -> None:
47
42
  """
@@ -127,15 +122,11 @@ class SasRecPredictionDataset(TorchDataset):
127
122
  Dataset that generates samples to infer SasRec-like model
128
123
  """
129
124
 
130
- @deprecation_warning(
131
- "`padding_value` parameter will be removed in future versions. "
132
- "Instead, you should specify `padding_value` for each column in TensorSchema"
133
- )
134
125
  def __init__(
135
126
  self,
136
127
  sequential: SequentialDataset,
137
128
  max_sequence_length: int,
138
- padding_value: int = 0,
129
+ padding_value: Optional[int] = None,
139
130
  ) -> None:
140
131
  """
141
132
  :param sequential: Sequential dataset with data to make predictions at.
@@ -179,17 +170,13 @@ class SasRecValidationDataset(TorchDataset):
179
170
  Dataset that generates samples to infer and validate SasRec-like model
180
171
  """
181
172
 
182
- @deprecation_warning(
183
- "`padding_value` parameter will be removed in future versions. "
184
- "Instead, you should specify `padding_value` for each column in TensorSchema"
185
- )
186
173
  def __init__(
187
174
  self,
188
175
  sequential: SequentialDataset,
189
176
  ground_truth: SequentialDataset,
190
177
  train: SequentialDataset,
191
178
  max_sequence_length: int,
192
- padding_value: int = 0,
179
+ padding_value: Optional[int] = None,
193
180
  label_feature_name: Optional[str] = None,
194
181
  ):
195
182
  """
@@ -15,4 +15,3 @@ from .types import (
15
15
  PolarsDataFrame,
16
16
  SparkDataFrame,
17
17
  )
18
- from .warnings import deprecation_warning
@@ -1,62 +0,0 @@
1
- """
2
- Most metrics require dataframe with recommendations
3
- and dataframe with ground truth values —
4
- which objects each user interacted with.
5
-
6
- - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
- predictions of a recommender system,
8
- DataFrame with columns ``[user_id, item_id, relevance]``
9
- - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
- test data, DataFrame with columns
11
- ``[user_id, item_id, timestamp, relevance]``
12
-
13
- Metric is calculated for all users, presented in ``ground_truth``
14
- for accurate metric calculation in case when the recommender system generated
15
- recommendation not for all users. It is assumed, that all users,
16
- we want to calculate metric for, have positive interactions.
17
-
18
- But if we have users, who observed the recommendations, but have not responded,
19
- those users will be ignored and metric will be overestimated.
20
- For such case we propose additional optional parameter ``ground_truth_users``,
21
- the dataframe with all users, which should be considered during the metric calculation.
22
-
23
- - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
- full list of users to calculate metric for, DataFrame with ``user_id`` column
25
-
26
- Every metric is calculated using top ``K`` items for each user.
27
- It is also possible to calculate metrics
28
- using multiple values for ``K`` simultaneously.
29
- In this case the result will be a dictionary and not a number.
30
-
31
- Make sure your recommendations do not contain user-item duplicates
32
- as duplicates could lead to the wrong calculation results.
33
-
34
- - k (Union[Iterable[int], int]):
35
- a single number or a list, specifying the
36
- truncation length for recommendation list for each user
37
-
38
- By default, metrics are averaged by users,
39
- but you can alternatively use method ``metric.median``.
40
- Also, you can get the lower bound
41
- of ``conf_interval`` for a given ``alpha``.
42
-
43
- Diversity metrics require extra parameters on initialization stage,
44
- but do not use ``ground_truth`` parameter.
45
-
46
- For each metric, a formula for its calculation is given, because this is
47
- important for the correct comparison of algorithms, as mentioned in our
48
- `article <https://arxiv.org/abs/2206.12858>`_.
49
- """
50
-
51
- from replay.experimental.metrics.base_metric import Metric, NCISMetric
52
- from replay.experimental.metrics.coverage import Coverage
53
- from replay.experimental.metrics.hitrate import HitRate
54
- from replay.experimental.metrics.map import MAP
55
- from replay.experimental.metrics.mrr import MRR
56
- from replay.experimental.metrics.ncis_precision import NCISPrecision
57
- from replay.experimental.metrics.ndcg import NDCG
58
- from replay.experimental.metrics.precision import Precision
59
- from replay.experimental.metrics.recall import Recall
60
- from replay.experimental.metrics.rocauc import RocAuc
61
- from replay.experimental.metrics.surprisal import Surprisal
62
- from replay.experimental.metrics.unexpectedness import Unexpectedness