replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
  35. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
  36. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
  37. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,152 @@
1
+ import pathlib
2
+ from typing import Optional, Union, get_args
3
+
4
+ import openvino as ov
5
+ import torch
6
+
7
+ from replay.data.nn import TensorSchema
8
+ from replay.models.nn.sequential.bert4rec import (
9
+ Bert4Rec,
10
+ Bert4RecPredictionBatch,
11
+ )
12
+ from replay.models.nn.sequential.bert4rec.lightning import _prepare_prediction_batch
13
+ from replay.models.nn.sequential.compiled.base_compiled_model import (
14
+ BaseCompiledModel,
15
+ OptimizedModeType,
16
+ )
17
+
18
+
19
+ class Bert4RecCompiled(BaseCompiledModel):
20
+ """
21
+ Bert4Rec CPU-optimized model for inference via OpenVINO.
22
+ It is recommended to compile model with ``compile`` method and pass ``Bert4Rec`` checkpoint
23
+ or the model object itself into it.
24
+ It is also possible to compile model by yourself and pass it to the ``__init__`` with ``TensorSchema``.
25
+
26
+ **Note** that compilation requires disk write (and maybe delete) permission.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ compiled_model: ov.CompiledModel,
32
+ schema: TensorSchema,
33
+ ) -> None:
34
+ """
35
+ :param compiled_model: Compiled model.
36
+ :param schema: Tensor schema of Bert4Rec model.
37
+ """
38
+ super().__init__(compiled_model, schema)
39
+
40
+ def predict(
41
+ self,
42
+ batch: Bert4RecPredictionBatch,
43
+ candidates_to_score: Optional[torch.LongTensor] = None,
44
+ ) -> torch.Tensor:
45
+ """
46
+ Inference on one batch.
47
+
48
+ :param batch: Prediction input.
49
+ :param candidates_to_score: Item ids to calculate scores.
50
+ Default: ``None``.
51
+
52
+ :return: Tensor with scores.
53
+ """
54
+ self._valilade_predict_input(batch, candidates_to_score)
55
+
56
+ batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
57
+ model_inputs = {
58
+ self._inputs_names[0]: batch.features[self._inputs_names[0]],
59
+ self._inputs_names[1]: batch.padding_mask,
60
+ self._inputs_names[2]: batch.tokens_mask,
61
+ }
62
+ if self._num_candidates_to_score is not None:
63
+ self._validate_candidates_to_score(candidates_to_score)
64
+ model_inputs[self._inputs_names[3]] = candidates_to_score
65
+ return torch.from_numpy(self._model(model_inputs)[self._output_name])
66
+
67
+ @classmethod
68
+ def compile(
69
+ cls,
70
+ model: Union[Bert4Rec, str, pathlib.Path],
71
+ mode: OptimizedModeType = "one_query",
72
+ batch_size: Optional[int] = None,
73
+ num_candidates_to_score: Optional[int] = None,
74
+ num_threads: Optional[int] = None,
75
+ onnx_path: Optional[str] = None,
76
+ ) -> "Bert4RecCompiled":
77
+ """
78
+ Model compilation.
79
+
80
+ :param model: Path to lightning Bert4Rec model saved in .ckpt format or the Bert4Rec object itself.
81
+ :param mode: Inference mode, defines shape of inputs.
82
+ Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
83
+ ``one_query`` - sets input shape to [1, max_seq_len]\n
84
+ ``batch`` - sets input shape to [batch_size, max_seq_len]\n
85
+ ``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
86
+ Default: ``one_query``.
87
+ :param batch_size: Batch size, required for ``batch`` mode.
88
+ Default: ``None``.
89
+ :param num_candidates_to_score: Number of item ids to calculate scores.
90
+ Could be one of [``None``, ``-1``, ``N``].\n
91
+ ``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
92
+ ``N`` - sets candidates_to_score shape to [1, N]\n
93
+ ``None`` - disables candidates_to_score usage\n
94
+ Default: ``None``.
95
+ :param num_threads: Number of CPU threads to use.
96
+ Must be a natural number or ``None``.
97
+ If ``None``, then compiler will set this parameter automatically.
98
+ Default: ``None``.
99
+ :param onnx_path: Save ONNX model to path, if defined.
100
+ Default: ``None``.
101
+ """
102
+ if mode not in get_args(OptimizedModeType):
103
+ msg = f"Parameter ``mode`` could be one of {get_args(OptimizedModeType)}."
104
+ raise ValueError(msg)
105
+ num_candidates_to_score = Bert4RecCompiled._validate_num_candidates_to_score(num_candidates_to_score)
106
+ if isinstance(model, Bert4Rec):
107
+ lightning_model = model.cpu()
108
+ elif isinstance(model, (str, pathlib.Path)):
109
+ lightning_model = Bert4Rec.load_from_checkpoint(model, map_location=torch.device("cpu"))
110
+
111
+ schema = lightning_model._schema
112
+ item_seq_name = schema.item_id_feature_name
113
+ max_seq_len = lightning_model._model.max_len
114
+
115
+ batch_size, num_candidates_to_score = Bert4RecCompiled._get_input_params(
116
+ mode, batch_size, num_candidates_to_score
117
+ )
118
+
119
+ item_sequence = torch.zeros((1, max_seq_len)).long()
120
+ padding_mask = torch.zeros((1, max_seq_len)).bool()
121
+ tokens_mask = torch.zeros((1, max_seq_len)).bool()
122
+
123
+ model_input_names = [item_seq_name, "padding_mask", "tokens_mask"]
124
+ model_dynamic_axes_in_input = {
125
+ item_seq_name: {0: "batch_size", 1: "max_len"},
126
+ "padding_mask": {0: "batch_size", 1: "max_len"},
127
+ "tokens_mask": {0: "batch_size", 1: "max_len"},
128
+ }
129
+ if num_candidates_to_score:
130
+ candidates_to_score = torch.zeros((1,)).long()
131
+ model_input_names += ["candidates_to_score"]
132
+ model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
133
+ model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask, candidates_to_score)
134
+ else:
135
+ model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask)
136
+
137
+ # Need to disable "Better Transformer" optimizations that interfere with the compilation process
138
+ if hasattr(torch.backends, "mha"):
139
+ torch.backends.mha.set_fastpath_enabled(value=False)
140
+
141
+ compiled_model = Bert4RecCompiled._run_model_compilation(
142
+ lightning_model,
143
+ model_input_sample,
144
+ model_input_names,
145
+ model_dynamic_axes_in_input,
146
+ batch_size,
147
+ num_candidates_to_score,
148
+ num_threads,
149
+ onnx_path,
150
+ )
151
+
152
+ return cls(compiled_model, schema)
@@ -0,0 +1,145 @@
1
+ import pathlib
2
+ from typing import Optional, Union, get_args
3
+
4
+ import openvino as ov
5
+ import torch
6
+
7
+ from replay.data.nn import TensorSchema
8
+ from replay.models.nn.sequential.compiled.base_compiled_model import (
9
+ BaseCompiledModel,
10
+ OptimizedModeType,
11
+ )
12
+ from replay.models.nn.sequential.sasrec import (
13
+ SasRec,
14
+ SasRecPredictionBatch,
15
+ )
16
+ from replay.models.nn.sequential.sasrec.lightning import _prepare_prediction_batch
17
+
18
+
19
+ class SasRecCompiled(BaseCompiledModel):
20
+ """
21
+ SasRec CPU-optimized model for inference via OpenVINO.
22
+ It is recommended to compile model with ``compile`` method and pass ``SasRec`` checkpoint
23
+ or the model object itself into it.
24
+ It is also possible to compile model by yourself and pass it to the ``__init__`` with ``TensorSchema``.
25
+
26
+ **Note** that compilation requires disk write (and maybe delete) permission.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ compiled_model: ov.CompiledModel,
32
+ schema: TensorSchema,
33
+ ) -> None:
34
+ """
35
+ :param compiled_model: Compiled model.
36
+ :param schema: Tensor schema of SasRec model.
37
+ """
38
+ super().__init__(compiled_model, schema)
39
+
40
+ def predict(
41
+ self,
42
+ batch: SasRecPredictionBatch,
43
+ candidates_to_score: Optional[torch.LongTensor] = None,
44
+ ) -> torch.Tensor:
45
+ """
46
+ Inference on one batch.
47
+
48
+ :param batch: Prediction input.
49
+ :param candidates_to_score: Item ids to calculate scores.
50
+ Default: ``None``.
51
+
52
+ :return: Tensor with scores.
53
+ """
54
+ self._valilade_predict_input(batch, candidates_to_score)
55
+
56
+ batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
57
+ model_inputs = {
58
+ self._inputs_names[0]: batch.features[self._inputs_names[0]],
59
+ self._inputs_names[1]: batch.padding_mask,
60
+ }
61
+ if self._num_candidates_to_score is not None:
62
+ self._validate_candidates_to_score(candidates_to_score)
63
+ model_inputs[self._inputs_names[2]] = candidates_to_score
64
+ return torch.from_numpy(self._model(model_inputs)[self._output_name])
65
+
66
+ @classmethod
67
+ def compile(
68
+ cls,
69
+ model: Union[SasRec, str, pathlib.Path],
70
+ mode: OptimizedModeType = "one_query",
71
+ batch_size: Optional[int] = None,
72
+ num_candidates_to_score: Optional[int] = None,
73
+ num_threads: Optional[int] = None,
74
+ onnx_path: Optional[str] = None,
75
+ ) -> "SasRecCompiled":
76
+ """
77
+ Model compilation.
78
+
79
+ :param model: Path to lightning SasRec model saved in .ckpt format or the SasRec object itself.
80
+ :param mode: Inference mode, defines shape of inputs.
81
+ Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
82
+ ``one_query`` - sets input shape to [1, max_seq_len]\n
83
+ ``batch`` - sets input shape to [batch_size, max_seq_len]\n
84
+ ``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
85
+ Default: ``one_query``.
86
+ :param batch_size: Batch size, required for ``batch`` mode.
87
+ Default: ``None``.
88
+ :param num_candidates_to_score: Number of item ids to calculate scores.
89
+ Could be one of [``None``, ``-1``, ``N``].\n
90
+ ``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
91
+ ``N`` - sets candidates_to_score shape to [1, N]\n
92
+ ``None`` - disable candidates_to_score usage\n
93
+ Default: ``None``.
94
+ :param num_threads: Number of CPU threads to use.
95
+ Must be a natural number or ``None``.
96
+ If ``None``, then compiler will set this parameter automatically.
97
+ Default: ``None``.
98
+ :param onnx_path: Save ONNX model to path, if defined.
99
+ Default: ``None``.
100
+ """
101
+ if mode not in get_args(OptimizedModeType):
102
+ msg = f"Parameter ``mode`` could be one of {get_args(OptimizedModeType)}."
103
+ raise ValueError(msg)
104
+ num_candidates_to_score = SasRecCompiled._validate_num_candidates_to_score(num_candidates_to_score)
105
+ if isinstance(model, SasRec):
106
+ lightning_model = model.cpu()
107
+ elif isinstance(model, (str, pathlib.Path)):
108
+ lightning_model = SasRec.load_from_checkpoint(model, map_location=torch.device("cpu"))
109
+
110
+ schema = lightning_model._schema
111
+ item_seq_name = schema.item_id_feature_name
112
+ max_seq_len = lightning_model._model.max_len
113
+
114
+ batch_size, num_candidates_to_score = SasRecCompiled._get_input_params(
115
+ mode, batch_size, num_candidates_to_score
116
+ )
117
+
118
+ item_sequence = torch.zeros((1, max_seq_len)).long()
119
+ padding_mask = torch.zeros((1, max_seq_len)).bool()
120
+
121
+ model_input_names = [item_seq_name, "padding_mask"]
122
+ model_dynamic_axes_in_input = {
123
+ item_seq_name: {0: "batch_size", 1: "max_len"},
124
+ "padding_mask": {0: "batch_size", 1: "max_len"},
125
+ }
126
+ if num_candidates_to_score:
127
+ candidates_to_score = torch.zeros((1,)).long()
128
+ model_input_names += ["candidates_to_score"]
129
+ model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
130
+ model_input_sample = ({item_seq_name: item_sequence}, padding_mask, candidates_to_score)
131
+ else:
132
+ model_input_sample = ({item_seq_name: item_sequence}, padding_mask)
133
+
134
+ compiled_model = SasRecCompiled._run_model_compilation(
135
+ lightning_model,
136
+ model_input_sample,
137
+ model_input_names,
138
+ model_dynamic_axes_in_input,
139
+ batch_size,
140
+ num_candidates_to_score,
141
+ num_threads,
142
+ onnx_path,
143
+ )
144
+
145
+ return cls(compiled_model, schema)
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Set, Tuple, cast
1
+ from typing import List, Optional, Set, Tuple, Union, cast
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -17,6 +17,8 @@ class RemoveSeenItems(BasePostProcessor):
17
17
  def __init__(self, sequential: SequentialDataset) -> None:
