xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +102 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/nn/geom.py +75 -0
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/jax.py +126 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +50 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/METADATA +12 -2
- xax-0.0.7.dist-info/RECORD +55 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/top_level.txt +0 -0
xax/task/mixins/train.py
CHANGED
@@ -13,14 +13,24 @@ import traceback
|
|
13
13
|
from abc import ABC, abstractmethod
|
14
14
|
from dataclasses import dataclass, is_dataclass
|
15
15
|
from threading import Thread
|
16
|
-
from typing import
|
16
|
+
from typing import (
|
17
|
+
Any,
|
18
|
+
Generator,
|
19
|
+
Generic,
|
20
|
+
Iterator,
|
21
|
+
Literal,
|
22
|
+
Mapping,
|
23
|
+
Sequence,
|
24
|
+
TypeVar,
|
25
|
+
cast,
|
26
|
+
get_args,
|
27
|
+
)
|
17
28
|
|
18
29
|
import equinox as eqx
|
19
30
|
import jax
|
20
|
-
import jax.numpy as jnp
|
21
31
|
import numpy as np
|
22
32
|
import optax
|
23
|
-
from jaxtyping import Array, PyTree
|
33
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree
|
24
34
|
from omegaconf import DictConfig
|
25
35
|
|
26
36
|
from xax.core.conf import field
|
@@ -130,6 +140,7 @@ class ValidStepTimer:
|
|
130
140
|
return False
|
131
141
|
|
132
142
|
|
143
|
+
@jax.tree_util.register_dataclass
|
133
144
|
@dataclass
|
134
145
|
class TrainConfig(
|
135
146
|
CheckpointingConfig,
|
@@ -191,16 +202,13 @@ class TrainMixin(
|
|
191
202
|
# The kind of step that was specified in the config.
|
192
203
|
self._step_kind = cast_step_kind(self.config.step_kind)
|
193
204
|
|
194
|
-
def prng_key(self) ->
|
205
|
+
def prng_key(self) -> PRNGKeyArray:
|
195
206
|
return jax.random.PRNGKey(self.config.random_seed)
|
196
207
|
|
197
208
|
def on_step_end(self, state: State) -> State:
|
198
209
|
state = super().on_step_end(state)
|
199
|
-
|
200
|
-
|
201
|
-
"elapsed_time_s": time.time() - state.start_time_s,
|
202
|
-
},
|
203
|
-
)
|
210
|
+
state.elapsed_time_s = time.time() - state.start_time_s
|
211
|
+
return state
|
204
212
|
|
205
213
|
def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
206
214
|
"""Override this function to do logging during the training phase.
|
@@ -228,16 +236,19 @@ class TrainMixin(
|
|
228
236
|
state: The current training state.
|
229
237
|
"""
|
230
238
|
|
231
|
-
def
|
232
|
-
|
233
|
-
|
234
|
-
# Log the state timers.
|
235
|
-
timer = self.state_timers[phase]
|
239
|
+
def log_state_timers(self, state: State) -> None:
|
240
|
+
timer = self.state_timers[state.phase]
|
236
241
|
timer.step(state)
|
237
242
|
for ns, d in timer.log_dict().items():
|
238
243
|
for k, v in d.items():
|
239
244
|
self.logger.log_scalar(k, v, namespace=ns)
|
240
245
|
|
246
|
+
def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
|
247
|
+
phase = state.phase
|
248
|
+
|
249
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
250
|
+
self.log_state_timers(state)
|
251
|
+
|
241
252
|
# Delegate to the appropriate logging function based on the phase.
|
242
253
|
match phase:
|
243
254
|
case "train":
|
@@ -247,8 +258,10 @@ class TrainMixin(
|
|
247
258
|
case _:
|
248
259
|
raise KeyError(f"Unknown phase: {phase}")
|
249
260
|
|
261
|
+
self.write_logs(state)
|
262
|
+
|
250
263
|
@abstractmethod
|
251
|
-
def get_model(self) -> PyTree:
|
264
|
+
def get_model(self, key: PRNGKeyArray) -> PyTree:
|
252
265
|
"""Returns the Equinox model to train.
|
253
266
|
|
254
267
|
Returns:
|
@@ -266,7 +279,10 @@ class TrainMixin(
|
|
266
279
|
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
267
280
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
268
281
|
|
269
|
-
def load_initial_state(
|
282
|
+
def load_initial_state(
|
283
|
+
self,
|
284
|
+
key: PRNGKeyArray,
|
285
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
270
286
|
init_ckpt_path = self.get_init_ckpt_path()
|
271
287
|
|
272
288
|
if init_ckpt_path is not None:
|
@@ -279,18 +295,18 @@ class TrainMixin(
|
|
279
295
|
return model, optimizer, opt_state, state
|
280
296
|
|
281
297
|
with self.step_context("get_model"):
|
282
|
-
model = self.get_model()
|
298
|
+
model = self.get_model(key)
|
283
299
|
|
284
300
|
with self.step_context("get_optimizer"):
|
285
301
|
optimizer = self.get_optimizer()
|
286
302
|
|
287
303
|
with self.step_context("get_initial_opt_state"):
|
288
|
-
opt_state =
|
304
|
+
opt_state = self.get_initial_opt_state(model, optimizer)
|
289
305
|
|
290
306
|
return model, optimizer, opt_state, State.init_state()
|
291
307
|
|
292
|
-
@
|
293
|
-
def get_output(self, model: PyTree, batch: Batch
|
308
|
+
@eqx.filter_jit
|
309
|
+
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
294
310
|
"""Gets the output from the model.
|
295
311
|
|
296
312
|
By default, we assume the model is a function that takes the batch as
|
@@ -300,10 +316,11 @@ class TrainMixin(
|
|
300
316
|
Args:
|
301
317
|
model: The current model.
|
302
318
|
batch: The current minibatch of samples.
|
303
|
-
state: The current training state.
|
304
319
|
"""
|
320
|
+
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
305
321
|
|
306
|
-
|
322
|
+
@eqx.filter_jit
|
323
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
|
307
324
|
"""Gets the loss for the current batch.
|
308
325
|
|
309
326
|
By default, we assume the model is a function that takes the batch as
|
@@ -314,7 +331,6 @@ class TrainMixin(
|
|
314
331
|
model: The current model.
|
315
332
|
batch: The current minibatch of samples.
|
316
333
|
output: The output from the model.
|
317
|
-
state: The current training state.
|
318
334
|
|
319
335
|
Returns:
|
320
336
|
The computed loss, as a tensor.
|
@@ -323,9 +339,10 @@ class TrainMixin(
|
|
323
339
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
324
340
|
return output
|
325
341
|
|
326
|
-
|
327
|
-
|
328
|
-
|
342
|
+
@eqx.filter_jit
|
343
|
+
def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
|
344
|
+
output = self.get_output(model, batch)
|
345
|
+
loss = self.compute_loss(model, batch, output)
|
329
346
|
return loss, output
|
330
347
|
|
331
348
|
@eqx.filter_jit
|
@@ -335,10 +352,9 @@ class TrainMixin(
|
|
335
352
|
optimizer: optax.GradientTransformation,
|
336
353
|
opt_state: optax.OptState,
|
337
354
|
batch: Batch,
|
338
|
-
state: State,
|
339
355
|
) -> tuple[Array, PyTree, optax.OptState, Output]:
|
340
|
-
(loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch
|
341
|
-
updates, opt_state = optimizer.update(grads, opt_state)
|
356
|
+
(loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
|
357
|
+
updates, opt_state = optimizer.update(grads, opt_state, model)
|
342
358
|
model = eqx.apply_updates(model, updates)
|
343
359
|
return loss, model, opt_state, output
|
344
360
|
|
@@ -377,7 +393,13 @@ class TrainMixin(
|
|
377
393
|
self._last_printed_remaining_time = state.elapsed_time_s
|
378
394
|
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
379
395
|
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
380
|
-
logger.info("Estimated finish time: %s", termination_time)
|
396
|
+
# logger.info("Estimated finish time: %s", termination_time)
|
397
|
+
jax.debug.print("Estimated finish time: {}", termination_time)
|
398
|
+
|
399
|
+
def get_remaining_percent(self, state: State) -> float | None:
|
400
|
+
if self.config.max_steps is None:
|
401
|
+
return None
|
402
|
+
return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
|
381
403
|
|
382
404
|
def is_training_over(self, state: State) -> bool:
|
383
405
|
if self._training_over_flag:
|
@@ -385,7 +407,6 @@ class TrainMixin(
|
|
385
407
|
remaining_percent = self.get_remaining_percent(state)
|
386
408
|
if remaining_percent is None:
|
387
409
|
return False
|
388
|
-
self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
|
389
410
|
self.maybe_log_termination_time(remaining_percent, state)
|
390
411
|
return remaining_percent <= 0.0
|
391
412
|
|
@@ -400,59 +421,124 @@ class TrainMixin(
|
|
400
421
|
case _:
|
401
422
|
raise ValueError(f"Invalid step kind {self._step_kind}")
|
402
423
|
|
403
|
-
def get_remaining_percent(self, state: State) -> float | None:
|
404
|
-
if self.config.max_steps is None:
|
405
|
-
return None
|
406
|
-
return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
|
407
|
-
|
408
424
|
def log_state(self) -> None:
|
409
425
|
logger.log(LOG_STATUS, self.task_path)
|
410
426
|
logger.log(LOG_STATUS, self.task_name)
|
411
|
-
self.logger.
|
412
|
-
self.logger.
|
413
|
-
self.logger.
|
427
|
+
self.logger.log_file("git_state.txt", get_git_state(self))
|
428
|
+
self.logger.log_file("training_code.txt", get_training_code(self))
|
429
|
+
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
414
430
|
|
431
|
+
@eqx.filter_jit
|
415
432
|
def train_step(
|
416
433
|
self,
|
417
434
|
model: PyTree,
|
418
435
|
optimizer: optax.GradientTransformation,
|
419
436
|
opt_state: optax.OptState,
|
420
437
|
batch: Batch,
|
438
|
+
) -> tuple[PyTree, optax.OptState, Array, Output]:
|
439
|
+
loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
|
440
|
+
return model, opt_state, loss, output
|
441
|
+
|
442
|
+
@eqx.filter_jit
|
443
|
+
def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
|
444
|
+
loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
|
445
|
+
return model, loss, output
|
446
|
+
|
447
|
+
def train_loop(
|
448
|
+
self,
|
449
|
+
model: PyTree,
|
450
|
+
optimizer: optax.GradientTransformation,
|
451
|
+
opt_state: optax.OptState,
|
452
|
+
train_pf: Iterator[Batch],
|
453
|
+
valid_pf: Iterator[Batch],
|
421
454
|
state: State,
|
422
|
-
) ->
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
455
|
+
) -> None:
|
456
|
+
while not self.is_training_over(state):
|
457
|
+
if self.valid_step_timer.is_valid_step(state):
|
458
|
+
valid_batch = next(valid_pf)
|
459
|
+
model, loss, output = self.val_step(model, valid_batch)
|
460
|
+
|
461
|
+
# Perform logging.
|
462
|
+
with self.step_context("write_logs"):
|
463
|
+
state.phase = "valid"
|
464
|
+
self.log_step(model, valid_batch, output, loss, state)
|
465
|
+
state.num_valid_samples += 1
|
466
|
+
|
467
|
+
with self.step_context("on_step_start"):
|
468
|
+
state = self.on_step_start(state)
|
469
|
+
|
470
|
+
with self.step_context("update_state"):
|
471
|
+
train_batch = next(train_pf)
|
472
|
+
model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
|
473
|
+
|
474
|
+
# Perform logging.
|
475
|
+
with self.step_context("write_logs"):
|
476
|
+
state.phase = "train"
|
477
|
+
self.log_step(model, train_batch, output, loss, state)
|
478
|
+
state.num_steps += 1
|
479
|
+
state.num_samples += self.get_size_of_batch(train_batch) or 0
|
480
|
+
|
481
|
+
with self.step_context("on_step_end"):
|
482
|
+
state = self.on_step_end(state)
|
483
|
+
|
484
|
+
if self.should_checkpoint(state):
|
485
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
438
486
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
487
|
+
# After finishing training, save the final checkpoint.
|
488
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
489
|
+
|
490
|
+
@contextlib.contextmanager
|
491
|
+
def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
492
|
+
try:
|
493
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train")
|
494
|
+
yield train_iterator
|
495
|
+
return
|
496
|
+
except NotImplementedError:
|
497
|
+
pass
|
498
|
+
|
499
|
+
with self.step_context("get_dataset"):
|
500
|
+
train_ds = self.get_dataset("train")
|
501
|
+
|
502
|
+
with self.step_context("get_dataloader"):
|
503
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
504
|
+
|
505
|
+
with self.step_context("get_prefetcher"):
|
506
|
+
train_pf = self.get_prefetcher(train_dl)
|
507
|
+
|
508
|
+
try:
|
509
|
+
with train_pf as train_pf_ctx:
|
510
|
+
yield train_pf_ctx
|
511
|
+
finally:
|
512
|
+
logger.info("Closing train prefetcher")
|
513
|
+
|
514
|
+
@contextlib.contextmanager
|
515
|
+
def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
516
|
+
try:
|
517
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
|
518
|
+
yield valid_iterator
|
519
|
+
return
|
520
|
+
except NotImplementedError:
|
521
|
+
pass
|
522
|
+
|
523
|
+
with self.step_context("get_dataset"):
|
524
|
+
valid_ds = self.get_dataset("valid")
|
525
|
+
|
526
|
+
with self.step_context("get_dataloader"):
|
527
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
528
|
+
|
529
|
+
with self.step_context("get_prefetcher"):
|
530
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
531
|
+
|
532
|
+
try:
|
533
|
+
with valid_pf as valid_pf_ctx:
|
534
|
+
yield valid_pf_ctx
|
535
|
+
finally:
|
536
|
+
logger.info("Closing valid prefetcher")
|
451
537
|
|
452
538
|
def run(self) -> None:
|
453
|
-
self.
|
539
|
+
self.run_training()
|
454
540
|
|
455
|
-
def
|
541
|
+
def run_training(self) -> None:
|
456
542
|
"""Runs the training loop.
|
457
543
|
|
458
544
|
Args:
|
@@ -464,33 +550,16 @@ class TrainMixin(
|
|
464
550
|
Raises:
|
465
551
|
ValueError: If the task is not a supervised learning task
|
466
552
|
"""
|
467
|
-
with
|
553
|
+
with self:
|
554
|
+
key = self.prng_key()
|
555
|
+
|
468
556
|
self.set_loggers()
|
469
557
|
|
470
558
|
if is_master():
|
471
559
|
Thread(target=self.log_state, daemon=True).start()
|
472
560
|
|
473
|
-
|
474
|
-
|
475
|
-
train_ds = self.get_dataset("train")
|
476
|
-
valid_ds = self.get_dataset("valid")
|
477
|
-
|
478
|
-
# Gets the dataloaders.
|
479
|
-
with self.step_context("get_dataloader"):
|
480
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
481
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
482
|
-
|
483
|
-
# Gets the prefetchers.
|
484
|
-
with self.step_context("get_prefetcher"):
|
485
|
-
train_pf = self.get_prefetcher(train_dl)
|
486
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
487
|
-
|
488
|
-
ctx.enter_context(self)
|
489
|
-
ctx.enter_context(train_pf)
|
490
|
-
ctx.enter_context(valid_pf)
|
491
|
-
|
492
|
-
model, optimizer, opt_state, state = self.load_initial_state()
|
493
|
-
|
561
|
+
key, model_key = jax.random.split(key)
|
562
|
+
model, optimizer, opt_state, state = self.load_initial_state(model_key)
|
494
563
|
state = self.on_training_start(state)
|
495
564
|
|
496
565
|
def on_exit() -> None:
|
@@ -499,43 +568,34 @@ class TrainMixin(
|
|
499
568
|
# Handle user-defined interrupts during the training loop.
|
500
569
|
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
501
570
|
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
with self.step_context("on_step_start"):
|
512
|
-
state = self.on_step_start(state)
|
513
|
-
|
514
|
-
model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
|
515
|
-
|
516
|
-
with self.step_context("on_step_end"):
|
517
|
-
state = self.on_step_end(state)
|
518
|
-
|
519
|
-
if self.should_checkpoint(state):
|
520
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
521
|
-
|
522
|
-
except TrainingFinishedError:
|
523
|
-
if is_master():
|
524
|
-
show_info(
|
525
|
-
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
526
|
-
important=True,
|
571
|
+
with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
|
572
|
+
try:
|
573
|
+
self.train_loop(
|
574
|
+
model=model,
|
575
|
+
optimizer=optimizer,
|
576
|
+
opt_state=opt_state,
|
577
|
+
train_pf=train_pf,
|
578
|
+
valid_pf=valid_pf,
|
579
|
+
state=state,
|
527
580
|
)
|
528
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
529
|
-
|
530
|
-
except (KeyboardInterrupt, bdb.BdbQuit):
|
531
|
-
if is_master():
|
532
|
-
show_info("Interrupted training", important=True)
|
533
|
-
|
534
|
-
except BaseException:
|
535
|
-
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
536
|
-
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
537
|
-
sys.stdout.flush()
|
538
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
539
581
|
|
540
|
-
|
541
|
-
|
582
|
+
except TrainingFinishedError:
|
583
|
+
if is_master():
|
584
|
+
show_info(
|
585
|
+
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
586
|
+
important=True,
|
587
|
+
)
|
588
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
589
|
+
|
590
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
591
|
+
if is_master():
|
592
|
+
show_info("Interrupted training", important=True)
|
593
|
+
|
594
|
+
except BaseException:
|
595
|
+
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
596
|
+
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
597
|
+
sys.stdout.flush()
|
598
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
599
|
+
|
600
|
+
finally:
|
601
|
+
state = self.on_training_end(state)
|
xax/task/script.py
CHANGED
xax/task/task.py
CHANGED
@@ -3,12 +3,16 @@
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Generic, TypeVar
|
5
5
|
|
6
|
+
import jax
|
7
|
+
|
6
8
|
from xax.task.base import BaseConfig, BaseTask
|
7
9
|
from xax.task.mixins import (
|
8
10
|
ArtifactsConfig,
|
9
11
|
ArtifactsMixin,
|
10
12
|
CheckpointingConfig,
|
11
13
|
CheckpointingMixin,
|
14
|
+
CompileConfig,
|
15
|
+
CompileMixin,
|
12
16
|
CPUStatsConfig,
|
13
17
|
CPUStatsMixin,
|
14
18
|
DataloadersConfig,
|
@@ -28,10 +32,12 @@ from xax.task.mixins import (
|
|
28
32
|
)
|
29
33
|
|
30
34
|
|
35
|
+
@jax.tree_util.register_dataclass
|
31
36
|
@dataclass
|
32
37
|
class Config(
|
33
38
|
TrainConfig,
|
34
39
|
CheckpointingConfig,
|
40
|
+
CompileConfig,
|
35
41
|
DataloadersConfig,
|
36
42
|
CPUStatsConfig,
|
37
43
|
GPUStatsConfig,
|
@@ -51,6 +57,7 @@ ConfigT = TypeVar("ConfigT", bound=Config)
|
|
51
57
|
class Task(
|
52
58
|
TrainMixin[ConfigT],
|
53
59
|
CheckpointingMixin[ConfigT],
|
60
|
+
CompileMixin[ConfigT],
|
54
61
|
DataloadersMixin[ConfigT],
|
55
62
|
CPUStatsMixin[ConfigT],
|
56
63
|
GPUStatsMixin[ConfigT],
|
xax/utils/jax.py
CHANGED
@@ -1,14 +1,140 @@
|
|
1
1
|
"""Defines some utility functions for interfacing with Jax."""
|
2
2
|
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import time
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
|
10
|
+
import jax
|
3
11
|
import jax.numpy as jnp
|
4
12
|
import numpy as np
|
13
|
+
from jax._src import sharding_impls
|
14
|
+
from jax._src.lib import xla_client as xc
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
DEFAULT_COMPILE_TIMEOUT = 1.0
|
5
19
|
|
6
20
|
Number = int | float | np.ndarray | jnp.ndarray
|
7
21
|
|
8
22
|
|
23
|
+
P = ParamSpec("P") # For function parameters
|
24
|
+
R = TypeVar("R") # For function return type
|
25
|
+
|
26
|
+
|
9
27
|
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
10
28
|
if isinstance(value, (int, float)):
|
11
29
|
return float(value)
|
12
30
|
if isinstance(value, (np.ndarray, jnp.ndarray)):
|
13
31
|
return float(value.item())
|
14
32
|
raise TypeError(f"Unexpected type: {type(value)}")
|
33
|
+
|
34
|
+
|
35
|
+
def get_hash(obj: object) -> int:
|
36
|
+
"""Get a hash of an object.
|
37
|
+
|
38
|
+
If the object is hashable, use the hash. Otherwise, use the id.
|
39
|
+
"""
|
40
|
+
if hasattr(obj, "__hash__"):
|
41
|
+
return hash(obj)
|
42
|
+
return id(obj)
|
43
|
+
|
44
|
+
|
45
|
+
def jit(
|
46
|
+
in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
47
|
+
out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
48
|
+
static_argnums: int | Sequence[int] | None = None,
|
49
|
+
static_argnames: str | Iterable[str] | None = None,
|
50
|
+
donate_argnums: int | Sequence[int] | None = None,
|
51
|
+
donate_argnames: str | Iterable[str] | None = None,
|
52
|
+
keep_unused: bool = False,
|
53
|
+
device: xc.Device | None = None,
|
54
|
+
backend: str | None = None,
|
55
|
+
inline: bool = False,
|
56
|
+
abstracted_axes: Any | None = None, # noqa: ANN401
|
57
|
+
compiler_options: dict[str, Any] | None = None,
|
58
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
59
|
+
"""Wrapper function that provides utility improvements over Jax's JIT.
|
60
|
+
|
61
|
+
Specifically, this function works on class methods, is toggleable, and
|
62
|
+
detects recompilations by matching hash values.
|
63
|
+
|
64
|
+
This is meant to be used as a decorator factory, and the decorated function
|
65
|
+
calls `wrapped`.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
69
|
+
class JitState:
|
70
|
+
compilation_count = 0
|
71
|
+
last_arg_dict: dict[str, int] | None = None
|
72
|
+
|
73
|
+
sig = inspect.signature(fn)
|
74
|
+
param_names = list(sig.parameters.keys())
|
75
|
+
|
76
|
+
jitted_fn = jax.jit(
|
77
|
+
fn,
|
78
|
+
in_shardings=in_shardings,
|
79
|
+
out_shardings=out_shardings,
|
80
|
+
static_argnums=static_argnums,
|
81
|
+
static_argnames=static_argnames,
|
82
|
+
donate_argnums=donate_argnums,
|
83
|
+
donate_argnames=donate_argnames,
|
84
|
+
keep_unused=keep_unused,
|
85
|
+
device=device,
|
86
|
+
backend=backend,
|
87
|
+
inline=inline,
|
88
|
+
abstracted_axes=abstracted_axes,
|
89
|
+
compiler_options=compiler_options,
|
90
|
+
)
|
91
|
+
|
92
|
+
@wraps(fn)
|
93
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
94
|
+
if os.environ.get("DEBUG", "0") == "1": # skipping during debug
|
95
|
+
return fn(*args, **kwargs)
|
96
|
+
|
97
|
+
do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
|
98
|
+
|
99
|
+
if do_profile:
|
100
|
+
class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
|
101
|
+
logger.info(
|
102
|
+
"Currently running %s (count: %s)",
|
103
|
+
f"{class_name}{fn.__name__}",
|
104
|
+
JitState.compilation_count,
|
105
|
+
)
|
106
|
+
|
107
|
+
start_time = time.time()
|
108
|
+
res = jitted_fn(*args, **kwargs)
|
109
|
+
end_time = time.time()
|
110
|
+
runtime = end_time - start_time
|
111
|
+
|
112
|
+
# if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
|
113
|
+
# TODO: we should probably reimplement the lower-level jitting logic to avoid this
|
114
|
+
if do_profile:
|
115
|
+
arg_dict = {}
|
116
|
+
for i, arg in enumerate(args):
|
117
|
+
if i < len(param_names):
|
118
|
+
arg_dict[param_names[i]] = get_hash(arg)
|
119
|
+
for k, v in kwargs.items():
|
120
|
+
arg_dict[k] = get_hash(v)
|
121
|
+
|
122
|
+
logger.info("Hashing took %s seconds", runtime)
|
123
|
+
JitState.compilation_count += 1
|
124
|
+
|
125
|
+
if JitState.last_arg_dict is not None:
|
126
|
+
all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
|
127
|
+
for k in all_keys:
|
128
|
+
prev = JitState.last_arg_dict.get(k, "N/A")
|
129
|
+
curr = arg_dict.get(k, "N/A")
|
130
|
+
|
131
|
+
if prev != curr:
|
132
|
+
logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
|
133
|
+
|
134
|
+
JitState.last_arg_dict = arg_dict
|
135
|
+
|
136
|
+
return cast(R, res)
|
137
|
+
|
138
|
+
return wrapped
|
139
|
+
|
140
|
+
return decorator
|