replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
  35. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
  36. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
  37. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
@@ -31,7 +31,7 @@ class Bert4Rec(lightning.LightningModule):
31
31
  loss_sample_count: Optional[int] = None,
32
32
  negative_sampling_strategy: str = "global_uniform",
33
33
  negatives_sharing: bool = False,
34
- optimizer_factory: Optional[OptimizerFactory] = None,
34
+ optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
35
35
  lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
36
36
  ):
37
37
  """
@@ -65,7 +65,7 @@ class Bert4Rec(lightning.LightningModule):
65
65
  :param negatives_sharing: Apply negative sharing in calculating sampled logits.
66
66
  Default: ``False``.
67
67
  :param optimizer_factory: Optimizer factory.
68
- Default: ``None``.
68
+ Default: ``FatOptimizerFactory``.
69
69
  :param lr_scheduler_factory: Learning rate schedule factory.
70
70
  Default: ``None``.
71
71
  """
@@ -95,6 +95,7 @@ class Bert4Rec(lightning.LightningModule):
95
95
  item_count = tensor_schema.item_id_features.item().cardinality
96
96
  assert item_count
97
97
  self._vocab_size = item_count
98
+ self.candidates_to_score = None
98
99
 
99
100
  def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
100
101
  """
@@ -107,33 +108,51 @@ class Bert4Rec(lightning.LightningModule):
107
108
  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
108
109
  return loss
109
110
 
111
+ def predict_step(
112
+ self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
113
+ ) -> torch.Tensor:
114
+ """
115
+ :param batch (Bert4RecPredictionBatch): Batch of prediction data.
116
+ :param batch_idx (int): Batch index.
117
+ :param dataloader_idx (int): Dataloader index.
118
+
119
+ :returns: Calculated scores on prediction batch.
120
+ """
121
+ batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
122
+ return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
123
+
124
+ def predict(
125
+ self,
126
+ batch: Bert4RecPredictionBatch,
127
+ candidates_to_score: Optional[torch.LongTensor] = None,
128
+ ) -> torch.Tensor:
129
+ """
130
+ :param batch (Bert4RecPredictionBatch): Batch of prediction data.
131
+ :param candidates_to_score: Item ids to calculate scores.
132
+ Default: ``None``.
133
+
134
+ :returns: Calculated scores on prediction batch.
135
+ """
136
+ batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
137
+ return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask, candidates_to_score)
138
+
110
139
  def forward(
111
140
  self,
112
141
  feature_tensors: TensorMap,
113
142
  padding_mask: torch.BoolTensor,
114
143
  tokens_mask: torch.BoolTensor,
144
+ candidates_to_score: Optional[torch.LongTensor] = None,
115
145
  ) -> torch.Tensor: # pragma: no cover
116
146
  """
117
147
  :param feature_tensors: Batch of features.
118
148
  :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
119
149
  :param tokens_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
150
+ :param candidates_to_score: Item ids to calculate scores.
151
+ Default: ``None``.
120
152
 
121
153
  :returns: Calculated scores.
122
154
  """
123
- return self._model_predict(feature_tensors, padding_mask, tokens_mask)
124
-
125
- def predict_step(
126
- self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
127
- ) -> torch.Tensor:
128
- """
129
- :param batch (Bert4RecPredictionBatch): Batch of prediction data.
130
- :param batch_idx (int): Batch index.
131
- :param dataloader_idx (int): Dataloader index.
132
-
133
- :returns: Calculated scores on prediction batch.
134
- """
135
- batch = self._prepare_prediction_batch(batch)
136
- return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
155
+ return self._model_predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
137
156
 
138
157
  def validation_step(
139
158
  self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
@@ -150,8 +169,7 @@ class Bert4Rec(lightning.LightningModule):
150
169
  """
151
170
  :returns: Configured optimizer and lr scheduler.