18
18
  super().__init__()
19
19
  self._sequential = sequential
20
+ self._apply_candidates = False
21
+ self._candidates = None
20
22
 
21
23
  def on_validation(
22
24
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
@@ -30,6 +32,7 @@ class RemoveSeenItems(BasePostProcessor):
30
32
 
31
33
  :returns: modified query ids and scores and ground truth dataset
32
34
  """
35
+ self._apply_candidates = False
33
36
  modified_scores = self._compute_scores(query_ids, scores)
34
37
  return query_ids, modified_scores, ground_truth
35
38
 
@@ -42,6 +45,7 @@ class RemoveSeenItems(BasePostProcessor):
42
45
 
43
46
  :returns: modified query ids and scores
44
47
  """
48
+ self._apply_candidates = True
45
49
  modified_scores = self._compute_scores(query_ids, scores)
46
50
  return query_ids, modified_scores
47
51
 
@@ -56,6 +60,13 @@ class RemoveSeenItems(BasePostProcessor):
56
60
  value: float,
57
61
  ) -> torch.Tensor:
58
62
  flat_item_ids_on_device = flat_item_ids.to(scores.device)
63
+
64
+ if self._apply_candidates and self._candidates is not None:
65
+ item_count = self._sequential.schema.item_id_features.item().cardinality
66
+ assert item_count
67
+ _scores = torch.full((scores.shape[0], item_count), -float("inf")).to(scores.device)
68
+ _scores[:, self._candidates] = torch.reshape(scores, _scores[:, self._candidates].shape)
69
+ scores = _scores
59
70
  if scores.is_contiguous():
