replay-rec 0.19.0rc0__tar.gz → 0.20.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/PKG-INFO +58 -42
  2. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/README.md +32 -1
  3. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/pyproject.toml +56 -70
  4. replay_rec-0.20.0rc0/replay/__init__.py +7 -0
  5. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset.py +19 -18
  6. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  7. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/schema.py +9 -18
  8. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/sequence_tokenizer.py +54 -47
  9. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/sequential_dataset.py +16 -11
  10. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/torch_sequential_dataset.py +18 -16
  11. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/utils.py +3 -2
  12. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/schema.py +3 -12
  13. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/base_metric.py +6 -5
  14. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/coverage.py +5 -5
  15. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/experiment.py +2 -2
  16. replay_rec-0.20.0rc0/replay/experimental/models/__init__.py +50 -0
  17. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/admm_slim.py +59 -7
  18. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_neighbour_rec.py +6 -10
  19. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_rec.py +58 -12
  20. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_torch_rec.py +2 -2
  21. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/cql.py +6 -6
  22. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/ddpg.py +47 -38
  23. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +3 -3
  24. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/utils.py +4 -5
  25. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +5 -5
  26. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/lightfm_wrap.py +4 -3
  27. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/mult_vae.py +4 -4
  28. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/neural_ts.py +13 -13
  29. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/neuromf.py +4 -4
  30. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/scala_als.py +14 -17
  31. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/nn/data/schema_builder.py +4 -4
  32. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/data_preparator.py +13 -13
  33. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/padder.py +7 -7
  34. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/sequence_generator.py +7 -7
  35. replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  36. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +5 -5
  37. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +4 -4
  38. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +3 -5
  39. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +4 -4
  40. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +18 -18
  41. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/session_handler.py +2 -2
  42. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/base_metric.py +12 -11
  43. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/categorical_diversity.py +8 -8
  44. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/coverage.py +11 -15
  45. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/experiment.py +6 -6
  46. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/hitrate.py +1 -3
  47. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/map.py +1 -3
  48. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/mrr.py +1 -3
  49. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/ndcg.py +1 -2
  50. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/novelty.py +3 -3
  51. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/offline_metrics.py +18 -18
  52. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/precision.py +1 -3
  53. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/recall.py +1 -3
  54. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/rocauc.py +1 -3
  55. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/surprisal.py +4 -4
  56. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/torch_metrics_builder.py +13 -12
  57. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/unexpectedness.py +2 -2
  58. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/__init__.py +19 -0
  59. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/als.py +2 -2
  60. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/association_rules.py +5 -7
  61. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/base_neighbour_rec.py +8 -10
  62. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/base_rec.py +54 -302
  63. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/cat_pop_rec.py +4 -2
  64. replay_rec-0.20.0rc0/replay/models/common.py +69 -0
  65. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/ann_mixin.py +31 -25
  66. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  67. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  68. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  69. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/utils.py +4 -3
  70. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/knn.py +18 -17
  71. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/lin_ucb.py +3 -3
  72. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  73. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +3 -3
  74. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  75. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/model.py +2 -2
  76. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +14 -14
  77. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  78. replay_rec-0.20.0rc0/replay/models/nn/sequential/compiled/__init__.py +15 -0
  79. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/base_compiled_model.py +8 -6
  80. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  81. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  82. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +2 -3
  83. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
  84. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +1 -1
  85. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +3 -3
  86. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/model.py +9 -9
  87. replay_rec-0.20.0rc0/replay/models/optimization/__init__.py +14 -0
  88. replay_rec-0.20.0rc0/replay/models/optimization/optuna_mixin.py +279 -0
  89. {replay_rec-0.19.0rc0/replay → replay_rec-0.20.0rc0/replay/models}/optimization/optuna_objective.py +13 -15
  90. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/slim.py +4 -6
  91. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/ucb.py +2 -2
  92. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/word2vec.py +9 -14
  93. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/discretizer.py +9 -9
  94. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/filters.py +4 -4
  95. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/history_based_fp.py +7 -7
  96. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/label_encoder.py +9 -8
  97. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/scenarios/fallback.py +4 -3
  98. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/base_splitter.py +3 -3
  99. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/cold_user_random_splitter.py +17 -11
  100. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/k_folds.py +4 -4
  101. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/last_n_splitter.py +27 -20
  102. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/new_users_splitter.py +4 -4
  103. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/random_splitter.py +4 -4
  104. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/ratio_splitter.py +10 -10
  105. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/time_splitter.py +6 -6
  106. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/two_stage_splitter.py +4 -4
  107. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/__init__.py +7 -2
  108. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/common.py +5 -3
  109. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/model_handler.py +11 -31
  110. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/session_handler.py +4 -4
  111. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/spark_utils.py +8 -7
  112. replay_rec-0.20.0rc0/replay/utils/types.py +50 -0
  113. replay_rec-0.20.0rc0/replay/utils/warnings.py +26 -0
  114. replay_rec-0.19.0rc0/replay/__init__.py +0 -3
  115. replay_rec-0.19.0rc0/replay/experimental/models/__init__.py +0 -13
  116. replay_rec-0.19.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  117. replay_rec-0.19.0rc0/replay/models/nn/sequential/compiled/__init__.py +0 -5
  118. replay_rec-0.19.0rc0/replay/optimization/__init__.py +0 -5
  119. replay_rec-0.19.0rc0/replay/utils/types.py +0 -38
  120. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/LICENSE +0 -0
  121. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/NOTICE +0 -0
  122. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/__init__.py +0 -0
  123. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
  124. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/__init__.py +6 -6
  125. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/spark_schema.py +0 -0
  126. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/__init__.py +0 -0
  127. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/__init__.py +0 -0
  128. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
  129. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/map.py +0 -0
  130. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/mrr.py +0 -0
  131. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
  132. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
  133. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/precision.py +0 -0
  134. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/recall.py +0 -0
  135. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
  136. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
  137. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
  138. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
  139. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
  140. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
  141. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  142. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/hierarchical_recommender.py +0 -0
  143. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
  144. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/u_lin_ucb.py +0 -0
  145. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
  146. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
  147. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
  148. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/__init__.py +0 -0
  149. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/__init__.py +0 -0
  150. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/logger.py +0 -0
  151. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/model_handler.py +0 -0
  152. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/__init__.py +0 -0
  153. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/descriptors.py +0 -0
  154. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/cluster.py +0 -0
  155. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/__init__.py +0 -0
  156. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/__init__.py +0 -0
  157. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/__init__.py +0 -0
  158. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  159. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  160. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  161. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  162. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  163. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  164. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  165. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  166. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  167. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  168. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  169. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  170. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  171. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  172. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  173. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  174. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  175. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  176. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  177. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  178. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/kl_ucb.py +0 -0
  179. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/__init__.py +0 -0
  180. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/loss/__init__.py +0 -0
  181. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/loss/sce.py +0 -0
  182. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  183. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
  184. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  185. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  186. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  187. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  188. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/pop_rec.py +0 -0
  189. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/query_pop_rec.py +0 -0
  190. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/random_rec.py +0 -0
  191. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/thompson_sampling.py +0 -0
  192. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/wilson.py +0 -0
  193. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/__init__.py +0 -0
  194. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/converter.py +0 -0
  195. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/sessionizer.py +0 -0
  196. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/scenarios/__init__.py +0 -0
  197. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/__init__.py +0 -0
  198. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
  199. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/distributions.py +0 -0
  200. {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/time.py +0 -0
@@ -1,53 +1,38 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.19.0rc0
3
+ Version: 0.20.0rc0
4
4
  Summary: RecSys Library
5
- Home-page: https://sb-ai-lab.github.io/RePlay/
6
- License: Apache-2.0
5
+ License-Expression: Apache-2.0
6
+ License-File: LICENSE
7
+ License-File: NOTICE
7
8
  Author: AI Lab
8
- Requires-Python: >=3.8.1,<3.12
9
+ Requires-Python: >=3.9, <3.13
10
+ Classifier: Operating System :: Unix
9
11
  Classifier: Development Status :: 4 - Beta
10
12
  Classifier: Environment :: Console
11
13
  Classifier: Intended Audience :: Developers
12
14
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: Apache Software License
14
15
  Classifier: Natural Language :: English
15
- Classifier: Operating System :: Unix
16
- Classifier: Programming Language :: Python :: 3
17
- Classifier: Programming Language :: Python :: 3.9
18
- Classifier: Programming Language :: Python :: 3.10
19
- Classifier: Programming Language :: Python :: 3.11
20
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
- Provides-Extra: all
22
- Provides-Extra: spark
23
- Provides-Extra: torch
24
- Provides-Extra: torch-openvino
25
- Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
26
- Requires-Dist: fixed-install-nmslib (==2.1.2)
27
- Requires-Dist: gym (>=0.26.0,<0.27.0)
28
- Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
29
- Requires-Dist: implicit (>=0.7.0,<0.8.0)
30
- Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
31
- Requires-Dist: lightfm (==1.17)
32
- Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "torch-openvino" or extra == "all"
33
- Requires-Dist: llvmlite (>=0.32.1)
34
- Requires-Dist: numba (>=0.50)
35
- Requires-Dist: numpy (>=1.20.0)
36
- Requires-Dist: onnx (>=1.16.2,<1.17.0) ; extra == "torch-openvino" or extra == "all"
37
- Requires-Dist: openvino (>=2024.3.0,<2024.4.0) ; extra == "torch-openvino" or extra == "all"
38
- Requires-Dist: optuna (>=3.2.0,<3.3.0)
39
- Requires-Dist: pandas (>=1.3.5,<=2.2.2)
40
- Requires-Dist: polars (>=1.0.0,<1.1.0)
41
- Requires-Dist: psutil (>=6.0.0,<6.1.0)
42
- Requires-Dist: pyarrow (>=12.0.1)
43
- Requires-Dist: pyspark (>=3.0,<3.5) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
44
- Requires-Dist: pyspark (>=3.4,<3.5) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
45
- Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "torch-openvino" or extra == "all"
46
- Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
47
- Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
48
- Requires-Dist: scipy (>=1.8.1,<2.0.0)
49
- Requires-Dist: torch (>=1.8,<3.0.0) ; (python_version >= "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
50
- Requires-Dist: torch (>=1.8,<=2.4.1) ; (python_version >= "3.8" and python_version < "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
17
+ Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
+ Requires-Dist: implicit (>=0.7.2,<0.8)
19
+ Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
+ Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
+ Requires-Dist: numba (>=0.50,<1)
22
+ Requires-Dist: numpy (>=1.20.0,<2)
23
+ Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
+ Requires-Dist: polars (<2.0)
25
+ Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: pyarrow (<22.0)
27
+ Requires-Dist: pyspark (>=3.0,<3.5)
28
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
+ Requires-Dist: sb-obp (>=0.5.10,<0.6)
30
+ Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
+ Requires-Dist: scipy (>=1.13.1,<1.14)
32
+ Requires-Dist: setuptools
33
+ Requires-Dist: torch (>=1.8,<3.0.0)
34
+ Requires-Dist: tqdm (>=4.67,<5)
35
+ Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
51
36
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
52
37
  Description-Content-Type: text/markdown
53
38
 
@@ -216,7 +201,6 @@ pip install replay-rec==XX.YY.ZZrc0
216
201
  In addition to the core package, several extras are also provided, including:
217
202
  - `[spark]`: Install PySpark functionality
218
203
  - `[torch]`: Install PyTorch and Lightning functionality
219
- - `[all]`: `[spark]` `[torch]`
220
204
 
221
205
  Example:
222
206
  ```bash
@@ -227,9 +211,41 @@ pip install replay-rec[spark]
227
211
  pip install replay-rec[spark]==XX.YY.ZZrc0
228
212
  ```
229
213
 
214
+ Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
215
+ ```bash
216
+ # Install package with the CPU version of torch
217
+ pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
218
+ ```
219
+
220
+
230
221
  To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
231
222
 
232
223
 
224
+ ### Optional features
225
+ RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
226
+
227
+ 1) Hyperpearameter search via Optuna:
228
+ ```bash
229
+ pip install optuna
230
+ ```
231
+
232
+ 2) Model compilation via OpenVINO:
233
+ ```bash
234
+ pip install openvino onnx
235
+ ```
236
+
237
+ 3) Vector database and hierarchical search support:
238
+ ```bash
239
+ pip install hnswlib fixed-install-nmslib
240
+ ```
241
+
242
+ 4) (Experimental) LightFM model support:
243
+ ```bash
244
+ pip install ligfhtfm
245
+ ```
246
+ > **_NOTE_** : LightFM is not officially supported for Python 3.12 due to discontinued maintenance of the library. If you wish to install it locally, you'll have to use a patched fork of LightFM, such as the [one used internally](https://github.com/daviddavo/lightfm).
247
+
248
+
233
249
  <a name="examples"></a>
