replay-rec 0.17.1rc0__tar.gz → 0.18.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/PKG-INFO +13 -11
  2. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/pyproject.toml +18 -15
  3. replay_rec-0.18.0rc0/replay/__init__.py +3 -0
  4. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset.py +3 -2
  5. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +1 -0
  6. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/schema.py +5 -5
  7. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/__init__.py +1 -0
  8. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/base_metric.py +1 -0
  9. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_rec.py +7 -7
  10. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/cql.py +2 -0
  11. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/ddpg.py +6 -4
  12. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/lightfm_wrap.py +2 -2
  13. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/mult_vae.py +1 -0
  14. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/neuromf.py +1 -0
  15. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/scala_als.py +2 -2
  16. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/data_preparator.py +2 -1
  17. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/padder.py +1 -1
  18. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +1 -1
  19. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/model_handler.py +7 -2
  20. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/__init__.py +1 -0
  21. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/als.py +1 -1
  22. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/base_rec.py +7 -7
  23. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
  24. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
  25. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/model.py +5 -112
  26. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/model.py +8 -5
  27. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/optimization/optuna_objective.py +1 -0
  28. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/converter.py +1 -1
  29. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/filters.py +19 -18
  30. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/history_based_fp.py +5 -5
  31. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/label_encoder.py +1 -0
  32. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/scenarios/__init__.py +1 -0
  33. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/last_n_splitter.py +1 -1
  34. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/time_splitter.py +1 -1
  35. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/two_stage_splitter.py +8 -6
  36. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/distributions.py +1 -0
  37. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/session_handler.py +3 -3
  38. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/spark_utils.py +2 -2
  39. replay_rec-0.17.1rc0/replay/__init__.py +0 -2
  40. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/LICENSE +0 -0
  41. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/NOTICE +0 -0
  42. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/README.md +0 -0
  43. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/__init__.py +0 -0
  44. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
  45. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/__init__.py +0 -0
  46. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/sequence_tokenizer.py +0 -0
  47. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/sequential_dataset.py +0 -0
  48. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/torch_sequential_dataset.py +0 -0
  49. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/utils.py +0 -0
  50. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/schema.py +0 -0
  51. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/spark_schema.py +0 -0
  52. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/__init__.py +0 -0
  53. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/coverage.py +0 -0
  54. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/experiment.py +0 -0
  55. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
  56. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/map.py +0 -0
  57. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/mrr.py +0 -0
  58. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
  59. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
  60. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/precision.py +0 -0
  61. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/recall.py +0 -0
  62. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
  63. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
  64. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
  65. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/__init__.py +0 -0
  66. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/admm_slim.py +0 -0
  67. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_neighbour_rec.py +0 -0
  68. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_torch_rec.py +0 -0
  69. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
  70. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +0 -0
  71. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
  72. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
  73. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/utils.py +0 -0
  74. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  75. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -0
  76. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
  77. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
  78. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/nn/data/schema_builder.py +0 -0
  79. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
  80. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/sequence_generator.py +0 -0
  81. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
  82. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -0
  83. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -0
  84. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -0
  85. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +0 -0
  86. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/__init__.py +0 -0
  87. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +0 -0
  88. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/__init__.py +0 -0
  89. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/logger.py +0 -0
  90. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/session_handler.py +0 -0
  91. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/base_metric.py +0 -0
  92. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/categorical_diversity.py +0 -0
  93. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/coverage.py +0 -0
  94. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/descriptors.py +0 -0
  95. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/experiment.py +0 -0
  96. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/hitrate.py +0 -0
  97. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/map.py +0 -0
  98. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/mrr.py +0 -0
  99. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/ndcg.py +0 -0
  100. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/novelty.py +0 -0
  101. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/offline_metrics.py +0 -0
  102. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/precision.py +0 -0
  103. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/recall.py +0 -0
  104. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/rocauc.py +0 -0
  105. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/surprisal.py +0 -0
  106. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/torch_metrics_builder.py +0 -0
  107. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/unexpectedness.py +0 -0
  108. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/__init__.py +0 -0
  109. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/association_rules.py +0 -0
  110. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/base_neighbour_rec.py +0 -0
  111. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/cat_pop_rec.py +0 -0
  112. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/cluster.py +0 -0
  113. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/__init__.py +0 -0
  114. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/__init__.py +0 -0
  115. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/ann_mixin.py +0 -0
  116. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/__init__.py +0 -0
  117. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  118. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  119. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  120. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  121. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  122. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  123. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  124. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  125. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  126. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  127. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  128. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  129. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  130. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  131. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  132. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  133. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  134. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  135. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  136. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  137. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  138. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/utils.py +0 -0
  139. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/kl_ucb.py +0 -0
  140. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/knn.py +0 -0
  141. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/__init__.py +0 -0
  142. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  143. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  144. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
  145. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  146. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
  147. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  148. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  149. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  150. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
  151. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  152. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  153. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
  154. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  155. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
  156. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  157. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/pop_rec.py +0 -0
  158. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/query_pop_rec.py +0 -0
  159. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/random_rec.py +0 -0
  160. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/slim.py +0 -0
  161. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/thompson_sampling.py +0 -0
  162. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/ucb.py +0 -0
  163. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/wilson.py +0 -0
  164. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/word2vec.py +0 -0
  165. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/optimization/__init__.py +0 -0
  166. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/__init__.py +0 -0
  167. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/sessionizer.py +0 -0
  168. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/scenarios/fallback.py +0 -0
  169. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/__init__.py +0 -0
  170. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/base_splitter.py +0 -0
  171. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/cold_user_random_splitter.py +0 -0
  172. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/k_folds.py +0 -0
  173. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/new_users_splitter.py +0 -0
  174. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/random_splitter.py +0 -0
  175. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/ratio_splitter.py +0 -0
  176. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/__init__.py +0 -0
  177. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/common.py +0 -0
  178. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
  179. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/model_handler.py +0 -0
  180. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/time.py +0 -0
  181. {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/types.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.1rc0
3
+ Version: 0.18.0rc0
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
7
7
  Author: AI Lab
8
- Requires-Python: >=3.8.1,<3.11
8
+ Requires-Python: >=3.8.1,<3.12
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Environment :: Console
11
11
  Classifier: Intended Audience :: Developers
@@ -16,32 +16,34 @@ Classifier: Operating System :: Unix
16
16
  Classifier: Programming Language :: Python :: 3
17
17
  Classifier: Programming Language :: Python :: 3.9
18
18
  Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
19
20
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
21
  Provides-Extra: all
21
22
  Provides-Extra: spark
22
23
  Provides-Extra: torch
23
24
  Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
25
+ Requires-Dist: fixed-install-nmslib (==2.1.2)
24
26
  Requires-Dist: gym (>=0.26.0,<0.27.0)
25
- Requires-Dist: hnswlib (==0.7.0)
27
+ Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
26
28
  Requires-Dist: implicit (>=0.7.0,<0.8.0)
27
29
  Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
28
30
  Requires-Dist: lightfm (==1.17)
29
- Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
31
+ Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "all"
30
32
  Requires-Dist: llvmlite (>=0.32.1)
31
- Requires-Dist: nmslib (==2.1.1)
32
33
  Requires-Dist: numba (>=0.50)
33
34
  Requires-Dist: numpy (>=1.20.0)
34
35
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
36
  Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
- Requires-Dist: polars (>=0.20.7,<0.21.0)
37
- Requires-Dist: psutil (>=5.9.5,<5.10.0)
37
+ Requires-Dist: polars (>=1.0.0,<1.1.0)
38
+ Requires-Dist: psutil (>=6.0.0,<6.1.0)
38
39
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
40
+ Requires-Dist: pyspark (>=3.0,<3.5) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
41
+ Requires-Dist: pyspark (>=3.4,<3.5) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
40
42
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
- Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
43
+ Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
42
44
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
43
- Requires-Dist: scipy (>=1.8.1,<1.9.0)
44
- Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
45
+ Requires-Dist: scipy (>=1.8.1,<2.0.0)
46
+ Requires-Dist: torch (>=1.8,<=2.4.0) ; extra == "torch" or extra == "all"
45
47
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
46
48
  Description-Content-Type: text/markdown
47
49
 
@@ -7,7 +7,7 @@ build-backend = "poetry_dynamic_versioning.backend"
7
7
 
8
8
  [tool.black]
9
9
  line-length = 120
10
- target-versions = ["py38", "py39", "py310"]
10
+ target-versions = ["py38", "py39", "py310", "py311"]
11
11
 
12
12
  [tool.poetry]
13
13
  name = "replay-rec"
@@ -40,29 +40,32 @@ classifiers = [
40
40
  exclude = [
41
41
  "replay/conftest.py",
42
42
  ]
43
- version = "0.17.1.preview"
43
+ version = "0.18.0.preview"
44
44
 
45
45
  [tool.poetry.dependencies]
46
- python = ">=3.8.1, <3.11"
46
+ python = ">=3.8.1, <3.12"
47
47
  numpy = ">=1.20.0"
48
- pandas = ">=1.3.5,<=2.2.2"
49
- polars = "~0.20.7"
48
+ pandas = ">=1.3.5, <=2.2.2"
49
+ polars = "~1.0.0"
50
50
  optuna = "~3.2.0"
51
- scipy = "~1.8.1"
52
- psutil = "~5.9.5"
53
- pyspark = {version = ">=3.0,<3.5", optional = true}
51
+ scipy = "^1.8.1"
52
+ psutil = "~6.0.0"
54
53
  scikit-learn = "^1.0.2"
55
54
  pyarrow = ">=12.0.1"
56
- nmslib = "2.1.1"
57
- hnswlib = "0.7.0"
58
- torch = "^1.8"
59
- lightning = "^2.0.2"
55
+ fixed-install-nmslib = "2.1.2"
56
+ hnswlib = "^0.7.0"
57
+ pyspark = [
58
+ {version = ">=3.4,<3.5", python = ">=3.11,<3.12"},
59
+ {version = ">=3.0,<3.5", python = ">=3.8.1,<3.11"},
60
+ ]
61
+ torch = ">=1.8, <=2.4.0"
62
+ lightning = ">=2.0.2, <=2.4.0"
60
63
  pytorch-ranger = "^0.1.1"
61
64
  lightfm = "1.17"
62
65
  lightautoml = "~0.3.1"
63
66
  numba = ">=0.50"
64
67
  llvmlite = ">=0.32.1"
65
- sb-obp = "^0.5.7"
68
+ sb-obp = "^0.5.8"
66
69
  d3rlpy = "^2.0.4"
67
70
  implicit = "~0.7.0"
68
71
  gym = "^0.26.0"
@@ -77,7 +80,7 @@ jupyter = "~1.0.0"
77
80
  jupyterlab = "^3.6.0"
78
81
  pytest = ">=7.1.0"
79
82
  pytest-cov = ">=3.0.0"
80
- statsmodels = "~0.13.5"
83
+ statsmodels = "~0.14.0"
81
84
  black = ">=23.3.0"
82
85
  ruff = ">=0.0.261"
83
86
  toml-sort = "^0.23.0"
@@ -92,7 +95,7 @@ data-science-types = "0.2.23"
92
95
 
93
96
  [tool.poetry-dynamic-versioning]
94
97
  enable = false
95
- format-jinja = """0.17.1{{ env['PACKAGE_SUFFIX'] }}"""
98
+ format-jinja = """0.18.0{{ env['PACKAGE_SUFFIX'] }}"""
96
99
  vcs = "git"
97
100
 
98
101
  [tool.ruff]
@@ -0,0 +1,3 @@
1
+ """ RecSys library """
2
+
3
+ __version__ = "0.18.0.preview"
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ``Dataset`` universal dataset class for manipulating interactions and feed data to models.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
6
7
  import json
@@ -606,7 +607,7 @@ class Dataset:
606
607
  if self.is_pandas:
607
608
  min_id = data[column].min()
608
609
  elif self.is_spark:
609
- min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
610
+ min_id = data.agg(sf.min(column).alias("min_index")).first()[0]
610
611
  else:
611
612
  min_id = data[column].min()
612
613
  if min_id < 0:
@@ -616,7 +617,7 @@ class Dataset:
616
617
  if self.is_pandas:
617
618
  max_id = data[column].max()
618
619
  elif self.is_spark:
619
- max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
620
+ max_id = data.agg(sf.max(column).alias("max_index")).first()[0]
620
621
  else:
621
622
  max_id = data[column].max()
622
623
 
@@ -4,6 +4,7 @@ Contains classes for encoding categorical data
4
4
  ``LabelEncoderTransformWarning`` new category of warning for DatasetLabelEncoder.
5
5
  ``DatasetLabelEncoder`` to encode categorical features in `Dataset` objects.
6
6
  """
7
+
7
8
  import warnings
8
9
  from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
10
 
@@ -418,11 +418,11 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
418
418
  "feature_type": feature.feature_type.name,
419
419
  "is_seq": feature.is_seq,
420
420
  "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
- "feature_sources": [
422
- {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
- ]
424
- if feature.feature_sources
425
- else None,
421
+ "feature_sources": (
422
+ [{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources]
423
+ if feature.feature_sources
424
+ else None
425
+ ),
426
426
  "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
427
  "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
428
  "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
@@ -47,6 +47,7 @@ For each metric, a formula for its calculation is given, because this is
47
47
  important for the correct comparison of algorithms, as mentioned in our
48
48
  `article <https://arxiv.org/abs/2206.12858>`_.
49
49
  """
50
+
50
51
  from replay.experimental.metrics.base_metric import Metric, NCISMetric
51
52
  from replay.experimental.metrics.coverage import Coverage
52
53
  from replay.experimental.metrics.hitrate import HitRate
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Base classes for quality and diversity metrics.
3
3
  """
4
+
4
5
  import logging
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Dict, List, Optional, Union
@@ -86,8 +86,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
86
86
  self.fit_items = sf.broadcast(items)
87
87
  self._num_users = self.fit_users.count()
88
88
  self._num_items = self.fit_items.count()
89
- self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).collect()[0][0] + 1
90
- self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).collect()[0][0] + 1
89
+ self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
90
+ self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
91
91
  self._fit(log, user_features, item_features)
92
92
 
93
93
  @abstractmethod
@@ -122,7 +122,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
122
122
  # count maximal number of items seen by users
123
123
  max_seen = 0
124
124
  if num_seen.count() > 0:
125
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
125
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
126
126
 
127
127
  # crop recommendations to first k + max_seen items for each user
128
128
  recs = recs.withColumn(
@@ -335,7 +335,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
335
335
  setattr(
336
336
  self,
337
337
  f"_{entity}_dim_size",
338
- getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).collect()[0][0] + 1,
338
+ getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
339
339
  )
340
340
  return getattr(self, f"_{entity}_dim_size")
341
341
 
@@ -1088,7 +1088,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1088
1088
  Calculating a fill value a the minimal relevance
1089
1089
  calculated during model training multiplied by weight.
1090
1090
  """
1091
- return item_popularity.select(sf.min("relevance")).collect()[0][0] * weight
1091
+ return item_popularity.select(sf.min("relevance")).first()[0] * weight
1092
1092
 
1093
1093
  @staticmethod
1094
1094
  def _check_relevance(log: SparkDataFrame):
@@ -1113,7 +1113,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1113
1113
  max_hist_len = (
1114
1114
  (log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
1115
1115
  .select(sf.max("items_count"))
1116
- .collect()[0][0]
1116
+ .first()[0]
1117
1117
  )
1118
1118
  # all users have empty history
1119
1119
  if max_hist_len is None:
@@ -1146,7 +1146,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1146
1146
  users = users.join(user_to_num_items, on="user_idx", how="left")
1147
1147
  users = users.fillna(0, "num_items")
1148
1148
  # 'selected_item_popularity' truncation by k + max_seen
1149
- max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1149
+ max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1150
1150
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1151
1151
  return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1152
1152
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Using CQL implementation from `d3rlpy` package.
3
3
  """
4
+
4
5
  import io
5
6
  import logging
6
7
  import tempfile
@@ -402,6 +403,7 @@ class MdpDatasetBuilder:
402
403
  top_k (int): the number of top user items to learn predicting.
403
404
  action_randomization_scale (float): the scale of action randomization gaussian noise.
404
405
  """
406
+
405
407
  logger: logging.Logger
406
408
  top_k: int
407
409
  action_randomization_scale: float
@@ -704,13 +704,15 @@ class DDPG(Recommender):
704
704
  :param data: pandas DataFrame
705
705
  """
706
706
  data = data[["user_idx", "item_idx", "relevance"]]
707
- train_data = data.values.tolist()
707
+ users = data["user_idx"].values.tolist()
708
+ items = data["item_idx"].values.tolist()
709
+ scores = data["relevance"].values.tolist()
708
710
 
709
- user_num = data["user_idx"].max() + 1
710
- item_num = data["item_idx"].max() + 1
711
+ user_num = max(users) + 1
712
+ item_num = max(items) + 1
711
713
 
712
714
  train_mat = defaultdict(float)
713
- for user, item, rel in train_data:
715
+ for user, item, rel in zip(users, items, scores):
714
716
  train_mat[user, item] = rel
715
717
  train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
716
718
  dict.update(train_matrix, train_mat)
@@ -98,12 +98,12 @@ class LightFMWrap(HybridRecommender):
98
98
  fit_dim = getattr(self, f"_{entity}_dim")
99
99
  matrix_height = max(
100
100
  fit_dim,
101
- log_ids_list.select(sf.max(idx_col_name)).collect()[0][0] + 1,
101
+ log_ids_list.select(sf.max(idx_col_name)).first()[0] + 1,
102
102
  )
103
103
  if not feature_table.rdd.isEmpty():
104
104
  matrix_height = max(
105
105
  matrix_height,
106
- feature_table.select(sf.max(idx_col_name)).collect()[0][0] + 1,
106
+ feature_table.select(sf.max(idx_col_name)).first()[0] + 1,
107
107
  )
108
108
 
109
109
  features_np = (
@@ -2,6 +2,7 @@
2
2
  MultVAE implementation
3
3
  (Variational Autoencoders for Collaborative Filtering)
4
4
  """
5
+
5
6
  from typing import Optional, Tuple
6
7
 
7
8
  import numpy as np
@@ -3,6 +3,7 @@ Generalized Matrix Factorization (GMF),
3
3
  Multi-Layer Perceptron (MLP),
4
4
  Neural Matrix Factorization (MLP + GMF).
5
5
  """
6
+
6
7
  from typing import List, Optional
7
8
 
8
9
  import numpy as np
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy("user_idx")
116
116
  .agg(sf.count("user_idx").alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_log if max_seen_in_log is not None else 0
121
121
 
@@ -280,7 +280,7 @@ class ScalaALSWrap(ALSWrap, ANNMixin):
280
280
  .groupBy("user_idx")
281
281
  .agg(sf.count("user_idx").alias("num_seen"))
282
282
  .select(sf.max("num_seen"))
283
- .collect()[0][0]
283
+ .first()[0]
284
284
  )
285
285
  max_seen = max_seen_in_log if max_seen_in_log is not None else 0
286
286
 
@@ -6,6 +6,7 @@ Contains classes for data preparation and categorical features transformation.
6
6
  ``ToNumericFeatureTransformer`` leaves only numerical features
7
7
  by one-hot encoding of some features and deleting the others.
8
8
  """
9
+
9
10
  import json
10
11
  import logging
11
12
  import string
@@ -699,7 +700,7 @@ if PYSPARK_AVAILABLE:
699
700
  return
700
701
 
701
702
  cat_feat_values_dict = {
702
- name: (spark_df.select(sf.collect_set(sf.col(name))).collect()[0][0]) for name in self.cat_cols_list
703
+ name: (spark_df.select(sf.collect_set(sf.col(name))).first()[0]) for name in self.cat_cols_list
703
704
  }
704
705
  self.expressions_list = [
705
706
  sf.when(sf.col(col_name) == cur_name, 1)
@@ -179,7 +179,7 @@ class Padder:
179
179
  self, df_transformed: SparkDataFrame, col: str, pad_value: Union[str, float, List, None]
180
180
  ) -> SparkDataFrame:
181
181
  if self.array_size == -1:
182
- max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).collect()[0][0]
182
+ max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).first()[0]
183
183
  else:
184
184
  max_array_size = self.array_size
185
185
 
@@ -383,7 +383,7 @@ class TwoStagesScenario(HybridRecommender):
383
383
  log_to_filter_cached.groupBy("user_idx")
384
384
  .agg(sf.count("item_idx").alias("num_positives"))
385
385
  .select(sf.max("num_positives"))
386
- .collect()[0][0]
386
+ .first()[0]
387
387
  )
