xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ """Defines a logger that calls a callback function with the log line."""
2
+
3
+ from typing import Callable
4
+
5
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
6
+
7
+
8
+ class CallbackLogger(LoggerImpl):
9
+ def __init__(
10
+ self,
11
+ *,
12
+ callback: Callable[[LogLine], None] = lambda x: None,
13
+ error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
14
+ error_callback: Callable[[LogError], None] = lambda x: None,
15
+ status_callback: Callable[[LogStatus], None] = lambda x: None,
16
+ ping_callback: Callable[[LogPing], None] = lambda x: None,
17
+ file_callback: Callable[[str, str], None] = lambda x, y: None,
18
+ ) -> None:
19
+ super().__init__()
20
+
21
+ self.callback = callback
22
+ self.error_summary_callback = error_summary_callback
23
+ self.error_callback = error_callback
24
+ self.status_callback = status_callback
25
+ self.ping_callback = ping_callback
26
+ self.file_callback = file_callback
27
+
28
+ def write(self, line: LogLine) -> None:
29
+ self.callback(line)
30
+
31
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
32
+ self.error_summary_callback(error_summary)
33
+
34
+ def write_error(self, error: LogError) -> None:
35
+ self.error_callback(error)
36
+
37
+ def write_status(self, status: LogStatus) -> None:
38
+ self.status_callback(status)
39
+
40
+ def write_ping(self, ping: LogPing) -> None:
41
+ self.ping_callback(ping)
42
+
43
+ def log_file(self, name: str, contents: str) -> None:
44
+ self.file_callback(name, contents)
xax/task/loggers/state.py CHANGED
@@ -3,8 +3,6 @@
3
3
  from pathlib import Path
4
4
  from typing import Literal
5
5
 
6
- from omegaconf import DictConfig, OmegaConf
7
-
8
6
  from xax.task.logger import LoggerImpl, LogLine
9
7
 
10
8
 
@@ -12,9 +10,6 @@ class StateLogger(LoggerImpl):
12
10
  def __init__(
13
11
  self,
14
12
  run_directory: str | Path,
15
- git_state_name: str = "git_state.txt",
16
- train_code_name: str = "train_code.py",
17
- config_name: str = "config.yaml",
18
13
  flush_immediately: bool = False,
19
14
  open_mode: Literal["w", "a"] = "w",
20
15
  line_sep: str = "\n",
@@ -22,24 +17,16 @@ class StateLogger(LoggerImpl):
22
17
  ) -> None:
23
18
  super().__init__(float("inf"))
24
19
 
25
- self.git_state_file = Path(run_directory).expanduser().resolve() / git_state_name
26
- self.train_code_file = Path(run_directory).expanduser().resolve() / train_code_name
27
- self.config_file = Path(run_directory).expanduser().resolve() / config_name
20
+ self.run_directory = Path(run_directory).expanduser().resolve()
21
+
28
22
  self.flush_immediately = flush_immediately
29
23
  self.open_mode = open_mode
30
24
  self.line_sep = line_sep
31
25
  self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
32
26
 
33
- def log_git_state(self, git_state: str) -> None:
34
- with open(self.git_state_file, "w") as f:
35
- f.write(git_state)
36
-
37
- def log_training_code(self, training_code: str) -> None:
38
- with open(self.train_code_file, "w") as f:
39
- f.write(training_code)
40
-
41
- def log_config(self, config: DictConfig) -> None:
42
- OmegaConf.save(config, self.config_file)
27
+ def log_file(self, name: str, contents: str) -> None:
28
+ with open(self.run_directory / name, "w") as f:
29
+ f.write(contents)
43
30
 
44
31
  def write(self, line: LogLine) -> None:
45
32
  pass
@@ -12,10 +12,6 @@ import time
12
12
  from pathlib import Path
13
13
  from typing import TypeVar
14
14
 
15
- import jax
16
- import PIL.Image
17
- from omegaconf import DictConfig, OmegaConf
18
-
19
15
  from xax.core.state import Phase
20
16
  from xax.nn.parallel import is_master
21
17
  from xax.task.logger import LoggerImpl, LogLine
@@ -62,10 +58,7 @@ class TensorboardLogger(LoggerImpl):
62
58
 
63
59
  self.proc: subprocess.Popen | None = None
