warprec 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. warprec/__init__.py +6 -0
  2. warprec/common/__init__.py +8 -0
  3. warprec/common/initialize.py +417 -0
  4. warprec/common/optimizers.py +73 -0
  5. warprec/data/__init__.py +26 -0
  6. warprec/data/dataset.py +867 -0
  7. warprec/data/entities/__init__.py +10 -0
  8. warprec/data/entities/interactions.py +594 -0
  9. warprec/data/entities/sessions.py +366 -0
  10. warprec/data/entities/train_structures/__init__.py +21 -0
  11. warprec/data/entities/train_structures/custom_collate_fn.py +29 -0
  12. warprec/data/entities/train_structures/interaction_structures.py +199 -0
  13. warprec/data/entities/train_structures/session_structures.py +278 -0
  14. warprec/data/eval_loaders.py +400 -0
  15. warprec/data/filtering.py +543 -0
  16. warprec/data/reader/__init__.py +21 -0
  17. warprec/data/reader/azureblob_reader.py +239 -0
  18. warprec/data/reader/base_reader.py +587 -0
  19. warprec/data/reader/local_reader.py +147 -0
  20. warprec/data/splitting/__init__.py +4 -0
  21. warprec/data/splitting/splitter.py +248 -0
  22. warprec/data/splitting/strategies.py +498 -0
  23. warprec/data/writer/__init__.py +21 -0
  24. warprec/data/writer/azureblob_writer.py +145 -0
  25. warprec/data/writer/base_writer.py +565 -0
  26. warprec/data/writer/local_writer.py +110 -0
  27. warprec/evaluation/__init__.py +5 -0
  28. warprec/evaluation/evaluator.py +488 -0
  29. warprec/evaluation/metrics/__init__.py +19 -0
  30. warprec/evaluation/metrics/accuracy/__init__.py +26 -0
  31. warprec/evaluation/metrics/accuracy/auc.py +62 -0
  32. warprec/evaluation/metrics/accuracy/f1.py +107 -0
  33. warprec/evaluation/metrics/accuracy/gauc.py +51 -0
  34. warprec/evaluation/metrics/accuracy/hit_rate.py +24 -0
  35. warprec/evaluation/metrics/accuracy/lauc.py +57 -0
  36. warprec/evaluation/metrics/accuracy/map.py +36 -0
  37. warprec/evaluation/metrics/accuracy/mar.py +36 -0
  38. warprec/evaluation/metrics/accuracy/mrr.py +27 -0
  39. warprec/evaluation/metrics/accuracy/ndcg.py +71 -0
  40. warprec/evaluation/metrics/accuracy/precision.py +24 -0
  41. warprec/evaluation/metrics/accuracy/recall.py +31 -0
  42. warprec/evaluation/metrics/base_metric.py +474 -0
  43. warprec/evaluation/metrics/bias/__init__.py +7 -0
  44. warprec/evaluation/metrics/bias/aclt.py +77 -0
  45. warprec/evaluation/metrics/bias/aplt.py +75 -0
  46. warprec/evaluation/metrics/bias/arp.py +59 -0
  47. warprec/evaluation/metrics/bias/pop_reo.py +111 -0
  48. warprec/evaluation/metrics/bias/pop_rsp.py +104 -0
  49. warprec/evaluation/metrics/coverage/__init__.py +11 -0
  50. warprec/evaluation/metrics/coverage/item_coverage.py +58 -0
  51. warprec/evaluation/metrics/coverage/numretrieved.py +31 -0
  52. warprec/evaluation/metrics/coverage/user_coverage.py +42 -0
  53. warprec/evaluation/metrics/coverage/user_coverage_at_n.py +43 -0
  54. warprec/evaluation/metrics/diversity/__init__.py +9 -0
  55. warprec/evaluation/metrics/diversity/gini_index.py +95 -0
  56. warprec/evaluation/metrics/diversity/shannon_entropy.py +76 -0
  57. warprec/evaluation/metrics/diversity/srecall.py +107 -0
  58. warprec/evaluation/metrics/fairness/__init__.py +21 -0
  59. warprec/evaluation/metrics/fairness/biasdisparitybd.py +96 -0
  60. warprec/evaluation/metrics/fairness/biasdisparitybr.py +124 -0
  61. warprec/evaluation/metrics/fairness/biasdisparitybs.py +124 -0
  62. warprec/evaluation/metrics/fairness/itemmadranking.py +130 -0
  63. warprec/evaluation/metrics/fairness/itemmadrating.py +134 -0
  64. warprec/evaluation/metrics/fairness/reo.py +154 -0
  65. warprec/evaluation/metrics/fairness/rsp.py +167 -0
  66. warprec/evaluation/metrics/fairness/usermadranking.py +121 -0
  67. warprec/evaluation/metrics/fairness/usermadrating.py +98 -0
  68. warprec/evaluation/metrics/multiobjective/__init__.py +4 -0
  69. warprec/evaluation/metrics/multiobjective/euclideandistance.py +115 -0
  70. warprec/evaluation/metrics/multiobjective/hypervolume.py +134 -0
  71. warprec/evaluation/metrics/novelty/__init__.py +4 -0
  72. warprec/evaluation/metrics/novelty/efd.py +100 -0
  73. warprec/evaluation/metrics/novelty/epc.py +100 -0
  74. warprec/evaluation/metrics/rating/__init__.py +9 -0
  75. warprec/evaluation/metrics/rating/mae.py +16 -0
  76. warprec/evaluation/metrics/rating/mse.py +15 -0
  77. warprec/evaluation/metrics/rating/rmse.py +21 -0
  78. warprec/evaluation/statistical_significance.py +323 -0
  79. warprec/pipelines/__init__.py +5 -0
  80. warprec/pipelines/design.py +132 -0
  81. warprec/pipelines/eval.py +209 -0
  82. warprec/pipelines/train.py +559 -0
  83. warprec/recommenders/__init__.py +47 -0
  84. warprec/recommenders/base_recommender.py +722 -0
  85. warprec/recommenders/collaborative_filtering_recommender/__init__.py +13 -0
  86. warprec/recommenders/collaborative_filtering_recommender/autoencoder/__init__.py +9 -0
  87. warprec/recommenders/collaborative_filtering_recommender/autoencoder/cdae.py +218 -0
  88. warprec/recommenders/collaborative_filtering_recommender/autoencoder/ease.py +47 -0
  89. warprec/recommenders/collaborative_filtering_recommender/autoencoder/elsa.py +142 -0
  90. warprec/recommenders/collaborative_filtering_recommender/autoencoder/macridvae.py +324 -0
  91. warprec/recommenders/collaborative_filtering_recommender/autoencoder/multidae.py +216 -0
  92. warprec/recommenders/collaborative_filtering_recommender/autoencoder/multivae.py +255 -0
  93. warprec/recommenders/collaborative_filtering_recommender/autoencoder/sansa.py +197 -0
  94. warprec/recommenders/collaborative_filtering_recommender/graph_based/__init__.py +213 -0
  95. warprec/recommenders/collaborative_filtering_recommender/graph_based/dgcf.py +336 -0
  96. warprec/recommenders/collaborative_filtering_recommender/graph_based/egcf.py +332 -0
  97. warprec/recommenders/collaborative_filtering_recommender/graph_based/esigcf.py +288 -0
  98. warprec/recommenders/collaborative_filtering_recommender/graph_based/gcmc.py +337 -0
  99. warprec/recommenders/collaborative_filtering_recommender/graph_based/graph_utils.py +173 -0
  100. warprec/recommenders/collaborative_filtering_recommender/graph_based/lightccf.py +216 -0
  101. warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcl.py +309 -0
  102. warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcn.py +201 -0
  103. warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgcnpp.py +278 -0
  104. warprec/recommenders/collaborative_filtering_recommender/graph_based/lightgode.py +258 -0
  105. warprec/recommenders/collaborative_filtering_recommender/graph_based/mixrec.py +346 -0
  106. warprec/recommenders/collaborative_filtering_recommender/graph_based/ngcf.py +349 -0
  107. warprec/recommenders/collaborative_filtering_recommender/graph_based/rp3beta.py +186 -0
  108. warprec/recommenders/collaborative_filtering_recommender/graph_based/sgcl.py +201 -0
  109. warprec/recommenders/collaborative_filtering_recommender/graph_based/sgl.py +306 -0
  110. warprec/recommenders/collaborative_filtering_recommender/graph_based/ultragcn.py +268 -0
  111. warprec/recommenders/collaborative_filtering_recommender/graph_based/xsimgcl.py +239 -0
  112. warprec/recommenders/collaborative_filtering_recommender/knn/__init__.py +4 -0
  113. warprec/recommenders/collaborative_filtering_recommender/knn/itemknn.py +52 -0
  114. warprec/recommenders/collaborative_filtering_recommender/knn/userknn.py +92 -0
  115. warprec/recommenders/collaborative_filtering_recommender/latent_factor/__init__.py +6 -0
  116. warprec/recommenders/collaborative_filtering_recommender/latent_factor/admmslim.py +95 -0
  117. warprec/recommenders/collaborative_filtering_recommender/latent_factor/bpr.py +150 -0
  118. warprec/recommenders/collaborative_filtering_recommender/latent_factor/fism.py +219 -0
  119. warprec/recommenders/collaborative_filtering_recommender/latent_factor/slim.py +74 -0
  120. warprec/recommenders/collaborative_filtering_recommender/neural/__init__.py +4 -0
  121. warprec/recommenders/collaborative_filtering_recommender/neural/convncf.py +221 -0
  122. warprec/recommenders/collaborative_filtering_recommender/neural/neumf.py +230 -0
  123. warprec/recommenders/content_based_recommender/__init__.py +3 -0
  124. warprec/recommenders/content_based_recommender/vsm.py +147 -0
  125. warprec/recommenders/context_aware_recommender/__init__.py +10 -0
  126. warprec/recommenders/context_aware_recommender/afm.py +412 -0
  127. warprec/recommenders/context_aware_recommender/dcn.py +353 -0
  128. warprec/recommenders/context_aware_recommender/dcnv2.py +418 -0
  129. warprec/recommenders/context_aware_recommender/deepfm.py +365 -0
  130. warprec/recommenders/context_aware_recommender/fm.py +251 -0
  131. warprec/recommenders/context_aware_recommender/nfm.py +314 -0
  132. warprec/recommenders/context_aware_recommender/wideanddeep.py +310 -0
  133. warprec/recommenders/context_aware_recommender/xdeepfm.py +468 -0
  134. warprec/recommenders/hybrid_recommender/__init__.py +6 -0
  135. warprec/recommenders/hybrid_recommender/addease.py +59 -0
  136. warprec/recommenders/hybrid_recommender/attributeitemknn.py +52 -0
  137. warprec/recommenders/hybrid_recommender/attributeuserknn.py +121 -0
  138. warprec/recommenders/hybrid_recommender/cease.py +54 -0
  139. warprec/recommenders/layers.py +159 -0
  140. warprec/recommenders/loops.py +102 -0
  141. warprec/recommenders/losses.py +196 -0
  142. warprec/recommenders/lr_scheduler_wrapper.py +163 -0
  143. warprec/recommenders/proxy.py +146 -0
  144. warprec/recommenders/sequential_recommender/__init__.py +23 -0
  145. warprec/recommenders/sequential_recommender/bert4rec.py +279 -0
  146. warprec/recommenders/sequential_recommender/caser.py +274 -0
  147. warprec/recommenders/sequential_recommender/core.py +281 -0
  148. warprec/recommenders/sequential_recommender/fossil.py +345 -0
  149. warprec/recommenders/sequential_recommender/gru4rec.py +225 -0
  150. warprec/recommenders/sequential_recommender/gsasrec.py +303 -0
  151. warprec/recommenders/sequential_recommender/lightsans.py +399 -0
  152. warprec/recommenders/sequential_recommender/linrec.py +325 -0
  153. warprec/recommenders/sequential_recommender/narm.py +255 -0
  154. warprec/recommenders/sequential_recommender/sasrec.py +254 -0
  155. warprec/recommenders/similarities.py +80 -0
  156. warprec/recommenders/trainer/__init__.py +27 -0
  157. warprec/recommenders/trainer/objectives.py +751 -0
  158. warprec/recommenders/trainer/scheduler_wrapper.py +76 -0
  159. warprec/recommenders/trainer/search_algorithm_wrapper.py +110 -0
  160. warprec/recommenders/trainer/trainer.py +758 -0
  161. warprec/recommenders/unpersonalized_recommender/__init__.py +4 -0
  162. warprec/recommenders/unpersonalized_recommender/pop.py +87 -0
  163. warprec/recommenders/unpersonalized_recommender/random.py +45 -0
  164. warprec/run.py +66 -0
  165. warprec/utils/__init__.py +64 -0
  166. warprec/utils/callback.py +98 -0
  167. warprec/utils/config/__init__.py +67 -0
  168. warprec/utils/config/common.py +465 -0
  169. warprec/utils/config/config.py +434 -0
  170. warprec/utils/config/dashboard_configuration.py +84 -0
  171. warprec/utils/config/evaluation_configuration.py +206 -0
  172. warprec/utils/config/general_configuration.py +193 -0
  173. warprec/utils/config/model_configuration.py +706 -0
  174. warprec/utils/config/reader_configuration.py +270 -0
  175. warprec/utils/config/recommender_model_config/__init__.py +15 -0
  176. warprec/utils/config/recommender_model_config/collaborative_filtering_config/__init__.py +25 -0
  177. warprec/utils/config/recommender_model_config/collaborative_filtering_config/autoencoder_config.py +461 -0
  178. warprec/utils/config/recommender_model_config/collaborative_filtering_config/graph_based_config.py +1301 -0
  179. warprec/utils/config/recommender_model_config/collaborative_filtering_config/knn_config.py +62 -0
  180. warprec/utils/config/recommender_model_config/collaborative_filtering_config/latent_factor_config.py +219 -0
  181. warprec/utils/config/recommender_model_config/collaborative_filtering_config/neural_config.py +384 -0
  182. warprec/utils/config/recommender_model_config/content_based_config.py +45 -0
  183. warprec/utils/config/recommender_model_config/context_aware_config.py +708 -0
  184. warprec/utils/config/recommender_model_config/hybrid_config.py +133 -0
  185. warprec/utils/config/recommender_model_config/sequential_model_config.py +1465 -0
  186. warprec/utils/config/recommender_model_config/unpersonalized_config.py +13 -0
  187. warprec/utils/config/search_space_wrapper.py +109 -0
  188. warprec/utils/config/splitter_configuration.py +204 -0
  189. warprec/utils/config/writer_configuration.py +131 -0
  190. warprec/utils/enums.py +283 -0
  191. warprec/utils/helpers.py +136 -0
  192. warprec/utils/logger/__init__.py +3 -0
  193. warprec/utils/logger/logger.py +126 -0
  194. warprec/utils/registry.py +128 -0
  195. warprec-1.0.0.dist-info/LICENSE +21 -0
  196. warprec-1.0.0.dist-info/METADATA +220 -0
  197. warprec-1.0.0.dist-info/RECORD +198 -0
  198. warprec-1.0.0.dist-info/WHEEL +4 -0
