xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/mixins/train.py CHANGED
@@ -13,14 +13,24 @@ import traceback
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass, is_dataclass
15
15
  from threading import Thread
16
- from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
16
+ from typing import (
17
+ Any,
18
+ Generator,
19
+ Generic,
20
+ Iterator,
21
+ Literal,
22
+ Mapping,
23
+ Sequence,
24
+ TypeVar,
25
+ cast,
26
+ get_args,
27
+ )
17
28
 
18
29
  import equinox as eqx
19
30
  import jax
20
- import jax.numpy as jnp
21
31
  import numpy as np
22
32
  import optax
23
- from jaxtyping import Array, PyTree
33
+ from jaxtyping import Array, PRNGKeyArray, PyTree
24
34
  from omegaconf import DictConfig
25
35
 
26
36
  from xax.core.conf import field
@@ -130,6 +140,7 @@ class ValidStepTimer:
130
140
  return False
131
141
 
132
142
 
143
+ @jax.tree_util.register_dataclass
133
144
  @dataclass
134
145
  class TrainConfig(
135
146
  CheckpointingConfig,
@@ -191,16 +202,13 @@ class TrainMixin(
191
202
  # The kind of step that was specified in the config.
192
203
  self._step_kind = cast_step_kind(self.config.step_kind)
193
204
 
194
- def prng_key(self) -> jnp.ndarray:
205
+ def prng_key(self) -> PRNGKeyArray:
195
206
  return jax.random.PRNGKey(self.config.random_seed)
196
207
 
197
208
  def on_step_end(self, state: State) -> State:
198
209
  state = super().on_step_end(state)
199
- return state.replace(
200
- {
201
- "elapsed_time_s": time.time() - state.start_time_s,
202
- },
203
- )
210
+ state.elapsed_time_s = time.time() - state.start_time_s
211
+ return state
204
212
 
205
213
  def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
206
214
  """Override this function to do logging during the training phase.
@@ -228,16 +236,19 @@ class TrainMixin(
228
236
  state: The current training state.
229
237
  """
230
238
 
231
- def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
232
- phase = state.phase
233
-
234
- # Log the state timers.
235
- timer = self.state_timers[phase]
239
+ def log_state_timers(self, state: State) -> None:
240
+ timer = self.state_timers[state.phase]
236
241
  timer.step(state)
237
242
  for ns, d in timer.log_dict().items():
238
243
  for k, v in d.items():
239
244
  self.logger.log_scalar(k, v, namespace=ns)
240
245
 
246
+ def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
247
+ phase = state.phase
248
+
249
+ self.logger.log_scalar("loss", loss, namespace="loss")
250
+ self.log_state_timers(state)
251
+
241
252
  # Delegate to the appropriate logging function based on the phase.
242
253
  match phase:
243
254
  case "train":
@@ -247,8 +258,10 @@ class TrainMixin(
247
258
  case _:
248
259
  raise KeyError(f"Unknown phase: {phase}")
249
260
 
261
+ self.write_logs(state)
262
+
250
263
  @abstractmethod
251
- def get_model(self) -> PyTree:
264
+ def get_model(self, key: PRNGKeyArray) -> PyTree:
252
265
  """Returns the Equinox model to train.
253
266
 
254
267
  Returns:
@@ -266,7 +279,10 @@ class TrainMixin(
266
279
  def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
267
280
  return optimizer.init(eqx.filter(model, eqx.is_array))
268
281
 
269
- def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
282
+ def load_initial_state(
283
+ self,
284
+ key: PRNGKeyArray,
285
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
270
286
  init_ckpt_path = self.get_init_ckpt_path()
271
287
 
272
288
  if init_ckpt_path is not None:
@@ -279,18 +295,18 @@ class TrainMixin(
279
295
  return model, optimizer, opt_state, state
280
296
 
281
297
  with self.step_context("get_model"):
282
- model = self.get_model()
298
+ model = self.get_model(key)
283
299
 
284
300
  with self.step_context("get_optimizer"):
285
301
  optimizer = self.get_optimizer()
286
302
 
287
303
  with self.step_context("get_initial_opt_state"):
288
- opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
304
+ opt_state = self.get_initial_opt_state(model, optimizer)
289
305
 
290
306
  return model, optimizer, opt_state, State.init_state()
291
307
 
292
- @abstractmethod
293
- def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
308
+ @eqx.filter_jit
309
+ def get_output(self, model: PyTree, batch: Batch) -> Output:
294
310
  """Gets the output from the model.
295
311
 
296
312
  By default, we assume the model is a function that takes the batch as
@@ -300,10 +316,11 @@ class TrainMixin(
300
316
  Args:
301
317
  model: The current model.
302
318
  batch: The current minibatch of samples.
303
- state: The current training state.
304
319
  """
320
+ raise NotImplementedError("`get_output` must be implemented by the subclass")
305
321
 
306
- def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
322
+ @eqx.filter_jit
323
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
307
324
  """Gets the loss for the current batch.
308
325
 
309
326
  By default, we assume the model is a function that takes the batch as
@@ -314,7 +331,6 @@ class TrainMixin(
314
331
  model: The current model.
315
332
  batch: The current minibatch of samples.
316
333
  output: The output from the model.
317
- state: The current training state.
318
334
 
319
335
  Returns:
320
336
  The computed loss, as a tensor.
@@ -323,9 +339,10 @@ class TrainMixin(
323
339
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
324
340
  return output
325
341
 
326
- def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
327
- output = self.get_output(model, batch, state)
328
- loss = self.compute_loss(model, batch, output, state)
342
+ @eqx.filter_jit
343
+ def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
344
+ output = self.get_output(model, batch)
345
+ loss = self.compute_loss(model, batch, output)
329
346
  return loss, output
330
347
 
331
348
  @eqx.filter_jit
@@ -335,10 +352,9 @@ class TrainMixin(
335
352
  optimizer: optax.GradientTransformation,
336
353
  opt_state: optax.OptState,
337
354
  batch: Batch,
338
- state: State,
339
355
  ) -> tuple[Array, PyTree, optax.OptState, Output]:
340
- (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
341
- updates, opt_state = optimizer.update(grads, opt_state)
356
+ (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
357
+ updates, opt_state = optimizer.update(grads, opt_state, model)
342
358
  model = eqx.apply_updates(model, updates)
343
359
  return loss, model, opt_state, output
344
360
 
@@ -377,7 +393,13 @@ class TrainMixin(
377
393
  self._last_printed_remaining_time = state.elapsed_time_s
378
394
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
379
395
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
380
- logger.info("Estimated finish time: %s", termination_time)
396
+ # logger.info("Estimated finish time: %s", termination_time)
397
+ jax.debug.print("Estimated finish time: {}", termination_time)
398
+
399
+ def get_remaining_percent(self, state: State) -> float | None:
400
+ if self.config.max_steps is None:
401
+ return None
402
+ return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
381
403
 
382
404
  def is_training_over(self, state: State) -> bool:
383
405
  if self._training_over_flag:
@@ -385,7 +407,6 @@ class TrainMixin(
385
407
  remaining_percent = self.get_remaining_percent(state)
386
408
  if remaining_percent is None:
387
409
  return False
388
- self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
389
410
  self.maybe_log_termination_time(remaining_percent, state)
390
411
  return remaining_percent <= 0.0
391
412
 
@@ -400,59 +421,124 @@ class TrainMixin(
400
421
  case _:
401
422
  raise ValueError(f"Invalid step kind {self._step_kind}")
402
423
 
403
- def get_remaining_percent(self, state: State) -> float | None:
404
- if self.config.max_steps is None:
405
- return None
406
- return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
407
-
408
424
  def log_state(self) -> None:
409
425
  logger.log(LOG_STATUS, self.task_path)
410
426
  logger.log(LOG_STATUS, self.task_name)
411
- self.logger.log_git_state(get_git_state(self))
412
- self.logger.log_training_code(get_training_code(self))
413
- self.logger.log_config(cast(DictConfig, self.config))
427
+ self.logger.log_file("git_state.txt", get_git_state(self))
428
+ self.logger.log_file("training_code.txt", get_training_code(self))
429
+ self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
414
430
 
431
+ @eqx.filter_jit
415
432
  def train_step(
416
433
  self,
417
434
  model: PyTree,
418
435
  optimizer: optax.GradientTransformation,
419
436
  opt_state: optax.OptState,
420
437
  batch: Batch,
438
+ ) -> tuple[PyTree, optax.OptState, Array, Output]:
439
+ loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
440
+ return model, opt_state, loss, output
441
+
442
+ @eqx.filter_jit
443
+ def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
444
+ loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
445
+ return model, loss, output
446
+
447
+ def train_loop(
448
+ self,
449
+ model: PyTree,
450
+ optimizer: optax.GradientTransformation,
451
+ opt_state: optax.OptState,
452
+ train_pf: Iterator[Batch],
453
+ valid_pf: Iterator[Batch],
421
454
  state: State,
422
- ) -> tuple[PyTree, optax.OptState, State]:
423
- state = state.with_phase("train")
424
- loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
425
- self.logger.log_scalar("loss", loss, namespace="loss")
426
- self.log_step(model, batch, output, state)
427
- self.write_logs(state)
428
- return (
429
- model,
430
- opt_state,
431
- state.replace(
432
- {
433
- "num_steps": state.num_steps + 1,
434
- "num_samples": state.num_samples + (self.get_size_of_batch(batch) or 0),
435
- },
436
- ),
437
- )
455
+ ) -> None:
456
+ while not self.is_training_over(state):
457
+ if self.valid_step_timer.is_valid_step(state):
458
+ valid_batch = next(valid_pf)
459
+ model, loss, output = self.val_step(model, valid_batch)
460
+
461
+ # Perform logging.
462
+ with self.step_context("write_logs"):
463
+ state.phase = "valid"
464
+ self.log_step(model, valid_batch, output, loss, state)
465
+ state.num_valid_samples += 1
466
+
467
+ with self.step_context("on_step_start"):
468
+ state = self.on_step_start(state)
469
+
470
+ with self.step_context("update_state"):
471
+ train_batch = next(train_pf)
472
+ model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
+
474
+ # Perform logging.
475
+ with self.step_context("write_logs"):
476
+ state.phase = "train"
477
+ self.log_step(model, train_batch, output, loss, state)
478
+ state.num_steps += 1
479
+ state.num_samples += self.get_size_of_batch(train_batch) or 0
480
+
481
+ with self.step_context("on_step_end"):
482
+ state = self.on_step_end(state)
483
+
484
+ if self.should_checkpoint(state):
485
+ self.save_checkpoint(model, optimizer, opt_state, state)
438
486
 
439
- def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
440
- state = state.with_phase("valid")
441
- loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
442
- self.logger.log_scalar("loss", loss, namespace="loss")
443
- self.log_step(model, batch, output, state)
444
- self.write_logs(state)
445
- return model, state.replace(
446
- {
447
- "num_valid_steps": state.num_valid_steps + 1,
448
- "num_valid_samples": state.num_valid_samples + (self.get_size_of_batch(batch) or 0),
449
- },
450
- )
487
+ # After finishing training, save the final checkpoint.
488
+ self.save_checkpoint(model, optimizer, opt_state, state)
489
+
490
+ @contextlib.contextmanager
491
+ def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
492
+ try:
493
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train")
494
+ yield train_iterator
495
+ return
496
+ except NotImplementedError:
497
+ pass
498
+
499
+ with self.step_context("get_dataset"):
500
+ train_ds = self.get_dataset("train")
501
+
502
+ with self.step_context("get_dataloader"):
503
+ train_dl = self.get_dataloader(train_ds, "train")
504
+
505
+ with self.step_context("get_prefetcher"):
506
+ train_pf = self.get_prefetcher(train_dl)
507
+
508
+ try:
509
+ with train_pf as train_pf_ctx:
510
+ yield train_pf_ctx
511
+ finally:
512
+ logger.info("Closing train prefetcher")
513
+
514
+ @contextlib.contextmanager
515
+ def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
516
+ try:
517
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
518
+ yield valid_iterator
519
+ return
520
+ except NotImplementedError:
521
+ pass
522
+
523
+ with self.step_context("get_dataset"):
524
+ valid_ds = self.get_dataset("valid")
525
+
526
+ with self.step_context("get_dataloader"):
527
+ valid_dl = self.get_dataloader(valid_ds, "valid")
528
+
529
+ with self.step_context("get_prefetcher"):
530
+ valid_pf = self.get_prefetcher(valid_dl)
531
+
532
+ try:
533
+ with valid_pf as valid_pf_ctx:
534
+ yield valid_pf_ctx
535
+ finally:
536
+ logger.info("Closing valid prefetcher")
451
537
 
452
538
  def run(self) -> None:
453
- self.run_training_loop()
539
+ self.run_training()
454
540
 
455
- def run_training_loop(self) -> None:
541
+ def run_training(self) -> None:
456
542
  """Runs the training loop.
457
543
 
458
544
  Args:
@@ -464,33 +550,16 @@ class TrainMixin(
464
550
  Raises:
465
551
  ValueError: If the task is not a supervised learning task
466
552
  """
467
- with contextlib.ExitStack() as ctx:
553
+ with self:
554
+ key = self.prng_key()
555
+
468
556
  self.set_loggers()
469
557
 
470
558
  if is_master():
471
559
  Thread(target=self.log_state, daemon=True).start()
472
560
 
473
- # Gets the datasets.
474
- with self.step_context("get_dataset"):
475
- train_ds = self.get_dataset("train")
476
- valid_ds = self.get_dataset("valid")
477
-
478
- # Gets the dataloaders.
479
- with self.step_context("get_dataloader"):
480
- train_dl = self.get_dataloader(train_ds, "train")
481
- valid_dl = self.get_dataloader(valid_ds, "valid")
482
-
483
- # Gets the prefetchers.
484
- with self.step_context("get_prefetcher"):
485
- train_pf = self.get_prefetcher(train_dl)
486
- valid_pf = self.get_prefetcher(valid_dl)
487
-
488
- ctx.enter_context(self)
489
- ctx.enter_context(train_pf)
490
- ctx.enter_context(valid_pf)
491
-
492
- model, optimizer, opt_state, state = self.load_initial_state()
493
-
561
+ key, model_key = jax.random.split(key)
562
+ model, optimizer, opt_state, state = self.load_initial_state(model_key)
494
563
  state = self.on_training_start(state)
495
564
 
496
565
  def on_exit() -> None:
@@ -499,43 +568,34 @@ class TrainMixin(
499
568
  # Handle user-defined interrupts during the training loop.
500
569
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
501
570
 
502
- try:
503
- while True:
504
- while True:
505
- if self.is_training_over(state):
506
- raise TrainingFinishedError
507
-
508
- if self.valid_step_timer.is_valid_step(state):
509
- model, state = self.val_step(model, next(valid_pf), state)
510
-
511
- with self.step_context("on_step_start"):
512
- state = self.on_step_start(state)
513
-
514
- model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
515
-
516
- with self.step_context("on_step_end"):
517
- state = self.on_step_end(state)
518
-
519
- if self.should_checkpoint(state):
520
- self.save_checkpoint(model, optimizer, opt_state, state)
521
-
522
- except TrainingFinishedError:
523
- if is_master():
524
- show_info(
525
- f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
526
- important=True,
571
+ with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
572
+ try:
573
+ self.train_loop(
574
+ model=model,
575
+ optimizer=optimizer,
576
+ opt_state=opt_state,
577
+ train_pf=train_pf,
578
+ valid_pf=valid_pf,
579
+ state=state,
527
580
  )
528
- self.save_checkpoint(model, optimizer, opt_state, state)
529
-
530
- except (KeyboardInterrupt, bdb.BdbQuit):
531
- if is_master():
532
- show_info("Interrupted training", important=True)
533
-
534
- except BaseException:
535
- exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
536
- sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
537
- sys.stdout.flush()
538
- self.save_checkpoint(model, optimizer, opt_state, state)
539
581
 
540
- finally:
541
- state = self.on_training_end(state)
582
+ except TrainingFinishedError:
583
+ if is_master():
584
+ show_info(
585
+ f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
586
+ important=True,
587
+ )
588
+ self.save_checkpoint(model, optimizer, opt_state, state)
589
+
590
+ except (KeyboardInterrupt, bdb.BdbQuit):
591
+ if is_master():
592
+ show_info("Interrupted training", important=True)
593
+
594
+ except BaseException:
595
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
596
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
597
+ sys.stdout.flush()
598
+ self.save_checkpoint(model, optimizer, opt_state, state)
599
+
600
+ finally:
601
+ state = self.on_training_end(state)
xax/task/script.py CHANGED
@@ -22,7 +22,7 @@ from xax.task.mixins import (
22
22
  )
23
23
 
24
24
 
25
- @dataclass
25
+ @dataclass(kw_only=True)
26
26
  class ScriptConfig(
27
27
  CPUStatsConfig,
28
28
  GPUStatsConfig,
xax/task/task.py CHANGED
@@ -3,12 +3,16 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Generic, TypeVar
5
5
 
6
+ import jax
7
+
6
8
  from xax.task.base import BaseConfig, BaseTask
7
9
  from xax.task.mixins import (
8
10
  ArtifactsConfig,
9
11
  ArtifactsMixin,
10
12
  CheckpointingConfig,
11
13
  CheckpointingMixin,
14
+ CompileConfig,
15
+ CompileMixin,
12
16
  CPUStatsConfig,
13
17
  CPUStatsMixin,
14
18
  DataloadersConfig,
@@ -28,10 +32,12 @@ from xax.task.mixins import (
28
32
  )
29
33
 
30
34
 
35
+ @jax.tree_util.register_dataclass
31
36
  @dataclass
32
37
  class Config(
33
38
  TrainConfig,
34
39
  CheckpointingConfig,
40
+ CompileConfig,
35
41
  DataloadersConfig,
36
42
  CPUStatsConfig,
37
43
  GPUStatsConfig,
@@ -51,6 +57,7 @@ ConfigT = TypeVar("ConfigT", bound=Config)
51
57
  class Task(
52
58
  TrainMixin[ConfigT],
53
59
  CheckpointingMixin[ConfigT],
60
+ CompileMixin[ConfigT],
54
61
  DataloadersMixin[ConfigT],
55
62
  CPUStatsMixin[ConfigT],
56
63
  GPUStatsMixin[ConfigT],
xax/utils/jax.py CHANGED
@@ -1,14 +1,140 @@
1
1
  """Defines some utility functions for interfacing with Jax."""
2
2
 
3
+ import inspect
4
+ import logging
5
+ import os
6
+ import time
7
+ from functools import wraps
8
+ from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
9
+
10
+ import jax
3
11
  import jax.numpy as jnp
4
12
  import numpy as np
13
+ from jax._src import sharding_impls
14
+ from jax._src.lib import xla_client as xc
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ DEFAULT_COMPILE_TIMEOUT = 1.0
5
19
 
6
20
  Number = int | float | np.ndarray | jnp.ndarray
7
21
 
8
22
 
23
+ P = ParamSpec("P") # For function parameters
24
+ R = TypeVar("R") # For function return type
25
+
26
+
9
27
  def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
10
28
  if isinstance(value, (int, float)):
11
29
  return float(value)
12
30
  if isinstance(value, (np.ndarray, jnp.ndarray)):
13
31
  return float(value.item())
14
32
  raise TypeError(f"Unexpected type: {type(value)}")
33
+
34
+
35
+ def get_hash(obj: object) -> int:
36
+ """Get a hash of an object.
37
+
38
+ If the object is hashable, use the hash. Otherwise, use the id.
39
+ """
40
+ if hasattr(obj, "__hash__"):
41
+ return hash(obj)
42
+ return id(obj)
43
+
44
+
45
+ def jit(
46
+ in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
47
+ out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
48
+ static_argnums: int | Sequence[int] | None = None,
49
+ static_argnames: str | Iterable[str] | None = None,
50
+ donate_argnums: int | Sequence[int] | None = None,
51
+ donate_argnames: str | Iterable[str] | None = None,
52
+ keep_unused: bool = False,
53
+ device: xc.Device | None = None,
54
+ backend: str | None = None,
55
+ inline: bool = False,
56
+ abstracted_axes: Any | None = None, # noqa: ANN401
57
+ compiler_options: dict[str, Any] | None = None,
58
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
59
+ """Wrapper function that provides utility improvements over Jax's JIT.
60
+
61
+ Specifically, this function works on class methods, is toggleable, and
62
+ detects recompilations by matching hash values.
63
+
64
+ This is meant to be used as a decorator factory, and the decorated function
65
+ calls `wrapped`.
66
+ """
67
+
68
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
69
+ class JitState:
70
+ compilation_count = 0
71
+ last_arg_dict: dict[str, int] | None = None
72
+
73
+ sig = inspect.signature(fn)
74
+ param_names = list(sig.parameters.keys())
75
+
76
+ jitted_fn = jax.jit(
77
+ fn,
78
+ in_shardings=in_shardings,
79
+ out_shardings=out_shardings,
80
+ static_argnums=static_argnums,
81
+ static_argnames=static_argnames,
82
+ donate_argnums=donate_argnums,
83
+ donate_argnames=donate_argnames,
84
+ keep_unused=keep_unused,
85
+ device=device,
86
+ backend=backend,
87
+ inline=inline,
88
+ abstracted_axes=abstracted_axes,
89
+ compiler_options=compiler_options,
90
+ )
91
+
92
+ @wraps(fn)
93
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
94
+ if os.environ.get("DEBUG", "0") == "1": # skipping during debug
95
+ return fn(*args, **kwargs)
96
+
97
+ do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
98
+
99
+ if do_profile:
100
+ class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
101
+ logger.info(
102
+ "Currently running %s (count: %s)",
103
+ f"{class_name}{fn.__name__}",
104
+ JitState.compilation_count,
105
+ )
106
+
107
+ start_time = time.time()
108
+ res = jitted_fn(*args, **kwargs)
109
+ end_time = time.time()
110
+ runtime = end_time - start_time
111
+
112
+ # if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
113
+ # TODO: we should probably reimplement the lower-level jitting logic to avoid this
114
+ if do_profile:
115
+ arg_dict = {}
116
+ for i, arg in enumerate(args):
117
+ if i < len(param_names):
118
+ arg_dict[param_names[i]] = get_hash(arg)
119
+ for k, v in kwargs.items():
120
+ arg_dict[k] = get_hash(v)
121
+
122
+ logger.info("Hashing took %s seconds", runtime)
123
+ JitState.compilation_count += 1
124
+
125
+ if JitState.last_arg_dict is not None:
126
+ all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
127
+ for k in all_keys:
128
+ prev = JitState.last_arg_dict.get(k, "N/A")
129
+ curr = arg_dict.get(k, "N/A")
130
+
131
+ if prev != curr:
132
+ logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
133
+
134
+ JitState.last_arg_dict = arg_dict
135
+
136
+ return cast(R, res)
137
+
138
+ return wrapped
139
+
140
+ return decorator