xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,206 @@
1
+ """Defines custom collation functions for PyTorch datasets."""
2
+
3
+ from dataclasses import is_dataclass
4
+ from typing import Any, Callable, Literal
5
+
6
+ import numpy as np
7
+ from PIL.Image import Image as PILImage
8
+
9
+ CollateMode = Literal["stack", "concat"]
10
+
11
+
12
+ def is_named_tuple(obj: Any) -> bool: # noqa: ANN401
13
+ return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
14
+
15
+
16
+ def pad_sequence(
17
+ tensors: list[np.ndarray],
18
+ *,
19
+ dim: int = 0,
20
+ max_length: int | None = None,
21
+ left_pad: bool = False,
22
+ left_truncate: bool = False,
23
+ pad_value: int | float | bool = 0,
24
+ ) -> list[np.ndarray]:
25
+ """Pads or truncates a sequence of tensors to the same length.
26
+
27
+ Args:
28
+ tensors: The tensors to pad or truncate
29
+ dim: The dimension to pad or truncate
30
+ max_length: The maximum tensor length
31
+ left_pad: If set, pad on the left side, otherwise pad the right side
32
+ left_truncate: If set, truncate on the left side, otherwise truncate
33
+ on the right side
34
+ pad_value: The padding value to use
35
+
36
+ Returns:
37
+ The padded or truncated tensors
38
+
39
+ Raises:
40
+ ValueError: If the tensor dimensions are invalid
41
+ """
42
+ if not tensors:
43
+ return tensors
44
+
45
+ num_dims = tensors[0].ndim
46
+ if num_dims == 0:
47
+ raise ValueError("Tensor dimensions must be greater than zero")
48
+ if not all(t.ndim == num_dims for t in tensors):
49
+ tensor_dims = {t.ndim for t in tensors}
50
+ raise ValueError(f"All tensors should have the same number of dimensions; got {tensor_dims}")
51
+
52
+ dim = dim if dim >= 0 else num_dims + dim
53
+ target_length = int(max(t.shape[dim] for t in tensors))
54
+ if max_length is not None:
55
+ target_length = min(target_length, max_length)
56
+
57
+ def pad_tensor(t: np.ndarray) -> np.ndarray:
58
+ length = t.shape[dim]
59
+ if length > target_length:
60
+ t = np.take(t, range(length - target_length if left_truncate else 0, target_length), axis=dim)
61
+ elif length < target_length:
62
+ padding_shape = [target_length - s if i == dim else s for i, s in enumerate(t.shape)]
63
+ padding = np.full(padding_shape, fill_value=pad_value)
64
+ t = np.concatenate((padding, t) if left_pad else (t, padding), axis=dim)
65
+ return t
66
+
67
+ return list(map(pad_tensor, tensors))
68
+
69
+
70
+ def pad_all(
71
+ tensors: list[np.ndarray],
72
+ *,
73
+ max_length: int | None = None,
74
+ left_pad: bool = False,
75
+ left_truncate: bool = False,
76
+ pad_value: int | float | bool = 0,
77
+ ) -> list[np.ndarray]:
78
+ """Pads all tensors to the same shape.
79
+
80
+ Args:
81
+ tensors: The tensors to pad
82
+ max_length: The maximum tensor length
83
+ left_pad: If set, pad on the left side, otherwise pad the right side
84
+ left_truncate: If set, truncate on the left side, otherwise truncate
85
+ on the right side
86
+ pad_value: The padding value to use
87
+
88
+ Returns:
89
+ The padded tensors
90
+ """
91
+ if not tensors:
92
+ return tensors
93
+
94
+ # Gets the tensor dimension.
95
+ all_dims = set(t.ndim for t in tensors)
96
+ assert len(all_dims) == 1, f"Got different numbers of tensor dimensions: {all_dims}"
97
+ dims = list(all_dims)[0]
98
+
99
+ for dim in range(dims):
100
+ all_sizes = set(t.shape[dim] for t in tensors)
101
+ if len(all_sizes) > 1:
102
+ tensors = pad_sequence(
103
+ tensors,
104
+ dim=dim,
105
+ max_length=max_length,
106
+ left_pad=left_pad,
107
+ left_truncate=left_truncate,
108
+ pad_value=pad_value,
109
+ )
110
+
111
+ return tensors
112
+
113
+
114
+ def collate(
115
+ items: list[Any],
116
+ *,
117
+ mode: CollateMode | Callable[[list[np.ndarray]], np.ndarray] = "stack",
118
+ pad: bool | Callable[[list[np.ndarray]], list[np.ndarray]] = False,
119
+ ) -> Any | None: # noqa: ANN401
120
+ """Defines a general-purpose collating function.
121
+
122
+ Args:
123
+ items: The list of items to collate
124
+ mode: Either `stack`, `concat`, or a custom function which is called on
125
+ a list of tensors and returns a single tensor
126
+ pad: If set to True, pads sequences using the default padding function.
127
+ Can also pass a function which will perform padding
128
+
129
+ Returns:
130
+ The collated item, or None if the item list was empty
131
+
132
+ Raises:
133
+ NotImplementedError: If the mode is invalid
134
+ """
135
+ if len(items) == 0:
136
+ return None
137
+ item = items[0]
138
+
139
+ # Any None items should be filtered out.
140
+ if item is None:
141
+ return None
142
+
143
+ # Tensors are either concatenated or stacked.
144
+ if isinstance(item, np.ndarray):
145
+ if callable(mode):
146
+ return mode(items)
147
+ if isinstance(mode, str):
148
+ if isinstance(pad, bool) and pad:
149
+ pad = pad_all
150
+ if callable(pad):
151
+ items = pad(items)
152
+ if mode == "stack":
153
+ return np.stack(items, axis=0)
154
+ if mode == "concat":
155
+ return np.concatenate(items, axis=0)
156
+ raise NotImplementedError(f"Invalid collate mode: {mode}")
157
+ raise NotImplementedError(f"Invalid mode type: {type(mode)}")
158
+
159
+ # All images are converted to tensors.
160
+ if isinstance(item, PILImage):
161
+ return collate([np.asarray(i) for i in items], mode=mode, pad=pad)
162
+
163
+ # Numbers are converted to a list of tensors.
164
+ if isinstance(item, (bool, int, float, complex, np.bool_, np.number)):
165
+ return collate([np.asarray(i) for i in items], mode=mode, pad=pad)
166
+
167
+ # Collate dictionaries if they have the same keys.
168
+ if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
169
+ output_dict = {}
170
+ item_keys_set = set(item.keys())
171
+ for key_in_set in item_keys_set:
172
+ output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
173
+ return output_dict
174
+
175
+ # Collate lists and tuples if they have the same lengths.
176
+ if isinstance(item, (list, tuple)) and all(len(i) == len(item) for i in items):
177
+ output_list = []
178
+ for j in range(len(item)):
179
+ output_list.append(collate([i[j] for i in items], mode=mode, pad=pad))
180
+ if is_named_tuple(item):
181
+ return type(item)(*output_list) # type: ignore[arg-type]
182
+ if isinstance(item, tuple):
183
+ return tuple(output_list)
184
+ return output_list
185
+
186
+ # Handles dataclasses.
187
+ if is_dataclass(item):
188
+ output_dict = {}
189
+ item_keys_dict = item.__dict__.keys()
190
+ for key_in_dict in item_keys_dict:
191
+ output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
192
+ return item.__class__(**output_dict)
193
+
194
+ # By default, don't do anything.
195
+ return items
196
+
197
+
198
+ def collate_non_null(
199
+ items: list[Any],
200
+ *,
201
+ mode: CollateMode | Callable[[list[np.ndarray]], np.ndarray] = "stack",
202
+ pad: bool | Callable[[list[np.ndarray]], list[np.ndarray]] = False,
203
+ ) -> Any: # noqa: ANN401
204
+ collated = collate(items, mode=mode, pad=pad)
205
+ assert collated is not None
206
+ return collated