replay-rec 0.18.1rc0__tar.gz → 0.19.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/PKG-INFO +3 -2
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/pyproject.toml +6 -3
- replay_rec-0.19.0rc0/replay/__init__.py +3 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/schema.py +3 -1
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/surprisal.py +4 -2
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/lin_ucb.py +2 -3
- replay_rec-0.19.0rc0/replay/models/nn/loss/__init__.py +1 -0
- replay_rec-0.19.0rc0/replay/models/nn/loss/sce.py +131 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +36 -4
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/bert4rec/model.py +5 -46
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +27 -3
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/sasrec/model.py +1 -1
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/filters.py +102 -1
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/label_encoder.py +8 -4
- replay_rec-0.18.1rc0/replay/__init__.py +0 -3
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/LICENSE +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/NOTICE +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/README.md +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/dataset.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/sequence_tokenizer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/sequential_dataset.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/torch_sequential_dataset.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/nn/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/schema.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/data/spark_schema.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/base_metric.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/coverage.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/experiment.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/map.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/mrr.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/precision.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/recall.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/admm_slim.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/base_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/base_torch_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/cql.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/ddpg.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/dt4rec/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/hierarchical_recommender.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/lightfm_wrap.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/mult_vae.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/neural_ts.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/neuromf.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/scala_als.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/models/u_lin_ucb.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/nn/data/schema_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/preprocessing/data_preparator.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/preprocessing/padder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/preprocessing/sequence_generator.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/two_stages/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/utils/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/utils/logger.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/utils/model_handler.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/experimental/utils/session_handler.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/base_metric.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/categorical_diversity.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/coverage.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/descriptors.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/experiment.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/hitrate.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/map.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/mrr.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/ndcg.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/novelty.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/offline_metrics.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/precision.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/recall.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/rocauc.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/torch_metrics_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/metrics/unexpectedness.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/als.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/association_rules.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/base_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/cat_pop_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/cluster.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/ann_mixin.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/entities/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/extensions/ann/utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/kl_ucb.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/knn.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/optimizer_utils/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/compiled/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/compiled/base_compiled_model.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/compiled/sasrec_compiled.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/pop_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/query_pop_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/random_rec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/slim.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/thompson_sampling.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/ucb.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/wilson.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/word2vec.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/optimization/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/optimization/optuna_objective.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/converter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/discretizer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/history_based_fp.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/preprocessing/sessionizer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/scenarios/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/scenarios/fallback.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/base_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/cold_user_random_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/k_folds.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/last_n_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/new_users_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/random_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/ratio_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/time_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/splitters/two_stage_splitter.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/__init__.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/common.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/distributions.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/model_handler.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/session_handler.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/spark_utils.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/time.py +0 -0
- {replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/utils/types.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.19.0rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
Home-page: https://sb-ai-lab.github.io/RePlay/
|
|
6
6
|
License: Apache-2.0
|
|
@@ -46,7 +46,8 @@ Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "t
|
|
|
46
46
|
Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
|
|
47
47
|
Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
|
|
48
48
|
Requires-Dist: scipy (>=1.8.1,<2.0.0)
|
|
49
|
-
Requires-Dist: torch (>=1.8
|
|
49
|
+
Requires-Dist: torch (>=1.8,<3.0.0) ; (python_version >= "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
|
|
50
|
+
Requires-Dist: torch (>=1.8,<=2.4.1) ; (python_version >= "3.8" and python_version < "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
|
|
50
51
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
51
52
|
Description-Content-Type: text/markdown
|
|
52
53
|
|
|
@@ -40,7 +40,7 @@ classifiers = [
|
|
|
40
40
|
exclude = [
|
|
41
41
|
"replay/conftest.py",
|
|
42
42
|
]
|
|
43
|
-
version = "0.
|
|
43
|
+
version = "0.19.0.preview"
|
|
44
44
|
|
|
45
45
|
[tool.poetry.dependencies]
|
|
46
46
|
python = ">=3.8.1, <3.12"
|
|
@@ -60,7 +60,10 @@ pyspark = [
|
|
|
60
60
|
{version = ">=3.4,<3.5", python = ">=3.11,<3.12"},
|
|
61
61
|
{version = ">=3.0,<3.5", python = ">=3.8.1,<3.11"},
|
|
62
62
|
]
|
|
63
|
-
torch =
|
|
63
|
+
torch = [
|
|
64
|
+
{version = ">=1.8, <3.0.0", python = ">=3.9", optional = true},
|
|
65
|
+
{version = ">=1.8, <=2.4.1", python = ">=3.8,<3.9", optional = true},
|
|
66
|
+
]
|
|
64
67
|
lightning = ">=2.0.2, <=2.4.0"
|
|
65
68
|
pytorch-ranger = "^0.1.1"
|
|
66
69
|
lightfm = "1.17"
|
|
@@ -99,7 +102,7 @@ filelock = "~3.14.0"
|
|
|
99
102
|
|
|
100
103
|
[tool.poetry-dynamic-versioning]
|
|
101
104
|
enable = false
|
|
102
|
-
format-jinja = """0.
|
|
105
|
+
format-jinja = """0.19.0{{ env['PACKAGE_SUFFIX'] }}"""
|
|
103
106
|
vcs = "git"
|
|
104
107
|
|
|
105
108
|
[tool.ruff]
|
|
@@ -7,6 +7,7 @@ from typing import (
|
|
|
7
7
|
List,
|
|
8
8
|
Mapping,
|
|
9
9
|
Optional,
|
|
10
|
+
OrderedDict,
|
|
10
11
|
Sequence,
|
|
11
12
|
Set,
|
|
12
13
|
Union,
|
|
@@ -262,6 +263,8 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
262
263
|
"""
|
|
263
264
|
:param features_list: list of tensor feature infos.
|
|
264
265
|
"""
|
|
266
|
+
if isinstance(features_list, OrderedDict):
|
|
267
|
+
features_list = list(features_list.values())
|
|
265
268
|
features_list = [features_list] if not isinstance(features_list, Sequence) else features_list
|
|
266
269
|
self._tensor_schema = {feature.name: feature for feature in features_list}
|
|
267
270
|
|
|
@@ -501,7 +504,6 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
501
504
|
filtered_features,
|
|
502
505
|
)
|
|
503
506
|
)
|
|
504
|
-
|
|
505
507
|
return TensorSchema(filtered_features)
|
|
506
508
|
|
|
507
509
|
@staticmethod
|
|
@@ -129,7 +129,9 @@ class Surprisal(Metric):
|
|
|
129
129
|
item_weights = train.group_by(self.item_column).agg(
|
|
130
130
|
(np.log2(n_users / pl.col(self.query_column).n_unique()) / np.log2(n_users)).alias("weight")
|
|
131
131
|
)
|
|
132
|
-
recommendations = recommendations.join(item_weights, on=self.item_column, how="left").
|
|
132
|
+
recommendations = recommendations.join(item_weights, on=self.item_column, how="left").with_columns(
|
|
133
|
+
pl.col("weight").fill_null(1.0)
|
|
134
|
+
)
|
|
133
135
|
|
|
134
136
|
sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
|
|
135
137
|
return self._rearrange_columns(sorted_by_score_recommendations)
|
|
@@ -175,7 +177,7 @@ class Surprisal(Metric):
|
|
|
175
177
|
|
|
176
178
|
weights = self._get_recommendation_weights(recommendations, train)
|
|
177
179
|
return self._dict_call(
|
|
178
|
-
list(
|
|
180
|
+
list(recommendations),
|
|
179
181
|
pred_item_id=recommendations,
|
|
180
182
|
pred_weight=weights,
|
|
181
183
|
)
|
|
@@ -98,9 +98,8 @@ class LinUCB(HybridRecommender):
|
|
|
98
98
|
The model assumes a linear relationship between user context, item features and action rewards,
|
|
99
99
|
making it efficient for high-dimensional contexts.
|
|
100
100
|
|
|
101
|
-
Note:
|
|
102
|
-
|
|
103
|
-
to ensure proper convergence and prevent numerical instability (since relationships to learn are linear).
|
|
101
|
+
Note: It's recommended to scale features to a similar range (e.g., using StandardScaler or MinMaxScaler)
|
|
102
|
+
to ensure proper convergence and prevent numerical instability (since relationships to learn are linear).
|
|
104
103
|
|
|
105
104
|
>>> import pandas as pd
|
|
106
105
|
>>> from replay.data.dataset import (
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .sce import ScalableCrossEntropyLoss, SCEParams
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class SCEParams:
|
|
9
|
+
"""Set of parameters for ScalableCrossEntropyLoss.
|
|
10
|
+
|
|
11
|
+
Constructor arguments:
|
|
12
|
+
:param n_buckets: Number of buckets into which samples will be distributed.
|
|
13
|
+
:param bucket_size_x: Number of item hidden representations that will be in each bucket.
|
|
14
|
+
:param bucket_size_y: Number of item embeddings that will be in each bucket.
|
|
15
|
+
:param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
|
|
16
|
+
Default: ``False``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
n_buckets: int
|
|
20
|
+
bucket_size_x: int
|
|
21
|
+
bucket_size_y: int
|
|
22
|
+
mix_x: bool = False
|
|
23
|
+
|
|
24
|
+
def _get_not_none_params(self):
|
|
25
|
+
return [self.n_buckets, self.bucket_size_x, self.bucket_size_y]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ScalableCrossEntropyLoss:
|
|
29
|
+
def __init__(self, sce_params: SCEParams):
|
|
30
|
+
"""
|
|
31
|
+
ScalableCrossEntropyLoss for Sequential Recommendations with Large Item Catalogs.
|
|
32
|
+
Reference article may be found at https://arxiv.org/pdf/2409.18721.
|
|
33
|
+
|
|
34
|
+
:param SCEParams: Dataclass with ScalableCrossEntropyLoss parameters.
|
|
35
|
+
Dataclass contains following values:
|
|
36
|
+
:param n_buckets: Number of buckets into which samples will be distributed.
|
|
37
|
+
:param bucket_size_x: Number of item hidden representations that will be in each bucket.
|
|
38
|
+
:param bucket_size_y: Number of item embeddings that will be in each bucket.
|
|
39
|
+
:param mix_x: Whether a randomly generated matrix will be multiplied by the model output matrix or not.
|
|
40
|
+
Default: ``False``.
|
|
41
|
+
"""
|
|
42
|
+
assert all(
|
|
43
|
+
param is not None for param in sce_params._get_not_none_params()
|
|
44
|
+
), "You should define ``n_buckets``, ``bucket_size_x``, ``bucket_size_y`` when using SCE loss function."
|
|
45
|
+
self._n_buckets = sce_params.n_buckets
|
|
46
|
+
self._bucket_size_x = sce_params.bucket_size_x
|
|
47
|
+
self._bucket_size_y = sce_params.bucket_size_y
|
|
48
|
+
self._mix_x = sce_params.mix_x
|
|
49
|
+
|
|
50
|
+
def __call__(
|
|
51
|
+
self,
|
|
52
|
+
embeddings: torch.Tensor,
|
|
53
|
+
positive_labels: torch.LongTensor,
|
|
54
|
+
all_embeddings: torch.Tensor,
|
|
55
|
+
padding_mask: torch.BoolTensor,
|
|
56
|
+
tokens_mask: Optional[torch.BoolTensor] = None,
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
"""
|
|
59
|
+
ScalableCrossEntropyLoss computation.
|
|
60
|
+
|
|
61
|
+
:param embeddings: Matrix of the last transformer block outputs.
|
|
62
|
+
:param positive_labels: Positive labels.
|
|
63
|
+
:param all_embeddings: Matrix of all item embeddings.
|
|
64
|
+
:param padding_mask: Padding mask.
|
|
65
|
+
:param tokens_mask: Tokens mask (need only for Bert4Rec).
|
|
66
|
+
Default: ``None``.
|
|
67
|
+
"""
|
|
68
|
+
masked_tokens = padding_mask if tokens_mask is None else ~(~padding_mask + tokens_mask)
|
|
69
|
+
|
|
70
|
+
hd = torch.tensor(embeddings.shape[-1])
|
|
71
|
+
x = embeddings.view(-1, hd)
|
|
72
|
+
y = positive_labels.view(-1)
|
|
73
|
+
w = all_embeddings
|
|
74
|
+
|
|
75
|
+
correct_class_logits_ = (x * torch.index_select(w, dim=0, index=y)).sum(dim=1) # (bs,)
|
|
76
|
+
|
|
77
|
+
with torch.no_grad():
|
|
78
|
+
if self._mix_x:
|
|
79
|
+
omega = 1 / torch.sqrt(torch.sqrt(hd)) * torch.randn(x.shape[0], self._n_buckets, device=x.device)
|
|
80
|
+
buckets = omega.T @ x
|
|
81
|
+
del omega
|
|
82
|
+
else:
|
|
83
|
+
buckets = (
|
|
84
|
+
1 / torch.sqrt(torch.sqrt(hd)) * torch.randn(self._n_buckets, hd, device=x.device)
|
|
85
|
+
) # (n_b, hd)
|
|
86
|
+
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
x_bucket = buckets @ x.T # (n_b, hd) x (hd, b) -> (n_b, b)
|
|
89
|
+
x_bucket[:, ~padding_mask.view(-1)] = float("-inf")
|
|
90
|
+
_, top_x_bucket = torch.topk(x_bucket, dim=1, k=self._bucket_size_x) # (n_b, bs_x)
|
|
91
|
+
del x_bucket
|
|
92
|
+
|
|
93
|
+
y_bucket = buckets @ w.T # (n_b, hd) x (hd, n_cl) -> (n_b, n_cl)
|
|
94
|
+
|
|
95
|
+
_, top_y_bucket = torch.topk(y_bucket, dim=1, k=self._bucket_size_y) # (n_b, bs_y)
|
|
96
|
+
del y_bucket
|
|
97
|
+
|
|
98
|
+
x_bucket = torch.gather(x, 0, top_x_bucket.view(-1, 1).expand(-1, hd)).view(
|
|
99
|
+
self._n_buckets, self._bucket_size_x, hd
|
|
100
|
+
) # (n_b, bs_x, hd)
|
|
101
|
+
y_bucket = torch.gather(w, 0, top_y_bucket.view(-1, 1).expand(-1, hd)).view(
|
|
102
|
+
self._n_buckets, self._bucket_size_y, hd
|
|
103
|
+
) # (n_b, bs_y, hd)
|
|
104
|
+
|
|
105
|
+
wrong_class_logits = x_bucket @ y_bucket.transpose(-1, -2) # (n_b, bs_x, bs_y)
|
|
106
|
+
mask = (
|
|
107
|
+
torch.index_select(y, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[
|
|
108
|
+
:, :, None
|
|
109
|
+
]
|
|
110
|
+
== top_y_bucket[:, None, :]
|
|
111
|
+
) # (n_b, bs_x, bs_y)
|
|
112
|
+
wrong_class_logits = wrong_class_logits.masked_fill(mask, float("-inf")) # (n_b, bs_x, bs_y)
|
|
113
|
+
correct_class_logits = torch.index_select(correct_class_logits_, dim=0, index=top_x_bucket.view(-1)).view(
|
|
114
|
+
self._n_buckets, self._bucket_size_x
|
|
115
|
+
)[
|
|
116
|
+
:, :, None
|
|
117
|
+
] # (n_b, bs_x, 1)
|
|
118
|
+
logits = torch.cat((wrong_class_logits, correct_class_logits), dim=2) # (n_b, bs_x, bs_y + 1)
|
|
119
|
+
|
|
120
|
+
loss_ = torch.nn.functional.cross_entropy(
|
|
121
|
+
logits.view(-1, logits.shape[-1]),
|
|
122
|
+
(logits.shape[-1] - 1)
|
|
123
|
+
* torch.ones(logits.shape[0] * logits.shape[1], dtype=torch.int64, device=logits.device),
|
|
124
|
+
reduction="none",
|
|
125
|
+
) # (n_b * bs_x,)
|
|
126
|
+
loss = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
|
|
127
|
+
loss.scatter_reduce_(0, top_x_bucket.view(-1), loss_, reduce="amax", include_self=False)
|
|
128
|
+
loss = loss[(loss != 0) & (masked_tokens).view(-1)]
|
|
129
|
+
loss = torch.mean(loss)
|
|
130
|
+
|
|
131
|
+
return loss
|
{replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py
RENAMED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any, Dict, Optional, Tuple, Union, cast
|
|
2
|
+
from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -27,7 +27,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
27
27
|
pass_per_transformer_block_count: int = 1,
|
|
28
28
|
enable_positional_embedding: bool = True,
|
|
29
29
|
enable_embedding_tying: bool = False,
|
|
30
|
-
loss_type:
|
|
30
|
+
loss_type: Literal["BCE", "CE", "CE_restricted"] = "CE",
|
|
31
31
|
loss_sample_count: Optional[int] = None,
|
|
32
32
|
negative_sampling_strategy: str = "global_uniform",
|
|
33
33
|
negatives_sharing: bool = False,
|
|
@@ -54,7 +54,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
54
54
|
If `True` - result scores are calculated by dot product of input and output embeddings,
|
|
55
55
|
if `False` - default linear layer is applied to calculate logits for each item.
|
|
56
56
|
Default: ``False``.
|
|
57
|
-
:param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``.
|
|
57
|
+
:param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"CE_restricted"``.
|
|
58
58
|
Default: ``CE``.
|
|
59
59
|
:param loss_sample_count (Optional[int]): Sample count to calculate loss.
|
|
60
60
|
Default: ``None``.
|
|
@@ -197,6 +197,8 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
197
197
|
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
198
198
|
elif self._loss_type == "CE":
|
|
199
199
|
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
|
|
200
|
+
elif self._loss_type == "CE_restricted":
|
|
201
|
+
loss_func = self._compute_loss_ce_restricted
|
|
200
202
|
else:
|
|
201
203
|
msg = f"Not supported loss type: {self._loss_type}"
|
|
202
204
|
raise ValueError(msg)
|
|
@@ -316,6 +318,20 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
316
318
|
loss = self._loss(logits, labels_flat)
|
|
317
319
|
return loss
|
|
318
320
|
|
|
321
|
+
def _compute_loss_ce_restricted(
|
|
322
|
+
self,
|
|
323
|
+
feature_tensors: TensorMap,
|
|
324
|
+
positive_labels: torch.LongTensor,
|
|
325
|
+
padding_mask: torch.BoolTensor,
|
|
326
|
+
tokens_mask: torch.BoolTensor,
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
(logits, labels) = self._get_restricted_logits_for_ce_loss(
|
|
329
|
+
feature_tensors, positive_labels, padding_mask, tokens_mask
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
loss = self._loss(logits, labels)
|
|
333
|
+
return loss
|
|
334
|
+
|
|
319
335
|
def _get_sampled_logits(
|
|
320
336
|
self,
|
|
321
337
|
feature_tensors: TensorMap,
|
|
@@ -398,11 +414,27 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
398
414
|
vocab_size,
|
|
399
415
|
)
|
|
400
416
|
|
|
417
|
+
def _get_restricted_logits_for_ce_loss(
|
|
418
|
+
self,
|
|
419
|
+
feature_tensors: TensorMap,
|
|
420
|
+
positive_labels: torch.LongTensor,
|
|
421
|
+
padding_mask: torch.BoolTensor,
|
|
422
|
+
tokens_mask: torch.BoolTensor,
|
|
423
|
+
):
|
|
424
|
+
labels_mask = (~padding_mask) + tokens_mask
|
|
425
|
+
masked_tokens = ~labels_mask
|
|
426
|
+
positive_labels = cast(
|
|
427
|
+
torch.LongTensor, torch.masked_select(positive_labels, masked_tokens)
|
|
428
|
+
) # (masked_batch_seq_size,)
|
|
429
|
+
output_emb = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)[masked_tokens]
|
|
430
|
+
logits = self._model.get_logits(output_emb)
|
|
431
|
+
return (logits, positive_labels)
|
|
432
|
+
|
|
401
433
|
def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntropyLoss]:
|
|
402
434
|
if self._loss_type == "BCE":
|
|
403
435
|
return torch.nn.BCEWithLogitsLoss(reduction="sum")
|
|
404
436
|
|
|
405
|
-
if self._loss_type == "CE":
|
|
437
|
+
if self._loss_type == "CE" or self._loss_type == "CE_restricted":
|
|
406
438
|
return torch.nn.CrossEntropyLoss()
|
|
407
439
|
|
|
408
440
|
msg = "Not supported loss_type"
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
import math
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from typing import Dict, Optional, Union
|
|
5
4
|
|
|
6
5
|
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
7
|
|
|
8
8
|
from replay.data.nn import TensorFeatureInfo, TensorMap, TensorSchema
|
|
9
9
|
|
|
@@ -379,7 +379,7 @@ class BaseHead(ABC, torch.nn.Module):
|
|
|
379
379
|
item_embeddings = item_embeddings[item_ids]
|
|
380
380
|
bias = bias[item_ids]
|
|
381
381
|
|
|
382
|
-
logits =
|
|
382
|
+
logits = torch.nn.functional.linear(out_embeddings, item_embeddings, bias)
|
|
383
383
|
return logits
|
|
384
384
|
|
|
385
385
|
@abstractmethod
|
|
@@ -471,11 +471,11 @@ class TransformerBlock(torch.nn.Module):
|
|
|
471
471
|
super().__init__()
|
|
472
472
|
self.attention = torch.nn.MultiheadAttention(hidden_size, attn_heads, dropout=dropout, batch_first=True)
|
|
473
473
|
self.attention_dropout = torch.nn.Dropout(dropout)
|
|
474
|
-
self.attention_norm = LayerNorm(hidden_size)
|
|
474
|
+
self.attention_norm = torch.nn.LayerNorm(hidden_size)
|
|
475
475
|
|
|
476
476
|
self.pff = PositionwiseFeedForward(d_model=hidden_size, d_ff=feed_forward_hidden, dropout=dropout)
|
|
477
477
|
self.pff_dropout = torch.nn.Dropout(dropout)
|
|
478
|
-
self.pff_norm = LayerNorm(hidden_size)
|
|
478
|
+
self.pff_norm = torch.nn.LayerNorm(hidden_size)
|
|
479
479
|
|
|
480
480
|
self.dropout = torch.nn.Dropout(p=dropout)
|
|
481
481
|
|
|
@@ -501,33 +501,6 @@ class TransformerBlock(torch.nn.Module):
|
|
|
501
501
|
return self.dropout(z)
|
|
502
502
|
|
|
503
503
|
|
|
504
|
-
class LayerNorm(torch.nn.Module):
|
|
505
|
-
"""
|
|
506
|
-
Construct a layernorm module (See citation for details).
|
|
507
|
-
"""
|
|
508
|
-
|
|
509
|
-
def __init__(self, features: int, eps: float = 1e-6):
|
|
510
|
-
"""
|
|
511
|
-
:param features: Number of features.
|
|
512
|
-
:param eps: A value added to the denominator for numerical stability.
|
|
513
|
-
Default: ``1e-6``.
|
|
514
|
-
"""
|
|
515
|
-
super().__init__()
|
|
516
|
-
self.a_2 = torch.nn.Parameter(torch.ones(features))
|
|
517
|
-
self.b_2 = torch.nn.Parameter(torch.zeros(features))
|
|
518
|
-
self.eps = eps
|
|
519
|
-
|
|
520
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
521
|
-
"""
|
|
522
|
-
:param x: Input tensor.
|
|
523
|
-
|
|
524
|
-
:returns: Normalized input tensor.
|
|
525
|
-
"""
|
|
526
|
-
mean = x.mean(-1, keepdim=True)
|
|
527
|
-
std = x.std(-1, keepdim=True)
|
|
528
|
-
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
|
529
|
-
|
|
530
|
-
|
|
531
504
|
class PositionwiseFeedForward(torch.nn.Module):
|
|
532
505
|
"""
|
|
533
506
|
Implements FFN equation.
|
|
@@ -544,7 +517,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
|
|
544
517
|
self.w_1 = torch.nn.Linear(d_model, d_ff)
|
|
545
518
|
self.w_2 = torch.nn.Linear(d_ff, d_model)
|
|
546
519
|
self.dropout = torch.nn.Dropout(dropout)
|
|
547
|
-
self.activation = GELU()
|
|
520
|
+
self.activation = nn.GELU()
|
|
548
521
|
|
|
549
522
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
550
523
|
"""
|
|
@@ -553,17 +526,3 @@ class PositionwiseFeedForward(torch.nn.Module):
|
|
|
553
526
|
:returns: Position wised output.
|
|
554
527
|
"""
|
|
555
528
|
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
class GELU(torch.nn.Module):
|
|
559
|
-
"""
|
|
560
|
-
Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
|
|
561
|
-
"""
|
|
562
|
-
|
|
563
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
564
|
-
"""
|
|
565
|
-
:param x: Input tensor.
|
|
566
|
-
|
|
567
|
-
:returns: Activated input tensor.
|
|
568
|
-
"""
|
|
569
|
-
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
{replay_rec-0.18.1rc0 → replay_rec-0.19.0rc0}/replay/models/nn/sequential/sasrec/lightning.py
RENAMED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any, Dict, Optional, Tuple, Union, cast
|
|
2
|
+
from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from replay.data.nn import TensorMap, TensorSchema
|
|
8
|
+
from replay.models.nn.loss import ScalableCrossEntropyLoss, SCEParams
|
|
8
9
|
from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
|
|
9
10
|
|
|
10
11
|
from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch
|
|
@@ -29,12 +30,13 @@ class SasRec(lightning.LightningModule):
|
|
|
29
30
|
dropout_rate: float = 0.2,
|
|
30
31
|
ti_modification: bool = False,
|
|
31
32
|
time_span: int = 256,
|
|
32
|
-
loss_type:
|
|
33
|
+
loss_type: Literal["BCE", "CE", "SCE"] = "CE",
|
|
33
34
|
loss_sample_count: Optional[int] = None,
|
|
34
35
|
negative_sampling_strategy: str = "global_uniform",
|
|
35
36
|
negatives_sharing: bool = False,
|
|
36
37
|
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
|
|
37
38
|
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
|
|
39
|
+
sce_params: Optional[SCEParams] = None,
|
|
38
40
|
):
|
|
39
41
|
"""
|
|
40
42
|
:param tensor_schema: Tensor schema of features.
|
|
@@ -52,9 +54,10 @@ class SasRec(lightning.LightningModule):
|
|
|
52
54
|
Default: ``False``.
|
|
53
55
|
:param time_span: Time span value.
|
|
54
56
|
Default: ``256``.
|
|
55
|
-
:param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``.
|
|
57
|
+
:param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"SCE"``.
|
|
56
58
|
Default: ``CE``.
|
|
57
59
|
:param loss_sample_count (Optional[int]): Sample count to calculate loss.
|
|
60
|
+
Suitable for ``"CE"`` and ``"BCE"`` loss functions.
|
|
58
61
|
Default: ``None``.
|
|
59
62
|
:param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
|
|
60
63
|
Is used when large count of items in dataset.
|
|
@@ -66,6 +69,8 @@ class SasRec(lightning.LightningModule):
|
|
|
66
69
|
Default: ``FatOptimizerFactory``.
|
|
67
70
|
:param lr_scheduler_factory: Learning rate schedule factory.
|
|
68
71
|
Default: ``None``.
|
|
72
|
+
:param sce_params: Dataclass with SCE parameters. Need to be defined if ``loss_type`` is ``SCE``.
|
|
73
|
+
Default: ``None``.
|
|
69
74
|
"""
|
|
70
75
|
super().__init__()
|
|
71
76
|
self.save_hyperparameters()
|
|
@@ -85,9 +90,12 @@ class SasRec(lightning.LightningModule):
|
|
|
85
90
|
self._negatives_sharing = negatives_sharing
|
|
86
91
|
self._optimizer_factory = optimizer_factory
|
|
87
92
|
self._lr_scheduler_factory = lr_scheduler_factory
|
|
93
|
+
self._sce_params = sce_params
|
|
88
94
|
self._loss = self._create_loss()
|
|
89
95
|
self._schema = tensor_schema
|
|
90
96
|
assert negative_sampling_strategy in {"global_uniform", "inbatch"}
|
|
97
|
+
if self._loss_type == "SCE":
|
|
98
|
+
assert sce_params is not None, "You should define ``sce_params`` when using SCE loss function."
|
|
91
99
|
|
|
92
100
|
item_count = tensor_schema.item_id_features.item().cardinality
|
|
93
101
|
assert item_count
|
|
@@ -197,6 +205,8 @@ class SasRec(lightning.LightningModule):
|
|
|
197
205
|
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
198
206
|
elif self._loss_type == "CE":
|
|
199
207
|
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
|
|
208
|
+
elif self._loss_type == "SCE":
|
|
209
|
+
loss_func = self._compute_loss_scalable_ce
|
|
200
210
|
else:
|
|
201
211
|
msg = f"Not supported loss type: {self._loss_type}"
|
|
202
212
|
raise ValueError(msg)
|
|
@@ -314,6 +324,17 @@ class SasRec(lightning.LightningModule):
|
|
|
314
324
|
loss = self._loss(logits, labels_flat)
|
|
315
325
|
return loss
|
|
316
326
|
|
|
327
|
+
def _compute_loss_scalable_ce(
|
|
328
|
+
self,
|
|
329
|
+
feature_tensors: TensorMap,
|
|
330
|
+
positive_labels: torch.LongTensor,
|
|
331
|
+
padding_mask: torch.BoolTensor,
|
|
332
|
+
tokens_mask: torch.BoolTensor, # noqa: ARG002
|
|
333
|
+
) -> torch.Tensor:
|
|
334
|
+
emb = self._model.forward_step(feature_tensors, padding_mask)
|
|
335
|
+
all_embeddings = self.get_all_embeddings()["item_embedding"]
|
|
336
|
+
return self._loss(emb, positive_labels, all_embeddings, padding_mask)
|
|
337
|
+
|
|
317
338
|
def _get_sampled_logits(
|
|
318
339
|
self,
|
|
319
340
|
feature_tensors: TensorMap,
|
|
@@ -401,6 +422,9 @@ class SasRec(lightning.LightningModule):
|
|
|
401
422
|
if self._loss_type == "CE":
|
|
402
423
|
return torch.nn.CrossEntropyLoss()
|
|
403
424
|
|
|
425
|
+
if self._loss_type == "SCE":
|
|
426
|
+
return ScalableCrossEntropyLoss(self._sce_params)
|
|
427
|
+
|
|
404
428
|
msg = "Not supported loss_type"
|
|
405
429
|
raise NotImplementedError(msg)
|
|
406
430
|
|
|
@@ -298,7 +298,7 @@ class EmbeddingTyingHead(torch.nn.Module):
|
|
|
298
298
|
if len(item_embeddings.shape) > 2: # global_uniform, negative sharing=False, train only
|
|
299
299
|
logits = (item_embeddings * out_embeddings.unsqueeze(-2)).sum(dim=-1)
|
|
300
300
|
else:
|
|
301
|
-
logits =
|
|
301
|
+
logits = torch.matmul(out_embeddings, item_embeddings.t())
|
|
302
302
|
return logits
|
|
303
303
|
|
|
304
304
|
|
|
@@ -4,7 +4,8 @@ Select or remove data by some criteria
|
|
|
4
4
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from datetime import datetime, timedelta
|
|
7
|
-
from typing import Callable, Optional, Tuple, Union
|
|
7
|
+
from typing import Callable, Literal, Optional, Tuple, Union
|
|
8
|
+
from uuid import uuid4
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import pandas as pd
|
|
@@ -989,3 +990,103 @@ class QuantileItemsFilter(_BaseFilter):
|
|
|
989
990
|
)
|
|
990
991
|
short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
|
|
991
992
|
return long_tail.select(df.columns).union(short_tail.select(df.columns))
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
class ConsecutiveDuplicatesFilter(_BaseFilter):
|
|
996
|
+
"""Removes consecutive duplicate items from sequential dataset.
|
|
997
|
+
|
|
998
|
+
>>> import datetime as dt
|
|
999
|
+
>>> import pandas as pd
|
|
1000
|
+
>>> from replay.utils.spark_utils import convert2spark
|
|
1001
|
+
>>> interactions = pd.DataFrame({
|
|
1002
|
+
... "user_id": ["u0", "u1", "u1", "u0", "u0", "u0", "u1", "u0"],
|
|
1003
|
+
... "item_id": ["i0", "i1", "i1", "i2", "i0", "i1", "i2", "i1"],
|
|
1004
|
+
... "timestamp": [dt.datetime(2024, 1, 1) + dt.timedelta(days=i) for i in range(8)]
|
|
1005
|
+
... })
|
|
1006
|
+
>>> interactions = convert2spark(interactions)
|
|
1007
|
+
>>> interactions.show()
|
|
1008
|
+
+-------+-------+-------------------+
|
|
1009
|
+
|user_id|item_id| timestamp|
|
|
1010
|
+
+-------+-------+-------------------+
|
|
1011
|
+
| u0| i0|2024-01-01 00:00:00|
|
|
1012
|
+
| u1| i1|2024-01-02 00:00:00|
|
|
1013
|
+
| u1| i1|2024-01-03 00:00:00|
|
|
1014
|
+
| u0| i2|2024-01-04 00:00:00|
|
|
1015
|
+
| u0| i0|2024-01-05 00:00:00|
|
|
1016
|
+
| u0| i1|2024-01-06 00:00:00|
|
|
1017
|
+
| u1| i2|2024-01-07 00:00:00|
|
|
1018
|
+
| u0| i1|2024-01-08 00:00:00|
|
|
1019
|
+
+-------+-------+-------------------+
|
|
1020
|
+
<BLANKLINE>
|
|
1021
|
+
|
|
1022
|
+
>>> ConsecutiveDuplicatesFilter(query_column="user_id").transform(interactions).show()
|
|
1023
|
+
+-------+-------+-------------------+
|
|
1024
|
+
|user_id|item_id| timestamp|
|
|
1025
|
+
+-------+-------+-------------------+
|
|
1026
|
+
| u0| i0|2024-01-01 00:00:00|
|
|
1027
|
+
| u0| i2|2024-01-04 00:00:00|
|
|
1028
|
+
| u0| i0|2024-01-05 00:00:00|
|
|
1029
|
+
| u0| i1|2024-01-06 00:00:00|
|
|
1030
|
+
| u1| i1|2024-01-02 00:00:00|
|
|
1031
|
+
| u1| i2|2024-01-07 00:00:00|
|
|
1032
|
+
+-------+-------+-------------------+
|
|
1033
|
+
<BLANKLINE>
|
|
1034
|
+
"""
|
|
1035
|
+
|
|
1036
|
+
def __init__(
|
|
1037
|
+
self,
|
|
1038
|
+
keep: Literal["first", "last"] = "first",
|
|
1039
|
+
query_column: str = "query_id",
|
|
1040
|
+
item_column: str = "item_id",
|
|
1041
|
+
timestamp_column: str = "timestamp",
|
|
1042
|
+
) -> None:
|
|
1043
|
+
"""
|
|
1044
|
+
:param keep: whether to keep first or last occurrence,
|
|
1045
|
+
Default: ``first``.
|
|
1046
|
+
:param query_column: query column,
|
|
1047
|
+
Default: ``query_id``.
|
|
1048
|
+
:param item_column: item column,
|
|
1049
|
+
Default: ``item_id``.
|
|
1050
|
+
:param timestamp_column: timestamp column,
|
|
1051
|
+
Default: ``timestamp``.
|
|
1052
|
+
"""
|
|
1053
|
+
super().__init__()
|
|
1054
|
+
self.query_column = query_column
|
|
1055
|
+
self.item_column = item_column
|
|
1056
|
+
self.timestamp_column = timestamp_column
|
|
1057
|
+
|
|
1058
|
+
if keep not in ("first", "last"):
|
|
1059
|
+
msg = "`keep` must be either 'first' or 'last'"
|
|
1060
|
+
raise ValueError(msg)
|
|
1061
|
+
|
|
1062
|
+
self.bias = 1 if keep == "first" else -1
|
|
1063
|
+
self.temporary_column = f"__shifted_{uuid4().hex[:8]}"
|
|
1064
|
+
|
|
1065
|
+
def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
1066
|
+
interactions = interactions.sort_values(self.timestamp_column)
|
|
1067
|
+
interactions[self.temporary_column] = interactions.groupby(self.query_column)[self.item_column].shift(
|
|
1068
|
+
periods=self.bias
|
|
1069
|
+
)
|
|
1070
|
+
return (
|
|
1071
|
+
interactions[interactions[self.item_column] != interactions[self.temporary_column]]
|
|
1072
|
+
.drop(self.temporary_column, axis=1)
|
|
1073
|
+
.reset_index(drop=True)
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
1077
|
+
return (
|
|
1078
|
+
interactions.sort(self.timestamp_column)
|
|
1079
|
+
.with_columns(
|
|
1080
|
+
pl.col(self.item_column).shift(n=self.bias).over(self.query_column).alias(self.temporary_column)
|
|
1081
|
+
)
|
|
1082
|
+
.filter((pl.col(self.item_column) != pl.col(self.temporary_column)).fill_null(True))
|
|
1083
|
+
.drop(self.temporary_column)
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
def _filter_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
|
|
1087
|
+
window = Window.partitionBy(self.query_column).orderBy(self.timestamp_column)
|
|
1088
|
+
return (
|
|
1089
|
+
interactions.withColumn(self.temporary_column, sf.lag(self.item_column, offset=self.bias).over(window))
|
|
1090
|
+
.where((sf.col(self.item_column) != sf.col(self.temporary_column)) | sf.col(self.temporary_column).isNull())
|
|
1091
|
+
.drop(self.temporary_column)
|
|
1092
|
+
)
|