replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
  35. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
  36. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
  37. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
@@ -1,8 +1,9 @@
1
+ import abc
1
2
  import json
2
3
  import pickle
3
4
  import warnings
4
5
  from pathlib import Path
5
- from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
6
+ from typing import Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
6
7
 
7
8
  import numpy as np
8
9
  import polars as pl
@@ -20,6 +21,7 @@ from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset
20
21
  from .utils import ensure_pandas, groupby_sequences
21
22
 
22
23
  SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
+ _T = TypeVar("_T")
23
25
 
24
26
 
25
27
  class SequenceTokenizer:
@@ -154,7 +156,6 @@ class SequenceTokenizer:
154
156
 
155
157
  matched_dataset = self._match_features_with_tensor_schema(dataset, schema)
156
158
 
157
- is_polars = isinstance(dataset.interactions, PolarsDataFrame)
158
159
  encoded_dataset = self._encode_dataset(matched_dataset)
159
160
  is_polars = isinstance(encoded_dataset.interactions, PolarsDataFrame)
160
161
  grouped_interactions, query_features, item_features = self._group_dataset(encoded_dataset)
@@ -223,7 +224,11 @@ class SequenceTokenizer:
223
224
  query_features: Optional[SequenceDataFrameLike],
224
225
  item_features: Optional[SequenceDataFrameLike],
225
226
  ) -> SequenceDataFrameLike:
226
- processor = _SequenceProcessor(
227
+ sequence_processor_class = (
228
+ _PolarsSequenceProcessor if isinstance(grouped_interactions, PolarsDataFrame) else _PandasSequenceProcessor
229
+ )
230
+
231
+ processor = sequence_processor_class(
227
232
  tensor_schema=schema,
228
233
  query_id_column=feature_schema.query_id_column,
229
234
  item_id_column=feature_schema.item_id_column,
@@ -231,17 +236,7 @@ class SequenceTokenizer:
231
236
  query_features=query_features,
232
237
  item_features=item_features,
233
238
  )
234
-
235
- if isinstance(grouped_interactions, PolarsDataFrame):
236
- return processor.process_features_polars()
237
-
238
- all_features: Dict[str, Union[np.ndarray, List[np.ndarray]]] = {}
239
- all_features[feature_schema.query_id_column] = grouped_interactions[feature_schema.query_id_column].values
240
-
241
- for tensor_feature_name in schema:
242
- all_features[tensor_feature_name] = processor.process_feature(tensor_feature_name)
243
-
244
- return PandasDataFrame(all_features)
239
+ return processor.process_features()
245
240
 
246
241
  @classmethod
247
242
  def _match_features_with_tensor_schema(
@@ -397,7 +392,7 @@ class SequenceTokenizer:
397
392
  f"The specified cardinality of {tensor_feature.name} "
398
393
  f"will be replaced by {dataset_feature.column} from Dataset"
399
394
  )
400
- if dataset_feature.feature_type != FeatureType.CATEGORICAL:
395
+ if dataset_feature.feature_type not in [FeatureType.CATEGORICAL, FeatureType.CATEGORICAL_LIST]:
401
396
  error_msg = (
402
397
  f"TensorFeatureInfo {tensor_feature.name} "
403
398
  f"and FeatureInfo {dataset_feature.column} must be the same FeatureType"
@@ -498,16 +493,9 @@ class SequenceTokenizer:
498
493
  pickle.dump(self, file)
499
494
 
500
495
 
501
- class _SequenceProcessor:
496
+ class _BaseSequenceProcessor(Generic[_T]):
502
497
  """
503
- Class to process sequences of different categorical and numerical features.
504
-
505
- Processing performs over all features in `tensor_schema`. Each feature processing steps
506
- depends on feature type (categorical/numerical), feature source (interactions/query features/item features)
507
- and `grouped_interactions` data format (Pandas/Polars).
508
- If `grouped_interactions` is `PolarsDataFrame` object, then method `process_features_polars` is called.
509
- If `grouped_interactions` is `PandasDataFrame` object, then method `process_features` is called,
510
- with passing all tensor features one by one.
498
+ Base class for sequence processing
511
499
  """
512
500
 
513
501
  def __init__(
@@ -515,29 +503,26 @@ class _SequenceProcessor:
515
503
  tensor_schema: TensorSchema,
516
504
  query_id_column: str,
517
505
  item_id_column: str,
518
- grouped_interactions: SequenceDataFrameLike,
519
- query_features: Optional[SequenceDataFrameLike] = None,
520
- item_features: Optional[SequenceDataFrameLike] = None,
506
+ grouped_interactions: _T,
507
+ query_features: Optional[_T] = None,
508
+ item_features: Optional[_T] = None,
521
509
  ) -> None:
522
510
  self._tensor_schema = tensor_schema
523
511
  self._query_id_column = query_id_column
524
512
  self._item_id_column = item_id_column
525
513
  self._grouped_interactions = grouped_interactions
526
- self._is_polars = isinstance(grouped_interactions, PolarsDataFrame)
527
- if not self._is_polars:
528
- self._query_features = (
529
- query_features.set_index(self._query_id_column).sort_index() if query_features is not None else None
530
- )
531
- self._item_features = (
532
- item_features.set_index(self._item_id_column).sort_index() if item_features is not None else None
533
- )
534
- else:
535
- self._query_features = query_features
536
- self._item_features = item_features
514
+ self._query_features = query_features
515
+ self._item_features = item_features
537
516
 
538
- def process_feature(self, tensor_feature_name: str) -> List[np.ndarray]:
517
+ @abc.abstractmethod
518
+ def process_features(self) -> _T: # pragma: no cover
519
+ """
520
+ For each feature that you want to process, you should call the _process_feature function.
539
521
  """
540
- Process each tensor feature for Pandas dataframes.
522
+
523
+ def _process_feature(self, tensor_feature_name: str) -> _T:
524
+ """
525
+ Process each tensor feature for dataframes.
541
526
 
542
527
  :param tensor_feature_name: name of feature to process.
543
528
 
@@ -550,30 +535,24 @@ class _SequenceProcessor:
550
535
  return self._process_num_feature(tensor_feature)
551
536
  assert False, "Unknown tensor feature type"
552
537
 
553
- def process_features_polars(self) -> PolarsDataFrame:
538
+ def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
554
539
  """
555
- Process all features in `tensor_schema` for Polars dataframes.
556
- Each Polars processing step returns DataFrame with query and target column
557
- to join in result dataframe.
558
-
559
- :returns: processed Polars DataFrame with all features from tensor schema.
540
+ Process numerical tensor feature depends on it source.
560
541
  """
561
- data = self._grouped_interactions.select(self._query_id_column)
562
- for tensor_feature_name in self._tensor_schema:
563
- tensor_feature = self._tensor_schema[tensor_feature_name]
564
- if tensor_feature.is_cat:
565
- data = data.join(self._process_cat_feature(tensor_feature), on=self._query_id_column, how="left")
566
- elif tensor_feature.is_num:
567
- data = data.join(self._process_num_feature(tensor_feature), on=self._query_id_column, how="left")
568
- else:
569
- assert False, "Unknown tensor feature type"
570
- return data
542
+ assert tensor_feature.feature_sources is not None
543
+ if tensor_feature.feature_source.source == FeatureSource.INTERACTIONS:
544
+ return self._process_num_interaction_feature(tensor_feature)
545
+ if tensor_feature.feature_source.source == FeatureSource.QUERY_FEATURES:
546
+ return self._process_num_query_feature(tensor_feature)
547
+ if tensor_feature.feature_source.source == FeatureSource.ITEM_FEATURES:
548
+ return self._process_num_item_feature(tensor_feature)
549
+ assert False, "Unknown tensor feature source table"
571
550
 
572
- def _process_cat_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
551
+ def _process_cat_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
573
552
  """
574
553
  Process categorical tensor feature depends on it source.
575
554
  """
576
- assert tensor_feature.feature_source is not None
555
+ assert tensor_feature.feature_sources is not None
577
556
  if tensor_feature.feature_source.source == FeatureSource.INTERACTIONS:
578
557
  return self._process_cat_interaction_feature(tensor_feature)
579
558
  if tensor_feature.feature_source.source == FeatureSource.QUERY_FEATURES:
@@ -582,102 +561,153 @@ class _SequenceProcessor:
582
561
  return self._process_cat_item_feature(tensor_feature)
583
562
  assert False, "Unknown tensor feature source table"
584
563
 
585
- def _process_num_feature_polars(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
586
- def get_sequence(user, source, data):
587
- if source.source == FeatureSource.INTERACTIONS:
588
- return np.array(
589
- self._grouped_interactions.filter(pl.col(self._query_id_column) == user)[source.column][0],
590
- dtype=np.float32,
591
- ).tolist()
592
- elif source.source == FeatureSource.ITEM_FEATURES:
593
- return (
594
- pl.DataFrame({self._item_id_column: data})
595
- .join(self._item_features, on=self._item_id_column, how="left")
596
- .select(source.column)
597
- .to_numpy()
598
- .reshape(-1)
599
- .tolist()
600
- )
601
- else:
602
- assert False, "Unknown tensor feature source table"
564
+ @abc.abstractmethod
565
+ def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
566
+ pass
603
567
 
604
- result = (
605
- self._grouped_interactions.select(self._query_id_column, self._item_id_column).map_rows(
606
- lambda x: (x[0], [get_sequence(x[0], source, x[1]) for source in tensor_feature.feature_sources])
607
- )
608
- ).rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
568
+ @abc.abstractmethod
569
+ def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
570
+ pass
609
571
 
610
- if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
611
- reshape_size = -1
612
- else:
613
- reshape_size = (-1, len(tensor_feature.feature_sources))
614
-
615
- return pl.DataFrame(
616
- {
617
- self._query_id_column: result[self._query_id_column].to_list(),
618
- tensor_feature.name: [
619
- np.array(x).reshape(reshape_size).tolist() for x in result[tensor_feature.name].to_list()
620
- ],
621
- }
572
+ @abc.abstractmethod
573
+ def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
574
+ pass
575
+
576
+ @abc.abstractmethod
577
+ def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
578
+ pass
579
+
580
+ @abc.abstractmethod
581
+ def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
582
+ pass
583
+
584
+ @abc.abstractmethod
585
+ def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
586
+ pass
587
+
588
+
589
+ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
590
+ """
591
+ Class to process sequences of different categorical and numerical features.
592
+
593
+ Processing performs over all features in `tensor_schema`. Each feature processing steps
594
+ depends on feature type (categorical/numerical), feature source (interactions/query features/item features)
595
+ and `grouped_interactions` in Pandas DataFrame format.
596
+ """
597
+
598
+ def __init__(
599
+ self,
600
+ tensor_schema: TensorSchema,
601
+ query_id_column: str,
602
+ item_id_column: str,
603
+ grouped_interactions: PandasDataFrame,
604
+ query_features: Optional[PandasDataFrame] = None,
605
+ item_features: Optional[PandasDataFrame] = None,
606
+ ) -> None:
607
+ super().__init__(
608
+ tensor_schema=tensor_schema,
609
+ query_id_column=query_id_column,
610
+ item_id_column=item_id_column,
611
+ grouped_interactions=grouped_interactions,
612
+ query_features=(
613
+ query_features.set_index(query_id_column).sort_index() if query_features is not None else None
614
+ ),
615
+ item_features=item_features.set_index(item_id_column).sort_index() if item_features is not None else None,
622
616
  )
623
617
 
624
- def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
618
+ def process_features(self) -> PandasDataFrame:
619
+ """
620
+ :returns: processed Pandas DataFrame with all features from tensor schema.
621
+ """
622
+ all_features: Dict[str, Union[np.ndarray, List[np.ndarray]]] = {}
623
+ all_features[self._query_id_column] = self._grouped_interactions[self._query_id_column].values
624
+
625
+ for tensor_feature_name in self._tensor_schema:
626
+ all_features[tensor_feature_name] = self._process_feature(tensor_feature_name)
627
+
628
+ return PandasDataFrame(all_features)
629
+
630
+ def _process_num_interaction_feature(
631
+ self, tensor_feature: TensorFeatureInfo
632
+ ) -> Union[List[np.ndarray], List[List]]:
625
633
  """
626
- Process numerical feature for all sources.
634
+ Process numerical interaction feature.
627
635
 
628
636
  :param tensor_feature: tensor feature information.
629
637
 
630
- :returns: sequences for each source for each query.
631
- If feature came from item features then gets item features values.
632
- If feature came from interactions then gets values from interactions.
633
- The results are combined in one sequence array.
638
+ :returns: tensor feature column as a sequences from `grouped_interactions`.
634
639
  """
635
- assert tensor_feature.feature_sources is not None
636
640
  assert tensor_feature.is_seq
637
641
 
638
- if self._is_polars:
639
- return self._process_num_feature_polars(tensor_feature)
642
+ source = tensor_feature.feature_source
643
+ assert source is not None
640
644
 
641
- values: List[np.ndarray] = []
642
- for pos, item_id_sequence in enumerate(self._grouped_interactions[self._item_id_column]):
643
- all_features_for_user = []
644
- for source in tensor_feature.feature_sources:
645
- if source.source == FeatureSource.ITEM_FEATURES:
646
- item_feature = self._item_features[source.column]
647
- feature_sequence = item_feature.loc[item_id_sequence].values
648
- all_features_for_user.append(feature_sequence)
649
- elif source.source == FeatureSource.INTERACTIONS:
650
- sequence = self._grouped_interactions[source.column][pos]
651
- all_features_for_user.append(sequence)
652
- else:
653
- assert False, "Unknown tensor feature source table"
654
- all_seqs = np.array(all_features_for_user, dtype=np.float32)
655
- if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
656
- all_seqs = all_seqs.reshape(-1)
645
+ values = []
646
+ for sequence in self._grouped_interactions[source.column].values:
647
+ if tensor_feature.feature_type == FeatureType.NUMERICAL_LIST:
648
+ values.append(list(sequence))
657
649
  else:
658
- all_seqs = all_seqs.reshape(-1, (len(tensor_feature.feature_sources)))
659
- values.append(all_seqs)
650
+ values.append(np.array(sequence))
660
651
  return values
661
652
 
662
- def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
653
+ def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
663
654
  """
664
- Process categorical interaction feature.
655
+ Process numerical feature from item features dataset.
665
656
 
666
657
  :param tensor_feature: tensor feature information.
667
658
 
668
659
  :returns: tensor feature column as a sequences from `grouped_interactions`.
669
660
  """
670
661
  assert tensor_feature.is_seq
662
+ assert self._item_features is not None
671
663
 
672
664
  source = tensor_feature.feature_source
673
665
  assert source is not None
674
666
 
675
- if self._is_polars:
676
- return self._grouped_interactions.select(self._query_id_column, source.column).rename(
677
- {source.column: tensor_feature.name}
678
- )
667
+ item_feature = self._item_features[source.column]
668
+ values = []
679
669
 
680
- return [np.array(sequence, dtype=np.int64) for sequence in self._grouped_interactions[source.column]]
670
+ for item_id_sequence in self._grouped_interactions[self._item_id_column]:
671
+ feature_sequence = item_feature.loc[item_id_sequence].values
672
+ if tensor_feature.feature_type == FeatureType.NUMERICAL_LIST:
673
+ values.append(feature_sequence.tolist())
674
+ else:
675
+ values.append(np.array(feature_sequence))
676
+
677
+ return values
678
+
679
+ def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
680
+ """
681
+ Process numerical feature from query features dataset.
682
+
683
+ :param tensor_feature: tensor feature information.
684
+
685
+ :returns: tensor feature column as a sequences from `grouped_interactions`.
686
+ """
687
+ return self._process_cat_query_feature(tensor_feature)
688
+
689
+ def _process_cat_interaction_feature(
690
+ self, tensor_feature: TensorFeatureInfo
691
+ ) -> Union[List[np.ndarray], List[List]]:
692
+ """
693
+ Process categorical interaction feature.
694
+
695
+ :param tensor_feature: tensor feature information.
696
+
697
+ :returns: tensor feature column as a sequences from `grouped_interactions`.
698
+ """
699
+ assert tensor_feature.is_seq
700
+
701
+ source = tensor_feature.feature_source
702
+ assert source is not None
703
+
704
+ values = []
705
+ for sequence in self._grouped_interactions[source.column].values:
706
+ if tensor_feature.feature_type == FeatureType.CATEGORICAL_LIST:
707
+ values.append(list(sequence))
708
+ else:
709
+ values.append(np.array(sequence))
710
+ return values
681
711
 
682
712
  def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
683
713
  """
@@ -693,30 +723,22 @@ class _SequenceProcessor:
693
723
  source = tensor_feature.feature_source
694
724
  assert source is not None
695
725
 
696
- if self._is_polars:
697
- if tensor_feature.is_seq:
698
- lengths = self._grouped_interactions.select(
699
- self._query_id_column, pl.col(self._item_id_column).list.len().alias("len")
700
- )
701
- result = self._query_features.join(lengths, on=self._query_id_column, how="left")
702
- repeat_value = "len"
703
- else:
704
- result = self._query_features
705
- repeat_value = 1
706
-
707
- return result.select(self._query_id_column, pl.col(source.column).repeat_by(repeat_value)).rename(
708
- {source.column: tensor_feature.name}
709
- )
710
-
711
726
  query_feature = self._query_features[source.column].values
712
727
  if tensor_feature.is_seq:
713
- return [
714
- np.full(len(item_id_sequence), query_feature[i], dtype=np.int64)
715
- for i, item_id_sequence in enumerate(self._grouped_interactions[self._item_id_column])
716
- ]
717
- return [np.array([query_feature[i]], dtype=np.int64) for i in range(len(self._grouped_interactions))]
728
+ if tensor_feature.is_list:
729
+ result = []
730
+ for i, item_id_sequence in enumerate(self._grouped_interactions[self._item_id_column]):
731
+ seq_len = len(item_id_sequence)
732
+ result.append(np.repeat(query_feature[i], seq_len).reshape(-1, seq_len).T)
733
+ return result
734
+ else:
735
+ return [
736
+ np.full(len(item_id_sequence), query_feature[i])
737
+ for i, item_id_sequence in enumerate(self._grouped_interactions[self._item_id_column])
738
+ ]
739
+ return [np.array([query_feature[i]]).reshape(-1) for i in range(len(self._grouped_interactions))]
718
740
 
719
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
741
+ def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
720
742
  """
721
743
  Process categorical feature from item features dataset.
722
744
 
@@ -731,28 +753,154 @@ class _SequenceProcessor:
731
753
  source = tensor_feature.feature_source
732
754
  assert source is not None
733
755
 
734
- if self._is_polars:
756
+ item_feature = self._item_features[source.column]
757
+ values: List[np.ndarray] = []
758
+
759
+ for item_id_sequence in self._grouped_interactions[self._item_id_column]:
760
+ feature_sequence = item_feature.loc[item_id_sequence].values
761
+ if tensor_feature.feature_type == FeatureType.CATEGORICAL_LIST:
762
+ values.append(feature_sequence.tolist())
763
+ else:
764
+ values.append(np.array(feature_sequence, dtype=np.int64))
765
+
766
+ return values
767
+
768
+
769
+ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
770
+ """
771
+ Class to process sequences of different categorical and numerical features.
772
+
773
+ Processing performs over all features in `tensor_schema`. Each feature processing steps
774
+ depends on feature type (categorical/numerical), feature source (interactions/query features/item features)
775
+ and `grouped_interactions` in Polars DataFrame format.
776
+ """
777
+
778
+ def process_features(self) -> PolarsDataFrame:
779
+ """
780
+ :returns: processed Polars DataFrame with all features from tensor schema.
781
+ """
782
+ data = self._grouped_interactions.select(self._query_id_column)
783
+ for tensor_feature_name in self._tensor_schema:
784
+ data = data.join(self._process_feature(tensor_feature_name), on=self._query_id_column, how="left")
785
+ return data
786
+
787
+ def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
788
+ """
789
+ Process numerical interaction feature.
790
+
791
+ :param tensor_feature: tensor feature information.
792
+
793
+ :returns: tensor feature column as a sequences from `grouped_interactions`.
794
+ """
795
+ return self._process_cat_interaction_feature(tensor_feature)
796
+
797
+ def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
798
+ """
799
+ Process numerical feature from query features dataset.
800
+
801
+ :param tensor_feature: tensor feature information.
802
+
803
+ :returns: sequences with length of item sequence for each query for
804
+ sequential features and one size sequences otherwise.
805
+ """
806
+ return self._process_cat_query_feature(tensor_feature)
807
+
808
+ def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
809
+ """
810
+ Process numerical feature from item features dataset.
811
+
812
+ :param tensor_feature: tensor feature information.
813
+
814
+ :returns: item features as a sequence for each item in a sequence
815
+ for each query.
816
+ """
817
+ return self._process_cat_item_feature(tensor_feature)
818
+
819
+ def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
820
+ """
821
+ Process categorical interaction feature.
822
+
823
+ :param tensor_feature: tensor feature information.
824
+
825
+ :returns: tensor feature column as a sequences from `grouped_interactions`.
826
+ """
827
+ assert tensor_feature.is_seq
828
+
829
+ source = tensor_feature.feature_source
830
+ assert source is not None
831
+
832
+ return self._grouped_interactions.select(self._query_id_column, source.column).rename(
833
+ {source.column: tensor_feature.name}
834
+ )
835
+
836
+ def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
837
+ """
838
+ Process categorical feature from query features dataset.
839
+
840
+ :param tensor_feature: tensor feature information.
841
+
842
+ :returns: sequences with length of item sequence for each query for
843
+ sequential features and one size sequences otherwise.
844
+ """
845
+ assert self._query_features is not None
846
+
847
+ source = tensor_feature.feature_source
848
+ assert source is not None
849
+
850
+ if not tensor_feature.is_seq:
851
+ result = self._query_features.select(self._query_id_column, source.column).rename(
852
+ {source.column: tensor_feature.name}
853
+ )
854
+ if not tensor_feature.is_list:
855
+ result = result.with_columns(pl.col(tensor_feature.name).cast(pl.List(pl.Int64)))
856
+ return result
857
+
858
+ lengths = self._grouped_interactions.select(
859
+ self._query_id_column, pl.col(self._item_id_column).list.len().alias("len")
860
+ )
861
+ result = lengths.join(
862
+ self._query_features.select(self._query_id_column, source.column), on=self._query_id_column, how="left"
863
+ )
864
+
865
+ if tensor_feature.is_list:
735
866
  return (
736
- self._grouped_interactions.select(self._query_id_column, self._item_id_column)
737
- .map_rows(
867
+ result.map_rows(
738
868
  lambda x: (
739
869
  x[0],
740
- pl.DataFrame({self._item_id_column: x[1]})
741
- .join(self._item_features, on=self._item_id_column, how="left")
742
- .select(source.column)
743
- .to_numpy()
744
- .reshape(-1)
745
- .tolist(),
870
+ [x[2]] * x[1],
746
871
  )
747
872
  )
748
- .rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
749
- )
873
+ ).rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
750
874
 
751
- item_feature = self._item_features[source.column]
752
- values: List[np.ndarray] = []
875
+ # just categorical branch
876
+ return result.select(self._query_id_column, pl.col(source.column).repeat_by("len")).rename(
877
+ {source.column: tensor_feature.name}
878
+ )
753
879
 
754
- for item_id_sequence in self._grouped_interactions[self._item_id_column]:
755
- feature_sequence = item_feature.loc[item_id_sequence].values
756
- values.append(np.array(feature_sequence, dtype=np.int64))
880
+ def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
881
+ """
882
+ Process categorical feature from item features dataset.
757
883
 
758
- return values
884
+ :param tensor_feature: tensor feature information.
885
+
886
+ :returns: item features as a sequence for each item in a sequence
887
+ for each query.
888
+ """
889
+ assert tensor_feature.is_seq
890
+ assert self._item_features is not None
891
+
892
+ source = tensor_feature.feature_source
893
+ assert source is not None
894
+ return (
895
+ self._grouped_interactions.select(self._query_id_column, self._item_id_column)
896
+ .map_rows(
897
+ lambda x: (
898
+ x[0],
899
+ self._item_features.select(source.column)
900
+ .filter(self._item_features[self._item_id_column].is_in(x[1]))
901
+ .to_series()
902
+ .to_list(),
903
+ )
904
+ )
905
+ .rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
906
+ )