replay-rec 0.17.0rc0__tar.gz → 0.17.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/PKG-INFO +3 -11
  2. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/pyproject.toml +11 -15
  3. replay_rec-0.17.1/replay/__init__.py +2 -0
  4. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/dataset.py +246 -20
  5. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/schema.py +42 -0
  6. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/sequence_tokenizer.py +17 -47
  7. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/sequential_dataset.py +76 -2
  8. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/filters.py +169 -4
  9. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/base_splitter.py +1 -1
  10. replay_rec-0.17.1/replay/utils/common.py +167 -0
  11. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/spark_utils.py +13 -6
  12. replay_rec-0.17.0rc0/NOTICE +0 -41
  13. replay_rec-0.17.0rc0/replay/__init__.py +0 -2
  14. replay_rec-0.17.0rc0/replay/experimental/metrics/__init__.py +0 -61
  15. replay_rec-0.17.0rc0/replay/experimental/metrics/base_metric.py +0 -601
  16. replay_rec-0.17.0rc0/replay/experimental/metrics/coverage.py +0 -97
  17. replay_rec-0.17.0rc0/replay/experimental/metrics/experiment.py +0 -175
  18. replay_rec-0.17.0rc0/replay/experimental/metrics/hitrate.py +0 -26
  19. replay_rec-0.17.0rc0/replay/experimental/metrics/map.py +0 -30
  20. replay_rec-0.17.0rc0/replay/experimental/metrics/mrr.py +0 -18
  21. replay_rec-0.17.0rc0/replay/experimental/metrics/ncis_precision.py +0 -31
  22. replay_rec-0.17.0rc0/replay/experimental/metrics/ndcg.py +0 -49
  23. replay_rec-0.17.0rc0/replay/experimental/metrics/precision.py +0 -22
  24. replay_rec-0.17.0rc0/replay/experimental/metrics/recall.py +0 -25
  25. replay_rec-0.17.0rc0/replay/experimental/metrics/rocauc.py +0 -49
  26. replay_rec-0.17.0rc0/replay/experimental/metrics/surprisal.py +0 -90
  27. replay_rec-0.17.0rc0/replay/experimental/metrics/unexpectedness.py +0 -76
  28. replay_rec-0.17.0rc0/replay/experimental/models/__init__.py +0 -10
  29. replay_rec-0.17.0rc0/replay/experimental/models/admm_slim.py +0 -205
  30. replay_rec-0.17.0rc0/replay/experimental/models/base_neighbour_rec.py +0 -204
  31. replay_rec-0.17.0rc0/replay/experimental/models/base_rec.py +0 -1271
  32. replay_rec-0.17.0rc0/replay/experimental/models/base_torch_rec.py +0 -234
  33. replay_rec-0.17.0rc0/replay/experimental/models/cql.py +0 -452
  34. replay_rec-0.17.0rc0/replay/experimental/models/ddpg.py +0 -921
  35. replay_rec-0.17.0rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
  36. replay_rec-0.17.0rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
  37. replay_rec-0.17.0rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
  38. replay_rec-0.17.0rc0/replay/experimental/models/dt4rec/utils.py +0 -265
  39. replay_rec-0.17.0rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  40. replay_rec-0.17.0rc0/replay/experimental/models/implicit_wrap.py +0 -131
  41. replay_rec-0.17.0rc0/replay/experimental/models/lightfm_wrap.py +0 -302
  42. replay_rec-0.17.0rc0/replay/experimental/models/mult_vae.py +0 -331
  43. replay_rec-0.17.0rc0/replay/experimental/models/neuromf.py +0 -405
  44. replay_rec-0.17.0rc0/replay/experimental/models/scala_als.py +0 -296
  45. replay_rec-0.17.0rc0/replay/experimental/nn/data/__init__.py +0 -1
  46. replay_rec-0.17.0rc0/replay/experimental/nn/data/schema_builder.py +0 -55
  47. replay_rec-0.17.0rc0/replay/experimental/preprocessing/__init__.py +0 -3
  48. replay_rec-0.17.0rc0/replay/experimental/preprocessing/data_preparator.py +0 -838
  49. replay_rec-0.17.0rc0/replay/experimental/preprocessing/padder.py +0 -229
  50. replay_rec-0.17.0rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
  51. replay_rec-0.17.0rc0/replay/experimental/scenarios/__init__.py +0 -1
  52. replay_rec-0.17.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  53. replay_rec-0.17.0rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  54. replay_rec-0.17.0rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  55. replay_rec-0.17.0rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  56. replay_rec-0.17.0rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
  57. replay_rec-0.17.0rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  58. replay_rec-0.17.0rc0/replay/experimental/utils/logger.py +0 -24
  59. replay_rec-0.17.0rc0/replay/experimental/utils/model_handler.py +0 -181
  60. replay_rec-0.17.0rc0/replay/experimental/utils/session_handler.py +0 -44
  61. replay_rec-0.17.0rc0/replay/models/extensions/ann/__init__.py +0 -0
  62. replay_rec-0.17.0rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
  63. replay_rec-0.17.0rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
  64. replay_rec-0.17.0rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
  65. replay_rec-0.17.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
  66. replay_rec-0.17.0rc0/replay/utils/common.py +0 -65
  67. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/LICENSE +0 -0
  68. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/README.md +0 -0
  69. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/__init__.py +0 -0
  70. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/dataset_utils/__init__.py +0 -0
  71. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
  72. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/__init__.py +0 -0
  73. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/torch_sequential_dataset.py +0 -0
  74. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/nn/utils.py +0 -0
  75. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/schema.py +0 -0
  76. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/data/spark_schema.py +0 -0
  77. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/__init__.py +0 -0
  78. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/base_metric.py +0 -0
  79. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/categorical_diversity.py +0 -0
  80. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/coverage.py +0 -0
  81. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/descriptors.py +0 -0
  82. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/experiment.py +0 -0
  83. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/hitrate.py +0 -0
  84. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/map.py +0 -0
  85. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/mrr.py +0 -0
  86. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/ndcg.py +0 -0
  87. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/novelty.py +0 -0
  88. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/offline_metrics.py +0 -0
  89. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/precision.py +0 -0
  90. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/recall.py +0 -0
  91. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/rocauc.py +0 -0
  92. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/surprisal.py +0 -0
  93. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/torch_metrics_builder.py +0 -0
  94. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/metrics/unexpectedness.py +0 -0
  95. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/__init__.py +0 -0
  96. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/als.py +0 -0
  97. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/association_rules.py +0 -0
  98. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/base_neighbour_rec.py +0 -0
  99. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/base_rec.py +0 -0
  100. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/cat_pop_rec.py +0 -0
  101. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/cluster.py +0 -0
  102. {replay_rec-0.17.0rc0/replay/experimental → replay_rec-0.17.1/replay/models/extensions}/__init__.py +0 -0
  103. {replay_rec-0.17.0rc0/replay/experimental/models/dt4rec → replay_rec-0.17.1/replay/models/extensions/ann}/__init__.py +0 -0
  104. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/ann_mixin.py +0 -0
  105. {replay_rec-0.17.0rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.17.1/replay/models/extensions/ann/entities}/__init__.py +0 -0
  106. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
  107. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
  108. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
  109. {replay_rec-0.17.0rc0/replay/experimental/scenarios/two_stages → replay_rec-0.17.1/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
  110. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
  111. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
  112. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
  113. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
  114. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
  115. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
  116. {replay_rec-0.17.0rc0/replay/experimental/utils → replay_rec-0.17.1/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
  117. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
  118. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
  119. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
  120. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
  121. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
  122. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
  123. {replay_rec-0.17.0rc0/replay/models/extensions → replay_rec-0.17.1/replay/models/extensions/ann/index_stores}/__init__.py +0 -0
  124. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
  125. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
  126. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
  127. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
  128. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/index_stores/utils.py +0 -0
  129. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/extensions/ann/utils.py +0 -0
  130. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/kl_ucb.py +0 -0
  131. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/knn.py +0 -0
  132. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/__init__.py +0 -0
  133. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/optimizer_utils/__init__.py +0 -0
  134. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
  135. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/__init__.py +0 -0
  136. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
  137. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
  138. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
  139. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/bert4rec/model.py +0 -0
  140. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
  141. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
  142. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
  143. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
  144. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
  145. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
  146. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
  147. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
  148. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
  149. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/nn/sequential/sasrec/model.py +0 -0
  150. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/pop_rec.py +0 -0
  151. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/query_pop_rec.py +0 -0
  152. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/random_rec.py +0 -0
  153. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/slim.py +0 -0
  154. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/thompson_sampling.py +0 -0
  155. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/ucb.py +0 -0
  156. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/wilson.py +0 -0
  157. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/models/word2vec.py +0 -0
  158. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/optimization/__init__.py +0 -0
  159. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/optimization/optuna_objective.py +0 -0
  160. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/__init__.py +0 -0
  161. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/converter.py +0 -0
  162. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/history_based_fp.py +0 -0
  163. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/label_encoder.py +0 -0
  164. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/preprocessing/sessionizer.py +0 -0
  165. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/scenarios/__init__.py +0 -0
  166. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/scenarios/fallback.py +0 -0
  167. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/__init__.py +0 -0
  168. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/cold_user_random_splitter.py +0 -0
  169. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/k_folds.py +0 -0
  170. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/last_n_splitter.py +0 -0
  171. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/new_users_splitter.py +0 -0
  172. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/random_splitter.py +0 -0
  173. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/ratio_splitter.py +0 -0
  174. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/time_splitter.py +0 -0
  175. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/splitters/two_stage_splitter.py +0 -0
  176. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/__init__.py +0 -0
  177. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/dataframe_bucketizer.py +0 -0
  178. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/distributions.py +0 -0
  179. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/model_handler.py +0 -0
  180. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/session_handler.py +0 -0
  181. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/time.py +0 -0
  182. {replay_rec-0.17.0rc0 → replay_rec-0.17.1}/replay/utils/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.0rc0
3
+ Version: 0.17.1
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
@@ -20,25 +20,17 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Provides-Extra: all
21
21
  Provides-Extra: spark
22
22
  Provides-Extra: torch
23
- Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
24
- Requires-Dist: gym (>=0.26.0,<0.27.0)
25
23
  Requires-Dist: hnswlib (==0.7.0)
26
- Requires-Dist: implicit (>=0.7.0,<0.8.0)
27
- Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
28
- Requires-Dist: lightfm (==1.17)
29
24
  Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
30
- Requires-Dist: llvmlite (>=0.32.1)
31
25
  Requires-Dist: nmslib (==2.1.1)
32
- Requires-Dist: numba (>=0.50)
33
26
  Requires-Dist: numpy (>=1.20.0)
34
27
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
- Requires-Dist: pandas (>=1.3.5,<2.0.0)
28
+ Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
29
  Requires-Dist: polars (>=0.20.7,<0.21.0)
37
30
  Requires-Dist: psutil (>=5.9.5,<5.10.0)
38
31
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.3) ; extra == "spark" or extra == "all"
32
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
40
33
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
- Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
42
34
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
43
35
  Requires-Dist: scipy (>=1.8.1,<1.9.0)
44
36
  Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
@@ -39,33 +39,26 @@ classifiers = [
39
39
  ]
40
40
  exclude = [
41
41
  "replay/conftest.py",
42
+ "replay/experimental",
42
43
  ]
43
- version = "0.17.0.preview"
44
+ version = "0.17.1"
44
45
 
45
46
  [tool.poetry.dependencies]
46
47
  python = ">=3.8.1, <3.11"
47
48
  numpy = ">=1.20.0"
48
- pandas = "^1.3.5"
49
+ pandas = ">=1.3.5,<=2.2.2"
49
50
  polars = "~0.20.7"
50
51
  optuna = "~3.2.0"
51
52
  scipy = "~1.8.1"
52
53
  psutil = "~5.9.5"
53
- pyspark = {version = ">=3.0,<3.3", optional = true}
54
+ pyspark = {version = ">=3.0,<3.5", optional = true}
54
55
  scikit-learn = "^1.0.2"
55
56
  pyarrow = ">=12.0.1"
57
+ torch = {version = "^1.8", optional = true}
58
+ lightning = {version = "^2.0.2", optional = true}
59
+ pytorch-ranger = {version = "^0.1.1", optional = true}
56
60
  nmslib = "2.1.1"
57
61
  hnswlib = "0.7.0"
58
- torch = "^1.8"
59
- lightning = "^2.0.2"
60
- pytorch-ranger = "^0.1.1"
61
- lightfm = "1.17"
62
- lightautoml = "~0.3.1"
63
- numba = ">=0.50"
64
- llvmlite = ">=0.32.1"
65
- sb-obp = "^0.5.7"
66
- d3rlpy = "^2.0.4"
67
- implicit = "~0.7.0"
68
- gym = "^0.26.0"
69
62
 
70
63
  [tool.poetry.extras]
71
64
  spark = ["pyspark"]
@@ -92,7 +85,7 @@ data-science-types = "0.2.23"
92
85
 
93
86
  [tool.poetry-dynamic-versioning]
94
87
  enable = false
95
- format-jinja = """0.17.0{{ env['PACKAGE_SUFFIX'] }}"""
88
+ format-jinja = """0.17.1{{ env['PACKAGE_SUFFIX'] }}"""
96
89
  vcs = "git"
97
90
 
98
91
  [tool.ruff]
@@ -123,6 +116,9 @@ max-complexity = 13
123
116
  "tests/*" = ["ARG", "E402", "INP", "ISC", "N", "S", "SIM", "F811"]
124
117
  "tests/experimental/*" = ["F401", "F811"]
125
118
  "replay/experimental/models/extensions/spark_custom_models/als_extension.py" = ["ARG002", "N802", "N803", "N815"]
119
+ "replay/data/nn/sequence_tokenizer.py" = ["ARG003"]
120
+ "replay/splitters/base_splitter.py" = ["ARG003"]
121
+ "replay/data/nn/sequential_dataset.py" = ["ARG003"]
126
122
 
127
123
  [tool.tomlsort]
128
124
  ignore_case = true
@@ -0,0 +1,2 @@
1
+ """ RecSys library """
2
+ __version__ = "0.17.1"
@@ -3,11 +3,22 @@
3
3
  """
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Callable, Dict, Iterable, List, Optional, Sequence
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
7
9
 
8
10
  import numpy as np
9
-
10
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
11
+ from pandas import read_parquet as pd_read_parquet
12
+ from polars import read_parquet as pl_read_parquet
13
+
14
+ from replay.utils import (
15
+ PYSPARK_AVAILABLE,
16
+ DataFrameLike,
17
+ PandasDataFrame,
18
+ PolarsDataFrame,
19
+ SparkDataFrame,
20
+ )
21
+ from replay.utils.session_handler import get_spark_session
11
22
 
12
23
  from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
13
24
 
@@ -47,9 +58,7 @@ class Dataset:
47
58
  self._query_features = query_features
48
59
  self._item_features = item_features
49
60
 
50
- self.is_pandas = isinstance(interactions, PandasDataFrame)
51
- self.is_spark = isinstance(interactions, SparkDataFrame)
52
- self.is_polars = isinstance(interactions, PolarsDataFrame)
61
+ self._assign_df_type()
53
62
 
54
63
  self._categorical_encoded = categorical_encoded
55
64
 
@@ -74,16 +83,8 @@ class Dataset:
74
83
  msg = "Interactions and query features should have the same type."
75
84
  raise TypeError(msg)
76
85
 
77
- self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
78
- FeatureSource.INTERACTIONS: self.interactions,
79
- FeatureSource.QUERY_FEATURES: self.query_features,
80
- FeatureSource.ITEM_FEATURES: self.item_features,
81
- }
82
-
83
- self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
84
- FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
85
- FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
86
- }
86
+ self._get_feature_source_map()
87
+ self._get_ids_source_map()
87
88
 
88
89
  self._feature_schema = self._fill_feature_schema(feature_schema)
89
90
 
@@ -92,7 +93,6 @@ class Dataset:
92
93
  self._check_ids_consistency(hint=FeatureHint.QUERY_ID)
93
94
  if self.item_features is not None:
94
95
  self._check_ids_consistency(hint=FeatureHint.ITEM_ID)
95
-
96
96
  if self._categorical_encoded:
97
97
  self._check_encoded()
98
98
 
@@ -189,6 +189,157 @@ class Dataset:
189
189
  """
190
190
  return self._feature_schema
191
191
 
192
+ def _get_df_type(self) -> str:
193
+ """
194
+ :returns: Stored dataframe type.
195
+ """
196
+ if self.is_spark:
197
+ return "spark"
198
+ if self.is_pandas:
199
+ return "pandas"
200
+ if self.is_polars:
201
+ return "polars"
202
+ msg = "No known dataframe types are provided"
203
+ raise ValueError(msg)
204
+
205
+ def _to_parquet(self, df: DataFrameLike, path: Path) -> None:
206
+ """
207
+ Save the content of the dataframe in parquet format to the provided path.
208
+
209
+ :param df: Dataframe to save.
210
+ :param path: Path to save the dataframe to.
211
+ """
212
+ if self.is_spark:
213
+ path = str(path)
214
+ df = df.withColumn("idx", sf.monotonically_increasing_id())
215
+ df.write.mode("overwrite").parquet(path)
216
+ elif self.is_pandas:
217
+ df.to_parquet(path)
218
+ elif self.is_polars:
219
+ df.write_parquet(path)
220
+ else:
221
+ msg = """
222
+ _to_parquet() can only be used to save polars|pandas|spark dataframes;
223
+ No known dataframe types are provided
224
+ """
225
+ raise TypeError(msg)
226
+
227
+ @staticmethod
228
+ def _read_parquet(path: Path, mode: str) -> Union[SparkDataFrame, PandasDataFrame, PolarsDataFrame]:
229
+ """
230
+ Read the parquet file as dataframe.
231
+
232
+ :param path: The parquet file path.
233
+ :param mode: Dataframe type. Can be spark|pandas|polars.
234
+ :returns: The dataframe read from the file.
235
+ """
236
+ if mode == "spark":
237
+ path = str(path)
238
+ spark_session = get_spark_session()
239
+ df = spark_session.read.parquet(path)
240
+ if "idx" in df.columns:
241
+ df = df.orderBy("idx").drop("idx")
242
+ return df
243
+ if mode == "pandas":
244
+ df = pd_read_parquet(path)
245
+ if "idx" in df.columns:
246
+ df = df.set_index("idx").reset_index(drop=True)
247
+ return df
248
+ if mode == "polars":
249
+ df = pl_read_parquet(path, use_pyarrow=True)
250
+ if "idx" in df.columns:
251
+ df = df.sort("idx").drop("idx")
252
+ return df
253
+ msg = f"_read_parquet() can only be used to read polars|pandas|spark dataframes, not {mode}"
254
+ raise TypeError(msg)
255
+
256
+ def save(self, path: str) -> None:
257
+ """
258
+ Save the Dataset to the provided path.
259
+
260
+ :param path: Path to save the Dataset to.
261
+ """
262
+ dataset_dict = {}
263
+ dataset_dict["_class_name"] = self.__class__.__name__
264
+
265
+ interactions_type = self._get_df_type()
266
+ dataset_dict["init_args"] = {
267
+ "feature_schema": [],
268
+ "interactions": interactions_type,
269
+ "item_features": (interactions_type if self.item_features is not None else None),
270
+ "query_features": (interactions_type if self.query_features is not None else None),
271
+ "check_consistency": False,
272
+ "categorical_encoded": self._categorical_encoded,
273
+ }
274
+
275
+ for feature in self.feature_schema.all_features:
276
+ dataset_dict["init_args"]["feature_schema"].append(
277
+ {
278
+ "column": feature.column,
279
+ "feature_type": feature.feature_type.name,
280
+ "feature_hint": (feature.feature_hint.name if feature.feature_hint else None),
281
+ }
282
+ )
283
+
284
+ base_path = Path(path).with_suffix(".replay").resolve()
285
+ base_path.mkdir(parents=True, exist_ok=True)
286
+
287
+ with open(base_path / "init_args.json", "w+") as file:
288
+ json.dump(dataset_dict, file)
289
+
290
+ df_data = {
291
+ "interactions": self.interactions,
292
+ "item_features": self.item_features,
293
+ "query_features": self.query_features,
294
+ }
295
+
296
+ for df_name, df in df_data.items():
297
+ if df is not None:
298
+ df_path = base_path / f"{df_name}.parquet"
299
+ self._to_parquet(df, df_path)
300
+
301
+ @classmethod
302
+ def load(
303
+ cls,
304
+ path: str,
305
+ dataframe_type: Optional[str] = None,
306
+ ) -> Dataset:
307
+ """
308
+ Load the Dataset from the provided path.
309
+
310
+ :param path: The file path
311
+ :dataframe_type: Dataframe type to use to store internal data.
312
+ Can be spark|pandas|polars|None.
313
+ If not provided automatically sets to the one used when the Dataset was saved.
314
+ :returns: Loaded Dataset.
315
+ """
316
+ base_path = Path(path).with_suffix(".replay").resolve()
317
+ with open(base_path / "init_args.json", "r") as file:
318
+ dataset_dict = json.loads(file.read())
319
+
320
+ if dataframe_type not in ["pandas", "spark", "polars", None]:
321
+ msg = f"Argument dataframe_type can be spark|pandas|polars|None, not {dataframe_type}"
322
+ raise ValueError(msg)
323
+
324
+ feature_schema_data = dataset_dict["init_args"]["feature_schema"]
325
+ features_list = []
326
+ for feature_data in feature_schema_data:
327
+ f_type = feature_data["feature_type"]
328
+ f_hint = feature_data["feature_hint"]
329
+ feature_data["feature_type"] = FeatureType[f_type] if f_type else None
330
+ feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
331
+ features_list.append(FeatureInfo(**feature_data))
332
+ dataset_dict["init_args"]["feature_schema"] = FeatureSchema(features_list)
333
+
334
+ for df_name in ["interactions", "query_features", "item_features"]:
335
+ df_type = dataset_dict["init_args"][df_name]
336
+ if df_type:
337
+ df_type = dataframe_type or df_type
338
+ load_path = base_path / f"{df_name}.parquet"
339
+ dataset_dict["init_args"][df_name] = cls._read_parquet(load_path, df_type)
340
+ dataset = cls(**dataset_dict["init_args"])
341
+ return dataset
342
+
192
343
  if PYSPARK_AVAILABLE:
193
344
 
194
345
  def persist(self, storage_level: StorageLevel = StorageLevel(True, True, False, True, 1)) -> None:
@@ -283,6 +434,24 @@ class Dataset:
283
434
  categorical_encoded=self._categorical_encoded,
284
435
  )
