replay-rec 0.21.0__tar.gz → 0.21.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (223) hide show
  1. {replay_rec-0.21.0 → replay_rec-0.21.1}/PKG-INFO +1 -1
  2. {replay_rec-0.21.0 → replay_rec-0.21.1}/pyproject.toml +2 -2
  3. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/__init__.py +1 -1
  4. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/parquet_module.py +1 -1
  5. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/torch_metrics_builder.py +1 -1
  6. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/validation_callback.py +14 -4
  7. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/callback/metrics_callback.py +18 -9
  8. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/callback/predictions_callback.py +2 -2
  9. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/base.py +3 -3
  10. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/login_ce.py +1 -1
  11. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/model.py +1 -1
  12. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/reader.py +14 -5
  13. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/template/sasrec.py +3 -3
  14. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/template/twotower.py +1 -1
  15. {replay_rec-0.21.0 → replay_rec-0.21.1}/LICENSE +0 -0
  16. {replay_rec-0.21.0 → replay_rec-0.21.1}/NOTICE +0 -0
  17. {replay_rec-0.21.0 → replay_rec-0.21.1}/README.md +0 -0
  18. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/__init__.py +0 -0
  19. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/dataset.py +0 -0
  20. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/dataset_utils/__init__.py +0 -0
  21. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
  22. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/__init__.py +0 -0
  23. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/__init__.py +0 -0
  24. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/collate.py +0 -0
  25. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/__init__.py +0 -0
  26. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/batches.py +0 -0
  27. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/device.py +0 -0
  28. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/filesystem.py +0 -0
  29. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/constants/metadata.py +0 -0
  30. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/fixed_batch_dataset.py +0 -0
  31. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/__init__.py +0 -0
  32. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/array_1d_column.py +0 -0
  33. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/array_2d_column.py +0 -0
  34. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/column_protocol.py +0 -0
  35. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/indexing.py +0 -0
  36. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/masking.py +0 -0
  37. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/named_columns.py +0 -0
  38. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/numeric_column.py +0 -0
  39. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/impl/utils.py +0 -0
  40. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/__init__.py +0 -0
  41. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/distributed_info.py +0 -0
  42. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/partitioning.py +0 -0
  43. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/replicas.py +0 -0
  44. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/info/worker_info.py +0 -0
  45. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/iterable_dataset.py +0 -0
  46. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/iterator.py +0 -0
  47. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/metadata/__init__.py +0 -0
  48. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/metadata/metadata.py +0 -0
  49. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/parquet_dataset.py +0 -0
  50. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/partitioned_iterable_dataset.py +0 -0
  51. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/utils/__init__.py +0 -0
  52. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/parquet/utils/compute_length.py +0 -0
  53. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/schema.py +0 -0
  54. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/sequence_tokenizer.py +0 -0
  55. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/sequential_dataset.py +0 -0
  56. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/torch_sequential_dataset.py +0 -0
  57. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/nn/utils.py +0 -0
  58. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/schema.py +0 -0
  59. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/spark_schema.py +0 -0
  60. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/utils/__init__.py +0 -0
  61. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/utils/batching.py +0 -0
  62. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/utils/typing/__init__.py +0 -0
  63. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/data/utils/typing/dtype.py +0 -0
  64. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/__init__.py +0 -0
  65. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/base_metric.py +0 -0
  66. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/categorical_diversity.py +0 -0
  67. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/coverage.py +0 -0
  68. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/descriptors.py +0 -0
  69. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/experiment.py +0 -0
  70. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/hitrate.py +0 -0
  71. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/map.py +0 -0
  72. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/mrr.py +0 -0
  73. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/ndcg.py +0 -0
  74. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/novelty.py +0 -0
  75. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/offline_metrics.py +0 -0
  76. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/precision.py +0 -0
  77. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/recall.py +0 -0
  78. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/rocauc.py +0 -0
  79. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/surprisal.py +0 -0
  80. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/metrics/unexpectedness.py +0 -0
  81. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/__init__.py +0 -0
  82. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/als.py +0 -0
  83. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/association_rules.py +0 -0
  84. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/base_neighbour_rec.py +0 -0
  85. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/base_rec.py +0 -0
  86. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/cat_pop_rec.py +0 -0
  87. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/cluster.py +0 -0
  88. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/common.py +0 -0
  89. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/__init__.py +0 -0
  90. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/__init__.py +0 -0
  91. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/ann_mixin.py +0 -0
  92. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/__init__.py +0 -0
  93. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  94. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  95. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  96. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  97. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  98. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  99. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  100. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  101. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  102. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  103. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  104. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  105. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  106. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  107. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  108. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  109. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  110. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  111. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  112. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  113. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  114. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  115. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  116. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/extensions/ann/utils.py +0 -0
  117. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/kl_ucb.py +0 -0
  118. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/knn.py +0 -0
  119. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/lin_ucb.py +0 -0
  120. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/__init__.py +0 -0
  121. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/loss/__init__.py +0 -0
  122. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/loss/sce.py +0 -0
  123. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  124. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  125. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/__init__.py +0 -0
  126. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  127. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
  128. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  129. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/bert4rec/model.py +0 -0
  130. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  131. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  132. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/__init__.py +0 -0
  133. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/base_compiled_model.py +0 -0
  134. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +0 -0
  135. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/compiled/sasrec_compiled.py +0 -0
  136. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  137. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  138. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
  139. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  140. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
  141. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  142. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/nn/sequential/sasrec/model.py +0 -0
  143. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/optimization/__init__.py +0 -0
  144. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/optimization/optuna_mixin.py +0 -0
  145. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/optimization/optuna_objective.py +0 -0
  146. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/pop_rec.py +0 -0
  147. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/query_pop_rec.py +0 -0
  148. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/random_rec.py +0 -0
  149. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/slim.py +0 -0
  150. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/thompson_sampling.py +0 -0
  151. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/ucb.py +0 -0
  152. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/wilson.py +0 -0
  153. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/models/word2vec.py +0 -0
  154. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/__init__.py +0 -0
  155. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/agg.py +0 -0
  156. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/attention.py +0 -0
  157. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/embedding.py +0 -0
  158. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/ffn.py +0 -0
  159. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/head.py +0 -0
  160. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/__init__.py +0 -0
  161. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/callback/__init__.py +0 -0
  162. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/module.py +0 -0
  163. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/optimizer.py +0 -0
  164. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/__init__.py +0 -0
  165. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/_base.py +0 -0
  166. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/postprocessor/seen_items.py +0 -0
  167. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/lightning/scheduler.py +0 -0
  168. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/__init__.py +0 -0
  169. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/bce.py +0 -0
  170. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/ce.py +0 -0
  171. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/loss/logout_ce.py +0 -0
  172. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/mask.py +0 -0
  173. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/normalization.py +0 -0
  174. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/output.py +0 -0
  175. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/__init__.py +0 -0
  176. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/__init__.py +0 -0
  177. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/agg.py +0 -0
  178. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/diff_transformer.py +0 -0
  179. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/sasrec/transformer.py +0 -0
  180. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/__init__.py +0 -0
  181. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/sequential/twotower/model.py +0 -0
  182. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/__init__.py +0 -0
  183. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/copy.py +0 -0
  184. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/grouping.py +0 -0
  185. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/negative_sampling.py +0 -0
  186. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/next_token.py +0 -0
  187. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/rename.py +0 -0
  188. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/reshape.py +0 -0
  189. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/sequence_roll.py +0 -0
  190. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/template/__init__.py +0 -0
  191. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/token_mask.py +0 -0
  192. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/transform/trim.py +0 -0
  193. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/nn/utils.py +0 -0
  194. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/__init__.py +0 -0
  195. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/converter.py +0 -0
  196. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/discretizer.py +0 -0
  197. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/filters.py +0 -0
  198. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/history_based_fp.py +0 -0
  199. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/label_encoder.py +0 -0
  200. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/sessionizer.py +0 -0
  201. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/preprocessing/utils.py +0 -0
  202. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/scenarios/__init__.py +0 -0
  203. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/scenarios/fallback.py +0 -0
  204. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/__init__.py +0 -0
  205. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/base_splitter.py +0 -0
  206. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/cold_user_random_splitter.py +0 -0
  207. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/k_folds.py +0 -0
  208. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/last_n_splitter.py +0 -0
  209. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/new_users_splitter.py +0 -0
  210. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/random_next_n_splitter.py +0 -0
  211. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/random_splitter.py +0 -0
  212. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/ratio_splitter.py +0 -0
  213. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/time_splitter.py +0 -0
  214. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/splitters/two_stage_splitter.py +0 -0
  215. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/__init__.py +0 -0
  216. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/common.py +0 -0
  217. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/dataframe_bucketizer.py +0 -0
  218. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/distributions.py +0 -0
  219. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/model_handler.py +0 -0
  220. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/session_handler.py +0 -0
  221. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/spark_utils.py +0 -0
  222. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/time.py +0 -0
  223. {replay_rec-0.21.0 → replay_rec-0.21.1}/replay/utils/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.21.0
