replay-rec 0.19.0rc0__tar.gz → 0.20.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/PKG-INFO +58 -42
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/README.md +32 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/pyproject.toml +56 -70
- replay_rec-0.20.0rc0/replay/__init__.py +7 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset.py +19 -18
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/schema.py +9 -18
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/sequence_tokenizer.py +54 -47
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/sequential_dataset.py +16 -11
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/torch_sequential_dataset.py +18 -16
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/utils.py +3 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/schema.py +3 -12
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/base_metric.py +6 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/coverage.py +5 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/experiment.py +2 -2
- replay_rec-0.20.0rc0/replay/experimental/models/__init__.py +50 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/admm_slim.py +59 -7
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_neighbour_rec.py +6 -10
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_rec.py +58 -12
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/base_torch_rec.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/cql.py +6 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/ddpg.py +47 -38
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/utils.py +4 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +5 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/lightfm_wrap.py +4 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/mult_vae.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/neural_ts.py +13 -13
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/neuromf.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/scala_als.py +14 -17
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/nn/data/schema_builder.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/data_preparator.py +13 -13
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/padder.py +7 -7
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/sequence_generator.py +7 -7
- replay_rec-0.20.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +5 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +3 -5
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +18 -18
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/session_handler.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/base_metric.py +12 -11
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/categorical_diversity.py +8 -8
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/coverage.py +11 -15
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/experiment.py +6 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/hitrate.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/map.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/mrr.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/ndcg.py +1 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/novelty.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/offline_metrics.py +18 -18
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/precision.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/recall.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/rocauc.py +1 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/surprisal.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/torch_metrics_builder.py +13 -12
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/unexpectedness.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/__init__.py +19 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/als.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/association_rules.py +5 -7
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/base_neighbour_rec.py +8 -10
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/base_rec.py +54 -302
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/cat_pop_rec.py +4 -2
- replay_rec-0.20.0rc0/replay/models/common.py +69 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/ann_mixin.py +31 -25
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/utils.py +4 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/knn.py +18 -17
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/lin_ucb.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/model.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +14 -14
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay_rec-0.20.0rc0/replay/models/nn/sequential/compiled/__init__.py +15 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/base_compiled_model.py +8 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +2 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +1 -1
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/model.py +9 -9
- replay_rec-0.20.0rc0/replay/models/optimization/__init__.py +14 -0
- replay_rec-0.20.0rc0/replay/models/optimization/optuna_mixin.py +279 -0
- {replay_rec-0.19.0rc0/replay → replay_rec-0.20.0rc0/replay/models}/optimization/optuna_objective.py +13 -15
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/slim.py +4 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/ucb.py +2 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/word2vec.py +9 -14
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/discretizer.py +9 -9
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/filters.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/history_based_fp.py +7 -7
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/label_encoder.py +9 -8
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/scenarios/fallback.py +4 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/base_splitter.py +3 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/cold_user_random_splitter.py +17 -11
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/k_folds.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/last_n_splitter.py +27 -20
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/new_users_splitter.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/random_splitter.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/ratio_splitter.py +10 -10
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/time_splitter.py +6 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/two_stage_splitter.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/__init__.py +7 -2
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/common.py +5 -3
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/model_handler.py +11 -31
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/session_handler.py +4 -4
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/spark_utils.py +8 -7
- replay_rec-0.20.0rc0/replay/utils/types.py +50 -0
- replay_rec-0.20.0rc0/replay/utils/warnings.py +26 -0
- replay_rec-0.19.0rc0/replay/__init__.py +0 -3
- replay_rec-0.19.0rc0/replay/experimental/models/__init__.py +0 -13
- replay_rec-0.19.0rc0/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay_rec-0.19.0rc0/replay/models/nn/sequential/compiled/__init__.py +0 -5
- replay_rec-0.19.0rc0/replay/optimization/__init__.py +0 -5
- replay_rec-0.19.0rc0/replay/utils/types.py +0 -38
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/LICENSE +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/NOTICE +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/nn/__init__.py +6 -6
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/spark_schema.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/map.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/mrr.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/precision.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/recall.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/hierarchical_recommender.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/models/u_lin_ucb.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/scenarios/two_stages/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/logger.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/experimental/utils/model_handler.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/metrics/descriptors.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/cluster.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/kl_ucb.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/loss/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/loss/sce.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/pop_rec.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/query_pop_rec.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/random_rec.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/thompson_sampling.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/models/wilson.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/converter.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/preprocessing/sessionizer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/scenarios/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/splitters/__init__.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/distributions.py +0 -0
- {replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/utils/time.py +0 -0
|
@@ -1,53 +1,38 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.20.0rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
|
-
|
|
6
|
-
License:
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
License-File: NOTICE
|
|
7
8
|
Author: AI Lab
|
|
8
|
-
Requires-Python: >=3.
|
|
9
|
+
Requires-Python: >=3.9, <3.13
|
|
10
|
+
Classifier: Operating System :: Unix
|
|
9
11
|
Classifier: Development Status :: 4 - Beta
|
|
10
12
|
Classifier: Environment :: Console
|
|
11
13
|
Classifier: Intended Audience :: Developers
|
|
12
14
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
14
15
|
Classifier: Natural Language :: English
|
|
15
|
-
Classifier: Operating System :: Unix
|
|
16
|
-
Classifier: Programming Language :: Python :: 3
|
|
17
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
20
16
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Requires-Dist:
|
|
26
|
-
Requires-Dist:
|
|
27
|
-
Requires-Dist:
|
|
28
|
-
Requires-Dist:
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist:
|
|
31
|
-
Requires-Dist:
|
|
32
|
-
Requires-Dist:
|
|
33
|
-
Requires-Dist:
|
|
34
|
-
Requires-Dist:
|
|
35
|
-
Requires-Dist:
|
|
36
|
-
Requires-Dist:
|
|
37
|
-
Requires-Dist:
|
|
38
|
-
Requires-Dist:
|
|
39
|
-
|
|
40
|
-
Requires-Dist: polars (>=1.0.0,<1.1.0)
|
|
41
|
-
Requires-Dist: psutil (>=6.0.0,<6.1.0)
|
|
42
|
-
Requires-Dist: pyarrow (>=12.0.1)
|
|
43
|
-
Requires-Dist: pyspark (>=3.0,<3.5) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
|
|
44
|
-
Requires-Dist: pyspark (>=3.4,<3.5) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
|
|
45
|
-
Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "torch-openvino" or extra == "all"
|
|
46
|
-
Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
|
|
47
|
-
Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
|
|
48
|
-
Requires-Dist: scipy (>=1.8.1,<2.0.0)
|
|
49
|
-
Requires-Dist: torch (>=1.8,<3.0.0) ; (python_version >= "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
|
|
50
|
-
Requires-Dist: torch (>=1.8,<=2.4.1) ; (python_version >= "3.8" and python_version < "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
|
|
17
|
+
Requires-Dist: d3rlpy (>=2.8.1,<2.9)
|
|
18
|
+
Requires-Dist: implicit (>=0.7.2,<0.8)
|
|
19
|
+
Requires-Dist: lightautoml (>=0.4.1,<0.5)
|
|
20
|
+
Requires-Dist: lightning (>=2.0.2,<=2.4.0)
|
|
21
|
+
Requires-Dist: numba (>=0.50,<1)
|
|
22
|
+
Requires-Dist: numpy (>=1.20.0,<2)
|
|
23
|
+
Requires-Dist: pandas (>=1.3.5,<2.4.0)
|
|
24
|
+
Requires-Dist: polars (<2.0)
|
|
25
|
+
Requires-Dist: psutil (<=7.0.0)
|
|
26
|
+
Requires-Dist: pyarrow (<22.0)
|
|
27
|
+
Requires-Dist: pyspark (>=3.0,<3.5)
|
|
28
|
+
Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
|
|
29
|
+
Requires-Dist: sb-obp (>=0.5.10,<0.6)
|
|
30
|
+
Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
|
|
31
|
+
Requires-Dist: scipy (>=1.13.1,<1.14)
|
|
32
|
+
Requires-Dist: setuptools
|
|
33
|
+
Requires-Dist: torch (>=1.8,<3.0.0)
|
|
34
|
+
Requires-Dist: tqdm (>=4.67,<5)
|
|
35
|
+
Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
|
|
51
36
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
52
37
|
Description-Content-Type: text/markdown
|
|
53
38
|
|
|
@@ -216,7 +201,6 @@ pip install replay-rec==XX.YY.ZZrc0
|
|
|
216
201
|
In addition to the core package, several extras are also provided, including:
|
|
217
202
|
- `[spark]`: Install PySpark functionality
|
|
218
203
|
- `[torch]`: Install PyTorch and Lightning functionality
|
|
219
|
-
- `[all]`: `[spark]` `[torch]`
|
|
220
204
|
|
|
221
205
|
Example:
|
|
222
206
|
```bash
|
|
@@ -227,9 +211,41 @@ pip install replay-rec[spark]
|
|
|
227
211
|
pip install replay-rec[spark]==XX.YY.ZZrc0
|
|
228
212
|
```
|
|
229
213
|
|
|
214
|
+
Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
|
|
215
|
+
```bash
|
|
216
|
+
# Install package with the CPU version of torch
|
|
217
|
+
pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
|
|
230
221
|
To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
|
|
231
222
|
|
|
232
223
|
|
|
224
|
+
### Optional features
|
|
225
|
+
RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
|
|
226
|
+
|
|
227
|
+
1) Hyperpearameter search via Optuna:
|
|
228
|
+
```bash
|
|
229
|
+
pip install optuna
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
2) Model compilation via OpenVINO:
|
|
233
|
+
```bash
|
|
234
|
+
pip install openvino onnx
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
3) Vector database and hierarchical search support:
|
|
238
|
+
```bash
|
|
239
|
+
pip install hnswlib fixed-install-nmslib
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
4) (Experimental) LightFM model support:
|
|
243
|
+
```bash
|
|
244
|
+
pip install ligfhtfm
|
|
245
|
+
```
|
|
246
|
+
> **_NOTE_** : LightFM is not officially supported for Python 3.12 due to discontinued maintenance of the library. If you wish to install it locally, you'll have to use a patched fork of LightFM, such as the [one used internally](https://github.com/daviddavo/lightfm).
|
|
247
|
+
|
|
248
|
+
|
|
233
249
|
<a name="examples"></a>
|
|
234
250
|
## 📑 Resources
|
|
235
251
|
|
|
@@ -163,7 +163,6 @@ pip install replay-rec==XX.YY.ZZrc0
|
|
|
163
163
|
In addition to the core package, several extras are also provided, including:
|
|
164
164
|
- `[spark]`: Install PySpark functionality
|
|
165
165
|
- `[torch]`: Install PyTorch and Lightning functionality
|
|
166
|
-
- `[all]`: `[spark]` `[torch]`
|
|
167
166
|
|
|
168
167
|
Example:
|
|
169
168
|
```bash
|
|
@@ -174,9 +173,41 @@ pip install replay-rec[spark]
|
|
|
174
173
|
pip install replay-rec[spark]==XX.YY.ZZrc0
|
|
175
174
|
```
|
|
176
175
|
|
|
176
|
+
Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
|
|
177
|
+
```bash
|
|
178
|
+
# Install package with the CPU version of torch
|
|
179
|
+
pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
|
|
177
183
|
To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
|
|
178
184
|
|
|
179
185
|
|
|
186
|
+
### Optional features
|
|
187
|
+
RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
|
|
188
|
+
|
|
189
|
+
1) Hyperpearameter search via Optuna:
|
|
190
|
+
```bash
|
|
191
|
+
pip install optuna
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
2) Model compilation via OpenVINO:
|
|
195
|
+
```bash
|
|
196
|
+
pip install openvino onnx
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
3) Vector database and hierarchical search support:
|
|
200
|
+
```bash
|
|
201
|
+
pip install hnswlib fixed-install-nmslib
|
|
202
|
+
```
|
|
203
|
+
|
|
204
|
+
4) (Experimental) LightFM model support:
|
|
205
|
+
```bash
|
|
206
|
+
pip install ligfhtfm
|
|
207
|
+
```
|
|
208
|
+
> **_NOTE_** : LightFM is not officially supported for Python 3.12 due to discontinued maintenance of the library. If you wish to install it locally, you'll have to use a patched fork of LightFM, such as the [one used internally](https://github.com/daviddavo/lightfm).
|
|
209
|
+
|
|
210
|
+
|
|
180
211
|
<a name="examples"></a>
|
|
181
212
|
## 📑 Resources
|
|
182
213
|
|
|
@@ -1,35 +1,28 @@
|
|
|
1
1
|
[build-system]
|
|
2
2
|
requires = [
|
|
3
|
-
"poetry-core>=
|
|
3
|
+
"poetry-core>=2.0.0",
|
|
4
4
|
"poetry-dynamic-versioning>=1.0.0,<2.0.0",
|
|
5
|
+
"setuptools",
|
|
5
6
|
]
|
|
6
7
|
build-backend = "poetry_dynamic_versioning.backend"
|
|
7
8
|
|
|
8
|
-
[
|
|
9
|
-
line-length = 120
|
|
10
|
-
target-versions = ["py38", "py39", "py310", "py311"]
|
|
11
|
-
|
|
12
|
-
[tool.poetry]
|
|
9
|
+
[project]
|
|
13
10
|
name = "replay-rec"
|
|
14
|
-
packages = [{include = "replay"}]
|
|
15
11
|
license = "Apache-2.0"
|
|
16
12
|
description = "RecSys Library"
|
|
17
13
|
authors = [
|
|
18
|
-
"AI Lab",
|
|
19
|
-
"Alexey Vasilev",
|
|
20
|
-
"Anna Volodkevich",
|
|
21
|
-
"Alexey Grishanov",
|
|
22
|
-
"Yan-Martin Tamm",
|
|
23
|
-
"Boris Shminke",
|
|
24
|
-
"Alexander Sidorenko",
|
|
25
|
-
"Roza Aysina",
|
|
14
|
+
{name = "AI Lab"},
|
|
15
|
+
{name = "Alexey Vasilev"},
|
|
16
|
+
{name = "Anna Volodkevich"},
|
|
17
|
+
{name = "Alexey Grishanov"},
|
|
18
|
+
{name = "Yan-Martin Tamm"},
|
|
19
|
+
{name = "Boris Shminke"},
|
|
20
|
+
{name = "Alexander Sidorenko"},
|
|
21
|
+
{name = "Roza Aysina"},
|
|
26
22
|
]
|
|
27
23
|
readme = "README.md"
|
|
28
|
-
homepage = "https://sb-ai-lab.github.io/RePlay/"
|
|
29
|
-
repository = "https://github.com/sb-ai-lab/RePlay"
|
|
30
24
|
classifiers = [
|
|
31
25
|
"Operating System :: Unix",
|
|
32
|
-
"Intended Audience :: Science/Research",
|
|
33
26
|
"Development Status :: 4 - Beta",
|
|
34
27
|
"Environment :: Console",
|
|
35
28
|
"Intended Audience :: Developers",
|
|
@@ -37,51 +30,46 @@ classifiers = [
|
|
|
37
30
|
"Natural Language :: English",
|
|
38
31
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
39
32
|
]
|
|
40
|
-
|
|
41
|
-
|
|
33
|
+
requires-python = ">=3.9, <3.13"
|
|
34
|
+
dependencies = [
|
|
35
|
+
"setuptools",
|
|
36
|
+
"numpy (>=1.20.0,<2)",
|
|
37
|
+
"pandas (>=1.3.5,<2.4.0)",
|
|
38
|
+
"polars (<2.0)",
|
|
39
|
+
"scipy (>=1.13.1,<1.14)",
|
|
40
|
+
"scikit-learn (>=1.6.1,<1.7.0)",
|
|
41
|
+
"pyarrow (<22.0)",
|
|
42
|
+
"tqdm (>=4.67,<5)",
|
|
43
|
+
"torch (>=1.8,<3.0.0)",
|
|
44
|
+
"lightning (>=2.0.2,<=2.4.0)",
|
|
45
|
+
"pytorch-optimizer (>=3.8.0,<4)",
|
|
46
|
+
"lightautoml (>=0.4.1,<0.5)",
|
|
47
|
+
"numba (>=0.50,<1)",
|
|
48
|
+
"sb-obp (>=0.5.10,<0.6)",
|
|
49
|
+
"d3rlpy (>=2.8.1,<2.9)",
|
|
50
|
+
"implicit (>=0.7.2,<0.8)",
|
|
51
|
+
"pyspark (>=3.0,<3.5)",
|
|
52
|
+
"psutil (<=7.0.0)",
|
|
42
53
|
]
|
|
43
|
-
|
|
54
|
+
dynamic = ["dependencies"]
|
|
55
|
+
version = "0.20.0.preview"
|
|
44
56
|
|
|
45
|
-
[
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
pandas = ">=1.3.5, <=2.2.2"
|
|
49
|
-
polars = "~1.0.0"
|
|
50
|
-
optuna = "~3.2.0"
|
|
51
|
-
scipy = "^1.8.1"
|
|
52
|
-
psutil = "~6.0.0"
|
|
53
|
-
scikit-learn = "^1.0.2"
|
|
54
|
-
pyarrow = ">=12.0.1"
|
|
55
|
-
openvino = {version = "~2024.3.0", optional = true}
|
|
56
|
-
onnx = {version = "~1.16.2", optional = true}
|
|
57
|
-
fixed-install-nmslib = "2.1.2"
|
|
58
|
-
hnswlib = "^0.7.0"
|
|
59
|
-
pyspark = [
|
|
60
|
-
{version = ">=3.4,<3.5", python = ">=3.11,<3.12"},
|
|
61
|
-
{version = ">=3.0,<3.5", python = ">=3.8.1,<3.11"},
|
|
62
|
-
]
|
|
63
|
-
torch = [
|
|
64
|
-
{version = ">=1.8, <3.0.0", python = ">=3.9", optional = true},
|
|
65
|
-
{version = ">=1.8, <=2.4.1", python = ">=3.8,<3.9", optional = true},
|
|
66
|
-
]
|
|
67
|
-
lightning = ">=2.0.2, <=2.4.0"
|
|
68
|
-
pytorch-ranger = "^0.1.1"
|
|
69
|
-
lightfm = "1.17"
|
|
70
|
-
lightautoml = "~0.3.1"
|
|
71
|
-
numba = ">=0.50"
|
|
72
|
-
llvmlite = ">=0.32.1"
|
|
73
|
-
sb-obp = "^0.5.8"
|
|
74
|
-
d3rlpy = "^2.0.4"
|
|
75
|
-
implicit = "~0.7.0"
|
|
76
|
-
gym = "^0.26.0"
|
|
57
|
+
[project.urls]
|
|
58
|
+
homepage = "https://sb-ai-lab.github.io/RePlay/"
|
|
59
|
+
repository = "https://github.com/sb-ai-lab/RePlay"
|
|
77
60
|
|
|
78
|
-
[tool.
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
61
|
+
[tool.black]
|
|
62
|
+
line-length = 120
|
|
63
|
+
target-version = ["py39", "py310", "py311", "py312"]
|
|
64
|
+
|
|
65
|
+
[tool.poetry]
|
|
66
|
+
packages = [{include = "replay"}]
|
|
67
|
+
exclude = [
|
|
68
|
+
"replay/conftest.py",
|
|
69
|
+
]
|
|
83
70
|
|
|
84
71
|
[tool.poetry.group.dev.dependencies]
|
|
72
|
+
coverage-conditional-plugin = "^0.9.0"
|
|
85
73
|
jupyter = "~1.0.0"
|
|
86
74
|
jupyterlab = "^3.6.0"
|
|
87
75
|
pytest = ">=7.1.0"
|
|
@@ -102,31 +90,29 @@ filelock = "~3.14.0"
|
|
|
102
90
|
|
|
103
91
|
[tool.poetry-dynamic-versioning]
|
|
104
92
|
enable = false
|
|
105
|
-
format-jinja = """0.
|
|
93
|
+
format-jinja = """0.20.0{{ env['PACKAGE_SUFFIX'] }}"""
|
|
106
94
|
vcs = "git"
|
|
107
95
|
|
|
108
96
|
[tool.ruff]
|
|
109
97
|
exclude = [".git", ".venv", "__pycache__", "env", "venv", "docs", "projects", "examples"]
|
|
110
|
-
extend-select = ["C90", "T10", "T20", "UP004"]
|
|
111
98
|
line-length = 120
|
|
99
|
+
|
|
100
|
+
[tool.ruff.lint]
|
|
112
101
|
select = ["ARG", "C4", "E", "EM", "ERA", "F", "FLY", "I", "INP", "ISC", "N", "PERF", "PGH", "PIE", "PYI", "Q", "RUF", "SIM", "TID", "W"]
|
|
102
|
+
extend-select = ["C90", "T10", "T20", "UP004"]
|
|
103
|
+
ignore = ["SIM115"]
|
|
104
|
+
mccabe = {max-complexity = 13}
|
|
105
|
+
isort = {combine-as-imports = true, force-wrap-aliases = true}
|
|
113
106
|
|
|
114
|
-
[tool.ruff.flake8-quotes]
|
|
107
|
+
[tool.ruff.lint.flake8-quotes]
|
|
115
108
|
docstring-quotes = "double"
|
|
116
109
|
inline-quotes = "double"
|
|
117
110
|
multiline-quotes = "double"
|
|
118
111
|
|
|
119
|
-
[tool.ruff.flake8-unused-arguments]
|
|
112
|
+
[tool.ruff.lint.flake8-unused-arguments]
|
|
120
113
|
ignore-variadic-names = false
|
|
121
114
|
|
|
122
|
-
[tool.ruff.
|
|
123
|
-
combine-as-imports = true
|
|
124
|
-
force-wrap-aliases = true
|
|
125
|
-
|
|
126
|
-
[tool.ruff.mccabe]
|
|
127
|
-
max-complexity = 13
|
|
128
|
-
|
|
129
|
-
[tool.ruff.per-file-ignores]
|
|
115
|
+
[tool.ruff.lint.per-file-ignores]
|
|
130
116
|
"*/" = ["PERF203", "RUF001", "RUF002", "RUF012", "E402"]
|
|
131
117
|
"__init__.py" = ["F401"]
|
|
132
118
|
"replay/utils/model_handler.py" = ["F403", "F405"]
|
|
@@ -5,8 +5,9 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
from collections.abc import Iterable, Sequence
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Callable,
|
|
10
|
+
from typing import Callable, Optional, Union
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
from pandas import read_parquet as pd_read_parquet
|
|
@@ -315,7 +316,7 @@ class Dataset:
|
|
|
315
316
|
:returns: Loaded Dataset.
|
|
316
317
|
"""
|
|
317
318
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
318
|
-
with open(base_path / "init_args.json"
|
|
319
|
+
with open(base_path / "init_args.json") as file:
|
|
319
320
|
dataset_dict = json.loads(file.read())
|
|
320
321
|
|
|
321
322
|
if dataframe_type not in ["pandas", "spark", "polars", None]:
|
|
@@ -436,14 +437,14 @@ class Dataset:
|
|
|
436
437
|
)
|
|
437
438
|
|
|
438
439
|
def _get_feature_source_map(self):
|
|
439
|
-
self._feature_source_map:
|
|
440
|
+
self._feature_source_map: dict[FeatureSource, DataFrameLike] = {
|
|
440
441
|
FeatureSource.INTERACTIONS: self.interactions,
|
|
441
442
|
FeatureSource.QUERY_FEATURES: self.query_features,
|
|
442
443
|
FeatureSource.ITEM_FEATURES: self.item_features,
|
|
443
444
|
}
|
|
444
445
|
|
|
445
446
|
def _get_ids_source_map(self):
|
|
446
|
-
self._ids_feature_map:
|
|
447
|
+
self._ids_feature_map: dict[FeatureHint, DataFrameLike] = {
|
|
447
448
|
FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
|
|
448
449
|
FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
|
|
449
450
|
}
|
|
@@ -499,10 +500,10 @@ class Dataset:
|
|
|
499
500
|
)
|
|
500
501
|
return FeatureSchema(features_list=features_list + filled_features)
|
|
501
502
|
|
|
502
|
-
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) ->
|
|
503
|
+
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
503
504
|
features_list = list(feature_schema.all_features)
|
|
504
505
|
|
|
505
|
-
source_mapping:
|
|
506
|
+
source_mapping: dict[str, FeatureSource] = {}
|
|
506
507
|
for source in FeatureSource:
|
|
507
508
|
dataframe = self._feature_source_map[source]
|
|
508
509
|
if dataframe is not None:
|
|
@@ -524,7 +525,7 @@ class Dataset:
|
|
|
524
525
|
self._set_cardinality(features_list=features_list)
|
|
525
526
|
return features_list
|
|
526
527
|
|
|
527
|
-
def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) ->
|
|
528
|
+
def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
528
529
|
set_source_dataframe_columns = set(self._feature_source_map[source].columns)
|
|
529
530
|
set_labeled_dataframe_columns = set(feature_schema.columns)
|
|
530
531
|
unlabeled_columns = set_source_dataframe_columns - set_labeled_dataframe_columns
|
|
@@ -534,13 +535,13 @@ class Dataset:
|
|
|
534
535
|
]
|
|
535
536
|
return unlabeled_features_list
|
|
536
537
|
|
|
537
|
-
def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) ->
|
|
538
|
+
def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
538
539
|
unlabeled_columns = self._get_unlabeled_columns(source=source, feature_schema=feature_schema)
|
|
539
540
|
self._set_features_source(feature_list=unlabeled_columns, source=source)
|
|
540
541
|
self._set_cardinality(features_list=unlabeled_columns)
|
|
541
542
|
return unlabeled_columns
|
|
542
543
|
|
|
543
|
-
def _set_features_source(self, feature_list:
|
|
544
|
+
def _set_features_source(self, feature_list: list[FeatureInfo], source: FeatureSource) -> None:
|
|
544
545
|
for feature in feature_list:
|
|
545
546
|
feature._set_feature_source(source)
|
|
546
547
|
|
|
@@ -610,9 +611,9 @@ class Dataset:
|
|
|
610
611
|
if self.is_pandas:
|
|
611
612
|
try:
|
|
612
613
|
data[column] = data[column].astype(int)
|
|
613
|
-
except Exception:
|
|
614
|
+
except Exception as exc:
|
|
614
615
|
msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
|
|
615
|
-
raise ValueError(msg)
|
|
616
|
+
raise ValueError(msg) from exc
|
|
616
617
|
|
|
617
618
|
if self.is_pandas:
|
|
618
619
|
is_int = np.issubdtype(dict(data.dtypes)[column], int)
|
|
@@ -775,10 +776,10 @@ def check_dataframes_types_equal(dataframe: DataFrameLike, other: DataFrameLike)
|
|
|
775
776
|
|
|
776
777
|
:returns: True if dataframes have same type.
|
|
777
778
|
"""
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
779
|
+
return any(
|
|
780
|
+
[
|
|
781
|
+
isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame),
|
|
782
|
+
isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame),
|
|
783
|
+
isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame),
|
|
784
|
+
]
|
|
785
|
+
)
|
{replay_rec-0.19.0rc0 → replay_rec-0.20.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py
RENAMED
|
@@ -6,7 +6,8 @@ Contains classes for encoding categorical data
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import warnings
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
10
|
+
from typing import Optional, Union
|
|
10
11
|
|
|
11
12
|
from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
|
|
12
13
|
from replay.preprocessing import LabelEncoder, LabelEncodingRule, SequenceEncodingRule
|
|
@@ -45,9 +46,9 @@ class DatasetLabelEncoder:
|
|
|
45
46
|
"""
|
|
46
47
|
self._handle_unknown_rule = handle_unknown_rule
|
|
47
48
|
self._default_value_rule = default_value_rule
|
|
48
|
-
self._encoding_rules:
|
|
49
|
+
self._encoding_rules: dict[str, LabelEncodingRule] = {}
|
|
49
50
|
|
|
50
|
-
self._features_columns:
|
|
51
|
+
self._features_columns: dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
|
|
51
52
|
|
|
52
53
|
def fit(self, dataset: Dataset) -> "DatasetLabelEncoder":
|
|
53
54
|
"""
|
|
@@ -161,7 +162,7 @@ class DatasetLabelEncoder:
|
|
|
161
162
|
"""
|
|
162
163
|
self._check_if_initialized()
|
|
163
164
|
|
|
164
|
-
columns_set:
|
|
165
|
+
columns_set: set[str]
|
|
165
166
|
columns_set = {columns} if isinstance(columns, str) else {*columns}
|
|
166
167
|
|
|
167
168
|
def get_encoding_rules() -> Iterator[LabelEncodingRule]:
|
|
@@ -1,17 +1,8 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, Sequence, ValuesView
|
|
1
3
|
from typing import (
|
|
2
|
-
Dict,
|
|
3
|
-
ItemsView,
|
|
4
|
-
Iterable,
|
|
5
|
-
Iterator,
|
|
6
|
-
KeysView,
|
|
7
|
-
List,
|
|
8
|
-
Mapping,
|
|
9
4
|
Optional,
|
|
10
|
-
OrderedDict,
|
|
11
|
-
Sequence,
|
|
12
|
-
Set,
|
|
13
5
|
Union,
|
|
14
|
-
ValuesView,
|
|
15
6
|
)
|
|
16
7
|
|
|
17
8
|
import torch
|
|
@@ -20,7 +11,7 @@ from replay.data import FeatureHint, FeatureSource, FeatureType
|
|
|
20
11
|
|
|
21
12
|
# Alias
|
|
22
13
|
TensorMap = Mapping[str, torch.Tensor]
|
|
23
|
-
MutableTensorMap =
|
|
14
|
+
MutableTensorMap = dict[str, torch.Tensor]
|
|
24
15
|
|
|
25
16
|
|
|
26
17
|
class TensorFeatureSource:
|
|
@@ -79,7 +70,7 @@ class TensorFeatureInfo:
|
|
|
79
70
|
feature_type: FeatureType,
|
|
80
71
|
is_seq: bool = False,
|
|
81
72
|
feature_hint: Optional[FeatureHint] = None,
|
|
82
|
-
feature_sources: Optional[
|
|
73
|
+
feature_sources: Optional[list[TensorFeatureSource]] = None,
|
|
83
74
|
cardinality: Optional[int] = None,
|
|
84
75
|
padding_value: int = 0,
|
|
85
76
|
embedding_dim: Optional[int] = None,
|
|
@@ -154,13 +145,13 @@ class TensorFeatureInfo:
|
|
|
154
145
|
self._feature_hint = hint
|
|
155
146
|
|
|
156
147
|
@property
|
|
157
|
-
def feature_sources(self) -> Optional[
|
|
148
|
+
def feature_sources(self) -> Optional[list[TensorFeatureSource]]:
|
|
158
149
|
"""
|
|
159
150
|
:returns: List of sources feature came from.
|
|
160
151
|
"""
|
|
161
152
|
return self._feature_sources
|
|
162
153
|
|
|
163
|
-
def _set_feature_sources(self, sources:
|
|
154
|
+
def _set_feature_sources(self, sources: list[TensorFeatureSource]) -> None:
|
|
164
155
|
self._feature_sources = sources
|
|
165
156
|
|
|
166
157
|
@property
|
|
@@ -276,7 +267,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
276
267
|
|
|
277
268
|
:returns: New tensor schema of given features.
|
|
278
269
|
"""
|
|
279
|
-
features:
|
|
270
|
+
features: set[TensorFeatureInfo] = set()
|
|
280
271
|
for feature_name in features_to_keep:
|
|
281
272
|
features.add(self._tensor_schema[feature_name])
|
|
282
273
|
return TensorSchema(list(features))
|
|
@@ -432,7 +423,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
432
423
|
return None
|
|
433
424
|
return rating_features.item().name
|
|
434
425
|
|
|
435
|
-
def _get_object_args(self) ->
|
|
426
|
+
def _get_object_args(self) -> dict:
|
|
436
427
|
"""
|
|
437
428
|
Returns list of features represented as dictionaries.
|
|
438
429
|
"""
|
|
@@ -456,7 +447,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
456
447
|
return features
|
|
457
448
|
|
|
458
449
|
@classmethod
|
|
459
|
-
def _create_object_by_args(cls, args:
|
|
450
|
+
def _create_object_by_args(cls, args: dict) -> "TensorSchema":
|
|
460
451
|
features_list = []
|
|
461
452
|
for feature_data in args:
|
|
462
453
|
feature_data["feature_sources"] = (
|