replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,119 @@
1
+ from collections.abc import Iterator
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.utils.data as data
6
+
7
+ from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
8
+ from replay.data.utils.batching import UniformBatching, uniform_batch_count
9
+
10
+ from .impl.named_columns import NamedColumns
11
+ from .info.partitioning import Partitioning, partitioning_per_replica
12
+ from .info.replicas import ReplicasInfoProtocol
13
+
14
+ Batch = dict[str, torch.Tensor]
15
+
16
+
17
+ def validate_batch_size(batch_size: int) -> int:
18
+ if batch_size <= 0:
19
+ msg = f"batch_size must be a positive integer. Got {batch_size=}"
20
+ raise ValueError(msg)
21
+ return batch_size
22
+
23
+
24
+ class IterableDataset(data.IterableDataset):
25
+ """
26
+ An iterable dataset used for processing a single partition of data.
27
+ Supports distributed training, where data is divided between replicas, and reproducible random shuffling.
28
+
29
+ A replica is a worker or a set of workers for which a unique chunk of data will be assigned
30
+ during distributed training/inference.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ named_columns: NamedColumns,
36
+ batch_size: int,
37
+ generator: Optional[torch.Generator] = None,
38
+ replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
39
+ ) -> None:
40
+ """
41
+ :param named_columns: Structured data presented as columns.
42
+ :param batch_size: Batch size.
43
+ :param generator: Random number generator for batch shuffling.
44
+ If ``None``, shuffling will be disabled. Default: ``None``.
45
+ :param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
46
+ Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
47
+ """
48
+ super().__init__()
49
+
50
+ self.named_columns = named_columns
51
+ self.generator = generator
52
+ self.replicas_info = replicas_info
53
+ self.batch_size = validate_batch_size(batch_size)
54
+
55
+ @property
56
+ def device(self) -> torch.device:
57
+ """Returns the device containing the dataset."""
58
+ return self.named_columns.device
59
+
60
+ @property
61
+ def full_length(self) -> int:
62
+ """Returns the total amount of elements in `named_columns`."""
63
+ return self.named_columns.length
64
+
65
+ @property
66
+ def length_per_replica(self) -> int:
67
+ """Returns the total number of available elements per replica."""
68
+ full_length = self.named_columns.length
69
+ num_replicas = self.replicas_info.num_replicas
70
+ return partitioning_per_replica(full_length, num_replicas)
71
+
72
+ @property
73
+ def length(self) -> int:
74
+ """Returns the total number of batches available to the current replica."""
75
+ batch_size = self.batch_size
76
+ per_replica = self.length_per_replica
77
+ return uniform_batch_count(per_replica, batch_size)
78
+
79
+ def __len__(self) -> int:
80
+ """Returns the total number of batches in a dataset."""
81
+ return self.length
82
+
83
+ def get_indices(self) -> torch.LongTensor:
84
+ """
85
+ Generates indices corresponding to data assigned to current replica.
86
+
87
+ :return: tensor containing relevant indices.
88
+ """
89
+ partitioning = Partitioning(
90
+ curr_replica=self.replicas_info.curr_replica,
91
+ num_replicas=self.replicas_info.num_replicas,
92
+ device=self.named_columns.device,
93
+ generator=self.generator,
94
+ )
95
+ indices = partitioning(self.full_length)
96
+ assert self.length_per_replica == torch.numel(indices)
97
+ return indices
98
+
99
+ def get_batching(self) -> UniformBatching:
100
+ """
101
+ Creates a partitioning object which splits data into batches.
102
+
103
+ :return: The partitioning object.
104
+ """
105
+ batching = UniformBatching(
106
+ length=self.length_per_replica,
107
+ batch_size=self.batch_size,
108
+ )
109
+ assert len(batching) == self.length
110
+ return batching
111
+
112
+ def __iter__(self) -> Iterator[Batch]:
113
+ """Batched data iterator."""
114
+ batching = self.get_batching()
115
+ indices = self.get_indices()
116
+
117
+ for first, last in iter(batching):
118
+ batch_ids = indices[first:last]
119
+ yield self.named_columns[batch_ids]
@@ -0,0 +1,61 @@
1
+ from collections.abc import Iterator
2
+ from typing import Any, Callable, Optional
3
+
4
+ import pyarrow.dataset as da
5
+ import torch
6
+
7
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
8
+ from replay.data.nn.parquet.impl.masking import DEFAULT_MAKE_MASK_NAME
9
+
10
+ from .impl.array_1d_column import to_array_1d_columns
11
+ from .impl.array_2d_column import to_array_2d_columns
12
+ from .impl.named_columns import NamedColumns
13
+ from .impl.numeric_column import to_numeric_columns
14
+ from .metadata import Metadata
15
+
16
+
17
+ class BatchesIterator:
18
+ """Iterator for batch-wise extraction of data from a Parquet dataset with conversion to structured columns."""
19
+
20
+ def __init__(
21
+ self,
22
+ metadata: Metadata,
23
+ dataset: da.Dataset,
24
+ batch_size: int,
25
+ make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
26
+ device: torch.device = DEFAULT_DEVICE,
27
+ pyarrow_kwargs: Optional[dict[str, Any]] = None,
28
+ ) -> None:
29
+ """
30
+ :param metadata: Metadata describing the structure and types of input data.
31
+ :param dataset: Pyarrow dataset implementing the ``to_batches`` method.
32
+ :param batch_size: Batch size sampled from a single partition.
33
+ Resulting batch will not always match it in size due to mismatches between
34
+ the target batch size and the partition size.
35
+ :param make_mask_name: Mask name generation function. Default: value of ``DEFAULT_MAKE_MASK_NAME``.
36
+ :param device: The device on which the data will be generated. Defaults: value of ``DEFAULT_DEVICE``.
37
+ :param pyarrow_kwargs: Additional parameters for PyArrow dataset's ``to_batches`` method. Default: ``None``.
38
+ """
39
+ if pyarrow_kwargs is None:
40
+ pyarrow_kwargs = {}
41
+ self.dataset = dataset
42
+ self.metadata = metadata
43
+ self.batch_size = batch_size
44
+ self.make_mask_name = make_mask_name
45
+ self.device = device
46
+ self.pyarrow_kwargs = pyarrow_kwargs
47
+
48
+ def __iter__(self) -> Iterator[NamedColumns]:
49
+ for batch in self.dataset.to_batches(
50
+ batch_size=self.batch_size,
51
+ columns=list(self.metadata.keys()),
52
+ **self.pyarrow_kwargs,
53
+ ):
54
+ yield NamedColumns(
55
+ columns={
56
+ **to_numeric_columns(batch, self.metadata, self.device),
57
+ **to_array_1d_columns(batch, self.metadata, self.device),
58
+ **to_array_2d_columns(batch, self.metadata, self.device),
59
+ },
60
+ make_mask_name=self.make_mask_name,
61
+ )
@@ -0,0 +1,19 @@
1
+ from .metadata import (
2
+ ColumnMetadata,
3
+ Metadata,
4
+ get_1d_array_columns,
5
+ get_2d_array_columns,
6
+ get_numeric_columns,
7
+ get_padding,
8
+ get_shape,
9
+ )
10
+
11
+ __all__ = [
12
+ "ColumnMetadata",
13
+ "Metadata",
14
+ "get_1d_array_columns",
15
+ "get_2d_array_columns",
16
+ "get_numeric_columns",
17
+ "get_padding",
18
+ "get_shape",
19
+ ]
@@ -0,0 +1,116 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, Union
3
+
4
+ from typing_extensions import TypeAlias
5
+
6
+ from replay.data.nn.parquet.constants.metadata import (
7
+ DEFAULT_PADDING,
8
+ PADDING_FLAG,
9
+ SHAPE_FLAG,
10
+ )
11
+
12
+ FieldType: TypeAlias = Union[bool, int, float, str]
13
+ ColumnMetadata: TypeAlias = dict[str, FieldType]
14
+ Metadata: TypeAlias = dict[str, ColumnMetadata]
15
+
16
+ ColumnCheck: TypeAlias = Callable[[ColumnMetadata], bool]
17
+ CheckColumn: TypeAlias = Callable[[ColumnCheck], bool]
18
+ Listing: TypeAlias = Callable[[Metadata], list[str]]
19
+
20
+
21
+ def make_shape_check(dim: int) -> ColumnCheck:
22
+ """
23
+ Constructs a function which checks a column's shape.
24
+
25
+ :param dim: Target number of dimensions.
26
+ """
27
+
28
+ def function(column_metadata: ColumnMetadata) -> bool:
29
+ if SHAPE_FLAG in column_metadata:
30
+ value: Any = column_metadata[SHAPE_FLAG]
31
+ if dim == 1 and isinstance(value, int):
32
+ return True
33
+ if isinstance(value, list):
34
+ result: bool = len(value) == dim
35
+ if result:
36
+
37
+ def is_int(v: Any) -> bool:
38
+ return isinstance(v, int)
39
+
40
+ result &= all(map(is_int, value))
41
+ return result
42
+ return False
43
+
44
+ return function
45
+
46
+
47
+ def make_not_check(check: ColumnCheck) -> ColumnCheck:
48
+ def function(column_metadata: ColumnCheck) -> bool:
49
+ return not check(column_metadata)
50
+
51
+ return function
52
+
53
+
54
+ def all_column_checks(*checks: ColumnCheck) -> ColumnCheck:
55
+ def function(column_metadata: ColumnMetadata) -> bool:
56
+ def perform_check(check):
57
+ return check(column_metadata)
58
+
59
+ return all(map(perform_check, checks))
60
+
61
+ return function
62
+
63
+
64
+ is_array_1d = all_column_checks(make_shape_check(dim=1))
65
+ is_array_2d = all_column_checks(make_shape_check(dim=2))
66
+ is_number = all_column_checks(
67
+ make_not_check(is_array_1d),
68
+ make_not_check(is_array_2d),
69
+ )
70
+
71
+
72
+ def make_listing(check: ColumnCheck) -> Listing:
73
+ """
74
+ Filtering function for selecting columns that pass the provided check.
75
+
76
+ :param check: Check function to validate agains.
77
+ """
78
+
79
+ def function(metadata: Metadata) -> list[str]:
80
+ result: list[str] = []
81
+ for col_name, col_meta in metadata.items():
82
+ if check(col_meta):
83
+ result.append(col_name)
84
+ return sorted(result)
85
+
86
+ return function
87
+
88
+
89
+ get_1d_array_columns = make_listing(is_array_1d)
90
+ get_2d_array_columns = make_listing(is_array_2d)
91
+ get_numeric_columns = make_listing(is_number)
92
+
93
+
94
+ def get_padding(metadata: Metadata, column_name: str) -> Any:
95
+ if column_name not in metadata:
96
+ msg = f"Column {column_name} not found in metadata."
97
+ raise KeyError(msg)
98
+ return metadata[column_name].get(PADDING_FLAG, DEFAULT_PADDING)
99
+
100
+
101
+ def get_shape(metadata: Metadata, column_name: str) -> list[int]:
102
+ if column_name not in metadata:
103
+ msg = f"Column {column_name} not found in metadata."
104
+ raise KeyError(msg)
105
+ if is_number(metadata[column_name]):
106
+ msg = f"Column {column_name} is not an array."
107
+ raise ValueError(msg)
108
+ result: Any = metadata[column_name][SHAPE_FLAG]
109
+
110
+ array_res: list[Any] = result if isinstance(result, list) else [result]
111
+
112
+ for i in range(len(array_res)):
113
+ if array_res[i] < 1:
114
+ msg = f"Shape for column {column_name} at position {i} is not a positive integer."
115
+ raise ValueError(msg)
116
+ return result
@@ -0,0 +1,176 @@
1
+ import warnings
2
+ from collections.abc import Callable, Iterator
3
+ from typing import Optional, Union, cast
4
+
5
+ import pyarrow.dataset as ds
6
+ import pyarrow.fs as fs
7
+ import torch
8
+ from torch.utils.data import IterableDataset
9
+
10
+ from replay.data.nn.parquet import DEFAULT_REPLICAS_INFO
11
+ from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
12
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
13
+ from replay.data.nn.parquet.constants.filesystem import DEFAULT_FILESYSTEM
14
+ from replay.data.nn.parquet.impl.masking import (
15
+ DEFAULT_COLLATE_FN,
16
+ DEFAULT_MAKE_MASK_NAME,
17
+ )
18
+ from replay.data.nn.parquet.info.replicas import ReplicasInfoProtocol
19
+ from replay.data.nn.parquet.utils.compute_length import compute_fixed_size_length
20
+
21
+ from .fixed_batch_dataset import FixedBatchSizeDataset
22
+ from .iterator import BatchesIterator
23
+ from .metadata import Metadata
24
+ from .partitioned_iterable_dataset import PartitionedIterableDataset
25
+
26
+
27
+ class ParquetDataset(IterableDataset):
28
+ """
29
+ Combination dataset and sampler for batch-wise reading and processing of Parquet files.
30
+
31
+ This implementation allows one to read data using a PyArrow Dataset, convert it into structured columns,
32
+ split it into partitions, and then into batches needed for model training.
33
+ Supports distributed training and reproducible random shuffling.
34
+
35
+ During data loader operation, a partition of size ``partition_size`` is read.
36
+ There may be situations where the size of the read partition is less than
37
+ ``partition_size`` - this depends on the number of rows in the data fragment.
38
+ A fragment is a single Parquet file in the file system.
39
+
40
+ The partition will be read by every worker, split according to their replica ID,
41
+ processed and the result will be returned as a batch of size ``batch_size``.
42
+ Please note that the resulting batch size may be less than ``batch_size``.
43
+
44
+ For maximum efficiency when reading and processing data, as well as imporved data shuffling,
45
+ it is recommended to set ``partition_size`` to several times larger than ``batch_size``.
46
+
47
+ **Note:**
48
+
49
+ * ``ParquetDataset`` supports only numeric values (boolean/integer/float),
50
+ therefore, the data paths passed as arguments must contain encoded data.
51
+ * For optimal performance, set the ``OMP_NUM_THREADS`` and ``ARROW_IO_THREADS`` to match
52
+ the number of available CPU cores.
53
+
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ source: Union[str, list[str]],
59
+ metadata: Metadata,
60
+ partition_size: int,
61
+ batch_size: int,
62
+ filesystem: Union[str, fs.FileSystem] = DEFAULT_FILESYSTEM,
63
+ make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
64
+ device: torch.device = DEFAULT_DEVICE,
65
+ generator: Optional[torch.Generator] = None,
66
+ replicas_info: ReplicasInfoProtocol = DEFAULT_REPLICAS_INFO,
67
+ collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
68
+ **kwargs,
69
+ ) -> None:
70
+ """
71
+ :param source: The path or list of paths to files/directories containing data in Parquet format.
72
+ :param metadata: Metadata describing the data structure.
73
+ The structure of each column is defined by the following values:
74
+
75
+ ``shape`` - the dimension of the column being read.
76
+ If the column contains only one value, this parameter does not need to be specified.
77
+ If the column contains a one-dimensional array, the parameter must be a number or an array
78
+ containing one number.
79
+ If the column contains a two-dimensional array, the parameter
80
+ must be an array containing two numbers.
81
+
82
+ ``padding`` - padding value that will fill the arrays if their length is less
83
+ than that specified in the `shape` parameter.
84
+ :param partition_size: Partition size when reading data from Parquet files.
85
+ :param batch_size: The size of the batch that will be returned during iteration.
86
+ :param filesystem: A PyArrow's Filesystem object used to access data, or a URI-based path
87
+ to infer the filesystem from. Default: value of ``DEFAULT_FILESYSTEM``.
88
+ :param make_mask_name: Mask name generation function. Default: value of ``DEFAULT_MAKE_MASK_NAME``.
89
+ :param device: The device on which the data will be generated. Defaults: value of ``DEFAULT_DEVICE``.
90
+ :param generator: Random number generator for batch shuffling.
91
+ If ``None``, shuffling will be disabled. Default: ``None``.
92
+ :param replicas_info: A connector object capable of fetching total replica count and replica id during runtime.
93
+ Default: value of ``DEFAULT_REPLICAS_INFO`` - a pre-built connector which assumes standard Torch DDP mode.
94
+ ``torch.utils.data`` and ``torch.distributed`` modules.
95
+ :param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
96
+ """
97
+ if partition_size // batch_size < 20:
98
+ msg = (
99
+ "Suboptimal parameters: partition to batch size ratio too low. "
100
+ "Recommended proportion of partition size to batch size is at least 20:1. "
101
+ f"Got: {partition_size=}, {batch_size=}."
102
+ )
103
+ warnings.warn(msg, stacklevel=2)
104
+
105
+ if (partition_size % batch_size) != 0:
106
+ msg = (
107
+ "Suboptimal parameters: partition size is not multiple of batch size. "
108
+ f"Got: {partition_size=}, {batch_size=}."
109
+ )
110
+ warnings.warn(msg, stacklevel=2)
111
+
112
+ if isinstance(filesystem, str):
113
+ filesystem, _ = fs.FileSystem.from_uri(filesystem)
114
+ assert isinstance(filesystem, fs.FileSystem)
115
+ self.filesystem = cast(fs.FileSystem, filesystem)
116
+
117
+ self.pyarrow_dataset = ds.dataset(
118
+ source,
119
+ filesystem=self.filesystem,
120
+ format="parquet",
121
+ **kwargs.get("pyarrow_dataset_kwargs", {}),
122
+ )
123
+
124
+ self.batch_size = batch_size
125
+ self.partition_size = partition_size
126
+ self.replicas_info = replicas_info
127
+ self.metadata = metadata
128
+
129
+ self.iterator = BatchesIterator(
130
+ dataset=self.pyarrow_dataset,
131
+ metadata=self.metadata,
132
+ batch_size=partition_size,
133
+ device=device,
134
+ make_mask_name=make_mask_name,
135
+ pyarrow_kwargs=kwargs.get("pyarrow_to_batches_kwargs", {}),
136
+ )
137
+
138
+ self.raw_dataset = PartitionedIterableDataset(
139
+ batch_size=batch_size,
140
+ iterable=self.iterator,
141
+ generator=generator,
142
+ replicas_info=replicas_info,
143
+ )
144
+
145
+ self.dataset = FixedBatchSizeDataset(
146
+ dataset=self.raw_dataset,
147
+ batch_size=batch_size,
148
+ collate_fn=collate_fn,
149
+ )
150
+
151
+ self.do_compute_length = True
152
+ self.cached_lengths: dict[int, int] = {}
153
+
154
+ def compute_length(self) -> int:
155
+ """Returns the length of the dataset counted in fixed-size batches."""
156
+ num_replicas = self.replicas_info.num_replicas
157
+ if num_replicas not in self.cached_lengths:
158
+ if len(self.cached_lengths) > 0:
159
+ msg = "`num_replicas` changed. Unable to reuse cached length."
160
+ warnings.warn(msg, stacklevel=2)
161
+ curr_length = compute_fixed_size_length(
162
+ iterable=self.iterator,
163
+ num_replicas=num_replicas,
164
+ batch_size=self.batch_size,
165
+ )
166
+ self.cached_lengths[num_replicas] = curr_length
167
+ return self.cached_lengths[num_replicas]
168
+
169
+ def __len__(self) -> int:
170
+ if self.do_compute_length:
171
+ return self.compute_length()
172
+ msg = "This instance doesn't support `len()` method. You can enable it by setting `do_compute_length=True`."
173
+ raise TypeError(msg)
174
+
175
+ def __iter__(self) -> Iterator[GeneralBatch]:
176
+ return iter(self.dataset)
@@ -0,0 +1,178 @@
1
+ import copy
2
+ import warnings
3
+ from collections.abc import Iterable
4
+ from typing import Literal, Optional, Union, get_args
5
+
6
+ import lightning as L # noqa: N812
7
+ import torch
8
+ from lightning.pytorch.trainer.states import RunningStage
9
+ from lightning.pytorch.utilities import CombinedLoader
10
+ from typing_extensions import TypeAlias, override
11
+
12
+ from replay.data.nn.parquet.constants.filesystem import DEFAULT_FILESYSTEM
13
+ from replay.data.nn.parquet.impl.masking import (
14
+ DEFAULT_COLLATE_FN,
15
+ DEFAULT_MAKE_MASK_NAME,
16
+ DEFAULT_REPLICAS_INFO,
17
+ )
18
+ from replay.data.nn.parquet.parquet_dataset import ParquetDataset
19
+
20
+ TransformStage: TypeAlias = Literal["train", "validate", "test", "predict"]
21
+
22
+ DEFAULT_CONFIG = {"train": {"generator": torch.default_generator}}
23
+
24
+
25
+ class ParquetModule(L.LightningDataModule):
26
+ """
27
+ Standardized DataModule with batch-wise support via `ParquetDataset`.
28
+
29
+ Allows for unified access to all data splits across the training/inference pipeline without loading
30
+ full dataset into memory. See the :ref:`parquet-processing` section for details.
31
+
32
+ ParquetModule provides per batch data loading and preprocessing via transform pipelines.
33
+ See the :ref:`Transforms` section for getting info about available batch transforms.
34
+
35
+ **Note:**
36
+
37
+ * ``ParquetModule`` supports only numeric values (boolean/integer/float),
38
+ therefore, the data paths passed as arguments must contain encoded data.
39
+ * For optimal performance, set the OMP_NUM_THREADS and ARROW_IO_THREADS to match
40
+ the number of available CPU cores.
41
+ * It's possible to use all train/validate/test/predict splits, then paths to splits should be passed
42
+ as corresponding arguments of ``ParquetModule``.
43
+ Alternatively, all the paths to the splits may be not specified
44
+ but then do not forget to configure the Pytorch Lightning Trainer's instance accordingly.
45
+ For example, if you don't want use validation data, you are able not to set ``validate_path`` parameter
46
+ in ``ParquetModule`` and set ``limit_val_batches=0`` in Ligthning.Trainer.
47
+
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ batch_size: int,
53
+ metadata: dict,
54
+ transforms: dict[TransformStage, list[torch.nn.Module]],
55
+ config: Optional[dict] = None,
56
+ *,
57
+ train_path: Optional[str] = None,
58
+ validate_path: Optional[Union[str, list[str]]] = None,
59
+ test_path: Optional[Union[str, list[str]]] = None,
60
+ predict_path: Optional[Union[str, list[str]]] = None,
61
+ ) -> None:
62
+ """
63
+ :param batch_size: Target batch size.
64
+ :param metadata: A dictionary that each data split maps to a dictionary of feature names
65
+ with each feature is associated with its shape and padding_value.\n
66
+ Example: {"train": {"item_id" : {"shape": 100, "padding_value": 7657}}}.\n
67
+ For details, see the section :ref:`parquet-processing`.
68
+ :param config: Dict specifying configuration options of ``ParquetDataset`` (generator,
69
+ filesystem, collate_fn, make_mask_name, replicas_info) for each data split.
70
+ Default: ``DEFAULT_CONFIG``.\n
71
+ In most scenarios, the default configuration is sufficient.
72
+ :param transforms: Dict specifying sequence of Transform modules for each data split.
73
+ :param train_path: Path to the Parquet file containing train data split. Default: ``None``.
74
+ :param validate_path: Path to the Parquet file or files containing validation data split. Default: ``None``.
75
+ :param test_path: Path to the Parquet file or files containing testing data split. Default: ``None``.
76
+ :param predict_path: Path to the Parquet file or files containing prediction data split. Default: ``None``.
77
+ """
78
+ if not any([train_path, validate_path, test_path, predict_path]):
79
+ msg = (
80
+ f"{type(self)}.__init__() expects at least one of "
81
+ "['train_path', 'val_path', 'test_path', 'predict_path], but none were provided."
82
+ )
83
+ raise KeyError(msg)
84
+
85
+ if train_path and not isinstance(train_path, str) and isinstance(train_path, Iterable):
86
+ msg = "'train_path' does not support multiple datapaths."
87
+ raise TypeError(msg)
88
+
89
+ super().__init__()
90
+ if config is None:
91
+ config = DEFAULT_CONFIG
92
+
93
+ self.datapaths = {"train": train_path, "validate": validate_path, "test": test_path, "predict": predict_path}
94
+ missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
95
+ if missing_splits:
96
+ msg = (
97
+ f"The following dataset paths aren't provided: {','.join(missing_splits)}."
98
+ "Make sure to disable these stages in your Lightning Trainer configuration."
99
+ )
100
+ warnings.warn(msg, stacklevel=2)
101
+
102
+ self.metadata = copy.deepcopy(metadata)
103
+ self.batch_size = batch_size
104
+ self.config = config
105
+
106
+ self.datasets: dict[str, Union[ParquetDataset, CombinedLoader]] = {}
107
+ self.transforms = transforms
108
+ self.compiled_transforms = self.prepare_transforms(transforms)
109
+
110
+ def prepare_transforms(
111
+ self, transforms: dict[TransformStage, list[torch.nn.Module]]
112
+ ) -> dict[TransformStage, torch.nn.Sequential]:
113
+ """
114
+ Preform meta adjustments based on provided transform pipelines,
115
+ then compile each subset into a `torch.nn.Sequential` module.
116
+
117
+ :param: transforms: Python dict where keys are names of stage (train, validate, test, predict)
118
+ and values are corresponding transform pipelines for every stage.
119
+ :returns: out: Compiled transform pipelines.
120
+ """
121
+ if not any(subset in get_args(TransformStage) for subset in transforms):
122
+ msg = (
123
+ f"Expected transform.keys()={list(transforms.keys())} to contain at least "
124
+ f"one of {get_args(TransformStage)}, but none were found."
125
+ )
126
+ raise KeyError(msg)
127
+
128
+ compiled_transorms = {}
129
+ for subset, transform_set in transforms.items():
130
+ compiled_transorms[subset] = torch.nn.Sequential(*transform_set)
131
+
132
+ return compiled_transorms
133
+
134
+ @override
135
+ def setup(self, stage):
136
+ for subset in get_args(TransformStage):
137
+ subset_datapaths = self.datapaths.get(subset, None)
138
+ if subset_datapaths is not None:
139
+ subset_config = self.config.get(subset, {})
140
+ shared_kwargs = {
141
+ "metadata": self.metadata[subset],
142
+ "batch_size": self.batch_size,
143
+ "partition_size": subset_config.get("partition_size", 2**17),
144
+ "generator": subset_config.get("generator", None),
145
+ "filesystem": subset_config.get("filesystem", DEFAULT_FILESYSTEM),
146
+ "make_mask_name": subset_config.get("make_mask_name", DEFAULT_MAKE_MASK_NAME),
147
+ "replicas_info": subset_config.get("replicas_info", DEFAULT_REPLICAS_INFO),
148
+ "collate_fn": subset_config.get("collate_fn", DEFAULT_COLLATE_FN),
149
+ }
150
+
151
+ if isinstance(subset_datapaths, list):
152
+ loaders = [ParquetDataset(**{"source": path, **shared_kwargs}) for path in subset_datapaths]
153
+ self.datasets[subset] = CombinedLoader(loaders, mode="sequential")
154
+ else:
155
+ self.datasets[subset] = ParquetDataset(**{"source": subset_datapaths, **shared_kwargs})
156
+
157
+ @override
158
+ def train_dataloader(self):
159
+ return self.datasets["train"]
160
+
161
+ @override
162
+ def val_dataloader(self):
163
+ return self.datasets["validate"]
164
+
165
+ @override
166
+ def test_dataloader(self):
167
+ return self.datasets["test"]
168
+
169
+ @override
170
+ def predict_dataloader(self):
171
+ return self.datasets["predict"]
172
+
173
+ @override
174
+ def on_after_batch_transfer(self, batch, _dataloader_idx):
175
+ stage = self.trainer.state.stage
176
+ target = RunningStage.VALIDATING if stage is RunningStage.SANITY_CHECKING else stage
177
+
178
+ return self.compiled_transforms[str(target.value)](batch)