3
+ Version: 0.21.1
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -72,7 +72,7 @@ dependencies = [
72
72
  "lightning (<2.6.0); extra == 'torch' or extra == 'torch-cpu'",
73
73
  ]
74
74
  dynamic = ["dependencies"]
75
- version = "0.21.0"
75
+ version = "0.21.1"
76
76
 
77
77
  [project.optional-dependencies]
78
78
  spark = ["pyspark", "psutil"]
@@ -107,7 +107,7 @@ priority = "explicit"
107
107
 
108
108
  [tool.poetry-dynamic-versioning]
109
109
  enable = false
110
- format-jinja = """0.21.0{{ env['PACKAGE_SUFFIX'] }}"""
110
+ format-jinja = """0.21.1{{ env['PACKAGE_SUFFIX'] }}"""
111
111
  vcs = "git"
112
112
 
113
113
  [tool.ruff]
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.21.0"
7
+ __version__ = "0.21.1"
@@ -94,7 +94,7 @@ class ParquetModule(L.LightningDataModule):
94
94
  missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
95
95
  if missing_splits:
96
96
  msg = (
97
- f"The following dataset paths aren't provided: {','.join(missing_splits)}."
97
+ f"The following dataset paths aren't provided: {','.join(missing_splits)}. "
98
98
  "Make sure to disable these stages in your Lightning Trainer configuration."
99
99
  )
