replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/batches.py +8 -0
  7. replay/data/nn/parquet/constants/device.py +3 -0
  8. replay/data/nn/parquet/constants/filesystem.py +3 -0
  9. replay/data/nn/parquet/constants/metadata.py +5 -0
  10. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  11. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  12. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  13. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  14. replay/data/nn/parquet/impl/indexing.py +123 -0
  15. replay/data/nn/parquet/impl/masking.py +20 -0
  16. replay/data/nn/parquet/impl/named_columns.py +100 -0
  17. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  18. replay/data/nn/parquet/impl/utils.py +17 -0
  19. replay/data/nn/parquet/info/distributed_info.py +40 -0
  20. replay/data/nn/parquet/info/partitioning.py +132 -0
  21. replay/data/nn/parquet/info/replicas.py +67 -0
  22. replay/data/nn/parquet/info/worker_info.py +43 -0
  23. replay/data/nn/parquet/iterable_dataset.py +119 -0
  24. replay/data/nn/parquet/iterator.py +61 -0
  25. replay/data/nn/parquet/metadata/__init__.py +19 -0
  26. replay/data/nn/parquet/metadata/metadata.py +116 -0
  27. replay/data/nn/parquet/parquet_dataset.py +176 -0
  28. replay/data/nn/parquet/parquet_module.py +178 -0
  29. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  30. replay/data/nn/parquet/utils/compute_length.py +66 -0
  31. replay/data/nn/schema.py +12 -14
  32. replay/data/nn/sequence_tokenizer.py +5 -0
  33. replay/data/nn/sequential_dataset.py +4 -0
  34. replay/data/nn/torch_sequential_dataset.py +5 -0
  35. replay/data/utils/batching.py +69 -0
  36. replay/data/utils/typing/__init__.py +0 -0
  37. replay/data/utils/typing/dtype.py +65 -0
  38. replay/metrics/torch_metrics_builder.py +20 -14
  39. replay/models/nn/loss/sce.py +2 -7
  40. replay/models/nn/optimizer_utils/__init__.py +6 -1
  41. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  42. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  43. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  44. replay/models/nn/sequential/bert4rec/model.py +11 -11
  45. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  46. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  47. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  48. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  49. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  50. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  51. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  52. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  53. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  54. replay/models/nn/sequential/sasrec/model.py +14 -9
  55. replay/nn/__init__.py +8 -0
  56. replay/nn/agg.py +109 -0
  57. replay/nn/attention.py +158 -0
  58. replay/nn/embedding.py +283 -0
  59. replay/nn/ffn.py +135 -0
  60. replay/nn/head.py +49 -0
  61. replay/nn/lightning/__init__.py +1 -0
  62. replay/nn/lightning/callback/__init__.py +9 -0
  63. replay/nn/lightning/callback/metrics_callback.py +183 -0
  64. replay/nn/lightning/callback/predictions_callback.py +314 -0
  65. replay/nn/lightning/module.py +123 -0
  66. replay/nn/lightning/optimizer.py +60 -0
  67. replay/nn/lightning/postprocessor/__init__.py +2 -0
  68. replay/nn/lightning/postprocessor/_base.py +51 -0
  69. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  70. replay/nn/lightning/scheduler.py +91 -0
  71. replay/nn/loss/__init__.py +22 -0
  72. replay/nn/loss/base.py +197 -0
  73. replay/nn/loss/bce.py +216 -0
  74. replay/nn/loss/ce.py +317 -0
  75. replay/nn/loss/login_ce.py +373 -0
  76. replay/nn/loss/logout_ce.py +230 -0
  77. replay/nn/mask.py +87 -0
  78. replay/nn/normalization.py +9 -0
  79. replay/nn/output.py +37 -0
  80. replay/nn/sequential/__init__.py +9 -0
  81. replay/nn/sequential/sasrec/__init__.py +7 -0
  82. replay/nn/sequential/sasrec/agg.py +53 -0
  83. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  84. replay/nn/sequential/sasrec/model.py +377 -0
  85. replay/nn/sequential/sasrec/transformer.py +107 -0
  86. replay/nn/sequential/twotower/__init__.py +2 -0
  87. replay/nn/sequential/twotower/model.py +674 -0
  88. replay/nn/sequential/twotower/reader.py +89 -0
  89. replay/nn/transform/__init__.py +22 -0
  90. replay/nn/transform/copy.py +38 -0
  91. replay/nn/transform/grouping.py +39 -0
  92. replay/nn/transform/negative_sampling.py +182 -0
  93. replay/nn/transform/next_token.py +100 -0
  94. replay/nn/transform/rename.py +33 -0
  95. replay/nn/transform/reshape.py +41 -0
  96. replay/nn/transform/sequence_roll.py +48 -0
  97. replay/nn/transform/template/__init__.py +2 -0
  98. replay/nn/transform/template/sasrec.py +53 -0
  99. replay/nn/transform/template/twotower.py +22 -0
  100. replay/nn/transform/token_mask.py +69 -0
  101. replay/nn/transform/trim.py +51 -0
  102. replay/nn/utils.py +28 -0
  103. replay/preprocessing/filters.py +128 -0
  104. replay/preprocessing/label_encoder.py +36 -33
  105. replay/preprocessing/utils.py +209 -0
  106. replay/splitters/__init__.py +1 -0
  107. replay/splitters/random_next_n_splitter.py +224 -0
  108. replay/utils/common.py +10 -4
  109. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
  110. replay_rec-0.21.0.dist-info/RECORD +223 -0
  111. replay/experimental/metrics/__init__.py +0 -62
  112. replay/experimental/metrics/base_metric.py +0 -603
  113. replay/experimental/metrics/coverage.py +0 -97
  114. replay/experimental/metrics/experiment.py +0 -175
  115. replay/experimental/metrics/hitrate.py +0 -26
  116. replay/experimental/metrics/map.py +0 -30
  117. replay/experimental/metrics/mrr.py +0 -18
  118. replay/experimental/metrics/ncis_precision.py +0 -31
  119. replay/experimental/metrics/ndcg.py +0 -49
  120. replay/experimental/metrics/precision.py +0 -22
  121. replay/experimental/metrics/recall.py +0 -25
  122. replay/experimental/metrics/rocauc.py +0 -49
  123. replay/experimental/metrics/surprisal.py +0 -90
  124. replay/experimental/metrics/unexpectedness.py +0 -76
  125. replay/experimental/models/__init__.py +0 -50
  126. replay/experimental/models/admm_slim.py +0 -257
  127. replay/experimental/models/base_neighbour_rec.py +0 -200
  128. replay/experimental/models/base_rec.py +0 -1386
  129. replay/experimental/models/base_torch_rec.py +0 -234
  130. replay/experimental/models/cql.py +0 -454
  131. replay/experimental/models/ddpg.py +0 -932
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  133. replay/experimental/models/dt4rec/gpt1.py +0 -401
  134. replay/experimental/models/dt4rec/trainer.py +0 -127
  135. replay/experimental/models/dt4rec/utils.py +0 -264
  136. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  137. replay/experimental/models/hierarchical_recommender.py +0 -331
  138. replay/experimental/models/implicit_wrap.py +0 -131
  139. replay/experimental/models/lightfm_wrap.py +0 -303
  140. replay/experimental/models/mult_vae.py +0 -332
  141. replay/experimental/models/neural_ts.py +0 -986
  142. replay/experimental/models/neuromf.py +0 -406
  143. replay/experimental/models/scala_als.py +0 -293
  144. replay/experimental/models/u_lin_ucb.py +0 -115
  145. replay/experimental/nn/data/__init__.py +0 -1
  146. replay/experimental/nn/data/schema_builder.py +0 -102
  147. replay/experimental/preprocessing/__init__.py +0 -3
  148. replay/experimental/preprocessing/data_preparator.py +0 -839
  149. replay/experimental/preprocessing/padder.py +0 -229
  150. replay/experimental/preprocessing/sequence_generator.py +0 -208
  151. replay/experimental/scenarios/__init__.py +0 -1
  152. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  153. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  154. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  155. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  156. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  157. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  158. replay/experimental/utils/logger.py +0 -24
  159. replay/experimental/utils/model_handler.py +0 -186
  160. replay/experimental/utils/session_handler.py +0 -44
  161. replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
  162. /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
  163. /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
  164. /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
  165. /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
  166. /replay/{experimental → data}/utils/__init__.py +0 -0
  167. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  168. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  169. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,123 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+