285
436
 
437
+ def _get_feature_source_map(self):
438
+ self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
439
+ FeatureSource.INTERACTIONS: self.interactions,
440
+ FeatureSource.QUERY_FEATURES: self.query_features,
441
+ FeatureSource.ITEM_FEATURES: self.item_features,
442
+ }
443
+
444
+ def _get_ids_source_map(self):
445
+ self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
446
+ FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
447
+ FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
448
+ }
449
+
450
+ def _assign_df_type(self):
451
+ self.is_pandas = isinstance(self.interactions, PandasDataFrame)
452
+ self.is_spark = isinstance(self.interactions, SparkDataFrame)
453
+ self.is_polars = isinstance(self.interactions, PolarsDataFrame)
454
+
286
455
  def _get_cardinality(self, feature: FeatureInfo) -> Callable:
287
456
  def callback(column: str) -> int:
288
457
  if feature.feature_hint in [FeatureHint.ITEM_ID, FeatureHint.QUERY_ID]:
@@ -381,7 +550,11 @@ class Dataset:
381
550
  is_consistent = (
382
551
  self.interactions.select(ids_column)
383
552
  .distinct()
384
- .join(features_df.select(ids_column).distinct(), on=[ids_column], how="leftanti")
553
+ .join(
554
+ features_df.select(ids_column).distinct(),
555
+ on=[ids_column],
556
+ how="leftanti",
557
+ )
385
558
  .count()
386
559
  ) == 0
