replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +27 -1
- replay/data/dataset_utils/dataset_label_encoder.py +6 -3
- replay/data/nn/schema.py +37 -16
- replay/data/nn/sequence_tokenizer.py +313 -165
- replay/data/nn/torch_sequential_dataset.py +17 -8
- replay/data/nn/utils.py +14 -7
- replay/data/schema.py +10 -6
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +1 -0
- replay/models/base_rec.py +18 -21
- replay/models/lin_ucb.py +407 -0
- replay/models/nn/sequential/bert4rec/dataset.py +17 -4
- replay/models/nn/sequential/bert4rec/lightning.py +121 -54
- replay/models/nn/sequential/bert4rec/model.py +21 -0
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
- replay/models/nn/sequential/compiled/__init__.py +5 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
- replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
- replay/models/nn/sequential/sasrec/dataset.py +17 -1
- replay/models/nn/sequential/sasrec/lightning.py +126 -50
- replay/models/nn/sequential/sasrec/model.py +3 -4
- replay/preprocessing/__init__.py +7 -1
- replay/preprocessing/discretizer.py +719 -0
- replay/preprocessing/label_encoder.py +384 -52
- replay/splitters/cold_user_random_splitter.py +1 -1
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +7 -8
- replay/utils/session_handler.py +3 -4
- replay/utils/spark_utils.py +15 -1
- replay/utils/types.py +8 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -31,7 +31,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
31
31
|
loss_sample_count: Optional[int] = None,
|
|
32
32
|
negative_sampling_strategy: str = "global_uniform",
|
|
33
33
|
negatives_sharing: bool = False,
|
|
34
|
-
optimizer_factory:
|
|
34
|
+
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
|
|
35
35
|
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
|
|
36
36
|
):
|
|
37
37
|
"""
|
|
@@ -65,7 +65,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
65
65
|
:param negatives_sharing: Apply negative sharing in calculating sampled logits.
|
|
66
66
|
Default: ``False``.
|
|
67
67
|
:param optimizer_factory: Optimizer factory.
|
|
68
|
-
Default: ``
|
|
68
|
+
Default: ``FatOptimizerFactory``.
|
|
69
69
|
:param lr_scheduler_factory: Learning rate schedule factory.
|
|
70
70
|
Default: ``None``.
|
|
71
71
|
"""
|
|
@@ -95,6 +95,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
95
95
|
item_count = tensor_schema.item_id_features.item().cardinality
|
|
96
96
|
assert item_count
|
|
97
97
|
self._vocab_size = item_count
|
|
98
|
+
self.candidates_to_score = None
|
|
98
99
|
|
|
99
100
|
def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
|
|
100
101
|
"""
|
|
@@ -107,33 +108,51 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
107
108
|
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
|
108
109
|
return loss
|
|
109
110
|
|
|
111
|
+
def predict_step(
|
|
112
|
+
self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
113
|
+
) -> torch.Tensor:
|
|
114
|
+
"""
|
|
115
|
+
:param batch (Bert4RecPredictionBatch): Batch of prediction data.
|
|
116
|
+
:param batch_idx (int): Batch index.
|
|
117
|
+
:param dataloader_idx (int): Dataloader index.
|
|
118
|
+
|
|
119
|
+
:returns: Calculated scores on prediction batch.
|
|
120
|
+
"""
|
|
121
|
+
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
122
|
+
return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
|
|
123
|
+
|
|
124
|
+
def predict(
|
|
125
|
+
self,
|
|
126
|
+
batch: Bert4RecPredictionBatch,
|
|
127
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
128
|
+
) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
:param batch (Bert4RecPredictionBatch): Batch of prediction data.
|
|
131
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
132
|
+
Default: ``None``.
|
|
133
|
+
|
|
134
|
+
:returns: Calculated scores on prediction batch.
|
|
135
|
+
"""
|
|
136
|
+
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
137
|
+
return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask, candidates_to_score)
|
|
138
|
+
|
|
110
139
|
def forward(
|
|
111
140
|
self,
|
|
112
141
|
feature_tensors: TensorMap,
|
|
113
142
|
padding_mask: torch.BoolTensor,
|
|
114
143
|
tokens_mask: torch.BoolTensor,
|
|
144
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
115
145
|
) -> torch.Tensor: # pragma: no cover
|
|
116
146
|
"""
|
|
117
147
|
:param feature_tensors: Batch of features.
|
|
118
148
|
:param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
119
149
|
:param tokens_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
|
|
150
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
151
|
+
Default: ``None``.
|
|
120
152
|
|
|
121
153
|
:returns: Calculated scores.
|
|
122
154
|
"""
|
|
123
|
-
return self._model_predict(feature_tensors, padding_mask, tokens_mask)
|
|
124
|
-
|
|
125
|
-
def predict_step(
|
|
126
|
-
self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
127
|
-
) -> torch.Tensor:
|
|
128
|
-
"""
|
|
129
|
-
:param batch (Bert4RecPredictionBatch): Batch of prediction data.
|
|
130
|
-
:param batch_idx (int): Batch index.
|
|
131
|
-
:param dataloader_idx (int): Dataloader index.
|
|
132
|
-
|
|
133
|
-
:returns: Calculated scores on prediction batch.
|
|
134
|
-
"""
|
|
135
|
-
batch = self._prepare_prediction_batch(batch)
|
|
136
|
-
return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
|
|
155
|
+
return self._model_predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
|
|
137
156
|
|
|
138
157
|
def validation_step(
|
|
139
158
|
self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
@@ -150,8 +169,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
150
169
|
"""
|
|
151
170
|
:returns: Configured optimizer and lr scheduler.
|
|
152
171
|
"""
|
|
153
|
-
|
|
154
|
-
optimizer = optimizer_factory.create(self._model.parameters())
|
|
172
|
+
optimizer = self._optimizer_factory.create(self._model.parameters())
|
|
155
173
|
|
|
156
174
|
if self._lr_scheduler_factory is None:
|
|
157
175
|
return optimizer
|
|
@@ -159,49 +177,20 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
159
177
|
lr_scheduler = self._lr_scheduler_factory.create(optimizer)
|
|
160
178
|
return [optimizer], [lr_scheduler]
|
|
161
179
|
|
|
162
|
-
def _prepare_prediction_batch(self, batch: Bert4RecPredictionBatch) -> Bert4RecPredictionBatch:
|
|
163
|
-
if batch.padding_mask.shape[1] > self._model.max_len:
|
|
164
|
-
msg = f"The length of the submitted sequence \
|
|
165
|
-
must not exceed the maximum length of the sequence. \
|
|
166
|
-
The length of the sequence is given {batch.padding_mask.shape[1]}, \
|
|
167
|
-
while the maximum length is {self._model.max_len}"
|
|
168
|
-
raise ValueError(msg)
|
|
169
|
-
|
|
170
|
-
if batch.padding_mask.shape[1] < self._model.max_len:
|
|
171
|
-
query_id, padding_mask, features, _ = batch
|
|
172
|
-
sequence_item_count = padding_mask.shape[1]
|
|
173
|
-
for feature_name, feature_tensor in features.items():
|
|
174
|
-
if self._schema[feature_name].is_cat:
|
|
175
|
-
features[feature_name] = torch.nn.functional.pad(
|
|
176
|
-
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
|
|
177
|
-
)
|
|
178
|
-
else:
|
|
179
|
-
features[feature_name] = torch.nn.functional.pad(
|
|
180
|
-
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
181
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
182
|
-
value=0,
|
|
183
|
-
).unsqueeze(-1)
|
|
184
|
-
padding_mask = torch.nn.functional.pad(
|
|
185
|
-
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
|
|
186
|
-
)
|
|
187
|
-
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
|
|
188
|
-
batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
|
|
189
|
-
return batch
|
|
190
|
-
|
|
191
180
|
def _model_predict(
|
|
192
181
|
self,
|
|
193
182
|
feature_tensors: TensorMap,
|
|
194
183
|
padding_mask: torch.BoolTensor,
|
|
195
184
|
tokens_mask: torch.BoolTensor,
|
|
185
|
+
candidates_to_score: torch.LongTensor = None,
|
|
196
186
|
) -> torch.Tensor:
|
|
197
187
|
model: Bert4RecModel
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
scores = model(feature_tensors, padding_mask, tokens_mask)
|
|
203
|
-
|
|
204
|
-
return candidate_scores
|
|
188
|
+
model = (
|
|
189
|
+
cast(Bert4RecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
|
|
190
|
+
)
|
|
191
|
+
candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
|
|
192
|
+
scores = model.predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
|
|
193
|
+
return scores
|
|
205
194
|
|
|
206
195
|
def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
|
|
207
196
|
if self._loss_type == "BCE":
|
|
@@ -504,6 +493,50 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
504
493
|
|
|
505
494
|
self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
|
|
506
495
|
|
|
496
|
+
@property
|
|
497
|
+
def optimizer_factory(self) -> OptimizerFactory:
|
|
498
|
+
"""
|
|
499
|
+
Returns current optimizer_factory.
|
|
500
|
+
"""
|
|
501
|
+
return self._optimizer_factory
|
|
502
|
+
|
|
503
|
+
@optimizer_factory.setter
|
|
504
|
+
def optimizer_factory(self, optimizer_factory: OptimizerFactory) -> None:
|
|
505
|
+
"""
|
|
506
|
+
Sets new optimizer_factory.
|
|
507
|
+
:param optimizer_factory: New optimizer factory.
|
|
508
|
+
"""
|
|
509
|
+
if isinstance(optimizer_factory, OptimizerFactory):
|
|
510
|
+
self._optimizer_factory = optimizer_factory
|
|
511
|
+
else:
|
|
512
|
+
msg = f"Expected optimizer_factory of type OptimizerFactory, got {type(optimizer_factory)}"
|
|
513
|
+
raise ValueError(msg)
|
|
514
|
+
|
|
515
|
+
@property
|
|
516
|
+
def candidates_to_score(self) -> Union[torch.LongTensor, None]:
|
|
517
|
+
"""
|
|
518
|
+
Returns tensor of item ids to calculate scores.
|
|
519
|
+
"""
|
|
520
|
+
return self._candidates_to_score
|
|
521
|
+
|
|
522
|
+
@candidates_to_score.setter
|
|
523
|
+
def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
|
|
524
|
+
"""
|
|
525
|
+
Sets tensor of item ids to calculate scores.
|
|
526
|
+
:param candidates: Tensor of item ids to calculate scores.
|
|
527
|
+
"""
|
|
528
|
+
total_item_count = self._model.item_count
|
|
529
|
+
if isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long:
|
|
530
|
+
if 0 < candidates.shape[0] <= total_item_count:
|
|
531
|
+
self._candidates_to_score = candidates
|
|
532
|
+
else:
|
|
533
|
+
msg = f"Expected candidates length to be between 1 and {total_item_count=}"
|
|
534
|
+
raise ValueError(msg)
|
|
535
|
+
elif candidates is not None:
|
|
536
|
+
msg = f"Expected candidates to be of type torch.LongTensor or None, gpt {type(candidates)}"
|
|
537
|
+
raise ValueError(msg)
|
|
538
|
+
self._candidates_to_score = candidates
|
|
539
|
+
|
|
507
540
|
def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int):
|
|
508
541
|
self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new
|
|
509
542
|
if self._model.enable_embedding_tying is True:
|
|
@@ -521,3 +554,37 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
521
554
|
self._vocab_size = new_vocab_size
|
|
522
555
|
self._model.item_count = new_vocab_size
|
|
523
556
|
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _prepare_prediction_batch(
|
|
560
|
+
schema: TensorSchema, max_len: int, batch: Bert4RecPredictionBatch
|
|
561
|
+
) -> Bert4RecPredictionBatch:
|
|
562
|
+
if batch.padding_mask.shape[1] > max_len:
|
|
563
|
+
msg = (
|
|
564
|
+
f"The length of the submitted sequence "
|
|
565
|
+
"must not exceed the maximum length of the sequence. "
|
|
566
|
+
f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
|
|
567
|
+
f"while the maximum length is {max_len}"
|
|
568
|
+
)
|
|
569
|
+
raise ValueError(msg)
|
|
570
|
+
|
|
571
|
+
if batch.padding_mask.shape[1] < max_len:
|
|
572
|
+
query_id, padding_mask, features, _ = batch
|
|
573
|
+
sequence_item_count = padding_mask.shape[1]
|
|
574
|
+
for feature_name, feature_tensor in features.items():
|
|
575
|
+
if schema[feature_name].is_cat:
|
|
576
|
+
features[feature_name] = torch.nn.functional.pad(
|
|
577
|
+
feature_tensor,
|
|
578
|
+
(max_len - sequence_item_count, 0),
|
|
579
|
+
value=schema[feature_name].padding_value,
|
|
580
|
+
)
|
|
581
|
+
else:
|
|
582
|
+
features[feature_name] = torch.nn.functional.pad(
|
|
583
|
+
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
584
|
+
(max_len - sequence_item_count, 0),
|
|
585
|
+
value=schema[feature_name].padding_value,
|
|
586
|
+
).unsqueeze(-1)
|
|
587
|
+
padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
|
|
588
|
+
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(schema, features, padding_mask)
|
|
589
|
+
batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
|
|
590
|
+
return batch
|
|
@@ -98,6 +98,27 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
98
98
|
|
|
99
99
|
return all_scores # [B x L x E]
|
|
100
100
|
|
|
101
|
+
def predict(
|
|
102
|
+
self,
|
|
103
|
+
inputs: TensorMap,
|
|
104
|
+
pad_mask: torch.BoolTensor,
|
|
105
|
+
token_mask: torch.BoolTensor,
|
|
106
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
:param inputs: Batch of features.
|
|
110
|
+
:param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
111
|
+
:param token_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
|
|
112
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
113
|
+
if `None` predicts for all items
|
|
114
|
+
|
|
115
|
+
:returns: Calculated scores among canditates_to_score items.
|
|
116
|
+
"""
|
|
117
|
+
# final_emb: [B x E]
|
|
118
|
+
final_emb = self.get_query_embeddings(inputs, pad_mask, token_mask)
|
|
119
|
+
candidate_scores = self.get_logits(final_emb, candidates_to_score)
|
|
120
|
+
return candidate_scores
|
|
121
|
+
|
|
101
122
|
def forward_step(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor) -> torch.Tensor:
|
|
102
123
|
"""
|
|
103
124
|
|
|
@@ -64,6 +64,11 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
64
64
|
self._item_batches.clear()
|
|
65
65
|
self._item_scores.clear()
|
|
66
66
|
|
|
67
|
+
candidates = trainer.model.candidates_to_score if hasattr(trainer.model, "candidates_to_score") else None
|
|
68
|
+
for postprocessor in self._postprocessors:
|
|
69
|
+
if hasattr(postprocessor, "candidates"):
|
|
70
|
+
postprocessor.candidates = candidates
|
|
71
|
+
|
|
67
72
|
def on_predict_batch_end(
|
|
68
73
|
self,
|
|
69
74
|
trainer: lightning.Trainer, # noqa: ARG002
|
|
@@ -88,7 +93,6 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
88
93
|
torch.cat(self._item_batches),
|
|
89
94
|
torch.cat(self._item_scores),
|
|
90
95
|
)
|
|
91
|
-
|
|
92
96
|
return prediction
|
|
93
97
|
|
|
94
98
|
def _compute_pipeline(
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import tempfile
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import lightning
|
|
7
|
+
import openvino as ov
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from replay.data.nn import TensorSchema
|
|
11
|
+
|
|
12
|
+
OptimizedModeType = Literal[
|
|
13
|
+
"batch",
|
|
14
|
+
"one_query",
|
|
15
|
+
"dynamic_batch_size",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _compile_openvino(
|
|
20
|
+
onnx_path: str,
|
|
21
|
+
batch_size: int,
|
|
22
|
+
max_seq_len: int,
|
|
23
|
+
num_candidates_to_score: int,
|
|
24
|
+
num_threads: Optional[int],
|
|
25
|
+
) -> ov.CompiledModel:
|
|
26
|
+
"""
|
|
27
|
+
Method defines compilation strategy for openvino backend.
|
|
28
|
+
|
|
29
|
+
:param onnx_path: Path to the model representation in ONNX format.
|
|
30
|
+
:param batch_size: Defines whether batch will be static or dynamic length.
|
|
31
|
+
:param max_seq_len: Defines whether sequence will be static or dynamic length.
|
|
32
|
+
:param num_candidates_to_score: Defines whether candidates will be static or dynamic length.
|
|
33
|
+
:param num_threads: Defines number of CPU threads for which the model will be compiled by the OpenVino core.
|
|
34
|
+
If ``None``, then compiler will set this parameter automatically.
|
|
35
|
+
Default: ``None``.
|
|
36
|
+
|
|
37
|
+
:return: Compiled model.
|
|
38
|
+
"""
|
|
39
|
+
core = ov.Core()
|
|
40
|
+
if num_threads is not None:
|
|
41
|
+
core.set_property("CPU", {"INFERENCE_NUM_THREADS": num_threads})
|
|
42
|
+
model_onnx = core.read_model(model=onnx_path)
|
|
43
|
+
inputs_names = [inputs.names.pop() for inputs in model_onnx.inputs]
|
|
44
|
+
del model_onnx
|
|
45
|
+
|
|
46
|
+
candidates_input_id = len(inputs_names) - 1 if num_candidates_to_score is not None else len(inputs_names)
|
|
47
|
+
model_input_scheme = [(input_name, [batch_size, max_seq_len]) for input_name in inputs_names[:candidates_input_id]]
|
|
48
|
+
if num_candidates_to_score is not None:
|
|
49
|
+
model_input_scheme += [(inputs_names[candidates_input_id], [num_candidates_to_score])]
|
|
50
|
+
model_onnx = ov.convert_model(onnx_path, input=model_input_scheme)
|
|
51
|
+
return core.compile_model(model=model_onnx, device_name="CPU")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class BaseCompiledModel:
|
|
55
|
+
"""
|
|
56
|
+
Base class of CPU-optimized model for inference via OpenVINO.
|
|
57
|
+
It is recommended to use inherited classes and not to use this one.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
compiled_model: ov.CompiledModel,
|
|
63
|
+
schema: TensorSchema,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""
|
|
66
|
+
:param compiled_model: Compiled model.
|
|
67
|
+
:param schema: Tensor schema of SasRec model.
|
|
68
|
+
"""
|
|
69
|
+
self._batch_size: int
|
|
70
|
+
self._max_seq_len: int
|
|
71
|
+
self._inputs_names: List[str]
|
|
72
|
+
self._output_name: str
|
|
73
|
+
|
|
74
|
+
self._set_inner_params_from_openvino_model(compiled_model)
|
|
75
|
+
self._schema = schema
|
|
76
|
+
self._model = compiled_model
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def predict(
|
|
80
|
+
self,
|
|
81
|
+
batch: Any,
|
|
82
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
"""
|
|
85
|
+
Inference on one batch.
|
|
86
|
+
|
|
87
|
+
:param batch: Prediction input.
|
|
88
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
89
|
+
Default: ``None``.
|
|
90
|
+
|
|
91
|
+
:return: Tensor with scores.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def _validate_candidates_to_score(self, candidates: torch.LongTensor) -> None:
|
|
95
|
+
"""Check if candidates param has proper type"""
|
|
96
|
+
|
|
97
|
+
if not (isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long):
|
|
98
|
+
msg = (
|
|
99
|
+
"Expected candidates to be of type ``torch.Tensor`` with dtype ``torch.long``, "
|
|
100
|
+
f"got {type(candidates)} with dtype {candidates.dtype}."
|
|
101
|
+
)
|
|
102
|
+
raise ValueError(msg)
|
|
103
|
+
|
|
104
|
+
def _valilade_predict_input(self, batch: Any, candidates_to_score: Optional[torch.LongTensor] = None) -> None:
|
|
105
|
+
if self._num_candidates_to_score is None and candidates_to_score is not None:
|
|
106
|
+
msg = (
|
|
107
|
+
"If ``num_candidates_to_score`` is None, "
|
|
108
|
+
"it is impossible to infer the model with passed ``candidates_to_score``."
|
|
109
|
+
)
|
|
110
|
+
raise ValueError(msg)
|
|
111
|
+
|
|
112
|
+
if self._batch_size != -1 and batch.padding_mask.shape[0] != self._batch_size:
|
|
113
|
+
msg = (
|
|
114
|
+
f"The batch is smaller then defined batch_size={self._batch_size}. "
|
|
115
|
+
"It is impossible to infer the model with dynamic batch size in ``mode`` = ``batch``. "
|
|
116
|
+
"Use ``mode`` = ``dynamic_batch_size``."
|
|
117
|
+
)
|
|
118
|
+
raise ValueError(msg)
|
|
119
|
+
|
|
120
|
+
def _set_inner_params_from_openvino_model(self, compiled_model: ov.CompiledModel) -> None:
|
|
121
|
+
"""Set params for ``predict`` method"""
|
|
122
|
+
|
|
123
|
+
input_scheme = compiled_model.inputs
|
|
124
|
+
self._batch_size = input_scheme[0].partial_shape[0].max_length
|
|
125
|
+
self._max_seq_len = input_scheme[0].partial_shape[1].max_length
|
|
126
|
+
self._inputs_names = [input.names.pop() for input in compiled_model.inputs]
|
|
127
|
+
if "candidates_to_score" in self._inputs_names:
|
|
128
|
+
self._num_candidates_to_score = input_scheme[-1].partial_shape[0].max_length
|
|
129
|
+
else:
|
|
130
|
+
self._num_candidates_to_score = None
|
|
131
|
+
self._output_name = compiled_model.output().names.pop()
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _validate_num_candidates_to_score(num_candidates: Union[int, None]) -> Union[int, None]:
|
|
135
|
+
"""Check if num_candidates param is proper"""
|
|
136
|
+
|
|
137
|
+
if num_candidates is None:
|
|
138
|
+
return num_candidates
|
|
139
|
+
if isinstance(num_candidates, int) and (num_candidates == -1 or num_candidates >= 1):
|
|
140
|
+
return num_candidates
|
|
141
|
+
|
|
142
|
+
msg = (
|
|
143
|
+
"Expected num_candidates_to_score to be of type ``int``, equal to ``-1``, ``natural number`` or ``None``. "
|
|
144
|
+
f"Got {num_candidates}."
|
|
145
|
+
)
|
|
146
|
+
raise ValueError(msg)
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _get_input_params(
|
|
150
|
+
mode: OptimizedModeType,
|
|
151
|
+
batch_size: Optional[int],
|
|
152
|
+
num_candidates_to_score: Optional[int],
|
|
153
|
+
) -> None:
|
|
154
|
+
"""Get params for model compilation according to compilation mode"""
|
|
155
|
+
|
|
156
|
+
if mode == "one_query":
|
|
157
|
+
batch_size = 1
|
|
158
|
+
|
|
159
|
+
if mode == "batch":
|
|
160
|
+
assert batch_size, f"{mode} mode requires `batch_size`"
|
|
161
|
+
batch_size = batch_size
|
|
162
|
+
|
|
163
|
+
if mode == "dynamic_batch_size":
|
|
164
|
+
batch_size = -1
|
|
165
|
+
|
|
166
|
+
num_candidates_to_score = num_candidates_to_score if num_candidates_to_score else None
|
|
167
|
+
return batch_size, num_candidates_to_score
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _run_model_compilation(
|
|
171
|
+
lightning_model: lightning.LightningModule,
|
|
172
|
+
model_input_sample: Tuple[Union[torch.Tensor, Dict[str, torch.Tensor]]],
|
|
173
|
+
model_input_names: List[str],
|
|
174
|
+
model_dynamic_axes_in_input: Dict[str, Dict],
|
|
175
|
+
batch_size: int,
|
|
176
|
+
num_candidates_to_score: Union[int, None],
|
|
177
|
+
num_threads: Optional[int] = None,
|
|
178
|
+
onnx_path: Optional[str] = None,
|
|
179
|
+
) -> ov.CompiledModel:
|
|
180
|
+
"""
|
|
181
|
+
Model conversion into ONNX format and compilation with defined engine.
|
|
182
|
+
|
|
183
|
+
:param lightning_model: Lightning model to be compiled.
|
|
184
|
+
:param model_input_sample: An example of model input with proper data type.
|
|
185
|
+
:param model_input_names: Input tensor names.
|
|
186
|
+
:param model_dynamic_axes_in_input: Dynamic axes in input.
|
|
187
|
+
:param batch_size: Defines the size of the axis with index 0 in the input of the compiled model.
|
|
188
|
+
:param num_candidates_to_score: Defines the size of the candidates in the input of the compiled model.
|
|
189
|
+
:param num_threads: Number of CPU threads to use.
|
|
190
|
+
Must be a natural number or ``None``.
|
|
191
|
+
If ``None``, then compiler will set this parameter automatically.
|
|
192
|
+
Default: ``None``.
|
|
193
|
+
:param onnx_path: Save ONNX model to path, if defined.
|
|
194
|
+
Default: ``None``.
|
|
195
|
+
|
|
196
|
+
:return: Compiled model.
|
|
197
|
+
"""
|
|
198
|
+
max_seq_len = lightning_model._model.max_len
|
|
199
|
+
|
|
200
|
+
if onnx_path is None:
|
|
201
|
+
is_saveble = False
|
|
202
|
+
onnx_file = tempfile.NamedTemporaryFile(suffix=".onnx")
|
|
203
|
+
onnx_path = onnx_file.name
|
|
204
|
+
else:
|
|
205
|
+
is_saveble = True
|
|
206
|
+
|
|
207
|
+
lightning_model.to_onnx(
|
|
208
|
+
onnx_path,
|
|
209
|
+
input_sample=model_input_sample,
|
|
210
|
+
export_params=True,
|
|
211
|
+
opset_version=torch.onnx._constants.ONNX_DEFAULT_OPSET,
|
|
212
|
+
do_constant_folding=True,
|
|
213
|
+
input_names=model_input_names,
|
|
214
|
+
output_names=["scores"],
|
|
215
|
+
dynamic_axes=model_dynamic_axes_in_input,
|
|
216
|
+
)
|
|
217
|
+
del lightning_model
|
|
218
|
+
|
|
219
|
+
compiled_model = _compile_openvino(onnx_path, batch_size, max_seq_len, num_candidates_to_score, num_threads)
|
|
220
|
+
|
|
221
|
+
if not is_saveble:
|
|
222
|
+
onnx_file.close()
|
|
223
|
+
|
|
224
|
+
return compiled_model
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
@abstractmethod
|
|
228
|
+
def compile(
|
|
229
|
+
cls,
|
|
230
|
+
model: Union[lightning.LightningModule, str, pathlib.Path],
|
|
231
|
+
mode: OptimizedModeType = "one_query",
|
|
232
|
+
batch_size: Optional[int] = None,
|
|
233
|
+
num_candidates_to_score: Optional[int] = None,
|
|
234
|
+
num_threads: Optional[int] = None,
|
|
235
|
+
onnx_path: Optional[str] = None,
|
|
236
|
+
) -> "BaseCompiledModel":
|
|
237
|
+
"""
|
|
238
|
+
Model compilation.
|
|
239
|
+
|
|
240
|
+
:param model: Path to lightning model saved in .ckpt format or the model object itself.
|
|
241
|
+
:param mode: Inference mode, defines shape of inputs.
|
|
242
|
+
Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
|
|
243
|
+
``one_query`` - sets input shape to [1, max_seq_len]\n
|
|
244
|
+
``batch`` - sets input shape to [batch_size, max_seq_len]\n
|
|
245
|
+
``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
|
|
246
|
+
Default: ``one_query``.
|
|
247
|
+
:param batch_size: Batch size, required for ``batch`` mode.
|
|
248
|
+
Default: ``None``.
|
|
249
|
+
:param num_candidates_to_score: Number of item ids to calculate scores.
|
|
250
|
+
Could be one of [``None``, ``-1``, ``N``].\n
|
|
251
|
+
``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
|
|
252
|
+
``N`` - sets candidates_to_score shape to [1, N]\n
|
|
253
|
+
``None`` - disable candidates_to_score usage\n
|
|
254
|
+
Default: ``None``.
|
|
255
|
+
:param num_threads: Number of CPU threads to use.
|
|
256
|
+
Must be a natural number or ``None``.
|
|
257
|
+
If ``None``, then compiler will set this parameter automatically.
|
|
258
|
+
Default: ``None``.
|
|
259
|
+
:param onnx_path: Save ONNX model to path, if defined.
|
|
260
|
+
Default: ``None``.
|
|
261
|
+
"""
|