5
+
6
+ def raw_get_offsets(lengths: torch.LongTensor) -> torch.LongTensor:
7
+ """
8
+ Performs offset calculation, defined simply as a cumulative sum of
9
+ the provided lengths tensor.
10
+
11
+ :param lengths: A tensor containing lengths of each individual row in a dataset's column.
12
+ :return: A tensor of offsets for each row.
13
+ """
14
+ zero = torch.zeros((1,), device=lengths.device, dtype=torch.int64)
15
+ cumsum = torch.cumsum(lengths, dim=-1)
16
+ return torch.cat([zero, cumsum])
17
+
18
+
19
+ def get_offsets(lengths: torch.LongTensor) -> torch.LongTensor:
20
+ """
21
+ Sanitizes row lengths, then calculates offsets for each row.
22
+ The calculation itself is performed via the ``raw_get_offsets`` method.
23
+
24
+ :param lengths: A tensor containing lengths of each individual row in a dataset's column.
25
+ :raises ValueError: If the lengths tensor is of invalid shape or contains negative values.
26
+
27
+ :return: A tensor of offsets for each row.
28
+ """
29
+ if lengths.ndim != 1:
30
+ msg = f"Lengths must be strictly 1D. Got {lengths.ndim}D."
31
+ raise ValueError(msg)
32
+ min_length = torch.min(lengths.detach()).cpu().item()
33
+ if min_length < 0:
34
+ msg = f"There is a negative length. Got {min_length}."
35
+ raise ValueError(msg)
36
+ return raw_get_offsets(lengths)
37
+
38
+
39
+ LengthType = Union[int, torch.LongTensor]
40
+
41
+
42
+ def raw_get_mask(
43
+ indices: torch.LongTensor,
44
+ offsets: torch.LongTensor,
45
+ length: LengthType,
46
+ ) -> tuple[torch.BoolTensor, torch.LongTensor]:
47
+ """
48
+ Performs mask construction.
49
+ Given the data itself, its offsets and the expected sequence length, returns two tensors.
50
+
51
+ The first tensor is the padding mask, where ``False`` represents a padded value that was not present in the data,
52
+ and ``True`` represents a real element from the dataset.
53
+
54
+ The second tensor is the data itself, left-padded with a 0 to the desired length.
55
+
56
+ :param indices: A tensor of indices to be sampled from the dataset.
57
+ :param offsets: A tensor containing individual offsets for each of the column's rows.
58
+ :param length: THe total number of elements in a dataset's column.
59
+
60
+ :return: Constructed mask.
61
+ """
62
+ length = torch.asarray(length, dtype=torch.int64, device=indices.device)
63
+
64
+ # For every "line", start element index matches the offset, while end is the offset of the next line
65
+ last = offsets[indices + 1]
66
+ first = offsets[indices + 0]
67
+
68
+ per_line = length - (last - first)
69
+
70
+ arange = torch.arange(length, dtype=torch.int64, device=offsets.device)
71
+ raw_indices = (first[:, None] - per_line[:, None]) + arange[None, :]
72
+ mask = (first[:, None] <= raw_indices) & (raw_indices < last[:, None])
73
+
74
+ assert torch.all(torch.sum(mask, dim=-1, dtype=torch.int64) == torch.minimum(last - first, length)).cpu().item()
75
+
76
+ output_indices = torch.where(mask, raw_indices, 0)
77
+ assert torch.all((torch.max(output_indices, dim=-1).values < last) | (last == first)).cpu().item()
78
+ return (mask, output_indices)
79
+
80
+
81
+ def get_mask(
82
+ indices: torch.LongTensor,
83
+ offsets: torch.LongTensor,
84
+ length: LengthType,
85
+ ) -> tuple[torch.BoolTensor, torch.LongTensor]:
86
+ """
87
+ Perform input sanity checks, then contructs a mask from inputs.
88
+ The mask calculation itself is performed via the ``raw_get_mask`` method.
89
+
90
+ :param indices: A tensor of indices to be sampled from the dataset.
91
+ :param offsets: A tensor containing individual offsets for each of the column's rows.
92
+ :param length: THe total number of elements in a dataset's column.
93
+
94
+ :raises ValueError: When mishaped or otherwise invalid arguments are provided.
95
+ :raises IndexError: When sampling indices missing from dataset or none at all.
96
+ :raises RuntimeError: When provided tensors are not on the same device.
97
+
98
+ :return: Constructed mask.
99
+ """
100
+ if torch.asarray(length).cpu().item() < 1:
101
+ msg = f"Length must be a positive number. Got {length}"
102
+ raise ValueError(msg)
103
+ if torch.numel(indices) < 1:
104
+ msg = f"Indices must be non-empty. Got {torch.numel(indices)}."
105
+ raise IndexError(msg)
106
+ if indices.device != offsets.device: # pragma: no cover
107
+ msg = f"Devices must match. Got {indices.device} vs {offsets.device}"
108
+ raise RuntimeError(msg)
109
+ if offsets.ndim != 1:
110
+ msg = f"Offsets must be strictly 1D. Got {offsets.ndim}D."
111
+ raise ValueError(msg)
112
+ min_index = torch.min(indices.detach()).cpu().item()
113
+ if min_index < 0:
114
+ msg = f"Index is too small. Got {min_index}."
115
+ raise IndexError(msg)
116
+ max_index = torch.max(indices.detach()).cpu().item()
117
+ if torch.numel(offsets) < max_index:
118
+ msg = f"Index is too large. Got {max_index}."
119
+ raise IndexError(msg)
120
+ if not torch.all(offsets[:-1] <= offsets[1:]).cpu().item():
121
+ msg = "Offset sequence is not monothonous."
122
+ raise ValueError(msg)
123
+ return raw_get_mask(indices, offsets, length)
@@ -0,0 +1,20 @@
1
+ from typing import Callable
2
+
3
+ from replay.data.nn.parquet.collate import general_collate
4
+ from replay.data.nn.parquet.constants.batches import GeneralCollateFn
5
+ from replay.data.nn.parquet.info.replicas import ReplicasInfo, ReplicasInfoProtocol
6
+
7
+ DEFAULT_COLLATE_FN: GeneralCollateFn = general_collate
8
+
9
+ DEFAULT_MASK_POSTFIX: str = "_mask"
10
+
11
+
12
+ def default_make_mask_name(postfix: str) -> Callable[[str], str]:
13
+ def function(name: str) -> str:
14
+ return f"{name}{postfix}"
15
+
16
+ return function
17
+
18
+
19
+ DEFAULT_MAKE_MASK_NAME = default_make_mask_name(DEFAULT_MASK_POSTFIX)
20
+ DEFAULT_REPLICAS_INFO: ReplicasInfoProtocol = ReplicasInfo()
@@ -0,0 +1,100 @@
1
+ from collections.abc import Sequence
2
+ from typing import Callable
3
+
4
+ import torch
5
+
6
+ from replay.data.nn.parquet.impl.masking import DEFAULT_MAKE_MASK_NAME
7
+
8
+ from .column_protocol import ColumnProtocol
9
+
10
+ Batch = dict[str, torch.Tensor]
11
+
12
+
13
+ def deduce_device(columns: Sequence[ColumnProtocol]) -> torch.device:
14
+ """
15
+ Sanity check for matching devices on all of dataset's columns.
16
+
17
+ :param columns: A list of dataset's column data.
18
+ :raises RuntimeError: If any of the columns have mismatching devices.
19
+ :return: The determined columns' device.
20
+ """
21
+ assert len(columns) > 0
22
+ device = columns[0].device
23
+
24
+ def is_correct_device(column: ColumnProtocol) -> bool:
25
+ return column.device == device
26
+
27
+ if not all(map(is_correct_device, columns)): # pragma: no cover
28
+ msg = "Columns must be all on the same device."
29
+ raise RuntimeError(msg)
30
+ return device
31
+
32
+
33
+ def deduce_length(columns: Sequence[ColumnProtocol]) -> int:
34
+ """
35
+ Sanity check for matching lengths on all of dataset's columns.
36
+
37
+ :param columns: A list of dataset's column data.
38
+ :raises RuntimeError: If any of the columns has less rows than others.
39
+ :return: The determined columns' length.
40
+ """
41
+ assert len(columns) > 0
42
+ length = columns[0].length
43
+
44
+ def is_correct_length(column: ColumnProtocol) -> bool:
45
+ return column.length == length
46
+
47
+ if not all(map(is_correct_length, columns)):
48
+ msg = "Columns must have the same lengths."
49
+ raise RuntimeError(msg)
50
+ assert length > 0
51
+ return length
52
+
53
+
54
+ def deduce_length_device(columns: dict[str, ColumnProtocol]) -> tuple[int, torch.device]:
55
+ """A combination check for both matching devices and lengths."""
56
+ raw = [*columns.values()]
57
+ columns_length = deduce_length(raw)
58
+ columns_device = deduce_device(raw)
59
+ del raw
60
+ return (columns_length, columns_device)
61
+
62
+
63
+ class NamedColumns:
64
+ """
65
+ Representation of a data batch read from the filesystem.
66
+ This representation contains all of the columns read into memory, as well as
67
+ metadata such as their length and current device.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ columns: dict[str, ColumnProtocol],
73
+ make_mask_name: Callable[[str], str] = DEFAULT_MAKE_MASK_NAME,
74
+ ) -> None:
75
+ """
76
+ :param columns: Column data read from the filesystem.
77
+ :param make_mask_name: A function generating matching mask names for each column.
78
+ """
79
+ self.columns_length, self.columns_device = deduce_length_device(columns)
80
+
81
+ self.columns = columns
82
+ self.make_mask_name = make_mask_name
83
+
84
+ @property
85
+ def length(self) -> int:
86
+ return self.columns_length
87
+
88
+ @property
89
+ def device(self) -> torch.device:
90
+ return self.columns_device
91
+
92
+ def __len__(self) -> int:
93
+ return self.columns_length
94
+
95
+ def __getitem__(self, indices: torch.LongTensor) -> Batch:
96
+ indices = indices.to(device=self.device)
97
+ result = {}
98
+ for name, column in self.columns.items():
99
+ result[self.make_mask_name(name)], result[name] = column[indices]
100
+ return result
@@ -0,0 +1,110 @@
1
+ from typing import Any, Optional
2
+
3
+ import pyarrow as pa
4
+ import torch
5
+
6
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
7
+ from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
8
+ from replay.data.nn.parquet.metadata import Metadata, get_numeric_columns
9
+ from replay.data.utils.typing.dtype import pyarrow_to_torch
10
+
11
+ from .column_protocol import OutputType
12
+ from .utils import ensure_mutable
13
+
14
+
15
+ class NumericColumn:
16
+ """A representation of a numeric column, containing a single number in each of its rows."""
17
+
18
+ def __init__(
19
+ self,
20
+ data: torch.Tensor,
21
+ mask: Optional[torch.BoolTensor] = None,
22
+ padding: Any = DEFAULT_PADDING,
23
+ ) -> None:
24
+ """
25
+ :param data: A tensor containing column data.
26
+ :param mask: A mask tensor to differentiate real values from paddings. Default: ``None``.
27
+ :param padding: Padding to use for future indexing of non-existent data. Default: value of ``DEFAULT_PADDING``.
28
+ """
29
+ self.padding: Any = padding
30
+ self.data = data
31
+ self.mask = mask
32
+
33
+ @property
34
+ def length(self) -> int:
35
+ result = torch.numel(self.data)
36
+ if self.mask is not None:
37
+ assert result == torch.numel(self.mask)
38
+ return result
39
+
40
+ def __len__(self) -> int:
41
+ return self.length
42
+
43
+ @property
44
+ def device(self) -> torch.device:
45
+ result = self.data.device
46
+ if self.mask is not None:
47
+ assert result == self.mask.device
48
+ return result
49
+
50
+ def _get_mask(self, indices: torch.LongTensor) -> torch.BoolTensor:
51
+ mask = torch.ones_like(indices, dtype=torch.bool) if self.mask is None else self.mask[indices]
52
+ return mask
53
+
54
+ def __getitem__(self, indices: torch.LongTensor) -> OutputType:
55
+ indices = indices.to(device=self.device)
56
+ mask = self._get_mask(indices)
57
+ output = torch.where(mask, self.data[indices], self.padding)
58
+ return (mask, output)
59
+
60
+
61
+ def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE, padding: Any = DEFAULT_PADDING) -> OutputType:
62
+ """
63
+ Converts a PyArrow array into a PyTorch tensor.
64
+
65
+ :param array: Original PyArrow array.
66
+ :param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
67
+ :param padding: Value to fill null values with. Default: value of to ``DEFAULT_PADDING``.
68
+
69
+ :return: A PyTorch tensor obtained from original array.
70
+ """
71
+ dtype = pyarrow_to_torch(array.type)
72
+
73
+ mask_torch = None
74
+ if array.null_count > 0:
75
+ mask_torch = torch.asarray(
76
+ ensure_mutable(array.is_valid().to_numpy(zero_copy_only=False)),
77
+ device=device,
78
+ dtype=torch.bool,
79
+ )
80
+
81
+ array_torch = torch.asarray(
82
+ ensure_mutable(array.fill_null(padding).to_numpy()),
83
+ device=device,
84
+ dtype=dtype,
85
+ )
86
+ return (mask_torch, array_torch)
87
+
88
+
89
+ def to_numeric_columns(
90
+ data: pa.RecordBatch,
91
+ metadata: Metadata,
92
+ device: torch.device = DEFAULT_DEVICE,
93
+ padding: Any = DEFAULT_PADDING,
94
+ ) -> dict[str, NumericColumn]:
95
+ """
96
+ Converts a PyArrow batch of data to a set of ``NumericColumn``s.
97
+ This function filters only those columns matching its format from the full batch.
98
+
99
+ :param data: A PyArrow batch of column data.
100
+ :param metadata: Metadata containing information about columns' formats.
101
+ :param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
102
+ :param padding: Padding to use for future indexing of non-existent data. Default: value of ``DEFAULT_PADDING``.
103
+
104
+ :return: A dict of tensors containing dataset's numeric columns.
105
+ """
106
+ result = {}
107
+ for column_name in get_numeric_columns(metadata):
108
+ mask, torch_array = to_torch(data.column(column_name), device, padding)
109
+ result[column_name] = NumericColumn(data=torch_array, mask=mask, padding=padding)
110
+ return result
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+
3
+ WRITEABLE_FLAG: str = "WRITEABLE"
4
+
5
+
6
+ def ensure_mutable(array: np.array) -> np.array:
7
+ """
8
+ Ensures the resulting NumPy array is mutable by making a copy if it's not.
9
+
10
+ :param array: Array to be checked for mutability.
11
+ :return: Mutable copy of `array`.
12
+ """
13
+ if not array.flags[WRITEABLE_FLAG]:
14
+ result = array.copy()
15
+ assert result.flags[WRITEABLE_FLAG]
16
+ return result
17
+ return array
@@ -0,0 +1,40 @@
1
+ from typing import Protocol
2
+
3
+ import torch.distributed as dist
4
+
5
+
6
+ class DistributedInfo:
7
+ """Wrapper class for Torch's distibuted environment metadata."""
8
+
9
+ def __iter__(self):
10
+ yield self.rank
11
+ yield self.world_size
12
+
13
+ @property
14
+ def is_distributed(self) -> bool:
15
+ if dist.is_available():
16
+ return dist.is_initialized()
17
+ return False
18
+
19
+ @property
20
+ def rank(self) -> int:
21
+ if self.is_distributed:
22
+ return dist.get_rank()
23
+ return 0
24
+
25
+ @property
26
+ def world_size(self) -> int:
27
+ if self.is_distributed:
28
+ return dist.get_world_size()
29
+ return 1
30
+
31
+
32
+ class DistributedInfoProtocol(Protocol):
33
+ @property
34
+ def rank(self) -> int: ...
35
+
36
+ @property
37
+ def world_size(self) -> int: ...
38
+
39
+
40
+ DEFAULT_DISTRIBUTED_INFO: DistributedInfo = DistributedInfo()
@@ -0,0 +1,132 @@
1
+ from functools import lru_cache
2
+ from math import ceil
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
8
+
9
+
10
+ def validate_length(length: int) -> int:
11
+ if length < 1:
12
+ msg = f"Length is invalid. Got {length}."
13
+ raise ValueError(msg)
14
+ return length
15
+
16
+
17
+ def validate_num_replicas(num_replicas: int) -> int:
18
+ if num_replicas < 1:
19
+ msg = f"Num Replicas is invalid. Got {num_replicas}."
20
+ raise ValueError(msg)
21
+ return num_replicas
22
+
23
+
24
+ def validate_curr_replica(curr_replica: int, num_replicas: int) -> int:
25
+ num_replicas = validate_num_replicas(num_replicas)
26
+ if (curr_replica < 0) or (num_replicas <= curr_replica):
27
+ msg = f"Curr Replicas is invalid. Got {curr_replica}."
28
+ raise ValueError(msg)
29
+ return curr_replica
30
+
31
+
32
+ @lru_cache
33
+ def _partitioning_length(length: int, num_replicas: int) -> int:
34
+ length = validate_length(length)
35
+ num_replicas = validate_num_replicas(num_replicas)
36
+
37
+ result = length
38
+ if length % num_replicas != 0:
39
+ raw_per_replica = length / num_replicas
40
+ per_replica = ceil(raw_per_replica)
41
+ new_length = per_replica * num_replicas
42
+ assert (new_length - length) < num_replicas
43
+ result = new_length
44
+ assert result % num_replicas == 0
45
+ assert length <= result
46
+ return result
47
+
48
+
49
+ def partitioning_length(length: int, num_replicas: int) -> int:
50
+ return _partitioning_length(length, num_replicas)
51
+
52
+
53
+ @lru_cache
54
+ def _partitioning_per_replica(length: int, num_replicas: int) -> int:
55
+ full_length = partitioning_length(length, num_replicas)
56
+ result = full_length // num_replicas
57
+ assert result <= length
58
+ assert result > 0
59
+ return result
60
+
61
+
62
+ def partitioning_per_replica(length: int, num_replicas: int) -> int:
63
+ return _partitioning_per_replica(length, num_replicas)
64
+
65
+
66
+ class Partitioning:
67
+ """Utility class for calculating valid indices across multiple replicas."""
68
+
69
+ def __init__(
70
+ self,
71
+ curr_replica: int,
72
+ num_replicas: int,
73
+ device: Union[torch.device, str] = DEFAULT_DEVICE,
74
+ generator: Optional[torch.Generator] = None,
75
+ ) -> None:
76
+ """
77
+ :param curr_replica: Id of the curreent replica.
78
+ :param num_replicas: Total number of active replicas.
79
+ :param device: Target device to send the indices tensor to.
80
+ Default: value of ``DEFAULT_DEVICE``.
81
+ :param generator: A pseudo-random number generator for index shuffling. Default: ``None``.
82
+ """
83
+ self.device = torch.device(device)
84
+ self.generator = generator
85
+ self.num_replicas = validate_num_replicas(num_replicas)
86
+ self.curr_replica = validate_curr_replica(curr_replica, self.num_replicas)
87
+
88
+ def generate_raw_indices(self, length: int) -> torch.LongTensor:
89
+ full_length = partitioning_length(length, self.num_replicas)
90
+
91
+ if self.generator is None:
92
+ raw_indices = torch.arange(full_length, dtype=torch.int64, device=self.device)
93
+ else:
94
+ raw_indices = torch.randperm(full_length, dtype=torch.int64, generator=self.generator)
95
+ raw_indices = raw_indices.to(device=self.device)
96
+
97
+ assert torch.max(raw_indices).cpu().item() < full_length
98
+ assert torch.numel(raw_indices) == full_length
99
+ assert raw_indices.device == self.device
100
+
101
+ return raw_indices
102
+
103
+ def replica_indices(self, raw_indices: torch.LongTensor) -> torch.LongTensor:
104
+ full_length = torch.numel(raw_indices)
105
+ slc = slice(self.curr_replica, full_length, self.num_replicas)
106
+ replica_indices = raw_indices[slc].clone()
107
+
108
+ assert torch.max(replica_indices).cpu().item() < full_length
109
+
110
+ return replica_indices
111
+
112
+ def generate(self, length: int) -> torch.LongTensor:
113
+ raw_indices = self.generate_raw_indices(length)
114
+ full_length = partitioning_length(length, self.num_replicas)
115
+
116
+ assert torch.numel(raw_indices) == full_length
117
+
118
+ replica_indices = self.replica_indices(raw_indices)
119
+ per_replica = partitioning_per_replica(length, self.num_replicas)
120
+
121
+ assert torch.numel(replica_indices) == per_replica
122
+
123
+ indices = torch.remainder(replica_indices, length)
124
+
125
+ assert torch.max(indices).cpu().item() < length
126
+ assert torch.numel(indices) == per_replica
127
+ assert indices.device == self.device
128
+
129
+ return indices
130
+
131
+ def __call__(self, length: int) -> torch.LongTensor:
132
+ return self.generate(length)
@@ -0,0 +1,67 @@
1
+ from typing import Protocol
2
+
3
+ from .distributed_info import DEFAULT_DISTRIBUTED_INFO, DistributedInfoProtocol
4
+ from .worker_info import DEFAULT_WORKER_INFO, WorkerInfoProtocol
5
+
6
+
7
+ def num_replicas(
8
+ worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
9
+ distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
10
+ ) -> int:
11
+ return worker_info.num_workers * distributed_info.world_size
12
+
13
+
14
+ def curr_replica(
15
+ worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
16
+ distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
17
+ ) -> int:
18
+ result = worker_info.id + worker_info.num_workers * distributed_info.rank
19
+ assert result < num_replicas(worker_info, distributed_info)
20
+ return result
21
+
22
+
23
+ class ReplicasInfoProtocol(Protocol):
24
+ @property
25
+ def num_replicas(self) -> int: ...
26
+
27
+ @property
28
+ def curr_replica(self) -> int: ...
29
+
30
+
31
+ class ReplicasInfo:
32
+ """
33
+ A replica metadata geneartor.
34
+
35
+ By default, assumes standard Torch DDP training/inference procedure,
36
+ where each replica (a distinct worker on a specific device) is expected to process
37
+ a separate chunk of the dataset.
38
+
39
+ This behavior can be modified by providing custom ``worker_info`` and ``distributed_info`` objects
40
+ able to provide infor about local worker count and world size/rank respectively.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ worker_info: WorkerInfoProtocol = DEFAULT_WORKER_INFO,
46
+ distributed_info: DistributedInfoProtocol = DEFAULT_DISTRIBUTED_INFO,
47
+ ) -> None:
48
+ """
49
+ :param worker_info: An object adhering to the ``WorkerInfoProtocol`` and used to obtain local worker count.
50
+ Default: value of ``DEFAULT_WORKER_INFO`` - an implementation using ``torch.utils.data.get_worker_info()``.
51
+ :param distributed_info: An object adhering to the ``DistributedInfoProtocol`` and used to obtain
52
+ world size and rank. Default: value of ``DEFAULT_WORKER_INFO`` - an implementation using the
53
+ ``torch.distributed`` module.
54
+ """
55
+ self.worker_info = worker_info
56
+ self.distributed_info = distributed_info
57
+
58
+ @property
59
+ def num_replicas(self) -> int:
60
+ return num_replicas(worker_info=self.worker_info, distributed_info=self.distributed_info)
61
+
62
+ @property
63
+ def curr_replica(self) -> int:
64
+ return curr_replica(worker_info=self.worker_info, distributed_info=self.distributed_info)
65
+
66
+
67
+ DEFAULT_REPLICAS_INFO: ReplicasInfoProtocol = ReplicasInfo()
@@ -0,0 +1,43 @@
1
+ from typing import Any, Iterator, Optional, Protocol
2
+
3
+ import torch.utils.data as data
4
+
5
+
6
+ class WorkerInfoProtocol(Protocol):
7
+ @property
8
+ def id(self) -> int: ...
9
+
10
+ @property
11
+ def num_workers(self) -> int: ...
12
+
13
+
14
+ class WorkerInfo:
15
+ """Wrapper class for Torch's worker metadata."""
16
+
17
+ def __iter__(self) -> Iterator[int]:
18
+ yield self.id
19
+
20
+ @property
21
+ def worker_info(self) -> Optional[Any]:
22
+ return data.get_worker_info()
23
+
24
+ @property
25
+ def is_parallel(self) -> bool:
26
+ return self.worker_info is not None
27
+
28
+ @property
29
+ def id(self) -> int:
30
+ wi: Optional[data.WorkerInfo] = self.worker_info
31
+ if wi is not None:
32
+ return wi.id
33
+ return 0
34
+
35
+ @property
36
+ def num_workers(self) -> int:
37
+ wi: Optional[data.WorkerInfo] = self.worker_info
38
+ if wi is not None:
39
+ return wi.num_workers
40
+ return 1
41
+
42
+
43
+ DEFAULT_WORKER_INFO: WorkerInfo = WorkerInfo()