replay-rec 0.16.0rc0__tar.gz → 0.17.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/PKG-INFO +2 -2
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/pyproject.toml +67 -36
- replay_rec-0.17.0rc0/replay/__init__.py +2 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/dataset.py +45 -42
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/__init__.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/schema.py +20 -33
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/sequence_tokenizer.py +217 -87
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/sequential_dataset.py +6 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/torch_sequential_dataset.py +20 -11
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/nn/utils.py +7 -9
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/schema.py +17 -17
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/spark_schema.py +0 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/base_metric.py +63 -123
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/coverage.py +15 -35
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/experiment.py +18 -43
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/hitrate.py +2 -3
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/map.py +2 -3
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/mrr.py +0 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/ncis_precision.py +2 -3
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/ndcg.py +3 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/precision.py +2 -3
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/recall.py +2 -3
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/rocauc.py +6 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/surprisal.py +16 -28
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/unexpectedness.py +16 -14
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/admm_slim.py +11 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/base_neighbour_rec.py +20 -38
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/base_rec.py +59 -149
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/base_torch_rec.py +13 -26
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/cql.py +83 -99
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/ddpg.py +43 -129
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +9 -13
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/dt4rec/gpt1.py +9 -19
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/dt4rec/trainer.py +5 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/dt4rec/utils.py +6 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +123 -64
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/implicit_wrap.py +21 -28
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/lightfm_wrap.py +22 -47
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/mult_vae.py +22 -65
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/neuromf.py +34 -91
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/scala_als.py +25 -40
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/nn/data/schema_builder.py +1 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/preprocessing/data_preparator.py +78 -169
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/preprocessing/padder.py +20 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/preprocessing/sequence_generator.py +11 -21
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +25 -37
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +71 -94
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +31 -32
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +2 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +41 -127
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/utils/logger.py +7 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/utils/model_handler.py +18 -50
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/utils/session_handler.py +1 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/base_metric.py +38 -79
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/categorical_diversity.py +24 -58
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/coverage.py +25 -49
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/descriptors.py +4 -13
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/experiment.py +3 -8
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/hitrate.py +3 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/map.py +3 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/mrr.py +1 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/ndcg.py +4 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/novelty.py +10 -29
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/offline_metrics.py +26 -61
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/precision.py +3 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/recall.py +3 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/rocauc.py +7 -10
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/surprisal.py +13 -30
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/torch_metrics_builder.py +0 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/unexpectedness.py +15 -20
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/__init__.py +1 -2
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/als.py +7 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/association_rules.py +12 -28
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/base_neighbour_rec.py +21 -36
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/base_rec.py +92 -215
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/cat_pop_rec.py +9 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/cluster.py +17 -28
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/ann_mixin.py +7 -12
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay_rec-0.17.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_stores/utils.py +5 -2
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/utils.py +3 -5
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/kl_ucb.py +16 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/knn.py +37 -59
- replay_rec-0.17.0rc0/replay/models/nn/optimizer_utils/__init__.py +4 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/bert4rec/model.py +12 -25
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +8 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +53 -48
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/sasrec/model.py +4 -17
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/pop_rec.py +9 -10
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/query_pop_rec.py +7 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/random_rec.py +10 -18
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/slim.py +8 -13
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/thompson_sampling.py +13 -14
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/ucb.py +11 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/wilson.py +5 -14
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/word2vec.py +24 -69
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/optimization/optuna_objective.py +13 -27
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/__init__.py +1 -2
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/converter.py +2 -7
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/filters.py +67 -142
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/history_based_fp.py +44 -116
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/label_encoder.py +106 -68
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/preprocessing/sessionizer.py +1 -11
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/scenarios/fallback.py +3 -8
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/base_splitter.py +43 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/cold_user_random_splitter.py +18 -31
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/k_folds.py +14 -24
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/last_n_splitter.py +33 -43
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/new_users_splitter.py +31 -55
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/random_splitter.py +16 -23
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/ratio_splitter.py +30 -54
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/time_splitter.py +13 -18
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/two_stage_splitter.py +44 -79
- replay_rec-0.17.0rc0/replay/utils/common.py +65 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/dataframe_bucketizer.py +25 -31
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/distributions.py +3 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/model_handler.py +36 -33
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/session_handler.py +11 -15
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/spark_utils.py +51 -85
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/time.py +8 -22
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/types.py +1 -3
- replay_rec-0.16.0rc0/replay/__init__.py +0 -2
- replay_rec-0.16.0rc0/replay/models/nn/optimizer_utils/__init__.py +0 -9
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/LICENSE +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/NOTICE +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/README.md +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/__init__.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/metrics/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/__init__.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/dt4rec/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/experimental/utils → replay_rec-0.17.0rc0/replay/experimental/scenarios/two_stages}/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions → replay_rec-0.17.0rc0/replay/experimental/utils}/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/metrics/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions/ann → replay_rec-0.17.0rc0/replay/models/extensions}/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions/ann/entities → replay_rec-0.17.0rc0/replay/models/extensions/ann}/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions/ann/index_builders → replay_rec-0.17.0rc0/replay/models/extensions/ann/entities}/__init__.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions/ann/index_inferers → replay_rec-0.17.0rc0/replay/models/extensions/ann/index_builders}/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.16.0rc0/replay/models/extensions/ann/index_stores → replay_rec-0.17.0rc0/replay/models/extensions/ann/index_inferers}/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/optimization/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/scenarios/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/splitters/__init__.py +0 -0
- {replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/utils/__init__.py +1 -1
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.17.0rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
Home-page: https://sb-ai-lab.github.io/RePlay/
|
|
6
6
|
License: Apache-2.0
|
|
@@ -35,7 +35,7 @@ Requires-Dist: optuna (>=3.2.0,<3.3.0)
|
|
|
35
35
|
Requires-Dist: pandas (>=1.3.5,<2.0.0)
|
|
36
36
|
Requires-Dist: polars (>=0.20.7,<0.21.0)
|
|
37
37
|
Requires-Dist: psutil (>=5.9.5,<5.10.0)
|
|
38
|
-
Requires-Dist: pyarrow (>=12.0.1
|
|
38
|
+
Requires-Dist: pyarrow (>=12.0.1)
|
|
39
39
|
Requires-Dist: pyspark (>=3.0,<3.3) ; extra == "spark" or extra == "all"
|
|
40
40
|
Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
|
|
41
41
|
Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
|
|
@@ -1,16 +1,29 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"poetry-core>=1.0.0",
|
|
4
|
+
"poetry-dynamic-versioning>=1.0.0,<2.0.0",
|
|
5
|
+
]
|
|
6
|
+
build-backend = "poetry_dynamic_versioning.backend"
|
|
7
|
+
|
|
8
|
+
[tool.black]
|
|
9
|
+
line-length = 120
|
|
10
|
+
target-versions = ["py38", "py39", "py310"]
|
|
11
|
+
|
|
1
12
|
[tool.poetry]
|
|
2
13
|
name = "replay-rec"
|
|
3
14
|
packages = [{include = "replay"}]
|
|
4
15
|
license = "Apache-2.0"
|
|
5
16
|
description = "RecSys Library"
|
|
6
|
-
authors = [
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
17
|
+
authors = [
|
|
18
|
+
"AI Lab",
|
|
19
|
+
"Alexey Vasilev",
|
|
20
|
+
"Anna Volodkevich",
|
|
21
|
+
"Alexey Grishanov",
|
|
22
|
+
"Yan-Martin Tamm",
|
|
23
|
+
"Boris Shminke",
|
|
24
|
+
"Alexander Sidorenko",
|
|
25
|
+
"Roza Aysina",
|
|
26
|
+
]
|
|
14
27
|
readme = "README.md"
|
|
15
28
|
homepage = "https://sb-ai-lab.github.io/RePlay/"
|
|
16
29
|
repository = "https://github.com/sb-ai-lab/RePlay"
|
|
@@ -27,7 +40,7 @@ classifiers = [
|
|
|
27
40
|
exclude = [
|
|
28
41
|
"replay/conftest.py",
|
|
29
42
|
]
|
|
30
|
-
version = "0.
|
|
43
|
+
version = "0.17.0.preview"
|
|
31
44
|
|
|
32
45
|
[tool.poetry.dependencies]
|
|
33
46
|
python = ">=3.8.1, <3.11"
|
|
@@ -37,12 +50,11 @@ polars = "~0.20.7"
|
|
|
37
50
|
optuna = "~3.2.0"
|
|
38
51
|
scipy = "~1.8.1"
|
|
39
52
|
psutil = "~5.9.5"
|
|
40
|
-
pyspark = {
|
|
53
|
+
pyspark = {version = ">=3.0,<3.3", optional = true}
|
|
41
54
|
scikit-learn = "^1.0.2"
|
|
42
|
-
pyarrow = ">=12.0.1
|
|
55
|
+
pyarrow = ">=12.0.1"
|
|
43
56
|
nmslib = "2.1.1"
|
|
44
57
|
hnswlib = "0.7.0"
|
|
45
|
-
|
|
46
58
|
torch = "^1.8"
|
|
47
59
|
lightning = "^2.0.2"
|
|
48
60
|
pytorch-ranger = "^0.1.1"
|
|
@@ -55,20 +67,20 @@ d3rlpy = "^2.0.4"
|
|
|
55
67
|
implicit = "~0.7.0"
|
|
56
68
|
gym = "^0.26.0"
|
|
57
69
|
|
|
70
|
+
[tool.poetry.extras]
|
|
71
|
+
spark = ["pyspark"]
|
|
72
|
+
torch = ["torch", "pytorch-ranger", "lightning"]
|
|
73
|
+
all = ["pyspark", "torch", "pytorch-ranger", "lightning"]
|
|
74
|
+
|
|
58
75
|
[tool.poetry.group.dev.dependencies]
|
|
59
|
-
# visualization
|
|
60
76
|
jupyter = "~1.0.0"
|
|
61
77
|
jupyterlab = "^3.6.0"
|
|
62
|
-
# testing
|
|
63
78
|
pytest = ">=7.1.0"
|
|
64
79
|
pytest-cov = ">=3.0.0"
|
|
65
80
|
statsmodels = "~0.13.5"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
pylint = "^2.13"
|
|
70
|
-
pycodestyle = "^2.10"
|
|
71
|
-
# docs
|
|
81
|
+
black = ">=23.3.0"
|
|
82
|
+
ruff = ">=0.0.261"
|
|
83
|
+
toml-sort = "^0.23.0"
|
|
72
84
|
sphinx = "5.3.0"
|
|
73
85
|
sphinx-rtd-theme = "1.2.2"
|
|
74
86
|
sphinx-autodoc-typehints = "1.23.0"
|
|
@@ -76,26 +88,45 @@ sphinx-enum-extend = "0.1.3"
|
|
|
76
88
|
myst-parser = "1.0.0"
|
|
77
89
|
ghp-import = "2.1.0"
|
|
78
90
|
docutils = "0.16"
|
|
79
|
-
# stubs
|
|
80
91
|
data-science-types = "0.2.23"
|
|
81
92
|
|
|
82
|
-
[tool.poetry.extras]
|
|
83
|
-
spark = ["pyspark"]
|
|
84
|
-
torch = ["torch", "pytorch-ranger", "lightning"]
|
|
85
|
-
all = ["pyspark", "torch", "pytorch-ranger", "lightning"]
|
|
86
|
-
|
|
87
|
-
[build-system]
|
|
88
|
-
requires = [
|
|
89
|
-
"poetry-core>=1.0.0",
|
|
90
|
-
"poetry-dynamic-versioning>=1.0.0,<2.0.0",
|
|
91
|
-
]
|
|
92
|
-
build-backend = "poetry_dynamic_versioning.backend"
|
|
93
|
-
|
|
94
93
|
[tool.poetry-dynamic-versioning]
|
|
95
94
|
enable = false
|
|
96
|
-
format-jinja = """0.
|
|
95
|
+
format-jinja = """0.17.0{{ env['PACKAGE_SUFFIX'] }}"""
|
|
97
96
|
vcs = "git"
|
|
98
97
|
|
|
99
|
-
[tool.
|
|
98
|
+
[tool.ruff]
|
|
99
|
+
exclude = [".git", ".venv", "__pycache__", "env", "venv", "docs", "projects", "examples"]
|
|
100
|
+
extend-select = ["C90", "T10", "T20", "UP004"]
|
|
100
101
|
line-length = 120
|
|
101
|
-
|
|
102
|
+
select = ["ARG", "C4", "E", "EM", "ERA", "F", "FLY", "I", "INP", "ISC", "N", "PERF", "PGH", "PIE", "PYI", "Q", "RUF", "SIM", "TID", "W"]
|
|
103
|
+
|
|
104
|
+
[tool.ruff.flake8-quotes]
|
|
105
|
+
docstring-quotes = "double"
|
|
106
|
+
inline-quotes = "double"
|
|
107
|
+
multiline-quotes = "double"
|
|
108
|
+
|
|
109
|
+
[tool.ruff.flake8-unused-arguments]
|
|
110
|
+
ignore-variadic-names = false
|
|
111
|
+
|
|
112
|
+
[tool.ruff.isort]
|
|
113
|
+
combine-as-imports = true
|
|
114
|
+
force-wrap-aliases = true
|
|
115
|
+
|
|
116
|
+
[tool.ruff.mccabe]
|
|
117
|
+
max-complexity = 13
|
|
118
|
+
|
|
119
|
+
[tool.ruff.per-file-ignores]
|
|
120
|
+
"*/" = ["PERF203", "RUF001", "RUF002", "RUF012", "E402"]
|
|
121
|
+
"__init__.py" = ["F401"]
|
|
122
|
+
"replay/utils/model_handler.py" = ["F403", "F405"]
|
|
123
|
+
"tests/*" = ["ARG", "E402", "INP", "ISC", "N", "S", "SIM", "F811"]
|
|
124
|
+
"tests/experimental/*" = ["F401", "F811"]
|
|
125
|
+
"replay/experimental/models/extensions/spark_custom_models/als_extension.py" = ["ARG002", "N802", "N803", "N815"]
|
|
126
|
+
|
|
127
|
+
[tool.tomlsort]
|
|
128
|
+
ignore_case = true
|
|
129
|
+
in_place = true
|
|
130
|
+
no_comments = true
|
|
131
|
+
spaces_indent_inline_array = 4
|
|
132
|
+
trailing_comma_inline_array = true
|
|
@@ -7,21 +7,20 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
-
from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
|
|
11
10
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
12
11
|
|
|
12
|
+
from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
|
|
13
|
+
|
|
13
14
|
if PYSPARK_AVAILABLE:
|
|
14
|
-
import pyspark.sql.functions as
|
|
15
|
+
import pyspark.sql.functions as sf
|
|
15
16
|
from pyspark.storagelevel import StorageLevel
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
# pylint: disable=too-many-instance-attributes
|
|
19
19
|
class Dataset:
|
|
20
20
|
"""
|
|
21
21
|
Universal dataset for feeding data to models.
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
-
# pylint: disable=too-many-arguments
|
|
25
24
|
def __init__(
|
|
26
25
|
self,
|
|
27
26
|
feature_schema: FeatureSchema,
|
|
@@ -57,23 +56,23 @@ class Dataset:
|
|
|
57
56
|
try:
|
|
58
57
|
feature_schema.item_id_column
|
|
59
58
|
except Exception as exception:
|
|
60
|
-
|
|
59
|
+
msg = "Item id column is not set."
|
|
60
|
+
raise ValueError(msg) from exception
|
|
61
61
|
|
|
62
62
|
try:
|
|
63
63
|
feature_schema.query_id_column
|
|
64
64
|
except Exception as exception:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
and
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
self.query_features is not None
|
|
74
|
-
and not check_dataframes_types_equal(self._interactions, self.query_features)
|
|
65
|
+
msg = "Query id column is not set."
|
|
66
|
+
raise ValueError(msg) from exception
|
|
67
|
+
|
|
68
|
+
if self.item_features is not None and not check_dataframes_types_equal(self._interactions, self.item_features):
|
|
69
|
+
msg = "Interactions and item features should have the same type."
|
|
70
|
+
raise TypeError(msg)
|
|
71
|
+
if self.query_features is not None and not check_dataframes_types_equal(
|
|
72
|
+
self._interactions, self.query_features
|
|
75
73
|
):
|
|
76
|
-
|
|
74
|
+
msg = "Interactions and query features should have the same type."
|
|
75
|
+
raise TypeError(msg)
|
|
77
76
|
|
|
78
77
|
self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
|
|
79
78
|
FeatureSource.INTERACTIONS: self.interactions,
|
|
@@ -191,6 +190,7 @@ class Dataset:
|
|
|
191
190
|
return self._feature_schema
|
|
192
191
|
|
|
193
192
|
if PYSPARK_AVAILABLE:
|
|
193
|
+
|
|
194
194
|
def persist(self, storage_level: StorageLevel = StorageLevel(True, True, False, True, 1)) -> None:
|
|
195
195
|
"""
|
|
196
196
|
Sets the storage level to persist SparkDataFrame for interactions, item_features
|
|
@@ -295,7 +295,6 @@ class Dataset:
|
|
|
295
295
|
def _set_cardinality(self, features_list: Sequence[FeatureInfo]) -> None:
|
|
296
296
|
for feature in features_list:
|
|
297
297
|
if feature.feature_type == FeatureType.CATEGORICAL:
|
|
298
|
-
# pylint: disable=protected-access
|
|
299
298
|
feature._set_cardinality_callback(self._get_cardinality(feature))
|
|
300
299
|
|
|
301
300
|
def _fill_feature_schema(self, feature_schema: FeatureSchema) -> FeatureSchema:
|
|
@@ -333,15 +332,14 @@ class Dataset:
|
|
|
333
332
|
|
|
334
333
|
for feature in features_list:
|
|
335
334
|
if feature.feature_hint in [FeatureHint.QUERY_ID, FeatureHint.ITEM_ID]:
|
|
336
|
-
# pylint: disable=protected-access
|
|
337
335
|
feature._set_feature_source(source=FeatureSource.INTERACTIONS)
|
|
338
336
|
continue
|
|
339
|
-
source = source_mapping.get(feature.column)
|
|
337
|
+
source = source_mapping.get(feature.column)
|
|
340
338
|
if source:
|
|
341
|
-
# pylint: disable=protected-access
|
|
342
339
|
feature._set_feature_source(source=source_mapping[feature.column])
|
|
343
340
|
else:
|
|
344
|
-
|
|
341
|
+
msg = f"{feature.column} doesn't exist in provided dataframes"
|
|
342
|
+
raise ValueError(msg)
|
|
345
343
|
|
|
346
344
|
self._set_cardinality(features_list=features_list)
|
|
347
345
|
return features_list
|
|
@@ -362,10 +360,8 @@ class Dataset:
|
|
|
362
360
|
self._set_cardinality(features_list=unlabeled_columns)
|
|
363
361
|
return unlabeled_columns
|
|
364
362
|
|
|
365
|
-
# pylint: disable=no-self-use
|
|
366
363
|
def _set_features_source(self, feature_list: List[FeatureInfo], source: FeatureSource) -> None:
|
|
367
364
|
for feature in feature_list:
|
|
368
|
-
# pylint: disable=protected-access
|
|
369
365
|
feature._set_feature_source(source)
|
|
370
366
|
|
|
371
367
|
def _check_ids_consistency(self, hint: FeatureHint) -> None:
|
|
@@ -377,8 +373,8 @@ class Dataset:
|
|
|
377
373
|
self.feature_schema.item_id_column if hint == FeatureHint.ITEM_ID else self.feature_schema.query_id_column
|
|
378
374
|
)
|
|
379
375
|
if self.is_pandas:
|
|
380
|
-
interactions_unique_ids = set(self.interactions[ids_column].unique())
|
|
381
|
-
features_df_unique_ids = set(features_df[ids_column].unique())
|
|
376
|
+
interactions_unique_ids = set(self.interactions[ids_column].unique())
|
|
377
|
+
features_df_unique_ids = set(features_df[ids_column].unique())
|
|
382
378
|
in_interactions_not_in_features_ids = interactions_unique_ids - features_df_unique_ids
|
|
383
379
|
is_consistent = len(in_interactions_not_in_features_ids) == 0
|
|
384
380
|
elif self.is_spark:
|
|
@@ -389,14 +385,18 @@ class Dataset:
|
|
|
389
385
|
.count()
|
|
390
386
|
) == 0
|
|
391
387
|
else:
|
|
392
|
-
is_consistent =
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
388
|
+
is_consistent = (
|
|
389
|
+
len(
|
|
390
|
+
self.interactions.select(ids_column)
|
|
391
|
+
.unique()
|
|
392
|
+
.join(features_df.select(ids_column).unique(), on=ids_column, how="anti")
|
|
393
|
+
)
|
|
394
|
+
== 0
|
|
395
|
+
)
|
|
397
396
|
|
|
398
397
|
if not is_consistent:
|
|
399
|
-
|
|
398
|
+
msg = f"There are IDs in the interactions that are missing in the {hint.name} dataframe."
|
|
399
|
+
raise ValueError(msg)
|
|
400
400
|
|
|
401
401
|
def _check_column_encoded(
|
|
402
402
|
self, data: DataFrameLike, column: str, source: FeatureSource, cardinality: Optional[int]
|
|
@@ -419,26 +419,29 @@ class Dataset:
|
|
|
419
419
|
is_int = data[column].dtype.is_integer()
|
|
420
420
|
|
|
421
421
|
if not is_int:
|
|
422
|
-
|
|
422
|
+
msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
|
|
423
|
+
raise ValueError(msg)
|
|
423
424
|
|
|
424
425
|
if self.is_pandas:
|
|
425
|
-
min_id = data[column].min()
|
|
426
|
+
min_id = data[column].min()
|
|
426
427
|
elif self.is_spark:
|
|
427
|
-
min_id = data.agg(
|
|
428
|
+
min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
|
|
428
429
|
else:
|
|
429
|
-
min_id = data[column].min()
|
|
430
|
+
min_id = data[column].min()
|
|
430
431
|
if min_id < 0:
|
|
431
|
-
|
|
432
|
+
msg = f"IDs in {source.name}.{column} are not encoded. Min ID is less than 0."
|
|
433
|
+
raise ValueError(msg)
|
|
432
434
|
|
|
433
435
|
if self.is_pandas:
|
|
434
|
-
max_id = data[column].max()
|
|
436
|
+
max_id = data[column].max()
|
|
435
437
|
elif self.is_spark:
|
|
436
|
-
max_id = data.agg(
|
|
438
|
+
max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
|
|
437
439
|
else:
|
|
438
|
-
max_id = data[column].max()
|
|
440
|
+
max_id = data[column].max()
|
|
439
441
|
|
|
440
442
|
if max_id >= cardinality:
|
|
441
|
-
|
|
443
|
+
msg = f"IDs in {source.name}.{column} are not encoded. Max ID is more than quantity of IDs."
|
|
444
|
+
raise ValueError(msg)
|
|
442
445
|
|
|
443
446
|
def _check_encoded(self) -> None:
|
|
444
447
|
for feature in self.feature_schema.categorical_features.all_features:
|
|
@@ -471,11 +474,11 @@ class Dataset:
|
|
|
471
474
|
feature.cardinality,
|
|
472
475
|
)
|
|
473
476
|
else:
|
|
474
|
-
data = self._feature_source_map[feature.feature_source]
|
|
477
|
+
data = self._feature_source_map[feature.feature_source]
|
|
475
478
|
self._check_column_encoded(
|
|
476
479
|
data,
|
|
477
480
|
feature.column,
|
|
478
|
-
feature.feature_source,
|
|
481
|
+
feature.feature_source,
|
|
479
482
|
feature.cardinality,
|
|
480
483
|
)
|
|
481
484
|
|
{replay_rec-0.16.0rc0 → replay_rec-0.17.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py
RENAMED
|
@@ -31,6 +31,7 @@ class DatasetLabelEncoder:
|
|
|
31
31
|
When set to ``error`` an error will be raised in case an unknown label is present during transform.
|
|
32
32
|
When set to ``use_default_value``, the encoded value of unknown label will be set
|
|
33
33
|
to the value given for the parameter default_value.
|
|
34
|
+
When set to ``drop``, the unknown labels will be dropped.
|
|
34
35
|
Default: ``error``.
|
|
35
36
|
:param default_value: Default value that will fill the unknown labels after transform.
|
|
36
37
|
When the parameter handle_unknown is set to ``use_default_value``,
|
|
@@ -105,7 +106,7 @@ class DatasetLabelEncoder:
|
|
|
105
106
|
for column, feature_info in dataset.feature_schema.categorical_features.items():
|
|
106
107
|
if column not in self._encoding_rules:
|
|
107
108
|
warnings.warn(
|
|
108
|
-
f"Cannot transform feature '{column}'
|
|
109
|
+
f"Cannot transform feature '{column}' as it was not present at the fit stage",
|
|
109
110
|
LabelEncoderTransformWarning,
|
|
110
111
|
)
|
|
111
112
|
continue
|
|
@@ -157,10 +158,7 @@ class DatasetLabelEncoder:
|
|
|
157
158
|
self._check_if_initialized()
|
|
158
159
|
|
|
159
160
|
columns_set: Set[str]
|
|
160
|
-
if isinstance(columns, str)
|
|
161
|
-
columns_set = set([columns])
|
|
162
|
-
else:
|
|
163
|
-
columns_set = set(columns)
|
|
161
|
+
columns_set = {columns} if isinstance(columns, str) else {*columns}
|
|
164
162
|
|
|
165
163
|
def get_encoding_rules() -> Iterator[LabelEncodingRule]:
|
|
166
164
|
for column, rule in self._encoding_rules.items():
|
|
@@ -200,7 +198,7 @@ class DatasetLabelEncoder:
|
|
|
200
198
|
"""
|
|
201
199
|
query_id_column = self._features_columns[FeatureHint.QUERY_ID]
|
|
202
200
|
item_id_column = self._features_columns[FeatureHint.ITEM_ID]
|
|
203
|
-
encoder = self.get_encoder(query_id_column + item_id_column)
|
|
201
|
+
encoder = self.get_encoder(query_id_column + item_id_column)
|
|
204
202
|
assert encoder is not None
|
|
205
203
|
return encoder
|
|
206
204
|
|
|
@@ -231,7 +229,8 @@ class DatasetLabelEncoder:
|
|
|
231
229
|
|
|
232
230
|
def _check_if_initialized(self) -> None:
|
|
233
231
|
if not self._encoding_rules:
|
|
234
|
-
|
|
232
|
+
msg = "Encoder is not initialized"
|
|
233
|
+
raise ValueError(msg)
|
|
235
234
|
|
|
236
235
|
def _fill_features_columns(self, feature_info: FeatureSchema) -> None:
|
|
237
236
|
self._features_columns = {
|
|
@@ -3,7 +3,7 @@ from replay.utils import TORCH_AVAILABLE
|
|
|
3
3
|
if TORCH_AVAILABLE:
|
|
4
4
|
from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
|
|
5
5
|
from .sequence_tokenizer import SequenceTokenizer
|
|
6
|
-
from .sequential_dataset import PandasSequentialDataset,
|
|
6
|
+
from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
|
|
7
7
|
from .torch_sequential_dataset import (
|
|
8
8
|
DEFAULT_GROUND_TRUTH_PADDING_VALUE,
|
|
9
9
|
DEFAULT_TRAIN_PADDING_VALUE,
|
|
@@ -11,7 +11,6 @@ from typing import (
|
|
|
11
11
|
Set,
|
|
12
12
|
Union,
|
|
13
13
|
ValuesView,
|
|
14
|
-
Callable
|
|
15
14
|
)
|
|
16
15
|
|
|
17
16
|
import torch
|
|
@@ -23,7 +22,6 @@ TensorMap = Mapping[str, torch.Tensor]
|
|
|
23
22
|
MutableTensorMap = Dict[str, torch.Tensor]
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
# pylint: disable=too-many-instance-attributes
|
|
27
25
|
class TensorFeatureSource:
|
|
28
26
|
"""
|
|
29
27
|
Describes source of a feature
|
|
@@ -72,7 +70,6 @@ class TensorFeatureInfo:
|
|
|
72
70
|
Information about a tensor feature.
|
|
73
71
|
"""
|
|
74
72
|
|
|
75
|
-
# pylint: disable=too-many-arguments
|
|
76
73
|
def __init__(
|
|
77
74
|
self,
|
|
78
75
|
name: str,
|
|
@@ -108,15 +105,18 @@ class TensorFeatureInfo:
|
|
|
108
105
|
self._is_seq = is_seq
|
|
109
106
|
|
|
110
107
|
if not isinstance(feature_type, FeatureType):
|
|
111
|
-
|
|
108
|
+
msg = "Unknown feature type"
|
|
109
|
+
raise ValueError(msg)
|
|
112
110
|
self._feature_type = feature_type
|
|
113
111
|
|
|
114
112
|
if feature_type == FeatureType.NUMERICAL and (cardinality or embedding_dim):
|
|
115
|
-
|
|
113
|
+
msg = "Cardinality and embedding dimensions are needed only with categorical feature type."
|
|
114
|
+
raise ValueError(msg)
|
|
116
115
|
self._cardinality = cardinality
|
|
117
116
|
|
|
118
117
|
if feature_type == FeatureType.CATEGORICAL and tensor_dim:
|
|
119
|
-
|
|
118
|
+
msg = "Tensor dimensions is needed only with numerical feature type."
|
|
119
|
+
raise ValueError(msg)
|
|
120
120
|
|
|
121
121
|
if feature_type == FeatureType.CATEGORICAL:
|
|
122
122
|
default_embedding_dim = 64
|
|
@@ -168,7 +168,8 @@ class TensorFeatureInfo:
|
|
|
168
168
|
return None
|
|
169
169
|
|
|
170
170
|
if len(source) > 1:
|
|
171
|
-
|
|
171
|
+
msg = "Only one element feature sources can be converted to single feature source."
|
|
172
|
+
raise ValueError(msg)
|
|
172
173
|
assert isinstance(self.feature_sources, list)
|
|
173
174
|
return self.feature_sources[0]
|
|
174
175
|
|
|
@@ -199,35 +200,21 @@ class TensorFeatureInfo:
|
|
|
199
200
|
:returns: Cardinality of the feature.
|
|
200
201
|
"""
|
|
201
202
|
if self.feature_type != FeatureType.CATEGORICAL:
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
)
|
|
205
|
-
if hasattr(self, "_cardinality_callback") and self._cardinality is None:
|
|
206
|
-
self._set_cardinality(self._cardinality_callback(self._name))
|
|
203
|
+
msg = f"Can not get cardinality because feature type of {self.name} column is not categorical."
|
|
204
|
+
raise RuntimeError(msg)
|
|
207
205
|
return self._cardinality
|
|
208
206
|
|
|
209
|
-
# pylint: disable=attribute-defined-outside-init
|
|
210
|
-
def _set_cardinality_callback(self, callback: Callable) -> None:
|
|
211
|
-
self._cardinality_callback = callback
|
|
212
|
-
|
|
213
207
|
def _set_cardinality(self, cardinality: int) -> None:
|
|
214
208
|
self._cardinality = cardinality
|
|
215
209
|
|
|
216
|
-
def reset_cardinality(self) -> None:
|
|
217
|
-
"""
|
|
218
|
-
Reset cardinality of the feature to None.
|
|
219
|
-
"""
|
|
220
|
-
self._cardinality = None
|
|
221
|
-
|
|
222
210
|
@property
|
|
223
211
|
def tensor_dim(self) -> Optional[int]:
|
|
224
212
|
"""
|
|
225
213
|
:returns: Dimensions of the numerical feature.
|
|
226
214
|
"""
|
|
227
215
|
if self.feature_type != FeatureType.NUMERICAL:
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
)
|
|
216
|
+
msg = f"Can not get tensor dimensions because feature type of {self.name} feature is not numerical."
|
|
217
|
+
raise RuntimeError(msg)
|
|
231
218
|
return self._tensor_dim
|
|
232
219
|
|
|
233
220
|
def _set_tensor_dim(self, tensor_dim: int) -> None:
|
|
@@ -239,9 +226,8 @@ class TensorFeatureInfo:
|
|
|
239
226
|
:returns: Embedding dimensions of the feature.
|
|
240
227
|
"""
|
|
241
228
|
if self.feature_type != FeatureType.CATEGORICAL:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
)
|
|
229
|
+
msg = f"Can not get embedding dimensions because feature type of {self.name} feature is not categorical."
|
|
230
|
+
raise RuntimeError(msg)
|
|
245
231
|
return self._embedding_dim
|
|
246
232
|
|
|
247
233
|
def _set_embedding_dim(self, embedding_dim: int) -> None:
|
|
@@ -278,8 +264,9 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
278
264
|
:returns: Extract single feature from a schema.
|
|
279
265
|
"""
|
|
280
266
|
if len(self._tensor_schema) != 1:
|
|
281
|
-
|
|
282
|
-
|
|
267
|
+
msg = "Only one element tensor schema can be converted to single feature"
|
|
268
|
+
raise ValueError(msg)
|
|
269
|
+
return next(iter(self._tensor_schema.values()))
|
|
283
270
|
|
|
284
271
|
def items(self) -> ItemsView[str, TensorFeatureInfo]:
|
|
285
272
|
return self._tensor_schema.items()
|
|
@@ -290,7 +277,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
290
277
|
def values(self) -> ValuesView[TensorFeatureInfo]:
|
|
291
278
|
return self._tensor_schema.values()
|
|
292
279
|
|
|
293
|
-
def get(
|
|
280
|
+
def get(
|
|
294
281
|
self,
|
|
295
282
|
key: str,
|
|
296
283
|
default: Optional[TensorFeatureInfo] = None,
|
|
@@ -377,7 +364,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
377
364
|
@property
|
|
378
365
|
def names(self) -> Sequence[str]:
|
|
379
366
|
"""
|
|
380
|
-
|
|
367
|
+
:returns: List of all feature's names.
|
|
381
368
|
"""
|
|
382
369
|
return list(self._tensor_schema)
|
|
383
370
|
|
|
@@ -447,7 +434,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
447
434
|
for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
|
|
448
435
|
filtered_features = list(
|
|
449
436
|
filter(
|
|
450
|
-
lambda x: filtration_func(x, filtration_param),
|
|
437
|
+
lambda x: filtration_func(x, filtration_param),
|
|
451
438
|
filtered_features,
|
|
452
439
|
)
|
|
453
440
|
)
|