100
100
  warnings.warn(msg, stacklevel=2)
@@ -400,7 +400,7 @@ def metrics_to_df(metrics: Mapping[str, float]) -> PandasDataFrame:
400
400
 
401
401
  metric_name_and_k = metrics_df["m"].str.split("@", expand=True)
402
402
  metrics_df["metric"] = metric_name_and_k[0]
403
- metrics_df["k"] = metric_name_and_k[1]
403
+ metrics_df["k"] = metric_name_and_k[1].astype(int)
404
404
 
405
405
  pivoted_metrics = metrics_df.pivot(index="metric", columns="k", values="v")
406
406
  pivoted_metrics.index.name = None
@@ -162,14 +162,24 @@ class ValidationMetricsCallback(lightning.Callback):
162
162
  @rank_zero_only
163
163
  def print_metrics() -> None:
164
164
  metrics = {}
165
+
165
166
  for name, value in trainer.logged_metrics.items():
166
167
  if "@" in name:
167
168
  metrics[name] = value.item()
168
169
 
169
- if metrics:
170
- metrics_df = metrics_to_df(metrics)
170
+ if not metrics:
171
+ return
171
172
 
172
- print(metrics_df) # noqa: T201
173
- print() # noqa: T201
173
+ if len(self._dataloaders_size) > 1:
174
+ for i in range(len(self._dataloaders_size)):
175
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
176
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
177
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
178
+
179
+ print(suffix) # noqa: T201
180
+ print(metrics_df, "\n") # noqa: T201
181
+ else:
182
+ metrics_df = metrics_to_df(metrics)
183
+ print(metrics_df, "\n") # noqa: T201
174
184
 
175
185
  print_metrics()
@@ -2,7 +2,6 @@ from typing import Any, Optional
2
2
 
3
3
  import lightning
4
4
  import torch
5
- from lightning.pytorch.utilities.combined_loader import CombinedLoader
6
5
  from lightning.pytorch.utilities.rank_zero import rank_zero_only
7
6
 