60
71
  scores.view(-1)[flat_item_ids_on_device] = value
61
72
  else:
@@ -80,6 +91,21 @@ class RemoveSeenItems(BasePostProcessor):
80
91
  flat_seen_item_ids_np = np.concatenate(item_id_sequences)
81
92
  return torch.LongTensor(flat_seen_item_ids_np)
82
93
 
94
+ @property
95
+ def candidates(self) -> Union[torch.LongTensor, None]:
96
+ """
97
+ Returns tensor of item ids to calculate scores.
98
+ """
99
+ return self._candidates
100
+
101
+ @candidates.setter
102
+ def candidates(self, candidates: Optional[torch.LongTensor] = None) -> None:
103
+ """
104
+ Sets tensor of item ids to calculate scores.
105
+ :param candidates: Tensor of item ids to calculate scores.
106
+ """
107
+ self._candidates = candidates
108
+
83
109
 
84
110
  class SampleItems(BasePostProcessor):
85
111
  """
@@ -10,6 +10,7 @@ from replay.data.nn import (
10
10
  TorchSequentialDataset,
11
11
  TorchSequentialValidationDataset,
12
12
  )
13
+ from replay.utils.model_handler import deprecation_warning
13
14
 
14
15
 
15
16
  class SasRecTrainingBatch(NamedTuple):
@@ -30,6 +31,10 @@ class SasRecTrainingDataset(TorchDataset):
30
31
  Dataset that generates samples to train SasRec-like model
31
32
  """
