replay-rec 0.20.3rc0__tar.gz → 0.21.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (274) hide show
  1. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/PKG-INFO +18 -12
  2. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/README.md +1 -1
  3. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/pyproject.toml +49 -32
  4. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/__init__.py +1 -1
  5. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/dataset.py +11 -0
  6. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/__init__.py +3 -0
  7. replay_rec-0.21.0/replay/data/nn/parquet/__init__.py +22 -0
  8. replay_rec-0.21.0/replay/data/nn/parquet/collate.py +29 -0
  9. replay_rec-0.21.0/replay/data/nn/parquet/constants/batches.py +8 -0
  10. replay_rec-0.21.0/replay/data/nn/parquet/constants/device.py +3 -0
  11. replay_rec-0.21.0/replay/data/nn/parquet/constants/filesystem.py +3 -0
  12. replay_rec-0.21.0/replay/data/nn/parquet/constants/metadata.py +5 -0
  13. replay_rec-0.21.0/replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  14. replay_rec-0.21.0/replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  15. replay_rec-0.21.0/replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  16. replay_rec-0.21.0/replay/data/nn/parquet/impl/column_protocol.py +17 -0
  17. replay_rec-0.21.0/replay/data/nn/parquet/impl/indexing.py +123 -0
  18. replay_rec-0.21.0/replay/data/nn/parquet/impl/masking.py +20 -0
  19. replay_rec-0.21.0/replay/data/nn/parquet/impl/named_columns.py +100 -0
  20. replay_rec-0.21.0/replay/data/nn/parquet/impl/numeric_column.py +110 -0
  21. replay_rec-0.21.0/replay/data/nn/parquet/impl/utils.py +17 -0
  22. replay_rec-0.21.0/replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay_rec-0.21.0/replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay_rec-0.21.0/replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay_rec-0.21.0/replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay_rec-0.21.0/replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay_rec-0.21.0/replay/data/nn/parquet/iterator.py +61 -0
  28. replay_rec-0.21.0/replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay_rec-0.21.0/replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay_rec-0.21.0/replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay_rec-0.21.0/replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay_rec-0.21.0/replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay_rec-0.21.0/replay/data/nn/parquet/utils/compute_length.py +66 -0
  34. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/schema.py +12 -14
  35. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/sequence_tokenizer.py +5 -0
  36. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/sequential_dataset.py +4 -0
  37. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/torch_sequential_dataset.py +5 -0
  38. replay_rec-0.21.0/replay/data/utils/batching.py +69 -0
  39. replay_rec-0.21.0/replay/data/utils/typing/dtype.py +65 -0
  40. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/torch_metrics_builder.py +20 -14
  41. replay_rec-0.21.0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  42. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/loss/sce.py +2 -7
  43. replay_rec-0.21.0/replay/models/nn/optimizer_utils/__init__.py +9 -0
  44. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  45. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  46. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  47. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/bert4rec/model.py +11 -11
  48. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  49. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  50. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  51. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  52. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  53. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/postprocessors/_base.py +5 -0
  54. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  55. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/sasrec/dataset.py +81 -26
  56. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/sasrec/lightning.py +86 -24
  57. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/sasrec/model.py +14 -9
  58. replay_rec-0.21.0/replay/nn/__init__.py +8 -0
  59. replay_rec-0.21.0/replay/nn/agg.py +109 -0
  60. replay_rec-0.21.0/replay/nn/attention.py +158 -0
  61. replay_rec-0.21.0/replay/nn/embedding.py +283 -0
  62. replay_rec-0.21.0/replay/nn/ffn.py +135 -0
  63. replay_rec-0.21.0/replay/nn/head.py +49 -0
  64. replay_rec-0.21.0/replay/nn/lightning/__init__.py +1 -0
  65. replay_rec-0.21.0/replay/nn/lightning/callback/__init__.py +9 -0
  66. replay_rec-0.21.0/replay/nn/lightning/callback/metrics_callback.py +183 -0
  67. replay_rec-0.21.0/replay/nn/lightning/callback/predictions_callback.py +314 -0
  68. replay_rec-0.21.0/replay/nn/lightning/module.py +123 -0
  69. replay_rec-0.21.0/replay/nn/lightning/optimizer.py +60 -0
  70. replay_rec-0.21.0/replay/nn/lightning/postprocessor/__init__.py +2 -0
  71. replay_rec-0.21.0/replay/nn/lightning/postprocessor/_base.py +51 -0
  72. replay_rec-0.21.0/replay/nn/lightning/postprocessor/seen_items.py +83 -0
  73. replay_rec-0.21.0/replay/nn/lightning/scheduler.py +91 -0
  74. replay_rec-0.21.0/replay/nn/loss/__init__.py +22 -0
  75. replay_rec-0.21.0/replay/nn/loss/base.py +197 -0
  76. replay_rec-0.21.0/replay/nn/loss/bce.py +216 -0
  77. replay_rec-0.21.0/replay/nn/loss/ce.py +317 -0
  78. replay_rec-0.21.0/replay/nn/loss/login_ce.py +373 -0
  79. replay_rec-0.21.0/replay/nn/loss/logout_ce.py +230 -0
  80. replay_rec-0.21.0/replay/nn/mask.py +87 -0
  81. replay_rec-0.21.0/replay/nn/normalization.py +9 -0
  82. replay_rec-0.21.0/replay/nn/output.py +37 -0
  83. replay_rec-0.21.0/replay/nn/sequential/__init__.py +9 -0
  84. replay_rec-0.21.0/replay/nn/sequential/sasrec/__init__.py +7 -0
  85. replay_rec-0.21.0/replay/nn/sequential/sasrec/agg.py +53 -0
  86. replay_rec-0.21.0/replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  87. replay_rec-0.21.0/replay/nn/sequential/sasrec/model.py +377 -0
  88. replay_rec-0.21.0/replay/nn/sequential/sasrec/transformer.py +107 -0
  89. replay_rec-0.21.0/replay/nn/sequential/twotower/__init__.py +2 -0
  90. replay_rec-0.21.0/replay/nn/sequential/twotower/model.py +674 -0
  91. replay_rec-0.21.0/replay/nn/sequential/twotower/reader.py +89 -0
  92. replay_rec-0.21.0/replay/nn/transform/__init__.py +22 -0
  93. replay_rec-0.21.0/replay/nn/transform/copy.py +38 -0
  94. replay_rec-0.21.0/replay/nn/transform/grouping.py +39 -0
  95. replay_rec-0.21.0/replay/nn/transform/negative_sampling.py +182 -0
  96. replay_rec-0.21.0/replay/nn/transform/next_token.py +100 -0
  97. replay_rec-0.21.0/replay/nn/transform/rename.py +33 -0
  98. replay_rec-0.21.0/replay/nn/transform/reshape.py +41 -0
  99. replay_rec-0.21.0/replay/nn/transform/sequence_roll.py +48 -0
  100. replay_rec-0.21.0/replay/nn/transform/template/__init__.py +2 -0
  101. replay_rec-0.21.0/replay/nn/transform/template/sasrec.py +53 -0
  102. replay_rec-0.21.0/replay/nn/transform/template/twotower.py +22 -0
  103. replay_rec-0.21.0/replay/nn/transform/token_mask.py +69 -0
  104. replay_rec-0.21.0/replay/nn/transform/trim.py +51 -0
  105. replay_rec-0.21.0/replay/nn/utils.py +28 -0
  106. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/filters.py +128 -0
  107. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/label_encoder.py +36 -33
  108. replay_rec-0.21.0/replay/preprocessing/utils.py +209 -0
  109. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/__init__.py +1 -0
  110. replay_rec-0.21.0/replay/splitters/random_next_n_splitter.py +224 -0
  111. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/common.py +10 -4
  112. replay_rec-0.20.3rc0/replay/experimental/metrics/__init__.py +0 -62
  113. replay_rec-0.20.3rc0/replay/experimental/metrics/base_metric.py +0 -603
  114. replay_rec-0.20.3rc0/replay/experimental/metrics/coverage.py +0 -97
  115. replay_rec-0.20.3rc0/replay/experimental/metrics/experiment.py +0 -175
  116. replay_rec-0.20.3rc0/replay/experimental/metrics/hitrate.py +0 -26
  117. replay_rec-0.20.3rc0/replay/experimental/metrics/map.py +0 -30
  118. replay_rec-0.20.3rc0/replay/experimental/metrics/mrr.py +0 -18
  119. replay_rec-0.20.3rc0/replay/experimental/metrics/ncis_precision.py +0 -31
  120. replay_rec-0.20.3rc0/replay/experimental/metrics/ndcg.py +0 -49
  121. replay_rec-0.20.3rc0/replay/experimental/metrics/precision.py +0 -22
  122. replay_rec-0.20.3rc0/replay/experimental/metrics/recall.py +0 -25
  123. replay_rec-0.20.3rc0/replay/experimental/metrics/rocauc.py +0 -49
  124. replay_rec-0.20.3rc0/replay/experimental/metrics/surprisal.py +0 -90
  125. replay_rec-0.20.3rc0/replay/experimental/metrics/unexpectedness.py +0 -76
  126. replay_rec-0.20.3rc0/replay/experimental/models/__init__.py +0 -50
  127. replay_rec-0.20.3rc0/replay/experimental/models/admm_slim.py +0 -257
  128. replay_rec-0.20.3rc0/replay/experimental/models/base_neighbour_rec.py +0 -200
  129. replay_rec-0.20.3rc0/replay/experimental/models/base_rec.py +0 -1386
  130. replay_rec-0.20.3rc0/replay/experimental/models/base_torch_rec.py +0 -234
  131. replay_rec-0.20.3rc0/replay/experimental/models/cql.py +0 -454
  132. replay_rec-0.20.3rc0/replay/experimental/models/ddpg.py +0 -932
  133. replay_rec-0.20.3rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
  134. replay_rec-0.20.3rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
  135. replay_rec-0.20.3rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
  136. replay_rec-0.20.3rc0/replay/experimental/models/dt4rec/utils.py +0 -264
  137. replay_rec-0.20.3rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  138. replay_rec-0.20.3rc0/replay/experimental/models/hierarchical_recommender.py +0 -331
  139. replay_rec-0.20.3rc0/replay/experimental/models/implicit_wrap.py +0 -131
  140. replay_rec-0.20.3rc0/replay/experimental/models/lightfm_wrap.py +0 -303
  141. replay_rec-0.20.3rc0/replay/experimental/models/mult_vae.py +0 -332
  142. replay_rec-0.20.3rc0/replay/experimental/models/neural_ts.py +0 -986
  143. replay_rec-0.20.3rc0/replay/experimental/models/neuromf.py +0 -406
  144. replay_rec-0.20.3rc0/replay/experimental/models/scala_als.py +0 -293
  145. replay_rec-0.20.3rc0/replay/experimental/models/u_lin_ucb.py +0 -115
  146. replay_rec-0.20.3rc0/replay/experimental/nn/data/__init__.py +0 -1
  147. replay_rec-0.20.3rc0/replay/experimental/nn/data/schema_builder.py +0 -102
  148. replay_rec-0.20.3rc0/replay/experimental/preprocessing/__init__.py +0 -3
  149. replay_rec-0.20.3rc0/replay/experimental/preprocessing/data_preparator.py +0 -839
  150. replay_rec-0.20.3rc0/replay/experimental/preprocessing/padder.py +0 -229
  151. replay_rec-0.20.3rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
  152. replay_rec-0.20.3rc0/replay/experimental/scenarios/__init__.py +0 -1
  153. replay_rec-0.20.3rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  154. replay_rec-0.20.3rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  155. replay_rec-0.20.3rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  156. replay_rec-0.20.3rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  157. replay_rec-0.20.3rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
  158. replay_rec-0.20.3rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  159. replay_rec-0.20.3rc0/replay/experimental/utils/logger.py +0 -24
  160. replay_rec-0.20.3rc0/replay/experimental/utils/model_handler.py +0 -186
  161. replay_rec-0.20.3rc0/replay/experimental/utils/session_handler.py +0 -44
  162. replay_rec-0.20.3rc0/replay/models/nn/optimizer_utils/__init__.py +0 -4
  163. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/LICENSE +0 -0
  164. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/NOTICE +0 -0
  165. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/__init__.py +0 -0
  166. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/dataset_utils/__init__.py +0 -0
  167. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
  168. {replay_rec-0.20.3rc0/replay/experimental → replay_rec-0.21.0/replay/data/nn/parquet/constants}/__init__.py +0 -0
  169. {replay_rec-0.20.3rc0/replay/experimental/models/dt4rec → replay_rec-0.21.0/replay/data/nn/parquet/impl}/__init__.py +0 -0
  170. {replay_rec-0.20.3rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.21.0/replay/data/nn/parquet/info}/__init__.py +0 -0
  171. {replay_rec-0.20.3rc0/replay/experimental/scenarios/two_stages → replay_rec-0.21.0/replay/data/nn/parquet/utils}/__init__.py +0 -0
  172. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/nn/utils.py +0 -0
  173. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/schema.py +0 -0
  174. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/data/spark_schema.py +0 -0
  175. {replay_rec-0.20.3rc0/replay/experimental → replay_rec-0.21.0/replay/data}/utils/__init__.py +0 -0
  176. {replay_rec-0.20.3rc0/replay/models/extensions → replay_rec-0.21.0/replay/data/utils/typing}/__init__.py +0 -0
  177. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/__init__.py +0 -0
  178. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/base_metric.py +0 -0
  179. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/categorical_diversity.py +0 -0
  180. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/coverage.py +0 -0
  181. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/descriptors.py +0 -0
  182. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/experiment.py +0 -0
  183. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/hitrate.py +0 -0
  184. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/map.py +0 -0
  185. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/mrr.py +0 -0
  186. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/ndcg.py +0 -0
  187. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/novelty.py +0 -0
  188. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/offline_metrics.py +0 -0
  189. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/precision.py +0 -0
  190. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/recall.py +0 -0
  191. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/rocauc.py +0 -0
  192. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/surprisal.py +0 -0
  193. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/metrics/unexpectedness.py +0 -0
  194. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/__init__.py +0 -0
  195. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/als.py +0 -0
  196. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/association_rules.py +0 -0
  197. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/base_neighbour_rec.py +0 -0
  198. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/base_rec.py +0 -0
  199. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/cat_pop_rec.py +0 -0
  200. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/cluster.py +0 -0
  201. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/common.py +0 -0
  202. {replay_rec-0.20.3rc0/replay/models/extensions/ann → replay_rec-0.21.0/replay/models/extensions}/__init__.py +0 -0
  203. {replay_rec-0.20.3rc0/replay/models/extensions/ann/entities → replay_rec-0.21.0/replay/models/extensions/ann}/__init__.py +0 -0
  204. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/ann_mixin.py +0 -0
  205. {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_builders → replay_rec-0.21.0/replay/models/extensions/ann/entities}/__init__.py +0 -0
  206. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  207. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  208. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  209. {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_inferers → replay_rec-0.21.0/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
  210. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  211. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  212. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  213. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  214. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  215. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  216. {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_stores → replay_rec-0.21.0/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
  217. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  218. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  219. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  220. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  221. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  222. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  223. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  224. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  225. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  226. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  227. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  228. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/extensions/ann/utils.py +0 -0
  229. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/kl_ucb.py +0 -0
  230. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/knn.py +0 -0
  231. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/lin_ucb.py +0 -0
  232. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/__init__.py +0 -0
  233. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/loss/__init__.py +0 -0
  234. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/__init__.py +0 -0
  235. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  236. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  237. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/compiled/__init__.py +0 -0
  238. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  239. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  240. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/optimization/__init__.py +0 -0
  241. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/optimization/optuna_mixin.py +0 -0
  242. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/optimization/optuna_objective.py +0 -0
  243. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/pop_rec.py +0 -0
  244. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/query_pop_rec.py +0 -0
  245. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/random_rec.py +0 -0
  246. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/slim.py +0 -0
  247. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/thompson_sampling.py +0 -0
  248. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/ucb.py +0 -0
  249. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/wilson.py +0 -0
  250. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/models/word2vec.py +0 -0
  251. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/__init__.py +0 -0
  252. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/converter.py +0 -0
  253. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/discretizer.py +0 -0
  254. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/history_based_fp.py +0 -0
  255. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/preprocessing/sessionizer.py +0 -0
  256. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/scenarios/__init__.py +0 -0
  257. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/scenarios/fallback.py +0 -0
  258. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/base_splitter.py +0 -0
  259. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/cold_user_random_splitter.py +0 -0
  260. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/k_folds.py +0 -0
  261. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/last_n_splitter.py +0 -0
  262. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/new_users_splitter.py +0 -0
  263. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/random_splitter.py +0 -0
  264. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/ratio_splitter.py +0 -0
  265. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/time_splitter.py +0 -0
  266. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/splitters/two_stage_splitter.py +0 -0
  267. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/__init__.py +0 -0
  268. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/dataframe_bucketizer.py +0 -0
  269. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/distributions.py +0 -0
  270. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/model_handler.py +0 -0
  271. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/session_handler.py +0 -0
  272. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/spark_utils.py +0 -0
  273. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/time.py +0 -0
  274. {replay_rec-0.20.3rc0 → replay_rec-0.21.0}/replay/utils/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.20.3rc0
3
+ Version: 0.21.0
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Natural Language :: English
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
- Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
- Requires-Dist: implicit (>=0.7.2,<0.8)
19
- Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
- Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
- Requires-Dist: numba (>=0.50,<1)
17
+ Provides-Extra: spark
18
+ Provides-Extra: torch
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
22
23
  Requires-Dist: numpy (>=1.20.0,<2)
23
24
  Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
25
  Requires-Dist: polars (<2.0)
25
- Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
26
28
  Requires-Dist: pyarrow (<22.0)
27
- Requires-Dist: pyspark (>=3.0,<3.5)
28
- Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
- Requires-Dist: sb-obp (>=0.5.10,<0.6)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
30
34
  Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
35
  Requires-Dist: scipy (>=1.8.1,<2.0.0)
32
36
  Requires-Dist: setuptools
33
- Requires-Dist: torch (>=1.8,<2.9.0)
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
34
40
  Requires-Dist: tqdm (>=4.67,<5)
35
41
  Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
36
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
@@ -231,7 +237,7 @@ pip install optuna
231
237
 
232
238
  2) Model compilation via OpenVINO:
233
239
  ```bash
234
- pip install openvino onnx
240
+ pip install openvino onnx onnxscript
235
241
  ```
236
242
 
237
243
  3) Vector database and hierarchical search support:
