xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,206 @@
|
|
1
|
+
"""Defines custom collation functions for PyTorch datasets."""
|
2
|
+
|
3
|
+
from dataclasses import is_dataclass
|
4
|
+
from typing import Any, Callable, Literal
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from PIL.Image import Image as PILImage
|
8
|
+
|
9
|
+
CollateMode = Literal["stack", "concat"]
|
10
|
+
|
11
|
+
|
12
|
+
def is_named_tuple(obj: Any) -> bool: # noqa: ANN401
|
13
|
+
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
14
|
+
|
15
|
+
|
16
|
+
def pad_sequence(
|
17
|
+
tensors: list[np.ndarray],
|
18
|
+
*,
|
19
|
+
dim: int = 0,
|
20
|
+
max_length: int | None = None,
|
21
|
+
left_pad: bool = False,
|
22
|
+
left_truncate: bool = False,
|
23
|
+
pad_value: int | float | bool = 0,
|
24
|
+
) -> list[np.ndarray]:
|
25
|
+
"""Pads or truncates a sequence of tensors to the same length.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
tensors: The tensors to pad or truncate
|
29
|
+
dim: The dimension to pad or truncate
|
30
|
+
max_length: The maximum tensor length
|
31
|
+
left_pad: If set, pad on the left side, otherwise pad the right side
|
32
|
+
left_truncate: If set, truncate on the left side, otherwise truncate
|
33
|
+
on the right side
|
34
|
+
pad_value: The padding value to use
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The padded or truncated tensors
|
38
|
+
|
39
|
+
Raises:
|
40
|
+
ValueError: If the tensor dimensions are invalid
|
41
|
+
"""
|
42
|
+
if not tensors:
|
43
|
+
return tensors
|
44
|
+
|
45
|
+
num_dims = tensors[0].ndim
|
46
|
+
if num_dims == 0:
|
47
|
+
raise ValueError("Tensor dimensions must be greater than zero")
|
48
|
+
if not all(t.ndim == num_dims for t in tensors):
|
49
|
+
tensor_dims = {t.ndim for t in tensors}
|
50
|
+
raise ValueError(f"All tensors should have the same number of dimensions; got {tensor_dims}")
|
51
|
+
|
52
|
+
dim = dim if dim >= 0 else num_dims + dim
|
53
|
+
target_length = int(max(t.shape[dim] for t in tensors))
|
54
|
+
if max_length is not None:
|
55
|
+
target_length = min(target_length, max_length)
|
56
|
+
|
57
|
+
def pad_tensor(t: np.ndarray) -> np.ndarray:
|
58
|
+
length = t.shape[dim]
|
59
|
+
if length > target_length:
|
60
|
+
t = np.take(t, range(length - target_length if left_truncate else 0, target_length), axis=dim)
|
61
|
+
elif length < target_length:
|
62
|
+
padding_shape = [target_length - s if i == dim else s for i, s in enumerate(t.shape)]
|
63
|
+
padding = np.full(padding_shape, fill_value=pad_value)
|
64
|
+
t = np.concatenate((padding, t) if left_pad else (t, padding), axis=dim)
|
65
|
+
return t
|
66
|
+
|
67
|
+
return list(map(pad_tensor, tensors))
|
68
|
+
|
69
|
+
|
70
|
+
def pad_all(
|
71
|
+
tensors: list[np.ndarray],
|
72
|
+
*,
|
73
|
+
max_length: int | None = None,
|
74
|
+
left_pad: bool = False,
|
75
|
+
left_truncate: bool = False,
|
76
|
+
pad_value: int | float | bool = 0,
|
77
|
+
) -> list[np.ndarray]:
|
78
|
+
"""Pads all tensors to the same shape.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
tensors: The tensors to pad
|
82
|
+
max_length: The maximum tensor length
|
83
|
+
left_pad: If set, pad on the left side, otherwise pad the right side
|
84
|
+
left_truncate: If set, truncate on the left side, otherwise truncate
|
85
|
+
on the right side
|
86
|
+
pad_value: The padding value to use
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
The padded tensors
|
90
|
+
"""
|
91
|
+
if not tensors:
|
92
|
+
return tensors
|
93
|
+
|
94
|
+
# Gets the tensor dimension.
|
95
|
+
all_dims = set(t.ndim for t in tensors)
|
96
|
+
assert len(all_dims) == 1, f"Got different numbers of tensor dimensions: {all_dims}"
|
97
|
+
dims = list(all_dims)[0]
|
98
|
+
|
99
|
+
for dim in range(dims):
|
100
|
+
all_sizes = set(t.shape[dim] for t in tensors)
|
101
|
+
if len(all_sizes) > 1:
|
102
|
+
tensors = pad_sequence(
|
103
|
+
tensors,
|
104
|
+
dim=dim,
|
105
|
+
max_length=max_length,
|
106
|
+
left_pad=left_pad,
|
107
|
+
left_truncate=left_truncate,
|
108
|
+
pad_value=pad_value,
|
109
|
+
)
|
110
|
+
|
111
|
+
return tensors
|
112
|
+
|
113
|
+
|
114
|
+
def collate(
|
115
|
+
items: list[Any],
|
116
|
+
*,
|
117
|
+
mode: CollateMode | Callable[[list[np.ndarray]], np.ndarray] = "stack",
|
118
|
+
pad: bool | Callable[[list[np.ndarray]], list[np.ndarray]] = False,
|
119
|
+
) -> Any | None: # noqa: ANN401
|
120
|
+
"""Defines a general-purpose collating function.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
items: The list of items to collate
|
124
|
+
mode: Either `stack`, `concat`, or a custom function which is called on
|
125
|
+
a list of tensors and returns a single tensor
|
126
|
+
pad: If set to True, pads sequences using the default padding function.
|
127
|
+
Can also pass a function which will perform padding
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
The collated item, or None if the item list was empty
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
NotImplementedError: If the mode is invalid
|
134
|
+
"""
|
135
|
+
if len(items) == 0:
|
136
|
+
return None
|
137
|
+
item = items[0]
|
138
|
+
|
139
|
+
# Any None items should be filtered out.
|
140
|
+
if item is None:
|
141
|
+
return None
|
142
|
+
|
143
|
+
# Tensors are either concatenated or stacked.
|
144
|
+
if isinstance(item, np.ndarray):
|
145
|
+
if callable(mode):
|
146
|
+
return mode(items)
|
147
|
+
if isinstance(mode, str):
|
148
|
+
if isinstance(pad, bool) and pad:
|
149
|
+
pad = pad_all
|
150
|
+
if callable(pad):
|
151
|
+
items = pad(items)
|
152
|
+
if mode == "stack":
|
153
|
+
return np.stack(items, axis=0)
|
154
|
+
if mode == "concat":
|
155
|
+
return np.concatenate(items, axis=0)
|
156
|
+
raise NotImplementedError(f"Invalid collate mode: {mode}")
|
157
|
+
raise NotImplementedError(f"Invalid mode type: {type(mode)}")
|
158
|
+
|
159
|
+
# All images are converted to tensors.
|
160
|
+
if isinstance(item, PILImage):
|
161
|
+
return collate([np.asarray(i) for i in items], mode=mode, pad=pad)
|
162
|
+
|
163
|
+
# Numbers are converted to a list of tensors.
|
164
|
+
if isinstance(item, (bool, int, float, complex, np.bool_, np.number)):
|
165
|
+
return collate([np.asarray(i) for i in items], mode=mode, pad=pad)
|
166
|
+
|
167
|
+
# Collate dictionaries if they have the same keys.
|
168
|
+
if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
|
169
|
+
output_dict = {}
|
170
|
+
item_keys_set = set(item.keys())
|
171
|
+
for key_in_set in item_keys_set:
|
172
|
+
output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
|
173
|
+
return output_dict
|
174
|
+
|
175
|
+
# Collate lists and tuples if they have the same lengths.
|
176
|
+
if isinstance(item, (list, tuple)) and all(len(i) == len(item) for i in items):
|
177
|
+
output_list = []
|
178
|
+
for j in range(len(item)):
|
179
|
+
output_list.append(collate([i[j] for i in items], mode=mode, pad=pad))
|
180
|
+
if is_named_tuple(item):
|
181
|
+
return type(item)(*output_list) # type: ignore[arg-type]
|
182
|
+
if isinstance(item, tuple):
|
183
|
+
return tuple(output_list)
|
184
|
+
return output_list
|
185
|
+
|
186
|
+
# Handles dataclasses.
|
187
|
+
if is_dataclass(item):
|
188
|
+
output_dict = {}
|
189
|
+
item_keys_dict = item.__dict__.keys()
|
190
|
+
for key_in_dict in item_keys_dict:
|
191
|
+
output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
|
192
|
+
return item.__class__(**output_dict)
|
193
|
+
|
194
|
+
# By default, don't do anything.
|
195
|
+
return items
|
196
|
+
|
197
|
+
|
198
|
+
def collate_non_null(
|
199
|
+
items: list[Any],
|
200
|
+
*,
|
201
|
+
mode: CollateMode | Callable[[list[np.ndarray]], np.ndarray] = "stack",
|
202
|
+
pad: bool | Callable[[list[np.ndarray]], list[np.ndarray]] = False,
|
203
|
+
) -> Any: # noqa: ANN401
|
204
|
+
collated = collate(items, mode=mode, pad=pad)
|
205
|
+
assert collated is not None
|
206
|
+
return collated
|