64
60
 
65
- self.git_state: str | None = None
66
- self.training_code: str | None = None
67
- self.config: DictConfig | None = None
68
-
61
+ self.files: dict[str, str] = {}
69
62
  self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
70
63
  self._started = False
71
64
 
@@ -84,7 +77,7 @@ class TensorboardLogger(LoggerImpl):
84
77
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
85
78
 
86
79
  while port_is_busy(port):
87
- logger.warning(f"Port {port} is busy, waiting...")
80
+ logger.warning("Port %s is busy, waiting...", port)
88
81
  time.sleep(10)
89
82
 
90
83
  def make_localhost(s: str) -> str:
@@ -160,20 +153,10 @@ class TensorboardLogger(LoggerImpl):
160
153
  self._start()
161
154
  return self.writers.writer(phase)
162
155
 
163
- def log_git_state(self, git_state: str) -> None:
164
- if not is_master():
165
- return
166
- self.git_state = f"```\n{git_state}\n```"
167
-
168
- def log_training_code(self, training_code: str) -> None:
156
+ def log_file(self, name: str, contents: str) -> None:
169
157
  if not is_master():
170
158
  return
171
- self.training_code = f"```python\n{training_code}\n```"
172
-
173
- def log_config(self, config: DictConfig) -> None:
174
- if not is_master():
175
- return
176
- self.config = config
159
+ self.files[name] = f"```\n{contents}\n```"
177
160
 
178
161
  def write(self, line: LogLine) -> None:
179
162
  if not is_master():
@@ -205,22 +188,22 @@ class TensorboardLogger(LoggerImpl):
205
188
 
206
189
  for namespace, images in line.images.items():
207
190
  for image_key, image_value in images.items():
208
- image = PIL.Image.fromarray(jax.device_get(image_value.pixels))
209
191
  writer.add_image(
210
192
  f"{namespace}/{image_key}",
211
- image,
193
+ image_value.image,
212
194
  global_step=line.state.num_steps,
213
195
  walltime=walltime,
214
196
  )
215
197
 
216
- if self.config is not None:
217
- writer.add_text("config", f"```\n{OmegaConf.to_yaml(self.config)}\n```")
218
- self.config = None
219
-
220
- if self.git_state is not None:
221
- writer.add_text("git", self.git_state)
222
- self.git_state = None
198
+ for namespace, videos in line.videos.items():
199
+ for video_key, video_value in videos.items():
200
+ writer.add_video(
201
+ f"{namespace}/{video_key}",
202
+ video_value.frames,
203
+ fps=video_value.fps,
204
+ global_step=line.state.num_steps,
205
+ )
223
206
 
224
- if self.training_code is not None:
225
- writer.add_text("code", self.training_code)
226
- self.training_code = None
207
+ for name, contents in self.files.items():
208
+ writer.add_text(name, contents)
209
+ self.files.clear()
@@ -1,6 +1,8 @@
1
1
  """Defines a single interface for all the mixins."""
2
2
 
3
3
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
4
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
5
+ from xax.task.mixins.compile import CompileConfig, CompileMixin
4
6
  from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
5
7
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
6
8
  from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
@@ -8,4 +10,4 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
8
10
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
9
11
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
10
12
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
11
- from xax.task.mixins.train import Batch, Model, Output, TrainConfig, TrainMixin
13
+ from xax.task.mixins.train import TrainConfig, TrainMixin
@@ -8,6 +8,8 @@ from dataclasses import dataclass
8
8
  from pathlib import Path
9
9
  from typing import Self, TypeVar
10
10
 
11
+ import jax
12
+
11
13
  from xax.core.conf import field, get_run_dir
12
14
  from xax.core.state import State
13
15
  from xax.nn.parallel import is_master
@@ -19,6 +21,7 @@ from xax.utils.text import show_info
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
24
+ @jax.tree_util.register_dataclass
22
25
  @dataclass
23
26
  class ArtifactsConfig(BaseConfig):
24
27
  exp_dir: str | None = field(None, help="The fixed experiment directory")
@@ -43,8 +46,12 @@ class ArtifactsMixin(BaseTask[Config]):
43
46
  run_dir = Path(task_file).resolve().parent
44
47
  return run_dir / self.task_name
45
48
 