@@ -193,7 +193,7 @@ pip install optuna
193
193
 
194
194
  2) Model compilation via OpenVINO:
195
195
  ```bash
196
- pip install openvino onnx
196
+ pip install openvino onnx onnxscript
197
197
  ```
198
198
 
199
199
  3) Vector database and hierarchical search support:
@@ -1,11 +1,36 @@
1
1
  [build-system]
2
2
  requires = [
3
- "poetry-core>=2.0.0",
3
+ "poetry-core>=2.2.1",
4
4
  "poetry-dynamic-versioning>=1.0.0,<2.0.0",
5
5
  "setuptools",
6
6
  ]
7
7
  build-backend = "poetry_dynamic_versioning.backend"
8
8
 
9
+ [dependency-groups]
10
+ dev = [
11
+ "coverage-conditional-plugin (>=0.9, <1)",
12
+ "jupyter (>=1.0, <1.1)",
13
+ "jupyterlab (>=3.6, <4)",
14
+ "pyarrow-stubs",
15
+ "pytest (>=7.1.0)",
16
+ "pytest-mock (>3.15, <4.0)",
17
+ "pytest-cov (>=3.0)",
18
+ "statsmodels (>=0.14, <0.15)",
19
+ "black (>=23.3.0)",
20
+ "ruff (>=0.0.261)",
21
+ "hypothesis",
22
+ "toml-sort (>=0.23, <0.24)",
23
+ "sphinx (==5.3.0)",
24
+ "sphinx-rtd-theme (==1.2.2)",
25
+ "sphinx-autodoc-typehints (==1.23.0)",
26
+ "sphinx-enum-extend (==0.1.3)",
27
+ "myst-parser (==1.0.0)",
28
+ "ghp-import (==2.1.0)",
29
+ "docutils (==0.16)",
30
+ "data-science-types (==0.2.23)",
31
+ "filelock (>=3.14, <3.15)",
32
+ ]
33
+
9
34
  [project]
