replay-rec 0.20.3__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +603 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +50 -0
  18. replay/experimental/models/admm_slim.py +257 -0
  19. replay/experimental/models/base_neighbour_rec.py +200 -0
  20. replay/experimental/models/base_rec.py +1386 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +932 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +264 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/hierarchical_recommender.py +331 -0
  32. replay/experimental/models/implicit_wrap.py +131 -0
  33. replay/experimental/models/lightfm_wrap.py +303 -0
  34. replay/experimental/models/mult_vae.py +332 -0
  35. replay/experimental/models/neural_ts.py +986 -0
  36. replay/experimental/models/neuromf.py +406 -0
  37. replay/experimental/models/scala_als.py +293 -0
  38. replay/experimental/models/u_lin_ucb.py +115 -0
  39. replay/experimental/nn/data/__init__.py +1 -0
  40. replay/experimental/nn/data/schema_builder.py +102 -0
  41. replay/experimental/preprocessing/__init__.py +3 -0
  42. replay/experimental/preprocessing/data_preparator.py +839 -0
  43. replay/experimental/preprocessing/padder.py +229 -0
  44. replay/experimental/preprocessing/sequence_generator.py +208 -0
  45. replay/experimental/scenarios/__init__.py +1 -0
  46. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  47. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  48. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  49. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  50. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  51. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  52. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  53. replay/experimental/utils/__init__.py +0 -0
  54. replay/experimental/utils/logger.py +24 -0
  55. replay/experimental/utils/model_handler.py +186 -0
  56. replay/experimental/utils/session_handler.py +44 -0
  57. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  58. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +61 -6
  59. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,229 @@