388
388
 
389
389
  pred = model._predict(
@@ -170,8 +170,13 @@ def load_indexer(path: str) -> Indexer:
170
170
 
171
171
  indexer = Indexer(**args)
172
172
 
173
- indexer.user_type = getattr(st, user_type)()
174
- indexer.item_type = getattr(st, item_type)()
173
+ if user_type.endswith("()"):
174
+ user_type = user_type[:-2]
175
+ item_type = item_type[:-2]
176
+ user_type = getattr(st, user_type)
177
+ item_type = getattr(st, item_type)
178
+ indexer.user_type = user_type()
179
+ indexer.item_type = item_type()
175
180
 
176
181
  indexer.user_indexer = StringIndexerModel.load(join(path, "user_indexer"))
177
182
  indexer.item_indexer = StringIndexerModel.load(join(path, "item_indexer"))
@@ -42,6 +42,7 @@ For each metric, a formula for its calculation is given, because this is
42
42
  important for the correct comparison of algorithms, as mentioned in our
43
43
  `article <https://arxiv.org/abs/2206.12858>`_.
44
44
  """
45
+
45
46
  from .base_metric import Metric
46
47
  from .categorical_diversity import CategoricalDiversity
47
48
  from .coverage import Coverage
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy(self.query_column)
116
116
  .agg(sf.count(self.query_column).alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_interactions if max_seen_in_interactions is not None else 0
121
121
 
@@ -401,8 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
401
401
  self.fit_items = sf.broadcast(items)
402
402
  self._num_queries = self.fit_queries.count()
403
403
  self._num_items = self.fit_items.count()
404
- self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
405
- self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
404
+ self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).first()[0] + 1
405
+ self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).first()[0] + 1
406
406
  self._fit(dataset)
407
407
 
408
408
  @abstractmethod
@@ -431,7 +431,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
431
431
  # count maximal number of items seen by queries
432
432
  max_seen = 0
433
433
  if num_seen.count() > 0:
434
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
434
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
435
435
 
436
436
  # crop recommendations to first k + max_seen items for each query
437
437
  recs = recs.withColumn(
@@ -708,7 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
708
708
  setattr(
709
709
  self,
710
710
  dim_size,
711
- fit_entities.agg({column: "max"}).collect()[0][0] + 1,
711
+ fit_entities.agg({column: "max"}).first()[0] + 1,
712
712
  )
713
713
  return getattr(self, dim_size)
714
714
 
@@ -1426,7 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1426
1426
  Calculating a fill value a the minimal rating
1427
1427
  calculated during model training multiplied by weight.
1428
1428
  """
1429
- return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
1429
+ return item_popularity.select(sf.min(rating_column)).first()[0] * weight
1430
1430
 
1431
1431
  @staticmethod
1432
1432
  def _check_rating(dataset: Dataset):
@@ -1460,7 +1460,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1460
1460
  .agg(sf.countDistinct(item_column).alias("items_count"))
1461
1461
  )
