replay-rec 0.20.0rc0__tar.gz → 0.20.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/PKG-INFO +17 -11
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/pyproject.toml +24 -12
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/__init__.py +1 -1
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/sequence_tokenizer.py +10 -3
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/sequential_dataset.py +18 -14
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/torch_sequential_dataset.py +12 -12
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/lin_ucb.py +55 -9
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/dataset.py +3 -16
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/dataset.py +3 -16
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/__init__.py +0 -1
- replay_rec-0.20.0rc0/replay/experimental/metrics/__init__.py +0 -62
- replay_rec-0.20.0rc0/replay/experimental/metrics/base_metric.py +0 -603
- replay_rec-0.20.0rc0/replay/experimental/metrics/coverage.py +0 -97
- replay_rec-0.20.0rc0/replay/experimental/metrics/experiment.py +0 -175
- replay_rec-0.20.0rc0/replay/experimental/metrics/hitrate.py +0 -26
- replay_rec-0.20.0rc0/replay/experimental/metrics/map.py +0 -30
- replay_rec-0.20.0rc0/replay/experimental/metrics/mrr.py +0 -18
- replay_rec-0.20.0rc0/replay/experimental/metrics/ncis_precision.py +0 -31
- replay_rec-0.20.0rc0/replay/experimental/metrics/ndcg.py +0 -49
- replay_rec-0.20.0rc0/replay/experimental/metrics/precision.py +0 -22
- replay_rec-0.20.0rc0/replay/experimental/metrics/recall.py +0 -25
- replay_rec-0.20.0rc0/replay/experimental/metrics/rocauc.py +0 -49
- replay_rec-0.20.0rc0/replay/experimental/metrics/surprisal.py +0 -90
- replay_rec-0.20.0rc0/replay/experimental/metrics/unexpectedness.py +0 -76
- replay_rec-0.20.0rc0/replay/experimental/models/__init__.py +0 -50
- replay_rec-0.20.0rc0/replay/experimental/models/admm_slim.py +0 -257
- replay_rec-0.20.0rc0/replay/experimental/models/base_neighbour_rec.py +0 -200
- replay_rec-0.20.0rc0/replay/experimental/models/base_rec.py +0 -1386
- replay_rec-0.20.0rc0/replay/experimental/models/base_torch_rec.py +0 -234
- replay_rec-0.20.0rc0/replay/experimental/models/cql.py +0 -454
- replay_rec-0.20.0rc0/replay/experimental/models/ddpg.py +0 -932
- replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/trainer.py +0 -127
- replay_rec-0.20.0rc0/replay/experimental/models/dt4rec/utils.py +0 -264
- replay_rec-0.20.0rc0/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay_rec-0.20.0rc0/replay/experimental/models/hierarchical_recommender.py +0 -331
- replay_rec-0.20.0rc0/replay/experimental/models/implicit_wrap.py +0 -131
- replay_rec-0.20.0rc0/replay/experimental/models/lightfm_wrap.py +0 -303
- replay_rec-0.20.0rc0/replay/experimental/models/mult_vae.py +0 -332
- replay_rec-0.20.0rc0/replay/experimental/models/neural_ts.py +0 -986
- replay_rec-0.20.0rc0/replay/experimental/models/neuromf.py +0 -406
- replay_rec-0.20.0rc0/replay/experimental/models/scala_als.py +0 -293
- replay_rec-0.20.0rc0/replay/experimental/models/u_lin_ucb.py +0 -115
- replay_rec-0.20.0rc0/replay/experimental/nn/data/__init__.py +0 -1
- replay_rec-0.20.0rc0/replay/experimental/nn/data/schema_builder.py +0 -102
- replay_rec-0.20.0rc0/replay/experimental/preprocessing/__init__.py +0 -3
- replay_rec-0.20.0rc0/replay/experimental/preprocessing/data_preparator.py +0 -839
- replay_rec-0.20.0rc0/replay/experimental/preprocessing/padder.py +0 -229
- replay_rec-0.20.0rc0/replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay_rec-0.20.0rc0/replay/experimental/scenarios/__init__.py +0 -1
- replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay_rec-0.20.0rc0/replay/experimental/utils/logger.py +0 -24
- replay_rec-0.20.0rc0/replay/experimental/utils/model_handler.py +0 -186
- replay_rec-0.20.0rc0/replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.0rc0/replay/models/extensions/ann/__init__.py +0 -0
- replay_rec-0.20.0rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
- replay_rec-0.20.0rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
- replay_rec-0.20.0rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
- replay_rec-0.20.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- replay_rec-0.20.0rc0/replay/utils/warnings.py +0 -26
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/LICENSE +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/NOTICE +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/README.md +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/schema.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/nn/utils.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/schema.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/data/spark_schema.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/base_metric.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/categorical_diversity.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/coverage.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/descriptors.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/experiment.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/hitrate.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/map.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/mrr.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/ndcg.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/novelty.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/offline_metrics.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/precision.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/recall.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/rocauc.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/surprisal.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/torch_metrics_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/metrics/unexpectedness.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/als.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/association_rules.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/base_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/cat_pop_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/cluster.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/common.py +0 -0
- {replay_rec-0.20.0rc0/replay/experimental → replay_rec-0.20.1/replay/models/extensions}/__init__.py +0 -0
- {replay_rec-0.20.0rc0/replay/experimental/models/dt4rec → replay_rec-0.20.1/replay/models/extensions/ann}/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/ann_mixin.py +0 -0
- {replay_rec-0.20.0rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.20.1/replay/models/extensions/ann/entities}/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
- {replay_rec-0.20.0rc0/replay/experimental/scenarios/two_stages → replay_rec-0.20.1/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
- {replay_rec-0.20.0rc0/replay/experimental/utils → replay_rec-0.20.1/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
- {replay_rec-0.20.0rc0/replay/models/extensions → replay_rec-0.20.1/replay/models/extensions/ann/index_stores}/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/index_stores/utils.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/extensions/ann/utils.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/kl_ucb.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/knn.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/loss/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/loss/sce.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/optimizer_utils/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/bert4rec/model.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/base_compiled_model.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/compiled/sasrec_compiled.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/nn/sequential/sasrec/model.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/optuna_mixin.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/optimization/optuna_objective.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/pop_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/query_pop_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/random_rec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/slim.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/thompson_sampling.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/ucb.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/wilson.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/models/word2vec.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/converter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/discretizer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/filters.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/history_based_fp.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/label_encoder.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/preprocessing/sessionizer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/scenarios/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/scenarios/fallback.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/__init__.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/base_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/cold_user_random_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/k_folds.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/last_n_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/new_users_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/random_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/ratio_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/time_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/splitters/two_stage_splitter.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/common.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/dataframe_bucketizer.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/distributions.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/model_handler.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/session_handler.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/spark_utils.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/time.py +0 -0
- {replay_rec-0.20.0rc0 → replay_rec-0.20.1}/replay/utils/types.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.20.
|
|
3
|
+
Version: 0.20.1
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
License-File: LICENSE
|
|
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
|
|
|
14
14
|
Classifier: Intended Audience :: Science/Research
|
|
15
15
|
Classifier: Natural Language :: English
|
|
16
16
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
Requires-Dist: lightning (
|
|
21
|
-
Requires-Dist:
|
|
17
|
+
Provides-Extra: spark
|
|
18
|
+
Provides-Extra: torch
|
|
19
|
+
Provides-Extra: torch-cpu
|
|
20
|
+
Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
|
|
21
|
+
Requires-Dist: lightning ; extra == "torch"
|
|
22
|
+
Requires-Dist: lightning ; extra == "torch-cpu"
|
|
22
23
|
Requires-Dist: numpy (>=1.20.0,<2)
|
|
23
24
|
Requires-Dist: pandas (>=1.3.5,<2.4.0)
|
|
24
25
|
Requires-Dist: polars (<2.0)
|
|
25
|
-
Requires-Dist: psutil (<=7.0.0)
|
|
26
|
+
Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
|
|
27
|
+
Requires-Dist: psutil ; extra == "spark"
|
|
26
28
|
Requires-Dist: pyarrow (<22.0)
|
|
27
|
-
Requires-Dist: pyspark (>=3.0,<3.5)
|
|
28
|
-
Requires-Dist:
|
|
29
|
-
Requires-Dist:
|
|
29
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
|
|
30
|
+
Requires-Dist: pyspark ; extra == "spark"
|
|
31
|
+
Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
|
|
32
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch"
|
|
33
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
|
|
30
34
|
Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
|
|
31
35
|
Requires-Dist: scipy (>=1.13.1,<1.14)
|
|
32
36
|
Requires-Dist: setuptools
|
|
33
|
-
Requires-Dist: torch (>=1.8,<3.0.0)
|
|
37
|
+
Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
|
|
38
|
+
Requires-Dist: torch ; extra == "torch"
|
|
39
|
+
Requires-Dist: torch ; extra == "torch-cpu"
|
|
34
40
|
Requires-Dist: tqdm (>=4.67,<5)
|
|
35
41
|
Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
|
|
36
42
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
@@ -40,19 +40,19 @@ dependencies = [
|
|
|
40
40
|
"scikit-learn (>=1.6.1,<1.7.0)",
|
|
41
41
|
"pyarrow (<22.0)",
|
|
42
42
|
"tqdm (>=4.67,<5)",
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"sb-obp (>=0.5.10,<0.6)",
|
|
49
|
-
"d3rlpy (>=2.8.1,<2.9)",
|
|
50
|
-
"implicit (>=0.7.2,<0.8)",
|
|
51
|
-
"pyspark (>=3.0,<3.5)",
|
|
52
|
-
"psutil (<=7.0.0)",
|
|
43
|
+
"pyspark (>=3.0,<3.5); extra == 'spark'",
|
|
44
|
+
"psutil (<=7.0.0); extra == 'spark'",
|
|
45
|
+
"torch (>=1.8, <3.0.0); extra == 'torch' or extra == 'torch-cpu'",
|
|
46
|
+
"pytorch-optimizer (>=3.8.0,<3.9.0); extra == 'torch' or extra == 'torch-cpu'",
|
|
47
|
+
"lightning (<2.6.0); extra == 'torch' or extra == 'torch-cpu'",
|
|
53
48
|
]
|
|
54
49
|
dynamic = ["dependencies"]
|
|
55
|
-
version = "0.20.
|
|
50
|
+
version = "0.20.1"
|
|
51
|
+
|
|
52
|
+
[project.optional-dependencies]
|
|
53
|
+
spark = ["pyspark", "psutil"]
|
|
54
|
+
torch = ["torch", "pytorch-optimizer", "lightning"]
|
|
55
|
+
torch-cpu = ["torch", "pytorch-optimizer", "lightning"]
|
|
56
56
|
|
|
57
57
|
[project.urls]
|
|
58
58
|
homepage = "https://sb-ai-lab.github.io/RePlay/"
|
|
@@ -66,6 +66,13 @@ target-version = ["py39", "py310", "py311", "py312"]
|
|
|
66
66
|
packages = [{include = "replay"}]
|
|
67
67
|
exclude = [
|
|
68
68
|
"replay/conftest.py",
|
|
69
|
+
"replay/experimental",
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
[tool.poetry.dependencies]
|
|
73
|
+
torch = [
|
|
74
|
+
{markers = "extra == 'torch-cpu' and extra !='torch'", source = "torch-cpu-mirror"},
|
|
75
|
+
{markers = "extra == 'torch' and extra !='torch-cpu'", source = "PyPI"},
|
|
69
76
|
]
|
|
70
77
|
|
|
71
78
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -88,9 +95,14 @@ docutils = "0.16"
|
|
|
88
95
|
data-science-types = "0.2.23"
|
|
89
96
|
filelock = "~3.14.0"
|
|
90
97
|
|
|
98
|
+
[[tool.poetry.source]]
|
|
99
|
+
name = "torch-cpu-mirror"
|
|
100
|
+
url = "https://download.pytorch.org/whl/cpu"
|
|
101
|
+
priority = "explicit"
|
|
102
|
+
|
|
91
103
|
[tool.poetry-dynamic-versioning]
|
|
92
104
|
enable = false
|
|
93
|
-
format-jinja = """0.20.
|
|
105
|
+
format-jinja = """0.20.1{{ env['PACKAGE_SUFFIX'] }}"""
|
|
94
106
|
vcs = "git"
|
|
95
107
|
|
|
96
108
|
[tool.ruff]
|
|
@@ -15,7 +15,6 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
|
|
|
15
15
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
16
16
|
from replay.preprocessing import LabelEncoder, LabelEncodingRule
|
|
17
17
|
from replay.preprocessing.label_encoder import HandleUnknownStrategies
|
|
18
|
-
from replay.utils import deprecation_warning
|
|
19
18
|
|
|
20
19
|
if TYPE_CHECKING:
|
|
21
20
|
from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
@@ -406,7 +405,6 @@ class SequenceTokenizer:
|
|
|
406
405
|
tensor_feature._set_cardinality(dataset_feature.cardinality)
|
|
407
406
|
|
|
408
407
|
@classmethod
|
|
409
|
-
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
410
408
|
def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
|
|
411
409
|
"""
|
|
412
410
|
Load tokenizer object from the given path.
|
|
@@ -450,12 +448,16 @@ class SequenceTokenizer:
|
|
|
450
448
|
tokenizer._encoder._features_columns = encoder_features_columns
|
|
451
449
|
tokenizer._encoder._encoding_rules = tokenizer_dict["encoder"]["encoding_rules"]
|
|
452
450
|
else:
|
|
451
|
+
warnings.warn(
|
|
452
|
+
"with `use_pickle` equals to `True` will be deprecated in future versions",
|
|
453
|
+
DeprecationWarning,
|
|
454
|
+
stacklevel=2,
|
|
455
|
+
)
|
|
453
456
|
with open(path, "rb") as file:
|
|
454
457
|
tokenizer = pickle.load(file)
|
|
455
458
|
|
|
456
459
|
return tokenizer
|
|
457
460
|
|
|
458
|
-
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
459
461
|
def save(self, path: str, use_pickle: bool = False) -> None:
|
|
460
462
|
"""
|
|
461
463
|
Save the tokenizer to the given path.
|
|
@@ -496,6 +498,11 @@ class SequenceTokenizer:
|
|
|
496
498
|
with open(base_path / "init_args.json", "w+") as file:
|
|
497
499
|
json.dump(tokenizer_dict, file)
|
|
498
500
|
else:
|
|
501
|
+
warnings.warn(
|
|
502
|
+
"with `use_pickle` equals to `True` will be deprecated in future versions",
|
|
503
|
+
DeprecationWarning,
|
|
504
|
+
stacklevel=2,
|
|
505
|
+
)
|
|
499
506
|
with open(path, "wb") as file:
|
|
500
507
|
pickle.dump(self, file)
|
|
501
508
|
|
|
@@ -110,17 +110,27 @@ class SequentialDataset(abc.ABC):
|
|
|
110
110
|
|
|
111
111
|
sequential_dict = {}
|
|
112
112
|
sequential_dict["_class_name"] = self.__class__.__name__
|
|
113
|
-
|
|
113
|
+
|
|
114
|
+
df = SequentialDataset._convert_array_to_list(self._sequences)
|
|
115
|
+
df.reset_index().to_parquet(base_path / "sequences.parquet")
|
|
114
116
|
sequential_dict["init_args"] = {
|
|
115
117
|
"tensor_schema": self._tensor_schema._get_object_args(),
|
|
116
118
|
"query_id_column": self._query_id_column,
|
|
117
119
|
"item_id_column": self._item_id_column,
|
|
118
|
-
"sequences_path": "sequences.
|
|
120
|
+
"sequences_path": "sequences.parquet",
|
|
119
121
|
}
|
|
120
122
|
|
|
121
123
|
with open(base_path / "init_args.json", "w+") as file:
|
|
122
124
|
json.dump(sequential_dict, file)
|
|
123
125
|
|
|
126
|
+
@staticmethod
|
|
127
|
+
def _convert_array_to_list(df):
|
|
128
|
+
return df.map(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def _convert_list_to_array(df):
|
|
132
|
+
return df.map(lambda x: np.array(x) if isinstance(x, list) else x)
|
|
133
|
+
|
|
124
134
|
|
|
125
135
|
class PandasSequentialDataset(SequentialDataset):
|
|
126
136
|
"""
|
|
@@ -149,7 +159,7 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
149
159
|
if sequences.index.name != query_id_column:
|
|
150
160
|
sequences = sequences.set_index(query_id_column)
|
|
151
161
|
|
|
152
|
-
self._sequences = sequences
|
|
162
|
+
self._sequences = SequentialDataset._convert_list_to_array(sequences)
|
|
153
163
|
|
|
154
164
|
def __len__(self) -> int:
|
|
155
165
|
return len(self._sequences)
|
|
@@ -206,7 +216,8 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
206
216
|
with open(base_path / "init_args.json") as file:
|
|
207
217
|
sequential_dict = json.loads(file.read())
|
|
208
218
|
|
|
209
|
-
sequences = pd.
|
|
219
|
+
sequences = pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"])
|
|
220
|
+
sequences = cls._convert_array_to_list(sequences)
|
|
210
221
|
dataset = cls(
|
|
211
222
|
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
212
223
|
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
@@ -258,18 +269,11 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
258
269
|
|
|
259
270
|
def _convert_polars_to_pandas(self, df: PolarsDataFrame) -> PandasDataFrame:
|
|
260
271
|
pandas_df = PandasDataFrame(df.to_dict(as_series=False))
|
|
261
|
-
|
|
262
|
-
for column in pandas_df.select_dtypes(include="object").columns:
|
|
263
|
-
if isinstance(pandas_df[column].iloc[0], list):
|
|
264
|
-
pandas_df[column] = pandas_df[column].apply(lambda x: np.array(x))
|
|
265
|
-
|
|
272
|
+
pandas_df = SequentialDataset._convert_list_to_array(pandas_df)
|
|
266
273
|
return pandas_df
|
|
267
274
|
|
|
268
275
|
def _convert_pandas_to_polars(self, df: PandasDataFrame) -> PolarsDataFrame:
|
|
269
|
-
|
|
270
|
-
if isinstance(df[column].iloc[0], np.ndarray):
|
|
271
|
-
df[column] = df[column].apply(lambda x: x.tolist())
|
|
272
|
-
|
|
276
|
+
df = SequentialDataset._convert_array_to_list(df)
|
|
273
277
|
return pl.from_dict(df.to_dict("list"))
|
|
274
278
|
|
|
275
279
|
@classmethod
|
|
@@ -290,7 +294,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
290
294
|
with open(base_path / "init_args.json") as file:
|
|
291
295
|
sequential_dict = json.loads(file.read())
|
|
292
296
|
|
|
293
|
-
sequences = pl.
|
|
297
|
+
sequences = pl.from_pandas(pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"]))
|
|
294
298
|
dataset = cls(
|
|
295
299
|
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
296
300
|
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from collections.abc import Generator, Sequence
|
|
2
3
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
|
|
3
4
|
|
|
@@ -5,8 +6,6 @@ import numpy as np
|
|
|
5
6
|
import torch
|
|
6
7
|
from torch.utils.data import Dataset as TorchDataset
|
|
7
8
|
|
|
8
|
-
from replay.utils import deprecation_warning
|
|
9
|
-
|
|
10
9
|
if TYPE_CHECKING:
|
|
11
10
|
from .schema import TensorFeatureInfo, TensorMap, TensorSchema
|
|
12
11
|
from .sequential_dataset import SequentialDataset
|
|
@@ -29,16 +28,12 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
29
28
|
Torch dataset for sequential recommender models
|
|
30
29
|
"""
|
|
31
30
|
|
|
32
|
-
@deprecation_warning(
|
|
33
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
34
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
35
|
-
)
|
|
36
31
|
def __init__(
|
|
37
32
|
self,
|
|
38
33
|
sequential: "SequentialDataset",
|
|
39
34
|
max_sequence_length: int,
|
|
40
35
|
sliding_window_step: Optional[int] = None,
|
|
41
|
-
padding_value: int =
|
|
36
|
+
padding_value: Optional[int] = None,
|
|
42
37
|
) -> None:
|
|
43
38
|
"""
|
|
44
39
|
:param sequential: sequential dataset
|
|
@@ -53,6 +48,15 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
53
48
|
self._sequential = sequential
|
|
54
49
|
self._max_sequence_length = max_sequence_length
|
|
55
50
|
self._sliding_window_step = sliding_window_step
|
|
51
|
+
if padding_value is not None:
|
|
52
|
+
warnings.warn(
|
|
53
|
+
"`padding_value` parameter will be removed in future versions. "
|
|
54
|
+
"Instead, you should specify `padding_value` for each column in TensorSchema",
|
|
55
|
+
DeprecationWarning,
|
|
56
|
+
stacklevel=2,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
padding_value = 0
|
|
56
60
|
self._padding_value = padding_value
|
|
57
61
|
self._index2sequence_map = self._build_index2sequence_map()
|
|
58
62
|
|
|
@@ -177,17 +181,13 @@ class TorchSequentialValidationDataset(TorchDataset):
|
|
|
177
181
|
Torch dataset for sequential recommender models that additionally stores ground truth
|
|
178
182
|
"""
|
|
179
183
|
|
|
180
|
-
@deprecation_warning(
|
|
181
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
182
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
183
|
-
)
|
|
184
184
|
def __init__(
|
|
185
185
|
self,
|
|
186
186
|
sequential: "SequentialDataset",
|
|
187
187
|
ground_truth: "SequentialDataset",
|
|
188
188
|
train: "SequentialDataset",
|
|
189
189
|
max_sequence_length: int,
|
|
190
|
-
padding_value: int =
|
|
190
|
+
padding_value: Optional[int] = None,
|
|
191
191
|
sliding_window_step: Optional[int] = None,
|
|
192
192
|
label_feature_name: Optional[str] = None,
|
|
193
193
|
):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from
|
|
2
|
+
from os.path import join
|
|
3
|
+
from typing import Optional, Union
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
@@ -8,7 +9,11 @@ from tqdm import tqdm
|
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.utils import SparkDataFrame
|
|
11
|
-
from replay.utils.spark_utils import
|
|
12
|
+
from replay.utils.spark_utils import (
|
|
13
|
+
convert2spark,
|
|
14
|
+
load_pickled_from_parquet,
|
|
15
|
+
save_picklable_to_parquet,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
from .base_rec import HybridRecommender
|
|
14
19
|
|
|
@@ -177,6 +182,7 @@ class LinUCB(HybridRecommender):
|
|
|
177
182
|
_study = None # field required for proper optuna's optimization
|
|
178
183
|
linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
|
|
179
184
|
rel_matrix: np.array # matrix with relevance scores from predict method
|
|
185
|
+
_num_items: int # number of items/arms
|
|
180
186
|
|
|
181
187
|
def __init__(
|
|
182
188
|
self,
|
|
@@ -195,7 +201,7 @@ class LinUCB(HybridRecommender):
|
|
|
195
201
|
|
|
196
202
|
@property
|
|
197
203
|
def _init_args(self):
|
|
198
|
-
return {"is_hybrid": self.is_hybrid}
|
|
204
|
+
return {"is_hybrid": self.is_hybrid, "eps": self.eps, "alpha": self.alpha}
|
|
199
205
|
|
|
200
206
|
def _verify_features(self, dataset: Dataset):
|
|
201
207
|
if dataset.query_features is None:
|
|
@@ -230,6 +236,7 @@ class LinUCB(HybridRecommender):
|
|
|
230
236
|
self._num_items = item_features.shape[0]
|
|
231
237
|
self._user_dim_size = user_features.shape[1] - 1
|
|
232
238
|
self._item_dim_size = item_features.shape[1] - 1
|
|
239
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
233
240
|
|
|
234
241
|
# now initialize an arm object for each potential arm instance
|
|
235
242
|
if self.is_hybrid:
|
|
@@ -248,11 +255,14 @@ class LinUCB(HybridRecommender):
|
|
|
248
255
|
]
|
|
249
256
|
|
|
250
257
|
for i in tqdm(range(self._num_items)):
|
|
251
|
-
B = log.loc[
|
|
252
|
-
|
|
253
|
-
|
|
258
|
+
B = log.loc[ # noqa: N806
|
|
259
|
+
(log[feature_schema.item_id_column] == i)
|
|
260
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
261
|
+
]
|
|
254
262
|
if not B.empty:
|
|
255
263
|
# if we have at least one user interacting with the hand i
|
|
264
|
+
idxs_list = B[feature_schema.query_id_column].values
|
|
265
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
256
266
|
cur_usrs = scs.csr_matrix(
|
|
257
267
|
user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
|
|
258
268
|
.drop(columns=[feature_schema.query_id_column])
|
|
@@ -284,11 +294,14 @@ class LinUCB(HybridRecommender):
|
|
|
284
294
|
]
|
|
285
295
|
|
|
286
296
|
for i in range(self._num_items):
|
|
287
|
-
B = log.loc[
|
|
288
|
-
|
|
289
|
-
|
|
297
|
+
B = log.loc[ # noqa: N806
|
|
298
|
+
(log[feature_schema.item_id_column] == i)
|
|
299
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
300
|
+
]
|
|
290
301
|
if not B.empty:
|
|
291
302
|
# if we have at least one user interacting with the hand i
|
|
303
|
+
idxs_list = B[feature_schema.query_id_column].values # noqa: F841
|
|
304
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
292
305
|
cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
|
|
293
306
|
columns=[feature_schema.query_id_column]
|
|
294
307
|
)
|
|
@@ -318,8 +331,10 @@ class LinUCB(HybridRecommender):
|
|
|
318
331
|
user_features = dataset.query_features
|
|
319
332
|
item_features = dataset.item_features
|
|
320
333
|
big_k = min(oversample * k, item_features.shape[0])
|
|
334
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
321
335
|
|
|
322
336
|
users = users.toPandas()
|
|
337
|
+
users = users[users[feature_schema.query_id_column].isin(self._user_idxs_list)]
|
|
323
338
|
num_user_pred = users.shape[0]
|
|
324
339
|
rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
|
|
325
340
|
|
|
@@ -404,3 +419,34 @@ class LinUCB(HybridRecommender):
|
|
|
404
419
|
warnings.warn(warn_msg)
|
|
405
420
|
dataset.to_spark()
|
|
406
421
|
return convert2spark(res_df)
|
|
422
|
+
|
|
423
|
+
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
424
|
+
super()._save_model(path, additional_params)
|
|
425
|
+
|
|
426
|
+
save_picklable_to_parquet(self.linucb_arms, join(path, "linucb_arms.dump"))
|
|
427
|
+
|
|
428
|
+
if self.is_hybrid:
|
|
429
|
+
linucb_hybrid_shared_params = {
|
|
430
|
+
"A_0": self.A_0,
|
|
431
|
+
"A_0_inv": self.A_0_inv,
|
|
432
|
+
"b_0": self.b_0,
|
|
433
|
+
"beta": self.beta,
|
|
434
|
+
}
|
|
435
|
+
save_picklable_to_parquet(
|
|
436
|
+
linucb_hybrid_shared_params,
|
|
437
|
+
join(path, "linucb_hybrid_shared_params.dump"),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _load_model(self, path: str):
|
|
441
|
+
super()._load_model(path)
|
|
442
|
+
|
|
443
|
+
loaded_linucb_arms = load_pickled_from_parquet(join(path, "linucb_arms.dump"))
|
|
444
|
+
self.linucb_arms = loaded_linucb_arms
|
|
445
|
+
self._num_items = len(loaded_linucb_arms)
|
|
446
|
+
|
|
447
|
+
if self.is_hybrid:
|
|
448
|
+
loaded_linucb_hybrid_shared_params = load_pickled_from_parquet(
|
|
449
|
+
join(path, "linucb_hybrid_shared_params.dump")
|
|
450
|
+
)
|
|
451
|
+
for param, value in loaded_linucb_hybrid_shared_params.items():
|
|
452
|
+
setattr(self, param, value)
|
|
@@ -12,7 +12,6 @@ from replay.data.nn import (
|
|
|
12
12
|
TorchSequentialDataset,
|
|
13
13
|
TorchSequentialValidationDataset,
|
|
14
14
|
)
|
|
15
|
-
from replay.utils import deprecation_warning
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class Bert4RecTrainingBatch(NamedTuple):
|
|
@@ -89,10 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
89
88
|
Dataset that generates samples to train BERT-like model
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
|
-
@deprecation_warning(
|
|
93
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
94
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
95
|
-
)
|
|
96
91
|
def __init__(
|
|
97
92
|
self,
|
|
98
93
|
sequential: SequentialDataset,
|
|
@@ -101,7 +96,7 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
101
96
|
sliding_window_step: Optional[int] = None,
|
|
102
97
|
label_feature_name: Optional[str] = None,
|
|
103
98
|
custom_masker: Optional[Bert4RecMasker] = None,
|
|
104
|
-
padding_value: int =
|
|
99
|
+
padding_value: Optional[int] = None,
|
|
105
100
|
) -> None:
|
|
106
101
|
"""
|
|
107
102
|
:param sequential: Sequential dataset with training data.
|
|
@@ -181,15 +176,11 @@ class Bert4RecPredictionDataset(TorchDataset):
|
|
|
181
176
|
Dataset that generates samples to infer BERT-like model
|
|
182
177
|
"""
|
|
183
178
|
|
|
184
|
-
@deprecation_warning(
|
|
185
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
186
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
187
|
-
)
|
|
188
179
|
def __init__(
|
|
189
180
|
self,
|
|
190
181
|
sequential: SequentialDataset,
|
|
191
182
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
183
|
+
padding_value: Optional[int] = None,
|
|
193
184
|
) -> None:
|
|
194
185
|
"""
|
|
195
186
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -239,17 +230,13 @@ class Bert4RecValidationDataset(TorchDataset):
|
|
|
239
230
|
Dataset that generates samples to infer and validate BERT-like model
|
|
240
231
|
"""
|
|
241
232
|
|
|
242
|
-
@deprecation_warning(
|
|
243
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
244
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
245
|
-
)
|
|
246
233
|
def __init__(
|
|
247
234
|
self,
|
|
248
235
|
sequential: SequentialDataset,
|
|
249
236
|
ground_truth: SequentialDataset,
|
|
250
237
|
train: SequentialDataset,
|
|
251
238
|
max_sequence_length: int,
|
|
252
|
-
padding_value: int =
|
|
239
|
+
padding_value: Optional[int] = None,
|
|
253
240
|
label_feature_name: Optional[str] = None,
|
|
254
241
|
):
|
|
255
242
|
"""
|
|
@@ -51,7 +51,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
51
51
|
|
|
52
52
|
def _compute_scores(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> torch.Tensor:
|
|
53
53
|
flat_seen_item_ids = self._get_flat_seen_item_ids(query_ids)
|
|
54
|
-
return self._fill_item_ids(scores, flat_seen_item_ids, -np.inf)
|
|
54
|
+
return self._fill_item_ids(scores.clone(), flat_seen_item_ids, -np.inf)
|
|
55
55
|
|
|
56
56
|
def _fill_item_ids(
|
|
57
57
|
self,
|
|
@@ -10,7 +10,6 @@ from replay.data.nn import (
|
|
|
10
10
|
TorchSequentialDataset,
|
|
11
11
|
TorchSequentialValidationDataset,
|
|
12
12
|
)
|
|
13
|
-
from replay.utils import deprecation_warning
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class SasRecTrainingBatch(NamedTuple):
|
|
@@ -31,17 +30,13 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
31
30
|
Dataset that generates samples to train SasRec-like model
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
|
-
@deprecation_warning(
|
|
35
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
36
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
37
|
-
)
|
|
38
33
|
def __init__(
|
|
39
34
|
self,
|
|
40
35
|
sequential: SequentialDataset,
|
|
41
36
|
max_sequence_length: int,
|
|
42
37
|
sequence_shift: int = 1,
|
|
43
38
|
sliding_window_step: Optional[None] = None,
|
|
44
|
-
padding_value: int =
|
|
39
|
+
padding_value: Optional[int] = None,
|
|
45
40
|
label_feature_name: Optional[str] = None,
|
|
46
41
|
) -> None:
|
|
47
42
|
"""
|
|
@@ -127,15 +122,11 @@ class SasRecPredictionDataset(TorchDataset):
|
|
|
127
122
|
Dataset that generates samples to infer SasRec-like model
|
|
128
123
|
"""
|
|
129
124
|
|
|
130
|
-
@deprecation_warning(
|
|
131
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
132
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
133
|
-
)
|
|
134
125
|
def __init__(
|
|
135
126
|
self,
|
|
136
127
|
sequential: SequentialDataset,
|
|
137
128
|
max_sequence_length: int,
|
|
138
|
-
padding_value: int =
|
|
129
|
+
padding_value: Optional[int] = None,
|
|
139
130
|
) -> None:
|
|
140
131
|
"""
|
|
141
132
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -179,17 +170,13 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
179
170
|
Dataset that generates samples to infer and validate SasRec-like model
|
|
180
171
|
"""
|
|
181
172
|
|
|
182
|
-
@deprecation_warning(
|
|
183
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
184
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
185
|
-
)
|
|
186
173
|
def __init__(
|
|
187
174
|
self,
|
|
188
175
|
sequential: SequentialDataset,
|
|
189
176
|
ground_truth: SequentialDataset,
|
|
190
177
|
train: SequentialDataset,
|
|
191
178
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
179
|
+
padding_value: Optional[int] = None,
|
|
193
180
|
label_feature_name: Optional[str] = None,
|
|
194
181
|
):
|
|
195
182
|
"""
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Most metrics require dataframe with recommendations
|
|
3
|
-
and dataframe with ground truth values —
|
|
4
|
-
which objects each user interacted with.
|
|
5
|
-
|
|
6
|
-
- recommendations (Union[pandas.DataFrame, spark.DataFrame]):
|
|
7
|
-
predictions of a recommender system,
|
|
8
|
-
DataFrame with columns ``[user_id, item_id, relevance]``
|
|
9
|
-
- ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
|
|
10
|
-
test data, DataFrame with columns
|
|
11
|
-
``[user_id, item_id, timestamp, relevance]``
|
|
12
|
-
|
|
13
|
-
Metric is calculated for all users, presented in ``ground_truth``
|
|
14
|
-
for accurate metric calculation in case when the recommender system generated
|
|
15
|
-
recommendation not for all users. It is assumed, that all users,
|
|
16
|
-
we want to calculate metric for, have positive interactions.
|
|
17
|
-
|
|
18
|
-
But if we have users, who observed the recommendations, but have not responded,
|
|
19
|
-
those users will be ignored and metric will be overestimated.
|
|
20
|
-
For such case we propose additional optional parameter ``ground_truth_users``,
|
|
21
|
-
the dataframe with all users, which should be considered during the metric calculation.
|
|
22
|
-
|
|
23
|
-
- ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
|
|
24
|
-
full list of users to calculate metric for, DataFrame with ``user_id`` column
|
|
25
|
-
|
|
26
|
-
Every metric is calculated using top ``K`` items for each user.
|
|
27
|
-
It is also possible to calculate metrics
|
|
28
|
-
using multiple values for ``K`` simultaneously.
|
|
29
|
-
In this case the result will be a dictionary and not a number.
|
|
30
|
-
|
|
31
|
-
Make sure your recommendations do not contain user-item duplicates
|
|
32
|
-
as duplicates could lead to the wrong calculation results.
|
|
33
|
-
|
|
34
|
-
- k (Union[Iterable[int], int]):
|
|
35
|
-
a single number or a list, specifying the
|
|
36
|
-
truncation length for recommendation list for each user
|
|
37
|
-
|
|
38
|
-
By default, metrics are averaged by users,
|
|
39
|
-
but you can alternatively use method ``metric.median``.
|
|
40
|
-
Also, you can get the lower bound
|
|
41
|
-
of ``conf_interval`` for a given ``alpha``.
|
|
42
|
-
|
|
43
|
-
Diversity metrics require extra parameters on initialization stage,
|
|
44
|
-
but do not use ``ground_truth`` parameter.
|
|
45
|
-
|
|
46
|
-
For each metric, a formula for its calculation is given, because this is
|
|
47
|
-
important for the correct comparison of algorithms, as mentioned in our
|
|
48
|
-
`article <https://arxiv.org/abs/2206.12858>`_.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
from replay.experimental.metrics.base_metric import Metric, NCISMetric
|
|
52
|
-
from replay.experimental.metrics.coverage import Coverage
|
|
53
|
-
from replay.experimental.metrics.hitrate import HitRate
|
|
54
|
-
from replay.experimental.metrics.map import MAP
|
|
55
|
-
from replay.experimental.metrics.mrr import MRR
|
|
56
|
-
from replay.experimental.metrics.ncis_precision import NCISPrecision
|
|
57
|
-
from replay.experimental.metrics.ndcg import NDCG
|
|
58
|
-
from replay.experimental.metrics.precision import Precision
|
|
59
|
-
from replay.experimental.metrics.recall import Recall
|
|
60
|
-
from replay.experimental.metrics.rocauc import RocAuc
|
|
61
|
-
from replay.experimental.metrics.surprisal import Surprisal
|
|
62
|
-
from replay.experimental.metrics.unexpectedness import Unexpectedness
|