replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.data as data
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
|
|
8
|
+
from replay.data.utils.batching import UniformBatching, uniform_batch_count
|
|
9
|
+
|
|
10
|
+
from .impl.named_columns import NamedColumns
|
|
11
|
+
from .info.partitioning import Partitioning, partitioning_per_replica
|
|
12
|
+
from .info.replicas import ReplicasInfoProtocol
|
|
13
|
+
|
|
14
|
+
Batch = dict[str, torch.Tensor]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def validate_batch_size(batch_size: int) -> int:
|
|
18
|
+
if batch_size <= 0:
|
|
19
|
+
msg = f"batch_size must be a positive integer. Got {batch_size=}"
|
|
20
|
+
raise ValueError(msg)
|
|
21
|
+
return batch_size
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class IterableDataset(data.IterableDataset):
|
|
25
|
+
"""
|
|
26
|
+
An iterable dataset used for processing a single partition of data.
|
|
27
|
+
Supports distributed training, where data is divided between replicas, and reproducible random shuffling.
|
|
28
|
+
|
|
29
|
+
A replica is a worker or a set of workers for which a unique chunk of data will be assigned
|
|
30
|
+
during distributed training/inference.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
named_columns: NamedColumns,
|
|
36
|
+
batch_size: int,
|
|
37
|
+
generator: Optional[torch.Generator] = None,
|
|
38
|
+
replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
:param named_columns: Structured data presented as columns.
|
|
42
|
+
:param batch_size: Batch size.
|
|
43
|
+
:param generator: Random number generator for batch shuffling.
|
|
44
|
+
If ``None``, shuffling will be disabled. Default: ``None``.
|
|
45
|
+
:param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
|
|
46
|
+
Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self.named_columns = named_columns
|
|
51
|
+
self.generator = generator
|
|
52
|
+
self.replicas_info = replicas_info
|
|
53
|
+
self.batch_size = validate_batch_size(batch_size)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def device(self) -> torch.device:
|
|
57
|
+
"""Returns the device containing the dataset."""
|
|
58
|
+
return self.named_columns.device
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def full_length(self) -> int:
|
|
62
|
+
"""Returns the total amount of elements in `named_columns`."""
|
|
63
|
+
return self.named_columns.length
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def length_per_replica(self) -> int:
|
|
67
|
+
"""Returns the total number of available elements per replica."""
|
|
68
|
+
full_length = self.named_columns.length
|
|
69
|
+
num_replicas = self.replicas_info.num_replicas
|
|
70
|
+
return partitioning_per_replica(full_length, num_replicas)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def length(self) -> int:
|
|
74
|
+
"""Returns the total number of batches available to the current replica."""
|
|
75
|
+
batch_size = self.batch_size
|
|
76
|
+
per_replica = self.length_per_replica
|
|
77
|
+
return uniform_batch_count(per_replica, batch_size)
|
|
78
|
+
|
|
79
|
+
def __len__(self) -> int:
|
|
80
|
+
"""Returns the total number of batches in a dataset."""
|
|
81
|
+
return self.length
|
|
82
|
+
|
|
83
|
+
def get_indices(self) -> torch.LongTensor:
|
|
84
|
+
"""
|
|
85
|
+
Generates indices corresponding to data assigned to current replica.
|
|
86
|
+
|
|
87
|
+
:return: tensor containing relevant indices.
|
|
88
|
+
"""
|
|
89
|
+
partitioning = Partitioning(
|
|
90
|
+
curr_replica=self.replicas_info.curr_replica,
|
|
91
|
+
num_replicas=self.replicas_info.num_replicas,
|
|
92
|
+
device=self.named_columns.device,
|
|
93
|
+
generator=self.generator,
|
|
94
|
+
)
|
|
95
|
+
indices = partitioning(self.full_length)
|
|
96
|
+
assert self.length_per_replica == torch.numel(indices)
|
|
97
|
+
return indices
|
|
98
|
+
|
|
99
|
+
def get_batching(self) -> UniformBatching:
|
|
100
|
+
"""
|
|
101
|
+
Creates a partitioning object which splits data into batches.
|
|
102
|
+
|
|
103
|
+
:return: The partitioning object.
|
|
104
|
+
"""
|
|
105
|
+
batching = UniformBatching(
|
|
106
|
+
length=self.length_per_replica,
|
|
107
|
+
batch_size=self.batch_size,
|
|
108
|
+
)
|
|
109
|
+
assert len(batching) == self.length
|
|
110
|
+
return batching
|
|
111
|
+
|
|
112
|
+
def __iter__(self) -> Iterator[Batch]:
|
|
113
|
+
"""Batched data iterator."""
|
|
114
|
+
batching = self.get_batching()
|
|
115
|
+
indices = self.get_indices()
|
|
116
|
+
|
|
117
|
+
for first, last in iter(batching):
|
|
118
|
+
batch_ids = indices[first:last]
|
|
119
|
+
yield self.named_columns[batch_ids]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
from typing import Any, Callable, Optional
|
|
3
|
+
|
|
4
|
+
import pyarrow.dataset as da
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
8
|
+
from replay.data.nn.parquet.impl.masking import DEFAULT_MAKE_MASK_NAME
|
|
9
|
+
|
|
10
|
+
from .impl.array_1d_column import to_array_1d_columns
|
|
11
|
+
from .impl.array_2d_column import to_array_2d_columns
|
|
12
|
+
from .impl.named_columns import NamedColumns
|
|
13
|
+
from .impl.numeric_column import to_numeric_columns
|
|
14
|
+
from .metadata import Metadata
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BatchesIterator:
|
|
18
|
+
"""Iterator for batch-wise extraction of data from a Parquet dataset with conversion to structured columns."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
metadata: Metadata,
|
|
23
|
+
dataset: da.Dataset,
|
|
24
|
+
batch_size: int,
|
|
25
|
+
make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
|
|
26
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
27
|
+
pyarrow_kwargs: Optional[dict[str, Any]] = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
:param metadata: Metadata describing the structure and types of input data.
|
|
31
|
+
:param dataset: Pyarrow dataset implementing the ``to_batches`` method.
|
|
32
|
+
:param batch_size: Batch size sampled from a single partition.
|
|
33
|
+
Resulting batch will not always match it in size due to mismatches between
|
|
34
|
+
the target batch size and the partition size.
|
|
35
|
+
:param make_mask_name: Mask name generation function. Default: value of ``DEFAULT_MAKE_MASK_NAME``.
|
|
36
|
+
:param device: The device on which the data will be generated. Defaults: value of ``DEFAULT_DEVICE``.
|
|
37
|
+
:param pyarrow_kwargs: Additional parameters for PyArrow dataset's ``to_batches`` method. Default: ``None``.
|
|
38
|
+
"""
|
|
39
|
+
if pyarrow_kwargs is None:
|
|
40
|
+
pyarrow_kwargs = {}
|
|
41
|
+
self.dataset = dataset
|
|
42
|
+
self.metadata = metadata
|
|
43
|
+
self.batch_size = batch_size
|
|
44
|
+
self.make_mask_name = make_mask_name
|
|
45
|
+
self.device = device
|
|
46
|
+
self.pyarrow_kwargs = pyarrow_kwargs
|
|
47
|
+
|
|
48
|
+
def __iter__(self) -> Iterator[NamedColumns]:
|
|
49
|
+
for batch in self.dataset.to_batches(
|
|
50
|
+
batch_size=self.batch_size,
|
|
51
|
+
columns=list(self.metadata.keys()),
|
|
52
|
+
**self.pyarrow_kwargs,
|
|
53
|
+
):
|
|
54
|
+
yield NamedColumns(
|
|
55
|
+
columns={
|
|
56
|
+
**to_numeric_columns(batch, self.metadata, self.device),
|
|
57
|
+
**to_array_1d_columns(batch, self.metadata, self.device),
|
|
58
|
+
**to_array_2d_columns(batch, self.metadata, self.device),
|
|
59
|
+
},
|
|
60
|
+
make_mask_name=self.make_mask_name,
|
|
61
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .metadata import (
|
|
2
|
+
ColumnMetadata,
|
|
3
|
+
Metadata,
|
|
4
|
+
get_1d_array_columns,
|
|
5
|
+
get_2d_array_columns,
|
|
6
|
+
get_numeric_columns,
|
|
7
|
+
get_padding,
|
|
8
|
+
get_shape,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ColumnMetadata",
|
|
13
|
+
"Metadata",
|
|
14
|
+
"get_1d_array_columns",
|
|
15
|
+
"get_2d_array_columns",
|
|
16
|
+
"get_numeric_columns",
|
|
17
|
+
"get_padding",
|
|
18
|
+
"get_shape",
|
|
19
|
+
]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any, Union
|
|
3
|
+
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.parquet.constants.metadata import (
|
|
7
|
+
DEFAULT_PADDING,
|
|
8
|
+
PADDING_FLAG,
|
|
9
|
+
SHAPE_FLAG,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
FieldType: TypeAlias = Union[bool, int, float, str]
|
|
13
|
+
ColumnMetadata: TypeAlias = dict[str, FieldType]
|
|
14
|
+
Metadata: TypeAlias = dict[str, ColumnMetadata]
|
|
15
|
+
|
|
16
|
+
ColumnCheck: TypeAlias = Callable[[ColumnMetadata], bool]
|
|
17
|
+
CheckColumn: TypeAlias = Callable[[ColumnCheck], bool]
|
|
18
|
+
Listing: TypeAlias = Callable[[Metadata], list[str]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_shape_check(dim: int) -> ColumnCheck:
|
|
22
|
+
"""
|
|
23
|
+
Constructs a function which checks a column's shape.
|
|
24
|
+
|
|
25
|
+
:param dim: Target number of dimensions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def function(column_metadata: ColumnMetadata) -> bool:
|
|
29
|
+
if SHAPE_FLAG in column_metadata:
|
|
30
|
+
value: Any = column_metadata[SHAPE_FLAG]
|
|
31
|
+
if dim == 1 and isinstance(value, int):
|
|
32
|
+
return True
|
|
33
|
+
if isinstance(value, list):
|
|
34
|
+
result: bool = len(value) == dim
|
|
35
|
+
if result:
|
|
36
|
+
|
|
37
|
+
def is_int(v: Any) -> bool:
|
|
38
|
+
return isinstance(v, int)
|
|
39
|
+
|
|
40
|
+
result &= all(map(is_int, value))
|
|
41
|
+
return result
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
return function
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def make_not_check(check: ColumnCheck) -> ColumnCheck:
|
|
48
|
+
def function(column_metadata: ColumnCheck) -> bool:
|
|
49
|
+
return not check(column_metadata)
|
|
50
|
+
|
|
51
|
+
return function
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def all_column_checks(*checks: ColumnCheck) -> ColumnCheck:
|
|
55
|
+
def function(column_metadata: ColumnMetadata) -> bool:
|
|
56
|
+
def perform_check(check):
|
|
57
|
+
return check(column_metadata)
|
|
58
|
+
|
|
59
|
+
return all(map(perform_check, checks))
|
|
60
|
+
|
|
61
|
+
return function
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
is_array_1d = all_column_checks(make_shape_check(dim=1))
|
|
65
|
+
is_array_2d = all_column_checks(make_shape_check(dim=2))
|
|
66
|
+
is_number = all_column_checks(
|
|
67
|
+
make_not_check(is_array_1d),
|
|
68
|
+
make_not_check(is_array_2d),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def make_listing(check: ColumnCheck) -> Listing:
|
|
73
|
+
"""
|
|
74
|
+
Filtering function for selecting columns that pass the provided check.
|
|
75
|
+
|
|
76
|
+
:param check: Check function to validate agains.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def function(metadata: Metadata) -> list[str]:
|
|
80
|
+
result: list[str] = []
|
|
81
|
+
for col_name, col_meta in metadata.items():
|
|
82
|
+
if check(col_meta):
|
|
83
|
+
result.append(col_name)
|
|
84
|
+
return sorted(result)
|
|
85
|
+
|
|
86
|
+
return function
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
get_1d_array_columns = make_listing(is_array_1d)
|
|
90
|
+
get_2d_array_columns = make_listing(is_array_2d)
|
|
91
|
+
get_numeric_columns = make_listing(is_number)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_padding(metadata: Metadata, column_name: str) -> Any:
|
|
95
|
+
if column_name not in metadata:
|
|
96
|
+
msg = f"Column {column_name} not found in metadata."
|
|
97
|
+
raise KeyError(msg)
|
|
98
|
+
return metadata[column_name].get(PADDING_FLAG, DEFAULT_PADDING)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_shape(metadata: Metadata, column_name: str) -> list[int]:
|
|
102
|
+
if column_name not in metadata:
|
|
103
|
+
msg = f"Column {column_name} not found in metadata."
|
|
104
|
+
raise KeyError(msg)
|
|
105
|
+
if is_number(metadata[column_name]):
|
|
106
|
+
msg = f"Column {column_name} is not an array."
|
|
107
|
+
raise ValueError(msg)
|
|
108
|
+
result: Any = metadata[column_name][SHAPE_FLAG]
|
|
109
|
+
|
|
110
|
+
array_res: list[Any] = result if isinstance(result, list) else [result]
|
|
111
|
+
|
|
112
|
+
for i in range(len(array_res)):
|
|
113
|
+
if array_res[i] < 1:
|
|
114
|
+
msg = f"Shape for column {column_name} at position {i} is not a positive integer."
|
|
115
|
+
raise ValueError(msg)
|
|
116
|
+
return result
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable, Iterator
|
|
3
|
+
from typing import Optional, Union, cast
|
|
4
|
+
|
|
5
|
+
import pyarrow.dataset as ds
|
|
6
|
+
import pyarrow.fs as fs
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import IterableDataset
|
|
9
|
+
|
|
10
|
+
from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
|
|
11
|
+
from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
|
|
12
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
13
|
+
from replay.data.nn.parquet.constants.filesystem import DEFAULT_FILESYSTEM
|
|
14
|
+
from replay.data.nn.parquet.impl.masking import (
|
|
15
|
+
DEFAULT_COLLATE_FN,
|
|
16
|
+
DEFAULT_MAKE_MASK_NAME,
|
|
17
|
+
)
|
|
18
|
+
from replay.data.nn.parquet.info.replicas import ReplicasInfoProtocol
|
|
19
|
+
from replay.data.nn.parquet.utils.compute_length import compute_fixed_size_length
|
|
20
|
+
|
|
21
|
+
from .fixed_batch_dataset import FixedBatchSizeDataset
|
|
22
|
+
from .iterator import BatchesIterator
|
|
23
|
+
from .metadata import Metadata
|
|
24
|
+
from .partitioned_iterable_dataset import PartitionedIterableDataset
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ParquetDataset(IterableDataset):
|
|
28
|
+
"""
|
|
29
|
+
Combination dataset and sampler for batch-wise reading and processing of Parquet files.
|
|
30
|
+
|
|
31
|
+
This implementation allows one to read data using a PyArrow Dataset, convert it into structured columns,
|
|
32
|
+
split it into partitions, and then into batches needed for model training.
|
|
33
|
+
Supports distributed training and reproducible random shuffling.
|
|
34
|
+
|
|
35
|
+
During data loader operation, a partition of size ``partition_size`` is read.
|
|
36
|
+
There may be situations where the size of the read partition is less than
|
|
37
|
+
``partition_size`` - this depends on the number of rows in the data fragment.
|
|
38
|
+
A fragment is a single Parquet file in the file system.
|
|
39
|
+
|
|
40
|
+
The partition will be read by every worker, split according to their replica ID,
|
|
41
|
+
processed and the result will be returned as a batch of size ``batch_size``.
|
|
42
|
+
Please note that the resulting batch size may be less than ``batch_size``.
|
|
43
|
+
|
|
44
|
+
For maximum efficiency when reading and processing data, as well as imporved data shuffling,
|
|
45
|
+
it is recommended to set ``partition_size`` to several times larger than ``batch_size``.
|
|
46
|
+
|
|
47
|
+
**Note:**
|
|
48
|
+
|
|
49
|
+
* ``ParquetDataset`` supports only numeric values (boolean/integer/float),
|
|
50
|
+
therefore, the data paths passed as arguments must contain encoded data.
|
|
51
|
+
* For optimal performance, set the ``OMP_NUM_THREADS`` and ``ARROW_IO_THREADS`` to match
|
|
52
|
+
the number of available CPU cores.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
source: Union[str, list[str]],
|
|
59
|
+
metadata: Metadata,
|
|
60
|
+
partition_size: int,
|
|
61
|
+
batch_size: int,
|
|
62
|
+
filesystem: Union[str, fs.FileSystem] = DEFAULT_FILESYSTEM,
|
|
63
|
+
make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
|
|
64
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
65
|
+
generator: Optional[torch.Generator] = None,
|
|
66
|
+
replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
|
|
67
|
+
collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
|
|
68
|
+
**kwargs,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
:param source: The path or list of paths to files/directories containing data in Parquet format.
|
|
72
|
+
:param metadata: Metadata describing the data structure.
|
|
73
|
+
The structure of each column is defined by the following values:
|
|
74
|
+
|
|
75
|
+
``shape`` - the dimension of the column being read.
|
|
76
|
+
If the column contains only one value, this parameter does not need to be specified.
|
|
77
|
+
If the column contains a one-dimensional array, the parameter must be a number or an array
|
|
78
|
+
containing one number.
|
|
79
|
+
If the column contains a two-dimensional array, the parameter
|
|
80
|
+
must be an array containing two numbers.
|
|
81
|
+
|
|
82
|
+
``padding`` - padding value that will fill the arrays if their length is less
|
|
83
|
+
than that specified in the `shape` parameter.
|
|
84
|
+
:param partition_size: Partition size when reading data from Parquet files.
|
|
85
|
+
:param batch_size: The size of the batch that will be returned during iteration.
|
|
86
|
+
:param filesystem: A PyArrow's Filesystem object used to access data, or a URI-based path
|
|
87
|
+
to infer the filesystem from. Default: value of ``DEFAULT_FILESYSTEM``.
|
|
88
|
+
:param make_mask_name: Mask name generation function. Default: value of ``DEFAULT_MAKE_MASK_NAME``.
|
|
89
|
+
:param device: The device on which the data will be generated. Defaults: value of ``DEFAULT_DEVICE``.
|
|
90
|
+
:param generator: Random number generator for batch shuffling.
|
|
91
|
+
If ``None``, shuffling will be disabled. Default: ``None``.
|
|
92
|
+
:param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
|
|
93
|
+
Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
|
|
94
|
+
``torch.utils.data`` and ``torch.distributed`` modules.
|
|
95
|
+
:param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
|
|
96
|
+
"""
|
|
97
|
+
if partition_size // batch_size < 20:
|
|
98
|
+
msg = (
|
|
99
|
+
"Suboptimal parameters: partition to batch size ratio too low. "
|
|
100
|
+
"Recommended proportion of partition size to batch size is at least 20:1. "
|
|
101
|
+
f"Got: {partition_size=}, {batch_size=}."
|
|
102
|
+
)
|
|
103
|
+
warnings.warn(msg, stacklevel=2)
|
|
104
|
+
|
|
105
|
+
if (partition_size % batch_size) != 0:
|
|
106
|
+
msg = (
|
|
107
|
+
"Suboptimal parameters: partition size is not multiple of batch size. "
|
|
108
|
+
f"Got: {partition_size=}, {batch_size=}."
|
|
109
|
+
)
|
|
110
|
+
warnings.warn(msg, stacklevel=2)
|
|
111
|
+
|
|
112
|
+
if isinstance(filesystem, str):
|
|
113
|
+
filesystem, _ = fs.FileSystem.from_uri(filesystem)
|
|
114
|
+
assert isinstance(filesystem, fs.FileSystem)
|
|
115
|
+
self.filesystem = cast(fs.FileSystem, filesystem)
|
|
116
|
+
|
|
117
|
+
self.pyarrow_dataset = ds.dataset(
|
|
118
|
+
source,
|
|
119
|
+
filesystem=self.filesystem,
|
|
120
|
+
format="parquet",
|
|
121
|
+
**kwargs.get("pyarrow_dataset_kwargs", {}),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
self.batch_size = batch_size
|
|
125
|
+
self.partition_size = partition_size
|
|
126
|
+
self.replicas_info = replicas_info
|
|
127
|
+
self.metadata = metadata
|
|
128
|
+
|
|
129
|
+
self.iterator = BatchesIterator(
|
|
130
|
+
dataset=self.pyarrow_dataset,
|
|
131
|
+
metadata=self.metadata,
|
|
132
|
+
batch_size=partition_size,
|
|
133
|
+
device=device,
|
|
134
|
+
make_mask_name=make_mask_name,
|
|
135
|
+
pyarrow_kwargs=kwargs.get("pyarrow_to_batches_kwargs", {}),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.raw_dataset = PartitionedIterableDataset(
|
|
139
|
+
batch_size=batch_size,
|
|
140
|
+
iterable=self.iterator,
|
|
141
|
+
generator=generator,
|
|
142
|
+
replicas_info=replicas_info,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
self.dataset = FixedBatchSizeDataset(
|
|
146
|
+
dataset=self.raw_dataset,
|
|
147
|
+
batch_size=batch_size,
|
|
148
|
+
collate_fn=collate_fn,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
self.do_compute_length = True
|
|
152
|
+
self.cached_lengths: dict[int, int] = {}
|
|
153
|
+
|
|
154
|
+
def compute_length(self) -> int:
|
|
155
|
+
"""Returns the length of the dataset counted in fixed-size batches."""
|
|
156
|
+
num_replicas = self.replicas_info.num_replicas
|
|
157
|
+
if num_replicas not in self.cached_lengths:
|
|
158
|
+
if len(self.cached_lengths) > 0:
|
|
159
|
+
msg = "`num_replicas` changed. Unable to reuse cached length."
|
|
160
|
+
warnings.warn(msg, stacklevel=2)
|
|
161
|
+
curr_length = compute_fixed_size_length(
|
|
162
|
+
iterable=self.iterator,
|
|
163
|
+
num_replicas=num_replicas,
|
|
164
|
+
batch_size=self.batch_size,
|
|
165
|
+
)
|
|
166
|
+
self.cached_lengths[num_replicas] = curr_length
|
|
167
|
+
return self.cached_lengths[num_replicas]
|
|
168
|
+
|
|
169
|
+
def __len__(self) -> int:
|
|
170
|
+
if self.do_compute_length:
|
|
171
|
+
return self.compute_length()
|
|
172
|
+
msg = "This instance doesn't support `len()` method. You can enable it by setting `do_compute_length=True`."
|
|
173
|
+
raise TypeError(msg)
|
|
174
|
+
|
|
175
|
+
def __iter__(self) -> Iterator[GeneralBatch]:
|
|
176
|
+
return iter(self.dataset)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Literal, Optional, Union, get_args
|
|
5
|
+
|
|
6
|
+
import lightning as L # noqa: N812
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.trainer.states import RunningStage
|
|
9
|
+
from lightning.pytorch.utilities import CombinedLoader
|
|
10
|
+
from typing_extensions import TypeAlias, override
|
|
11
|
+
|
|
12
|
+
from replay.data.nn.parquet.constants.filesystem import DEFAULT_FILESYSTEM
|
|
13
|
+
from replay.data.nn.parquet.impl.masking import (
|
|
14
|
+
DEFAULT_COLLATE_FN,
|
|
15
|
+
DEFAULT_MAKE_MASK_NAME,
|
|
16
|
+
DEFAULT_REPLICAS_INFO,
|
|
17
|
+
)
|
|
18
|
+
from replay.data.nn.parquet.parquet_dataset import ParquetDataset
|
|
19
|
+
|
|
20
|
+
TransformStage: TypeAlias = Literal["train", "validate", "test", "predict"]
|
|
21
|
+
|
|
22
|
+
DEFAULT_CONFIG = {"train": {"generator": torch.default_generator}}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ParquetModule(L.LightningDataModule):
|
|
26
|
+
"""
|
|
27
|
+
Standardized DataModule with batch-wise support via `ParquetDataset`.
|
|
28
|
+
|
|
29
|
+
Allows for unified access to all data splits across the training/inference pipeline without loading
|
|
30
|
+
full dataset into memory. See the :ref:`parquet-processing` section for details.
|
|
31
|
+
|
|
32
|
+
ParquetModule provides per batch data loading and preprocessing via transform pipelines.
|
|
33
|
+
See the :ref:`Transforms` section for getting info about available batch transforms.
|
|
34
|
+
|
|
35
|
+
**Note:**
|
|
36
|
+
|
|
37
|
+
* ``ParquetModule`` supports only numeric values (boolean/integer/float),
|
|
38
|
+
therefore, the data paths passed as arguments must contain encoded data.
|
|
39
|
+
* For optimal performance, set the OMP_NUM_THREADS and ARROW_IO_THREADS to match
|
|
40
|
+
the number of available CPU cores.
|
|
41
|
+
* It's possible to use all train/validate/test/predict splits, then paths to splits should be passed
|
|
42
|
+
as corresponding arguments of ``ParquetModule``.
|
|
43
|
+
Alternatively, all the paths to the splits may be not specified
|
|
44
|
+
but then do not forget to configure the Pytorch Lightning Trainer's instance accordingly.
|
|
45
|
+
For example, if you don't want use validation data, you are able not to set ``validate_path`` parameter
|
|
46
|
+
in ``ParquetModule`` and set ``limit_val_batches=0`` in Ligthning.Trainer.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
batch_size: int,
|
|
53
|
+
metadata: dict,
|
|
54
|
+
transforms: dict[TransformStage, list[torch.nn.Module]],
|
|
55
|
+
config: Optional[dict] = None,
|
|
56
|
+
*,
|
|
57
|
+
train_path: Optional[str] = None,
|
|
58
|
+
validate_path: Optional[Union[str, list[str]]] = None,
|
|
59
|
+
test_path: Optional[Union[str, list[str]]] = None,
|
|
60
|
+
predict_path: Optional[Union[str, list[str]]] = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
:param batch_size: Target batch size.
|
|
64
|
+
:param metadata: A dictionary that each data split maps to a dictionary of feature names
|
|
65
|
+
with each feature is associated with its shape and padding_value.\n
|
|
66
|
+
Example: {"train": {"item_id" : {"shape": 100, "padding_value": 7657}}}.\n
|
|
67
|
+
For details, see the section :ref:`parquet-processing`.
|
|
68
|
+
:param config: Dict specifying configuration options of ``ParquetDataset`` (generator,
|
|
69
|
+
filesystem, collate_fn, make_mask_name, replicas_info) for each data split.
|
|
70
|
+
Default: ``DEFAULT_CONFIG``.\n
|
|
71
|
+
In most scenarios, the default configuration is sufficient.
|
|
72
|
+
:param transforms: Dict specifying sequence of Transform modules for each data split.
|
|
73
|
+
:param train_path: Path to the Parquet file containing train data split. Default: ``None``.
|
|
74
|
+
:param validate_path: Path to the Parquet file or files containing validation data split. Default: ``None``.
|
|
75
|
+
:param test_path: Path to the Parquet file or files containing testing data split. Default: ``None``.
|
|
76
|
+
:param predict_path: Path to the Parquet file or files containing prediction data split. Default: ``None``.
|
|
77
|
+
"""
|
|
78
|
+
if not any([train_path, validate_path, test_path, predict_path]):
|
|
79
|
+
msg = (
|
|
80
|
+
f"{type(self)}.__init__() expects at least one of "
|
|
81
|
+
"['train_path', 'val_path', 'test_path', 'predict_path], but none were provided."
|
|
82
|
+
)
|
|
83
|
+
raise KeyError(msg)
|
|
84
|
+
|
|
85
|
+
if train_path and not isinstance(train_path, str) and isinstance(train_path, Iterable):
|
|
86
|
+
msg = "'train_path' does not support multiple datapaths."
|
|
87
|
+
raise TypeError(msg)
|
|
88
|
+
|
|
89
|
+
super().__init__()
|
|
90
|
+
if config is None:
|
|
91
|
+
config = DEFAULT_CONFIG
|
|
92
|
+
|
|
93
|
+
self.datapaths = {"train": train_path, "validate": validate_path, "test": test_path, "predict": predict_path}
|
|
94
|
+
missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
|
|
95
|
+
if missing_splits:
|
|
96
|
+
msg = (
|
|
97
|
+
f"The following dataset paths aren't provided: {','.join(missing_splits)}."
|
|
98
|
+
"Make sure to disable these stages in your Lightning Trainer configuration."
|
|
99
|
+
)
|
|
100
|
+
warnings.warn(msg, stacklevel=2)
|
|
101
|
+
|
|
102
|
+
self.metadata = copy.deepcopy(metadata)
|
|
103
|
+
self.batch_size = batch_size
|
|
104
|
+
self.config = config
|
|
105
|
+
|
|
106
|
+
self.datasets: dict[str, Union[ParquetDataset, CombinedLoader]] = {}
|
|
107
|
+
self.transforms = transforms
|
|
108
|
+
self.compiled_transforms = self.prepare_transforms(transforms)
|
|
109
|
+
|
|
110
|
+
def prepare_transforms(
|
|
111
|
+
self, transforms: dict[TransformStage, list[torch.nn.Module]]
|
|
112
|
+
) -> dict[TransformStage, torch.nn.Sequential]:
|
|
113
|
+
"""
|
|
114
|
+
Preform meta adjustments based on provided transform pipelines,
|
|
115
|
+
then compile each subset into a `torch.nn.Sequential` module.
|
|
116
|
+
|
|
117
|
+
:param: transforms: Python dict where keys are names of stage (train, validate, test, predict)
|
|
118
|
+
and values are corresponding transform pipelines for every stage.
|
|
119
|
+
:returns: out: Compiled transform pipelines.
|
|
120
|
+
"""
|
|
121
|
+
if not any(subset in get_args(TransformStage) for subset in transforms):
|
|
122
|
+
msg = (
|
|
123
|
+
f"Expected transform.keys()={list(transforms.keys())} to contain at least "
|
|
124
|
+
f"one of {get_args(TransformStage)}, but none were found."
|
|
125
|
+
)
|
|
126
|
+
raise KeyError(msg)
|
|
127
|
+
|
|
128
|
+
compiled_transorms = {}
|
|
129
|
+
for subset, transform_set in transforms.items():
|
|
130
|
+
compiled_transorms[subset] = torch.nn.Sequential(*transform_set)
|
|
131
|
+
|
|
132
|
+
return compiled_transorms
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def setup(self, stage):
|
|
136
|
+
for subset in get_args(TransformStage):
|
|
137
|
+
subset_datapaths = self.datapaths.get(subset, None)
|
|
138
|
+
if subset_datapaths is not None:
|
|
139
|
+
subset_config = self.config.get(subset, {})
|
|
140
|
+
shared_kwargs = {
|
|
141
|
+
"metadata": self.metadata[subset],
|
|
142
|
+
"batch_size": self.batch_size,
|
|
143
|
+
"partition_size": subset_config.get("partition_size", 2**17),
|
|
144
|
+
"generator": subset_config.get("generator", None),
|
|
145
|
+
"filesystem": subset_config.get("filesystem", DEFAULT_FILESYSTEM),
|
|
146
|
+
"make_mask_name": subset_config.get("make_mask_name", DEFAULT_MAKE_MASK_NAME),
|
|
147
|
+
"replicas_info": subset_config.get("replicas_info", DEFAULT_REPLICAS_INFO),
|
|
148
|
+
"collate_fn": subset_config.get("collate_fn", DEFAULT_COLLATE_FN),
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
if isinstance(subset_datapaths, list):
|
|
152
|
+
loaders = [ParquetDataset(**{"source": path, **shared_kwargs}) for path in subset_datapaths]
|
|
153
|
+
self.datasets[subset] = CombinedLoader(loaders, mode="sequential")
|
|
154
|
+
else:
|
|
155
|
+
self.datasets[subset] = ParquetDataset(**{"source": subset_datapaths, **shared_kwargs})
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
def train_dataloader(self):
|
|
159
|
+
return self.datasets["train"]
|
|
160
|
+
|
|
161
|
+
@override
|
|
162
|
+
def val_dataloader(self):
|
|
163
|
+
return self.datasets["validate"]
|
|
164
|
+
|
|
165
|
+
@override
|
|
166
|
+
def test_dataloader(self):
|
|
167
|
+
return self.datasets["test"]
|
|
168
|
+
|
|
169
|
+
@override
|
|
170
|
+
def predict_dataloader(self):
|
|
171
|
+
return self.datasets["predict"]
|
|
172
|
+
|
|
173
|
+
@override
|
|
174
|
+
def on_after_batch_transfer(self, batch, _dataloader_idx):
|
|
175
|
+
stage = self.trainer.state.stage
|
|
176
|
+
target = RunningStage.VALIDATING if stage is RunningStage.SANITY_CHECKING else stage
|
|
177
|
+
|
|
178
|
+
return self.compiled_transforms[str(target.value)](batch)
|