387
560
  else:
@@ -389,7 +562,11 @@ class Dataset:
389
562
  len(
390
563
  self.interactions.select(ids_column)
391
564
  .unique()
392
- .join(features_df.select(ids_column).unique(), on=ids_column, how="anti")
565
+ .join(
566
+ features_df.select(ids_column).unique(),
567
+ on=ids_column,
568
+ how="anti",
569
+ )
393
570
  )
394
571
  == 0
395
572
  )
@@ -399,7 +576,11 @@ class Dataset:
399
576
  raise ValueError(msg)
400
577
 
401
578
  def _check_column_encoded(
402
- self, data: DataFrameLike, column: str, source: FeatureSource, cardinality: Optional[int]
579
+ self,
580
+ data: DataFrameLike,
581
+ column: str,
582
+ source: FeatureSource,
583
+ cardinality: Optional[int],
403
584
  ) -> None:
404
585
  """
405
586
  Checks that IDs are encoded:
@@ -482,6 +663,51 @@ class Dataset:
482
663
  feature.cardinality,
483
664
  )
484
665
 
666
+ def to_pandas(self) -> None:
667
+ """
668
+ Convert internally stored dataframes to pandas.DataFrame.
669
+ """
670
+ from replay.utils.common import convert2pandas
671
+
672
+ self._interactions = convert2pandas(self._interactions)
673
+ if self._query_features is not None:
674
+ self._query_features = convert2pandas(self._query_features)
675
+ if self._item_features is not None:
676
+ self._item_features = convert2pandas(self.item_features)
677
+ self._get_feature_source_map()
678
+ self._get_ids_source_map()
679
+ self._assign_df_type()
680
+
681
+ def to_spark(self):
682
+ """
683
+ Convert internally stored dataframes to pyspark.sql.DataFrame.
684
+ """
685
+ from replay.utils.common import convert2spark
686
+
687
+ self._interactions = convert2spark(self._interactions)
688
+ if self._query_features is not None:
689
+ self._query_features = convert2spark(self._query_features)
690
+ if self._item_features is not None:
691
+ self._item_features = convert2spark(self._item_features)
692
+ self._get_feature_source_map()
693
+ self._get_ids_source_map()
694
+ self._assign_df_type()
695
+
696
+ def to_polars(self):
697
+ """
698
+ Convert internally stored dataframes to polars.DataFrame.
699
+ """
700
+ from replay.utils.common import convert2polars
701
+
702
+ self._interactions = convert2polars(self._interactions)
703
+ if self._query_features is not None:
704
+ self._query_features = convert2polars(self._query_features)
705
+ if self._item_features is not None:
706
+ self._item_features = convert2polars(self._item_features)
707
+ self._get_feature_source_map()
708
+ self._get_ids_source_map()
709
+ self._assign_df_type()
710
+
485
711
 
486
712
  def nunique(data: DataFrameLike, column: str) -> int:
487
713
  """
