replay-rec 0.17.1rc0__tar.gz → 0.18.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/PKG-INFO +13 -11
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/pyproject.toml +18 -15
- replay_rec-0.18.0rc0/replay/__init__.py +3 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset.py +3 -2
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/schema.py +5 -5
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/__init__.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/base_metric.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_rec.py +7 -7
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/cql.py +2 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/ddpg.py +6 -4
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/lightfm_wrap.py +2 -2
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/mult_vae.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/neuromf.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/scala_als.py +2 -2
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/data_preparator.py +2 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/padder.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/model_handler.py +7 -2
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/__init__.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/als.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/base_rec.py +7 -7
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/model.py +5 -112
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/model.py +8 -5
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/optimization/optuna_objective.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/converter.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/filters.py +19 -18
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/history_based_fp.py +5 -5
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/label_encoder.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/scenarios/__init__.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/last_n_splitter.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/time_splitter.py +1 -1
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/two_stage_splitter.py +8 -6
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/distributions.py +1 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/session_handler.py +3 -3
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/spark_utils.py +2 -2
- replay_rec-0.17.1rc0/replay/__init__.py +0 -2
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/LICENSE +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/NOTICE +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/README.md +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/sequence_tokenizer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/sequential_dataset.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/torch_sequential_dataset.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/nn/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/schema.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/spark_schema.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/coverage.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/experiment.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/map.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/mrr.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/precision.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/recall.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/admm_slim.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/base_torch_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/dt4rec/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/nn/data/schema_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/sequence_generator.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/logger.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/utils/session_handler.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/base_metric.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/categorical_diversity.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/coverage.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/descriptors.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/experiment.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/hitrate.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/map.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/mrr.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/ndcg.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/novelty.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/offline_metrics.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/precision.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/recall.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/rocauc.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/surprisal.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/torch_metrics_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/metrics/unexpectedness.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/association_rules.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/cat_pop_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/cluster.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/ann_mixin.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/extensions/ann/utils.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/kl_ucb.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/knn.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/pop_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/query_pop_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/random_rec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/slim.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/thompson_sampling.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/ucb.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/wilson.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/models/word2vec.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/optimization/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/preprocessing/sessionizer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/scenarios/fallback.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/base_splitter.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/cold_user_random_splitter.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/k_folds.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/new_users_splitter.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/random_splitter.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/splitters/ratio_splitter.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/__init__.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/common.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/model_handler.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/time.py +0 -0
- {replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/utils/types.py +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.18.0rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
Home-page: https://sb-ai-lab.github.io/RePlay/
|
|
6
6
|
License: Apache-2.0
|
|
7
7
|
Author: AI Lab
|
|
8
|
-
Requires-Python: >=3.8.1,<3.
|
|
8
|
+
Requires-Python: >=3.8.1,<3.12
|
|
9
9
|
Classifier: Development Status :: 4 - Beta
|
|
10
10
|
Classifier: Environment :: Console
|
|
11
11
|
Classifier: Intended Audience :: Developers
|
|
@@ -16,32 +16,34 @@ Classifier: Operating System :: Unix
|
|
|
16
16
|
Classifier: Programming Language :: Python :: 3
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.9
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
20
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
21
|
Provides-Extra: all
|
|
21
22
|
Provides-Extra: spark
|
|
22
23
|
Provides-Extra: torch
|
|
23
24
|
Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
|
|
25
|
+
Requires-Dist: fixed-install-nmslib (==2.1.2)
|
|
24
26
|
Requires-Dist: gym (>=0.26.0,<0.27.0)
|
|
25
|
-
Requires-Dist: hnswlib (
|
|
27
|
+
Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
|
|
26
28
|
Requires-Dist: implicit (>=0.7.0,<0.8.0)
|
|
27
29
|
Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
|
|
28
30
|
Requires-Dist: lightfm (==1.17)
|
|
29
|
-
Requires-Dist: lightning (>=2.0.2
|
|
31
|
+
Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "all"
|
|
30
32
|
Requires-Dist: llvmlite (>=0.32.1)
|
|
31
|
-
Requires-Dist: nmslib (==2.1.1)
|
|
32
33
|
Requires-Dist: numba (>=0.50)
|
|
33
34
|
Requires-Dist: numpy (>=1.20.0)
|
|
34
35
|
Requires-Dist: optuna (>=3.2.0,<3.3.0)
|
|
35
36
|
Requires-Dist: pandas (>=1.3.5,<=2.2.2)
|
|
36
|
-
Requires-Dist: polars (>=0.
|
|
37
|
-
Requires-Dist: psutil (>=
|
|
37
|
+
Requires-Dist: polars (>=1.0.0,<1.1.0)
|
|
38
|
+
Requires-Dist: psutil (>=6.0.0,<6.1.0)
|
|
38
39
|
Requires-Dist: pyarrow (>=12.0.1)
|
|
39
|
-
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
|
|
40
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
|
|
41
|
+
Requires-Dist: pyspark (>=3.4,<3.5) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
|
|
40
42
|
Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
|
|
41
|
-
Requires-Dist: sb-obp (>=0.5.
|
|
43
|
+
Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
|
|
42
44
|
Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
|
|
43
|
-
Requires-Dist: scipy (>=1.8.1,<
|
|
44
|
-
Requires-Dist: torch (>=1.8
|
|
45
|
+
Requires-Dist: scipy (>=1.8.1,<2.0.0)
|
|
46
|
+
Requires-Dist: torch (>=1.8,<=2.4.0) ; extra == "torch" or extra == "all"
|
|
45
47
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
46
48
|
Description-Content-Type: text/markdown
|
|
47
49
|
|
|
@@ -7,7 +7,7 @@ build-backend = "poetry_dynamic_versioning.backend"
|
|
|
7
7
|
|
|
8
8
|
[tool.black]
|
|
9
9
|
line-length = 120
|
|
10
|
-
target-versions = ["py38", "py39", "py310"]
|
|
10
|
+
target-versions = ["py38", "py39", "py310", "py311"]
|
|
11
11
|
|
|
12
12
|
[tool.poetry]
|
|
13
13
|
name = "replay-rec"
|
|
@@ -40,29 +40,32 @@ classifiers = [
|
|
|
40
40
|
exclude = [
|
|
41
41
|
"replay/conftest.py",
|
|
42
42
|
]
|
|
43
|
-
version = "0.
|
|
43
|
+
version = "0.18.0.preview"
|
|
44
44
|
|
|
45
45
|
[tool.poetry.dependencies]
|
|
46
|
-
python = ">=3.8.1, <3.
|
|
46
|
+
python = ">=3.8.1, <3.12"
|
|
47
47
|
numpy = ">=1.20.0"
|
|
48
|
-
pandas = ">=1.3.5
|
|
49
|
-
polars = "~0.
|
|
48
|
+
pandas = ">=1.3.5, <=2.2.2"
|
|
49
|
+
polars = "~1.0.0"
|
|
50
50
|
optuna = "~3.2.0"
|
|
51
|
-
scipy = "
|
|
52
|
-
psutil = "~
|
|
53
|
-
pyspark = {version = ">=3.0,<3.5", optional = true}
|
|
51
|
+
scipy = "^1.8.1"
|
|
52
|
+
psutil = "~6.0.0"
|
|
54
53
|
scikit-learn = "^1.0.2"
|
|
55
54
|
pyarrow = ">=12.0.1"
|
|
56
|
-
nmslib = "2.1.
|
|
57
|
-
hnswlib = "0.7.0"
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
fixed-install-nmslib = "2.1.2"
|
|
56
|
+
hnswlib = "^0.7.0"
|
|
57
|
+
pyspark = [
|
|
58
|
+
{version = ">=3.4,<3.5", python = ">=3.11,<3.12"},
|
|
59
|
+
{version = ">=3.0,<3.5", python = ">=3.8.1,<3.11"},
|
|
60
|
+
]
|
|
61
|
+
torch = ">=1.8, <=2.4.0"
|
|
62
|
+
lightning = ">=2.0.2, <=2.4.0"
|
|
60
63
|
pytorch-ranger = "^0.1.1"
|
|
61
64
|
lightfm = "1.17"
|
|
62
65
|
lightautoml = "~0.3.1"
|
|
63
66
|
numba = ">=0.50"
|
|
64
67
|
llvmlite = ">=0.32.1"
|
|
65
|
-
sb-obp = "^0.5.
|
|
68
|
+
sb-obp = "^0.5.8"
|
|
66
69
|
d3rlpy = "^2.0.4"
|
|
67
70
|
implicit = "~0.7.0"
|
|
68
71
|
gym = "^0.26.0"
|
|
@@ -77,7 +80,7 @@ jupyter = "~1.0.0"
|
|
|
77
80
|
jupyterlab = "^3.6.0"
|
|
78
81
|
pytest = ">=7.1.0"
|
|
79
82
|
pytest-cov = ">=3.0.0"
|
|
80
|
-
statsmodels = "~0.
|
|
83
|
+
statsmodels = "~0.14.0"
|
|
81
84
|
black = ">=23.3.0"
|
|
82
85
|
ruff = ">=0.0.261"
|
|
83
86
|
toml-sort = "^0.23.0"
|
|
@@ -92,7 +95,7 @@ data-science-types = "0.2.23"
|
|
|
92
95
|
|
|
93
96
|
[tool.poetry-dynamic-versioning]
|
|
94
97
|
enable = false
|
|
95
|
-
format-jinja = """0.
|
|
98
|
+
format-jinja = """0.18.0{{ env['PACKAGE_SUFFIX'] }}"""
|
|
96
99
|
vcs = "git"
|
|
97
100
|
|
|
98
101
|
[tool.ruff]
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
``Dataset`` universal dataset class for manipulating interactions and feed data to models.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
from __future__ import annotations
|
|
5
6
|
|
|
6
7
|
import json
|
|
@@ -606,7 +607,7 @@ class Dataset:
|
|
|
606
607
|
if self.is_pandas:
|
|
607
608
|
min_id = data[column].min()
|
|
608
609
|
elif self.is_spark:
|
|
609
|
-
min_id = data.agg(sf.min(column).alias("min_index")).
|
|
610
|
+
min_id = data.agg(sf.min(column).alias("min_index")).first()[0]
|
|
610
611
|
else:
|
|
611
612
|
min_id = data[column].min()
|
|
612
613
|
if min_id < 0:
|
|
@@ -616,7 +617,7 @@ class Dataset:
|
|
|
616
617
|
if self.is_pandas:
|
|
617
618
|
max_id = data[column].max()
|
|
618
619
|
elif self.is_spark:
|
|
619
|
-
max_id = data.agg(sf.max(column).alias("max_index")).
|
|
620
|
+
max_id = data.agg(sf.max(column).alias("max_index")).first()[0]
|
|
620
621
|
else:
|
|
621
622
|
max_id = data[column].max()
|
|
622
623
|
|
{replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py
RENAMED
|
@@ -4,6 +4,7 @@ Contains classes for encoding categorical data
|
|
|
4
4
|
``LabelEncoderTransformWarning`` new category of warning for DatasetLabelEncoder.
|
|
5
5
|
``DatasetLabelEncoder`` to encode categorical features in `Dataset` objects.
|
|
6
6
|
"""
|
|
7
|
+
|
|
7
8
|
import warnings
|
|
8
9
|
from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
|
|
9
10
|
|
|
@@ -418,11 +418,11 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
418
418
|
"feature_type": feature.feature_type.name,
|
|
419
419
|
"is_seq": feature.is_seq,
|
|
420
420
|
"feature_hint": feature.feature_hint.name if feature.feature_hint else None,
|
|
421
|
-
"feature_sources":
|
|
422
|
-
{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
421
|
+
"feature_sources": (
|
|
422
|
+
[{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources]
|
|
423
|
+
if feature.feature_sources
|
|
424
|
+
else None
|
|
425
|
+
),
|
|
426
426
|
"cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
427
427
|
"embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
428
428
|
"tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
|
|
@@ -47,6 +47,7 @@ For each metric, a formula for its calculation is given, because this is
|
|
|
47
47
|
important for the correct comparison of algorithms, as mentioned in our
|
|
48
48
|
`article <https://arxiv.org/abs/2206.12858>`_.
|
|
49
49
|
"""
|
|
50
|
+
|
|
50
51
|
from replay.experimental.metrics.base_metric import Metric, NCISMetric
|
|
51
52
|
from replay.experimental.metrics.coverage import Coverage
|
|
52
53
|
from replay.experimental.metrics.hitrate import HitRate
|
|
@@ -86,8 +86,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
86
86
|
self.fit_items = sf.broadcast(items)
|
|
87
87
|
self._num_users = self.fit_users.count()
|
|
88
88
|
self._num_items = self.fit_items.count()
|
|
89
|
-
self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).
|
|
90
|
-
self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).
|
|
89
|
+
self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
|
|
90
|
+
self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
|
|
91
91
|
self._fit(log, user_features, item_features)
|
|
92
92
|
|
|
93
93
|
@abstractmethod
|
|
@@ -122,7 +122,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
122
122
|
# count maximal number of items seen by users
|
|
123
123
|
max_seen = 0
|
|
124
124
|
if num_seen.count() > 0:
|
|
125
|
-
max_seen = num_seen.select(sf.max("seen_count")).
|
|
125
|
+
max_seen = num_seen.select(sf.max("seen_count")).first()[0]
|
|
126
126
|
|
|
127
127
|
# crop recommendations to first k + max_seen items for each user
|
|
128
128
|
recs = recs.withColumn(
|
|
@@ -335,7 +335,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
335
335
|
setattr(
|
|
336
336
|
self,
|
|
337
337
|
f"_{entity}_dim_size",
|
|
338
|
-
getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).
|
|
338
|
+
getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
|
|
339
339
|
)
|
|
340
340
|
return getattr(self, f"_{entity}_dim_size")
|
|
341
341
|
|
|
@@ -1088,7 +1088,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1088
1088
|
Calculating a fill value a the minimal relevance
|
|
1089
1089
|
calculated during model training multiplied by weight.
|
|
1090
1090
|
"""
|
|
1091
|
-
return item_popularity.select(sf.min("relevance")).
|
|
1091
|
+
return item_popularity.select(sf.min("relevance")).first()[0] * weight
|
|
1092
1092
|
|
|
1093
1093
|
@staticmethod
|
|
1094
1094
|
def _check_relevance(log: SparkDataFrame):
|
|
@@ -1113,7 +1113,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1113
1113
|
max_hist_len = (
|
|
1114
1114
|
(log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
|
|
1115
1115
|
.select(sf.max("items_count"))
|
|
1116
|
-
.
|
|
1116
|
+
.first()[0]
|
|
1117
1117
|
)
|
|
1118
1118
|
# all users have empty history
|
|
1119
1119
|
if max_hist_len is None:
|
|
@@ -1146,7 +1146,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1146
1146
|
users = users.join(user_to_num_items, on="user_idx", how="left")
|
|
1147
1147
|
users = users.fillna(0, "num_items")
|
|
1148
1148
|
# 'selected_item_popularity' truncation by k + max_seen
|
|
1149
|
-
max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).
|
|
1149
|
+
max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
|
|
1150
1150
|
selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
|
|
1151
1151
|
return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
|
|
1152
1152
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Using CQL implementation from `d3rlpy` package.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import io
|
|
5
6
|
import logging
|
|
6
7
|
import tempfile
|
|
@@ -402,6 +403,7 @@ class MdpDatasetBuilder:
|
|
|
402
403
|
top_k (int): the number of top user items to learn predicting.
|
|
403
404
|
action_randomization_scale (float): the scale of action randomization gaussian noise.
|
|
404
405
|
"""
|
|
406
|
+
|
|
405
407
|
logger: logging.Logger
|
|
406
408
|
top_k: int
|
|
407
409
|
action_randomization_scale: float
|
|
@@ -704,13 +704,15 @@ class DDPG(Recommender):
|
|
|
704
704
|
:param data: pandas DataFrame
|
|
705
705
|
"""
|
|
706
706
|
data = data[["user_idx", "item_idx", "relevance"]]
|
|
707
|
-
|
|
707
|
+
users = data["user_idx"].values.tolist()
|
|
708
|
+
items = data["item_idx"].values.tolist()
|
|
709
|
+
scores = data["relevance"].values.tolist()
|
|
708
710
|
|
|
709
|
-
user_num =
|
|
710
|
-
item_num =
|
|
711
|
+
user_num = max(users) + 1
|
|
712
|
+
item_num = max(items) + 1
|
|
711
713
|
|
|
712
714
|
train_mat = defaultdict(float)
|
|
713
|
-
for user, item, rel in
|
|
715
|
+
for user, item, rel in zip(users, items, scores):
|
|
714
716
|
train_mat[user, item] = rel
|
|
715
717
|
train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
|
|
716
718
|
dict.update(train_matrix, train_mat)
|
|
@@ -98,12 +98,12 @@ class LightFMWrap(HybridRecommender):
|
|
|
98
98
|
fit_dim = getattr(self, f"_{entity}_dim")
|
|
99
99
|
matrix_height = max(
|
|
100
100
|
fit_dim,
|
|
101
|
-
log_ids_list.select(sf.max(idx_col_name)).
|
|
101
|
+
log_ids_list.select(sf.max(idx_col_name)).first()[0] + 1,
|
|
102
102
|
)
|
|
103
103
|
if not feature_table.rdd.isEmpty():
|
|
104
104
|
matrix_height = max(
|
|
105
105
|
matrix_height,
|
|
106
|
-
feature_table.select(sf.max(idx_col_name)).
|
|
106
|
+
feature_table.select(sf.max(idx_col_name)).first()[0] + 1,
|
|
107
107
|
)
|
|
108
108
|
|
|
109
109
|
features_np = (
|
|
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
115
115
|
.groupBy("user_idx")
|
|
116
116
|
.agg(sf.count("user_idx").alias("num_seen"))
|
|
117
117
|
.select(sf.max("num_seen"))
|
|
118
|
-
.
|
|
118
|
+
.first()[0]
|
|
119
119
|
)
|
|
120
120
|
max_seen = max_seen_in_log if max_seen_in_log is not None else 0
|
|
121
121
|
|
|
@@ -280,7 +280,7 @@ class ScalaALSWrap(ALSWrap, ANNMixin):
|
|
|
280
280
|
.groupBy("user_idx")
|
|
281
281
|
.agg(sf.count("user_idx").alias("num_seen"))
|
|
282
282
|
.select(sf.max("num_seen"))
|
|
283
|
-
.
|
|
283
|
+
.first()[0]
|
|
284
284
|
)
|
|
285
285
|
max_seen = max_seen_in_log if max_seen_in_log is not None else 0
|
|
286
286
|
|
{replay_rec-0.17.1rc0 → replay_rec-0.18.0rc0}/replay/experimental/preprocessing/data_preparator.py
RENAMED
|
@@ -6,6 +6,7 @@ Contains classes for data preparation and categorical features transformation.
|
|
|
6
6
|
``ToNumericFeatureTransformer`` leaves only numerical features
|
|
7
7
|
by one-hot encoding of some features and deleting the others.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
import json
|
|
10
11
|
import logging
|
|
11
12
|
import string
|
|
@@ -699,7 +700,7 @@ if PYSPARK_AVAILABLE:
|
|
|
699
700
|
return
|
|
700
701
|
|
|
701
702
|
cat_feat_values_dict = {
|
|
702
|
-
name: (spark_df.select(sf.collect_set(sf.col(name))).
|
|
703
|
+
name: (spark_df.select(sf.collect_set(sf.col(name))).first()[0]) for name in self.cat_cols_list
|
|
703
704
|
}
|
|
704
705
|
self.expressions_list = [
|
|
705
706
|
sf.when(sf.col(col_name) == cur_name, 1)
|
|
@@ -179,7 +179,7 @@ class Padder:
|
|
|
179
179
|
self, df_transformed: SparkDataFrame, col: str, pad_value: Union[str, float, List, None]
|
|
180
180
|
) -> SparkDataFrame:
|
|
181
181
|
if self.array_size == -1:
|
|
182
|
-
max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).
|
|
182
|
+
max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).first()[0]
|
|
183
183
|
else:
|
|
184
184
|
max_array_size = self.array_size
|
|
185
185
|
|
|
@@ -383,7 +383,7 @@ class TwoStagesScenario(HybridRecommender):
|
|
|
383
383
|
log_to_filter_cached.groupBy("user_idx")
|
|
384
384
|
.agg(sf.count("item_idx").alias("num_positives"))
|
|
385
385
|
.select(sf.max("num_positives"))
|
|
386
|
-
.
|
|
386
|
+
.first()[0]
|
|
387
387
|
)
|
|
388
388
|
|
|
389
389
|
pred = model._predict(
|
|
@@ -170,8 +170,13 @@ def load_indexer(path: str) -> Indexer:
|
|
|
170
170
|
|
|
171
171
|
indexer = Indexer(**args)
|
|
172
172
|
|
|
173
|
-
|
|
174
|
-
|
|
173
|
+
if user_type.endswith("()"):
|
|
174
|
+
user_type = user_type[:-2]
|
|
175
|
+
item_type = item_type[:-2]
|
|
176
|
+
user_type = getattr(st, user_type)
|
|
177
|
+
item_type = getattr(st, item_type)
|
|
178
|
+
indexer.user_type = user_type()
|
|
179
|
+
indexer.item_type = item_type()
|
|
175
180
|
|
|
176
181
|
indexer.user_indexer = StringIndexerModel.load(join(path, "user_indexer"))
|
|
177
182
|
indexer.item_indexer = StringIndexerModel.load(join(path, "item_indexer"))
|
|
@@ -42,6 +42,7 @@ For each metric, a formula for its calculation is given, because this is
|
|
|
42
42
|
important for the correct comparison of algorithms, as mentioned in our
|
|
43
43
|
`article <https://arxiv.org/abs/2206.12858>`_.
|
|
44
44
|
"""
|
|
45
|
+
|
|
45
46
|
from .base_metric import Metric
|
|
46
47
|
from .categorical_diversity import CategoricalDiversity
|
|
47
48
|
from .coverage import Coverage
|
|
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
115
115
|
.groupBy(self.query_column)
|
|
116
116
|
.agg(sf.count(self.query_column).alias("num_seen"))
|
|
117
117
|
.select(sf.max("num_seen"))
|
|
118
|
-
.
|
|
118
|
+
.first()[0]
|
|
119
119
|
)
|
|
120
120
|
max_seen = max_seen_in_interactions if max_seen_in_interactions is not None else 0
|
|
121
121
|
|
|
@@ -401,8 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
401
401
|
self.fit_items = sf.broadcast(items)
|
|
402
402
|
self._num_queries = self.fit_queries.count()
|
|
403
403
|
self._num_items = self.fit_items.count()
|
|
404
|
-
self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).
|
|
405
|
-
self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).
|
|
404
|
+
self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).first()[0] + 1
|
|
405
|
+
self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).first()[0] + 1
|
|
406
406
|
self._fit(dataset)
|
|
407
407
|
|
|
408
408
|
@abstractmethod
|
|
@@ -431,7 +431,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
431
431
|
# count maximal number of items seen by queries
|
|
432
432
|
max_seen = 0
|
|
433
433
|
if num_seen.count() > 0:
|
|
434
|
-
max_seen = num_seen.select(sf.max("seen_count")).
|
|
434
|
+
max_seen = num_seen.select(sf.max("seen_count")).first()[0]
|
|
435
435
|
|
|
436
436
|
# crop recommendations to first k + max_seen items for each query
|
|
437
437
|
recs = recs.withColumn(
|
|
@@ -708,7 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
708
708
|
setattr(
|
|
709
709
|
self,
|
|
710
710
|
dim_size,
|
|
711
|
-
fit_entities.agg({column: "max"}).
|
|
711
|
+
fit_entities.agg({column: "max"}).first()[0] + 1,
|
|
712
712
|
)
|
|
713
713
|
return getattr(self, dim_size)
|
|
714
714
|
|
|
@@ -1426,7 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1426
1426
|
Calculating a fill value a the minimal rating
|
|
1427
1427
|
calculated during model training multiplied by weight.
|
|
1428
1428
|
"""
|
|
1429
|
-
return item_popularity.select(sf.min(rating_column)).
|
|
1429
|
+
return item_popularity.select(sf.min(rating_column)).first()[0] * weight
|
|
1430
1430
|
|
|
1431
1431
|
@staticmethod
|
|
1432
1432
|
def _check_rating(dataset: Dataset):
|
|
@@ -1460,7 +1460,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1460
1460
|
.agg(sf.countDistinct(item_column).alias("items_count"))
|
|
1461
1461
|
)
|
|
1462
1462
|
.select(sf.max("items_count"))
|
|
1463
|
-
.
|
|
1463
|
+
.first()[0]
|
|
1464
1464
|
)
|
|
1465
1465
|
# all queries have empty history
|
|
1466
1466
|
if max_hist_len is None:
|
|
@@ -1495,7 +1495,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1495
1495
|
queries = queries.join(query_to_num_items, on=self.query_column, how="left")
|
|
1496
1496
|
queries = queries.fillna(0, "num_items")
|
|
1497
1497
|
# 'selected_item_popularity' truncation by k + max_seen
|
|
1498
|
-
max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).
|
|
1498
|
+
max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
|
|
1499
1499
|
selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
|
|
1500
1500
|
return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
|
|
1501
1501
|
|
|
@@ -32,9 +32,9 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
32
32
|
index = index_store.load_index(
|
|
33
33
|
init_index=lambda: create_nmslib_index_instance(index_params),
|
|
34
34
|
load_index=lambda index, path: index.loadIndex(path, load_data=True),
|
|
35
|
-
configure_index=lambda index:
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
configure_index=lambda index: (
|
|
36
|
+
index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
|
|
37
|
+
),
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
# max number of items to retrieve per batch
|
|
@@ -30,9 +30,9 @@ class NmslibIndexInferer(IndexInferer):
|
|
|
30
30
|
index = index_store.load_index(
|
|
31
31
|
init_index=lambda: create_nmslib_index_instance(index_params),
|
|
32
32
|
load_index=lambda index, path: index.loadIndex(path, load_data=True),
|
|
33
|
-
configure_index=lambda index:
|
|
34
|
-
|
|
35
|
-
|
|
33
|
+
configure_index=lambda index: (
|
|
34
|
+
index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
|
|
35
|
+
),
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
|