replay-rec 0.21.0__py3-none-any.whl → 0.21.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/nn/parquet/parquet_module.py +1 -1
- replay/metrics/torch_metrics_builder.py +1 -1
- replay/models/nn/sequential/callbacks/validation_callback.py +14 -4
- replay/nn/lightning/callback/metrics_callback.py +18 -9
- replay/nn/lightning/callback/predictions_callback.py +2 -2
- replay/nn/loss/base.py +3 -3
- replay/nn/loss/login_ce.py +1 -1
- replay/nn/sequential/sasrec/model.py +1 -1
- replay/nn/sequential/twotower/reader.py +14 -5
- replay/nn/transform/template/sasrec.py +3 -3
- replay/nn/transform/template/twotower.py +1 -1
- {replay_rec-0.21.0.dist-info → replay_rec-0.21.1.dist-info}/METADATA +1 -1
- {replay_rec-0.21.0.dist-info → replay_rec-0.21.1.dist-info}/RECORD +17 -17
- {replay_rec-0.21.0.dist-info → replay_rec-0.21.1.dist-info}/WHEEL +0 -0
- {replay_rec-0.21.0.dist-info → replay_rec-0.21.1.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.21.0.dist-info → replay_rec-0.21.1.dist-info}/licenses/NOTICE +0 -0
replay/__init__.py
CHANGED
|
@@ -94,7 +94,7 @@ class ParquetModule(L.LightningDataModule):
|
|
|
94
94
|
missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
|
|
95
95
|
if missing_splits:
|
|
96
96
|
msg = (
|
|
97
|
-
f"The following dataset paths aren't provided: {','.join(missing_splits)}."
|
|
97
|
+
f"The following dataset paths aren't provided: {','.join(missing_splits)}. "
|
|
98
98
|
"Make sure to disable these stages in your Lightning Trainer configuration."
|
|
99
99
|
)
|
|
100
100
|
warnings.warn(msg, stacklevel=2)
|
|
@@ -400,7 +400,7 @@ def metrics_to_df(metrics: Mapping[str, float]) -> PandasDataFrame:
|
|
|
400
400
|
|
|
401
401
|
metric_name_and_k = metrics_df["m"].str.split("@", expand=True)
|
|
402
402
|
metrics_df["metric"] = metric_name_and_k[0]
|
|
403
|
-
metrics_df["k"] = metric_name_and_k[1]
|
|
403
|
+
metrics_df["k"] = metric_name_and_k[1].astype(int)
|
|
404
404
|
|
|
405
405
|
pivoted_metrics = metrics_df.pivot(index="metric", columns="k", values="v")
|
|
406
406
|
pivoted_metrics.index.name = None
|
|
@@ -162,14 +162,24 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
162
162
|
@rank_zero_only
|
|
163
163
|
def print_metrics() -> None:
|
|
164
164
|
metrics = {}
|
|
165
|
+
|
|
165
166
|
for name, value in trainer.logged_metrics.items():
|
|
166
167
|
if "@" in name:
|
|
167
168
|
metrics[name] = value.item()
|
|
168
169
|
|
|
169
|
-
if metrics:
|
|
170
|
-
|
|
170
|
+
if not metrics:
|
|
171
|
+
return
|
|
171
172
|
|
|
172
|
-
|
|
173
|
-
|
|
173
|
+
if len(self._dataloaders_size) > 1:
|
|
174
|
+
for i in range(len(self._dataloaders_size)):
|
|
175
|
+
suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
|
|
176
|
+
cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
|
|
177
|
+
metrics_df = metrics_to_df(cur_dataloader_metrics)
|
|
178
|
+
|
|
179
|
+
print(suffix) # noqa: T201
|
|
180
|
+
print(metrics_df, "\n") # noqa: T201
|
|
181
|
+
else:
|
|
182
|
+
metrics_df = metrics_to_df(metrics)
|
|
183
|
+
print(metrics_df, "\n") # noqa: T201
|
|
174
184
|
|
|
175
185
|
print_metrics()
|
|
@@ -2,7 +2,6 @@ from typing import Any, Optional
|
|
|
2
2
|
|
|
3
3
|
import lightning
|
|
4
4
|
import torch
|
|
5
|
-
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
|
6
5
|
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
|
7
6
|
|
|
8
7
|
from replay.metrics.torch_metrics_builder import (
|
|
@@ -64,8 +63,8 @@ class ComputeMetricsCallback(lightning.Callback):
|
|
|
64
63
|
self._train_column = train_column
|
|
65
64
|
|
|
66
65
|
def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
|
|
67
|
-
if isinstance(dataloaders,
|
|
68
|
-
return [len(dataloader) for dataloader in dataloaders
|
|
66
|
+
if isinstance(dataloaders, list):
|
|
67
|
+
return [len(dataloader) for dataloader in dataloaders]
|
|
69
68
|
return [len(dataloaders)]
|
|
70
69
|
|
|
71
70
|
def on_validation_epoch_start(
|
|
@@ -123,7 +122,7 @@ class ComputeMetricsCallback(lightning.Callback):
|
|
|
123
122
|
batch: dict,
|
|
124
123
|
batch_idx: int,
|
|
125
124
|
dataloader_idx: int = 0,
|
|
126
|
-
) -> None:
|
|
125
|
+
) -> None:
|
|
127
126
|
self._batch_end(
|
|
128
127
|
trainer,
|
|
129
128
|
pl_module,
|
|
@@ -159,7 +158,7 @@ class ComputeMetricsCallback(lightning.Callback):
|
|
|
159
158
|
def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
|
|
160
159
|
self._epoch_end(trainer, pl_module)
|
|
161
160
|
|
|
162
|
-
def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
|
|
161
|
+
def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
|
|
163
162
|
self._epoch_end(trainer, pl_module)
|
|
164
163
|
|
|
165
164
|
def _epoch_end(
|
|
@@ -170,14 +169,24 @@ class ComputeMetricsCallback(lightning.Callback):
|
|
|
170
169
|
@rank_zero_only
|
|
171
170
|
def print_metrics() -> None:
|
|
172
171
|
metrics = {}
|
|
172
|
+
|
|
173
173
|
for name, value in trainer.logged_metrics.items():
|
|
174
174
|
if "@" in name:
|
|
175
175
|
metrics[name] = value.item()
|
|
176
176
|
|
|
177
|
-
if metrics:
|
|
178
|
-
|
|
177
|
+
if not metrics:
|
|
178
|
+
return
|
|
179
179
|
|
|
180
|
-
|
|
181
|
-
|
|
180
|
+
if len(self._dataloaders_size) > 1:
|
|
181
|
+
for i in range(len(self._dataloaders_size)):
|
|
182
|
+
suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
|
|
183
|
+
cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
|
|
184
|
+
metrics_df = metrics_to_df(cur_dataloader_metrics)
|
|
185
|
+
|
|
186
|
+
print(suffix) # noqa: T201
|
|
187
|
+
print(metrics_df, "\n") # noqa: T201
|
|
188
|
+
else:
|
|
189
|
+
metrics_df = metrics_to_df(metrics)
|
|
190
|
+
print(metrics_df, "\n") # noqa: T201
|
|
182
191
|
|
|
183
192
|
print_metrics()
|
|
@@ -15,11 +15,11 @@ from replay.utils import (
|
|
|
15
15
|
SparkDataFrame,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
-
if PYSPARK_AVAILABLE:
|
|
18
|
+
if PYSPARK_AVAILABLE:
|
|
19
19
|
import pyspark.sql.functions as sf
|
|
20
20
|
from pyspark.sql import SparkSession
|
|
21
21
|
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
|
|
22
|
-
else:
|
|
22
|
+
else:
|
|
23
23
|
SparkSession = MissingImport
|
|
24
24
|
|
|
25
25
|
|
replay/nn/loss/base.py
CHANGED
|
@@ -85,7 +85,7 @@ class SampledLossBase(torch.nn.Module):
|
|
|
85
85
|
# [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
|
|
86
86
|
negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
|
|
87
87
|
|
|
88
|
-
if negative_labels.dim() == 3:
|
|
88
|
+
if negative_labels.dim() == 3:
|
|
89
89
|
# [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
|
|
90
90
|
negative_labels = negative_labels.unsqueeze(-2)
|
|
91
91
|
if num_positives != 1:
|
|
@@ -119,7 +119,7 @@ class SampledLossBase(torch.nn.Module):
|
|
|
119
119
|
positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
|
|
120
120
|
assert positive_labels.size() == (masked_batch_size, 1)
|
|
121
121
|
|
|
122
|
-
if negative_labels.dim() != 1:
|
|
122
|
+
if negative_labels.dim() != 1:
|
|
123
123
|
# [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
|
|
124
124
|
negative_labels = negative_labels[target_padding_mask]
|
|
125
125
|
assert negative_labels.size() == (masked_batch_size, num_negatives)
|
|
@@ -183,7 +183,7 @@ def mask_negative_logits(
|
|
|
183
183
|
if negative_labels_ignore_index >= 0:
|
|
184
184
|
negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
|
|
185
185
|
|
|
186
|
-
if negative_labels.dim() > 1:
|
|
186
|
+
if negative_labels.dim() > 1:
|
|
187
187
|
# [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
|
|
188
188
|
negative_labels = negative_labels.unsqueeze(-2)
|
|
189
189
|
|
replay/nn/loss/login_ce.py
CHANGED
|
@@ -74,7 +74,7 @@ class LogInCEBase(SampledLossBase):
|
|
|
74
74
|
positive_labels = positive_labels[masked_target_padding_mask]
|
|
75
75
|
assert positive_labels.size() == (masked_batch_size, num_positives)
|
|
76
76
|
|
|
77
|
-
if negative_labels.dim() > 1:
|
|
77
|
+
if negative_labels.dim() > 1:
|
|
78
78
|
# [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
|
|
79
79
|
negative_labels = negative_labels[masked_target_padding_mask]
|
|
80
80
|
assert negative_labels.size() == (masked_batch_size, num_negatives)
|
|
@@ -141,7 +141,7 @@ class SasRec(torch.nn.Module):
|
|
|
141
141
|
feature_type=FeatureType.CATEGORICAL,
|
|
142
142
|
embedding_dim=256,
|
|
143
143
|
padding_value=NUM_UNIQUE_ITEMS,
|
|
144
|
-
cardinality=NUM_UNIQUE_ITEMS
|
|
144
|
+
cardinality=NUM_UNIQUE_ITEMS,
|
|
145
145
|
feature_hint=FeatureHint.ITEM_ID,
|
|
146
146
|
feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
|
|
147
147
|
),
|
|
@@ -22,7 +22,6 @@ class FeaturesReader:
|
|
|
22
22
|
:param schema: the same tensor schema used in TwoTower model.
|
|
23
23
|
:param metadata: A dictionary of feature names that
|
|
24
24
|
associated with its shape and padding_value.\n
|
|
25
|
-
Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
|
|
26
25
|
For details, see the section :ref:`parquet-processing`.
|
|
27
26
|
:param path: path to parquet with dataframe of item features.\n
|
|
28
27
|
**Note:**\n
|
|
@@ -30,8 +29,8 @@ class FeaturesReader:
|
|
|
30
29
|
2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
|
|
31
30
|
of the source features to create correct inverse mapping.
|
|
32
31
|
Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
|
|
33
|
-
contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
|
|
34
|
-
or hint type FeatureHint.ITEM_ID
|
|
32
|
+
contain ``feature_sources`` with a source of type ``FeatureSource.ITEM_FEATURES``
|
|
33
|
+
or hint type ``FeatureHint.ITEM_ID``.
|
|
35
34
|
|
|
36
35
|
"""
|
|
37
36
|
item_feature_names = [
|
|
@@ -81,8 +80,18 @@ class FeaturesReader:
|
|
|
81
80
|
self._features = {}
|
|
82
81
|
|
|
83
82
|
for k in features.columns:
|
|
84
|
-
dtype =
|
|
85
|
-
|
|
83
|
+
dtype = np.float32 if schema[k].is_num else np.int64
|
|
84
|
+
if schema[k].is_list:
|
|
85
|
+
feature = np.asarray(
|
|
86
|
+
features[k].to_list(),
|
|
87
|
+
dtype=dtype,
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
feature = features[k].to_numpy(dtype=dtype)
|
|
91
|
+
feature_tensor = torch.asarray(
|
|
92
|
+
feature,
|
|
93
|
+
dtype=torch.float32 if schema[k].is_num else torch.int64,
|
|
94
|
+
)
|
|
86
95
|
self._features[k] = feature_tensor
|
|
87
96
|
|
|
88
97
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
@@ -14,7 +14,7 @@ def make_default_sasrec_transforms(
|
|
|
14
14
|
|
|
15
15
|
Generated pipeline expects input dataset to contain the following columns:
|
|
16
16
|
1) Query ID column, specified by ``query_column``.
|
|
17
|
-
2)
|
|
17
|
+
2) All features specified in the ``tensor_schema``.
|
|
18
18
|
|
|
19
19
|
:param tensor_schema: TensorSchema used to infer feature columns.
|
|
20
20
|
:param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
|
|
@@ -32,12 +32,12 @@ def make_default_sasrec_transforms(
|
|
|
32
32
|
),
|
|
33
33
|
UnsqueezeTransform("target_padding_mask", -1),
|
|
34
34
|
UnsqueezeTransform("positive_labels", -1),
|
|
35
|
-
GroupTransform({"feature_tensors":
|
|
35
|
+
GroupTransform({"feature_tensors": tensor_schema.names}),
|
|
36
36
|
]
|
|
37
37
|
|
|
38
38
|
val_transforms = [
|
|
39
39
|
RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
|
|
40
|
-
GroupTransform({"feature_tensors":
|
|
40
|
+
GroupTransform({"feature_tensors": tensor_schema.names}),
|
|
41
41
|
]
|
|
42
42
|
test_transforms = copy.deepcopy(val_transforms)
|
|
43
43
|
|
|
@@ -13,7 +13,7 @@ def make_default_twotower_transforms(
|
|
|
13
13
|
|
|
14
14
|
Generated pipeline expects input dataset to contain the following columns:
|
|
15
15
|
1) Query ID column, specified by ``query_column``.
|
|
16
|
-
2)
|
|
16
|
+
2) All features specified in the ``tensor_schema``.
|
|
17
17
|
|
|
18
18
|
:param tensor_schema: TensorSchema used to infer feature columns.
|
|
19
19
|
:param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
replay/__init__.py,sha256=
|
|
1
|
+
replay/__init__.py,sha256=2kRxqt2GF_2mTRxcddaKhR1p-tGp_fVjPLBFC2gI4os,225
|
|
2
2
|
replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
|
|
3
3
|
replay/data/dataset.py,sha256=yBl-yJVIokgN4prFY949tHe2UVJC_j5xdaulIoSPvQI,31252
|
|
4
4
|
replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
|
|
@@ -31,7 +31,7 @@ replay/data/nn/parquet/iterator.py,sha256=X5KXtjdY_uSfMlP9IXBqMzSimBqlAZbYX_Y483
|
|
|
31
31
|
replay/data/nn/parquet/metadata/__init__.py,sha256=UZX60ANtjo6zX0p43hU9q8fBldVJNCEmGzXjHqz0MJQ,341
|
|
32
32
|
replay/data/nn/parquet/metadata/metadata.py,sha256=jJOL8mieXhX18FO9lgaP95MOtO1l7tY63ldxoOAvzwA,3459
|
|
33
33
|
replay/data/nn/parquet/parquet_dataset.py,sha256=pKthRppp0MstfNwOk9wMrE6wFvDecCtbTKWIri4HGr0,8017
|
|
34
|
-
replay/data/nn/parquet/parquet_module.py,sha256=
|
|
34
|
+
replay/data/nn/parquet/parquet_module.py,sha256=BSf_ev-XtFTsPV9R3y9YO2qa1JHU4Z1Wp7jsXL6GFjM,8209
|
|
35
35
|
replay/data/nn/parquet/partitioned_iterable_dataset.py,sha256=BZEh2EiBKMZxi822-doyTbjDkZQQ62SxAp_NhZVZdmk,1938
|
|
36
36
|
replay/data/nn/parquet/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
37
|
replay/data/nn/parquet/utils/compute_length.py,sha256=VWabulpRICy-_Z0ZBXpEmhAIlpXVwTwe9kX2L2XCdbE,2492
|
|
@@ -62,7 +62,7 @@ replay/metrics/precision.py,sha256=DRlsgY_b4bJCOSZjCA58N41REMiDt-dbagRSXxfXyvY,2
|
|
|
62
62
|
replay/metrics/recall.py,sha256=fzpASDiH88zcpXJZTbStQ3nuzzSdhd9k1wjF27rM4wc,2447
|
|
63
63
|
replay/metrics/rocauc.py,sha256=1vaVEK7DQTL8BX-i7A64hTFWyO38aNycscPGrdWKwbA,3282
|
|
64
64
|
replay/metrics/surprisal.py,sha256=HkmYrOuw3jydxFrkidjdcpAcKz2DeOnMsKqwB2g9pwY,7526
|
|
65
|
-
replay/metrics/torch_metrics_builder.py,sha256=
|
|
65
|
+
replay/metrics/torch_metrics_builder.py,sha256=cf0cRGQnBtR9OUUFOzLvOk3pX9rX2613nw5L28J4DDw,14083
|
|
66
66
|
replay/metrics/unexpectedness.py,sha256=LSi-z50l3_yrvLnmToHQzm6Ygf2QpNt_zhk6jdg7QUo,6882
|
|
67
67
|
replay/models/__init__.py,sha256=kECYluQZ83zRUWaHVvnt7Tg3BerHrJy9v8XfRxsqyYY,1123
|
|
68
68
|
replay/models/als.py,sha256=1MFAbcx64tv0MX1wE9CM1NxKD3F3ZDhZUrmt6dvHu74,6220
|
|
@@ -115,7 +115,7 @@ replay/models/nn/sequential/bert4rec/lightning.py,sha256=vxAf1H1VfLqgZhOz9fxEMmw
|
|
|
115
115
|
replay/models/nn/sequential/bert4rec/model.py,sha256=C1AKcQ8KF0XMXERwrFneW9kg7hzc-9FIqhCc-t91F7o,17469
|
|
116
116
|
replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
|
|
117
117
|
replay/models/nn/sequential/callbacks/prediction_callbacks.py,sha256=UtEzO9_f5Jwku9dbz7twr4o2_cV3L-viC4lQuce5l1c,10808
|
|
118
|
-
replay/models/nn/sequential/callbacks/validation_callback.py,sha256=
|
|
118
|
+
replay/models/nn/sequential/callbacks/validation_callback.py,sha256=VDIa8c6Wpekz_AvzdtETajvdkqi2aiJBokE9JEOY3rI,7071
|
|
119
119
|
replay/models/nn/sequential/compiled/__init__.py,sha256=eSVcCaUH5cDJQRbC7K99X7uMNR-Z-KR4TmYOGKWWJCI,531
|
|
120
120
|
replay/models/nn/sequential/compiled/base_compiled_model.py,sha256=f4AuTyx5tufQOtOWUSEgj1cWvMZzSL7YN2Z-PtURgTY,10478
|
|
121
121
|
replay/models/nn/sequential/compiled/bert4rec_compiled.py,sha256=woGI3qk4J2Rb5FyaDwpSCuG-AMfyH34F6Bt5pV-wqk0,6798
|
|
@@ -146,8 +146,8 @@ replay/nn/ffn.py,sha256=ivOFu14289URepyEFxYov_XNYMUrINjU-2rEqoXxbnU,4618
|
|
|
146
146
|
replay/nn/head.py,sha256=csjwQrcA7M7FebgSL1tKDbjfaoni52CymQR0Zt8RhWg,2084
|
|
147
147
|
replay/nn/lightning/__init__.py,sha256=jHiwtYuboGUY4Of18zrkvdWD0xXJ_zuo83-XgiqxSfY,36
|
|
148
148
|
replay/nn/lightning/callback/__init__.py,sha256=ImNEJeIK-wJnqdkZgP8tWTDQHaS9xYqzTEf3FEM0XAw,253
|
|
149
|
-
replay/nn/lightning/callback/metrics_callback.py,sha256=
|
|
150
|
-
replay/nn/lightning/callback/predictions_callback.py,sha256=
|
|
149
|
+
replay/nn/lightning/callback/metrics_callback.py,sha256=AzDsxvNHfjrJdhcgZsMtnKju1TYO86Pc2Knv_tS7HBA,7323
|
|
150
|
+
replay/nn/lightning/callback/predictions_callback.py,sha256=4iS3QwRRFolAwizxDp2guBDJNRvitgOOHhYmMPL8ub0,11307
|
|
151
151
|
replay/nn/lightning/module.py,sha256=jFvevwiriY9alZMBw6KAiRMsJv-dJ8fEVrenVRiuWeI,5246
|
|
152
152
|
replay/nn/lightning/optimizer.py,sha256=1tXhz9RIBHLpEQtZ1PUzCAc4mn6Q_E38zR0nf5km6U8,1778
|
|
153
153
|
replay/nn/lightning/postprocessor/__init__.py,sha256=LhUeOWDD5vRBDXF2tQEjvPKH1rNIlrf5KPbcV66AdtQ,77
|
|
@@ -155,10 +155,10 @@ replay/nn/lightning/postprocessor/_base.py,sha256=X0LtYItmxlXt4Sxk3cOdyIK3FG5dij
|
|
|
155
155
|
replay/nn/lightning/postprocessor/seen_items.py,sha256=h-sfD3vmNCdS7lYvqCfqw9oPqutmaSIuZ0CIidG0Y30,2922
|
|
156
156
|
replay/nn/lightning/scheduler.py,sha256=CUuynPTFrKBrkpmbWR-xpfAkHZ0Vfz_THUDo3uoZi8k,2714
|
|
157
157
|
replay/nn/loss/__init__.py,sha256=YXAXQIN0coj8MxeK5isTGXgvMxhH5pUO6j1D3d7jl3A,471
|
|
158
|
-
replay/nn/loss/base.py,sha256=
|
|
158
|
+
replay/nn/loss/base.py,sha256=XM2ASulAW8Kyg2Vw43I8Tqv1d8cij9NcNirP_RTk4b8,8811
|
|
159
159
|
replay/nn/loss/bce.py,sha256=cPlxdJTBZ0b22K6V9ve4qo7xkp99CjEsnl3_vVGphqs,8373
|
|
160
160
|
replay/nn/loss/ce.py,sha256=jOmhLtKD_E0jX8tUfXpsmaaQVHKKiwXW9USB_GyN3ZU,13218
|
|
161
|
-
replay/nn/loss/login_ce.py,sha256=
|
|
161
|
+
replay/nn/loss/login_ce.py,sha256=ri4KvHQXOVMB5o_vqGY2u8ayatYH9MLZwsXwp6cpDhI,16478
|
|
162
162
|
replay/nn/loss/logout_ce.py,sha256=KhcYyCnUzLZR1sFpxM6_QliLoxmC6MJoLkPOgf_ZYzU,10306
|
|
163
163
|
replay/nn/mask.py,sha256=Jbx7sulGZYfasNaD9CZzJma0cEVaDlxdpzs295507II,3329
|
|
164
164
|
replay/nn/normalization.py,sha256=Z86t5WCr4KfVR9qCCe-EIAwwomnIIxb11PP88WHA1JI,187
|
|
@@ -167,11 +167,11 @@ replay/nn/sequential/__init__.py,sha256=jet_ueMz5Bm087JDph7ln87NID7DbCb0WENj-tjo
|
|
|
167
167
|
replay/nn/sequential/sasrec/__init__.py,sha256=8crj-JL8xeP-cCOCnxCSVF_-R6feKhj0YRHOcaMsqrU,213
|
|
168
168
|
replay/nn/sequential/sasrec/agg.py,sha256=e-IkIO-MMbei2UGxTUopWvloguJoVaZiN31sXkdUVag,2004
|
|
169
169
|
replay/nn/sequential/sasrec/diff_transformer.py,sha256=4ehM5EMizajmWBAzmcj3CYSFl21V1R2b7RDRJlx3O4Q,4790
|
|
170
|
-
replay/nn/sequential/sasrec/model.py,sha256=
|
|
170
|
+
replay/nn/sequential/sasrec/model.py,sha256=Db4IcI4EHzQoO7Vij_ItGvvs8aOJ6ANyHuXp_9v84zs,14801
|
|
171
171
|
replay/nn/sequential/sasrec/transformer.py,sha256=sJf__IPnhbJWDPuFTPSbBGSSntznQtS-hJtJo3iFBkw,4037
|
|
172
172
|
replay/nn/sequential/twotower/__init__.py,sha256=-rEASPqKCbS55MTTgeDZ5irfWfM9or1vNTHZnJN2AcU,124
|
|
173
173
|
replay/nn/sequential/twotower/model.py,sha256=VxUUjldHndCkDjrXGqmxGnTi5fh8vmnr7XNBpYjsqW8,28659
|
|
174
|
-
replay/nn/sequential/twotower/reader.py,sha256=
|
|
174
|
+
replay/nn/sequential/twotower/reader.py,sha256=8z-R0oZDbOaw6eFL3ffyt7yuc3q7qoKFhFrIdiwwJ10,3938
|
|
175
175
|
replay/nn/transform/__init__.py,sha256=9PeaDHmftb0s1gEEgJRNWw6Bl2wfE_-lImatipaHUQ0,705
|
|
176
176
|
replay/nn/transform/copy.py,sha256=ZfNXbMJYTwXDMJ5T8ib9Dh5XOGLjj7gGB4NbBExFZiM,1302
|
|
177
177
|
replay/nn/transform/grouping.py,sha256=XOJoVBk234DI6x05Kqr7KOjLetDaLp2NMAJWHecQcsI,1384
|
|
@@ -181,8 +181,8 @@ replay/nn/transform/rename.py,sha256=_uD2e1UmtBRyOTVpHUnZ5xhePmClaGQsc0g7Es-rupE
|
|
|
181
181
|
replay/nn/transform/reshape.py,sha256=sgswIogWHUwOVp02k13Qopn84LofqLoA4M7U1GAfmio,1359
|
|
182
182
|
replay/nn/transform/sequence_roll.py,sha256=7jf42SgWHU1L7SirqQWXx0h9a6VQQ29kehE4LmdUt9o,1531
|
|
183
183
|
replay/nn/transform/template/__init__.py,sha256=lYzAekZUXwncGR66Nq8YypplGOtL00GFfm0PalGiY5g,106
|
|
184
|
-
replay/nn/transform/template/sasrec.py,sha256=
|
|
185
|
-
replay/nn/transform/template/twotower.py,sha256=
|
|
184
|
+
replay/nn/transform/template/sasrec.py,sha256=0bhb8EhyTafiM9Eh1GfPEFS_syyUCPY3JOzTBr17I-o,1919
|
|
185
|
+
replay/nn/transform/template/twotower.py,sha256=qutfN3iUwGHn1BG2yLEc7hHQoxEwQiFdvXIoe-ej4tM,897
|
|
186
186
|
replay/nn/transform/token_mask.py,sha256=WcalZkY2UCoNiq2mBtu8fqYFOUfqCh21XyDMgvIpeB4,2529
|
|
187
187
|
replay/nn/transform/trim.py,sha256=mPn6LPxu3c3yE14heMSRsDEU4h94tkFiRr62mOa3lKg,1608
|
|
188
188
|
replay/nn/utils.py,sha256=GumtN-QRP9ljXYti3YvuNk13e0Q92xvkYuCJBhaViCI,801
|
|
@@ -216,8 +216,8 @@ replay/utils/session_handler.py,sha256=fQo2wseow8yuzKnEXT-aYAXcQIgRbTTXp0v7g1VVi
|
|
|
216
216
|
replay/utils/spark_utils.py,sha256=GbRp-MuUoO3Pc4chFvlmo9FskSlRLeNlC3Go5pEJ6Ok,27411
|
|
217
217
|
replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
|
|
218
218
|
replay/utils/types.py,sha256=rD9q9CqEXgF4yy512Hv2nXclvwcnfodOnhBZ1HSUI4c,1260
|
|
219
|
-
replay_rec-0.21.
|
|
220
|
-
replay_rec-0.21.
|
|
221
|
-
replay_rec-0.21.
|
|
222
|
-
replay_rec-0.21.
|
|
223
|
-
replay_rec-0.21.
|
|
219
|
+
replay_rec-0.21.1.dist-info/METADATA,sha256=atVJNoBxihnIh3r6q0pORPlykCDaDD5t-HPKJ6ufNIw,13573
|
|
220
|
+
replay_rec-0.21.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
221
|
+
replay_rec-0.21.1.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
|
|
222
|
+
replay_rec-0.21.1.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
|
|
223
|
+
replay_rec-0.21.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|