xax 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.10"
15
+ __version__ = "0.1.11"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,9 +43,16 @@ __all__ = [
43
43
  "euler_to_quat",
44
44
  "get_projected_gravity_vector_from_quat",
45
45
  "quat_to_euler",
46
+ "cross_entropy",
46
47
  "cast_norm_type",
47
48
  "get_norm",
48
49
  "is_master",
50
+ "DiagSSMBlock",
51
+ "DiscreteTimeS4",
52
+ "S4",
53
+ "S4Layer",
54
+ "S6Layer",
55
+ "SSMBlock",
49
56
  "BaseLauncher",
50
57
  "CliLauncher",
51
58
  "SingleProcessLauncher",
@@ -196,9 +203,16 @@ NAME_MAP: dict[str, str] = {
196
203
  "euler_to_quat": "nn.geom",
197
204
  "get_projected_gravity_vector_from_quat": "nn.geom",
198
205
  "quat_to_euler": "nn.geom",
206
+ "cross_entropy": "nn.losses",
199
207
  "cast_norm_type": "nn.norm",
200
208
  "get_norm": "nn.norm",
201
209
  "is_master": "nn.parallel",
210
+ "DiagSSMBlock": "nn.ssm",
211
+ "DiscreteTimeS4": "nn.ssm",
212
+ "S4": "nn.ssm",
213
+ "S4Layer": "nn.ssm",
214
+ "S6Layer": "nn.ssm",
215
+ "SSMBlock": "nn.ssm",
202
216
  "BaseLauncher": "task.launchers.base",
203
217
  "CliLauncher": "task.launchers.cli",
204
218
  "SingleProcessLauncher": "task.launchers.single_process",
@@ -351,8 +365,10 @@ if IMPORT_ALL or TYPE_CHECKING:
351
365
  get_projected_gravity_vector_from_quat,
352
366
  quat_to_euler,
353
367
  )
368
+ from xax.nn.losses import cross_entropy
354
369
  from xax.nn.norm import NormType, cast_norm_type, get_norm
355
370
  from xax.nn.parallel import is_master
371
+ from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
356
372
  from xax.task.base import RawConfigType
357
373
  from xax.task.launchers.base import BaseLauncher
358
374
  from xax.task.launchers.cli import CliLauncher
xax/nn/losses.py ADDED
@@ -0,0 +1,9 @@
1
+ """Defines some common loss functions."""
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
6
+
7
+ def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
8
+ pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
9
+ return -jnp.mean(pred_y)
xax/nn/norm.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from typing import Literal, cast, get_args
4
4
 
5
5
  import jax.numpy as jnp
6
+ from jaxtyping import Array
6
7
 
7
8
  NormType = Literal["l1", "l2"]
8
9
 
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
13
14
  return cast(NormType, norm)
14
15
 
15
16
 
16
- def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
17
+ def get_norm(x: Array, norm: NormType) -> Array:
17
18
  match norm:
18
19
  case "l1":
19
20
  return jnp.abs(x)
xax/nn/ssm.py ADDED
@@ -0,0 +1,296 @@
1
+ """State space models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+
12
+ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
+ return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
+
15
+
16
+ class DiscreteTimeS4(eqx.Module):
17
+ a: Array
18
+ B: Array
19
+ C: Array
20
+ proj_in: eqx.nn.Linear
21
+ proj_out: eqx.nn.Linear
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ projection_size: int,
27
+ input_size: int,
28
+ output_size: int,
29
+ *,
30
+ key: PRNGKeyArray,
31
+ ) -> None:
32
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
33
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
34
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
35
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
36
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
37
+
38
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
39
+ h = self.a * h + self.B.T @ x
40
+ y = self.C.T @ h
41
+ return h, y
42
+
43
+ def predict_sequence(self, x_seq: Array) -> Array:
44
+ x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
45
+ h = jnp.zeros(self.a.shape[0])
46
+
47
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
48
+ h = self.a * h + self.B.T @ x
49
+ y = self.C.T @ h
50
+ return h, y
51
+
52
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
53
+ y_out = jax.vmap(self.proj_out)(y_seq)
54
+ return y_out
55
+
56
+
57
+ class S4Layer(eqx.Module):
58
+ a: Array
59
+ B: Array
60
+ C: Array
61
+ proj_in: eqx.nn.Linear
62
+ proj_out: eqx.nn.Linear
63
+ delta: Array
64
+
65
+ def __init__(
66
+ self,
67
+ hidden_size: int,
68
+ projection_size: int,
69
+ input_size: int,
70
+ output_size: int,
71
+ *,
72
+ key: PRNGKeyArray,
73
+ ) -> None:
74
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
75
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
76
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
77
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
78
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
79
+ self.delta = jax.random.uniform(key, (hidden_size,))
80
+
81
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
82
+ delta_a = self.delta * self.a
83
+ a_bar = jnp.exp(delta_a)
84
+ b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
85
+ h = a_bar * h + b_bar.T @ x
86
+ y = self.C.T @ h
87
+ return h, y
88
+
89
+ def predict_sequence(self, x_seq: Array) -> Array:
90
+ x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
91
+ h = jnp.zeros(self.a.shape[0])
92
+
93
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
94
+ h = self.a * h + self.B.T @ x
95
+ y = self.C.T @ h
96
+ return h, y
97
+
98
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
99
+ y_out = jax.vmap(self.proj_out)(y_seq)
100
+ return y_out
101
+
102
+
103
+ class S6Layer(eqx.Module):
104
+ a: Array
105
+ B: Array
106
+ C: Array
107
+ proj_in: eqx.nn.Linear
108
+ proj_out: eqx.nn.Linear
109
+ delta: Array
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ projection_size: int,
115
+ input_size: int,
116
+ output_size: int,
117
+ *,
118
+ key: PRNGKeyArray,
119
+ ) -> None:
120
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
121
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
122
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
123
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
124
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
125
+ self.delta = jax.random.uniform(key, (hidden_size,))
126
+
127
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
128
+ h = self.a * h + self.B.T @ x
129
+ y = self.C.T @ h
130
+ return h, y
131
+
132
+ def predict_sequence(self, x_seq: Array) -> Array:
133
+ x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
134
+ h = jnp.zeros(self.a.shape[0])
135
+
136
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
137
+ h = self.a * h + self.B.T @ x
138
+ y = self.C.T @ h
139
+ return h, y
140
+
141
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
142
+ y_out = jax.vmap(self.proj_out)(y_seq)
143
+ return y_out
144
+
145
+
146
+ class BaseSSMBlock(eqx.Module, ABC):
147
+ @abstractmethod
148
+ def forward(self, h: Array, x: Array) -> Array:
149
+ pass
150
+
151
+
152
+ class SSMBlock(BaseSSMBlock):
153
+ a_mat: Array
154
+ b_mat: Array
155
+
156
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
157
+ key_a, key_b = jax.random.split(key)
158
+ self.a_mat = glorot(key_a, (hidden_size, hidden_size))
159
+ self.b_mat = glorot(key_b, (hidden_size, hidden_size))
160
+
161
+ def forward(self, h: Array, x: Array) -> Array:
162
+ h = self.a_mat @ h + self.b_mat.T @ x
163
+ return h
164
+
165
+ def get_kernel(self, length: int) -> Array:
166
+ return self.a_mat
167
+
168
+
169
+ class DiagSSMBlock(BaseSSMBlock):
170
+ a_mat: Array
171
+ b_mat: Array
172
+
173
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
174
+ keys = jax.random.split(key, 2)
175
+ self.a_mat = glorot(keys[0], (hidden_size,))
176
+ self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
177
+
178
+ def forward(self, h: Array, x: Array) -> Array:
179
+ h = self.a_mat * h + self.b_mat.T @ x
180
+ h = jax.nn.tanh(h)
181
+ return h
182
+
183
+ def get_kernel(self, length: int) -> Array:
184
+ """Returns the kernel with time as the final dimension."""
185
+ exponents = jnp.arange(length)
186
+ kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
187
+ kernel = kernel[:, None, :] # (H, 1, L)
188
+ return kernel
189
+
190
+ def forward_across_time(self, x: Array) -> Array:
191
+ """Convolves x (T, H) across time using the kernel."""
192
+ tsz, nhid = x.shape
193
+
194
+ # Compute s = x @ U.T + b, with shape (N, T, H)
195
+ s = self.b_mat.T @ x
196
+ s = s.T # (H, T)
197
+
198
+ kernel = self.get_kernel(tsz) # (H, 1, T)
199
+ kernel_flipped = jnp.flip(kernel, axis=-1)
200
+
201
+ # Pad s on the left along the time axis (pad length T-1)
202
+ s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
203
+
204
+ # Perform depthwise (grouped) 1D convolution.
205
+ # We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
206
+ # The dimension_numbers are chosen so that the channel dimension is second.
207
+ conv_out = jax.lax.conv_general_dilated(
208
+ s_padded,
209
+ kernel_flipped,
210
+ window_strides=(1,),
211
+ padding="VALID",
212
+ dimension_numbers=("NCH", "OIH", "NCH"),
213
+ feature_group_count=nhid,
214
+ )
215
+ # conv_out has shape (N, H, T); transpose to (N, T, H)
216
+ conv_out = jnp.transpose(conv_out, (0, 2, 1))
217
+ return conv_out
218
+
219
+ def naive_forward_accross_time(self, x: Array) -> Array:
220
+ """Naively forward across time."""
221
+
222
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
223
+ h = self.forward(h, x)
224
+ return h, h
225
+
226
+ h_0 = jnp.zeros(self.a_mat.shape[0])
227
+ _, h_seq = jax.lax.scan(step, h_0, x)
228
+ return h_seq
229
+
230
+
231
+ class S4(eqx.Module):
232
+ vocab_embedding: eqx.nn.Embedding
233
+ proj_in: eqx.nn.Linear
234
+ proj_out: eqx.nn.Linear
235
+ blocks: list[BaseSSMBlock]
236
+ num_layers: int = eqx.static_field()
237
+ hidden_size: int = eqx.static_field()
238
+ skip_connections: bool = eqx.static_field()
239
+
240
+ def __init__(
241
+ self,
242
+ input_size: int,
243
+ hidden_size: int,
244
+ output_size: int,
245
+ num_layers: int,
246
+ block_type: Literal["ssm", "diag"] = "ssm",
247
+ skip_connections: bool = False,
248
+ *,
249
+ key: PRNGKeyArray,
250
+ ) -> None:
251
+ vocab_key, s4_key = jax.random.split(key, 2)
252
+ self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
253
+ self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
254
+ self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
255
+
256
+ block_keys = jax.random.split(s4_key, num_layers)
257
+
258
+ def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
259
+ match block_type:
260
+ case "ssm":
261
+ return SSMBlock(hidden_size, key=key)
262
+ case "diag":
263
+ return DiagSSMBlock(hidden_size, key=key)
264
+ case _:
265
+ raise ValueError(f"Unknown block type: {block_type}")
266
+
267
+ self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
268
+ self.skip_connections = skip_connections
269
+ self.num_layers = num_layers
270
+ self.hidden_size = hidden_size
271
+
272
+ def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
273
+ new_hs = []
274
+ for i, block in enumerate(self.blocks):
275
+ h = block.forward(hs[i], x)
276
+ new_hs.append(h)
277
+ xh = jax.nn.gelu(h)
278
+ x = xh + x if self.skip_connections else xh
279
+ y = self.proj_out(x)
280
+ return new_hs, y
281
+
282
+ def _embed_input(self, x: Array) -> Array:
283
+ """U is the input to the S4 cell."""
284
+ embedded = self.vocab_embedding(x)
285
+ return jax.nn.gelu(self.proj_in(embedded))
286
+
287
+ def predict_sequence(self, x_seq: Array) -> Array:
288
+ x_emb = jax.vmap(self._embed_input)(x_seq)
289
+ hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
290
+
291
+ def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
292
+ hs, y = self(hs, x)
293
+ return hs, y
294
+
295
+ _, y_seq = jax.lax.scan(step, hs, x_emb)
296
+ return y_seq
xax/task/mixins/train.py CHANGED
@@ -29,6 +29,7 @@ from typing import (
29
29
 
30
30
  import equinox as eqx
31
31
  import jax
32
+ import jax.numpy as jnp
32
33
  import numpy as np
33
34
  import optax
34
35
  from jaxtyping import Array, PRNGKeyArray, PyTree
@@ -162,6 +163,7 @@ class TrainConfig(
162
163
  max_steps: int | None = field(None, help="Maximum number of steps to run")
163
164
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
164
165
  random_seed: int = field(1337, help="Random seed for the task")
166
+ global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
165
167
 
166
168
 
167
169
  Config = TypeVar("Config", bound=TrainConfig)
@@ -403,12 +405,12 @@ class TrainMixin(
403
405
  model_static: PyTree,
404
406
  batch: Batch,
405
407
  state: State,
406
- ) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
408
+ ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
407
409
  model = eqx.combine(model_arr, model_static)
408
410
  output = self.get_output(model, batch, state)
409
411
  loss = self.compute_loss(model, batch, output, state)
410
412
  metrics = self.compute_metrics(model, batch, output, loss, state)
411
- return loss, (output, FrozenDict(metrics))
413
+ return loss, (output, metrics)
412
414
 
413
415
  def update(
414
416
  self,
@@ -418,13 +420,44 @@ class TrainMixin(
418
420
  opt_state: optax.OptState,
419
421
  batch: Batch,
420
422
  state: State,
421
- ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
423
+ ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
422
424
  grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
423
425
  grad_fn = xax_jit(static_argnums=[1])(grad_fn)
424
426
  grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
425
- updates, opt_state = optimizer.update(grads, opt_state, model_arr)
426
- model_arr = eqx.apply_updates(model_arr, updates)
427
- return model_arr, opt_state, output, metrics
427
+ model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
428
+ return model_arr, opt_state, output, metrics | grad_metrics
429
+
430
+ @xax_jit(static_argnames=["self", "optimizer"])
431
+ def apply_gradients_with_clipping(
432
+ self,
433
+ model_arr: PyTree,
434
+ grads: PyTree,
435
+ optimizer: optax.GradientTransformation,
436
+ opt_state: optax.OptState,
437
+ ) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
438
+ grad_norm = optax.global_norm(grads)
439
+ grad_metrics = {"grad_norm": grad_norm}
440
+
441
+ def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
442
+ # Clip the global gradient norm to some desired range.
443
+ grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
444
+ grads = jax.tree.map(lambda x: x * grad_factor, grads)
445
+
446
+ # Apply the gradient updates.
447
+ updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
448
+ new_model_arr = eqx.apply_updates(model_arr, updates)
449
+ return new_model_arr, new_opt_state
450
+
451
+ # Don't apply updates if the gradient is NaN or Inf.
452
+ new_model_arr, new_opt_state = jax.lax.cond(
453
+ jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
454
+ lambda *_: (model_arr, opt_state),
455
+ apply,
456
+ grads,
457
+ grad_norm,
458
+ )
459
+
460
+ return new_model_arr, new_opt_state, grad_metrics
428
461
 
429
462
  def get_size_of_batch(self, batch: Batch) -> int | None:
430
463
  """Gets the batch size for the current batch.
@@ -512,7 +545,7 @@ class TrainMixin(
512
545
  state: State,
513
546
  ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
514
547
  model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
515
- return model_arr, opt_state, output, metrics
548
+ return model_arr, opt_state, output, FrozenDict(metrics)
516
549
 
517
550
  @xax_jit(static_argnames=["self", "model_static"])
518
551
  def val_step(
@@ -523,7 +556,7 @@ class TrainMixin(
523
556
  state: State,
524
557
  ) -> tuple[Output, FrozenDict[str, Array]]:
525
558
  _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
526
- return output, metrics
559
+ return output, FrozenDict(metrics)
527
560
 
528
561
  def train_loop(
529
562
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.10
3
+ Version: 0.1.11
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
1
+ xax/__init__.py,sha256=2JdSxsZphJJFVMGBVXNc0hP2p0FVOu5y7xSgPRNeyNY,13835
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -11,8 +11,10 @@ xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
13
  xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
- xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
14
+ xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
+ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
15
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
+ xax/nn/ssm.py,sha256=eFeGkV1pkVGc0vNrQbykCbFnlPXQqsqVA_JVzLBHD28,9865
16
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
19
  xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
18
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
44
+ xax/task/mixins/train.py,sha256=lgLHiHQtnDK0XS3SwHTYZtDv5CTbPRN1-p_K9KiIpHQ,26000
43
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
46
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
56
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
- xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
- xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
61
- xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
- xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
- xax-0.1.10.dist-info/RECORD,,
61
+ xax-0.1.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.11.dist-info/METADATA,sha256=qDhn5EGxdiuEe5gQUZiBC430sXhJOPRWboTvsh2onxs,1878
63
+ xax-0.1.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.11.dist-info/RECORD,,
File without changes