1
+ from collections.abc import Iterable
2
+ from typing import Optional, Union
3
+
4
+ from pandas.api.types import is_object_dtype
5
+
6
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, SparkDataFrame
7
+
8
+ if PYSPARK_AVAILABLE:
9
+ from pyspark.sql import functions as sf
10
+
11
+
12
+ class Padder:
13
+ """
14
+ Pad array columns in dataframe.
15
+
16
+ >>> import pandas as pd
17
+ >>> pad_interactions = pd.DataFrame({
18
+ ... "user_id": [1, 1, 1, 1, 2, 2, 3, 3, 3],
19
+ ... "timestamp": [[1], [1, 2], [1, 2, 4], [1, 2, 4, 6], [4, 7, 12],
20
+ ... [4, 7, 12, 126], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6],
21
+ ... [1, 2, 3, 4, 5, 6, 7]],
22
+ ... "item_id": [['a'], ['a', 'b'], ['a', 'b', 'd'], ['a', 'b', 'd', 'f'], ['d', 'e', 'm'],
23
+ ... ['d', 'e', 'm', 'g'], ['a', 'b', 'c', 'd', 'a'], ['a', 'b', 'c', 'd', 'a', 'f'],
24
+ ... ['a', 'b', 'c', 'd', 'a', 'f', 'e']]
25
+ ... })
26
+ >>> pad_interactions
27
+ user_id timestamp item_id
28
+ 0 1 [1] [a]
29
+ 1 1 [1, 2] [a, b]
30
+ 2 1 [1, 2, 4] [a, b, d]
31
+ 3 1 [1, 2, 4, 6] [a, b, d, f]
32
+ 4 2 [4, 7, 12] [d, e, m]
33
+ 5 2 [4, 7, 12, 126] [d, e, m, g]
34
+ 6 3 [1, 2, 3, 4, 5] [a, b, c, d, a]
35
+ 7 3 [1, 2, 3, 4, 5, 6] [a, b, c, d, a, f]
36
+ 8 3 [1, 2, 3, 4, 5, 6, 7] [a, b, c, d, a, f, e]
37
+ >>> Padder(
38
+ ... pad_columns=["item_id", "timestamp"],
39
+ ... padding_side="right",
40
+ ... padding_value=["[PAD]", 0],
41
+ ... array_size=5,
42
+ ... cut_array=True,
43
+ ... cut_side="right"
44
+ ... ).transform(pad_interactions)
45
+ user_id timestamp item_id
46
+ 0 1 [1, 0, 0, 0, 0] [a, [PAD], [PAD], [PAD], [PAD]]
47
+ 1 1 [1, 2, 0, 0, 0] [a, b, [PAD], [PAD], [PAD]]
48
+ 2 1 [1, 2, 4, 0, 0] [a, b, d, [PAD], [PAD]]
49
+ 3 1 [1, 2, 4, 6, 0] [a, b, d, f, [PAD]]
50
+ 4 2 [4, 7, 12, 0, 0] [d, e, m, [PAD], [PAD]]
51
+ 5 2 [4, 7, 12, 126, 0] [d, e, m, g, [PAD]]
52
+ 6 3 [1, 2, 3, 4, 5] [a, b, c, d, a]
53
+ 7 3 [2, 3, 4, 5, 6] [b, c, d, a, f]
54
+ 8 3 [3, 4, 5, 6, 7] [c, d, a, f, e]
55
+ <BLANKLINE>
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ pad_columns: Union[str, list[str]],
61
+ padding_side: Optional[str] = "right",
62
+ padding_value: Union[str, float, list, None] = 0,
63
+ array_size: Optional[int] = None,
64
+ cut_array: Optional[bool] = True,
65
+ cut_side: Optional[str] = "right",
66
+ ):
67
+ """
68
+ :param pad_columns: Name of columns to pad.
69
+ :param padding_side: side of array to which add padding values. Can be "right" or "left",
70
+ default: ``right``.
71
+ :param padding_value: value to fill missing spacec,
72
+ default: ``0``.
73
+ :param array_size: needed array size,
74
+ default: ``None``
75
+ :param cut_array: is cutting arrays with shape more than array_size needed,
76
+ default: ``True``.
77
+ :param cut_side: side of array on which to cut to needed length. Can be "right" or "left",
78
+ default: ``right``.
79
+ """
80
+ self.pad_columns = (
81
+ pad_columns if (isinstance(pad_columns, Iterable) and not isinstance(pad_columns, str)) else [pad_columns]
82
+ )
83
+ if padding_side not in (
84
+ "right",
85
+ "left",
86
+ ):
87
+ msg = f"padding_side value {padding_side} is not implemented. Should be 'right' or 'left'"
88
+ raise ValueError(msg)
89
+
90
+ self.padding_side = padding_side
91
+ self.padding_value = (
92
+ padding_value
93
+ if (isinstance(padding_value, Iterable) and not isinstance(padding_value, str))
94
+ else [padding_value]
95
+ )
96
+ if len(self.padding_value) == 1 and len(self.pad_columns) > 1:
97
+ self.padding_value = self.padding_value * len(self.pad_columns)
98
+ if len(self.pad_columns) != len(self.padding_value):
99
+ msg = "pad_columns and padding_value should have same length"
100
+ raise ValueError(msg)
101
+
102
+ self.array_size = array_size
103
+ if self.array_size is not None:
104
+ if self.array_size < 1 or not isinstance(self.array_size, int):
105
+ msg = "array_size should be positive integer greater than 0"
106
+ raise ValueError(msg)
107
+ else:
108
+ self.array_size = -1
109
+
110
+ self.cut_array = cut_array
111
+ self.cut_side = cut_side
112
+
113
+ def transform(self, interactions: DataFrameLike) -> DataFrameLike:
114
+ """Pad dataframe.
115
+
116
+ :param interactions: DataFrame with array columns with names pad_columns.
117
+
118
+ :returns: DataFrame with padded array columns.
119
+
120
+ """
121
+ df_transformed = interactions
122
+ is_spark = isinstance(interactions, SparkDataFrame)
123
+ column_dtypes = dict(df_transformed.dtypes)
124
+
125
+ for col, pad_value in zip(self.pad_columns, self.padding_value):
126
+ if col not in df_transformed.columns:
127
+ msg = f"Column {col} not in DataFrame columns."
128
+ raise ValueError(msg)
129
+ if is_spark is True and not column_dtypes[col].startswith("array"):
130
+ msg = f"Column {col} should have ArrayType to be padded."
131
+ raise ValueError(msg)
132
+ if is_spark is False and not is_object_dtype(df_transformed[col]):
133
+ msg = f"Column {col} should have object dtype to be padded."
134
+ raise ValueError()
135
+
136
+ if is_spark is True:
137
+ df_transformed = self._transform_spark(df_transformed, col, pad_value)
138
+ else:
139
+ df_transformed = self._transform_pandas(df_transformed, col, pad_value)
140
+
141
+ return df_transformed
142
+
143
+ def _transform_pandas(
144
+ self, df_transformed: PandasDataFrame, col: str, pad_value: Union[str, float, list, None]
145
+ ) -> PandasDataFrame:
146
+ max_array_size = df_transformed[col].str.len().max() if self.array_size == -1 else self.array_size
147
+
148
+ def right_cut(sample: list) -> list:
149
+ # fmt: off
150
+ return sample[-min(len(sample), max_array_size):]
151
+ # fmt: on
152
+
153
+ def left_cut(sample: list) -> list:
154
+ # fmt: off
155
+ return sample[:min(len(sample), max_array_size)]
156
+ # fmt: on
157
+
158
+ res = df_transformed.copy(deep=True)
159
+ res[col] = res[col].apply(lambda sample: sample if isinstance(sample, list) else [])
160
+ cut_col_name = f"{col}_cut"
161
+ if self.cut_array:
162
+ cut_func = right_cut if self.cut_side == "right" else left_cut
163
+
164
+ res[cut_col_name] = res[col].apply(cut_func)
165
+ else:
166
+ res[cut_col_name] = res[col]
167
+
168
+ paddings = res[cut_col_name].apply(lambda x: [pad_value for _ in range(max_array_size - len(x))])
169
+ if self.padding_side == "right":
170
+ res[col] = res[cut_col_name] + paddings
171
+ else:
172
+ res[col] = paddings + res[cut_col_name]
173
+
174
+ res.drop(columns=[cut_col_name], inplace=True)
175
+
176
+ return res
177
+
178
+ def _transform_spark(
179
+ self, df_transformed: SparkDataFrame, col: str, pad_value: Union[str, float, list, None]
180
+ ) -> SparkDataFrame:
181
+ if self.array_size == -1:
182
+ max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).first()[0]
183
+ else:
184
+ max_array_size = self.array_size
185
+
186
+ df_transformed = df_transformed.withColumn(col, sf.coalesce(col, sf.array()))
187
+ insert_value = pad_value if not isinstance(pad_value, str) else "'" + pad_value + "'"
188
+
189
+ cut_col_name = f"{col}_cut"
190
+ drop_cols = [cut_col_name, "pre_zeros", "zeros"]
191
+
192
+ if self.cut_array:
193
+ slice_col_name = f"{col}_slice_point"
194
+ drop_cols += [slice_col_name]
195
+
196
+ if self.cut_side == "right":
197
+ slice_func = (-1) * sf.least(sf.size(col), sf.lit(max_array_size))
198
+ cut_func = sf.when(
199
+ sf.size(col) > 0, sf.expr(f"slice({col}, {slice_col_name}, {max_array_size})")
200
+ ).otherwise(sf.array())
201
+ else:
202
+ slice_func = sf.least(sf.size(col), sf.lit(max_array_size))
203
+ cut_func = sf.when(sf.size(col) > 0, sf.expr(f"slice({col}, 1, {slice_col_name})")).otherwise(
204
+ sf.array()
205
+ )
206
+
207
+ df_transformed = df_transformed.withColumn(slice_col_name, slice_func).withColumn(cut_col_name, cut_func)
208
+
209
+ else:
210
+ df_transformed = df_transformed.withColumn(cut_col_name, sf.col(col))
211
+
212
+ if self.padding_side == "right":
213
+ concat_func = sf.concat(sf.col(cut_col_name), sf.col("zeros"))
214
+ else:
215
+ concat_func = sf.concat(sf.col("zeros"), sf.col(cut_col_name))
216
+
217
+ df_transformed = (
218
+ df_transformed.withColumn(
219
+ "pre_zeros",
220
+ sf.sequence(sf.lit(0), sf.greatest(sf.lit(max_array_size) - sf.size(sf.col(cut_col_name)), sf.lit(0))),
221
+ )
222
+ .withColumn(
223
+ "zeros", sf.expr(f"transform(slice(pre_zeros, 1, size(pre_zeros) - 1), element -> {insert_value})")
224
+ )
225
+ .withColumn(col, concat_func)
226
+ .drop(*drop_cols)
227
+ )
228
+
229
+ return df_transformed
@@ -0,0 +1,208 @@
1
+ from typing import Optional, Union
2
+
3
+ import pandas as pd
4
+
5
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, SparkDataFrame
6
+
7
+ if PYSPARK_AVAILABLE:
8
+ from pyspark.sql import (
9
+ Column,
10
+ Window,
11
+ functions as sf,
12
+ )
13
+
14
+
15
+ class SequenceGenerator:
16
+ """
17
+ Creating sequences for sequential models.
18
+
19
+ E.g., ``u1`` has purchase sequence ``<i1, i2, i3, i4>``,
20
+ then after processing, there will be generated three cases.
21
+
22
+ ``u1, <i1> | i2``
23
+
24
+ (Which means given user_id ``u1`` and item_seq ``<i1>``,
25
+ model need to predict the next item ``i2``.)
26
+
27
+ The other cases are below:
28
+
29
+ ``u1, <i1, i2> | i3``
30
+
31
+ ``u1, <i1, i2, i3> | i4``
32
+
33
+ >>> import pandas as pd
34
+ >>> time_interactions = pd.DataFrame({
35
+ ... "user_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 3],
36
+ ... "item_id": [3, 7, 10, 5, 8, 11, 4, 9, 2, 5],
37
+ ... "timestamp": [1, 2, 3, 3, 2, 1, 3, 12, 1, 4]
38
+ ... })
39
+ >>> time_interactions
40
+ user_id item_id timestamp
41
+ 0 1 3 1
42
+ 1 1 7 2
43
+ 2 1 10 3
44
+ 3 2 5 3
45
+ 4 2 8 2
46
+ 5 2 11 1
47
+ 6 3 4 3
48
+ 7 3 9 12
49
+ 8 3 2 1
50
+ 9 3 5 4
51
+ >>> sequences = (
52
+ ... SequenceGenerator(
53
+ ... groupby_column="user_id", transform_columns=["item_id", "timestamp"]
54
+ ... ).transform(time_interactions)
55
+ ... )
56
+ >>> sequences
57
+ user_id item_id_list timestamp_list label_item_id label_timestamp
58
+ 0 1 [3] [1] 7 2
59
+ 1 1 [3, 7] [1, 2] 10 3
60
+ 2 2 [5] [3] 8 2
61
+ 3 2 [5, 8] [3, 2] 11 1
62
+ 4 3 [4] [3] 9 12
63
+ 5 3 [4, 9] [3, 12] 2 1
64
+ 6 3 [4, 9, 2] [3, 12, 1] 5 4
65
+ <BLANKLINE>
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ groupby_column: Union[str, list[str]],
71
+ orderby_column: Union[str, list[str], None] = None,
72
+ transform_columns: Union[None, str, list[str]] = None,
73
+ len_window: int = 50,
74
+ sequence_prefix: Optional[str] = None,
75
+ sequence_suffix: Optional[str] = "_list",
76
+ label_prefix: Optional[str] = "label_",
77
+ label_suffix: Optional[str] = None,
78
+ get_list_len: Optional[bool] = False,
79
+ list_len_column: str = "list_len",
80
+ ):
81
+ """
82
+ :param groupby_column: Name of column to group by.
83
+ :param orderby_column Columns to sort by. If None
84
+ than values are not ordered.
85
+ default: ``None``.
86
+ :param transform_columns: Names of interaction columns to process. If None
87
+ than all columns are processed except grouping ones.
88
+ default: ``None``.
89
+ :param len_window: Max len of sequention, must be positive.
90
+ default: ``50``.
91
+ :param sequence_prefix: prefix added to column name after creating sequences.
92
+ default: ``None``.
93
+ :param sequence_suffix: suffix added to column name after creating sequences.
94
+ default: ``_list``.
95
+ :param label_prefix: prefix added to label column after creating sequences.
96
+ default: ``label_``.
97
+ :param label_suffix: suffix added to label column after creating sequences.
98
+ default: ``None``.
99
+ :param get_list_len: flag to calculate length of processed list or not.
100
+ default: ``False``.
101
+ :param list_len_column: List length column name. Used if get_list_len.
102
+ default: ``list_len``.
103
+ """
104
+ self.groupby_column = groupby_column if not isinstance(groupby_column, str) else [groupby_column]
105
+ self.orderby_column: Union[list, Column, None]
106
+ if orderby_column is None:
107
+ self.orderby_column = None
108
+ else:
109
+ self.orderby_column = orderby_column if not isinstance(orderby_column, str) else [orderby_column]
110
+
111
+ self.transform_columns = transform_columns
112
+ self.len_window = len_window
113
+
114
+ self.sequence_prefix = "" if sequence_prefix is None else sequence_prefix
115
+ self.sequence_suffix = "" if sequence_suffix is None else sequence_suffix
116
+
117
+ self.label_prefix = "" if label_prefix is None else label_prefix
118
+ self.label_suffix = "" if label_suffix is None else label_suffix
119
+
120
+ self.get_list_len = get_list_len
121
+ self.list_len_column = list_len_column
122
+
123
+ def transform(self, interactions: DataFrameLike) -> DataFrameLike:
124
+ """Create sequences from given interactions.
125
+
126
+ :param interactions: DataFrame.
127
+
128
+ :returns: DataFrame with transformed interactions. Sequential interactions in list.
129
+
130
+ """
131
+ if self.transform_columns is None:
132
+ self.transform_columns = list(set(interactions.columns).difference(self.groupby_column))
133
+ else:
134
+ self.transform_columns = (
135
+ self.transform_columns if not isinstance(self.transform_columns, str) else [self.transform_columns]
136
+ )
137
+
138
+ if isinstance(interactions, SparkDataFrame):
139
+ return self._transform_spark(interactions)
140
+
141
+ return self._transform_pandas(interactions)
142
+
143
+ def _transform_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
144
+ assert self.transform_columns is not None
145
+ processed_interactions = interactions.copy(deep=True)
146
+
147
+ def seq_rolling(col: pd.Series) -> list:
148
+ return [window.to_list()[:-1] for window in col.rolling(self.len_window + 1)]
149
+
150
+ for transform_col in self.transform_columns:
151
+ if self.orderby_column is not None:
152
+ processed_interactions.sort_values(by=self.orderby_column, inplace=True)
153
+ else:
154
+ processed_interactions.sort_values(by=self.groupby_column, inplace=True)
155
+
156
+ processed_interactions[self.sequence_prefix + transform_col + self.sequence_suffix] = [
157
+ item
158
+ for sublist in processed_interactions.groupby(self.groupby_column, sort=False)[transform_col].apply(
159
+ seq_rolling
160
+ )
161
+ for item in sublist
162
+ ]
163
+ processed_interactions[self.label_prefix + transform_col + self.label_suffix] = processed_interactions[
164
+ transform_col
165
+ ]
166
+
167
+ first_tranformed_col = self.sequence_prefix + self.transform_columns[0] + self.sequence_suffix
168
+ processed_interactions = processed_interactions[processed_interactions[first_tranformed_col].str.len() > 0]
169
+
170
+ transformed_columns = [self.sequence_prefix + x + self.sequence_suffix for x in self.transform_columns]
171
+ label_columns = [self.label_prefix + x + self.label_suffix for x in self.transform_columns]
172
+ select_columns = self.groupby_column + transformed_columns + label_columns
173
+
174
+ if self.get_list_len:
175
+ processed_interactions[self.list_len_column] = processed_interactions[first_tranformed_col].str.len()
176
+ select_columns += [self.list_len_column]
177
+
178
+ processed_interactions.reset_index(inplace=True)
179
+
180
+ return processed_interactions[select_columns]
181
+
182
+ def _transform_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
183
+ assert self.transform_columns is not None
184
+ processed_interactions = interactions
185
+ orderby_column: Union[Column, list]
186
+ orderby_column = sf.lit(1) if self.orderby_column is None else self.orderby_column
187
+
188
+ window = Window.partitionBy(self.groupby_column).orderBy(orderby_column).rowsBetween(-self.len_window, -1)
189
+ for transform_col in self.transform_columns:
190
+ processed_interactions = processed_interactions.withColumn(
191
+ self.sequence_prefix + transform_col + self.sequence_suffix,
192
+ sf.collect_list(transform_col).over(window),
193
+ ).withColumn(self.label_prefix + transform_col + self.label_suffix, sf.col(transform_col))
194
+
195
+ first_tranformed_col = self.sequence_prefix + self.transform_columns[0] + self.sequence_suffix
196
+ processed_interactions = processed_interactions.filter(sf.size(first_tranformed_col) > 0)
197
+
198
+ transformed_columns = [self.sequence_prefix + x + self.sequence_suffix for x in self.transform_columns]
199
+ label_columns = [self.label_prefix + x + self.label_suffix for x in self.transform_columns]
200
+ select_columns = self.groupby_column + transformed_columns + label_columns
201
+
202
+ if self.get_list_len:
203
+ processed_interactions = processed_interactions.withColumn(
204
+ self.list_len_column, sf.size(first_tranformed_col)
205
+ )
206
+ select_columns += [self.list_len_column]
207
+
208
+ return processed_interactions.select(select_columns)
@@ -0,0 +1 @@
1
+ from replay.experimental.scenarios.two_stages.two_stages_scenario import TwoStagesScenario
@@ -0,0 +1,8 @@
1
+ """
2
+ This module contains Learner class for RePlay models training on bandit dataset.
3
+ The format of the bandit dataset should be the same as in OpenBanditPipeline.
4
+ Learner class has methods `fit` and `predict` which are wrappers for the corresponding
5
+ methods of RePlay model. Optimize is based on optimization over CTR estimated by OBP.
6
+ """
7
+
8
+ from replay.experimental.scenarios.obp_wrapper.replay_offline import OBPOfflinePolicyLearner
@@ -0,0 +1,74 @@
1
+ from functools import partial
2
+ from typing import Any, Optional
3
+
4
+ import numpy as np
5
+ from obp.ope import DirectMethod, DoublyRobust, InverseProbabilityWeighting, OffPolicyEvaluation
6
+ from optuna import Trial
7
+
8
+ from replay.experimental.scenarios.obp_wrapper.utils import get_est_rewards_by_reg
9
+ from replay.models.optimization.optuna_objective import ObjectiveWrapper, suggest_params
10
+
11
+
12
+ def obp_objective_calculator(
13
+ trial: Trial,
14
+ search_space: dict[str, list[Optional[Any]]],
15
+ bandit_feedback_train: dict[str, np.ndarray],
16
+ bandit_feedback_val: dict[str, np.ndarray],
17
+ learner,
18
+ criterion: str,
19
+ k: int,
20
+ ) -> float:
21
+ """
22
+ Sample parameters and calculate criterion value
23
+ :param trial: optuna trial
24
+ :param search_space: hyper parameter search space
25
+ :bandit_feedback_train: dict with bandit train data
26
+ :bandit_feedback_cal: dist with bandit validation data
27
+ :param criterion: optimization metric
28
+ :param k: length of a recommendation list
29
+ :return: criterion value
30
+ """
31
+
32
+ params_for_trial = suggest_params(trial, search_space)
33
+ learner.replay_model.set_params(**params_for_trial)
34
+
35
+ timestamp = np.arange(bandit_feedback_train["n_rounds"])
36
+
37
+ learner.fit(
38
+ action=bandit_feedback_train["action"],
39
+ reward=bandit_feedback_train["reward"],
40
+ timestamp=timestamp,
41
+ context=bandit_feedback_train["context"],
42
+ action_context=bandit_feedback_train["action_context"],
43
+ )
44
+
45
+ action_dist = learner.predict(bandit_feedback_val["n_rounds"], bandit_feedback_val["context"])
46
+
47
+ ope_estimator = None
48
+ if criterion == "ipw":
49
+ ope_estimator = InverseProbabilityWeighting()
50
+ elif criterion == "dm":
51
+ ope_estimator = DirectMethod()
52
+ elif criterion == "dr":
53
+ ope_estimator = DoublyRobust()
54
+ else:
55
+ msg = f"There is no criterion with name {criterion}"
56
+ raise NotImplementedError(msg)
57
+
58
+ ope = OffPolicyEvaluation(bandit_feedback=bandit_feedback_val, ope_estimators=[ope_estimator])
59
+
60
+ estimated_rewards_by_reg_model = None
61
+ if criterion in ("dm", "dr"):
62
+ estimated_rewards_by_reg_model = get_est_rewards_by_reg(
63
+ learner.n_actions, k, bandit_feedback_train, bandit_feedback_val
64
+ )
65
+
66
+ estimated_policy_value = ope.estimate_policy_values(
67
+ action_dist=action_dist,
68
+ estimated_rewards_by_reg_model=estimated_rewards_by_reg_model,
69
+ )[criterion]
70
+
71
+ return estimated_policy_value
72
+
73
+
74
+ OBPObjective = partial(ObjectiveWrapper, objective_calculator=obp_objective_calculator)