xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +121 -3
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +101 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jax.py +126 -0
- xax/utils/jaxpr.py +77 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +238 -0
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/RECORD +28 -20
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/task/mixins/train.py
CHANGED
@@ -24,6 +24,7 @@ from typing import (
|
|
24
24
|
TypeVar,
|
25
25
|
cast,
|
26
26
|
get_args,
|
27
|
+
overload,
|
27
28
|
)
|
28
29
|
|
29
30
|
import equinox as eqx
|
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
|
|
35
36
|
|
36
37
|
from xax.core.conf import field
|
37
38
|
from xax.core.state import Phase, State
|
39
|
+
from xax.nn.functions import set_random_seed
|
38
40
|
from xax.nn.parallel import is_master
|
39
41
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
40
42
|
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
@@ -115,7 +117,7 @@ class ValidStepTimer:
|
|
115
117
|
if self.last_valid_time is None or self.last_valid_step is None:
|
116
118
|
self.last_valid_time = state.elapsed_time_s
|
117
119
|
self.last_valid_step = state.num_steps
|
118
|
-
return
|
120
|
+
return False
|
119
121
|
|
120
122
|
# Step-based validation.
|
121
123
|
valid_every_n_steps = self.valid_every_n_steps
|
@@ -183,6 +185,9 @@ class TrainMixin(
|
|
183
185
|
def __init__(self, config: Config) -> None:
|
184
186
|
super().__init__(config)
|
185
187
|
|
188
|
+
# Sets the random seed whenever we instantiate a new train mixin.
|
189
|
+
set_random_seed(self.config.random_seed)
|
190
|
+
|
186
191
|
# Timer for validation steps.
|
187
192
|
self.valid_step_timer = ValidStepTimer(
|
188
193
|
valid_every_n_steps=config.valid_every_n_steps,
|
@@ -279,31 +284,53 @@ class TrainMixin(
|
|
279
284
|
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
280
285
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
281
286
|
|
287
|
+
@overload
|
288
|
+
def load_initial_state(
|
289
|
+
self,
|
290
|
+
key: PRNGKeyArray,
|
291
|
+
load_optimizer: Literal[False] = False,
|
292
|
+
) -> tuple[PyTree, State]: ...
|
293
|
+
|
294
|
+
@overload
|
282
295
|
def load_initial_state(
|
283
296
|
self,
|
284
297
|
key: PRNGKeyArray,
|
285
|
-
|
298
|
+
load_optimizer: Literal[True],
|
299
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
300
|
+
|
301
|
+
def load_initial_state(
|
302
|
+
self,
|
303
|
+
key: PRNGKeyArray,
|
304
|
+
load_optimizer: bool = False,
|
305
|
+
) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
286
306
|
init_ckpt_path = self.get_init_ckpt_path()
|
287
307
|
|
288
308
|
if init_ckpt_path is not None:
|
289
309
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
290
|
-
|
310
|
+
if load_optimizer:
|
291
311
|
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
292
312
|
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
293
313
|
if config_diff:
|
294
314
|
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
295
315
|
return model, optimizer, opt_state, state
|
296
316
|
|
297
|
-
|
298
|
-
|
317
|
+
else:
|
318
|
+
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
319
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
320
|
+
if config_diff:
|
321
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
322
|
+
return model, state
|
323
|
+
|
324
|
+
model = self.get_model(key)
|
325
|
+
state = State.init_state()
|
299
326
|
|
300
|
-
|
301
|
-
|
327
|
+
if not load_optimizer:
|
328
|
+
return model, state
|
302
329
|
|
303
|
-
|
304
|
-
|
330
|
+
optimizer = self.get_optimizer()
|
331
|
+
opt_state = self.get_initial_opt_state(model, optimizer)
|
305
332
|
|
306
|
-
return model, optimizer, opt_state,
|
333
|
+
return model, optimizer, opt_state, state
|
307
334
|
|
308
335
|
@eqx.filter_jit
|
309
336
|
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
@@ -424,6 +451,7 @@ class TrainMixin(
|
|
424
451
|
def log_state(self) -> None:
|
425
452
|
logger.log(LOG_STATUS, self.task_path)
|
426
453
|
logger.log(LOG_STATUS, self.task_name)
|
454
|
+
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
427
455
|
self.logger.log_file("git_state.txt", get_git_state(self))
|
428
456
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
429
457
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
@@ -456,7 +484,8 @@ class TrainMixin(
|
|
456
484
|
while not self.is_training_over(state):
|
457
485
|
if self.valid_step_timer.is_valid_step(state):
|
458
486
|
valid_batch = next(valid_pf)
|
459
|
-
|
487
|
+
with self.step_context("model_step"):
|
488
|
+
model, loss, output = self.val_step(model, valid_batch)
|
460
489
|
|
461
490
|
# Perform logging.
|
462
491
|
with self.step_context("write_logs"):
|
@@ -464,22 +493,19 @@ class TrainMixin(
|
|
464
493
|
self.log_step(model, valid_batch, output, loss, state)
|
465
494
|
state.num_valid_samples += 1
|
466
495
|
|
467
|
-
|
468
|
-
state = self.on_step_start(state)
|
496
|
+
state = self.on_step_start(state)
|
469
497
|
|
470
|
-
with self.step_context("
|
498
|
+
with self.step_context("model_step"):
|
471
499
|
train_batch = next(train_pf)
|
472
500
|
model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
|
473
501
|
|
474
|
-
# Perform logging.
|
475
502
|
with self.step_context("write_logs"):
|
476
503
|
state.phase = "train"
|
477
504
|
self.log_step(model, train_batch, output, loss, state)
|
478
505
|
state.num_steps += 1
|
479
506
|
state.num_samples += self.get_size_of_batch(train_batch) or 0
|
480
507
|
|
481
|
-
|
482
|
-
state = self.on_step_end(state)
|
508
|
+
state = self.on_step_end(state)
|
483
509
|
|
484
510
|
if self.should_checkpoint(state):
|
485
511
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
@@ -496,14 +522,9 @@ class TrainMixin(
|
|
496
522
|
except NotImplementedError:
|
497
523
|
pass
|
498
524
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
with self.step_context("get_dataloader"):
|
503
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
504
|
-
|
505
|
-
with self.step_context("get_prefetcher"):
|
506
|
-
train_pf = self.get_prefetcher(train_dl)
|
525
|
+
train_ds = self.get_dataset("train")
|
526
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
527
|
+
train_pf = self.get_prefetcher(train_dl)
|
507
528
|
|
508
529
|
try:
|
509
530
|
with train_pf as train_pf_ctx:
|
@@ -520,14 +541,9 @@ class TrainMixin(
|
|
520
541
|
except NotImplementedError:
|
521
542
|
pass
|
522
543
|
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
with self.step_context("get_dataloader"):
|
527
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
528
|
-
|
529
|
-
with self.step_context("get_prefetcher"):
|
530
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
544
|
+
valid_ds = self.get_dataset("valid")
|
545
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
546
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
531
547
|
|
532
548
|
try:
|
533
549
|
with valid_pf as valid_pf_ctx:
|
@@ -559,7 +575,7 @@ class TrainMixin(
|
|
559
575
|
Thread(target=self.log_state, daemon=True).start()
|
560
576
|
|
561
577
|
key, model_key = jax.random.split(key)
|
562
|
-
model, optimizer, opt_state, state = self.load_initial_state(model_key)
|
578
|
+
model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
|
563
579
|
state = self.on_training_start(state)
|
564
580
|
|
565
581
|
def on_exit() -> None:
|
xax/task/script.py
CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
|
|
17
17
|
ProcessMixin,
|
18
18
|
RunnableConfig,
|
19
19
|
RunnableMixin,
|
20
|
-
StepContextConfig,
|
21
|
-
StepContextMixin,
|
22
20
|
)
|
23
21
|
|
24
22
|
|
@@ -28,7 +26,6 @@ class ScriptConfig(
|
|
28
26
|
GPUStatsConfig,
|
29
27
|
ProcessConfig,
|
30
28
|
LoggerConfig,
|
31
|
-
StepContextConfig,
|
32
29
|
ArtifactsConfig,
|
33
30
|
RunnableConfig,
|
34
31
|
BaseConfig,
|
@@ -44,7 +41,6 @@ class Script(
|
|
44
41
|
GPUStatsMixin[ConfigT],
|
45
42
|
ProcessMixin[ConfigT],
|
46
43
|
LoggerMixin[ConfigT],
|
47
|
-
StepContextMixin[ConfigT],
|
48
44
|
ArtifactsMixin[ConfigT],
|
49
45
|
RunnableMixin[ConfigT],
|
50
46
|
BaseTask[ConfigT],
|
xax/utils/debugging.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Defines some useful Jax debugging utilities."""
|
2
|
+
|
3
|
+
from collections import deque
|
4
|
+
from collections.abc import Iterable, Mapping
|
5
|
+
from typing import Any, Callable, Deque
|
6
|
+
|
7
|
+
from jaxtyping import Array
|
8
|
+
|
9
|
+
|
10
|
+
def get_named_leaves(
|
11
|
+
obj: Any, # noqa: ANN401
|
12
|
+
is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
|
13
|
+
max_depth: int = 100,
|
14
|
+
) -> list[tuple[str, Any]]: # noqa: ANN401
|
15
|
+
ret: list[tuple[str, Any]] = []
|
16
|
+
q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
|
17
|
+
q.append((0, "", obj))
|
18
|
+
|
19
|
+
while q:
|
20
|
+
depth, name, node = q.popleft()
|
21
|
+
|
22
|
+
if depth > max_depth:
|
23
|
+
continue
|
24
|
+
|
25
|
+
if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
|
26
|
+
for cname, cnode in node.__dict__.items():
|
27
|
+
gname = f"{name}.{cname}" if name else cname
|
28
|
+
if is_leaf(cnode):
|
29
|
+
ret.append((gname, cnode))
|
30
|
+
else:
|
31
|
+
q.append((depth + 1, gname, cnode))
|
32
|
+
|
33
|
+
elif isinstance(node, Mapping):
|
34
|
+
for cname, cnode in node.items():
|
35
|
+
gname = f"{name}.{cname}" if name else cname
|
36
|
+
if is_leaf(cnode):
|
37
|
+
ret.append((gname, cnode))
|
38
|
+
else:
|
39
|
+
q.append((depth + 1, gname, cnode))
|
40
|
+
|
41
|
+
elif isinstance(node, Iterable):
|
42
|
+
for i, cnode in enumerate(node):
|
43
|
+
gname = f"{name}.{i}" if name else str(i)
|
44
|
+
if is_leaf(cnode):
|
45
|
+
ret.append((gname, cnode))
|
46
|
+
else:
|
47
|
+
q.append((depth + 1, gname, cnode))
|
48
|
+
|
49
|
+
return ret
|
xax/utils/experiments.py
CHANGED
@@ -23,7 +23,8 @@ import urllib.request
|
|
23
23
|
import warnings
|
24
24
|
from abc import ABC, abstractmethod
|
25
25
|
from pathlib import Path
|
26
|
-
from
|
26
|
+
from types import TracebackType
|
27
|
+
from typing import Any, Iterator, Self, TypeVar, cast
|
27
28
|
from urllib.parse import urlparse
|
28
29
|
|
29
30
|
import git
|
@@ -116,19 +117,19 @@ class StateTimer:
|
|
116
117
|
logs: dict[str, dict[str, int | float]] = {}
|
117
118
|
|
118
119
|
# Logs step statistics.
|
119
|
-
logs["
|
120
|
+
logs["⌛ steps"] = {
|
120
121
|
"total": self.step_timer.steps,
|
121
122
|
"per-second": self.step_timer.steps_per_second,
|
122
123
|
}
|
123
124
|
|
124
125
|
# Logs sample statistics.
|
125
|
-
logs["
|
126
|
+
logs["⌛ samples"] = {
|
126
127
|
"total": self.sample_timer.steps,
|
127
128
|
"per-second": self.sample_timer.steps_per_second,
|
128
129
|
}
|
129
130
|
|
130
131
|
# Logs full iteration statistics.
|
131
|
-
logs["
|
132
|
+
logs["⌛ dt"] = {
|
132
133
|
"iter": self.iter_timer.iter_seconds,
|
133
134
|
}
|
134
135
|
|
@@ -147,6 +148,24 @@ class IntervalTicker:
|
|
147
148
|
return False
|
148
149
|
|
149
150
|
|
151
|
+
class ContextTimer:
|
152
|
+
def __init__(self) -> None:
|
153
|
+
self.start_time = 0.0
|
154
|
+
self.elapsed_time = 0.0
|
155
|
+
|
156
|
+
def __enter__(self) -> Self:
|
157
|
+
self.start_time = time.time()
|
158
|
+
return self
|
159
|
+
|
160
|
+
def __exit__(
|
161
|
+
self,
|
162
|
+
exc_type: type[BaseException] | None,
|
163
|
+
exc_value: BaseException | None,
|
164
|
+
traceback: TracebackType | None,
|
165
|
+
) -> None:
|
166
|
+
self.elapsed_time = time.time() - self.start_time
|
167
|
+
|
168
|
+
|
150
169
|
def abs_path(path: str) -> str:
|
151
170
|
return str(Path(path).resolve())
|
152
171
|
|
xax/utils/jax.py
CHANGED
@@ -1,14 +1,140 @@
|
|
1
1
|
"""Defines some utility functions for interfacing with Jax."""
|
2
2
|
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import time
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
|
10
|
+
import jax
|
3
11
|
import jax.numpy as jnp
|
4
12
|
import numpy as np
|
13
|
+
from jax._src import sharding_impls
|
14
|
+
from jax._src.lib import xla_client as xc
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
DEFAULT_COMPILE_TIMEOUT = 1.0
|
5
19
|
|
6
20
|
Number = int | float | np.ndarray | jnp.ndarray
|
7
21
|
|
8
22
|
|
23
|
+
P = ParamSpec("P") # For function parameters
|
24
|
+
R = TypeVar("R") # For function return type
|
25
|
+
|
26
|
+
|
9
27
|
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
10
28
|
if isinstance(value, (int, float)):
|
11
29
|
return float(value)
|
12
30
|
if isinstance(value, (np.ndarray, jnp.ndarray)):
|
13
31
|
return float(value.item())
|
14
32
|
raise TypeError(f"Unexpected type: {type(value)}")
|
33
|
+
|
34
|
+
|
35
|
+
def get_hash(obj: object) -> int:
|
36
|
+
"""Get a hash of an object.
|
37
|
+
|
38
|
+
If the object is hashable, use the hash. Otherwise, use the id.
|
39
|
+
"""
|
40
|
+
if hasattr(obj, "__hash__"):
|
41
|
+
return hash(obj)
|
42
|
+
return id(obj)
|
43
|
+
|
44
|
+
|
45
|
+
def jit(
|
46
|
+
in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
47
|
+
out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
48
|
+
static_argnums: int | Sequence[int] | None = None,
|
49
|
+
static_argnames: str | Iterable[str] | None = None,
|
50
|
+
donate_argnums: int | Sequence[int] | None = None,
|
51
|
+
donate_argnames: str | Iterable[str] | None = None,
|
52
|
+
keep_unused: bool = False,
|
53
|
+
device: xc.Device | None = None,
|
54
|
+
backend: str | None = None,
|
55
|
+
inline: bool = False,
|
56
|
+
abstracted_axes: Any | None = None, # noqa: ANN401
|
57
|
+
compiler_options: dict[str, Any] | None = None,
|
58
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
59
|
+
"""Wrapper function that provides utility improvements over Jax's JIT.
|
60
|
+
|
61
|
+
Specifically, this function works on class methods, is toggleable, and
|
62
|
+
detects recompilations by matching hash values.
|
63
|
+
|
64
|
+
This is meant to be used as a decorator factory, and the decorated function
|
65
|
+
calls `wrapped`.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
69
|
+
class JitState:
|
70
|
+
compilation_count = 0
|
71
|
+
last_arg_dict: dict[str, int] | None = None
|
72
|
+
|
73
|
+
sig = inspect.signature(fn)
|
74
|
+
param_names = list(sig.parameters.keys())
|
75
|
+
|
76
|
+
jitted_fn = jax.jit(
|
77
|
+
fn,
|
78
|
+
in_shardings=in_shardings,
|
79
|
+
out_shardings=out_shardings,
|
80
|
+
static_argnums=static_argnums,
|
81
|
+
static_argnames=static_argnames,
|
82
|
+
donate_argnums=donate_argnums,
|
83
|
+
donate_argnames=donate_argnames,
|
84
|
+
keep_unused=keep_unused,
|
85
|
+
device=device,
|
86
|
+
backend=backend,
|
87
|
+
inline=inline,
|
88
|
+
abstracted_axes=abstracted_axes,
|
89
|
+
compiler_options=compiler_options,
|
90
|
+
)
|
91
|
+
|
92
|
+
@wraps(fn)
|
93
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
94
|
+
if os.environ.get("DEBUG", "0") == "1": # skipping during debug
|
95
|
+
return fn(*args, **kwargs)
|
96
|
+
|
97
|
+
do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
|
98
|
+
|
99
|
+
if do_profile:
|
100
|
+
class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
|
101
|
+
logger.info(
|
102
|
+
"Currently running %s (count: %s)",
|
103
|
+
f"{class_name}{fn.__name__}",
|
104
|
+
JitState.compilation_count,
|
105
|
+
)
|
106
|
+
|
107
|
+
start_time = time.time()
|
108
|
+
res = jitted_fn(*args, **kwargs)
|
109
|
+
end_time = time.time()
|
110
|
+
runtime = end_time - start_time
|
111
|
+
|
112
|
+
# if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
|
113
|
+
# TODO: we should probably reimplement the lower-level jitting logic to avoid this
|
114
|
+
if do_profile:
|
115
|
+
arg_dict = {}
|
116
|
+
for i, arg in enumerate(args):
|
117
|
+
if i < len(param_names):
|
118
|
+
arg_dict[param_names[i]] = get_hash(arg)
|
119
|
+
for k, v in kwargs.items():
|
120
|
+
arg_dict[k] = get_hash(v)
|
121
|
+
|
122
|
+
logger.info("Hashing took %s seconds", runtime)
|
123
|
+
JitState.compilation_count += 1
|
124
|
+
|
125
|
+
if JitState.last_arg_dict is not None:
|
126
|
+
all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
|
127
|
+
for k in all_keys:
|
128
|
+
prev = JitState.last_arg_dict.get(k, "N/A")
|
129
|
+
curr = arg_dict.get(k, "N/A")
|
130
|
+
|
131
|
+
if prev != curr:
|
132
|
+
logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
|
133
|
+
|
134
|
+
JitState.last_arg_dict = arg_dict
|
135
|
+
|
136
|
+
return cast(R, res)
|
137
|
+
|
138
|
+
return wrapped
|
139
|
+
|
140
|
+
return decorator
|
xax/utils/jaxpr.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
"""Visualize JAXPR."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.core
|
7
|
+
|
8
|
+
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
|
+
"""Save the JAXPR to a DOT file.
|
11
|
+
|
12
|
+
Example usage:
|
13
|
+
|
14
|
+
grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
|
15
|
+
save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
|
16
|
+
|
17
|
+
Then, you can visualize the JAXPR using Graphviz:
|
18
|
+
|
19
|
+
dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
|
20
|
+
|
21
|
+
Args:
|
22
|
+
closed_jaxpr: The closed JAXPR to save.
|
23
|
+
filename: The filename to save the JAXPR to.
|
24
|
+
"""
|
25
|
+
if hasattr(closed_jaxpr, "jaxpr"):
|
26
|
+
jaxpr = closed_jaxpr.jaxpr
|
27
|
+
else:
|
28
|
+
jaxpr = closed_jaxpr
|
29
|
+
|
30
|
+
with open(filename, "w") as f:
|
31
|
+
f.write("digraph Jaxpr {\n")
|
32
|
+
|
33
|
+
var_names: dict[jax.core.Var, str] = {}
|
34
|
+
var_count = 0
|
35
|
+
|
36
|
+
def get_var_name(var: jax.core.Var) -> str:
|
37
|
+
"""Get a unique name for a variable."""
|
38
|
+
nonlocal var_names, var_count
|
39
|
+
|
40
|
+
# Handle Literal objects specially since they're not hashable
|
41
|
+
if isinstance(var, jax.core.Literal):
|
42
|
+
# Create a name based on the literal value
|
43
|
+
name = f"lit_{var.val}"
|
44
|
+
return name
|
45
|
+
|
46
|
+
# For other variables
|
47
|
+
if var not in var_names:
|
48
|
+
name = f"var_{var_count}"
|
49
|
+
var_names[var] = name
|
50
|
+
var_count += 1
|
51
|
+
return var_names[var]
|
52
|
+
|
53
|
+
for var in jaxpr.invars:
|
54
|
+
node_name = get_var_name(var)
|
55
|
+
f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
|
56
|
+
|
57
|
+
eq_count = 0
|
58
|
+
for eq in jaxpr.eqns:
|
59
|
+
eq_node = f"eq{eq_count}"
|
60
|
+
label = f"{eq.primitive.name}"
|
61
|
+
f.write(f' {eq_node} [shape=box, label="{label}"];\n')
|
62
|
+
|
63
|
+
for invar in eq.invars:
|
64
|
+
var_name = get_var_name(invar)
|
65
|
+
f.write(f" {var_name} -> {eq_node};\n")
|
66
|
+
|
67
|
+
for outvar in eq.outvars:
|
68
|
+
var_name = get_var_name(outvar)
|
69
|
+
f.write(f" {eq_node} -> {var_name};\n")
|
70
|
+
|
71
|
+
eq_count += 1
|
72
|
+
|
73
|
+
for var in jaxpr.outvars:
|
74
|
+
node_name = get_var_name(var)
|
75
|
+
f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
|
76
|
+
|
77
|
+
f.write("}\n")
|
xax/utils/profile.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
"""Profiling utilities."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import time
|
6
|
+
from functools import wraps
|
7
|
+
from typing import Callable, ParamSpec, TypeVar
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
P = ParamSpec("P") # For function parameters
|
12
|
+
R = TypeVar("R") # For function return type
|
13
|
+
|
14
|
+
|
15
|
+
def profile(fn: Callable[P, R]) -> Callable[P, R]:
|
16
|
+
"""Profiling decorator that tracks function call count and execution time.
|
17
|
+
|
18
|
+
Activated when the PROFILE environment variable is set to "1".
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
A decorated function with profiling capabilities.
|
22
|
+
"""
|
23
|
+
|
24
|
+
class ProfileState:
|
25
|
+
call_count = 0
|
26
|
+
total_time = 0.0
|
27
|
+
|
28
|
+
@wraps(fn)
|
29
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
30
|
+
if os.environ.get("PROFILE", "0") != "1":
|
31
|
+
return fn(*args, **kwargs)
|
32
|
+
|
33
|
+
start_time = time.time()
|
34
|
+
res = fn(*args, **kwargs)
|
35
|
+
end_time = time.time()
|
36
|
+
runtime = end_time - start_time
|
37
|
+
|
38
|
+
ProfileState.call_count += 1
|
39
|
+
ProfileState.total_time += runtime
|
40
|
+
|
41
|
+
# Handle class methods by showing class name
|
42
|
+
if fn.__name__ == "__call__" or (args and hasattr(args[0], "__class__")):
|
43
|
+
try:
|
44
|
+
class_name = args[0].__class__.__name__ + "."
|
45
|
+
except (IndexError, AttributeError):
|
46
|
+
class_name = ""
|
47
|
+
else:
|
48
|
+
class_name = ""
|
49
|
+
|
50
|
+
logger.info(
|
51
|
+
"%s %s - call #%s, took %s seconds, total: %s seconds",
|
52
|
+
class_name,
|
53
|
+
fn.__name__,
|
54
|
+
ProfileState.call_count,
|
55
|
+
runtime,
|
56
|
+
ProfileState.total_time,
|
57
|
+
)
|
58
|
+
|
59
|
+
return res
|
60
|
+
|
61
|
+
return wrapped
|