xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +220 -44
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -35
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/logging.py +12 -2
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/RECORD +27 -22
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/top_level.txt +0 -0
xax/task/mixins/step_wrapper.py
CHANGED
@@ -1,53 +1,39 @@
|
|
1
1
|
"""Defines a mixin to wrap some steps in a context manager."""
|
2
2
|
|
3
|
+
import time
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from types import TracebackType
|
5
|
-
from typing import
|
6
|
+
from typing import Callable, ContextManager, TypeVar
|
6
7
|
|
7
|
-
import equinox as eqx
|
8
8
|
import jax
|
9
9
|
|
10
10
|
from xax.task.base import BaseConfig, BaseTask
|
11
11
|
|
12
|
-
StepType = Literal[
|
13
|
-
"backward",
|
14
|
-
"change_mode",
|
15
|
-
"clip_grads",
|
16
|
-
"create_optimizers",
|
17
|
-
"forward",
|
18
|
-
"get_dataloader",
|
19
|
-
"get_dataset",
|
20
|
-
"get_prefetcher",
|
21
|
-
"get_model",
|
22
|
-
"get_optimizer",
|
23
|
-
"get_initial_opt_state",
|
24
|
-
"get_update_fn",
|
25
|
-
"load_checkpoint",
|
26
|
-
"log_losses",
|
27
|
-
"model_to_device",
|
28
|
-
"on_step_end",
|
29
|
-
"on_step_start",
|
30
|
-
"save_checkpoint",
|
31
|
-
"step",
|
32
|
-
"update_state",
|
33
|
-
"write_logs",
|
34
|
-
"zero_grads",
|
35
|
-
]
|
36
|
-
|
37
12
|
|
38
13
|
class StepContext(ContextManager):
|
39
14
|
"""Context manager to get the current step type."""
|
40
15
|
|
41
|
-
CURRENT_STEP:
|
16
|
+
CURRENT_STEP: str | None = None
|
42
17
|
|
43
|
-
def __init__(
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
step: str,
|
21
|
+
on_context_start: Callable[[str], None],
|
22
|
+
on_context_end: Callable[[str, float], None],
|
23
|
+
) -> None:
|
44
24
|
self.step = step
|
25
|
+
self.start_time = 0.0
|
26
|
+
self.on_context_start = on_context_start
|
27
|
+
self.on_context_end = on_context_end
|
45
28
|
|
46
29
|
def __enter__(self) -> None:
|
47
30
|
StepContext.CURRENT_STEP = self.step
|
31
|
+
self.start_time = time.time()
|
32
|
+
self.on_context_start(self.step)
|
48
33
|
|
49
34
|
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
50
35
|
StepContext.CURRENT_STEP = None
|
36
|
+
self.on_context_end(self.step, time.time() - self.start_time)
|
51
37
|
|
52
38
|
|
53
39
|
@jax.tree_util.register_dataclass
|
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
|
|
63
49
|
def __init__(self, config: Config) -> None:
|
64
50
|
super().__init__(config)
|
65
51
|
|
66
|
-
|
67
|
-
|
68
|
-
|
52
|
+
def step_context(self, step: str) -> ContextManager:
|
53
|
+
return StepContext(step, self.on_context_start, self.on_context_stop)
|
54
|
+
|
55
|
+
def on_context_start(self, step: str) -> None:
|
56
|
+
pass
|
57
|
+
|
58
|
+
def on_context_stop(self, step: str, elapsed_time: float) -> None:
|
59
|
+
pass
|
xax/task/mixins/train.py
CHANGED
@@ -24,6 +24,7 @@ from typing import (
|
|
24
24
|
TypeVar,
|
25
25
|
cast,
|
26
26
|
get_args,
|
27
|
+
overload,
|
27
28
|
)
|
28
29
|
|
29
30
|
import equinox as eqx
|
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
|
|
35
36
|
|
36
37
|
from xax.core.conf import field
|
37
38
|
from xax.core.state import Phase, State
|
39
|
+
from xax.nn.functions import set_random_seed
|
38
40
|
from xax.nn.parallel import is_master
|
39
41
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
40
42
|
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
@@ -115,7 +117,7 @@ class ValidStepTimer:
|
|
115
117
|
if self.last_valid_time is None or self.last_valid_step is None:
|
116
118
|
self.last_valid_time = state.elapsed_time_s
|
117
119
|
self.last_valid_step = state.num_steps
|
118
|
-
return
|
120
|
+
return False
|
119
121
|
|
120
122
|
# Step-based validation.
|
121
123
|
valid_every_n_steps = self.valid_every_n_steps
|
@@ -154,7 +156,6 @@ class TrainConfig(
|
|
154
156
|
valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
|
155
157
|
valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
|
156
158
|
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
157
|
-
batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
|
158
159
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
159
160
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
160
161
|
random_seed: int = field(1337, help="Random seed for the task")
|
@@ -183,6 +184,9 @@ class TrainMixin(
|
|
183
184
|
def __init__(self, config: Config) -> None:
|
184
185
|
super().__init__(config)
|
185
186
|
|
187
|
+
# Sets the random seed whenever we instantiate a new train mixin.
|
188
|
+
set_random_seed(self.config.random_seed)
|
189
|
+
|
186
190
|
# Timer for validation steps.
|
187
191
|
self.valid_step_timer = ValidStepTimer(
|
188
192
|
valid_every_n_steps=config.valid_every_n_steps,
|
@@ -279,31 +283,53 @@ class TrainMixin(
|
|
279
283
|
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
280
284
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
281
285
|
|
286
|
+
@overload
|
287
|
+
def load_initial_state(
|
288
|
+
self,
|
289
|
+
key: PRNGKeyArray,
|
290
|
+
load_optimizer: Literal[False] = False,
|
291
|
+
) -> tuple[PyTree, State]: ...
|
292
|
+
|
293
|
+
@overload
|
282
294
|
def load_initial_state(
|
283
295
|
self,
|
284
296
|
key: PRNGKeyArray,
|
285
|
-
|
297
|
+
load_optimizer: Literal[True],
|
298
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
299
|
+
|
300
|
+
def load_initial_state(
|
301
|
+
self,
|
302
|
+
key: PRNGKeyArray,
|
303
|
+
load_optimizer: bool = False,
|
304
|
+
) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
286
305
|
init_ckpt_path = self.get_init_ckpt_path()
|
287
306
|
|
288
307
|
if init_ckpt_path is not None:
|
289
308
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
290
|
-
|
309
|
+
if load_optimizer:
|
291
310
|
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
292
311
|
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
293
312
|
if config_diff:
|
294
313
|
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
295
314
|
return model, optimizer, opt_state, state
|
296
315
|
|
297
|
-
|
298
|
-
|
316
|
+
else:
|
317
|
+
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
318
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
319
|
+
if config_diff:
|
320
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
321
|
+
return model, state
|
322
|
+
|
323
|
+
model = self.get_model(key)
|
324
|
+
state = State.init_state()
|
299
325
|
|
300
|
-
|
301
|
-
|
326
|
+
if not load_optimizer:
|
327
|
+
return model, state
|
302
328
|
|
303
|
-
|
304
|
-
|
329
|
+
optimizer = self.get_optimizer()
|
330
|
+
opt_state = self.get_initial_opt_state(model, optimizer)
|
305
331
|
|
306
|
-
return model, optimizer, opt_state,
|
332
|
+
return model, optimizer, opt_state, state
|
307
333
|
|
308
334
|
@eqx.filter_jit
|
309
335
|
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
@@ -424,6 +450,7 @@ class TrainMixin(
|
|
424
450
|
def log_state(self) -> None:
|
425
451
|
logger.log(LOG_STATUS, self.task_path)
|
426
452
|
logger.log(LOG_STATUS, self.task_name)
|
453
|
+
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
427
454
|
self.logger.log_file("git_state.txt", get_git_state(self))
|
428
455
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
429
456
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
@@ -456,7 +483,8 @@ class TrainMixin(
|
|
456
483
|
while not self.is_training_over(state):
|
457
484
|
if self.valid_step_timer.is_valid_step(state):
|
458
485
|
valid_batch = next(valid_pf)
|
459
|
-
|
486
|
+
with self.step_context("model_step"):
|
487
|
+
model, loss, output = self.val_step(model, valid_batch)
|
460
488
|
|
461
489
|
# Perform logging.
|
462
490
|
with self.step_context("write_logs"):
|
@@ -464,22 +492,19 @@ class TrainMixin(
|
|
464
492
|
self.log_step(model, valid_batch, output, loss, state)
|
465
493
|
state.num_valid_samples += 1
|
466
494
|
|
467
|
-
|
468
|
-
state = self.on_step_start(state)
|
495
|
+
state = self.on_step_start(state)
|
469
496
|
|
470
|
-
with self.step_context("
|
497
|
+
with self.step_context("model_step"):
|
471
498
|
train_batch = next(train_pf)
|
472
499
|
model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
|
473
500
|
|
474
|
-
# Perform logging.
|
475
501
|
with self.step_context("write_logs"):
|
476
502
|
state.phase = "train"
|
477
503
|
self.log_step(model, train_batch, output, loss, state)
|
478
504
|
state.num_steps += 1
|
479
505
|
state.num_samples += self.get_size_of_batch(train_batch) or 0
|
480
506
|
|
481
|
-
|
482
|
-
state = self.on_step_end(state)
|
507
|
+
state = self.on_step_end(state)
|
483
508
|
|
484
509
|
if self.should_checkpoint(state):
|
485
510
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
@@ -496,14 +521,9 @@ class TrainMixin(
|
|
496
521
|
except NotImplementedError:
|
497
522
|
pass
|
498
523
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
with self.step_context("get_dataloader"):
|
503
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
504
|
-
|
505
|
-
with self.step_context("get_prefetcher"):
|
506
|
-
train_pf = self.get_prefetcher(train_dl)
|
524
|
+
train_ds = self.get_dataset("train")
|
525
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
526
|
+
train_pf = self.get_prefetcher(train_dl)
|
507
527
|
|
508
528
|
try:
|
509
529
|
with train_pf as train_pf_ctx:
|
@@ -520,14 +540,9 @@ class TrainMixin(
|
|
520
540
|
except NotImplementedError:
|
521
541
|
pass
|
522
542
|
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
with self.step_context("get_dataloader"):
|
527
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
528
|
-
|
529
|
-
with self.step_context("get_prefetcher"):
|
530
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
543
|
+
valid_ds = self.get_dataset("valid")
|
544
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
545
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
531
546
|
|
532
547
|
try:
|
533
548
|
with valid_pf as valid_pf_ctx:
|
@@ -559,7 +574,7 @@ class TrainMixin(
|
|
559
574
|
Thread(target=self.log_state, daemon=True).start()
|
560
575
|
|
561
576
|
key, model_key = jax.random.split(key)
|
562
|
-
model, optimizer, opt_state, state = self.load_initial_state(model_key)
|
577
|
+
model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
|
563
578
|
state = self.on_training_start(state)
|
564
579
|
|
565
580
|
def on_exit() -> None:
|
xax/task/script.py
CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
|
|
17
17
|
ProcessMixin,
|
18
18
|
RunnableConfig,
|
19
19
|
RunnableMixin,
|
20
|
-
StepContextConfig,
|
21
|
-
StepContextMixin,
|
22
20
|
)
|
23
21
|
|
24
22
|
|
@@ -28,7 +26,6 @@ class ScriptConfig(
|
|
28
26
|
GPUStatsConfig,
|
29
27
|
ProcessConfig,
|
30
28
|
LoggerConfig,
|
31
|
-
StepContextConfig,
|
32
29
|
ArtifactsConfig,
|
33
30
|
RunnableConfig,
|
34
31
|
BaseConfig,
|
@@ -44,7 +41,6 @@ class Script(
|
|
44
41
|
GPUStatsMixin[ConfigT],
|
45
42
|
ProcessMixin[ConfigT],
|
46
43
|
LoggerMixin[ConfigT],
|
47
|
-
StepContextMixin[ConfigT],
|
48
44
|
ArtifactsMixin[ConfigT],
|
49
45
|
RunnableMixin[ConfigT],
|
50
46
|
BaseTask[ConfigT],
|
xax/utils/debugging.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Defines some useful Jax debugging utilities."""
|
2
|
+
|
3
|
+
from collections import deque
|
4
|
+
from collections.abc import Iterable, Mapping
|
5
|
+
from typing import Any, Callable, Deque
|
6
|
+
|
7
|
+
from jaxtyping import Array
|
8
|
+
|
9
|
+
|
10
|
+
def get_named_leaves(
|
11
|
+
obj: Any, # noqa: ANN401
|
12
|
+
is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
|
13
|
+
max_depth: int = 100,
|
14
|
+
) -> list[tuple[str, Any]]: # noqa: ANN401
|
15
|
+
ret: list[tuple[str, Any]] = []
|
16
|
+
q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
|
17
|
+
q.append((0, "", obj))
|
18
|
+
|
19
|
+
while q:
|
20
|
+
depth, name, node = q.popleft()
|
21
|
+
|
22
|
+
if depth > max_depth:
|
23
|
+
continue
|
24
|
+
|
25
|
+
if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
|
26
|
+
for cname, cnode in node.__dict__.items():
|
27
|
+
gname = f"{name}.{cname}" if name else cname
|
28
|
+
if is_leaf(cnode):
|
29
|
+
ret.append((gname, cnode))
|
30
|
+
else:
|
31
|
+
q.append((depth + 1, gname, cnode))
|
32
|
+
|
33
|
+
elif isinstance(node, Mapping):
|
34
|
+
for cname, cnode in node.items():
|
35
|
+
gname = f"{name}.{cname}" if name else cname
|
36
|
+
if is_leaf(cnode):
|
37
|
+
ret.append((gname, cnode))
|
38
|
+
else:
|
39
|
+
q.append((depth + 1, gname, cnode))
|
40
|
+
|
41
|
+
elif isinstance(node, Iterable):
|
42
|
+
for i, cnode in enumerate(node):
|
43
|
+
gname = f"{name}.{i}" if name else str(i)
|
44
|
+
if is_leaf(cnode):
|
45
|
+
ret.append((gname, cnode))
|
46
|
+
else:
|
47
|
+
q.append((depth + 1, gname, cnode))
|
48
|
+
|
49
|
+
return ret
|
xax/utils/experiments.py
CHANGED
@@ -23,7 +23,8 @@ import urllib.request
|
|
23
23
|
import warnings
|
24
24
|
from abc import ABC, abstractmethod
|
25
25
|
from pathlib import Path
|
26
|
-
from
|
26
|
+
from types import TracebackType
|
27
|
+
from typing import Any, Iterator, Self, TypeVar, cast
|
27
28
|
from urllib.parse import urlparse
|
28
29
|
|
29
30
|
import git
|
@@ -116,19 +117,19 @@ class StateTimer:
|
|
116
117
|
logs: dict[str, dict[str, int | float]] = {}
|
117
118
|
|
118
119
|
# Logs step statistics.
|
119
|
-
logs["
|
120
|
+
logs["⌛ steps"] = {
|
120
121
|
"total": self.step_timer.steps,
|
121
122
|
"per-second": self.step_timer.steps_per_second,
|
122
123
|
}
|
123
124
|
|
124
125
|
# Logs sample statistics.
|
125
|
-
logs["
|
126
|
+
logs["⌛ samples"] = {
|
126
127
|
"total": self.sample_timer.steps,
|
127
128
|
"per-second": self.sample_timer.steps_per_second,
|
128
129
|
}
|
129
130
|
|
130
131
|
# Logs full iteration statistics.
|
131
|
-
logs["
|
132
|
+
logs["⌛ dt"] = {
|
132
133
|
"iter": self.iter_timer.iter_seconds,
|
133
134
|
}
|
134
135
|
|
@@ -147,6 +148,24 @@ class IntervalTicker:
|
|
147
148
|
return False
|
148
149
|
|
149
150
|
|
151
|
+
class ContextTimer:
|
152
|
+
def __init__(self) -> None:
|
153
|
+
self.start_time = 0.0
|
154
|
+
self.elapsed_time = 0.0
|
155
|
+
|
156
|
+
def __enter__(self) -> Self:
|
157
|
+
self.start_time = time.time()
|
158
|
+
return self
|
159
|
+
|
160
|
+
def __exit__(
|
161
|
+
self,
|
162
|
+
exc_type: type[BaseException] | None,
|
163
|
+
exc_value: BaseException | None,
|
164
|
+
traceback: TracebackType | None,
|
165
|
+
) -> None:
|
166
|
+
self.elapsed_time = time.time() - self.start_time
|
167
|
+
|
168
|
+
|
150
169
|
def abs_path(path: str) -> str:
|
151
170
|
return str(Path(path).resolve())
|
152
171
|
|
xax/utils/jaxpr.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
"""Visualize JAXPR."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.core
|
7
|
+
|
8
|
+
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
|
+
"""Save the JAXPR to a DOT file.
|
11
|
+
|
12
|
+
Example usage:
|
13
|
+
|
14
|
+
grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
|
15
|
+
save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
|
16
|
+
|
17
|
+
Then, you can visualize the JAXPR using Graphviz:
|
18
|
+
|
19
|
+
dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
|
20
|
+
|
21
|
+
Args:
|
22
|
+
closed_jaxpr: The closed JAXPR to save.
|
23
|
+
filename: The filename to save the JAXPR to.
|
24
|
+
"""
|
25
|
+
if hasattr(closed_jaxpr, "jaxpr"):
|
26
|
+
jaxpr = closed_jaxpr.jaxpr
|
27
|
+
else:
|
28
|
+
jaxpr = closed_jaxpr
|
29
|
+
|
30
|
+
with open(filename, "w") as f:
|
31
|
+
f.write("digraph Jaxpr {\n")
|
32
|
+
|
33
|
+
var_names: dict[jax.core.Var, str] = {}
|
34
|
+
var_count = 0
|
35
|
+
|
36
|
+
def get_var_name(var: jax.core.Var) -> str:
|
37
|
+
"""Get a unique name for a variable."""
|
38
|
+
nonlocal var_names, var_count
|
39
|
+
|
40
|
+
# Handle Literal objects specially since they're not hashable
|
41
|
+
if isinstance(var, jax.core.Literal):
|
42
|
+
# Create a name based on the literal value
|
43
|
+
name = f"lit_{var.val}"
|
44
|
+
return name
|
45
|
+
|
46
|
+
# For other variables
|
47
|
+
if var not in var_names:
|
48
|
+
name = f"var_{var_count}"
|
49
|
+
var_names[var] = name
|
50
|
+
var_count += 1
|
51
|
+
return var_names[var]
|
52
|
+
|
53
|
+
for var in jaxpr.invars:
|
54
|
+
node_name = get_var_name(var)
|
55
|
+
f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
|
56
|
+
|
57
|
+
eq_count = 0
|
58
|
+
for eq in jaxpr.eqns:
|
59
|
+
eq_node = f"eq{eq_count}"
|
60
|
+
label = f"{eq.primitive.name}"
|
61
|
+
f.write(f' {eq_node} [shape=box, label="{label}"];\n')
|
62
|
+
|
63
|
+
for invar in eq.invars:
|
64
|
+
var_name = get_var_name(invar)
|
65
|
+
f.write(f" {var_name} -> {eq_node};\n")
|
66
|
+
|
67
|
+
for outvar in eq.outvars:
|
68
|
+
var_name = get_var_name(outvar)
|
69
|
+
f.write(f" {eq_node} -> {var_name};\n")
|
70
|
+
|
71
|
+
eq_count += 1
|
72
|
+
|
73
|
+
for var in jaxpr.outvars:
|
74
|
+
node_name = get_var_name(var)
|
75
|
+
f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
|
76
|
+
|
77
|
+
f.write("}\n")
|
xax/utils/logging.py
CHANGED
@@ -140,7 +140,13 @@ class ColoredFormatter(logging.Formatter):
|
|
140
140
|
return logging.Formatter.format(self, record)
|
141
141
|
|
142
142
|
|
143
|
-
def configure_logging(
|
143
|
+
def configure_logging(
|
144
|
+
prefix: str | None = None,
|
145
|
+
*,
|
146
|
+
rank: int | None = None,
|
147
|
+
world_size: int | None = None,
|
148
|
+
debug: bool | None = None,
|
149
|
+
) -> None:
|
144
150
|
"""Instantiates logging.
|
145
151
|
|
146
152
|
This captures logs and reroutes them to the Toasts module, which is
|
@@ -151,6 +157,7 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
|
|
151
157
|
prefix: An optional prefix to add to the logger
|
152
158
|
rank: The current rank, or None if not using multiprocessing
|
153
159
|
world_size: The total world size, or None if not using multiprocessing
|
160
|
+
debug: Whether to enable debug logging
|
154
161
|
"""
|
155
162
|
if rank is not None or world_size is not None:
|
156
163
|
assert rank is not None and world_size is not None
|
@@ -168,7 +175,10 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
|
|
168
175
|
stream_handler.addFilter(filter)
|
169
176
|
root_logger.addHandler(stream_handler)
|
170
177
|
|
171
|
-
|
178
|
+
if debug is None:
|
179
|
+
root_logger.setLevel(logging._nameToLevel[config.log_level])
|
180
|
+
else:
|
181
|
+
root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
|
172
182
|
|
173
183
|
# Avoid junk logs from other libraries.
|
174
184
|
if config.hide_third_party_logs:
|