10
35
  name = "replay-rec"
11
36
  license = "Apache-2.0"
@@ -40,19 +65,19 @@ dependencies = [
40
65
  "scikit-learn (>=1.6.1,<1.7.0)",
41
66
  "pyarrow (<22.0)",
42
67
  "tqdm (>=4.67,<5)",
43
- "torch (>=1.8,<2.9.0)",
44
- "lightning (>=2.0.2,<=2.4.0)",
45
- "pytorch-optimizer (>=3.8.0,<4)",
46
- "lightautoml (>=0.4.1,<0.5)",
47
- "numba (>=0.50,<1)",
48
- "sb-obp (>=0.5.10,<0.6)",
49
- "d3rlpy (>=2.8.1,<2.9)",
50
- "implicit (>=0.7.2,<0.8)",
51
- "pyspark (>=3.0,<3.5)",
52
- "psutil (<=7.0.0)",
68
+ "pyspark (>=3.0,<3.5); extra == 'spark'",
69
+ "psutil (<=7.0.0); extra == 'spark'",
70
+ "torch (>=1.8, <3.0.0); extra == 'torch' or extra == 'torch-cpu'",
71
+ "pytorch-optimizer (>=3.8.0,<3.9.0); extra == 'torch' or extra == 'torch-cpu'",
72
+ "lightning (<2.6.0); extra == 'torch' or extra == 'torch-cpu'",
53
73
  ]