1462
1462
  .select(sf.max("items_count"))
1463
- .collect()[0][0]
1463
+ .first()[0]
1464
1464
  )
1465
1465
  # all queries have empty history
1466
1466
  if max_hist_len is None:
@@ -1495,7 +1495,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1495
1495
  queries = queries.join(query_to_num_items, on=self.query_column, how="left")
1496
1496
  queries = queries.fillna(0, "num_items")
1497
1497
  # 'selected_item_popularity' truncation by k + max_seen
1498
- max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1498
+ max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1499
1499
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1500
1500
  return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1501
1501
 
@@ -32,9 +32,9 @@ class NmslibFilterIndexInferer(IndexInferer):
32
32
  index = index_store.load_index(
33
33
  init_index=lambda: create_nmslib_index_instance(index_params),
34
34
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
35
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
36
- if index_params.ef_s
37
- else None,
35
+ configure_index=lambda index: (
36
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
37
+ ),
38
38
  )
39
39
 
40
40
  # max number of items to retrieve per batch
@@ -30,9 +30,9 @@ class NmslibIndexInferer(IndexInferer):
30
30
  index = index_store.load_index(
31
31
  init_index=lambda: create_nmslib_index_instance(index_params),
32
32
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
33
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
34
- if index_params.ef_s
35
- else None,
33
+ configure_index=lambda index: (
34
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
35
+ ),
36
36
  )
37
37
 
38
38
  user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)