warprec 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- warprec/__init__.py +6 -0
- warprec/common/__init__.py +8 -0
- warprec/common/initialize.py +417 -0
- warprec/common/optimizers.py +73 -0
- warprec/data/__init__.py +26 -0
- warprec/data/dataset.py +867 -0
- warprec/data/entities/__init__.py +10 -0
- warprec/data/entities/interactions.py +594 -0
- warprec/data/entities/sessions.py +366 -0
- warprec/data/entities/train_structures/__init__.py +21 -0
- warprec/data/entities/train_structures/custom_collate_fn.py +29 -0
- warprec/data/entities/train_structures/interaction_structures.py +199 -0
- warprec/data/entities/train_structures/session_structures.py +278 -0
- warprec/data/eval_loaders.py +400 -0
- warprec/data/filtering.py +543 -0
- warprec/data/reader/__init__.py +21 -0
- warprec/data/reader/azureblob_reader.py +239 -0
- warprec/data/reader/base_reader.py +587 -0
- warprec/data/reader/local_reader.py +147 -0
- warprec/data/splitting/__init__.py +4 -0
- warprec/data/splitting/splitter.py +248 -0
- warprec/data/splitting/strategies.py +498 -0
- warprec/data/writer/__init__.py +21 -0
- warprec/data/writer/azureblob_writer.py +145 -0
- warprec/data/writer/base_writer.py +565 -0
- warprec/data/writer/local_writer.py +110 -0
- warprec/evaluation/__init__.py +5 -0
- warprec/evaluation/evaluator.py +488 -0
- warprec/evaluation/metrics/__init__.py +19 -0
- warprec/evaluation/metrics/accuracy/__init__.py +26 -0
- warprec/evaluation/metrics/accuracy/auc.py +62 -0
- warprec/evaluation/metrics/accuracy/f1.py +107 -0
- warprec/evaluation/metrics/accuracy/gauc.py +51 -0
- warprec/evaluation/metrics/accuracy/hit_rate.py +24 -0
- warprec/evaluation/metrics/accuracy/lauc.py +57 -0
- warprec/evaluation/metrics/accuracy/map.py +36 -0
- warprec/evaluation/metrics/accuracy/mar.py +36 -0
- warprec/evaluation/metrics/accuracy/mrr.py +27 -0
- warprec/evaluation/metrics/accuracy/ndcg.py +71 -0
- warprec/evaluation/metrics/accuracy/precision.py +24 -0
- warprec/evaluation/metrics/accuracy/recall.py +31 -0
- warprec/evaluation/metrics/base_metric.py +474 -0
- warprec/evaluation/metrics/bias/__init__.py +7 -0
- warprec/evaluation/metrics/bias/aclt.py +77 -0
- warprec/evaluation/metrics/bias/aplt.py +75 -0
- warprec/evaluation/metrics/bias/arp.py +59 -0
- warprec/evaluation/metrics/bias/pop_reo.py +111 -0
- warprec/evaluation/metrics/bias/pop_rsp.py +104 -0
- warprec/evaluation/metrics/coverage/__init__.py +11 -0
- warprec/evaluation/metrics/coverage/item_coverage.py +58 -0
- warprec/evaluation/metrics/coverage/numretrieved.py +31 -0
- warprec/evaluation/metrics/coverage/user_coverage.py +42 -0
- warprec/evaluation/metrics/coverage/user_coverage_at_n.py +43 -0
- warprec/evaluation/metrics/diversity/__init__.py +9 -0
- warprec/evaluation/metrics/diversity/gini_index.py +95 -0
- warprec/evaluation/metrics/diversity/shannon_entropy.py +76 -0
- warprec/evaluation/metrics/diversity/srecall.py +107 -0
- warprec/evaluation/metrics/fairness/__init__.py +21 -0
- warprec/evaluation/metrics/fairness/biasdisparitybd.py +96 -0
- warprec/evaluation/metrics/fairness/biasdisparitybr.py +124 -0
- warprec/evaluation/metrics/fairness/biasdisparitybs.py +124 -0
- warprec/evaluation/metrics/fairness/itemmadranking.py +130 -0
- warprec/evaluation/metrics/fairness/itemmadrating.py +134 -0
- warprec/evaluation/metrics/fairness/reo.py +154 -0
- warprec/evaluation/metrics/fairness/rsp.py +167 -0
- warprec/evaluation/metrics/fairness/usermadranking.py +121 -0
- warprec/evaluation/metrics/fairness/usermadrating.py +98 -0
- warprec/evaluation/metrics/multiobjective/__init__.py +4 -0
- warprec/evaluation/metrics/multiobjective/euclideandistance.py +115 -0
- warprec/evaluation/metrics/multiobjective/hypervolume.py +134 -0
- warprec/evaluation/metrics/novelty/__init__.py +4 -0
- warprec/evaluation/metrics/novelty/efd.py +100 -0
- warprec/evaluation/metrics/novelty/epc.py +100 -0
- warprec/evaluation/metrics/rating/__init__.py +9 -0
- warprec/evaluation/metrics/rating/mae.py +16 -0
- warprec/evaluation/metrics/rating/mse.py +15 -0
- warprec/evaluation/metrics/rating/rmse.py +21 -0
- warprec/evaluation/statistical_significance.py +323 -0
- warprec/pipelines/__init__.py +5 -0
- warprec/pipelines/design.py +132 -0
- warprec/pipelines/eval.py +209 -0
- warprec/pipelines/train.py +559 -0
- warprec/recommenders/__init__.py +47 -0
- warprec/recommenders/base_recommender.py +722 -0
- warprec/recommenders/collaborative_filtering_recommender/__init__.py +13 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/__init__.py +9 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/cdae.py +218 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/ease.py +47 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/elsa.py +142 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/macridvae.py +324 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/multidae.py +216 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/multivae.py +255 -0
- warprec/recommenders/collaborative_filtering_recommender/autoencoder/sansa.py +197 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/__init__.py +213 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/dgcf.py +336 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/egcf.py +332 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/esigcf.py +288 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/gcmc.py +337 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/graph_utils.py +173 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/lightccf.py +216 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcl.py +309 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcn.py +201 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcnpp.py +278 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgode.py +258 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/mixrec.py +346 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/ngcf.py +349 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/rp3beta.py +186 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/sgcl.py +201 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/sgl.py +306 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/ultragcn.py +268 -0
- warprec/recommenders/collaborative_filtering_recommender/graph_based/xsimgcl.py +239 -0
- warprec/recommenders/collaborative_filtering_recommender/knn/__init__.py +4 -0
- warprec/recommenders/collaborative_filtering_recommender/knn/itemknn.py +52 -0
- warprec/recommenders/collaborative_filtering_recommender/knn/userknn.py +92 -0
- warprec/recommenders/collaborative_filtering_recommender/latent_factor/__init__.py +6 -0
- warprec/recommenders/collaborative_filtering_recommender/latent_factor/admmslim.py +95 -0
- warprec/recommenders/collaborative_filtering_recommender/latent_factor/bpr.py +150 -0
- warprec/recommenders/collaborative_filtering_recommender/latent_factor/fism.py +219 -0
- warprec/recommenders/collaborative_filtering_recommender/latent_factor/slim.py +74 -0
- warprec/recommenders/collaborative_filtering_recommender/neural/__init__.py +4 -0
- warprec/recommenders/collaborative_filtering_recommender/neural/convncf.py +221 -0
- warprec/recommenders/collaborative_filtering_recommender/neural/neumf.py +230 -0
- warprec/recommenders/content_based_recommender/__init__.py +3 -0
- warprec/recommenders/content_based_recommender/vsm.py +147 -0
- warprec/recommenders/context_aware_recommender/__init__.py +10 -0
- warprec/recommenders/context_aware_recommender/afm.py +412 -0
- warprec/recommenders/context_aware_recommender/dcn.py +353 -0
- warprec/recommenders/context_aware_recommender/dcnv2.py +418 -0
- warprec/recommenders/context_aware_recommender/deepfm.py +365 -0
- warprec/recommenders/context_aware_recommender/fm.py +251 -0
- warprec/recommenders/context_aware_recommender/nfm.py +314 -0
- warprec/recommenders/context_aware_recommender/wideanddeep.py +310 -0
- warprec/recommenders/context_aware_recommender/xdeepfm.py +468 -0
- warprec/recommenders/hybrid_recommender/__init__.py +6 -0
- warprec/recommenders/hybrid_recommender/addease.py +59 -0
- warprec/recommenders/hybrid_recommender/attributeitemknn.py +52 -0
- warprec/recommenders/hybrid_recommender/attributeuserknn.py +121 -0
- warprec/recommenders/hybrid_recommender/cease.py +54 -0
- warprec/recommenders/layers.py +159 -0
- warprec/recommenders/loops.py +102 -0
- warprec/recommenders/losses.py +196 -0
- warprec/recommenders/lr_scheduler_wrapper.py +163 -0
- warprec/recommenders/proxy.py +146 -0
- warprec/recommenders/sequential_recommender/__init__.py +23 -0
- warprec/recommenders/sequential_recommender/bert4rec.py +279 -0
- warprec/recommenders/sequential_recommender/caser.py +274 -0
- warprec/recommenders/sequential_recommender/core.py +281 -0
- warprec/recommenders/sequential_recommender/fossil.py +345 -0
- warprec/recommenders/sequential_recommender/gru4rec.py +225 -0
- warprec/recommenders/sequential_recommender/gsasrec.py +303 -0
- warprec/recommenders/sequential_recommender/lightsans.py +399 -0
- warprec/recommenders/sequential_recommender/linrec.py +325 -0
- warprec/recommenders/sequential_recommender/narm.py +255 -0
- warprec/recommenders/sequential_recommender/sasrec.py +254 -0
- warprec/recommenders/similarities.py +80 -0
- warprec/recommenders/trainer/__init__.py +27 -0
- warprec/recommenders/trainer/objectives.py +751 -0
- warprec/recommenders/trainer/scheduler_wrapper.py +76 -0
- warprec/recommenders/trainer/search_algorithm_wrapper.py +110 -0
- warprec/recommenders/trainer/trainer.py +758 -0
- warprec/recommenders/unpersonalized_recommender/__init__.py +4 -0
- warprec/recommenders/unpersonalized_recommender/pop.py +87 -0
- warprec/recommenders/unpersonalized_recommender/random.py +45 -0
- warprec/run.py +66 -0
- warprec/utils/__init__.py +64 -0
- warprec/utils/callback.py +98 -0
- warprec/utils/config/__init__.py +67 -0
- warprec/utils/config/common.py +465 -0
- warprec/utils/config/config.py +434 -0
- warprec/utils/config/dashboard_configuration.py +84 -0
- warprec/utils/config/evaluation_configuration.py +206 -0
- warprec/utils/config/general_configuration.py +193 -0
- warprec/utils/config/model_configuration.py +706 -0
- warprec/utils/config/reader_configuration.py +270 -0
- warprec/utils/config/recommender_model_config/__init__.py +15 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/__init__.py +25 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/autoencoder_config.py +461 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/graph_based_config.py +1301 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/knn_config.py +62 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/latent_factor_config.py +219 -0
- warprec/utils/config/recommender_model_config/collaborative_filtering_config/neural_config.py +384 -0
- warprec/utils/config/recommender_model_config/content_based_config.py +45 -0
- warprec/utils/config/recommender_model_config/context_aware_config.py +708 -0
- warprec/utils/config/recommender_model_config/hybrid_config.py +133 -0
- warprec/utils/config/recommender_model_config/sequential_model_config.py +1465 -0
- warprec/utils/config/recommender_model_config/unpersonalized_config.py +13 -0
- warprec/utils/config/search_space_wrapper.py +109 -0
- warprec/utils/config/splitter_configuration.py +204 -0
- warprec/utils/config/writer_configuration.py +131 -0
- warprec/utils/enums.py +283 -0
- warprec/utils/helpers.py +136 -0
- warprec/utils/logger/__init__.py +3 -0
- warprec/utils/logger/logger.py +126 -0
- warprec/utils/registry.py +128 -0
- warprec-1.0.0.dist-info/LICENSE +21 -0
- warprec-1.0.0.dist-info/METADATA +220 -0
- warprec-1.0.0.dist-info/RECORD +198 -0
- warprec-1.0.0.dist-info/WHEEL +4 -0
warprec/__init__.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
# pylint: disable=too-many-branches, too-many-statements
|
|
2
|
+
from typing import Tuple, List, Optional, Dict, Union, Any
|
|
3
|
+
|
|
4
|
+
from narwhals.dataframe import DataFrame
|
|
5
|
+
|
|
6
|
+
from warprec.data import Dataset
|
|
7
|
+
from warprec.data.reader import Reader
|
|
8
|
+
from warprec.data.splitting import Splitter
|
|
9
|
+
from warprec.data.filtering import apply_filtering
|
|
10
|
+
from warprec.recommenders.base_recommender import ContextRecommenderUtils
|
|
11
|
+
from warprec.utils.config import (
|
|
12
|
+
TrainConfiguration,
|
|
13
|
+
DesignConfiguration,
|
|
14
|
+
EvalConfiguration,
|
|
15
|
+
)
|
|
16
|
+
from warprec.utils.callback import WarpRecCallback
|
|
17
|
+
from warprec.utils.registry import model_registry
|
|
18
|
+
from warprec.utils.logger import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def initialize_datasets(
|
|
22
|
+
reader: Reader,
|
|
23
|
+
callback: WarpRecCallback,
|
|
24
|
+
config: Union[TrainConfiguration, DesignConfiguration, EvalConfiguration],
|
|
25
|
+
) -> Tuple[Dataset, Dataset | None, List[Dataset]]:
|
|
26
|
+
"""Initialize datasets based on the configuration. This is a common operation
|
|
27
|
+
used in both training and design scripts.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
reader (Reader): The initialized reader object that will be used to read data.
|
|
31
|
+
callback (WarpRecCallback): The callback object for handling events during initialization.
|
|
32
|
+
config (Union[TrainConfiguration, DesignConfiguration, EvalConfiguration]): The configuration
|
|
33
|
+
object containing all necessary settings for data loading, filtering, and splitting.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Tuple[Dataset, Dataset | None, List[Dataset]]: A tuple containing the main
|
|
37
|
+
dataset, an optional validation dataset, and a list of datasets for cross-validation folds.
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
ValueError: If the data type specified in the configuration is not supported.
|
|
41
|
+
"""
|
|
42
|
+
# Dataset loading
|
|
43
|
+
main_dataset: Dataset = None
|
|
44
|
+
val_data: List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any] = None
|
|
45
|
+
train_data: DataFrame[Any] = None
|
|
46
|
+
test_data: DataFrame[Any] = None
|
|
47
|
+
side_data = None
|
|
48
|
+
user_cluster = None
|
|
49
|
+
item_cluster = None
|
|
50
|
+
splitter = Splitter()
|
|
51
|
+
if config.reader.loading_strategy == "dataset":
|
|
52
|
+
file_format = config.reader.file_format
|
|
53
|
+
|
|
54
|
+
match file_format:
|
|
55
|
+
case "tabular":
|
|
56
|
+
data = reader.read_tabular(
|
|
57
|
+
**config.reader.model_dump(exclude=["labels", "dtypes"]), # type: ignore[arg-type]
|
|
58
|
+
column_names=config.reader.column_names(),
|
|
59
|
+
dtypes=config.reader.column_dtype(),
|
|
60
|
+
)
|
|
61
|
+
case "parquet":
|
|
62
|
+
data = reader.read_parquet(
|
|
63
|
+
**config.reader.model_dump(exclude=["labels", "dtypes"]), # type: ignore[arg-type]
|
|
64
|
+
)
|
|
65
|
+
case _:
|
|
66
|
+
raise ValueError(f"File format '{file_format}'not supported.")
|
|
67
|
+
data = callback.on_data_reading(data)
|
|
68
|
+
|
|
69
|
+
# Check for optional filtering
|
|
70
|
+
if config.filtering is not None:
|
|
71
|
+
filters = config.get_filters()
|
|
72
|
+
data = apply_filtering(data, filters)
|
|
73
|
+
|
|
74
|
+
# Splitter testing
|
|
75
|
+
if config.splitter:
|
|
76
|
+
if config.reader.data_type == "transaction":
|
|
77
|
+
# Gather splitting configurations
|
|
78
|
+
test_configuration = config.splitter.test_splitting.model_dump()
|
|
79
|
+
val_configuration = config.splitter.validation_splitting.model_dump()
|
|
80
|
+
|
|
81
|
+
# Add tag to distinguish test and validation keys
|
|
82
|
+
test_configuration = {
|
|
83
|
+
f"test_{key}": value for key, value in test_configuration.items()
|
|
84
|
+
}
|
|
85
|
+
val_configuration = {
|
|
86
|
+
f"val_{key}": value for key, value in val_configuration.items()
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Compute splitting
|
|
90
|
+
train_data, val_data, test_data = splitter.split_transaction(
|
|
91
|
+
data,
|
|
92
|
+
**config.reader.labels.model_dump(
|
|
93
|
+
exclude=["cluster_label", "context_labels"] # type: ignore[arg-type]
|
|
94
|
+
),
|
|
95
|
+
**test_configuration,
|
|
96
|
+
**val_configuration,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError("Data type not yet supported.")
|
|
101
|
+
|
|
102
|
+
elif config.reader.loading_strategy == "split":
|
|
103
|
+
if config.reader.data_type == "transaction":
|
|
104
|
+
file_format = config.reader.split.file_format
|
|
105
|
+
|
|
106
|
+
match file_format:
|
|
107
|
+
case "tabular":
|
|
108
|
+
train_data, val_data, test_data = reader.read_tabular_split(
|
|
109
|
+
**config.reader.split.model_dump(),
|
|
110
|
+
column_names=config.reader.column_names(),
|
|
111
|
+
dtypes=config.reader.column_dtype(),
|
|
112
|
+
)
|
|
113
|
+
case "parquet":
|
|
114
|
+
train_data, val_data, test_data = reader.read_parquet_split(
|
|
115
|
+
**config.reader.split.model_dump(),
|
|
116
|
+
column_names=config.reader.column_names(),
|
|
117
|
+
)
|
|
118
|
+
case _:
|
|
119
|
+
raise ValueError(f"File format '{file_format}'not supported.")
|
|
120
|
+
|
|
121
|
+
# Filter out train and validation data if not aligned with the training set
|
|
122
|
+
def _align_set_on_train(
|
|
123
|
+
train_set: DataFrame[Any], eval_set: DataFrame[Any]
|
|
124
|
+
):
|
|
125
|
+
eval_transactions = len(eval_set)
|
|
126
|
+
eval_set = splitter.filter_sets(
|
|
127
|
+
train_set,
|
|
128
|
+
eval_set,
|
|
129
|
+
**config.reader.labels.model_dump(
|
|
130
|
+
exclude=[
|
|
131
|
+
"rating_label",
|
|
132
|
+
"timestamp_label",
|
|
133
|
+
"cluster_label",
|
|
134
|
+
"context_labels",
|
|
135
|
+
] # type: ignore[arg-type]
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
if len(eval_set) == 0:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"After aligning the split with the training set, it resulted in an empty set. Please check the consistency of your splits and the filtering process."
|
|
141
|
+
)
|
|
142
|
+
if len(eval_set) < eval_transactions:
|
|
143
|
+
logger.attention(
|
|
144
|
+
f"Eval set was not aligned with the training set. Filtered out {eval_transactions - len(eval_set)} transactions."
|
|
145
|
+
)
|
|
146
|
+
return eval_set
|
|
147
|
+
|
|
148
|
+
test_data = _align_set_on_train(train_data, test_data)
|
|
149
|
+
if val_data is not None:
|
|
150
|
+
if isinstance(val_data, list):
|
|
151
|
+
for idx, (val_train, val_set) in enumerate(val_data):
|
|
152
|
+
val_data[idx] = _align_set_on_train(train_data, val_set)
|
|
153
|
+
else:
|
|
154
|
+
val_data = _align_set_on_train(train_data, val_data)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError("Data type not yet supported.")
|
|
157
|
+
|
|
158
|
+
# Side information reading
|
|
159
|
+
if config.reader.side:
|
|
160
|
+
file_format = config.reader.split.file_format
|
|
161
|
+
|
|
162
|
+
match file_format:
|
|
163
|
+
case "tabular":
|
|
164
|
+
side_data = reader.read_tabular(
|
|
165
|
+
**config.reader.side.model_dump(),
|
|
166
|
+
)
|
|
167
|
+
case "parquet":
|
|
168
|
+
side_data = reader.read_parquet(
|
|
169
|
+
**config.reader.side.model_dump(),
|
|
170
|
+
)
|
|
171
|
+
case _:
|
|
172
|
+
raise ValueError(f"File format '{file_format}'not supported.")
|
|
173
|
+
|
|
174
|
+
# Cluster information reading
|
|
175
|
+
if config.reader.clustering:
|
|
176
|
+
|
|
177
|
+
def _read_cluster_data_clean(
|
|
178
|
+
specific_config: dict,
|
|
179
|
+
common_cluster_label: str,
|
|
180
|
+
common_cluster_type: str,
|
|
181
|
+
reader: Reader,
|
|
182
|
+
) -> DataFrame[Any]:
|
|
183
|
+
"""Reads clustering data using a pre-prepared specific configuration (User or Item).
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
specific_config (dict): Specific configurations for user or item.
|
|
187
|
+
common_cluster_label (str): Common label for the cluster column.
|
|
188
|
+
common_cluster_type (str): Common data type for the cluster column.
|
|
189
|
+
reader (Reader): Object or module with the read_tabular method.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
DataFrame[Any]: A DataFrame containing the cluster data.
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
ValueError: If the file format is not supported.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
# Define column names
|
|
199
|
+
column_names = [
|
|
200
|
+
specific_config["id_label"],
|
|
201
|
+
common_cluster_label,
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
# Define data types (and map them to column names)
|
|
205
|
+
dtypes_list = [
|
|
206
|
+
specific_config["id_type"],
|
|
207
|
+
common_cluster_type,
|
|
208
|
+
]
|
|
209
|
+
dtype_map = zip(column_names, dtypes_list)
|
|
210
|
+
|
|
211
|
+
# Read data using the custom reader
|
|
212
|
+
file_format = specific_config["file_format"]
|
|
213
|
+
|
|
214
|
+
match file_format:
|
|
215
|
+
case "tabular":
|
|
216
|
+
cluster_data = reader.read_tabular(
|
|
217
|
+
local_path=specific_config["local_path"],
|
|
218
|
+
blob_name=specific_config["blob_name"],
|
|
219
|
+
column_names=column_names,
|
|
220
|
+
dtypes=dtype_map,
|
|
221
|
+
sep=specific_config["sep"],
|
|
222
|
+
header=specific_config["header"],
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
case "parquet":
|
|
226
|
+
cluster_data = reader.read_parquet(
|
|
227
|
+
local_path=specific_config["local_path"],
|
|
228
|
+
blob_name=specific_config["blob_name"],
|
|
229
|
+
column_names=column_names,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
case _:
|
|
233
|
+
raise ValueError(f"File format '{file_format}'not supported.")
|
|
234
|
+
|
|
235
|
+
return cluster_data
|
|
236
|
+
|
|
237
|
+
# Common clustering information
|
|
238
|
+
common_cluster_label = config.reader.labels.cluster_label
|
|
239
|
+
common_cluster_type = config.reader.dtypes.cluster_type
|
|
240
|
+
|
|
241
|
+
# User specific clustering information
|
|
242
|
+
user_config = {
|
|
243
|
+
"id_label": config.reader.labels.user_id_label,
|
|
244
|
+
"id_type": config.reader.dtypes.user_id_type,
|
|
245
|
+
"local_path": config.reader.clustering.user_local_path,
|
|
246
|
+
"blob_name": config.reader.clustering.user_azure_blob_name,
|
|
247
|
+
"file_format": config.reader.clustering.user_file_format,
|
|
248
|
+
"sep": config.reader.clustering.user_sep,
|
|
249
|
+
"header": config.reader.clustering.user_header,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
# Item specific clustering information
|
|
253
|
+
item_config = {
|
|
254
|
+
"id_label": config.reader.labels.item_id_label,
|
|
255
|
+
"id_type": config.reader.dtypes.item_id_type,
|
|
256
|
+
"local_path": config.reader.clustering.item_local_path,
|
|
257
|
+
"blob_name": config.reader.clustering.item_azure_blob_name,
|
|
258
|
+
"file_format": config.reader.clustering.item_file_format,
|
|
259
|
+
"sep": config.reader.clustering.item_sep,
|
|
260
|
+
"header": config.reader.clustering.item_header,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
# Read user clustering data
|
|
264
|
+
user_cluster = _read_cluster_data_clean(
|
|
265
|
+
specific_config=user_config,
|
|
266
|
+
common_cluster_label=common_cluster_label,
|
|
267
|
+
common_cluster_type=common_cluster_type,
|
|
268
|
+
reader=reader,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Read item clustering data
|
|
272
|
+
item_cluster = _read_cluster_data_clean(
|
|
273
|
+
specific_config=item_config,
|
|
274
|
+
common_cluster_label=common_cluster_label,
|
|
275
|
+
common_cluster_type=common_cluster_type,
|
|
276
|
+
reader=reader,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Dataset common information
|
|
280
|
+
common_params: Dict[str, Any] = {
|
|
281
|
+
"side_data": side_data,
|
|
282
|
+
"user_cluster": user_cluster,
|
|
283
|
+
"item_cluster": item_cluster,
|
|
284
|
+
"batch_size": config.evaluation.batch_size,
|
|
285
|
+
"rating_type": config.reader.rating_type,
|
|
286
|
+
"user_id_label": config.reader.labels.user_id_label,
|
|
287
|
+
"item_id_label": config.reader.labels.item_id_label,
|
|
288
|
+
"rating_label": config.reader.labels.rating_label,
|
|
289
|
+
"timestamp_label": config.reader.labels.timestamp_label,
|
|
290
|
+
"cluster_label": config.reader.labels.cluster_label,
|
|
291
|
+
"context_labels": config.reader.labels.context_labels,
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
logger.msg("Creating main dataset")
|
|
295
|
+
main_dataset = Dataset(
|
|
296
|
+
train_data,
|
|
297
|
+
test_data,
|
|
298
|
+
**common_params,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Handle validation data
|
|
302
|
+
val_dataset: Dataset = None
|
|
303
|
+
fold_dataset: List[Dataset] = []
|
|
304
|
+
if val_data is not None:
|
|
305
|
+
if not isinstance(val_data, list):
|
|
306
|
+
# CASE 2: Train/Validation/Test
|
|
307
|
+
logger.msg("Creating validation dataset")
|
|
308
|
+
val_dataset = Dataset(
|
|
309
|
+
train_data,
|
|
310
|
+
val_data,
|
|
311
|
+
evaluation_set="Validation",
|
|
312
|
+
**common_params,
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
# CASE 3: Cross-Validation
|
|
316
|
+
n_folds = len(val_data)
|
|
317
|
+
for idx, fold in enumerate(val_data):
|
|
318
|
+
logger.msg(f"Creating fold dataset {idx + 1}/{n_folds}")
|
|
319
|
+
val_train, val_set = fold
|
|
320
|
+
fold_dataset.append(
|
|
321
|
+
Dataset(
|
|
322
|
+
val_train,
|
|
323
|
+
val_set,
|
|
324
|
+
evaluation_set="Validation",
|
|
325
|
+
**common_params,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Callback on dataset creation
|
|
330
|
+
callback.on_dataset_creation(
|
|
331
|
+
main_dataset=main_dataset,
|
|
332
|
+
val_dataset=val_dataset,
|
|
333
|
+
validation_folds=fold_dataset,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
return main_dataset, val_dataset, fold_dataset
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def dataset_preparation(
|
|
340
|
+
main_dataset: Dataset,
|
|
341
|
+
fold_dataset: Optional[List[Dataset]],
|
|
342
|
+
config: TrainConfiguration,
|
|
343
|
+
):
|
|
344
|
+
"""This method prepares the dataloaders inside the dataset
|
|
345
|
+
that will be passed to Ray during HPO. It is important to
|
|
346
|
+
precompute these dataloaders before starting the optimization to
|
|
347
|
+
avoid multiple computations of the same dataloader.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
main_dataset (Dataset): The main dataset of train/test split.
|
|
351
|
+
fold_dataset (Optional[List[Dataset]]): The list of validation datasets
|
|
352
|
+
of train/val splits.
|
|
353
|
+
config (TrainConfiguration): The configuration file used for the experiment.
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def prepare_evaluation_loaders(
|
|
357
|
+
dataset: Dataset, has_classic: bool, has_context: bool
|
|
358
|
+
):
|
|
359
|
+
"""utility function to prepare the evaluation dataloaders
|
|
360
|
+
for a given dataset based on the evaluation strategy.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
dataset (Dataset): The dataset to prepare.
|
|
364
|
+
has_classic (bool): Wether or not experiment has a classic recommender.
|
|
365
|
+
has_context (bool): Wether or not experiment has a context recommender.
|
|
366
|
+
|
|
367
|
+
Raises:
|
|
368
|
+
ValueError: If both the flags are False.
|
|
369
|
+
"""
|
|
370
|
+
if not has_classic and not has_context:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
"Something went wrong. No correct model found during evaluation "
|
|
373
|
+
"initialization."
|
|
374
|
+
)
|
|
375
|
+
strategy = config.evaluation.strategy
|
|
376
|
+
|
|
377
|
+
# Initialize the classic evaluation structures
|
|
378
|
+
if has_classic:
|
|
379
|
+
if strategy == "full":
|
|
380
|
+
dataset.get_evaluation_dataloader()
|
|
381
|
+
elif strategy == "sampled":
|
|
382
|
+
dataset.get_sampled_evaluation_dataloader(
|
|
383
|
+
num_negatives=config.evaluation.num_negatives,
|
|
384
|
+
seed=config.evaluation.seed,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Initialize the contextual evaluation structures
|
|
388
|
+
if has_context:
|
|
389
|
+
if strategy == "full":
|
|
390
|
+
dataset.get_contextual_evaluation_dataloader()
|
|
391
|
+
elif strategy == "sampled":
|
|
392
|
+
dataset.get_sampled_contextual_evaluation_dataloader(
|
|
393
|
+
num_negatives=config.evaluation.num_negatives,
|
|
394
|
+
seed=config.evaluation.seed,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
logger.msg("Preparing main dataset inner structures for evaluation.")
|
|
398
|
+
|
|
399
|
+
model_classes = [
|
|
400
|
+
model_registry.get_class(model_name) for model_name in config.models.keys()
|
|
401
|
+
]
|
|
402
|
+
has_classic = any(
|
|
403
|
+
not issubclass(model_class, ContextRecommenderUtils)
|
|
404
|
+
for model_class in model_classes
|
|
405
|
+
)
|
|
406
|
+
has_context = any(
|
|
407
|
+
issubclass(model_class, ContextRecommenderUtils)
|
|
408
|
+
for model_class in model_classes
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
prepare_evaluation_loaders(main_dataset, has_classic, has_context)
|
|
412
|
+
if fold_dataset is not None and isinstance(fold_dataset, list):
|
|
413
|
+
for i, dataset in enumerate(fold_dataset):
|
|
414
|
+
logger.msg(
|
|
415
|
+
f"Preparing fold dataset {i + 1}/{len(fold_dataset)} inner structures for evaluation."
|
|
416
|
+
)
|
|
417
|
+
prepare_evaluation_loaders(dataset, has_classic, has_context)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.optim import Optimizer
|
|
4
|
+
|
|
5
|
+
from warprec.recommenders.base_recommender import IterativeRecommender
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def standard_optimizer(model: IterativeRecommender) -> Optimizer:
|
|
9
|
+
"""Pre-construct the standard optimizer used within WarpRec.
|
|
10
|
+
|
|
11
|
+
The standard approach uses Adam optimizer and separates parameters into two groups:
|
|
12
|
+
1. Decay Group (Adam handles L2):
|
|
13
|
+
- Dense layers weights (Linear, Conv).
|
|
14
|
+
- Structural embeddings (e.g., Positional Embeddings).
|
|
15
|
+
2. No-Decay Group:
|
|
16
|
+
- Sparse Entity Embeddings (User/Item) -> Handled manually by EmbLoss.
|
|
17
|
+
- Biases -> Standard DL practice (no decay).
|
|
18
|
+
- LayerNorm weights -> Standard Transformer practice (no decay).
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
model (IterativeRecommender): The model on which the optimization
|
|
22
|
+
will be performed.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Optimizer: The PyTorch optimizer adapted to model parameter.
|
|
26
|
+
"""
|
|
27
|
+
# Identify parameters that belong to nn.Embedding modules
|
|
28
|
+
embedding_param_ids = set()
|
|
29
|
+
for module in model.modules():
|
|
30
|
+
if isinstance(module, nn.Embedding):
|
|
31
|
+
for param in module.parameters():
|
|
32
|
+
embedding_param_ids.add(id(param))
|
|
33
|
+
|
|
34
|
+
# Separate parameters into groups
|
|
35
|
+
decay_params = []
|
|
36
|
+
no_decay_params = []
|
|
37
|
+
|
|
38
|
+
for name, param in model.named_parameters():
|
|
39
|
+
if not param.requires_grad:
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
# We disable Adam's weight decay for:
|
|
43
|
+
# A. Biases (standard practice)
|
|
44
|
+
# B. LayerNorm parameters (standard Transformer practice)
|
|
45
|
+
# C. Sparse Embeddings (User/Item), because we use EmbLoss for them.
|
|
46
|
+
# EXCEPTION: Positional Embeddings should have weight decay applied by Adam.
|
|
47
|
+
|
|
48
|
+
is_bias = "bias" in name
|
|
49
|
+
is_layernorm = "layernorm" in name or "norm" in name
|
|
50
|
+
is_embedding = id(param) in embedding_param_ids
|
|
51
|
+
is_positional = "position" in name # Heuristic to catch position_embedding
|
|
52
|
+
|
|
53
|
+
if is_bias or is_layernorm or (is_embedding and not is_positional):
|
|
54
|
+
no_decay_params.append(param)
|
|
55
|
+
else:
|
|
56
|
+
# Linear weights, Conv weights, and Positional Embeddings go here
|
|
57
|
+
decay_params.append(param)
|
|
58
|
+
|
|
59
|
+
# Finalize the Optimizer with correct groups
|
|
60
|
+
decay = getattr(model, "weight_decay", 0.0)
|
|
61
|
+
|
|
62
|
+
optimizer_grouped_parameters = [
|
|
63
|
+
{
|
|
64
|
+
"params": decay_params,
|
|
65
|
+
"weight_decay": decay,
|
|
66
|
+
},
|
|
67
|
+
{
|
|
68
|
+
"params": no_decay_params,
|
|
69
|
+
"weight_decay": 0.0,
|
|
70
|
+
},
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
return torch.optim.Adam(optimizer_grouped_parameters, lr=model.learning_rate)
|
warprec/data/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .dataset import Dataset
|
|
2
|
+
from .eval_loaders import (
|
|
3
|
+
EvaluationDataLoader,
|
|
4
|
+
ContextualEvaluationDataLoader,
|
|
5
|
+
SampledEvaluationDataLoader,
|
|
6
|
+
SampledContextualEvaluationDataLoader,
|
|
7
|
+
)
|
|
8
|
+
from . import entities
|
|
9
|
+
from . import reader
|
|
10
|
+
from . import splitting
|
|
11
|
+
from . import writer
|
|
12
|
+
from .filtering import Filter, apply_filtering
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"Dataset",
|
|
16
|
+
"EvaluationDataLoader",
|
|
17
|
+
"ContextualEvaluationDataLoader",
|
|
18
|
+
"SampledEvaluationDataLoader",
|
|
19
|
+
"SampledContextualEvaluationDataLoader",
|
|
20
|
+
"entities",
|
|
21
|
+
"reader",
|
|
22
|
+
"splitting",
|
|
23
|
+
"writer",
|
|
24
|
+
"Filter",
|
|
25
|
+
"apply_filtering",
|
|
26
|
+
]
|