234
250
  ## 📑 Resources
235
251
 
@@ -163,7 +163,6 @@ pip install replay-rec==XX.YY.ZZrc0
163
163
  In addition to the core package, several extras are also provided, including:
164
164
  - `[spark]`: Install PySpark functionality
165
165
  - `[torch]`: Install PyTorch and Lightning functionality
166
- - `[all]`: `[spark]` `[torch]`
167
166
 
168
167
  Example:
169
168
  ```bash
@@ -174,9 +173,41 @@ pip install replay-rec[spark]
174
173
  pip install replay-rec[spark]==XX.YY.ZZrc0
175
174
  ```
176
175
 
176
+ Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
177
+ ```bash
178
+ # Install package with the CPU version of torch
179
+ pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
180
+ ```
181
+
182
+
177
183
  To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
178
184
 
179
185
 
186
+ ### Optional features
187
+ RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
188
+
189
+ 1) Hyperpearameter search via Optuna:
190
+ ```bash
191
+ pip install optuna
192
+ ```
193
+
194
+ 2) Model compilation via OpenVINO:
195
+ ```bash
196
+ pip install openvino onnx
197
+ ```
198
+
199
+ 3) Vector database and hierarchical search support:
200
+ ```bash
201
+ pip install hnswlib fixed-install-nmslib
202
+ ```
203
+
204
+ 4) (Experimental) LightFM model support:
205
+ ```bash
206
+ pip install ligfhtfm
207
+ ```
208
+ > **_NOTE_** : LightFM is not officially supported for Python 3.12 due to discontinued maintenance of the library. If you wish to install it locally, you'll have to use a patched fork of LightFM, such as the [one used internally](https://github.com/daviddavo/lightfm).
209
+
210
+
180
211
  <a name="examples"></a>
