xax 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/mixins/train.py +62 -26
- xax/utils/pytree.py +11 -11
- {xax-0.1.9.dist-info → xax-0.1.10.dist-info}/METADATA +1 -1
- {xax-0.1.9.dist-info → xax-0.1.10.dist-info}/RECORD +8 -8
- {xax-0.1.9.dist-info → xax-0.1.10.dist-info}/WHEEL +0 -0
- {xax-0.1.9.dist-info → xax-0.1.10.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.9.dist-info → xax-0.1.10.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/mixins/train.py
CHANGED
@@ -56,6 +56,7 @@ from xax.utils.experiments import (
|
|
56
56
|
from xax.utils.jax import jit as xax_jit
|
57
57
|
from xax.utils.logging import LOG_STATUS
|
58
58
|
from xax.utils.text import highlight_exception_message, show_info
|
59
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
59
60
|
|
60
61
|
logger = logging.getLogger(__name__)
|
61
62
|
|
@@ -215,7 +216,7 @@ class TrainMixin(
|
|
215
216
|
state = super().on_step_end(state)
|
216
217
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
217
218
|
|
218
|
-
def log_train_step(self, batch: Batch, output: Output, state: State) -> None:
|
219
|
+
def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
219
220
|
"""Override this function to do logging during the training phase.
|
220
221
|
|
221
222
|
This function is called after the model forward pass and before the
|
@@ -224,10 +225,11 @@ class TrainMixin(
|
|
224
225
|
Args:
|
225
226
|
batch: The batch from the dataloader.
|
226
227
|
output: The model output.
|
228
|
+
metrics: The metrics for the current batch.
|
227
229
|
state: The current training state.
|
228
230
|
"""
|
229
231
|
|
230
|
-
def log_valid_step(self, batch: Batch, output: Output, state: State) -> None:
|
232
|
+
def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
231
233
|
"""Override this function to do logging during the validation phase.
|
232
234
|
|
233
235
|
This function is called after the model forward pass. It is called in
|
@@ -236,6 +238,7 @@ class TrainMixin(
|
|
236
238
|
Args:
|
237
239
|
batch: The batch from the dataloader.
|
238
240
|
output: The model output.
|
241
|
+
metrics: The metrics for the current batch.
|
239
242
|
state: The current training state.
|
240
243
|
"""
|
241
244
|
|
@@ -246,18 +249,23 @@ class TrainMixin(
|
|
246
249
|
for k, v in d.items():
|
247
250
|
self.logger.log_scalar(k, v, namespace=ns)
|
248
251
|
|
249
|
-
def log_step(self, batch: Batch, output: Output,
|
252
|
+
def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
250
253
|
phase = state.phase
|
251
254
|
|
252
|
-
|
255
|
+
for k, v in metrics.items():
|
256
|
+
if v.size == 1:
|
257
|
+
self.logger.log_scalar(k, v.item())
|
258
|
+
else:
|
259
|
+
self.logger.log_histogram(k, v)
|
260
|
+
|
253
261
|
self.log_state_timers(state)
|
254
262
|
|
255
263
|
# Delegate to the appropriate logging function based on the phase.
|
256
264
|
match phase:
|
257
265
|
case "train":
|
258
|
-
self.log_train_step(batch, output, state)
|
266
|
+
self.log_train_step(batch, output, metrics, state)
|
259
267
|
case "valid":
|
260
|
-
self.log_valid_step(batch, output, state)
|
268
|
+
self.log_valid_step(batch, output, metrics, state)
|
261
269
|
case _:
|
262
270
|
raise KeyError(f"Unknown phase: {phase}")
|
263
271
|
|
@@ -364,32 +372,59 @@ class TrainMixin(
|
|
364
372
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
365
373
|
return output
|
366
374
|
|
375
|
+
def compute_metrics(
|
376
|
+
self,
|
377
|
+
model: PyTree,
|
378
|
+
batch: Batch,
|
379
|
+
output: Output,
|
380
|
+
loss: Array,
|
381
|
+
state: State,
|
382
|
+
) -> dict[str, Array]:
|
383
|
+
"""Computes the metrics for the current batch.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
model: The current model.
|
387
|
+
batch: The current minibatch of samples.
|
388
|
+
output: The output from the model.
|
389
|
+
loss: The loss for the current batch.
|
390
|
+
state: The current training state.
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
A dictionary of metrics.
|
394
|
+
"""
|
395
|
+
return {
|
396
|
+
"loss": loss,
|
397
|
+
}
|
398
|
+
|
399
|
+
@xax_jit(static_argnames=["self", "model_static"])
|
367
400
|
def get_output_and_loss(
|
368
401
|
self,
|
369
|
-
model_static: PyTree,
|
370
402
|
model_arr: PyTree,
|
403
|
+
model_static: PyTree,
|
371
404
|
batch: Batch,
|
372
405
|
state: State,
|
373
|
-
) -> tuple[Array, Output]:
|
406
|
+
) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
|
374
407
|
model = eqx.combine(model_arr, model_static)
|
375
408
|
output = self.get_output(model, batch, state)
|
376
409
|
loss = self.compute_loss(model, batch, output, state)
|
377
|
-
|
410
|
+
metrics = self.compute_metrics(model, batch, output, loss, state)
|
411
|
+
return loss, (output, FrozenDict(metrics))
|
378
412
|
|
379
413
|
def update(
|
380
414
|
self,
|
381
|
-
model_static: PyTree,
|
382
415
|
model_arr: PyTree,
|
416
|
+
model_static: PyTree,
|
383
417
|
optimizer: optax.GradientTransformation,
|
384
418
|
opt_state: optax.OptState,
|
385
419
|
batch: Batch,
|
386
420
|
state: State,
|
387
|
-
) -> tuple[
|
388
|
-
grad_fn =
|
389
|
-
|
421
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
422
|
+
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
423
|
+
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
424
|
+
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
390
425
|
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
391
426
|
model_arr = eqx.apply_updates(model_arr, updates)
|
392
|
-
return
|
427
|
+
return model_arr, opt_state, output, metrics
|
393
428
|
|
394
429
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
395
430
|
"""Gets the batch size for the current batch.
|
@@ -469,25 +504,26 @@ class TrainMixin(
|
|
469
504
|
@xax_jit(static_argnames=["self", "model_static", "optimizer"])
|
470
505
|
def train_step(
|
471
506
|
self,
|
472
|
-
model_static: PyTree,
|
473
507
|
model_arr: PyTree,
|
508
|
+
model_static: PyTree,
|
474
509
|
optimizer: optax.GradientTransformation,
|
475
510
|
opt_state: optax.OptState,
|
476
511
|
batch: Batch,
|
477
512
|
state: State,
|
478
|
-
) -> tuple[PyTree, optax.OptState,
|
479
|
-
|
480
|
-
return model_arr, opt_state,
|
513
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
514
|
+
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
515
|
+
return model_arr, opt_state, output, metrics
|
481
516
|
|
482
517
|
@xax_jit(static_argnames=["self", "model_static"])
|
483
518
|
def val_step(
|
484
519
|
self,
|
485
|
-
model_static: PyTree,
|
486
520
|
model_arr: PyTree,
|
521
|
+
model_static: PyTree,
|
487
522
|
batch: Batch,
|
488
523
|
state: State,
|
489
|
-
) -> tuple[
|
490
|
-
|
524
|
+
) -> tuple[Output, FrozenDict[str, Array]]:
|
525
|
+
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
526
|
+
return output, metrics
|
491
527
|
|
492
528
|
def train_loop(
|
493
529
|
self,
|
@@ -509,8 +545,8 @@ class TrainMixin(
|
|
509
545
|
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
510
546
|
)
|
511
547
|
|
512
|
-
|
513
|
-
self.log_step(valid_batch, output,
|
548
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
549
|
+
self.log_step(valid_batch, output, metrics, state)
|
514
550
|
|
515
551
|
state = self.on_step_start(state)
|
516
552
|
train_batch = next(train_pf)
|
@@ -520,15 +556,15 @@ class TrainMixin(
|
|
520
556
|
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
521
557
|
)
|
522
558
|
|
523
|
-
model_arr, opt_state,
|
524
|
-
model_static=model_static,
|
559
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
525
560
|
model_arr=model_arr,
|
561
|
+
model_static=model_static,
|
526
562
|
optimizer=optimizer,
|
527
563
|
opt_state=opt_state,
|
528
564
|
batch=train_batch,
|
529
565
|
state=state,
|
530
566
|
)
|
531
|
-
self.log_step(train_batch, output,
|
567
|
+
self.log_step(train_batch, output, metrics, state)
|
532
568
|
|
533
569
|
state = self.on_step_end(state)
|
534
570
|
|
xax/utils/pytree.py
CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
|
31
31
|
|
32
32
|
def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
|
33
33
|
"""Get a slice of a pytree."""
|
34
|
-
return jax.
|
34
|
+
return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
|
35
35
|
|
36
36
|
|
37
37
|
def flatten_array(x: Array, flatten_size: int) -> Array:
|
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
|
|
43
43
|
|
44
44
|
def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
45
45
|
"""Flatten a pytree into a (flatten_size, ...) pytree."""
|
46
|
-
return jax.
|
46
|
+
return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
|
47
47
|
|
48
48
|
|
49
49
|
def pytree_has_nans(pytree: PyTree) -> Array:
|
50
50
|
"""Check if a pytree has any NaNs."""
|
51
51
|
has_nans = jax.tree_util.tree_reduce(
|
52
52
|
lambda a, b: jnp.logical_or(a, b),
|
53
|
-
jax.
|
53
|
+
jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
|
54
54
|
)
|
55
55
|
return has_nans
|
56
56
|
|
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
58
58
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
59
|
"""Update a pytree based on a condition."""
|
60
60
|
# Tricky, need use tree_map because where expects array leafs.
|
61
|
-
return jax.
|
61
|
+
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
62
|
|
63
63
|
|
64
64
|
def compute_nan_ratio(pytree: PyTree) -> Array:
|
65
65
|
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
66
|
-
nan_counts = jax.
|
67
|
-
total_counts = jax.
|
66
|
+
nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
67
|
+
total_counts = jax.tree.map(lambda x: x.size, pytree)
|
68
68
|
|
69
69
|
total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
|
70
70
|
total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
|
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
|
|
118
118
|
# Reshape back to the original shape
|
119
119
|
return permuted.reshape(orig_shape)
|
120
120
|
|
121
|
-
return jax.
|
121
|
+
return jax.tree.map(permute_array, data)
|
122
122
|
|
123
123
|
|
124
124
|
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
|
|
133
133
|
return x[tuple(idx_grids)]
|
134
134
|
return x
|
135
135
|
|
136
|
-
return jax.
|
136
|
+
return jax.tree.map(permute_array, data)
|
137
137
|
|
138
138
|
|
139
139
|
TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
|
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
|
|
215
215
|
transpose_info[path] = (transpose_order, original_shape)
|
216
216
|
return x
|
217
217
|
|
218
|
-
jax.
|
218
|
+
jax.tree.map_with_path(prepare_for_shuffle, data)
|
219
219
|
|
220
220
|
# Create a transposed pytree
|
221
221
|
def get_transposed(path: PathType, x: PyTree) -> PyTree:
|
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
|
|
223
223
|
return transposed_data[path]
|
224
224
|
return x
|
225
225
|
|
226
|
-
transposed_pytree = jax.
|
226
|
+
transposed_pytree = jax.tree.map_with_path(get_transposed, data)
|
227
227
|
|
228
228
|
# Reshuffle the transposed pytree along the leading dimensions
|
229
229
|
reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
|
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
|
|
235
235
|
return transpose_back(x, transpose_order, original_shape)
|
236
236
|
return x
|
237
237
|
|
238
|
-
return jax.
|
238
|
+
return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
41
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
42
|
+
xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
45
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -48,7 +48,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
|
48
48
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
49
49
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
50
50
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
51
|
-
xax/utils/pytree.py,sha256=
|
51
|
+
xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
|
52
52
|
xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -56,8 +56,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
56
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
58
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
59
|
+
xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
60
|
+
xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
|
61
|
+
xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
62
|
+
xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
63
|
+
xax-0.1.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|