8
7
  from replay.metrics.torch_metrics_builder import (
@@ -64,8 +63,8 @@ class ComputeMetricsCallback(lightning.Callback):
64
63
  self._train_column = train_column
65
64
 
66
65
  def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
67
- if isinstance(dataloaders, CombinedLoader):
68
- return [len(dataloader) for dataloader in dataloaders.flattened] # pragma: no cover
66
+ if isinstance(dataloaders, list):
67
+ return [len(dataloader) for dataloader in dataloaders]
69
68
  return [len(dataloaders)]
70
69
 
71
70
  def on_validation_epoch_start(
@@ -123,7 +122,7 @@ class ComputeMetricsCallback(lightning.Callback):
123
122
  batch: dict,
124
123
  batch_idx: int,
125
124
  dataloader_idx: int = 0,
126
- ) -> None: # pragma: no cover
125
+ ) -> None:
127
126
  self._batch_end(
128
127
  trainer,
129
128
  pl_module,
@@ -159,7 +158,7 @@ class ComputeMetricsCallback(lightning.Callback):
159
158
  def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
160
159
  self._epoch_end(trainer, pl_module)
161
160
 
162
- def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None: # pragma: no cover
161
+ def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
163
162
  self._epoch_end(trainer, pl_module)
164
163
 
165
164
  def _epoch_end(
@@ -170,14 +169,24 @@ class ComputeMetricsCallback(lightning.Callback):
170
169
  @rank_zero_only
171
170
  def print_metrics() -> None:
172
171
  metrics = {}
172
+
173
173
  for name, value in trainer.logged_metrics.items():
174
174
  if "@" in name:
175
175
  metrics[name] = value.item()
176
176
 
177
- if metrics:
178
- metrics_df = metrics_to_df(metrics)
177
+ if not metrics:
178
+ return
179
179
 
180
- print(metrics_df) # noqa: T201
181
- print() # noqa: T201
180
+ if len(self._dataloaders_size) > 1:
181
+ for i in range(len(self._dataloaders_size)):
182
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
183
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
184
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
185
+
186
+ print(suffix) # noqa: T201
187
+ print(metrics_df, "\n") # noqa: T201
188
+ else:
189
+ metrics_df = metrics_to_df(metrics)
190
+ print(metrics_df, "\n") # noqa: T201
182
191
 
183
192
  print_metrics()
@@ -15,11 +15,11 @@ from replay.utils import (
15
15
  SparkDataFrame,
16
16
  )
17
17
 
18
- if PYSPARK_AVAILABLE: # pragma: no cover
18
+ if PYSPARK_AVAILABLE:
19
19
  import pyspark.sql.functions as sf
20
20
  from pyspark.sql import SparkSession
21
21
  from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
22
- else: # pragma: no cover
22
+ else:
23
23
  SparkSession = MissingImport
24
24
 
25
25
 
@@ -85,7 +85,7 @@ class SampledLossBase(torch.nn.Module):
85
85
  # [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
86
86
  negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
87
87
 
88
- if negative_labels.dim() == 3: # pragma: no cover
88
+ if negative_labels.dim() == 3:
89
89
  # [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
90
90
  negative_labels = negative_labels.unsqueeze(-2)
91
91
  if num_positives != 1:
@@ -119,7 +119,7 @@ class SampledLossBase(torch.nn.Module):
119
119
  positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
120
120
  assert positive_labels.size() == (masked_batch_size, 1)
121
121
 
122
- if negative_labels.dim() != 1: # pragma: no cover
122
+ if negative_labels.dim() != 1:
123
123
  # [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
124
124
  negative_labels = negative_labels[target_padding_mask]
125
125
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -183,7 +183,7 @@ def mask_negative_logits(
183
183
  if negative_labels_ignore_index >= 0:
184
184
  negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
185
185
 
186
- if negative_labels.dim() > 1: # pragma: no cover
186
+ if negative_labels.dim() > 1:
187
187
  # [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
188
188
  negative_labels = negative_labels.unsqueeze(-2)
189
189
 
@@ -74,7 +74,7 @@ class LogInCEBase(SampledLossBase):
74
74
  positive_labels = positive_labels[masked_target_padding_mask]
75
75
  assert positive_labels.size() == (masked_batch_size, num_positives)
76
76
 
77
- if negative_labels.dim() > 1: # pragma: no cover
77
+ if negative_labels.dim() > 1:
78
78
  # [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
79
79
  negative_labels = negative_labels[masked_target_padding_mask]
80
80
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -141,7 +141,7 @@ class SasRec(torch.nn.Module):
141
141
  feature_type=FeatureType.CATEGORICAL,
142
142
  embedding_dim=256,
143
143
  padding_value=NUM_UNIQUE_ITEMS,
144
- cardinality=NUM_UNIQUE_ITEMS+1,
144
+ cardinality=NUM_UNIQUE_ITEMS,
145
145
  feature_hint=FeatureHint.ITEM_ID,
146
146
  feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
147
147
  ),
@@ -22,7 +22,6 @@ class FeaturesReader:
22
22
  :param schema: the same tensor schema used in TwoTower model.
23
23
  :param metadata: A dictionary of feature names that
24
24
  associated with its shape and padding_value.\n
25
- Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
26
25
  For details, see the section :ref:`parquet-processing`.
27
26
  :param path: path to parquet with dataframe of item features.\n
28
27
  **Note:**\n
@@ -30,8 +29,8 @@ class FeaturesReader:
30
29
  2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
31
30
  of the source features to create correct inverse mapping.
32
31
  Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
33
- contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
34
- or hint type FeatureHint.ITEM_ID.
32
+ contain ``feature_sources`` with a source of type ``FeatureSource.ITEM_FEATURES``
33
+ or hint type ``FeatureHint.ITEM_ID``.
35
34
 
36
35
  """
37
36
  item_feature_names = [
@@ -81,8 +80,18 @@ class FeaturesReader:
81
80
  self._features = {}
82
81
 
83
82
  for k in features.columns:
84
- dtype = torch.float32 if schema[k].is_num else torch.int64
85
- feature_tensor = torch.asarray(features[k], dtype=dtype)
83
+ dtype = np.float32 if schema[k].is_num else np.int64
84
+ if schema[k].is_list:
85
+ feature = np.asarray(
86
+ features[k].to_list(),
87
+ dtype=dtype,
88
+ )
89
+ else:
90
+ feature = features[k].to_numpy(dtype=dtype)
91
+ feature_tensor = torch.asarray(
92
+ feature,
93
+ dtype=torch.float32 if schema[k].is_num else torch.int64,
94
+ )
86
95
  self._features[k] = feature_tensor
87
96
 
88
97
  def __getitem__(self, key: str) -> torch.Tensor:
@@ -14,7 +14,7 @@ def make_default_sasrec_transforms(
14
14
 
15
15
  Generated pipeline expects input dataset to contain the following columns:
16
16
  1) Query ID column, specified by ``query_column``.
17
- 2) Item ID column, specified in the tensor schema.
17
+ 2) All features specified in the ``tensor_schema``.
18
18
 
19
19
  :param tensor_schema: TensorSchema used to infer feature columns.
20
20
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
@@ -32,12 +32,12 @@ def make_default_sasrec_transforms(
32
32
  ),
33
33
  UnsqueezeTransform("target_padding_mask", -1),
34
34
  UnsqueezeTransform("positive_labels", -1),
35
- GroupTransform({"feature_tensors": [item_column]}),
35
+ GroupTransform({"feature_tensors": tensor_schema.names}),
36
36
  ]
37
37
 
38
38
  val_transforms = [
39
39
  RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
40
- GroupTransform({"feature_tensors": [item_column]}),
40
+ GroupTransform({"feature_tensors": tensor_schema.names}),
41
41
  ]
42
42
  test_transforms = copy.deepcopy(val_transforms)
43
43
 
@@ -13,7 +13,7 @@ def make_default_twotower_transforms(
13
13
 
14
14
  Generated pipeline expects input dataset to contain the following columns:
15
15
  1) Query ID column, specified by ``query_column``.
16
- 2) Item ID column, specified in the tensor schema.
16
+ 2) All features specified in the ``tensor_schema``.
17
17
 
18
18
  :param tensor_schema: TensorSchema used to infer feature columns.
19
19
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
File without changes
File without changes
File without changes