181
212
  ## 📑 Resources
182
213
 
@@ -1,35 +1,28 @@
1
1
  [build-system]
2
2
  requires = [
3
- "poetry-core>=1.0.0",
3
+ "poetry-core>=2.0.0",
4
4
  "poetry-dynamic-versioning>=1.0.0,<2.0.0",
5
+ "setuptools",
5
6
  ]
6
7
  build-backend = "poetry_dynamic_versioning.backend"
7
8
 
8
- [tool.black]
9
- line-length = 120
10
- target-versions = ["py38", "py39", "py310", "py311"]
11
-
12
- [tool.poetry]
9
+ [project]
13
10
  name = "replay-rec"
14
- packages = [{include = "replay"}]
15
11
  license = "Apache-2.0"
16
12
  description = "RecSys Library"
17
13
  authors = [
18
- "AI Lab",
19
- "Alexey Vasilev",
20
- "Anna Volodkevich",
21
- "Alexey Grishanov",
22
- "Yan-Martin Tamm",
23
- "Boris Shminke",
24
- "Alexander Sidorenko",
25
- "Roza Aysina",
14
+ {name = "AI Lab"},
15
+ {name = "Alexey Vasilev"},
16
+ {name = "Anna Volodkevich"},
17
+ {name = "Alexey Grishanov"},
18
+ {name = "Yan-Martin Tamm"},
19
+ {name = "Boris Shminke"},
20
+ {name = "Alexander Sidorenko"},
21
+ {name = "Roza Aysina"},
26
22
  ]