@@ -408,6 +408,48 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
408
408
  return None
409
409
  return rating_features.item().name
410
410
 
411
+ def _get_object_args(self) -> Dict:
412
+ """
413
+ Returns list of features represented as dictionaries.
414
+ """
415
+ features = [
416
+ {
417
+ "name": feature.name,
418
+ "feature_type": feature.feature_type.name,
419
+ "is_seq": feature.is_seq,
420
+ "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
+ "feature_sources": [
422
+ {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
+ ]
424
+ if feature.feature_sources
425
+ else None,
426
+ "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
+ "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
+ "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
429
+ }
430
+ for feature in self.all_features
431
+ ]
432
+ return features
433
+
434
+ @classmethod
435
+ def _create_object_by_args(cls, args: Dict) -> "TensorSchema":
436
+ features_list = []
437
+ for feature_data in args:
438
+ feature_data["feature_sources"] = (
439
+ [
440
+ TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
441
+ for x in feature_data["feature_sources"]
442
+ ]
443
+ if feature_data["feature_sources"]
444
+ else None
445
+ )
446
+ f_type = feature_data["feature_type"]
447
+ f_hint = feature_data["feature_hint"]
448
+ feature_data["feature_type"] = FeatureType[f_type] if f_type else None
449
+ feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
450
+ features_list.append(TensorFeatureInfo(**feature_data))
451
+ return TensorSchema(features_list)
452
+
411
453
  def filter(
412
454
  self,
413
455
  name: Optional[str] = None,
@@ -24,7 +24,10 @@ SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
24
 
25
25
  class SequenceTokenizer:
26
26
  """
27
- Data tokenizer for transformers
27
+ Data tokenizer for transformers;
28
+ Encodes all categorical features (the ones marked as FeatureType.CATEGORICAL in
29
+ the FeatureSchema) and stores all data as items sequences (sorted by time if a
30
+ feature of type FeatureHint.TIMESTAMP is provided, unsorted otherwise).
28
31
  """
29
32
 
30
33
  def __init__(
@@ -278,17 +281,17 @@ class SequenceTokenizer:
278
281
  ]
279
282
 
280
283
  for tensor_feature in tensor_schema.values():
281
- source = tensor_feature.feature_source
282
- assert source is not None
284
+ for source in tensor_feature.feature_sources:
285
+ assert source is not None
283
286
 
284
- # Some columns already added to encoder, skip them
285
- if source.column in features_subset:
286
- continue
287
+ # Some columns already added to encoder, skip them
288
+ if source.column in features_subset:
289
+ continue
287
290
 
288
- if isinstance(source.source, FeatureSource):
289
- features_subset.append(source.column)
290
- else:
291
- assert False, "Unknown tensor feature source"
291
+ if isinstance(source.source, FeatureSource):
292
+ features_subset.append(source.column)
293
+ else:
294
+ assert False, "Unknown tensor feature source"
292
295
 
293
296
  return set(features_subset)
294
297
 
@@ -404,7 +407,7 @@ class SequenceTokenizer:
404
407
 
405
408
  @classmethod
406
409
  @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
407
- def load(cls, path: str, use_pickle: bool = False) -> "SequenceTokenizer":
410
+ def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
408
411
  """
409
412
  Load tokenizer object from the given path.
410
413
 
@@ -422,18 +425,7 @@ class SequenceTokenizer:
422
425
 
423
426
  # load tensor_schema, tensor_features
424
427
  tensor_schema_data = tokenizer_dict["init_args"]["tensor_schema"]
425
- features_list = []
426
- for feature_data in tensor_schema_data:
427
- feature_data["feature_sources"] = [
428
- TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
429
- for x in feature_data["feature_sources"]
430
- ]
431
- f_type = feature_data["feature_type"]
432
- f_hint = feature_data["feature_hint"]
433
- feature_data["feature_type"] = FeatureType[f_type] if f_type else None
434
- feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
435
- features_list.append(TensorFeatureInfo(**feature_data))
436
- tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema(features_list)
428
+ tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema._create_object_by_args(tensor_schema_data)
437
429
 
438
430
  # Load encoder columns and rules
439
431
  types = list(FeatureHint) + list(FeatureSource)
@@ -447,7 +439,7 @@ class SequenceTokenizer:
447
439
  rule_data = rules_dict[rule]
448
440
  if rule_data["mapping"] and rule_data["is_int"]:
449
441
  rule_data["mapping"] = {int(key): value for key, value in rule_data["mapping"].items()}
450
- del rule_data["is_int"]
442
+ del rule_data["is_int"]
451
443
 
452
444
  tokenizer_dict["encoder"]["encoding_rules"][rule] = LabelEncodingRule(**rule_data)
453
445
 
@@ -478,31 +470,9 @@ class SequenceTokenizer:
478
470
  "allow_collect_to_master": self._allow_collect_to_master,
479
471
  "handle_unknown_rule": self._encoder._handle_unknown_rule,
480
472
  "default_value_rule": self._encoder._default_value_rule,
481
- "tensor_schema": [],
473
+ "tensor_schema": self._tensor_schema._get_object_args(),
482
474
  }
483
475
 
484
- # save tensor schema
485
- for feature in list(self._tensor_schema.values()):
486
- tokenizer_dict["init_args"]["tensor_schema"].append(
487
- {
488
- "name": feature.name,
489
- "feature_type": feature.feature_type.name,
490
- "is_seq": feature.is_seq,
491
- "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
492
- "feature_sources": [
493
- {"source": x.source.name, "column": x.column, "index": x.index}
494
- for x in feature.feature_sources
495
- ]
496
- if feature.feature_sources
497
- else None,
498
- "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
499
- "embedding_dim": feature.embedding_dim
500
- if feature.feature_type == FeatureType.CATEGORICAL
501
- else None,
502
- "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
503
- }
504
- )
505
-
506
476
  # save DatasetLabelEncoder
507
477
  tokenizer_dict["encoder"] = {
508
478
  "features_columns": {key.name: value for key, value in self._encoder._features_columns.items()},