replay-rec 0.21.0rc0__tar.gz → 0.21.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/PKG-INFO +17 -11
  2. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/pyproject.toml +19 -12
  3. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/__init__.py +1 -1
  4. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/parquet_module.py +1 -1
  5. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/torch_metrics_builder.py +1 -1
  6. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/validation_callback.py +14 -4
  7. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/callback/metrics_callback.py +18 -9
  8. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/callback/predictions_callback.py +2 -2
  9. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/base.py +3 -3
  10. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/login_ce.py +1 -1
  11. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/model.py +1 -1
  12. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/reader.py +14 -5
  13. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/template/sasrec.py +3 -3
  14. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/template/twotower.py +1 -1
  15. replay_rec-0.21.0rc0/replay/experimental/metrics/__init__.py +0 -62
  16. replay_rec-0.21.0rc0/replay/experimental/metrics/base_metric.py +0 -603
  17. replay_rec-0.21.0rc0/replay/experimental/metrics/coverage.py +0 -97
  18. replay_rec-0.21.0rc0/replay/experimental/metrics/experiment.py +0 -175
  19. replay_rec-0.21.0rc0/replay/experimental/metrics/hitrate.py +0 -26
  20. replay_rec-0.21.0rc0/replay/experimental/metrics/map.py +0 -30
  21. replay_rec-0.21.0rc0/replay/experimental/metrics/mrr.py +0 -18
  22. replay_rec-0.21.0rc0/replay/experimental/metrics/ncis_precision.py +0 -31
  23. replay_rec-0.21.0rc0/replay/experimental/metrics/ndcg.py +0 -49
  24. replay_rec-0.21.0rc0/replay/experimental/metrics/precision.py +0 -22
  25. replay_rec-0.21.0rc0/replay/experimental/metrics/recall.py +0 -25
  26. replay_rec-0.21.0rc0/replay/experimental/metrics/rocauc.py +0 -49
  27. replay_rec-0.21.0rc0/replay/experimental/metrics/surprisal.py +0 -90
  28. replay_rec-0.21.0rc0/replay/experimental/metrics/unexpectedness.py +0 -76
  29. replay_rec-0.21.0rc0/replay/experimental/models/__init__.py +0 -50
  30. replay_rec-0.21.0rc0/replay/experimental/models/admm_slim.py +0 -257
  31. replay_rec-0.21.0rc0/replay/experimental/models/base_neighbour_rec.py +0 -200
  32. replay_rec-0.21.0rc0/replay/experimental/models/base_rec.py +0 -1386
  33. replay_rec-0.21.0rc0/replay/experimental/models/base_torch_rec.py +0 -234
  34. replay_rec-0.21.0rc0/replay/experimental/models/cql.py +0 -454
  35. replay_rec-0.21.0rc0/replay/experimental/models/ddpg.py +0 -932
  36. replay_rec-0.21.0rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
  37. replay_rec-0.21.0rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
  38. replay_rec-0.21.0rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
  39. replay_rec-0.21.0rc0/replay/experimental/models/dt4rec/utils.py +0 -264
  40. replay_rec-0.21.0rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  41. replay_rec-0.21.0rc0/replay/experimental/models/hierarchical_recommender.py +0 -331
  42. replay_rec-0.21.0rc0/replay/experimental/models/implicit_wrap.py +0 -131
  43. replay_rec-0.21.0rc0/replay/experimental/models/lightfm_wrap.py +0 -303
  44. replay_rec-0.21.0rc0/replay/experimental/models/mult_vae.py +0 -332
  45. replay_rec-0.21.0rc0/replay/experimental/models/neural_ts.py +0 -986
  46. replay_rec-0.21.0rc0/replay/experimental/models/neuromf.py +0 -406
  47. replay_rec-0.21.0rc0/replay/experimental/models/scala_als.py +0 -293
  48. replay_rec-0.21.0rc0/replay/experimental/models/u_lin_ucb.py +0 -115
  49. replay_rec-0.21.0rc0/replay/experimental/nn/data/__init__.py +0 -1
  50. replay_rec-0.21.0rc0/replay/experimental/nn/data/schema_builder.py +0 -102
  51. replay_rec-0.21.0rc0/replay/experimental/preprocessing/__init__.py +0 -3
  52. replay_rec-0.21.0rc0/replay/experimental/preprocessing/data_preparator.py +0 -839
  53. replay_rec-0.21.0rc0/replay/experimental/preprocessing/padder.py +0 -229
  54. replay_rec-0.21.0rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
  55. replay_rec-0.21.0rc0/replay/experimental/scenarios/__init__.py +0 -1
  56. replay_rec-0.21.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  57. replay_rec-0.21.0rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  58. replay_rec-0.21.0rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  59. replay_rec-0.21.0rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  60. replay_rec-0.21.0rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
  61. replay_rec-0.21.0rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  62. replay_rec-0.21.0rc0/replay/experimental/utils/logger.py +0 -24
  63. replay_rec-0.21.0rc0/replay/experimental/utils/model_handler.py +0 -186
  64. replay_rec-0.21.0rc0/replay/experimental/utils/session_handler.py +0 -44
  65. replay_rec-0.21.0rc0/replay/models/extensions/ann/__init__.py +0 -0
  66. replay_rec-0.21.0rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
  67. replay_rec-0.21.0rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  68. replay_rec-0.21.0rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  69. replay_rec-0.21.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  70. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/LICENSE +0 -0
  71. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/NOTICE +0 -0
  72. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/README.md +0 -0
  73. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/__init__.py +0 -0
  74. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/dataset.py +0 -0
  75. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/dataset_utils/__init__.py +0 -0
  76. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
  77. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/__init__.py +0 -0
  78. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/__init__.py +0 -0
  79. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/collate.py +0 -0
  80. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/__init__.py +0 -0
  81. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/batches.py +0 -0
  82. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/device.py +0 -0
  83. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/filesystem.py +0 -0
  84. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/metadata.py +0 -0
  85. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/fixed_batch_dataset.py +0 -0
  86. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/__init__.py +0 -0
  87. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/array_1d_column.py +0 -0
  88. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/array_2d_column.py +0 -0
  89. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/column_protocol.py +0 -0
  90. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/indexing.py +0 -0
  91. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/masking.py +0 -0
  92. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/named_columns.py +0 -0
  93. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/numeric_column.py +0 -0
  94. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/utils.py +0 -0
  95. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/__init__.py +0 -0
  96. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/distributed_info.py +0 -0
  97. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/partitioning.py +0 -0
  98. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/replicas.py +0 -0
  99. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/worker_info.py +0 -0
  100. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/iterable_dataset.py +0 -0
  101. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/iterator.py +0 -0
  102. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/metadata/__init__.py +0 -0
  103. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/metadata/metadata.py +0 -0
  104. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/parquet_dataset.py +0 -0
  105. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/partitioned_iterable_dataset.py +0 -0
  106. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/utils/__init__.py +0 -0
  107. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/parquet/utils/compute_length.py +0 -0
  108. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/schema.py +0 -0
  109. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/sequence_tokenizer.py +0 -0
  110. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/sequential_dataset.py +0 -0
  111. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/torch_sequential_dataset.py +0 -0
  112. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/nn/utils.py +0 -0
  113. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/schema.py +0 -0
  114. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/spark_schema.py +0 -0
  115. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/utils/__init__.py +0 -0
  116. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/utils/batching.py +0 -0
  117. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/utils/typing/__init__.py +0 -0
  118. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/data/utils/typing/dtype.py +0 -0
  119. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/__init__.py +0 -0
  120. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/base_metric.py +0 -0
  121. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/categorical_diversity.py +0 -0
  122. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/coverage.py +0 -0
  123. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/descriptors.py +0 -0
  124. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/experiment.py +0 -0
  125. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/hitrate.py +0 -0
  126. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/map.py +0 -0
  127. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/mrr.py +0 -0
  128. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/ndcg.py +0 -0
  129. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/novelty.py +0 -0
  130. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/offline_metrics.py +0 -0
  131. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/precision.py +0 -0
  132. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/recall.py +0 -0
  133. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/rocauc.py +0 -0
  134. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/surprisal.py +0 -0
  135. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/metrics/unexpectedness.py +0 -0
  136. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/__init__.py +0 -0
  137. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/als.py +0 -0
  138. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/association_rules.py +0 -0
  139. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/base_neighbour_rec.py +0 -0
  140. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/base_rec.py +0 -0
  141. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/cat_pop_rec.py +0 -0
  142. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/cluster.py +0 -0
  143. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/common.py +0 -0
  144. {replay_rec-0.21.0rc0/replay/experimental → replay_rec-0.21.1/replay/models/extensions}/__init__.py +0 -0
  145. {replay_rec-0.21.0rc0/replay/experimental/models/dt4rec → replay_rec-0.21.1/replay/models/extensions/ann}/__init__.py +0 -0
  146. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/ann_mixin.py +0 -0
  147. {replay_rec-0.21.0rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.21.1/replay/models/extensions/ann/entities}/__init__.py +0 -0
  148. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  149. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  150. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  151. {replay_rec-0.21.0rc0/replay/experimental/scenarios/two_stages → replay_rec-0.21.1/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
  152. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  153. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  154. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  155. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  156. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  157. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  158. {replay_rec-0.21.0rc0/replay/experimental/utils → replay_rec-0.21.1/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
  159. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  160. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  161. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  162. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  163. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  164. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  165. {replay_rec-0.21.0rc0/replay/models/extensions → replay_rec-0.21.1/replay/models/extensions/ann/index_stores}/__init__.py +0 -0
  166. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  167. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  168. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  169. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  170. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  171. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/extensions/ann/utils.py +0 -0
  172. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/kl_ucb.py +0 -0
  173. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/knn.py +0 -0
  174. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/lin_ucb.py +0 -0
  175. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/__init__.py +0 -0
  176. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/loss/__init__.py +0 -0
  177. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/loss/sce.py +0 -0
  178. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  179. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  180. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/__init__.py +0 -0
  181. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  182. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
  183. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  184. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/model.py +0 -0
  185. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  186. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  187. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/__init__.py +0 -0
  188. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/base_compiled_model.py +0 -0
  189. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +0 -0
  190. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/sasrec_compiled.py +0 -0
  191. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  192. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  193. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
  194. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  195. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
  196. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  197. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/model.py +0 -0
  198. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/optimization/__init__.py +0 -0
  199. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/optimization/optuna_mixin.py +0 -0
  200. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/optimization/optuna_objective.py +0 -0
  201. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/pop_rec.py +0 -0
  202. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/query_pop_rec.py +0 -0
  203. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/random_rec.py +0 -0
  204. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/slim.py +0 -0
  205. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/thompson_sampling.py +0 -0
  206. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/ucb.py +0 -0
  207. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/wilson.py +0 -0
  208. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/models/word2vec.py +0 -0
  209. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/__init__.py +0 -0
  210. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/agg.py +0 -0
  211. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/attention.py +0 -0
  212. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/embedding.py +0 -0
  213. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/ffn.py +0 -0
  214. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/head.py +0 -0
  215. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/__init__.py +0 -0
  216. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/callback/__init__.py +0 -0
  217. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/module.py +0 -0
  218. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/optimizer.py +0 -0
  219. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/__init__.py +0 -0
  220. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/_base.py +0 -0
  221. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/seen_items.py +0 -0
  222. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/lightning/scheduler.py +0 -0
  223. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/__init__.py +0 -0
  224. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/bce.py +0 -0
  225. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/ce.py +0 -0
  226. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/loss/logout_ce.py +0 -0
  227. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/mask.py +0 -0
  228. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/normalization.py +0 -0
  229. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/output.py +0 -0
  230. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/__init__.py +0 -0
  231. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/__init__.py +0 -0
  232. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/agg.py +0 -0
  233. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/diff_transformer.py +0 -0
  234. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/transformer.py +0 -0
  235. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/__init__.py +0 -0
  236. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/model.py +0 -0
  237. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/__init__.py +0 -0
  238. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/copy.py +0 -0
  239. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/grouping.py +0 -0
  240. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/negative_sampling.py +0 -0
  241. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/next_token.py +0 -0
  242. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/rename.py +0 -0
  243. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/reshape.py +0 -0
  244. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/sequence_roll.py +0 -0
  245. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/template/__init__.py +0 -0
  246. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/token_mask.py +0 -0
  247. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/transform/trim.py +0 -0
  248. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/nn/utils.py +0 -0
  249. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/__init__.py +0 -0
  250. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/converter.py +0 -0
  251. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/discretizer.py +0 -0
  252. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/filters.py +0 -0
  253. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/history_based_fp.py +0 -0
  254. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/label_encoder.py +0 -0
  255. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/sessionizer.py +0 -0
  256. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/preprocessing/utils.py +0 -0
  257. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/scenarios/__init__.py +0 -0
  258. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/scenarios/fallback.py +0 -0
  259. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/__init__.py +0 -0
  260. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/base_splitter.py +0 -0
  261. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/cold_user_random_splitter.py +0 -0
  262. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/k_folds.py +0 -0
  263. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/last_n_splitter.py +0 -0
  264. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/new_users_splitter.py +0 -0
  265. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/random_next_n_splitter.py +0 -0
  266. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/random_splitter.py +0 -0
  267. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/ratio_splitter.py +0 -0
  268. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/time_splitter.py +0 -0
  269. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/splitters/two_stage_splitter.py +0 -0
  270. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/__init__.py +0 -0
  271. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/common.py +0 -0
  272. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/dataframe_bucketizer.py +0 -0
  273. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/distributions.py +0 -0
  274. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/model_handler.py +0 -0
  275. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/session_handler.py +0 -0
  276. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/spark_utils.py +0 -0
  277. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/time.py +0 -0
  278. {replay_rec-0.21.0rc0 → replay_rec-0.21.1}/replay/utils/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.21.0rc0
3
+ Version: 0.21.1
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Natural Language :: English
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
- Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
- Requires-Dist: implicit (>=0.7.2,<0.8)
19
- Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
- Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
- Requires-Dist: numba (>=0.50,<1)
17
+ Provides-Extra: spark
18
+ Provides-Extra: torch
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
22
23
  Requires-Dist: numpy (>=1.20.0,<2)
23
24
  Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
25
  Requires-Dist: polars (<2.0)
25
- Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
26
28
  Requires-Dist: pyarrow (<22.0)
27
- Requires-Dist: pyspark (>=3.0,<3.5)
28
- Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
- Requires-Dist: sb-obp (>=0.5.10,<0.6)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
30
34
  Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
35
  Requires-Dist: scipy (>=1.8.1,<2.0.0)
32
36
  Requires-Dist: setuptools
33
- Requires-Dist: torch (>=1.8,<3.0.0)
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
34
40
  Requires-Dist: tqdm (>=4.67,<5)
35
41
  Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
36
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
@@ -65,19 +65,19 @@ dependencies = [
65
65
  "scikit-learn (>=1.6.1,<1.7.0)",
66
66
  "pyarrow (<22.0)",
67
67
  "tqdm (>=4.67,<5)",
68
- "torch (>=1.8,<3.0.0)",
69
- "lightning (>=2.0.2,<=2.4.0)",
70
- "pytorch-optimizer (>=3.8.0,<4)",
71
- "lightautoml (>=0.4.1,<0.5)",
72
- "numba (>=0.50,<1)",
73
- "sb-obp (>=0.5.10,<0.6)",
74
- "d3rlpy (>=2.8.1,<2.9)",
75
- "implicit (>=0.7.2,<0.8)",
76
- "pyspark (>=3.0,<3.5)",
77
- "psutil (<=7.0.0)",
68
+ "pyspark (>=3.0,<3.5); extra == 'spark'",
69
+ "psutil (<=7.0.0); extra == 'spark'",
70
+ "torch (>=1.8, <3.0.0); extra == 'torch' or extra == 'torch-cpu'",
71
+ "pytorch-optimizer (>=3.8.0,<3.9.0); extra == 'torch' or extra == 'torch-cpu'",
72
+ "lightning (<2.6.0); extra == 'torch' or extra == 'torch-cpu'",
78
73
  ]
79
74
  dynamic = ["dependencies"]
80
- version = "0.21.0.preview"
75
+ version = "0.21.1"
76
+
77
+ [project.optional-dependencies]
78
+ spark = ["pyspark", "psutil"]
79
+ torch = ["torch", "pytorch-optimizer", "lightning"]
80
+ torch-cpu = ["torch", "pytorch-optimizer", "lightning"]
81
81
 
82
82
  [project.urls]
83
83
  homepage = "https://sb-ai-lab.github.io/RePlay/"
@@ -91,6 +91,13 @@ target-version = ["py39", "py310", "py311", "py312"]
91
91
  packages = [{include = "replay"}]
92
92
  exclude = [
93
93
  "replay/conftest.py",
94
+ "replay/experimental",
95
+ ]
96
+
97
+ [tool.poetry.dependencies]
98
+ torch = [
99
+ {markers = "extra == 'torch-cpu' and extra !='torch'", source = "torch-cpu-mirror"},
100
+ {markers = "extra == 'torch' and extra !='torch-cpu'", source = "PyPI"},
94
101
  ]
95
102
 
96
103
  [[tool.poetry.source]]
@@ -100,7 +107,7 @@ priority = "explicit"
100
107
 
101
108
  [tool.poetry-dynamic-versioning]
102
109
  enable = false
103
- format-jinja = """0.21.0{{ env['PACKAGE_SUFFIX'] }}"""
110
+ format-jinja = """0.21.1{{ env['PACKAGE_SUFFIX'] }}"""
104
111
  vcs = "git"
105
112
 
106
113
  [tool.ruff]
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.21.0.preview"
7
+ __version__ = "0.21.1"
@@ -94,7 +94,7 @@ class ParquetModule(L.LightningDataModule):
94
94
  missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
95
95
  if missing_splits:
96
96
  msg = (
97
- f"The following dataset paths aren't provided: {','.join(missing_splits)}."
97
+ f"The following dataset paths aren't provided: {','.join(missing_splits)}. "
98
98
  "Make sure to disable these stages in your Lightning Trainer configuration."
99
99
  )
100
100
  warnings.warn(msg, stacklevel=2)
@@ -400,7 +400,7 @@ def metrics_to_df(metrics: Mapping[str, float]) -> PandasDataFrame:
400
400
 
401
401
  metric_name_and_k = metrics_df["m"].str.split("@", expand=True)
402
402
  metrics_df["metric"] = metric_name_and_k[0]
403
- metrics_df["k"] = metric_name_and_k[1]
403
+ metrics_df["k"] = metric_name_and_k[1].astype(int)
404
404
 
405
405
  pivoted_metrics = metrics_df.pivot(index="metric", columns="k", values="v")
406
406
  pivoted_metrics.index.name = None
@@ -162,14 +162,24 @@ class ValidationMetricsCallback(lightning.Callback):
162
162
  @rank_zero_only
163
163
  def print_metrics() -> None:
164
164
  metrics = {}
165
+
165
166
  for name, value in trainer.logged_metrics.items():
166
167
  if "@" in name:
167
168
  metrics[name] = value.item()
168
169
 
169
- if metrics:
170
- metrics_df = metrics_to_df(metrics)
170
+ if not metrics:
171
+ return
171
172
 
172
- print(metrics_df) # noqa: T201
173
- print() # noqa: T201
173
+ if len(self._dataloaders_size) > 1:
174
+ for i in range(len(self._dataloaders_size)):
175
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
176
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
177
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
178
+
179
+ print(suffix) # noqa: T201
180
+ print(metrics_df, "\n") # noqa: T201
181
+ else:
182
+ metrics_df = metrics_to_df(metrics)
183
+ print(metrics_df, "\n") # noqa: T201
174
184
 
175
185
  print_metrics()
@@ -2,7 +2,6 @@ from typing import Any, Optional
2
2
 
3
3
  import lightning
4
4
  import torch
5
- from lightning.pytorch.utilities.combined_loader import CombinedLoader
6
5
  from lightning.pytorch.utilities.rank_zero import rank_zero_only
7
6
 
8
7
  from replay.metrics.torch_metrics_builder import (
@@ -64,8 +63,8 @@ class ComputeMetricsCallback(lightning.Callback):
64
63
  self._train_column = train_column
65
64
 
66
65
  def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
67
- if isinstance(dataloaders, CombinedLoader):
68
- return [len(dataloader) for dataloader in dataloaders.flattened] # pragma: no cover
66
+ if isinstance(dataloaders, list):
67
+ return [len(dataloader) for dataloader in dataloaders]
69
68
  return [len(dataloaders)]
70
69
 
71
70
  def on_validation_epoch_start(
@@ -123,7 +122,7 @@ class ComputeMetricsCallback(lightning.Callback):
123
122
  batch: dict,
124
123
  batch_idx: int,
125
124
  dataloader_idx: int = 0,
126
- ) -> None: # pragma: no cover
125
+ ) -> None:
127
126
  self._batch_end(
128
127
  trainer,
129
128
  pl_module,
@@ -159,7 +158,7 @@ class ComputeMetricsCallback(lightning.Callback):
159
158
  def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
160
159
  self._epoch_end(trainer, pl_module)
161
160
 
162
- def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None: # pragma: no cover
161
+ def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
163
162
  self._epoch_end(trainer, pl_module)
164
163
 
165
164
  def _epoch_end(
@@ -170,14 +169,24 @@ class ComputeMetricsCallback(lightning.Callback):
170
169
  @rank_zero_only
171
170
  def print_metrics() -> None:
172
171
  metrics = {}
172
+
173
173
  for name, value in trainer.logged_metrics.items():
174
174
  if "@" in name:
175
175
  metrics[name] = value.item()
176
176
 
177
- if metrics:
178
- metrics_df = metrics_to_df(metrics)
177
+ if not metrics:
178
+ return
179
179
 
180
- print(metrics_df) # noqa: T201
181
- print() # noqa: T201
180
+ if len(self._dataloaders_size) > 1:
181
+ for i in range(len(self._dataloaders_size)):
182
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
183
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
184
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
185
+
186
+ print(suffix) # noqa: T201
187
+ print(metrics_df, "\n") # noqa: T201
188
+ else:
189
+ metrics_df = metrics_to_df(metrics)
190
+ print(metrics_df, "\n") # noqa: T201
182
191
 
183
192
  print_metrics()
@@ -15,11 +15,11 @@ from replay.utils import (
15
15
  SparkDataFrame,
16
16
  )
17
17
 
18
- if PYSPARK_AVAILABLE: # pragma: no cover
18
+ if PYSPARK_AVAILABLE:
19
19
  import pyspark.sql.functions as sf
20
20
  from pyspark.sql import SparkSession
21
21
  from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
22
- else: # pragma: no cover
22
+ else:
23
23
  SparkSession = MissingImport
24
24
 
25
25
 
@@ -85,7 +85,7 @@ class SampledLossBase(torch.nn.Module):
85
85
  # [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
86
86
  negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
87
87
 
88
- if negative_labels.dim() == 3: # pragma: no cover
88
+ if negative_labels.dim() == 3:
89
89
  # [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
90
90
  negative_labels = negative_labels.unsqueeze(-2)
91
91
  if num_positives != 1:
@@ -119,7 +119,7 @@ class SampledLossBase(torch.nn.Module):
119
119
  positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
120
120
  assert positive_labels.size() == (masked_batch_size, 1)
121
121
 
122
- if negative_labels.dim() != 1: # pragma: no cover
122
+ if negative_labels.dim() != 1:
123
123
  # [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
124
124
  negative_labels = negative_labels[target_padding_mask]
125
125
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -183,7 +183,7 @@ def mask_negative_logits(
183
183
  if negative_labels_ignore_index >= 0:
184
184
  negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
185
185
 
186
- if negative_labels.dim() > 1: # pragma: no cover
186
+ if negative_labels.dim() > 1:
187
187
  # [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
188
188
  negative_labels = negative_labels.unsqueeze(-2)
189
189
 
@@ -74,7 +74,7 @@ class LogInCEBase(SampledLossBase):
74
74
  positive_labels = positive_labels[masked_target_padding_mask]
75
75
  assert positive_labels.size() == (masked_batch_size, num_positives)
76
76
 
77
- if negative_labels.dim() > 1: # pragma: no cover
77
+ if negative_labels.dim() > 1:
78
78
  # [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
79
79
  negative_labels = negative_labels[masked_target_padding_mask]
80
80
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -141,7 +141,7 @@ class SasRec(torch.nn.Module):
141
141
  feature_type=FeatureType.CATEGORICAL,
142
142
  embedding_dim=256,
143
143
  padding_value=NUM_UNIQUE_ITEMS,
144
- cardinality=NUM_UNIQUE_ITEMS+1,
144
+ cardinality=NUM_UNIQUE_ITEMS,
145
145
  feature_hint=FeatureHint.ITEM_ID,
146
146
  feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
147
147
  ),
@@ -22,7 +22,6 @@ class FeaturesReader:
22
22
  :param schema: the same tensor schema used in TwoTower model.
23
23
  :param metadata: A dictionary of feature names that
24
24
  associated with its shape and padding_value.\n
25
- Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
26
25
  For details, see the section :ref:`parquet-processing`.
27
26
  :param path: path to parquet with dataframe of item features.\n
28
27
  **Note:**\n
@@ -30,8 +29,8 @@ class FeaturesReader:
30
29
  2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
31
30
  of the source features to create correct inverse mapping.
32
31
  Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
33
- contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
34
- or hint type FeatureHint.ITEM_ID.
32
+ contain ``feature_sources`` with a source of type ``FeatureSource.ITEM_FEATURES``
33
+ or hint type ``FeatureHint.ITEM_ID``.
35
34
 
36
35
  """
37
36
  item_feature_names = [
@@ -81,8 +80,18 @@ class FeaturesReader:
81
80
  self._features = {}
82
81
 
83
82
  for k in features.columns:
84
- dtype = torch.float32 if schema[k].is_num else torch.int64
85
- feature_tensor = torch.asarray(features[k], dtype=dtype)
83
+ dtype = np.float32 if schema[k].is_num else np.int64
84
+ if schema[k].is_list:
85
+ feature = np.asarray(
86
+ features[k].to_list(),
87
+ dtype=dtype,
88
+ )
89
+ else:
90
+ feature = features[k].to_numpy(dtype=dtype)
91
+ feature_tensor = torch.asarray(
92
+ feature,
93
+ dtype=torch.float32 if schema[k].is_num else torch.int64,
94
+ )
86
95
  self._features[k] = feature_tensor
87
96
 
88
97
  def __getitem__(self, key: str) -> torch.Tensor:
@@ -14,7 +14,7 @@ def make_default_sasrec_transforms(
14
14
 
15
15
  Generated pipeline expects input dataset to contain the following columns:
16
16
  1) Query ID column, specified by ``query_column``.
17
- 2) Item ID column, specified in the tensor schema.
17
+ 2) All features specified in the ``tensor_schema``.
18
18
 
19
19
  :param tensor_schema: TensorSchema used to infer feature columns.
20
20
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
@@ -32,12 +32,12 @@ def make_default_sasrec_transforms(
32
32
  ),
33
33
  UnsqueezeTransform("target_padding_mask", -1),
34
34
  UnsqueezeTransform("positive_labels", -1),
35
- GroupTransform({"feature_tensors": [item_column]}),
35
+ GroupTransform({"feature_tensors": tensor_schema.names}),
36
36
  ]
37
37
 
38
38
  val_transforms = [
39
39
  RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
40
- GroupTransform({"feature_tensors": [item_column]}),
40
+ GroupTransform({"feature_tensors": tensor_schema.names}),
41
41
  ]
42
42
  test_transforms = copy.deepcopy(val_transforms)
43
43
 
@@ -13,7 +13,7 @@ def make_default_twotower_transforms(
13
13
 
14
14
  Generated pipeline expects input dataset to contain the following columns:
15
15
  1) Query ID column, specified by ``query_column``.
16
- 2) Item ID column, specified in the tensor schema.
16
+ 2) All features specified in the ``tensor_schema``.
17
17
 
18
18
  :param tensor_schema: TensorSchema used to infer feature columns.
19
19
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
@@ -1,62 +0,0 @@
1
- """
2
- Most metrics require dataframe with recommendations
3
- and dataframe with ground truth values —
4
- which objects each user interacted with.
5
-
6
- - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
- predictions of a recommender system,
8
- DataFrame with columns ``[user_id, item_id, relevance]``
9
- - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
- test data, DataFrame with columns
11
- ``[user_id, item_id, timestamp, relevance]``
12
-
13
- Metric is calculated for all users, presented in ``ground_truth``
14
- for accurate metric calculation in case when the recommender system generated
15
- recommendation not for all users. It is assumed, that all users,
16
- we want to calculate metric for, have positive interactions.
17
-
18
- But if we have users, who observed the recommendations, but have not responded,
19
- those users will be ignored and metric will be overestimated.
20
- For such case we propose additional optional parameter ``ground_truth_users``,
21
- the dataframe with all users, which should be considered during the metric calculation.
22
-
23
- - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
- full list of users to calculate metric for, DataFrame with ``user_id`` column
25
-
26
- Every metric is calculated using top ``K`` items for each user.
27
- It is also possible to calculate metrics
28
- using multiple values for ``K`` simultaneously.
29
- In this case the result will be a dictionary and not a number.
30
-
31
- Make sure your recommendations do not contain user-item duplicates
32
- as duplicates could lead to the wrong calculation results.
33
-
34
- - k (Union[Iterable[int], int]):
35
- a single number or a list, specifying the
36
- truncation length for recommendation list for each user
37
-
38
- By default, metrics are averaged by users,
39
- but you can alternatively use method ``metric.median``.
40
- Also, you can get the lower bound
41
- of ``conf_interval`` for a given ``alpha``.
42
-
43
- Diversity metrics require extra parameters on initialization stage,
44
- but do not use ``ground_truth`` parameter.
45
-
46
- For each metric, a formula for its calculation is given, because this is
47
- important for the correct comparison of algorithms, as mentioned in our
48
- `article <https://arxiv.org/abs/2206.12858>`_.
49
- """
50
-
51
- from replay.experimental.metrics.base_metric import Metric, NCISMetric
52
- from replay.experimental.metrics.coverage import Coverage
53
- from replay.experimental.metrics.hitrate import HitRate
54
- from replay.experimental.metrics.map import MAP
55
- from replay.experimental.metrics.mrr import MRR
56
- from replay.experimental.metrics.ncis_precision import NCISPrecision
57
- from replay.experimental.metrics.ndcg import NDCG
58
- from replay.experimental.metrics.precision import Precision
59
- from replay.experimental.metrics.recall import Recall
60
- from replay.experimental.metrics.rocauc import RocAuc
61
- from replay.experimental.metrics.surprisal import Surprisal
62
- from replay.experimental.metrics.unexpectedness import Unexpectedness