27
23
  readme = "README.md"
28
- homepage = "https://sb-ai-lab.github.io/RePlay/"
29
- repository = "https://github.com/sb-ai-lab/RePlay"
30
24
  classifiers = [
31
25
  "Operating System :: Unix",
32
- "Intended Audience :: Science/Research",
33
26
  "Development Status :: 4 - Beta",
34
27
  "Environment :: Console",
35
28
  "Intended Audience :: Developers",
@@ -37,51 +30,46 @@ classifiers = [
37
30
  "Natural Language :: English",
38
31
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
39
32
  ]
40
- exclude = [
41
- "replay/conftest.py",
33
+ requires-python = ">=3.9, <3.13"
34
+ dependencies = [
35
+ "setuptools",
36
+ "numpy (>=1.20.0,<2)",
37
+ "pandas (>=1.3.5,<2.4.0)",
38
+ "polars (<2.0)",
39
+ "scipy (>=1.13.1,<1.14)",
40
+ "scikit-learn (>=1.6.1,<1.7.0)",
41
+ "pyarrow (<22.0)",
42
+ "tqdm (>=4.67,<5)",
43
+ "torch (>=1.8,<3.0.0)",
44
+ "lightning (>=2.0.2,<=2.4.0)",
45
+ "pytorch-optimizer (>=3.8.0,<4)",
46
+ "lightautoml (>=0.4.1,<0.5)",
47
+ "numba (>=0.50,<1)",
48
+ "sb-obp (>=0.5.10,<0.6)",
49
+ "d3rlpy (>=2.8.1,<2.9)",
50
+ "implicit (>=0.7.2,<0.8)",
51
+ "pyspark (>=3.0,<3.5)",
52
+ "psutil (<=7.0.0)",
42
53
  ]
43
- version = "0.19.0.preview"
54
+ dynamic = ["dependencies"]
55
+ version = "0.20.0.preview"
44
56
 
45
- [tool.poetry.dependencies]
46
- python = ">=3.8.1, <3.12"
47
- numpy = ">=1.20.0"
48
- pandas = ">=1.3.5, <=2.2.2"
49
- polars = "~1.0.0"
50
- optuna = "~3.2.0"
51
- scipy = "^1.8.1"
52
- psutil = "~6.0.0"
53
- scikit-learn = "^1.0.2"
54
- pyarrow = ">=12.0.1"
55
- openvino = {version = "~2024.3.0", optional = true}
56
- onnx = {version = "~1.16.2", optional = true}
57
- fixed-install-nmslib = "2.1.2"
58
- hnswlib = "^0.7.0"
59
- pyspark = [
60
- {version = ">=3.4,<3.5", python = ">=3.11,<3.12"},
61
- {version = ">=3.0,<3.5", python = ">=3.8.1,<3.11"},
62
- ]
63
- torch = [
64
- {version = ">=1.8, <3.0.0", python = ">=3.9", optional = true},
65
- {version = ">=1.8, <=2.4.1", python = ">=3.8,<3.9", optional = true},
66
- ]
67
- lightning = ">=2.0.2, <=2.4.0"
68
- pytorch-ranger = "^0.1.1"
69
- lightfm = "1.17"
70
- lightautoml = "~0.3.1"
71
- numba = ">=0.50"
72
- llvmlite = ">=0.32.1"
73
- sb-obp = "^0.5.8"
74
- d3rlpy = "^2.0.4"
75
- implicit = "~0.7.0"
76
- gym = "^0.26.0"
57
+ [project.urls]
58
+ homepage = "https://sb-ai-lab.github.io/RePlay/"
59
+ repository = "https://github.com/sb-ai-lab/RePlay"
77
60
 
78
- [tool.poetry.extras]
79
- spark = ["pyspark"]
80
- torch = ["torch", "pytorch-ranger", "lightning"]
81
- torch-openvino = ["torch", "pytorch-ranger", "lightning", "openvino", "onnx"]
82
- all = ["pyspark", "torch", "pytorch-ranger", "lightning", "openvino", "onnx"]
61
+ [tool.black]
62
+ line-length = 120
63
+ target-version = ["py39", "py310", "py311", "py312"]
64
+
65
+ [tool.poetry]
66
+ packages = [{include = "replay"}]
67
+ exclude = [
68
+ "replay/conftest.py",
69
+ ]
83
70
 
84
71
  [tool.poetry.group.dev.dependencies]
72
+ coverage-conditional-plugin = "^0.9.0"
85
73
  jupyter = "~1.0.0"
86
74
  jupyterlab = "^3.6.0"
87
75
  pytest = ">=7.1.0"
@@ -102,31 +90,29 @@ filelock = "~3.14.0"
102
90
 
103
91
  [tool.poetry-dynamic-versioning]
104
92
  enable = false
105
- format-jinja = """0.19.0{{ env['PACKAGE_SUFFIX'] }}"""
93
+ format-jinja = """0.20.0{{ env['PACKAGE_SUFFIX'] }}"""
106
94
  vcs = "git"
107
95
 
108
96
  [tool.ruff]
109
97
  exclude = [".git", ".venv", "__pycache__", "env", "venv", "docs", "projects", "examples"]
110
- extend-select = ["C90", "T10", "T20", "UP004"]
111
98
  line-length = 120
99
+
100
+ [tool.ruff.lint]
112
101
  select = ["ARG", "C4", "E", "EM", "ERA", "F", "FLY", "I", "INP", "ISC", "N", "PERF", "PGH", "PIE", "PYI", "Q", "RUF", "SIM", "TID", "W"]
102
+ extend-select = ["C90", "T10", "T20", "UP004"]
103
+ ignore = ["SIM115"]
104
+ mccabe = {max-complexity = 13}
105
+ isort = {combine-as-imports = true, force-wrap-aliases = true}
113
106
 
114
- [tool.ruff.flake8-quotes]
107
+ [tool.ruff.lint.flake8-quotes]
115
108
  docstring-quotes = "double"
116
109
  inline-quotes = "double"
117
110
  multiline-quotes = "double"
118
111
 
119
- [tool.ruff.flake8-unused-arguments]
112
+ [tool.ruff.lint.flake8-unused-arguments]
120
113
  ignore-variadic-names = false
121
114
 
122
- [tool.ruff.isort]
123
- combine-as-imports = true
124
- force-wrap-aliases = true
125
-
126
- [tool.ruff.mccabe]
127
- max-complexity = 13
128
-
129
- [tool.ruff.per-file-ignores]
115
+ [tool.ruff.lint.per-file-ignores]
130
116
  "*/" = ["PERF203", "RUF001", "RUF002", "RUF012", "E402"]
131
117
  "__init__.py" = ["F401"]
132
118
  "replay/utils/model_handler.py" = ["F403", "F405"]
@@ -0,0 +1,7 @@
1
+ """RecSys library"""
2
+
3
+ # NOTE: This ensures distutils monkey-patching is performed before any
4
+ # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
+ import setuptools as _
6
+
7
+ __version__ = "0.20.0.preview"
@@ -5,8 +5,9 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import json
8
+ from collections.abc import Iterable, Sequence
8
9
  from pathlib import Path
9
- from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
10
+ from typing import Callable, Optional, Union
10
11
 
11
12
  import numpy as np
12
13
  from pandas import read_parquet as pd_read_parquet
@@ -315,7 +316,7 @@ class Dataset:
315
316
  :returns: Loaded Dataset.
316
317
  """
317
318
  base_path = Path(path).with_suffix(".replay").resolve()
318
- with open(base_path / "init_args.json", "r") as file:
319
+ with open(base_path / "init_args.json") as file:
319
320
  dataset_dict = json.loads(file.read())
320
321
 
321
322
  if dataframe_type not in ["pandas", "spark", "polars", None]:
@@ -436,14 +437,14 @@ class Dataset:
436
437
  )
437
438
 
438
439
  def _get_feature_source_map(self):
439
- self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
440
+ self._feature_source_map: dict[FeatureSource, DataFrameLike] = {
440
441
  FeatureSource.INTERACTIONS: self.interactions,
441
442
  FeatureSource.QUERY_FEATURES: self.query_features,
442
443
  FeatureSource.ITEM_FEATURES: self.item_features,
443
444
  }
444
445
 
445
446
  def _get_ids_source_map(self):
446
- self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
447
+ self._ids_feature_map: dict[FeatureHint, DataFrameLike] = {
447
448
  FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
448
449
  FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
449
450
  }
@@ -499,10 +500,10 @@ class Dataset:
499
500
  )
500
501
  return FeatureSchema(features_list=features_list + filled_features)
501
502
 
502
- def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> List[FeatureInfo]:
503
+ def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
503
504
  features_list = list(feature_schema.all_features)
504
505
 
505
- source_mapping: Dict[str, FeatureSource] = {}
506
+ source_mapping: dict[str, FeatureSource] = {}
506
507
  for source in FeatureSource:
507
508
  dataframe = self._feature_source_map[source]
508
509
  if dataframe is not None:
@@ -524,7 +525,7 @@ class Dataset:
524
525
  self._set_cardinality(features_list=features_list)
525
526
  return features_list
526
527
 
527
- def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> List[FeatureInfo]:
528
+ def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
528
529
  set_source_dataframe_columns = set(self._feature_source_map[source].columns)
529
530
  set_labeled_dataframe_columns = set(feature_schema.columns)
530
531
  unlabeled_columns = set_source_dataframe_columns - set_labeled_dataframe_columns
@@ -534,13 +535,13 @@ class Dataset:
534
535
  ]
535
536
  return unlabeled_features_list
536
537
 
537
- def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> List[FeatureInfo]:
538
+ def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
538
539
  unlabeled_columns = self._get_unlabeled_columns(source=source, feature_schema=feature_schema)
539
540
  self._set_features_source(feature_list=unlabeled_columns, source=source)
540
541
  self._set_cardinality(features_list=unlabeled_columns)
541
542
  return unlabeled_columns
542
543
 
543
- def _set_features_source(self, feature_list: List[FeatureInfo], source: FeatureSource) -> None:
544
+ def _set_features_source(self, feature_list: list[FeatureInfo], source: FeatureSource) -> None:
544
545
  for feature in feature_list:
545
546
  feature._set_feature_source(source)
546
547
 
@@ -610,9 +611,9 @@ class Dataset:
610
611
  if self.is_pandas:
611
612
  try:
612
613
  data[column] = data[column].astype(int)
613
- except Exception:
614
+ except Exception as exc:
614
615
  msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
615
- raise ValueError(msg)
616
+ raise ValueError(msg) from exc
616
617
 
617
618
  if self.is_pandas:
618
619
  is_int = np.issubdtype(dict(data.dtypes)[column], int)
@@ -775,10 +776,10 @@ def check_dataframes_types_equal(dataframe: DataFrameLike, other: DataFrameLike)
775
776
 
776
777
  :returns: True if dataframes have same type.
777
778
  """
778
- if isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame):
779
- return True
780
- if isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame):
781
- return True
782
- if isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame):
783
- return True
784
- return False
779
+ return any(
780
+ [
781
+ isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame),
782
+ isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame),
783
+ isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame),
784
+ ]
785
+ )
@@ -6,7 +6,8 @@ Contains classes for encoding categorical data
6
6
  """
7
7
 
8
8
  import warnings
9
- from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
+ from collections.abc import Iterable, Iterator, Sequence
10
+ from typing import Optional, Union
10
11
 
11
12
  from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
12
13
  from replay.preprocessing import LabelEncoder, LabelEncodingRule, SequenceEncodingRule
@@ -45,9 +46,9 @@ class DatasetLabelEncoder:
45
46
  """
