nextrec 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.4"
1
+ __version__ = "0.2.5"
nextrec/basic/features.py CHANGED
@@ -93,6 +93,8 @@ class FeatureSpecMixin:
93
93
  dense_features: Sequence[DenseFeature] | None = None,
94
94
  sparse_features: Sequence[SparseFeature] | None = None,
95
95
  sequence_features: Sequence[SequenceFeature] | None = None,
96
+ target: str | Sequence[str] | None = None,
97
+ id_columns: str | Sequence[str] | None = None,
96
98
  ) -> None:
97
99
  self.dense_features: List[DenseFeature] = list(dense_features) if dense_features else []
98
100
  self.sparse_features: List[SparseFeature] = list(sparse_features) if sparse_features else []
@@ -100,8 +102,10 @@ class FeatureSpecMixin:
100
102
 
101
103
  self.all_features = self.dense_features + self.sparse_features + self.sequence_features
102
104
  self.feature_names = [feat.name for feat in self.all_features]
105
+ self.target_columns = self._normalize_to_list(target)
106
+ self.id_columns = self._normalize_to_list(id_columns)
103
107
 
104
- def _set_target_config(
108
+ def _set_target_id_config(
105
109
  self,
106
110
  target: str | Sequence[str] | None = None,
107
111
  id_columns: str | Sequence[str] | None = None,
nextrec/basic/layers.py CHANGED
@@ -49,10 +49,6 @@ __all__ = [
49
49
 
50
50
 
51
51
  class PredictionLayer(nn.Module):
52
- _CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
53
- _REGRESSION_TASKS = {"regression", "continuous"}
54
- _MULTICLASS_TASKS = {"multiclass", "softmax"}
55
-
56
52
  def __init__(
57
53
  self,
58
54
  task_type: Union[str, Sequence[str]] = "binary",
@@ -131,11 +127,11 @@ class PredictionLayer(nn.Module):
131
127
 
132
128
  def _get_activation(self, task_type: str):
133
129
  task = task_type.lower()
134
- if task in self._CLASSIFICATION_TASKS:
130
+ if task in ['binary','multiclass']:
135
131
  return torch.sigmoid
136
- if task in self._REGRESSION_TASKS:
132
+ if task in ['regression']:
137
133
  return lambda x: x
138
- if task in self._MULTICLASS_TASKS:
134
+ if task in ['multiclass']:
139
135
  return lambda x: torch.softmax(x, dim=-1)
140
136
  raise ValueError(f"Unsupported task_type '{task_type}'.")
141
137