152
171
  """
153
- optimizer_factory = self._optimizer_factory or FatOptimizerFactory()
154
- optimizer = optimizer_factory.create(self._model.parameters())
172
+ optimizer = self._optimizer_factory.create(self._model.parameters())
155
173
 
156
174
  if self._lr_scheduler_factory is None:
157
175
  return optimizer
@@ -159,49 +177,20 @@ class Bert4Rec(lightning.LightningModule):
159
177
  lr_scheduler = self._lr_scheduler_factory.create(optimizer)
160
178
  return [optimizer], [lr_scheduler]
161
179
 
162
- def _prepare_prediction_batch(self, batch: Bert4RecPredictionBatch) -> Bert4RecPredictionBatch:
163
- if batch.padding_mask.shape[1] > self._model.max_len:
164
- msg = f"The length of the submitted sequence \
165
- must not exceed the maximum length of the sequence. \
166
- The length of the sequence is given {batch.padding_mask.shape[1]}, \
167
- while the maximum length is {self._model.max_len}"
168
- raise ValueError(msg)
169
-
170
- if batch.padding_mask.shape[1] < self._model.max_len:
171
- query_id, padding_mask, features, _ = batch
172
- sequence_item_count = padding_mask.shape[1]
173
- for feature_name, feature_tensor in features.items():
174
- if self._schema[feature_name].is_cat:
175
- features[feature_name] = torch.nn.functional.pad(
176
- feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
177
- )
178
- else:
179
- features[feature_name] = torch.nn.functional.pad(
180
- feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
181
- (self._model.max_len - sequence_item_count, 0),
182
- value=0,
183
- ).unsqueeze(-1)
184
- padding_mask = torch.nn.functional.pad(
185
- padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
186
- )
187
- shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
188
- batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
189
- return batch
190
-
191
180
  def _model_predict(
192
181
  self,
193
182
  feature_tensors: TensorMap,
194
183
  padding_mask: torch.BoolTensor,
195
184
  tokens_mask: torch.BoolTensor,
185
+ candidates_to_score: torch.LongTensor = None,
196
186
  ) -> torch.Tensor:
197
187
  model: Bert4RecModel
198
- if isinstance(self._model, torch.nn.DataParallel):
199
- model = cast(Bert4RecModel, self._model.module) # multigpu
200
- else:
201
- model = self._model
202
- scores = model(feature_tensors, padding_mask, tokens_mask)
203
- candidate_scores = scores[:, -1, :]
204
- return candidate_scores
188
+ model = (
189
+ cast(Bert4RecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
190
+ )
191
+ candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
192
+ scores = model.predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
193
+ return scores
205
194
 
206
195
  def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
207
196
  if self._loss_type == "BCE":
@@ -504,6 +493,50 @@ class Bert4Rec(lightning.LightningModule):
504
493
 
505
494
  self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
506
495
 
496
+ @property
497
+ def optimizer_factory(self) -> OptimizerFactory:
498
+ """
499
+ Returns current optimizer_factory.
500
+ """
501
+ return self._optimizer_factory
502
+
503
+ @optimizer_factory.setter
504
+ def optimizer_factory(self, optimizer_factory: OptimizerFactory) -> None:
505
+ """
506
+ Sets new optimizer_factory.
507
+ :param optimizer_factory: New optimizer factory.
508
+ """
509
+ if isinstance(optimizer_factory, OptimizerFactory):
510
+ self._optimizer_factory = optimizer_factory
511
+ else:
512
+ msg = f"Expected optimizer_factory of type OptimizerFactory, got {type(optimizer_factory)}"
513
+ raise ValueError(msg)
514
+
515
+ @property
516
+ def candidates_to_score(self) -> Union[torch.LongTensor, None]:
517
+ """
518
+ Returns tensor of item ids to calculate scores.
519
+ """
520
+ return self._candidates_to_score
521
+
522
+ @candidates_to_score.setter
523
+ def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
524
+ """
525
+ Sets tensor of item ids to calculate scores.
526
+ :param candidates: Tensor of item ids to calculate scores.
527
+ """
528
+ total_item_count = self._model.item_count
529
+ if isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long:
530
+ if 0 < candidates.shape[0] <= total_item_count:
531
+ self._candidates_to_score = candidates
532
+ else:
533
+ msg = f"Expected candidates length to be between 1 and {total_item_count=}"
534
+ raise ValueError(msg)
535
+ elif candidates is not None:
536
+ msg = f"Expected candidates to be of type torch.LongTensor or None, gpt {type(candidates)}"
537
+ raise ValueError(msg)
538
+ self._candidates_to_score = candidates
539
+
507
540
  def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int):
508
541
  self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new
509
542
  if self._model.enable_embedding_tying is True:
@@ -521,3 +554,37 @@ class Bert4Rec(lightning.LightningModule):
521
554
  self._vocab_size = new_vocab_size
522
555
  self._model.item_count = new_vocab_size
523
556
  self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
557
+
558
+
559
+ def _prepare_prediction_batch(
560
+ schema: TensorSchema, max_len: int, batch: Bert4RecPredictionBatch
561
+ ) -> Bert4RecPredictionBatch:
562
+ if batch.padding_mask.shape[1] > max_len:
563
+ msg = (
564
+ f"The length of the submitted sequence "
565
+ "must not exceed the maximum length of the sequence. "
566
+ f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
567
+ f"while the maximum length is {max_len}"
568
+ )
569
+ raise ValueError(msg)
570
+
571
+ if batch.padding_mask.shape[1] < max_len:
572
+ query_id, padding_mask, features, _ = batch
573
+ sequence_item_count = padding_mask.shape[1]
574
+ for feature_name, feature_tensor in features.items():
575
+ if schema[feature_name].is_cat:
576
+ features[feature_name] = torch.nn.functional.pad(
577
+ feature_tensor,
578
+ (max_len - sequence_item_count, 0),
579
+ value=schema[feature_name].padding_value,
580
+ )
581
+ else:
582
+ features[feature_name] = torch.nn.functional.pad(
583
+ feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
584
+ (max_len - sequence_item_count, 0),
585
+ value=schema[feature_name].padding_value,
586
+ ).unsqueeze(-1)
587
+ padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
588
+ shifted_features, shifted_padding_mask, tokens_mask = _shift_features(schema, features, padding_mask)
589
+ batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
590
+ return batch
@@ -98,6 +98,27 @@ class Bert4RecModel(torch.nn.Module):
98
98
 
99
99
  return all_scores # [B x L x E]
100
100
 
101
+ def predict(
102
+ self,
103
+ inputs: TensorMap,
104
+ pad_mask: torch.BoolTensor,
105
+ token_mask: torch.BoolTensor,
106
+ candidates_to_score: Optional[torch.LongTensor] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ :param inputs: Batch of features.
110
+ :param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
111
+ :param token_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
112
+ :param candidates_to_score: Item ids to calculate scores.
113
+ if `None` predicts for all items
114
+
115
+ :returns: Calculated scores among canditates_to_score items.
116
+ """
117
+ # final_emb: [B x E]
118
+ final_emb = self.get_query_embeddings(inputs, pad_mask, token_mask)
119
+ candidate_scores = self.get_logits(final_emb, candidates_to_score)
120
+ return candidate_scores
121
+
101
122
  def forward_step(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor) -> torch.Tensor:
102
123
  """
103
124
 
@@ -64,6 +64,11 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
64
64
  self._item_batches.clear()
65
65
  self._item_scores.clear()
66
66
 
67
+ candidates = trainer.model.candidates_to_score if hasattr(trainer.model, "candidates_to_score") else None
68
+ for postprocessor in self._postprocessors:
69
+ if hasattr(postprocessor, "candidates"):
70
+ postprocessor.candidates = candidates
71
+
67
72
  def on_predict_batch_end(
68
73
  self,
69
74
  trainer: lightning.Trainer, # noqa: ARG002
@@ -88,7 +93,6 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
88
93
  torch.cat(self._item_batches),
89
94
  torch.cat(self._item_scores),
90
95
  )
91
-
92
96
  return prediction
93
97
 
94
98
  def _compute_pipeline(
@@ -0,0 +1,5 @@
1
+ from replay.utils import OPENVINO_AVAILABLE
2
+
3
+ if OPENVINO_AVAILABLE:
4
+ from .bert4rec_compiled import Bert4RecCompiled
5
+ from .sasrec_compiled import SasRecCompiled
@@ -0,0 +1,261 @@
1
+ import pathlib
2
+ import tempfile
3
+ from abc import abstractmethod
4
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ import lightning
7
+ import openvino as ov
8
+ import torch
9
+
10
+ from replay.data.nn import TensorSchema
11
+
12
+ OptimizedModeType = Literal[
13
+ "batch",
14
+ "one_query",
15
+ "dynamic_batch_size",
16
+ ]
17
+
18
+
19
+ def _compile_openvino(
20
+ onnx_path: str,
21
+ batch_size: int,
22
+ max_seq_len: int,
23
+ num_candidates_to_score: int,
24
+ num_threads: Optional[int],
25
+ ) -> ov.CompiledModel:
26
+ """
27
+ Method defines compilation strategy for openvino backend.
28
+
29
+ :param onnx_path: Path to the model representation in ONNX format.
30
+ :param batch_size: Defines whether batch will be static or dynamic length.
31
+ :param max_seq_len: Defines whether sequence will be static or dynamic length.
32
+ :param num_candidates_to_score: Defines whether candidates will be static or dynamic length.
33
+ :param num_threads: Defines number of CPU threads for which the model will be compiled by the OpenVino core.
34
+ If ``None``, then compiler will set this parameter automatically.
35
+ Default: ``None``.
36
+
37
+ :return: Compiled model.
38
+ """
39
+ core = ov.Core()
40
+ if num_threads is not None:
41
+ core.set_property("CPU", {"INFERENCE_NUM_THREADS": num_threads})
42
+ model_onnx = core.read_model(model=onnx_path)
43
+ inputs_names = [inputs.names.pop() for inputs in model_onnx.inputs]
44
+ del model_onnx
45
+
46
+ candidates_input_id = len(inputs_names) - 1 if num_candidates_to_score is not None else len(inputs_names)
47
+ model_input_scheme = [(input_name, [batch_size, max_seq_len]) for input_name in inputs_names[:candidates_input_id]]
48
+ if num_candidates_to_score is not None:
49
+ model_input_scheme += [(inputs_names[candidates_input_id], [num_candidates_to_score])]
50
+ model_onnx = ov.convert_model(onnx_path, input=model_input_scheme)
51
+ return core.compile_model(model=model_onnx, device_name="CPU")
52
+
53
+
54
+ class BaseCompiledModel:
55
+ """
56
+ Base class of CPU-optimized model for inference via OpenVINO.
57
+ It is recommended to use inherited classes and not to use this one.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ compiled_model: ov.CompiledModel,
63
+ schema: TensorSchema,
64
+ ) -> None:
65
+ """
66
+ :param compiled_model: Compiled model.
67
+ :param schema: Tensor schema of SasRec model.
68
+ """
69
+ self._batch_size: int
70
+ self._max_seq_len: int
71
+ self._inputs_names: List[str]
72
+ self._output_name: str
73
+
74
+ self._set_inner_params_from_openvino_model(compiled_model)
75
+ self._schema = schema
76
+ self._model = compiled_model
77
+
78
+ @abstractmethod
79
+ def predict(
80
+ self,
81
+ batch: Any,
82
+ candidates_to_score: Optional[torch.LongTensor] = None,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Inference on one batch.
86
+
87
+ :param batch: Prediction input.
88
+ :param candidates_to_score: Item ids to calculate scores.
89
+ Default: ``None``.
90
+
91
+ :return: Tensor with scores.
92
+ """
93
+
94
+ def _validate_candidates_to_score(self, candidates: torch.LongTensor) -> None:
95
+ """Check if candidates param has proper type"""
96
+
97
+ if not (isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long):
98
+ msg = (
99
+ "Expected candidates to be of type ``torch.Tensor`` with dtype ``torch.long``, "
100
+ f"got {type(candidates)} with dtype {candidates.dtype}."
101
+ )
102
+ raise ValueError(msg)
103
+
104
+ def _valilade_predict_input(self, batch: Any, candidates_to_score: Optional[torch.LongTensor] = None) -> None:
105
+ if self._num_candidates_to_score is None and candidates_to_score is not None:
106
+ msg = (
107
+ "If ``num_candidates_to_score`` is None, "
108
+ "it is impossible to infer the model with passed ``candidates_to_score``."
109
+ )
110
+ raise ValueError(msg)
111
+
112
+ if self._batch_size != -1 and batch.padding_mask.shape[0] != self._batch_size:
113
+ msg = (
114
+ f"The batch is smaller then defined batch_size={self._batch_size}. "
115
+ "It is impossible to infer the model with dynamic batch size in ``mode`` = ``batch``. "
116
+ "Use ``mode`` = ``dynamic_batch_size``."
117
+ )
118
+ raise ValueError(msg)
119
+
120
+ def _set_inner_params_from_openvino_model(self, compiled_model: ov.CompiledModel) -> None:
121
+ """Set params for ``predict`` method"""
122
+
123
+ input_scheme = compiled_model.inputs
124
+ self._batch_size = input_scheme[0].partial_shape[0].max_length
125
+ self._max_seq_len = input_scheme[0].partial_shape[1].max_length
126
+ self._inputs_names = [input.names.pop() for input in compiled_model.inputs]
127
+ if "candidates_to_score" in self._inputs_names:
128
+ self._num_candidates_to_score = input_scheme[-1].partial_shape[0].max_length
129
+ else:
130
+ self._num_candidates_to_score = None
131
+ self._output_name = compiled_model.output().names.pop()
132
+
133
+ @staticmethod
134
+ def _validate_num_candidates_to_score(num_candidates: Union[int, None]) -> Union[int, None]:
135
+ """Check if num_candidates param is proper"""
136
+
137
+ if num_candidates is None:
138
+ return num_candidates
139
+ if isinstance(num_candidates, int) and (num_candidates == -1 or num_candidates >= 1):
140
+ return num_candidates
141
+
142
+ msg = (
143
+ "Expected num_candidates_to_score to be of type ``int``, equal to ``-1``, ``natural number`` or ``None``. "
144
+ f"Got {num_candidates}."
145
+ )
146
+ raise ValueError(msg)
147
+
148
+ @staticmethod
149
+ def _get_input_params(
150
+ mode: OptimizedModeType,
151
+ batch_size: Optional[int],
152
+ num_candidates_to_score: Optional[int],
153
+ ) -> None:
154
+ """Get params for model compilation according to compilation mode"""
155
+
156
+ if mode == "one_query":
157
+ batch_size = 1
158
+
159
+ if mode == "batch":
160
+ assert batch_size, f"{mode} mode requires `batch_size`"
161
+ batch_size = batch_size
162
+
163
+ if mode == "dynamic_batch_size":
164
+ batch_size = -1
165
+
166
+ num_candidates_to_score = num_candidates_to_score if num_candidates_to_score else None
167
+ return batch_size, num_candidates_to_score
168
+
169
+ @staticmethod
170
+ def _run_model_compilation(
171
+ lightning_model: lightning.LightningModule,
172
+ model_input_sample: Tuple[Union[torch.Tensor, Dict[str, torch.Tensor]]],
173
+ model_input_names: List[str],
174
+ model_dynamic_axes_in_input: Dict[str, Dict],
175
+ batch_size: int,
176
+ num_candidates_to_score: Union[int, None],
177
+ num_threads: Optional[int] = None,
178
+ onnx_path: Optional[str] = None,
179
+ ) -> ov.CompiledModel:
180
+ """
181
+ Model conversion into ONNX format and compilation with defined engine.
182
+
183
+ :param lightning_model: Lightning model to be compiled.
184
+ :param model_input_sample: An example of model input with proper data type.
185
+ :param model_input_names: Input tensor names.
186
+ :param model_dynamic_axes_in_input: Dynamic axes in input.
187
+ :param batch_size: Defines the size of the axis with index 0 in the input of the compiled model.
188
+ :param num_candidates_to_score: Defines the size of the candidates in the input of the compiled model.
189
+ :param num_threads: Number of CPU threads to use.
190
+ Must be a natural number or ``None``.
191
+ If ``None``, then compiler will set this parameter automatically.
192
+ Default: ``None``.
193
+ :param onnx_path: Save ONNX model to path, if defined.
194
+ Default: ``None``.
195
+
196
+ :return: Compiled model.
197
+ """
198
+ max_seq_len = lightning_model._model.max_len
199
+
200
+ if onnx_path is None:
201
+ is_saveble = False
202
+ onnx_file = tempfile.NamedTemporaryFile(suffix=".onnx")
203
+ onnx_path = onnx_file.name
204
+ else:
205
+ is_saveble = True
206
+
207
+ lightning_model.to_onnx(
208
+ onnx_path,
209
+ input_sample=model_input_sample,
210
+ export_params=True,
211
+ opset_version=torch.onnx._constants.ONNX_DEFAULT_OPSET,
212
+ do_constant_folding=True,
213
+ input_names=model_input_names,
214
+ output_names=["scores"],
215
+ dynamic_axes=model_dynamic_axes_in_input,
216
+ )
217
+ del lightning_model
218
+
219
+ compiled_model = _compile_openvino(onnx_path, batch_size, max_seq_len, num_candidates_to_score, num_threads)
220
+
221
+ if not is_saveble:
222
+ onnx_file.close()
223
+
224
+ return compiled_model
225
+
226
+ @classmethod
227
+ @abstractmethod
228
+ def compile(
229
+ cls,
230
+ model: Union[lightning.LightningModule, str, pathlib.Path],
231
+ mode: OptimizedModeType = "one_query",
232
+ batch_size: Optional[int] = None,
233
+ num_candidates_to_score: Optional[int] = None,
234
+ num_threads: Optional[int] = None,
235
+ onnx_path: Optional[str] = None,
236
+ ) -> "BaseCompiledModel":
237
+ """
238
+ Model compilation.
239
+
240
+ :param model: Path to lightning model saved in .ckpt format or the model object itself.
241
+ :param mode: Inference mode, defines shape of inputs.
242
+ Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
243
+ ``one_query`` - sets input shape to [1, max_seq_len]\n
244
+ ``batch`` - sets input shape to [batch_size, max_seq_len]\n
245
+ ``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
246
+ Default: ``one_query``.
247
+ :param batch_size: Batch size, required for ``batch`` mode.
248
+ Default: ``None``.
249
+ :param num_candidates_to_score: Number of item ids to calculate scores.
250
+ Could be one of [``None``, ``-1``, ``N``].\n
251
+ ``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
252
+ ``N`` - sets candidates_to_score shape to [1, N]\n
253
+ ``None`` - disable candidates_to_score usage\n
254
+ Default: ``None``.
255
+ :param num_threads: Number of CPU threads to use.
256
+ Must be a natural number or ``None``.
257
+ If ``None``, then compiler will set this parameter automatically.
258
+ Default: ``None``.
259
+ :param onnx_path: Save ONNX model to path, if defined.
260
+ Default: ``None``.
261
+ """