46
47
  self._handle_unknown_rule = handle_unknown_rule
47
48
  self._default_value_rule = default_value_rule
48
- self._encoding_rules: Dict[str, LabelEncodingRule] = {}
49
+ self._encoding_rules: dict[str, LabelEncodingRule] = {}
49
50
 
50
- self._features_columns: Dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
51
+ self._features_columns: dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
51
52
 
52
53
  def fit(self, dataset: Dataset) -> "DatasetLabelEncoder":
53
54
  """
@@ -161,7 +162,7 @@ class DatasetLabelEncoder:
161
162
  """
162
163
  self._check_if_initialized()
163
164
 
164
- columns_set: Set[str]
165
+ columns_set: set[str]
165
166
  columns_set = {columns} if isinstance(columns, str) else {*columns}
166
167
 
167
168
  def get_encoding_rules() -> Iterator[LabelEncodingRule]:
@@ -1,17 +1,8 @@
1
+ from collections import OrderedDict
2
+ from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, Sequence, ValuesView
1
3
  from typing import (
2
- Dict,
3
- ItemsView,
4
- Iterable,
5
- Iterator,
6
- KeysView,
7
- List,
8
- Mapping,
9
4
  Optional,
10
- OrderedDict,
11
- Sequence,
12
- Set,
13
5
  Union,
14
- ValuesView,
15
6
  )
