xax 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
- xax/nn/geom.py +34 -0
- xax/task/mixins/checkpointing.py +9 -12
- xax/task/mixins/cpu_stats.py +12 -9
- xax/task/mixins/gpu_stats.py +14 -11
- xax/task/mixins/process.py +14 -8
- xax/task/mixins/train.py +133 -19
- {xax-0.2.0.dist-info → xax-0.2.2.dist-info}/METADATA +1 -1
- {xax-0.2.0.dist-info → xax-0.2.2.dist-info}/RECORD +12 -12
- {xax-0.2.0.dist-info → xax-0.2.2.dist-info}/WHEEL +0 -0
- {xax-0.2.0.dist-info → xax-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.0.dist-info → xax-0.2.2.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.2"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
44
44
|
"euler_to_quat",
|
45
45
|
"get_projected_gravity_vector_from_quat",
|
46
46
|
"quat_to_euler",
|
47
|
+
"quat_to_rotmat",
|
47
48
|
"rotate_vector_by_quat",
|
48
49
|
"cross_entropy",
|
49
50
|
"cast_norm_type",
|
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
|
|
206
207
|
"euler_to_quat": "nn.geom",
|
207
208
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
208
209
|
"quat_to_euler": "nn.geom",
|
210
|
+
"quat_to_rotmat": "nn.geom",
|
209
211
|
"rotate_vector_by_quat": "nn.geom",
|
210
212
|
"cross_entropy": "nn.losses",
|
211
213
|
"cast_norm_type": "nn.norm",
|
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
369
371
|
euler_to_quat,
|
370
372
|
get_projected_gravity_vector_from_quat,
|
371
373
|
quat_to_euler,
|
374
|
+
quat_to_rotmat,
|
372
375
|
rotate_vector_by_quat,
|
373
376
|
)
|
374
377
|
from xax.nn.losses import cross_entropy
|
xax/nn/geom.py
CHANGED
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
|
|
177
177
|
y_diff = y_end - y_start
|
178
178
|
bezier = x**3 + 3 * (x**2 * (1 - x))
|
179
179
|
return y_start + y_diff * bezier
|
180
|
+
|
181
|
+
|
182
|
+
def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
183
|
+
"""Converts a quaternion to a rotation matrix.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
quat: The quaternion to convert, shape (*, 4).
|
187
|
+
eps: A small epsilon value to avoid division by zero.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The rotation matrix, shape (*, 3, 3).
|
191
|
+
"""
|
192
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
193
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
194
|
+
|
195
|
+
xx = 1 - 2 * (y * y + z * z)
|
196
|
+
xy = 2 * (x * y - z * w)
|
197
|
+
xz = 2 * (x * z + y * w)
|
198
|
+
yx = 2 * (x * y + z * w)
|
199
|
+
yy = 1 - 2 * (x * x + z * z)
|
200
|
+
yz = 2 * (y * z - x * w)
|
201
|
+
zx = 2 * (x * z - y * w)
|
202
|
+
zy = 2 * (y * z + x * w)
|
203
|
+
zz = 1 - 2 * (x * x + y * y)
|
204
|
+
|
205
|
+
# Corrected stacking: row-major order
|
206
|
+
return jnp.concatenate(
|
207
|
+
[
|
208
|
+
jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
|
209
|
+
jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
|
210
|
+
jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
|
211
|
+
],
|
212
|
+
axis=-2,
|
213
|
+
)
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -63,10 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
63
63
|
|
64
64
|
def get_init_ckpt_path(self) -> Path | None:
|
65
65
|
if self._exp_dir is not None:
|
66
|
-
ckpt_path
|
67
|
-
if not ckpt_path.exists():
|
68
|
-
logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
|
69
|
-
else:
|
66
|
+
if (ckpt_path := self.get_ckpt_path()).exists():
|
70
67
|
return ckpt_path
|
71
68
|
if self.config.load_from_ckpt_path is not None:
|
72
69
|
ckpt_path = Path(self.config.load_from_ckpt_path)
|
@@ -86,7 +83,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
86
83
|
return False
|
87
84
|
|
88
85
|
@overload
|
89
|
-
def
|
86
|
+
def load_ckpt_with_template(
|
90
87
|
self,
|
91
88
|
path: Path,
|
92
89
|
*,
|
@@ -97,7 +94,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
97
94
|
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
98
95
|
|
99
96
|
@overload
|
100
|
-
def
|
97
|
+
def load_ckpt_with_template(
|
101
98
|
self,
|
102
99
|
path: Path,
|
103
100
|
*,
|
@@ -106,7 +103,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
106
103
|
) -> tuple[PyTree, State, Config]: ...
|
107
104
|
|
108
105
|
@overload
|
109
|
-
def
|
106
|
+
def load_ckpt_with_template(
|
110
107
|
self,
|
111
108
|
path: Path,
|
112
109
|
*,
|
@@ -115,7 +112,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
115
112
|
) -> PyTree: ...
|
116
113
|
|
117
114
|
@overload
|
118
|
-
def
|
115
|
+
def load_ckpt_with_template(
|
119
116
|
self,
|
120
117
|
path: Path,
|
121
118
|
*,
|
@@ -124,7 +121,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
124
121
|
) -> optax.GradientTransformation: ...
|
125
122
|
|
126
123
|
@overload
|
127
|
-
def
|
124
|
+
def load_ckpt_with_template(
|
128
125
|
self,
|
129
126
|
path: Path,
|
130
127
|
*,
|
@@ -133,7 +130,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
133
130
|
) -> optax.OptState: ...
|
134
131
|
|
135
132
|
@overload
|
136
|
-
def
|
133
|
+
def load_ckpt_with_template(
|
137
134
|
self,
|
138
135
|
path: Path,
|
139
136
|
*,
|
@@ -141,14 +138,14 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
141
138
|
) -> State: ...
|
142
139
|
|
143
140
|
@overload
|
144
|
-
def
|
141
|
+
def load_ckpt_with_template(
|
145
142
|
self,
|
146
143
|
path: Path,
|
147
144
|
*,
|
148
145
|
part: Literal["config"],
|
149
146
|
) -> Config: ...
|
150
147
|
|
151
|
-
def
|
148
|
+
def load_ckpt_with_template(
|
152
149
|
self,
|
153
150
|
path: Path,
|
154
151
|
*,
|
xax/task/mixins/cpu_stats.py
CHANGED
@@ -218,33 +218,36 @@ class CPUStatsMonitor:
|
|
218
218
|
class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
219
219
|
"""Defines a task mixin for getting CPU statistics."""
|
220
220
|
|
221
|
-
_cpu_stats_monitor: CPUStatsMonitor
|
221
|
+
_cpu_stats_monitor: CPUStatsMonitor | None
|
222
222
|
|
223
223
|
def __init__(self, config: Config) -> None:
|
224
224
|
super().__init__(config)
|
225
225
|
|
226
|
-
self.
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
)
|
226
|
+
if (ctx := self.multiprocessing_context) is not None and (mgr := self.multiprocessing_manager) is not None:
|
227
|
+
self._cpu_stats_monitor = CPUStatsMonitor(self.config.cpu_stats.ping_interval, ctx, mgr)
|
228
|
+
else:
|
229
|
+
self._cpu_stats_monitor = None
|
231
230
|
|
232
231
|
def on_training_start(self, state: State) -> State:
|
233
232
|
state = super().on_training_start(state)
|
234
233
|
|
235
|
-
self._cpu_stats_monitor
|
234
|
+
if (monitor := self._cpu_stats_monitor) is not None:
|
235
|
+
monitor.start()
|
236
236
|
return state
|
237
237
|
|
238
238
|
def on_training_end(self, state: State) -> State:
|
239
239
|
state = super().on_training_end(state)
|
240
240
|
|
241
|
-
self._cpu_stats_monitor
|
241
|
+
if (monitor := self._cpu_stats_monitor) is not None:
|
242
|
+
monitor.stop()
|
242
243
|
return state
|
243
244
|
|
244
245
|
def on_step_start(self, state: State) -> State:
|
245
246
|
state = super().on_step_start(state)
|
246
247
|
|
247
|
-
monitor
|
248
|
+
if (monitor := self._cpu_stats_monitor) is None:
|
249
|
+
return state
|
250
|
+
|
248
251
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
249
252
|
|
250
253
|
if stats is not None:
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -234,24 +234,27 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
234
234
|
def __init__(self, config: Config) -> None:
|
235
235
|
super().__init__(config)
|
236
236
|
|
237
|
-
|
238
|
-
|
239
|
-
self.
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
237
|
+
if (
|
238
|
+
shutil.which("nvidia-smi") is not None
|
239
|
+
and (ctx := self.multiprocessing_context) is not None
|
240
|
+
and (mgr := self.multiprocessing_manager) is not None
|
241
|
+
):
|
242
|
+
self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, ctx, mgr)
|
243
|
+
else:
|
244
|
+
self._gpu_stats_monitor = None
|
244
245
|
|
245
246
|
def on_training_start(self, state: State) -> State:
|
246
247
|
state = super().on_training_start(state)
|
247
|
-
|
248
|
-
|
248
|
+
|
249
|
+
if (monitor := self._gpu_stats_monitor) is not None:
|
250
|
+
monitor.start()
|
249
251
|
return state
|
250
252
|
|
251
253
|
def on_training_end(self, state: State) -> State:
|
252
254
|
state = super().on_training_end(state)
|
253
|
-
|
254
|
-
|
255
|
+
|
256
|
+
if (monitor := self._gpu_stats_monitor) is not None:
|
257
|
+
monitor.stop()
|
255
258
|
return state
|
256
259
|
|
257
260
|
def on_step_start(self, state: State) -> State:
|
xax/task/mixins/process.py
CHANGED
@@ -20,6 +20,7 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
20
20
|
@dataclass
|
21
21
|
class ProcessConfig(BaseConfig):
|
22
22
|
multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
|
23
|
+
disable_multiprocessing: bool = field(False, help="If set, disable multiprocessing")
|
23
24
|
|
24
25
|
|
25
26
|
Config = TypeVar("Config", bound=ProcessConfig)
|
@@ -28,27 +29,32 @@ Config = TypeVar("Config", bound=ProcessConfig)
|
|
28
29
|
class ProcessMixin(BaseTask[Config], Generic[Config]):
|
29
30
|
"""Defines a base trainer mixin for handling monitoring processes."""
|
30
31
|
|
31
|
-
_mp_ctx: BaseContext
|
32
|
-
_mp_manager: SyncManager
|
32
|
+
_mp_ctx: BaseContext | None
|
33
|
+
_mp_manager: SyncManager | None
|
33
34
|
|
34
35
|
def __init__(self, config: Config) -> None:
|
35
36
|
super().__init__(config)
|
36
37
|
|
37
|
-
self.
|
38
|
-
|
38
|
+
if self.config.disable_multiprocessing:
|
39
|
+
self._mp_ctx = None
|
40
|
+
self._mp_manager = None
|
41
|
+
else:
|
42
|
+
self._mp_ctx = mp.get_context(config.multiprocessing_context)
|
43
|
+
self._mp_manager = self._mp_ctx.Manager()
|
39
44
|
|
40
45
|
@property
|
41
|
-
def multiprocessing_context(self) -> BaseContext:
|
46
|
+
def multiprocessing_context(self) -> BaseContext | None:
|
42
47
|
return self._mp_ctx
|
43
48
|
|
44
49
|
@property
|
45
|
-
def multiprocessing_manager(self) -> SyncManager:
|
50
|
+
def multiprocessing_manager(self) -> SyncManager | None:
|
46
51
|
return self._mp_manager
|
47
52
|
|
48
53
|
def on_training_end(self, state: State) -> State:
|
49
54
|
state = super().on_training_end(state)
|
50
55
|
|
51
|
-
self._mp_manager
|
52
|
-
|
56
|
+
if self._mp_manager is not None:
|
57
|
+
self._mp_manager.shutdown()
|
58
|
+
self._mp_manager.join()
|
53
59
|
|
54
60
|
return state
|
xax/task/mixins/train.py
CHANGED
@@ -12,6 +12,7 @@ import time
|
|
12
12
|
import traceback
|
13
13
|
from abc import ABC, abstractmethod
|
14
14
|
from dataclasses import asdict, dataclass, is_dataclass
|
15
|
+
from pathlib import Path
|
15
16
|
from threading import Thread
|
16
17
|
from typing import (
|
17
18
|
Any,
|
@@ -39,7 +40,7 @@ from xax.core.state import Phase, State
|
|
39
40
|
from xax.nn.functions import set_random_seed
|
40
41
|
from xax.nn.parallel import is_master
|
41
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
42
|
-
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
43
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
|
43
44
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
44
45
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
45
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -54,7 +55,7 @@ from xax.utils.experiments import (
|
|
54
55
|
get_training_code,
|
55
56
|
)
|
56
57
|
from xax.utils.jax import jit as xax_jit
|
57
|
-
from xax.utils.logging import LOG_STATUS
|
58
|
+
from xax.utils.logging import LOG_PING, LOG_STATUS
|
58
59
|
from xax.utils.text import highlight_exception_message, show_info
|
59
60
|
from xax.utils.types.frozen_dict import FrozenDict
|
60
61
|
|
@@ -340,12 +341,7 @@ class TrainMixin(
|
|
340
341
|
|
341
342
|
if init_ckpt_path is not None:
|
342
343
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
343
|
-
|
344
|
-
model, state, config = self.load_checkpoint(
|
345
|
-
init_ckpt_path,
|
346
|
-
part="model_state_config",
|
347
|
-
model_template=model_spec,
|
348
|
-
)
|
344
|
+
model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
|
349
345
|
config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
|
350
346
|
if config_diff:
|
351
347
|
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
@@ -353,17 +349,11 @@ class TrainMixin(
|
|
353
349
|
if not load_optimizer:
|
354
350
|
return model, state
|
355
351
|
|
356
|
-
|
357
|
-
|
358
|
-
optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
|
359
|
-
|
360
|
-
# Loads the optimizer state.
|
361
|
-
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
362
|
-
opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
|
363
|
-
|
352
|
+
optimizer = self.load_ckpt(init_ckpt_path, part="opt")
|
353
|
+
opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
|
364
354
|
return model, optimizer, opt_state, state
|
365
355
|
|
366
|
-
logger.info("
|
356
|
+
logger.info("Starting a new training run")
|
367
357
|
model = self.get_model(key)
|
368
358
|
state = State.init_state()
|
369
359
|
|
@@ -375,6 +365,131 @@ class TrainMixin(
|
|
375
365
|
|
376
366
|
return model, optimizer, opt_state, state
|
377
367
|
|
368
|
+
@overload
|
369
|
+
def load_ckpt(
|
370
|
+
self,
|
371
|
+
path: Path,
|
372
|
+
*,
|
373
|
+
part: Literal["all"],
|
374
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
375
|
+
|
376
|
+
@overload
|
377
|
+
def load_ckpt(
|
378
|
+
self,
|
379
|
+
path: Path,
|
380
|
+
*,
|
381
|
+
part: Literal["model_state_config"],
|
382
|
+
) -> tuple[PyTree, State, Config]: ...
|
383
|
+
|
384
|
+
@overload
|
385
|
+
def load_ckpt(
|
386
|
+
self,
|
387
|
+
path: Path,
|
388
|
+
*,
|
389
|
+
part: Literal["model"],
|
390
|
+
) -> PyTree: ...
|
391
|
+
|
392
|
+
@overload
|
393
|
+
def load_ckpt(
|
394
|
+
self,
|
395
|
+
path: Path,
|
396
|
+
*,
|
397
|
+
part: Literal["opt"],
|
398
|
+
) -> optax.GradientTransformation: ...
|
399
|
+
|
400
|
+
@overload
|
401
|
+
def load_ckpt(
|
402
|
+
self,
|
403
|
+
path: Path,
|
404
|
+
*,
|
405
|
+
part: Literal["opt_state"],
|
406
|
+
model: PyTree | None = None,
|
407
|
+
optimizer: optax.GradientTransformation | None = None,
|
408
|
+
) -> optax.OptState: ...
|
409
|
+
|
410
|
+
@overload
|
411
|
+
def load_ckpt(
|
412
|
+
self,
|
413
|
+
path: Path,
|
414
|
+
*,
|
415
|
+
part: Literal["state"],
|
416
|
+
) -> State: ...
|
417
|
+
|
418
|
+
@overload
|
419
|
+
def load_ckpt(
|
420
|
+
self,
|
421
|
+
path: Path,
|
422
|
+
*,
|
423
|
+
part: Literal["config"],
|
424
|
+
) -> Config: ...
|
425
|
+
|
426
|
+
def load_ckpt(
|
427
|
+
self,
|
428
|
+
path: str | Path,
|
429
|
+
*,
|
430
|
+
part: CheckpointPart = "all",
|
431
|
+
model: PyTree | None = None,
|
432
|
+
optimizer: optax.GradientTransformation | None = None,
|
433
|
+
) -> (
|
434
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
435
|
+
| tuple[PyTree, State, Config]
|
436
|
+
| PyTree
|
437
|
+
| optax.GradientTransformation
|
438
|
+
| optax.OptState
|
439
|
+
| State
|
440
|
+
| Config
|
441
|
+
):
|
442
|
+
path = Path(path)
|
443
|
+
|
444
|
+
# This key isn't used for anything, it's just a required argument.
|
445
|
+
key = jax.random.PRNGKey(0)
|
446
|
+
|
447
|
+
match part:
|
448
|
+
case "model_state_config":
|
449
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
450
|
+
return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
|
451
|
+
|
452
|
+
case "model":
|
453
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
454
|
+
return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
455
|
+
|
456
|
+
case "config":
|
457
|
+
return self.load_ckpt_with_template(path, part="config")
|
458
|
+
|
459
|
+
case "opt":
|
460
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
461
|
+
return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
462
|
+
|
463
|
+
case "opt_state":
|
464
|
+
if model is None:
|
465
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
466
|
+
model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
467
|
+
if optimizer is None:
|
468
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
469
|
+
optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
470
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
471
|
+
return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
|
472
|
+
|
473
|
+
case "state":
|
474
|
+
return self.load_ckpt_with_template(path, part="state")
|
475
|
+
|
476
|
+
case "config":
|
477
|
+
return self.load_ckpt_with_template(path, part="config")
|
478
|
+
|
479
|
+
case "all":
|
480
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
481
|
+
model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
482
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
483
|
+
optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
484
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
485
|
+
opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
|
486
|
+
state = self.load_ckpt_with_template(path, part="state")
|
487
|
+
config = self.load_ckpt_with_template(path, part="config")
|
488
|
+
return model, optimizer, opt_state, state, config
|
489
|
+
|
490
|
+
case _:
|
491
|
+
raise ValueError(f"Unknown checkpoint part: {part}")
|
492
|
+
|
378
493
|
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
379
494
|
"""Gets the output from the model.
|
380
495
|
|
@@ -529,8 +644,7 @@ class TrainMixin(
|
|
529
644
|
self._last_printed_remaining_time = state.elapsed_time_s
|
530
645
|
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
531
646
|
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
532
|
-
|
533
|
-
jax.debug.print("Estimated finish time: {}", termination_time)
|
647
|
+
logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
|
534
648
|
|
535
649
|
def get_remaining_percent(self, state: State) -> float | None:
|
536
650
|
if self.config.max_steps is None:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Yj2SgoKyIAQzg3bt-hAS4gf0fqlfVBR4pv4JgpTl7-s,14182
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
@@ -32,16 +32,16 @@ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,65
|
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
|
-
xax/task/mixins/cpu_stats.py,sha256=
|
37
|
+
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
39
|
-
xax/task/mixins/gpu_stats.py,sha256=
|
39
|
+
xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
|
40
40
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
41
|
-
xax/task/mixins/process.py,sha256=
|
41
|
+
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.2.dist-info/METADATA,sha256=Ku0h6R6WToJ4rMYhcswGLXtIGVtzouWIGelHZFW30IM,1882
|
63
|
+
xax-0.2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|