warprec/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from . import data
2
+ from . import evaluation
3
+ from . import recommenders
4
+ from . import utils
5
+
6
+ __all__ = ["data", "evaluation", "recommenders", "utils"]
@@ -0,0 +1,8 @@
1
+ from .initialize import initialize_datasets, dataset_preparation
2
+ from .optimizers import standard_optimizer
3
+
4
+ __all__ = [
5
+ "initialize_datasets",
6
+ "dataset_preparation",
7
+ "standard_optimizer",
8
+ ]
@@ -0,0 +1,417 @@
1
+ # pylint: disable=too-many-branches, too-many-statements
2
+ from typing import Tuple, List, Optional, Dict, Union, Any
3
+
4
+ from narwhals.dataframe import DataFrame
5
+
6
+ from warprec.data import Dataset
7
+ from warprec.data.reader import Reader
8
+ from warprec.data.splitting import Splitter
9
+ from warprec.data.filtering import apply_filtering
10
+ from warprec.recommenders.base_recommender import ContextRecommenderUtils
11
+ from warprec.utils.config import (
12
+ TrainConfiguration,
13
+ DesignConfiguration,
14
+ EvalConfiguration,
15
+ )
16
+ from warprec.utils.callback import WarpRecCallback
17
+ from warprec.utils.registry import model_registry
18
+ from warprec.utils.logger import logger
19
+
20
+
21
+ def initialize_datasets(
22
+ reader: Reader,
23
+ callback: WarpRecCallback,
24
+ config: Union[TrainConfiguration, DesignConfiguration, EvalConfiguration],
25
+ ) -> Tuple[Dataset, Dataset | None, List[Dataset]]:
26
+ """Initialize datasets based on the configuration. This is a common operation
27
+ used in both training and design scripts.
28
+
29
+ Args:
30
+ reader (Reader): The initialized reader object that will be used to read data.
31
+ callback (WarpRecCallback): The callback object for handling events during initialization.
32
+ config (Union[TrainConfiguration, DesignConfiguration, EvalConfiguration]): The configuration
33
+ object containing all necessary settings for data loading, filtering, and splitting.
34
+
35
+ Returns:
36
+ Tuple[Dataset, Dataset | None, List[Dataset]]: A tuple containing the main
37
+ dataset, an optional validation dataset, and a list of datasets for cross-validation folds.
38
+
39
+ Raises:
40
+ ValueError: If the data type specified in the configuration is not supported.
41
+ """
42
+ # Dataset loading
43
+ main_dataset: Dataset = None
44
+ val_data: List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any] = None
45
+ train_data: DataFrame[Any] = None
46
+ test_data: DataFrame[Any] = None
47
+ side_data = None
48
+ user_cluster = None
49
+ item_cluster = None
50
+ splitter = Splitter()
51
+ if config.reader.loading_strategy == "dataset":
52
+ file_format = config.reader.file_format
53
+
54
+ match file_format:
55
+ case "tabular":
56
+ data = reader.read_tabular(
57
+ **config.reader.model_dump(exclude=["labels", "dtypes"]), # type: ignore[arg-type]
58
+ column_names=config.reader.column_names(),
59
+ dtypes=config.reader.column_dtype(),
60
+ )
61
+ case "parquet":
62
+ data = reader.read_parquet(
63
+ **config.reader.model_dump(exclude=["labels", "dtypes"]), # type: ignore[arg-type]
64
+ )
65
+ case _:
66
+ raise ValueError(f"File format '{file_format}'not supported.")
67
+ data = callback.on_data_reading(data)
68
+
69
+ # Check for optional filtering
70
+ if config.filtering is not None:
71
+ filters = config.get_filters()
72
+ data = apply_filtering(data, filters)
73
+
74
+ # Splitter testing
75
+ if config.splitter:
76
+ if config.reader.data_type == "transaction":
77
+ # Gather splitting configurations
78
+ test_configuration = config.splitter.test_splitting.model_dump()
79
+ val_configuration = config.splitter.validation_splitting.model_dump()
80
+
81
+ # Add tag to distinguish test and validation keys
82
+ test_configuration = {
83
+ f"test_{key}": value for key, value in test_configuration.items()
84
+ }
85
+ val_configuration = {
86
+ f"val_{key}": value for key, value in val_configuration.items()
87
+ }
88
+
89
+ # Compute splitting
90
+ train_data, val_data, test_data = splitter.split_transaction(
91
+ data,
92
+ **config.reader.labels.model_dump(
93
+ exclude=["cluster_label", "context_labels"] # type: ignore[arg-type]
94
+ ),
95
+ **test_configuration,
96
+ **val_configuration,
97
+ )
98
+
99
+ else:
100
+ raise ValueError("Data type not yet supported.")
101
+
102
+ elif config.reader.loading_strategy == "split":
103
+ if config.reader.data_type == "transaction":
104
+ file_format = config.reader.split.file_format
105
+
106
+ match file_format:
107
+ case "tabular":
108
+ train_data, val_data, test_data = reader.read_tabular_split(
109
+ **config.reader.split.model_dump(),
110
+ column_names=config.reader.column_names(),
111
+ dtypes=config.reader.column_dtype(),
112
+ )
113
+ case "parquet":
114
+ train_data, val_data, test_data = reader.read_parquet_split(
115
+ **config.reader.split.model_dump(),
116
+ column_names=config.reader.column_names(),
117
+ )
118
+ case _:
119
+ raise ValueError(f"File format '{file_format}'not supported.")
120
+
121
+ # Filter out train and validation data if not aligned with the training set
122
+ def _align_set_on_train(
123
+ train_set: DataFrame[Any], eval_set: DataFrame[Any]
124
+ ):
125
+ eval_transactions = len(eval_set)
126
+ eval_set = splitter.filter_sets(
127
+ train_set,
128
+ eval_set,
129
+ **config.reader.labels.model_dump(
130
+ exclude=[
131
+ "rating_label",
132
+ "timestamp_label",
133
+ "cluster_label",
134
+ "context_labels",
135
+ ] # type: ignore[arg-type]
136
+ ),
137
+ )
138
+ if len(eval_set) == 0:
139
+ raise ValueError(
140
+ "After aligning the split with the training set, it resulted in an empty set. Please check the consistency of your splits and the filtering process."
141
+ )
142
+ if len(eval_set) < eval_transactions:
143
+ logger.attention(
144
+ f"Eval set was not aligned with the training set. Filtered out {eval_transactions - len(eval_set)} transactions."
145
+ )
146
+ return eval_set
147
+
148
+ test_data = _align_set_on_train(train_data, test_data)
149
+ if val_data is not None:
150
+ if isinstance(val_data, list):
151
+ for idx, (val_train, val_set) in enumerate(val_data):
152
+ val_data[idx] = _align_set_on_train(train_data, val_set)
153
+ else:
154
+ val_data = _align_set_on_train(train_data, val_data)
155
+ else:
156
+ raise ValueError("Data type not yet supported.")
157
+
158
+ # Side information reading
159
+ if config.reader.side:
160
+ file_format = config.reader.split.file_format
161
+
162
+ match file_format:
163
+ case "tabular":
164
+ side_data = reader.read_tabular(
165
+ **config.reader.side.model_dump(),
166
+ )
167
+ case "parquet":
168
+ side_data = reader.read_parquet(
169
+ **config.reader.side.model_dump(),
170
+ )
171
+ case _:
172
+ raise ValueError(f"File format '{file_format}'not supported.")
173
+
174
+ # Cluster information reading
175
+ if config.reader.clustering:
176
+
177
+ def _read_cluster_data_clean(
178
+ specific_config: dict,
179
+ common_cluster_label: str,
180
+ common_cluster_type: str,
181
+ reader: Reader,
182
+ ) -> DataFrame[Any]:
183
+ """Reads clustering data using a pre-prepared specific configuration (User or Item).
184
+
185
+ Args:
186
+ specific_config (dict): Specific configurations for user or item.
187
+ common_cluster_label (str): Common label for the cluster column.
188
+ common_cluster_type (str): Common data type for the cluster column.
189
+ reader (Reader): Object or module with the read_tabular method.
190
+
191
+ Returns:
192
+ DataFrame[Any]: A DataFrame containing the cluster data.
193
+
194
+ Raises:
195
+ ValueError: If the file format is not supported.
196
+ """
197
+
198
+ # Define column names
199
+ column_names = [
200
+ specific_config["id_label"],
201
+ common_cluster_label,
202
+ ]
203
+
204
+ # Define data types (and map them to column names)
205
+ dtypes_list = [
206
+ specific_config["id_type"],
207
+ common_cluster_type,
208
+ ]
209
+ dtype_map = zip(column_names, dtypes_list)
210
+
211
+ # Read data using the custom reader
212
+ file_format = specific_config["file_format"]
213
+
214
+ match file_format:
215
+ case "tabular":
216
+ cluster_data = reader.read_tabular(
217
+ local_path=specific_config["local_path"],
218
+ blob_name=specific_config["blob_name"],
219
+ column_names=column_names,
220
+ dtypes=dtype_map,
221
+ sep=specific_config["sep"],
222
+ header=specific_config["header"],
223
+ )
224
+
225
+ case "parquet":
226
+ cluster_data = reader.read_parquet(
227
+ local_path=specific_config["local_path"],
228
+ blob_name=specific_config["blob_name"],
229
+ column_names=column_names,
230
+ )
231
+
232
+ case _:
233
+ raise ValueError(f"File format '{file_format}'not supported.")
234
+
235
+ return cluster_data
236
+
237
+ # Common clustering information
238
+ common_cluster_label = config.reader.labels.cluster_label
239
+ common_cluster_type = config.reader.dtypes.cluster_type
240
+
241
+ # User specific clustering information
242
+ user_config = {
243
+ "id_label": config.reader.labels.user_id_label,
244
+ "id_type": config.reader.dtypes.user_id_type,
245
+ "local_path": config.reader.clustering.user_local_path,
246
+ "blob_name": config.reader.clustering.user_azure_blob_name,
247
+ "file_format": config.reader.clustering.user_file_format,
248
+ "sep": config.reader.clustering.user_sep,
249
+ "header": config.reader.clustering.user_header,
250
+ }
251
+
252
+ # Item specific clustering information
253
+ item_config = {
254
+ "id_label": config.reader.labels.item_id_label,
255
+ "id_type": config.reader.dtypes.item_id_type,
256
+ "local_path": config.reader.clustering.item_local_path,
257
+ "blob_name": config.reader.clustering.item_azure_blob_name,
258
+ "file_format": config.reader.clustering.item_file_format,
259
+ "sep": config.reader.clustering.item_sep,
260
+ "header": config.reader.clustering.item_header,
261
+ }
262
+
263
+ # Read user clustering data
264
+ user_cluster = _read_cluster_data_clean(
265
+ specific_config=user_config,
266
+ common_cluster_label=common_cluster_label,
267
+ common_cluster_type=common_cluster_type,
268
+ reader=reader,
269
+ )
270
+
271
+ # Read item clustering data
272
+ item_cluster = _read_cluster_data_clean(
273
+ specific_config=item_config,
274
+ common_cluster_label=common_cluster_label,
275
+ common_cluster_type=common_cluster_type,
276
+ reader=reader,
277
+ )
278
+
279
+ # Dataset common information
280
+ common_params: Dict[str, Any] = {
281
+ "side_data": side_data,
282
+ "user_cluster": user_cluster,
283
+ "item_cluster": item_cluster,
284
+ "batch_size": config.evaluation.batch_size,
285
+ "rating_type": config.reader.rating_type,
286
+ "user_id_label": config.reader.labels.user_id_label,
287
+ "item_id_label": config.reader.labels.item_id_label,
288
+ "rating_label": config.reader.labels.rating_label,
289
+ "timestamp_label": config.reader.labels.timestamp_label,
290
+ "cluster_label": config.reader.labels.cluster_label,
291
+ "context_labels": config.reader.labels.context_labels,
292
+ }
293
+
294
+ logger.msg("Creating main dataset")
295
+ main_dataset = Dataset(
296
+ train_data,
297
+ test_data,
298
+ **common_params,
299
+ )
300
+
301
+ # Handle validation data
302
+ val_dataset: Dataset = None
303
+ fold_dataset: List[Dataset] = []
304
+ if val_data is not None:
305
+ if not isinstance(val_data, list):
306
+ # CASE 2: Train/Validation/Test
307
+ logger.msg("Creating validation dataset")
308
+ val_dataset = Dataset(
309
+ train_data,
310
+ val_data,
311
+ evaluation_set="Validation",
312
+ **common_params,
313
+ )
314
+ else:
315
+ # CASE 3: Cross-Validation
316
+ n_folds = len(val_data)
317
+ for idx, fold in enumerate(val_data):
318
+ logger.msg(f"Creating fold dataset {idx + 1}/{n_folds}")
319
+ val_train, val_set = fold
320
+ fold_dataset.append(
321
+ Dataset(
322
+ val_train,
323
+ val_set,
324
+ evaluation_set="Validation",
325
+ **common_params,
326
+ )
327
+ )
328
+
329
+ # Callback on dataset creation
330
+ callback.on_dataset_creation(
331
+ main_dataset=main_dataset,
332
+ val_dataset=val_dataset,
333
+ validation_folds=fold_dataset,
334
+ )
335
+
336
+ return main_dataset, val_dataset, fold_dataset
337
+
338
+
339
+ def dataset_preparation(
340
+ main_dataset: Dataset,
341
+ fold_dataset: Optional[List[Dataset]],
342
+ config: TrainConfiguration,
343
+ ):
344
+ """This method prepares the dataloaders inside the dataset
345
+ that will be passed to Ray during HPO. It is important to
346
+ precompute these dataloaders before starting the optimization to
347
+ avoid multiple computations of the same dataloader.
348
+
349
+ Args:
350
+ main_dataset (Dataset): The main dataset of train/test split.
351
+ fold_dataset (Optional[List[Dataset]]): The list of validation datasets
352
+ of train/val splits.
353
+ config (TrainConfiguration): The configuration file used for the experiment.
354
+ """
355
+
356
+ def prepare_evaluation_loaders(
357
+ dataset: Dataset, has_classic: bool, has_context: bool
358
+ ):
359
+ """utility function to prepare the evaluation dataloaders
360
+ for a given dataset based on the evaluation strategy.
361
+
362
+ Args:
363
+ dataset (Dataset): The dataset to prepare.
364
+ has_classic (bool): Wether or not experiment has a classic recommender.
365
+ has_context (bool): Wether or not experiment has a context recommender.
366
+
367
+ Raises:
368
+ ValueError: If both the flags are False.
369
+ """
370
+ if not has_classic and not has_context:
371
+ raise ValueError(
372
+ "Something went wrong. No correct model found during evaluation "
373
+ "initialization."
374
+ )
375
+ strategy = config.evaluation.strategy
376
+
377
+ # Initialize the classic evaluation structures
378
+ if has_classic:
379
+ if strategy == "full":
380
+ dataset.get_evaluation_dataloader()
381
+ elif strategy == "sampled":
382
+ dataset.get_sampled_evaluation_dataloader(
383
+ num_negatives=config.evaluation.num_negatives,
384
+ seed=config.evaluation.seed,
385
+ )
386
+
387
+ # Initialize the contextual evaluation structures
388
+ if has_context:
389
+ if strategy == "full":
390
+ dataset.get_contextual_evaluation_dataloader()
391
+ elif strategy == "sampled":
392
+ dataset.get_sampled_contextual_evaluation_dataloader(
393
+ num_negatives=config.evaluation.num_negatives,
394
+ seed=config.evaluation.seed,
395
+ )
396
+
397
+ logger.msg("Preparing main dataset inner structures for evaluation.")
398
+
399
+ model_classes = [
400
+ model_registry.get_class(model_name) for model_name in config.models.keys()
401
+ ]
402
+ has_classic = any(
403
+ not issubclass(model_class, ContextRecommenderUtils)
404
+ for model_class in model_classes
405
+ )
406
+ has_context = any(
407
+ issubclass(model_class, ContextRecommenderUtils)
408
+ for model_class in model_classes
409
+ )
410
+
411
+ prepare_evaluation_loaders(main_dataset, has_classic, has_context)
412
+ if fold_dataset is not None and isinstance(fold_dataset, list):
413
+ for i, dataset in enumerate(fold_dataset):
414
+ logger.msg(
415
+ f"Preparing fold dataset {i + 1}/{len(fold_dataset)} inner structures for evaluation."
416
+ )
417
+ prepare_evaluation_loaders(dataset, has_classic, has_context)
@@ -0,0 +1,73 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.optim import Optimizer
4
+
5
+ from warprec.recommenders.base_recommender import IterativeRecommender
6
+
7
+
8
+ def standard_optimizer(model: IterativeRecommender) -> Optimizer:
9
+ """Pre-construct the standard optimizer used within WarpRec.
10
+
11
+ The standard approach uses Adam optimizer and separates parameters into two groups:
12
+ 1. Decay Group (Adam handles L2):
13
+ - Dense layers weights (Linear, Conv).
14
+ - Structural embeddings (e.g., Positional Embeddings).
15
+ 2. No-Decay Group:
16
+ - Sparse Entity Embeddings (User/Item) -> Handled manually by EmbLoss.
17
+ - Biases -> Standard DL practice (no decay).
18
+ - LayerNorm weights -> Standard Transformer practice (no decay).
19
+
20
+ Args:
21
+ model (IterativeRecommender): The model on which the optimization
22
+ will be performed.
23
+
24
+ Returns:
25
+ Optimizer: The PyTorch optimizer adapted to model parameter.
26
+ """
27
+ # Identify parameters that belong to nn.Embedding modules
28
+ embedding_param_ids = set()
29
+ for module in model.modules():
30
+ if isinstance(module, nn.Embedding):
31
+ for param in module.parameters():
32
+ embedding_param_ids.add(id(param))
33
+
34
+ # Separate parameters into groups
35
+ decay_params = []
36
+ no_decay_params = []
37
+
38
+ for name, param in model.named_parameters():
39
+ if not param.requires_grad:
40
+ continue
41
+
42
+ # We disable Adam's weight decay for:
43
+ # A. Biases (standard practice)
44
+ # B. LayerNorm parameters (standard Transformer practice)
45
+ # C. Sparse Embeddings (User/Item), because we use EmbLoss for them.
46
+ # EXCEPTION: Positional Embeddings should have weight decay applied by Adam.
47
+
48
+ is_bias = "bias" in name
49
+ is_layernorm = "layernorm" in name or "norm" in name
50
+ is_embedding = id(param) in embedding_param_ids
51
+ is_positional = "position" in name # Heuristic to catch position_embedding
52
+
53
+ if is_bias or is_layernorm or (is_embedding and not is_positional):
54
+ no_decay_params.append(param)
55
+ else:
56
+ # Linear weights, Conv weights, and Positional Embeddings go here
57
+ decay_params.append(param)
58
+
59
+ # Finalize the Optimizer with correct groups
60
+ decay = getattr(model, "weight_decay", 0.0)
61
+
62
+ optimizer_grouped_parameters = [
63
+ {
64
+ "params": decay_params,
65
+ "weight_decay": decay,
66
+ },
67
+ {
68
+ "params": no_decay_params,
69
+ "weight_decay": 0.0,
70
+ },
71
+ ]
72
+
73
+ return torch.optim.Adam(optimizer_grouped_parameters, lr=model.learning_rate)
@@ -0,0 +1,26 @@
1
+ from .dataset import Dataset
2
+ from .eval_loaders import (
3
+ EvaluationDataLoader,
4
+ ContextualEvaluationDataLoader,
5
+ SampledEvaluationDataLoader,
6
+ SampledContextualEvaluationDataLoader,
7
+ )
8
+ from . import entities
9
+ from . import reader
10
+ from . import splitting
11
+ from . import writer
12
+ from .filtering import Filter, apply_filtering
13
+
14
+ __all__ = [
15
+ "Dataset",
16
+ "EvaluationDataLoader",
17
+ "ContextualEvaluationDataLoader",
18
+ "SampledEvaluationDataLoader",
19
+ "SampledContextualEvaluationDataLoader",
20
+ "entities",
21
+ "reader",
22
+ "splitting",
23
+ "writer",
24
+ "Filter",
25
+ "apply_filtering",
26
+ ]