16
7
 
17
8
  import torch
@@ -20,7 +11,7 @@ from replay.data import FeatureHint, FeatureSource, FeatureType
20
11
 
21
12
  # Alias
22
13
  TensorMap = Mapping[str, torch.Tensor]
23
- MutableTensorMap = Dict[str, torch.Tensor]
14
+ MutableTensorMap = dict[str, torch.Tensor]
24
15
 
25
16
 
26
17
  class TensorFeatureSource:
@@ -79,7 +70,7 @@ class TensorFeatureInfo:
79
70
  feature_type: FeatureType,
80
71
  is_seq: bool = False,
81
72
  feature_hint: Optional[FeatureHint] = None,
82
- feature_sources: Optional[List[TensorFeatureSource]] = None,
73
+ feature_sources: Optional[list[TensorFeatureSource]] = None,
83
74
  cardinality: Optional[int] = None,
84
75
  padding_value: int = 0,
85
76
  embedding_dim: Optional[int] = None,
@@ -154,13 +145,13 @@ class TensorFeatureInfo:
154
145
  self._feature_hint = hint
155
146
 
156
147
  @property
157
- def feature_sources(self) -> Optional[List[TensorFeatureSource]]:
148
+ def feature_sources(self) -> Optional[list[TensorFeatureSource]]:
158
149
  """
159
150
  :returns: List of sources feature came from.
160
151
  """
