replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,128 +0,0 @@
1
- import logging
2
-
3
- from .utils import matrix2df
4
- import pandas as pd
5
- from tqdm import tqdm
6
-
7
- from replay.utils import TORCH_AVAILABLE
8
- if TORCH_AVAILABLE:
9
- import torch
10
- from torch.nn import functional as F
11
-
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- # pylint: disable=too-few-public-methods
17
- class TrainerConfig:
18
- """
19
- Config holder for trainer
20
- """
21
-
22
- epochs = 1
23
- lr_scheduler = None
24
-
25
- def __init__(self, **kwargs):
26
- for key, value in kwargs.items():
27
- setattr(self, key, value)
28
-
29
- def update(self, **kwargs):
30
- """
31
- Arguments setter
32
- """
33
- for key, value in kwargs.items():
34
- setattr(self, key, value)
35
-
36
-
37
- # pylint: disable=too-many-instance-attributes
38
- class Trainer:
39
- """
40
- Trainer for DT4Rec
41
- """
42
-
43
- grad_norm_clip = 1.0
44
-
45
- # pylint: disable=too-many-arguments
46
- def __init__(
47
- self,
48
- model,
49
- train_dataloader,
50
- tconf,
51
- val_dataloader=None,
52
- experiment=None,
53
- use_cuda=True,
54
- ):
55
- self.model = model
56
- self.train_dataloader = train_dataloader
57
- self.optimizer = tconf.optimizer
58
- self.epochs = tconf.epochs
59
- self.lr_scheduler = tconf.lr_scheduler
60
- assert (val_dataloader is None) == (experiment is None)
61
- self.val_dataloader = val_dataloader
62
- self.experiment = experiment
63
-
64
- # take over whatever gpus are on the system
65
- self.device = "cpu"
66
- if use_cuda and torch.cuda.is_available():
67
- self.device = torch.cuda.current_device()
68
- self.model = torch.nn.DataParallel(self.model).to(self.device)
69
-
70
- def _move_batch(self, batch):
71
- return [elem.to(self.device) for elem in batch]
72
-
73
- def _train_epoch(self, epoch):
74
- self.model.train()
75
-
76
- losses = []
77
- pbar = tqdm(
78
- enumerate(self.train_dataloader),
79
- total=len(self.train_dataloader),
80
- )
81
-
82
- for iter_, batch in pbar:
83
- # place data on the correct device
84
- states, actions, rtgs, timesteps, users = self._move_batch(batch)
85
- targets = actions
86
-
87
- # forward the model
88
- logits = self.model(states, actions, rtgs, timesteps, users)
89
-
90
- loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)).mean()
91
- losses.append(loss.item())
92
-
93
- # backprop and update the parametersx
94
- self.model.zero_grad()
95
- loss.backward()
96
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clip)
97
- self.optimizer.step()
98
- if self.lr_scheduler is not None:
99
- self.lr_scheduler.step()
100
-
101
- # report progress
102
- if self.lr_scheduler is not None:
103
- current_lr = self.lr_scheduler.get_lr()
104
- else:
105
- current_lr = self.optimizer.param_groups[-1]["lr"]
106
- pbar.set_description(f"epoch {epoch+1} iter {iter_}: train loss {loss.item():.5f}, lr {current_lr}")
107
-
108
- def _evaluation_epoch(self, epoch):
109
- self.model.eval()
110
- ans_df = pd.DataFrame(columns=["user_idx", "item_idx", "relevance"])
111
- val_items = self.val_dataloader.dataset.val_items
112
- with torch.no_grad():
113
- for batch in tqdm(self.val_dataloader):
114
- states, actions, rtgs, timesteps, users = self._move_batch(batch)
115
- logits = self.model(states, actions, rtgs, timesteps, users)
116
- items_relevances = logits[:, -1, :][:, val_items]
117
- ans_df = ans_df.append(matrix2df(items_relevances, users.squeeze(), val_items))
118
- self.experiment.add_result(f"epoch: {epoch}", ans_df)
119
- self.experiment.results.to_csv("results.csv")
120
-
121
- def train(self):
122
- """
123
- Run training loop
124
- """
125
- for epoch in range(self.epochs):
126
- self._train_epoch(epoch)
127
- if self.experiment is not None:
128
- self._evaluation_epoch(epoch)
@@ -1,274 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import bisect
4
- import random
5
- from typing import List, Union
6
-
7
- import numpy as np
8
- import pandas as pd
9
- from tqdm import tqdm
10
-
11
- from replay.utils import TORCH_AVAILABLE
12
- if TORCH_AVAILABLE:
13
- import torch
14
- from torch.optim import Optimizer
15
- from torch.optim.lr_scheduler import _LRScheduler
16
- from torch.utils.data import Dataset
17
-
18
-
19
- def set_seed(seed):
20
- """
21
- Set random seed in all dependicies
22
- """
23
- random.seed(seed)
24
- np.random.seed(seed)
25
- torch.manual_seed(seed)
26
- torch.cuda.manual_seed_all(seed)
27
-
28
-
29
- class StateActionReturnDataset(Dataset):
30
- """
31
- Create Dataset from user trajectories
32
- """
33
-
34
- def __init__(self, user_trajectory, trajectory_len):
35
- self.user_trajectory = user_trajectory
36
- self.trajectory_len = trajectory_len
37
-
38
- self.len = 0
39
- self.prefix_lens = [0]
40
- for trajectory in self.user_trajectory:
41
- # print(f'{trajectory=}')
42
- self.len += max(1, len(trajectory["actions"]) - 30 + 1)
43
- self.prefix_lens.append(self.len)
44
-
45
- def __len__(self):
46
- return self.len
47
-
48
- def __getitem__(self, idx):
49
- user_num = bisect.bisect_right(self.prefix_lens, idx) - 1
50
- start = idx - self.prefix_lens[user_num]
51
-
52
- user = self.user_trajectory[user_num]
53
- end = min(len(user["actions"]), start + self.trajectory_len)
54
- states = torch.tensor(np.array(user["states"][start:end]), dtype=torch.float32)
55
- actions = torch.tensor(user["actions"][start:end], dtype=torch.long)
56
- rtgs = torch.tensor(user["rtgs"][start:end], dtype=torch.float32)
57
- # strange logic but work
58
- timesteps = start
59
-
60
- return states, actions, rtgs, timesteps, user_num
61
-
62
-
63
- class ValidateDataset(Dataset):
64
- """
65
- Dataset for Validation
66
- """
67
-
68
- def __init__(self, user_trajectory, max_context_len, val_users, val_items):
69
- self.user_trajectory = user_trajectory
70
- self.max_context_len = max_context_len
71
- self.val_users = val_users
72
- self.val_items = val_items
73
-
74
- def __len__(self):
75
- return len(self.val_users)
76
-
77
- def __getitem__(self, idx):
78
- user_idx = self.val_users[idx]
79
- user = self.user_trajectory[user_idx]
80
- if len(user["actions"]) <= self.max_context_len:
81
- start = 0
82
- end = -1
83
- else:
84
- end = -1
85
- start = end - self.max_context_len
86
-
87
- states = torch.tensor(
88
- np.array(user["states"][start - (start < 0) : end]),
89
- dtype=torch.float32,
90
- )
91
- actions = torch.tensor(user["actions"][start:end], dtype=torch.long)
92
- rtgs = torch.zeros((end - start + 1 if start < 0 else len(user["actions"])))
93
- rtgs[start:end] = torch.tensor(user["rtgs"][start:end], dtype=torch.float32)
94
- rtgs[end] = 10
95
- timesteps = len(user["actions"]) + start if start < 0 else 0
96
-
97
- return states, actions, rtgs, timesteps, user_idx
98
-
99
-
100
- def pad_sequence(
101
- sequences: Union[torch.Tensor, List[torch.Tensor]],
102
- batch_first: bool = False,
103
- padding_value: float = 0.0,
104
- pos: str = "right",
105
- ) -> torch.Tensor:
106
- """
107
- Pad sequence
108
- """
109
- if pos == "right":
110
- padded_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first, padding_value)
111
- elif pos == "left":
112
- sequences = tuple(map(lambda s: s.flip(0), sequences))
113
- padded_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first, padding_value)
114
- _seq_dim = padded_sequence.dim()
115
- padded_sequence = padded_sequence.flip(-_seq_dim + batch_first)
116
- else:
117
- raise ValueError(f"pos should be either 'right' or 'left', but got {pos}")
118
- return padded_sequence
119
-
120
-
121
- # pylint: disable=too-few-public-methods
122
- class Collator:
123
- """
124
- Callable class to merge several items to one batch
125
- """
126
-
127
- def __init__(self, item_pad):
128
- self.item_pad = item_pad
129
-
130
- def __call__(self, batch):
131
- states, actions, rtgs, timesteps, users_num = zip(*batch)
132
-
133
- return (
134
- pad_sequence(
135
- states,
136
- batch_first=True,
137
- padding_value=self.item_pad,
138
- pos="left",
139
- ),
140
- pad_sequence(
141
- actions,
142
- batch_first=True,
143
- padding_value=self.item_pad,
144
- pos="left",
145
- ).unsqueeze(-1),
146
- pad_sequence(rtgs, batch_first=True, padding_value=0, pos="left").unsqueeze(-1),
147
- torch.tensor(timesteps).unsqueeze(-1).unsqueeze(-1),
148
- torch.tensor(users_num).unsqueeze(-1),
149
- )
150
-
151
-
152
- def matrix2df(matrix, users=None, items=None):
153
- """
154
- Creata DataFrame from matrix
155
- """
156
- HEADER = ["user_idx", "item_idx", "relevance"]
157
- if users is None:
158
- users = np.arange(matrix.shape[0])
159
- else:
160
- users = np.array(users.cpu())
161
- if items is None:
162
- items = np.arange(matrix.shape[1])
163
- x1 = np.repeat(users, len(items))
164
- x2 = np.tile(items, len(users))
165
- x3 = np.array(matrix.cpu()).flatten()
166
-
167
- return pd.DataFrame(np.array([x1, x2, x3]).T, columns=HEADER)
168
-
169
-
170
- class WarmUpScheduler(_LRScheduler):
171
- """
172
- Implementation of WarmUp
173
- """
174
-
175
- # pylint: disable=too-many-arguments
176
- def __init__(
177
- self,
178
- optimizer: Optimizer,
179
- dim_embed: int,
180
- warmup_steps: int,
181
- last_epoch: int = -1,
182
- verbose: bool = False,
183
- ) -> None:
184
- self.dim_embed = dim_embed
185
- self.warmup_steps = warmup_steps
186
- self.num_param_groups = len(optimizer.param_groups)
187
-
188
- super().__init__(optimizer, last_epoch, verbose)
189
-
190
- def get_lr(self) -> float:
191
- lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
192
- return [lr] * self.num_param_groups
193
-
194
-
195
- def calc_lr(step, dim_embed, warmup_steps):
196
- """
197
- Learning rate calculation
198
- """
199
- return dim_embed ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))
200
-
201
-
202
- # pylint: disable=too-many-arguments
203
- def create_dataset(
204
- df, user_num, item_pad, time_col="timestamp", user_col="user_idx", item_col="item_idx", relevance_col="relevance"
205
- ):
206
- """
207
- Create dataset from DataFrame
208
- """
209
- user_trajectory = [{} for _ in range(user_num)]
210
- df = df.sort_values(by=time_col)
211
- for user_idx in tqdm(range(user_num)):
212
- user_trajectory[user_idx]["states"] = [[item_pad, item_pad, item_pad]]
213
- user_trajectory[user_idx]["actions"] = []
214
- user_trajectory[user_idx]["rewards"] = []
215
-
216
- user = user_trajectory[user_idx]
217
- user_df = df[df[user_col] == user_idx]
218
- for _, row in user_df.iterrows():
219
- action = row[item_col]
220
- user["actions"].append(action)
221
- if row[relevance_col] > 3:
222
- user["rewards"].append(1)
223
- user["states"].append([user["states"][-1][1], user["states"][-1][2], action])
224
- else:
225
- user["rewards"].append(0)
226
- user["states"].append(user["states"][-1])
227
-
228
- user["rtgs"] = np.cumsum(user["rewards"][::-1])[::-1]
229
- for key in user:
230
- user[key] = np.array(user[key])
231
-
232
- return user_trajectory
233
-
234
-
235
- # For debug
236
- # pylint: disable=too-many-locals
237
- def fast_create_dataset(
238
- df,
239
- user_num,
240
- item_pad,
241
- time_field="timestamp",
242
- user_field="user_idx",
243
- item_field="item_idx",
244
- relevance_field="relevance",
245
- ):
246
- """
247
- Create dataset from DataFrame
248
- """
249
- user_trajectory = [{} for _ in range(user_num)]
250
- df = df.sort_values(by=time_field)
251
- for user_idx in tqdm(range(user_num)):
252
- user_trajectory[user_idx]["states"] = [[item_pad, item_pad, item_pad]]
253
- user_trajectory[user_idx]["actions"] = []
254
- user_trajectory[user_idx]["rewards"] = []
255
-
256
- user = user_trajectory[user_idx]
257
- user_df = df[df[user_field] == user_idx]
258
- for idx, (_, row) in enumerate(user_df.iterrows()):
259
- if idx >= 35:
260
- break
261
- action = row[item_field]
262
- user["actions"].append(action)
263
- if row[relevance_field] > 3:
264
- user["rewards"].append(1)
265
- user["states"].append([user["states"][-1][1], user["states"][-1][2], action])
266
- else:
267
- user["rewards"].append(0)
268
- user["states"].append(user["states"][-1])
269
-
270
- user["rtgs"] = np.cumsum(user["rewards"][::-1])[::-1]
271
- for key in user:
272
- user[key] = np.array(user[key])
273
-
274
- return user_trajectory