32
33
 
34
+ @deprecation_warning(
35
+ "`padding_value` parameter will be removed in future versions. "
36
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
37
+ )
33
38
  def __init__(
34
39
  self,
35
40
  sequential: SequentialDataset,
@@ -90,7 +95,10 @@ class SasRecTrainingDataset(TorchDataset):
90
95
 
91
96
  output_features: MutableTensorMap = {}
92
97
  for feature_name in self._schema:
93
- output_features[feature_name] = features[feature_name][: -self._sequence_shift]
98
+ feature = features[feature_name]
99
+ if self._schema[feature_name].is_seq:
100
+ feature = feature[: -self._sequence_shift]
101
+ output_features[feature_name] = feature
94
102
 
95
103
  output_features_padding_mask = padding_mask[: -self._sequence_shift]
96
104
 
@@ -119,6 +127,10 @@ class SasRecPredictionDataset(TorchDataset):
119
127
  Dataset that generates samples to infer SasRec-like model
120
128
  """
121
129
 
130
+ @deprecation_warning(
131
+ "`padding_value` parameter will be removed in future versions. "
132
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
133
+ )
122
134
  def __init__(
123
135
  self,
124
136
  sequential: SequentialDataset,
@@ -167,6 +179,10 @@ class SasRecValidationDataset(TorchDataset):
167
179
  Dataset that generates samples to infer and validate SasRec-like model
168
180
  """
169
181
 
182
+ @deprecation_warning(
183
+ "`padding_value` parameter will be removed in future versions. "
184
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
185
+ )
170
186
  def __init__(
171
187
  self,
172
188
  sequential: SequentialDataset,
@@ -33,7 +33,7 @@ class SasRec(lightning.LightningModule):
33
33
  loss_sample_count: Optional[int] = None,
34
34
  negative_sampling_strategy: str = "global_uniform",
35
35
  negatives_sharing: bool = False,
36
- optimizer_factory: Optional[OptimizerFactory] = None,
36
+ optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
37
37
  lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
38
38
  ):
39
39
  """
@@ -63,7 +63,7 @@ class SasRec(lightning.LightningModule):
63
63
  :param negatives_sharing: Apply negative sharing in calculating sampled logits.
64
64
  Default: ``False``.
65
65
  :param optimizer_factory: Optimizer factory.
66
- Default: ``None``.
66
+ Default: ``FatOptimizerFactory``.
67
67
  :param lr_scheduler_factory: Learning rate schedule factory.
68
68
  Default: ``None``.
69
69
  """
@@ -92,6 +92,7 @@ class SasRec(lightning.LightningModule):
92
92
  item_count = tensor_schema.item_id_features.item().cardinality
93
93
  assert item_count
94
94
  self._vocab_size = item_count
95
+ self.candidates_to_score = None
95
96
 
96
97
  def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
97
98
  """
@@ -106,30 +107,58 @@ class SasRec(lightning.LightningModule):
106
107
  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
107
108
  return loss
108
109
 
109
- def forward(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor: # pragma: no cover
110
+ def predict_step(
111
+ self,
112
+ batch: SasRecPredictionBatch,
113
+ batch_idx: int, # noqa: ARG002
114
+ dataloader_idx: int = 0, # noqa: ARG002
115
+ ) -> torch.Tensor:
110
116
  """
111
- :param feature_tensors: Batch of features.
112
- :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
117
+ :param batch: Batch of prediction data.
118
+ :param batch_idx: Batch index.
119
+ :param dataloader_idx: Dataloader index.
113
120
 
114
121
  :returns: Calculated scores.
115
122
  """
116
- return self._model_predict(feature_tensors, padding_mask)
123
+ batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
124
+ return self._model_predict(batch.features, batch.padding_mask)
117
125
 
118
- def predict_step(
119
- self, batch: SasRecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
126
+ def predict(
127
+ self,
128
+ batch: SasRecPredictionBatch,
129
+ candidates_to_score: Optional[torch.LongTensor] = None,
120
130
  ) -> torch.Tensor:
121
131
  """
122
132
  :param batch: Batch of prediction data.
123
- :param batch_idx: Batch index.
124
- :param dataloader_idx: Dataloader index.
133
+ :param candidates_to_score: Item ids to calculate scores.
134
+ Default: ``None``.
125
135
 
126
136
  :returns: Calculated scores.
127
137
  """
128
- batch = self._prepare_prediction_batch(batch)
129
- return self._model_predict(batch.features, batch.padding_mask)
138
+ batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
139
+ return self._model_predict(batch.features, batch.padding_mask, candidates_to_score)
140
+
141
+ def forward(
142
+ self,
143
+ feature_tensors: TensorMap,
144
+ padding_mask: torch.BoolTensor,
145
+ candidates_to_score: Optional[torch.LongTensor] = None,
146
+ ) -> torch.Tensor: # pragma: no cover
147
+ """
148
+ :param feature_tensors: Batch of features.
149
+ :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
150
+ :param candidates_to_score: Item ids to calculate scores.
151
+ Default: ``None``.
152
+
153
+ :returns: Calculated scores.
154
+ """
155
+ return self._model_predict(feature_tensors, padding_mask, candidates_to_score)
130
156
 
131
157
  def validation_step(
132
- self, batch: SasRecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
158
+ self,
159
+ batch: SasRecValidationBatch,
160
+ batch_idx: int, # noqa: ARG002
161
+ dataloader_idx: int = 0, # noqa: ARG002
133
162
  ) -> torch.Tensor:
134
163
  """
135
164
  :param batch (SasRecValidationBatch): Batch of prediction data.
@@ -143,8 +172,7 @@ class SasRec(lightning.LightningModule):
143
172
  """
144
173
  :returns: Configured optimizer and lr scheduler.
145
174
  """
146
- optimizer_factory = self._optimizer_factory or FatOptimizerFactory()
147
- optimizer = optimizer_factory.create(self._model.parameters())
175
+ optimizer = self._optimizer_factory.create(self._model.parameters())
148
176
 
149
177
  if self._lr_scheduler_factory is None:
150
178
  return optimizer
@@ -152,38 +180,16 @@ class SasRec(lightning.LightningModule):
152
180
  lr_scheduler = self._lr_scheduler_factory.create(optimizer)
153
181
  return [optimizer], [lr_scheduler]
154
182
 
155
- def _prepare_prediction_batch(self, batch: SasRecPredictionBatch) -> SasRecPredictionBatch:
156
- if batch.padding_mask.shape[1] > self._model.max_len:
157
- msg = f"The length of the submitted sequence \
158
- must not exceed the maximum length of the sequence. \
159
- The length of the sequence is given {batch.padding_mask.shape[1]}, \
160
- while the maximum length is {self._model.max_len}"
161
- raise ValueError(msg)
162
-
163
- if batch.padding_mask.shape[1] < self._model.max_len:
164
- query_id, padding_mask, features = batch
165
- sequence_item_count = padding_mask.shape[1]
166
- for feature_name, feature_tensor in features.items():
167
- if self._schema[feature_name].is_cat:
168
- features[feature_name] = torch.nn.functional.pad(
169
- feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
170
- )
171
- else:
172
- features[feature_name] = torch.nn.functional.pad(
173
- feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
174
- (self._model.max_len - sequence_item_count, 0),
175
- value=0,
176
- ).unsqueeze(-1)
177
- padding_mask = torch.nn.functional.pad(
178
- padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
179
- )
180
- batch = SasRecPredictionBatch(query_id, padding_mask, features)
181
- return batch
182
-
183
- def _model_predict(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor:
183
+ def _model_predict(
184
+ self,
185
+ feature_tensors: TensorMap,
186
+ padding_mask: torch.BoolTensor,
187
+ candidates_to_score: torch.LongTensor = None,
188
+ ) -> torch.Tensor:
184
189
  model: SasRecModel
185
190
  model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
186
- scores = model.predict(feature_tensors, padding_mask)
191
+ candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
192
+ scores = model.predict(feature_tensors, padding_mask, candidates_to_score)
187
193
  return scores
188
194
 
189
195
  def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
@@ -479,6 +485,50 @@ class SasRec(lightning.LightningModule):
479
485
 
480
486
  self._set_new_item_embedder_to_model(new_embedding, new_vocab_size)
481
487
 
488
+ @property
489
+ def optimizer_factory(self) -> OptimizerFactory:
490
+ """
491
+ Returns current optimizer_factory.
492
+ """
493
+ return self._optimizer_factory
494
+
495
+ @optimizer_factory.setter
496
+ def optimizer_factory(self, optimizer_factory: OptimizerFactory) -> None:
497
+ """
498
+ Sets new optimizer_factory.
499
+ :param optimizer_factory: New optimizer factory.
500
+ """
501
+ if isinstance(optimizer_factory, OptimizerFactory):
502
+ self._optimizer_factory = optimizer_factory
503
+ else:
504
+ msg = f"Expected optimizer_factory of type OptimizerFactory, got {type(optimizer_factory)}"
505
+ raise ValueError(msg)
506
+
507
+ @property
508
+ def candidates_to_score(self) -> Union[torch.LongTensor, None]:
509
+ """
510
+ Returns tensor of item ids to calculate scores.
511
+ """
512
+ return self._candidates_to_score
513
+
514
+ @candidates_to_score.setter
515
+ def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
516
+ """
517
+ Sets tensor of item ids to calculate scores.
518
+ :param candidates: Tensor of item ids to calculate scores.
519
+ """
520
+ total_item_count = self._model.item_count
521
+ if isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long:
522
+ if 0 < candidates.shape[0] <= total_item_count:
523
+ self._candidates_to_score = candidates
524
+ else:
525
+ msg = f"Expected candidates length to be between 1 and {total_item_count=}"
526
+ raise ValueError(msg)
527
+ elif candidates is not None:
528
+ msg = f"Expected candidates to be of type torch.LongTensor or None, gpt {type(candidates)}"
529
+ raise ValueError(msg)
530
+ self._candidates_to_score = candidates
531
+
482
532
  def _set_new_item_embedder_to_model(self, new_embedding: torch.nn.Embedding, new_vocab_size: int):
483
533
  self._model.item_embedder.item_emb = new_embedding
484
534
  self._model._head._item_embedder = self._model.item_embedder
@@ -486,11 +536,37 @@ class SasRec(lightning.LightningModule):
486
536
  self._model.item_count = new_vocab_size
487
537
  self._model.padding_idx = new_vocab_size
488
538
  self._model.masking.padding_idx = new_vocab_size
489
- self._model.candidates_to_score = torch.tensor(
490
- list(range(new_embedding.weight.data.shape[0] - 1)),
491
- device=self._model.candidates_to_score.device,
492
- dtype=torch.long,
493
- )
494
539
  self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(
495
540
  new_embedding.weight.data.shape[0] - 1
496
541
  )
542
+
543
+
544
+ def _prepare_prediction_batch(
545
+ schema: TensorSchema, max_len: int, batch: SasRecPredictionBatch
546
+ ) -> SasRecPredictionBatch:
547
+ if batch.padding_mask.shape[1] > max_len:
548
+ msg = (
549
+ "The length of the submitted sequence "
550
+ "must not exceed the maximum length of the sequence. "
551
+ f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
552
+ f"while the maximum length is {max_len}"
553
+ )
554
+ raise ValueError(msg)
555
+
556
+ if batch.padding_mask.shape[1] < max_len:
557
+ query_id, padding_mask, features = batch
558
+ sequence_item_count = padding_mask.shape[1]
559
+ for feature_name, feature_tensor in features.items():
560
+ if schema[feature_name].is_cat:
561
+ features[feature_name] = torch.nn.functional.pad(
562
+ feature_tensor, (max_len - sequence_item_count, 0), value=0
563
+ )
564
+ else:
565
+ features[feature_name] = torch.nn.functional.pad(
566
+ feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
567
+ (max_len - sequence_item_count, 0),
568
+ value=0,
569
+ ).unsqueeze(-1)
570
+ padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
571
+ batch = SasRecPredictionBatch(query_id, padding_mask, features)
572
+ return batch