46
- def set_exp_dir(self, exp_dir: Path) -> Self:
47
- self._exp_dir = exp_dir
49
+ @property
50
+ def exp_dir(self) -> Path:
51
+ return self.get_exp_dir()
52
+
53
+ def set_exp_dir(self, exp_dir: str | Path) -> Self:
54
+ self._exp_dir = Path(exp_dir).expanduser().resolve()
48
55
  return self
49
56
 
50
57
  def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
@@ -61,13 +68,16 @@ class ArtifactsMixin(BaseTask[Config]):
61
68
  elif not missing_ok:
62
69
  raise RuntimeError(f"Lock file not found at {lock_file}")
63
70
 
64
- @functools.cached_property
65
- def exp_dir(self) -> Path:
71
+ def get_exp_dir(self) -> Path:
72
+ if self._exp_dir is not None:
73
+ return self._exp_dir
74
+
66
75
  if self.config.exp_dir is not None:
67
76
  exp_dir = Path(self.config.exp_dir).expanduser().resolve()
68
77
  exp_dir.mkdir(parents=True, exist_ok=True)
69
- logger.log(LOG_STATUS, self.exp_dir)
70
- return exp_dir
78
+ self._exp_dir = exp_dir
79
+ logger.log(LOG_STATUS, self._exp_dir)
80
+ return self._exp_dir
71
81
 
72
82
  def get_exp_dir(run_id: int) -> Path:
73
83
  return self.run_dir / f"run_{run_id}"
@@ -81,9 +91,9 @@ class ArtifactsMixin(BaseTask[Config]):
81
91
  while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
82
92
  run_id += 1
83
93
  exp_dir.mkdir(exist_ok=True, parents=True)
84
- exp_dir = exp_dir.expanduser().resolve()
85
- logger.log(LOG_STATUS, exp_dir)
86
- return exp_dir
94
+ self._exp_dir = exp_dir.expanduser().resolve()
95
+ logger.log(LOG_STATUS, self._exp_dir)
96
+ return self._exp_dir
87
97
 
88
98
  @functools.lru_cache(maxsize=None)
89
99
  def stage_environment(self) -> Path | None:
@@ -0,0 +1,221 @@
1
+ """Defines a mixin for handling model checkpointing."""
2
+
3
+ import io
4
+ import json
5
+ import logging
6
+ import tarfile
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
10
+
11
+ import cloudpickle
12
+ import jax
13
+ import optax
14
+ from jaxtyping import PyTree
15
+ from omegaconf import DictConfig, OmegaConf
16
+
17
+ from xax.core.conf import field
18
+ from xax.core.state import State
19
+ from xax.nn.parallel import is_master
20
+ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
25
+
26
+
27
+ def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
28
+ """Defines the path to the checkpoint for a given state.
29
+
30
+ Args:
31
+ exp_dir: The experiment directory
32
+ state: The current trainer state
33
+
34
+ Returns:
35
+ The path to the checkpoint file.
36
+ """
37
+ if state is None:
38
+ return exp_dir / "checkpoints" / "ckpt.bin"
39
+ return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
40
+
41
+
42
+ @jax.tree_util.register_dataclass
43
+ @dataclass
44
+ class CheckpointingConfig(ArtifactsConfig):
45
+ save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
46
+ save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
47
+ only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
+ load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
+ load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
+ save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
+
52
+
53
+ Config = TypeVar("Config", bound=CheckpointingConfig)
54
+
55
+
56
+ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
57
+ def __init__(self, config: Config) -> None:
58
+ super().__init__(config)
59
+
60
+ self.__last_ckpt_time = 0.0
61
+
62
+ def get_ckpt_path(self, state: State | None = None) -> Path:
63
+ return get_ckpt_path(self.exp_dir, state)
64
+
65
+ def get_init_ckpt_path(self) -> Path | None:
66
+ if self._exp_dir is not None:
67
+ ckpt_path = self.get_ckpt_path()
68
+ if ckpt_path.exists():
69
+ return ckpt_path
70
+ if self.config.load_from_ckpt_path is not None:
71
+ ckpt_path = Path(self.config.load_from_ckpt_path)
72
+ assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
73
+ return ckpt_path
74
+ return None
75
+
76
+ def should_checkpoint(self, state: State) -> bool:
77
+ if self.config.save_every_n_steps is not None:
78
+ if state.num_steps % self.config.save_every_n_steps == 0:
79
+ return True
80
+ if self.config.save_every_n_seconds is not None:
81
+ last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
82
+ if cur_time - last_time >= self.config.save_every_n_seconds:
83
+ self.__last_ckpt_time = cur_time
84
+ return True
85
+ return False
86
+
87
+ @overload
88
+ def load_checkpoint(
89
+ self,
90
+ path: Path,
91
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
+
93
+ @overload
94
+ def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
95
+
96
+ @overload
97
+ def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
98
+
99
+ @overload
100
+ def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
101
+
102
+ @overload
103
+ def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
104
+
105
+ @overload
106
+ def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
107
+
108
+ def load_checkpoint(
109
+ self,
110
+ path: Path,
111
+ part: CheckpointPart | None = None,
112
+ ) -> (
113
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
114
+ | PyTree
115
+ | optax.GradientTransformation
116
+ | optax.OptState
117
+ | State
118
+ | DictConfig
119
+ ):
120
+ with tarfile.open(path, "r:gz") as tar:
121
+
122
+ def get_model() -> PyTree:
123
+ if (model := tar.extractfile("model")) is None:
124
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
125
+ return cloudpickle.load(model)
126
+
127
+ def get_opt() -> optax.GradientTransformation:
128
+ if (opt := tar.extractfile("opt")) is None:
129
+ raise ValueError(f"Checkpoint does not contain an opt file: {path}")
130
+ return cloudpickle.load(opt)
131
+
132
+ def get_opt_state() -> optax.OptState:
133
+ if (opt_state := tar.extractfile("opt_state")) is None:
134
+ raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
135
+ return cloudpickle.load(opt_state)
136
+
137
+ def get_state() -> State:
138
+ if (state := tar.extractfile("state")) is None:
139
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
140
+ return State(**json.loads(state.read().decode()))
141
+
142
+ def get_config() -> DictConfig:
143
+ if (config := tar.extractfile("config")) is None:
144
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
145
+ return cast(DictConfig, OmegaConf.load(config))
146
+
147
+ match part:
148
+ case "model":
149
+ return get_model()
150
+ case "opt":
151
+ return get_opt()
152
+ case "opt_state":
153
+ return get_opt_state()
154
+ case "state":
155
+ return get_state()
156
+ case "config":
157
+ return get_config()
158
+ case None:
159
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
160
+ case _:
161
+ raise ValueError(f"Invalid checkpoint part: {part}")
162
+
163
+ def save_checkpoint(
164
+ self,
165
+ model: PyTree,
166
+ optimizer: optax.GradientTransformation,
167
+ opt_state: optax.OptState,
168
+ state: State,
169
+ ) -> Path:
170
+ ckpt_path = self.get_ckpt_path(state)
171
+
172
+ if not is_master():
173
+ return ckpt_path
174
+
175
+ # Gets the path to the last checkpoint.
176
+ logger.info("Saving checkpoint to %s", ckpt_path)
177
+ last_ckpt_path = self.get_ckpt_path()
178
+ ckpt_path.parent.mkdir(exist_ok=True, parents=True)
179
+
180
+ # Potentially removes the last checkpoint.
181
+ if last_ckpt_path.exists() and self.config.only_save_most_recent:
182
+ if (base_ckpt := last_ckpt_path.resolve()).is_file():
183
+ base_ckpt.unlink()
184
+
185
+ # Combines all temporary files into a single checkpoint TAR file.
186
+ with tarfile.open(ckpt_path, "w:gz") as tar:
187
+
188
+ def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
189
+ with io.BytesIO() as buf:
190
+ write_fn(buf)
191
+ tarinfo = tarfile.TarInfo(name)
192
+ tarinfo.size = buf.tell()
193
+ buf.seek(0)
194
+ tar.addfile(tarinfo, buf)
195
+
196
+ add_file("model", lambda buf: cloudpickle.dump(model, buf))
197
+ add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
198
+ add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
199
+ add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
200
+ add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
201
+
202
+ if self.config.save_tf_model:
203
+ try:
204
+ from jax.experimental import jax2tf
205
+ except ModuleNotFoundError:
206
+ raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
207
+
208
+ tf_model = jax2tf.convert(model)
209
+ add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
210
+
211
+ # Updates the symlink to the new checkpoint.
212
+ last_ckpt_path.unlink(missing_ok=True)
213
+ try:
214
+ last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
215
+ except FileExistsError:
216
+ logger.exception("Exception while trying to update %s", ckpt_path)
217
+
218
+ # Marks directory as having artifacts which shouldn't be overwritten.
219
+ self.add_lock_file("ckpt", exists_ok=True)
220
+
221
+ return ckpt_path
@@ -0,0 +1,104 @@
1
+ """Defines a mixin for handling JAX compilation behavior.
2
+
3
+ This mixin allows control over JAX compilation settings like jit, pmap, and vmap
4
+ behavior during initialization and training.
5
+ """
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Generic, TypeVar
11
+
12
+ import jax
13
+
14
+ from xax.core.conf import field
15
+ from xax.task.base import BaseConfig, BaseTask
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @jax.tree_util.register_dataclass
21
+ @dataclass
22
+ class CompileOptions:
23
+ # JAX compilation options
24
+ disable_jit: bool = field(
25
+ value=False,
26
+ help="If True, disables JIT compilation",
27
+ )
28
+ enable_x64: bool = field(
29
+ value=False,
30
+ help="If True, enables 64-bit precision",
31
+ )
32
+ default_device: str | None = field(
33
+ value=None,
34
+ help="Default device to use (e.g. 'cpu', 'gpu')",
35
+ )
36
+
37
+ # JAX logging options
38
+ logging_level: str = field(
39
+ value="INFO",
40
+ help="JAX logging verbosity level",
41
+ )
42
+
43
+ # JAX cache options
44
+ cache_dir: str | None = field(
45
+ value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
46
+ help="Directory for JAX compilation cache. If None, caching is disabled",
47
+ )
48
+ cache_min_size_bytes: int = field(
49
+ value=-1,
50
+ help="Minimum size in bytes for cache entries. -1 means no minimum",
51
+ )
52
+ cache_min_compile_time_secs: float = field(
53
+ value=0.0,
54
+ help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
+ )
56
+ cache_enable_xla: str = field(
57
+ value="all",
58
+ help="Which XLA caches to enable",
59
+ )
60
+
61
+
62
+ @jax.tree_util.register_dataclass
63
+ @dataclass
64
+ class CompileConfig(BaseConfig):
65
+ compile: CompileOptions = field(CompileOptions(), help="Compilation configuration")
66
+
67
+
68
+ Config = TypeVar("Config", bound=CompileConfig)
69
+
70
+
71
+ class CompileMixin(BaseTask[Config], Generic[Config]):
72
+ """Defines a task mixin for controlling JAX compilation behavior."""
73
+
74
+ def __init__(self, config: Config) -> None:
75
+ super().__init__(config)
76
+
77
+ cc = self.config.compile
78
+
79
+ # Set basic compilation flags
80
+ if cc.disable_jit:
81
+ logger.info("Disabling JIT compilation")
82
+ jax.config.update("jax_disable_jit", True)
83
+
84
+ if cc.enable_x64:
85
+ logger.info("Enabling 64-bit precision")
86
+ jax.config.update("jax_enable_x64", True)
87
+
88
+ if cc.default_device is not None:
89
+ logger.info("Setting default device to %s", cc.default_device)
90
+ jax.config.update("jax_default_device", cc.default_device)
91
+
92
+ # Set logging level
93
+ logger.info("Setting JAX logging level to %s", cc.logging_level)
94
+ jax.config.update("jax_logging_level", cc.logging_level)
95
+
96
+ # Configure compilation cache
97
+ if cc.cache_dir is not None:
98
+ logger.info("Setting JAX compilation cache directory to %s", cc.cache_dir)
99
+ jax.config.update("jax_compilation_cache_dir", cc.cache_dir)
100
+
101
+ logger.info("Configuring JAX compilation cache parameters")
102
+ jax.config.update("jax_persistent_cache_min_entry_size_bytes", cc.cache_min_size_bytes)
103
+ jax.config.update("jax_persistent_cache_min_compile_time_secs", cc.cache_min_compile_time_secs)
104
+ jax.config.update("jax_persistent_cache_enable_xla_caches", cc.cache_enable_xla)