161
152
  return self._feature_sources
162
153
 
163
- def _set_feature_sources(self, sources: List[TensorFeatureSource]) -> None:
154
+ def _set_feature_sources(self, sources: list[TensorFeatureSource]) -> None:
164
155
  self._feature_sources = sources
165
156
 
166
157
  @property
@@ -276,7 +267,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
276
267
 
277
268
  :returns: New tensor schema of given features.
278
269
  """
279
- features: Set[TensorFeatureInfo] = set()
270
+ features: set[TensorFeatureInfo] = set()
280
271
  for feature_name in features_to_keep:
281
272
  features.add(self._tensor_schema[feature_name])
282
273
  return TensorSchema(list(features))
@@ -432,7 +423,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
432
423
  return None
433
424
  return rating_features.item().name
434
425
 
435
- def _get_object_args(self) -> Dict:
426
+ def _get_object_args(self) -> dict:
436
427
  """
437
428
  Returns list of features represented as dictionaries.
438
429
  """
@@ -456,7 +447,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
456
447
  return features
457
448
 
458
449
  @classmethod
459
- def _create_object_by_args(cls, args: Dict) -> "TensorSchema":
450
+ def _create_object_by_args(cls, args: dict) -> "TensorSchema":
460
451
  features_list = []
461
452
  for feature_data in args:
462
453
  feature_data["feature_sources"] = (