54
74
  dynamic = ["dependencies"]
55
- version = "0.20.3.preview"
75
+ version = "0.21.0"
76
+
77
+ [project.optional-dependencies]
78
+ spark = ["pyspark", "psutil"]
79
+ torch = ["torch", "pytorch-optimizer", "lightning"]
80
+ torch-cpu = ["torch", "pytorch-optimizer", "lightning"]
56
81
 
57
82
  [project.urls]
58
83
  homepage = "https://sb-ai-lab.github.io/RePlay/"
@@ -66,31 +91,23 @@ target-version = ["py39", "py310", "py311", "py312"]
66
91
  packages = [{include = "replay"}]
67
92
  exclude = [
68
93
  "replay/conftest.py",
94
+ "replay/experimental",
95
+ ]
96
+
97
+ [tool.poetry.dependencies]
98
+ torch = [
99
+ {markers = "extra == 'torch-cpu' and extra !='torch'", source = "torch-cpu-mirror"},
100
+ {markers = "extra == 'torch' and extra !='torch-cpu'", source = "PyPI"},
69
101
  ]
70
102
 
71
- [tool.poetry.group.dev.dependencies]
72
- coverage-conditional-plugin = "^0.9.0"
73
- jupyter = "~1.0.0"
74
- jupyterlab = "^3.6.0"
75
- pytest = ">=7.1.0"
76
- pytest-cov = ">=3.0.0"
77
- statsmodels = "~0.14.0"
78
- black = ">=23.3.0"
79
- ruff = ">=0.0.261"
80
- toml-sort = "^0.23.0"
81
- sphinx = "5.3.0"
82
- sphinx-rtd-theme = "1.2.2"
83
- sphinx-autodoc-typehints = "1.23.0"
84
- sphinx-enum-extend = "0.1.3"
85
- myst-parser = "1.0.0"
86
- ghp-import = "2.1.0"
87
- docutils = "0.16"
88
- data-science-types = "0.2.23"
89
- filelock = "~3.14.0"
103
+ [[tool.poetry.source]]
104
+ name = "torch-cpu-mirror"
105
+ url = "https://download.pytorch.org/whl/cpu"
106
+ priority = "explicit"
90
107
 
91
108
  [tool.poetry-dynamic-versioning]
92
109
  enable = false
93
- format-jinja = """0.20.3{{ env['PACKAGE_SUFFIX'] }}"""
110
+ format-jinja = """0.21.0{{ env['PACKAGE_SUFFIX'] }}"""
94
111
  vcs = "git"
95
112
 
96
113
  [tool.ruff]
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.20.3.preview"
7
+ __version__ = "0.21.0"
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import json
8
+ import warnings
8
9
  from collections.abc import Iterable, Sequence
9
10
  from pathlib import Path
10
11
  from typing import Callable, Optional, Union
@@ -45,6 +46,7 @@ class Dataset:
45
46
  ):
46
47
  """
47
48
  :param feature_schema: mapping of columns names and feature infos.
49
+ All features not specified in the schema will be assumed numerical by default.
48
50
  :param interactions: dataframe with interactions.
49
51
  :param query_features: dataframe with query features,
50
52
  defaults: ```None```.
@@ -498,6 +500,15 @@ class Dataset:
498
500
  source=FeatureSource.QUERY_FEATURES,
499
501
  feature_schema=updated_feature_schema,
500
502
  )
503
+
504
+ if filled_features:
505
+ msg = (
506
+ "The following features are present in the dataset but have not been specified "
507
+ f"by the feature schema: {[(info.column, info.feature_source.value) for info in filled_features]}. "
508
+ "These features will be interpreted as NUMERICAL."
509
+ )
510
+ warnings.warn(msg, stacklevel=2)
511
+
501
512
  return FeatureSchema(features_list=features_list + filled_features)
502
513
 
503
514
  def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
@@ -1,6 +1,7 @@
1
1
  from replay.utils import TORCH_AVAILABLE
2
2
 
3
3
  if TORCH_AVAILABLE:
4
+ from .parquet import ParquetDataset, ParquetModule
4
5
  from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
5
6
  from .sequence_tokenizer import SequenceTokenizer
6
7
  from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
@@ -18,6 +19,8 @@ if TORCH_AVAILABLE:
18
19
  "DEFAULT_TRAIN_PADDING_VALUE",
19
20
  "MutableTensorMap",
20
21
  "PandasSequentialDataset",
22
+ "ParquetDataset",
23
+ "ParquetModule",
21
24
  "PolarsSequentialDataset",
22
25
  "SequenceTokenizer",
23
26
  "SequentialDataset",
@@ -0,0 +1,22 @@
1
+ """
2
+ Implementation of the ``ParquetDataset`` and its internals.
3
+
4
+ ``ParquetDataset`` is combination of PyTorch-compatible dataset and sampler which enables
5
+ training and inference of models on datasets of any arbitrary size by leveraging PyArrow
6
+ Datasets to perform batch-wise reading and processing of data from disk.
7
+
8
+ ``ParquetDataset`` includes support for Pytorch's distributed training framework as well as
9
+ access to remotely stored data via PyArrow's filesystem configs.
10
+ """
11
+
12
+ from .info.replicas import DEFAULT_REPLICAS_INFO, ReplicasInfo, ReplicasInfoProtocol
13
+ from .parquet_dataset import ParquetDataset
14
+ from .parquet_module import ParquetModule
15
+
16
+ __all__ = [
17
+ "DEFAULT_REPLICAS_INFO",
18
+ "ParquetDataset",
19
+ "ParquetModule",
20
+ "ReplicasInfo",
21
+ "ReplicasInfoProtocol",
22
+ ]
@@ -0,0 +1,29 @@
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+
5
+ from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralValue
6
+
7
+
8
+ def dict_collate(batch: Sequence[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
9
+ """Simple collate function that converts a dict of values into a tensor dict."""
10
+ return {k: torch.cat([d[k] for d in batch], dim=0) for k in batch[0]}
11
+
12
+
13
+ def general_collate(batch: Sequence[GeneralBatch]) -> GeneralBatch:
14
+ """General collate function that converts a nested dict of values into a tensor dict."""
15
+ result = {}
16
+ test_sample = batch[0]
17
+
18
+ if len(batch) == 1:
19
+ return test_sample
20
+
21
+ for key, test_value in test_sample.items():
22
+ values: Sequence[GeneralValue] = [sample[key] for sample in batch]
23
+ if torch.is_tensor(test_value):
24
+ result[key] = torch.cat(values, dim=0)
25
+ else:
26
+ assert isinstance(test_value, dict)
27
+ result[key] = general_collate(values)
28
+
29
+ return result
@@ -0,0 +1,8 @@
1
+ from typing import Callable, Union
2
+
3
+ import torch
4
+ from typing_extensions import TypeAlias
5
+
6
+ GeneralValue: TypeAlias = Union[torch.Tensor, "GeneralBatch"]
7
+ GeneralBatch: TypeAlias = dict[str, GeneralValue]
8
+ GeneralCollateFn: TypeAlias = Callable[[GeneralBatch], GeneralBatch]
@@ -0,0 +1,3 @@
1
+ import torch
2
+
3
+ DEFAULT_DEVICE = torch.device("cpu")
@@ -0,0 +1,3 @@
1
+ import pyarrow.fs as fs
2
+
3
+ DEFAULT_FILESYSTEM = fs.LocalFileSystem()
@@ -0,0 +1,5 @@
1
+ SHAPE_FLAG = "shape"
2
+ PADDING_FLAG = "padding"
3
+ DEFAULT_PADDING = -1
4
+ SEQUENCE_LENGTH_FLAG = "sequence_length"
5
+ PADDING_FLAG = "padding"
@@ -0,0 +1,157 @@
1
+ import warnings
2
+ from collections.abc import Iterator
3
+ from typing import Callable, Optional, Protocol, cast
4
+
5
+ import torch
6
+ from torch.utils.data import IterableDataset
7
+
8
+ from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
9
+ from replay.data.nn.parquet.impl.masking import DEFAULT_COLLATE_FN
10
+
11
+
12
+ def get_batch_size(batch: GeneralBatch, strict: bool = False) -> int:
13
+ """
14
+ Retrieves the size of the ``batch`` object.
15
+
16
+ :param batch: Batch object.
17
+ :param strict: If ``True``, performs additional validation. Default: ``False``.
18
+
19
+ :raises ValueError: If size mismatch is found in the batch during a strict check.
20
+
21
+ :return: Batch size.
22
+ """
23
+ batch_size: Optional[int] = None
24
+
25
+ for key, value in batch.items():
26
+ new_batch_size: int
27
+
28
+ if torch.is_tensor(value):
29
+ new_batch_size = value.size(0)
30
+ else:
31
+ assert isinstance(value, dict)
32
+ new_batch_size = get_batch_size(value, strict)
33
+
34
+ if batch_size is None:
35
+ batch_size = new_batch_size
36
+
37
+ if strict:
38
+ if batch_size != new_batch_size:
39
+ msg = f"Batch size mismatch {key}: {batch_size} != {new_batch_size}"
40
+ raise ValueError(msg)
41
+ else:
42
+ break
43
+ assert batch_size is not None
44
+ return cast(int, batch_size)
45
+
46
+
47
+ def split_batches(batch: GeneralBatch, split: int) -> tuple[GeneralBatch, GeneralBatch]:
48
+ left: GeneralBatch = {}
49
+ right: GeneralBatch = {}
50
+
51
+ for key, value in batch.items():
52
+ if torch.is_tensor(value):
53
+ sub_left = value[:split, ...]
54
+ sub_right = value[split:, ...]
55
+ else:
56
+ sub_left, sub_right = split_batches(value, split)
57
+ left[key], right[key] = sub_left, sub_right
58
+
59
+ return (left, right)
60
+
61
+
62
+ class DatasetProtocol(Protocol):
63
+ def __iter__(self) -> Iterator[GeneralBatch]: ...
64
+ @property
65
+ def batch_size(self) -> int: ...
66
+
67
+
68
+ class FixedBatchSizeDataset(IterableDataset):
69
+ """
70
+ Wrapper for arbitrary datasets that fetches batches of fixed size.
71
+ Concatenates batches from the wrapped dataset until it reaches the specified size.
72
+ The last batch may be smaller than the specified size.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dataset: DatasetProtocol,
78
+ batch_size: Optional[int] = None,
79
+ collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
80
+ strict_checks: bool = False,
81
+ ) -> None:
82
+ """
83
+ :param dataset: An iterable object that returns batches.
84
+ Generally a subclass of ``torch.utils.data.IterableDataset``.
85
+ :param batch_size: Desired batch size. If ``None``, will search for batch size in ``dataset.batch_size``.
86
+ Default: ``None``.
87
+ :param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
88
+ :param strict_checks: If ``True``, additional batch size checks will be performed.
89
+ May affect performance. Default: ``False``.
90
+
91
+ :raises ValueError: If an invalid batch size was provided.
92
+ """
93
+ super().__init__()
94
+
95
+ self.dataset: DatasetProtocol = dataset
96
+
97
+ if batch_size is None:
98
+ assert hasattr(dataset, "batch_size")
99
+ batch_size = self.dataset.batch_size
100
+
101
+ assert isinstance(batch_size, int)
102
+ int_batch_size: int = cast(int, batch_size)
103
+
104
+ if int_batch_size < 1:
105
+ msg = f"Insufficient batch size. Got {int_batch_size=}"
106
+ raise ValueError(msg)
107
+
108
+ if int_batch_size < 2:
109
+ warnings.warn(f"Low batch size. Got {int_batch_size=}. This may cause performance issues.", stacklevel=2)
110
+
111
+ self.collate_fn: Callable = collate_fn
112
+ self.batch_size: int = int_batch_size
113
+ self.strict_checks: bool = strict_checks
114
+
115
+ def get_batch_size(self, batch: GeneralBatch) -> int:
116
+ return get_batch_size(batch, strict=self.strict_checks)
117
+
118
+ def __iter__(self) -> Iterator[GeneralBatch]:
119
+ iterator: Iterator[GeneralBatch] = iter(self.dataset)
120
+
121
+ buffer: list[GeneralBatch] = []
122
+ buffer_size: int = 0
123
+
124
+ while True:
125
+ while buffer_size < self.batch_size:
126
+ try:
127
+ batch: GeneralBatch = next(iterator)
128
+ size: int = self.get_batch_size(batch)
129
+
130
+ buffer.append(batch)
131
+ buffer_size += size
132
+ except StopIteration:
133
+ break
134
+
135
+ if buffer_size == 0:
136
+ break
137
+
138
+ joined: GeneralBatch = self.collate_fn(buffer)
139
+ assert buffer_size == self.get_batch_size(joined)
140
+
141
+ if self.batch_size < buffer_size:
142
+ left, right = split_batches(joined, self.batch_size)
143
+ residue: int = buffer_size - self.batch_size
144
+ assert residue == self.get_batch_size(right)
145
+
146
+ buffer_size = residue
147
+ buffer = [right]
148
+
149
+ yield left
150
+ else:
151
+ buffer_size = 0
152
+ buffer = []
153
+
154
+ yield joined
155
+
156
+ assert buffer_size == 0
157
+ assert len(buffer) == 0
@@ -0,0 +1,140 @@
1
+ from typing import Any, Union
2
+
3
+ import pyarrow as pa
4
+ import pyarrow.compute as pc
5
+ import torch
6
+
7
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
8
+ from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
9
+ from replay.data.nn.parquet.metadata import (
10
+ Metadata,
11
+ get_1d_array_columns,
12
+ get_padding,
13
+ get_shape,
14
+ )
15
+ from replay.data.utils.typing.dtype import pyarrow_to_torch
16
+
17
+ from .column_protocol import OutputType
18
+ from .indexing import get_mask, get_offsets
19
+ from .utils import ensure_mutable
20
+
21
+
22
+ class Array1DColumn:
23
+ """
24
+ Representation of a 1D array column, containing a
25
+ list of numbers of varying length in each of its rows.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data: torch.Tensor,
31
+ lengths: torch.LongTensor,
32
+ shape: Union[int, list[int]],
33
+ padding: Any = DEFAULT_PADDING,
34
+ ) -> None:
35
+ """
36
+ :param data: A tensor containing column data.
37
+ :param lengths: A tensor containing lengths of each individual row array.
38
+ :param shape: An integer or list of integers representing the target array shapes.
39
+ :param padding: Padding value to use to fill null values and match target shape.
40
+ Default: value of ``DEFAULT_PADDING``
41
+
42
+ :raises ValueError: If the shape provided is not one-dimensional.
43
+ """
44
+ if isinstance(shape, list) and len(shape) > 1:
45
+ msg = f"Array1DColumn accepts a shape of size (1,) only. Got {shape=}"
46
+ raise ValueError(msg)
47
+
48
+ self.padding = padding
49
+ self.data = data
50
+ self.offsets = get_offsets(lengths)
51
+ self.shape = shape[0] if isinstance(shape, list) else shape
52
+ assert self.length == torch.numel(lengths)
53
+
54
+ @property
55
+ def length(self) -> int:
56
+ return torch.numel(self.offsets) - 1
57
+
58
+ def __len__(self) -> int:
59
+ return self.length
60
+
61
+ @property
62
+ def device(self) -> torch.device:
63
+ assert self.data.device == self.offsets.device
64
+ return self.offsets.device
65
+
66
+ @property
67
+ def dtype(self) -> torch.dtype:
68
+ return self.data.dtype
69
+
70
+ def __getitem__(self, indices: torch.LongTensor) -> OutputType:
71
+ indices = indices.to(device=self.device)
72
+ mask, output = get_mask(indices, self.offsets, self.shape)
73
+
74
+ # TODO: Test this for both 1d and 2d arrays. Add same check in 2d arrays
75
+ if self.data.numel() == 0:
76
+ mask = torch.zeros((indices.size(0), self.shape), dtype=torch.bool, device=self.device)
77
+ output = torch.ones((indices.size(0), self.shape), dtype=torch.bool, device=self.device) * self.padding
78
+ return mask, output
79
+
80
+ unmasked_values = torch.take(self.data, output)
81
+ masked_values = torch.where(mask, unmasked_values, self.padding)
82
+ assert masked_values.device == self.device
83
+ assert masked_values.dtype == self.dtype
84
+ return (mask, masked_values)
85
+
86
+
87
+ def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE) -> tuple[torch.Tensor, torch.Tensor]:
88
+ """
89
+ Converts a PyArrow array into a PyTorch tensor.
90
+
91
+ :param array: Original PyArrow array.
92
+ :param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
93
+
94
+ :return: A PyTorch tensor obtained from original array.
95
+ """
96
+ flatten = pc.list_flatten(array)
97
+ lengths = pc.list_value_length(array).cast(pa.int64())
98
+
99
+ # Copying to be mutable
100
+ flatten_torch = torch.asarray(
101
+ ensure_mutable(flatten.to_numpy()),
102
+ device=device,
103
+ dtype=pyarrow_to_torch(flatten.type),
104
+ )
105
+
106
+ # Copying to be mutable
107
+ lengths_torch = torch.asarray(
108
+ ensure_mutable(lengths.to_numpy()),
109
+ device=device,
110
+ dtype=torch.int64,
111
+ )
112
+ return (lengths_torch, flatten_torch)
113
+
114
+
115
+ def to_array_1d_columns(
116
+ data: pa.RecordBatch,
117
+ metadata: Metadata,
118
+ device: torch.device = DEFAULT_DEVICE,
119
+ ) -> dict[str, Array1DColumn]:
120
+ """
121
+ Converts a PyArrow batch of data to a set of ``Array1DColums``s.
122
+ This function filters only those columns matching its format from the full batch.
123
+
124
+ :param data: A PyArrow batch of column data.
125
+ :param metadata: Metadata containing information about columns' formats.
126
+ :param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
127
+
128
+ :return: A dict of tensors containing dataset's numeric columns.
129
+ """
130
+ result: dict[str, Array1DColumn] = {}
131
+
132
+ for column_name in get_1d_array_columns(metadata):
133
+ lengths, torch_array = to_torch(data.column(column_name), device=device)
134
+ result[column_name] = Array1DColumn(
135
+ data=torch_array,
136
+ lengths=lengths,
137
+ padding=get_padding(metadata, column_name),
138
+ shape=get_shape(metadata, column_name),
139
+ )
140
+ return result