xax 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +2 -15
- xax/core/state.py +22 -17
- xax/task/base.py +11 -5
- xax/task/mixins/train.py +74 -48
- xax/task/script.py +3 -0
- {xax-0.1.8.dist-info → xax-0.1.9.dist-info}/METADATA +1 -1
- {xax-0.1.8.dist-info → xax-0.1.9.dist-info}/RECORD +10 -10
- {xax-0.1.8.dist-info → xax-0.1.9.dist-info}/WHEEL +0 -0
- {xax-0.1.8.dist-info → xax-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.8.dist-info → xax-0.1.9.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.9"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,7 +23,6 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
-
"cast_phase",
|
27
26
|
"FourierEmbeddings",
|
28
27
|
"IdentityPositionalEmbeddings",
|
29
28
|
"LearnedPositionalEmbeddings",
|
@@ -41,9 +40,6 @@ __all__ = [
|
|
41
40
|
"load_eqx_mlp",
|
42
41
|
"make_eqx_mlp",
|
43
42
|
"save_eqx",
|
44
|
-
"export",
|
45
|
-
"export_flax",
|
46
|
-
"export_with_params",
|
47
43
|
"euler_to_quat",
|
48
44
|
"get_projected_gravity_vector_from_quat",
|
49
45
|
"quat_to_euler",
|
@@ -180,7 +176,6 @@ NAME_MAP: dict[str, str] = {
|
|
180
176
|
"get_run_dir": "core.conf",
|
181
177
|
"load_user_config": "core.conf",
|
182
178
|
"State": "core.state",
|
183
|
-
"cast_phase": "core.state",
|
184
179
|
"FourierEmbeddings": "nn.embeddings",
|
185
180
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
186
181
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -198,9 +193,6 @@ NAME_MAP: dict[str, str] = {
|
|
198
193
|
"load_eqx_mlp": "nn.equinox",
|
199
194
|
"make_eqx_mlp": "nn.equinox",
|
200
195
|
"save_eqx": "nn.equinox",
|
201
|
-
"export": "nn.export",
|
202
|
-
"export_flax": "nn.export",
|
203
|
-
"export_with_params": "nn.export",
|
204
196
|
"euler_to_quat": "nn.geom",
|
205
197
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
206
198
|
"quat_to_euler": "nn.geom",
|
@@ -329,7 +321,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
329
321
|
get_run_dir,
|
330
322
|
load_user_config,
|
331
323
|
)
|
332
|
-
from xax.core.state import Phase, State
|
324
|
+
from xax.core.state import Phase, State
|
333
325
|
from xax.nn.embeddings import (
|
334
326
|
EmbeddingKind,
|
335
327
|
FourierEmbeddings,
|
@@ -354,11 +346,6 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
354
346
|
make_eqx_mlp,
|
355
347
|
save_eqx,
|
356
348
|
)
|
357
|
-
from xax.nn.export import (
|
358
|
-
export,
|
359
|
-
export_flax,
|
360
|
-
export_with_params,
|
361
|
-
)
|
362
349
|
from xax.nn.geom import (
|
363
350
|
euler_to_quat,
|
364
351
|
get_projected_gravity_vector_from_quat,
|
xax/core/state.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
3
|
import time
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Literal, NotRequired, TypedDict,
|
4
|
+
from dataclasses import asdict, dataclass
|
5
|
+
from typing import Any, Literal, NotRequired, TypedDict, Unpack, cast
|
6
6
|
|
7
|
+
import jax
|
7
8
|
from omegaconf import MISSING
|
8
9
|
|
9
10
|
from xax.core.conf import field
|
@@ -11,12 +12,6 @@ from xax.core.conf import field
|
|
11
12
|
Phase = Literal["train", "valid"]
|
12
13
|
|
13
14
|
|
14
|
-
def cast_phase(raw_phase: str) -> Phase:
|
15
|
-
args = get_args(Phase)
|
16
|
-
assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
|
17
|
-
return cast(Phase, raw_phase)
|
18
|
-
|
19
|
-
|
20
15
|
class StateDict(TypedDict, total=False):
|
21
16
|
num_steps: NotRequired[int]
|
22
17
|
num_samples: NotRequired[int]
|
@@ -24,10 +19,11 @@ class StateDict(TypedDict, total=False):
|
|
24
19
|
num_valid_samples: NotRequired[int]
|
25
20
|
start_time_s: NotRequired[float]
|
26
21
|
elapsed_time_s: NotRequired[float]
|
27
|
-
|
22
|
+
phase: NotRequired[Phase]
|
28
23
|
|
29
24
|
|
30
|
-
@
|
25
|
+
@jax.tree_util.register_dataclass
|
26
|
+
@dataclass(frozen=True, kw_only=True)
|
31
27
|
class State:
|
32
28
|
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
29
|
num_samples: int = field(MISSING, help="Number of sample so far")
|
@@ -35,15 +31,11 @@ class State:
|
|
35
31
|
num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
|
36
32
|
start_time_s: float = field(MISSING, help="Start time of training")
|
37
33
|
elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
|
38
|
-
|
34
|
+
_phase: int = field(MISSING, help="Current training phase")
|
39
35
|
|
40
36
|
@property
|
41
37
|
def phase(self) -> Phase:
|
42
|
-
return
|
43
|
-
|
44
|
-
@phase.setter
|
45
|
-
def phase(self, phase: Phase) -> None:
|
46
|
-
self.raw_phase = phase
|
38
|
+
return cast(Phase, ["train", "valid"][self._phase])
|
47
39
|
|
48
40
|
@classmethod
|
49
41
|
def init_state(cls) -> "State":
|
@@ -54,7 +46,7 @@ class State:
|
|
54
46
|
num_valid_samples=0,
|
55
47
|
start_time_s=time.time(),
|
56
48
|
elapsed_time_s=0.0,
|
57
|
-
|
49
|
+
_phase=0,
|
58
50
|
)
|
59
51
|
|
60
52
|
@property
|
@@ -69,3 +61,16 @@ class State:
|
|
69
61
|
return self.num_valid_steps
|
70
62
|
case _:
|
71
63
|
raise ValueError(f"Invalid phase: {phase}")
|
64
|
+
|
65
|
+
def replace(self, **kwargs: Unpack[StateDict]) -> "State":
|
66
|
+
extra_kwargs: dict[str, Any] = {} # noqa: ANN401
|
67
|
+
if "phase" in kwargs:
|
68
|
+
phase = kwargs.pop("phase")
|
69
|
+
match phase:
|
70
|
+
case "train":
|
71
|
+
extra_kwargs["_phase"] = 0
|
72
|
+
case "valid":
|
73
|
+
extra_kwargs["_phase"] = 1
|
74
|
+
case _:
|
75
|
+
raise ValueError(f"Invalid phase: {phase}")
|
76
|
+
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
xax/task/base.py
CHANGED
@@ -16,7 +16,8 @@ from types import TracebackType
|
|
16
16
|
from typing import Generic, Self, TypeVar, cast
|
17
17
|
|
18
18
|
import jax
|
19
|
-
from omegaconf import
|
19
|
+
from omegaconf import DictConfig, OmegaConf
|
20
|
+
from omegaconf.base import SCMode
|
20
21
|
|
21
22
|
from xax.core.state import State
|
22
23
|
from xax.utils.text import camelcase_to_snakecase
|
@@ -66,9 +67,6 @@ class BaseTask(Generic[Config]):
|
|
66
67
|
|
67
68
|
self.config = config
|
68
69
|
|
69
|
-
if isinstance(self.config, Container):
|
70
|
-
OmegaConf.resolve(self.config)
|
71
|
-
|
72
70
|
def on_step_start(self, state: State) -> State:
|
73
71
|
return state
|
74
72
|
|
@@ -195,7 +193,15 @@ class BaseTask(Generic[Config]):
|
|
195
193
|
cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
|
196
194
|
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
|
197
195
|
|
198
|
-
return cast(
|
196
|
+
return cast(
|
197
|
+
Config,
|
198
|
+
OmegaConf.to_container(
|
199
|
+
cfg,
|
200
|
+
resolve=True,
|
201
|
+
throw_on_missing=True,
|
202
|
+
structured_config_mode=SCMode.INSTANTIATE,
|
203
|
+
),
|
204
|
+
)
|
199
205
|
|
200
206
|
@classmethod
|
201
207
|
def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
|
xax/task/mixins/train.py
CHANGED
@@ -53,6 +53,7 @@ from xax.utils.experiments import (
|
|
53
53
|
get_packages_with_versions,
|
54
54
|
get_training_code,
|
55
55
|
)
|
56
|
+
from xax.utils.jax import jit as xax_jit
|
56
57
|
from xax.utils.logging import LOG_STATUS
|
57
58
|
from xax.utils.text import highlight_exception_message, show_info
|
58
59
|
|
@@ -212,30 +213,27 @@ class TrainMixin(
|
|
212
213
|
|
213
214
|
def on_step_end(self, state: State) -> State:
|
214
215
|
state = super().on_step_end(state)
|
215
|
-
state.elapsed_time_s
|
216
|
-
return state
|
216
|
+
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
217
217
|
|
218
|
-
def log_train_step(self,
|
218
|
+
def log_train_step(self, batch: Batch, output: Output, state: State) -> None:
|
219
219
|
"""Override this function to do logging during the training phase.
|
220
220
|
|
221
221
|
This function is called after the model forward pass and before the
|
222
222
|
backward pass. It is called in the training phase.
|
223
223
|
|
224
224
|
Args:
|
225
|
-
model: The current model.
|
226
225
|
batch: The batch from the dataloader.
|
227
226
|
output: The model output.
|
228
227
|
state: The current training state.
|
229
228
|
"""
|
230
229
|
|
231
|
-
def log_valid_step(self,
|
230
|
+
def log_valid_step(self, batch: Batch, output: Output, state: State) -> None:
|
232
231
|
"""Override this function to do logging during the validation phase.
|
233
232
|
|
234
233
|
This function is called after the model forward pass. It is called in
|
235
234
|
the validation phase.
|
236
235
|
|
237
236
|
Args:
|
238
|
-
model: The current model.
|
239
237
|
batch: The batch from the dataloader.
|
240
238
|
output: The model output.
|
241
239
|
state: The current training state.
|
@@ -248,7 +246,7 @@ class TrainMixin(
|
|
248
246
|
for k, v in d.items():
|
249
247
|
self.logger.log_scalar(k, v, namespace=ns)
|
250
248
|
|
251
|
-
def log_step(self,
|
249
|
+
def log_step(self, batch: Batch, output: Output, loss: Array, state: State) -> None:
|
252
250
|
phase = state.phase
|
253
251
|
|
254
252
|
self.logger.log_scalar("loss", loss, namespace="loss")
|
@@ -257,9 +255,9 @@ class TrainMixin(
|
|
257
255
|
# Delegate to the appropriate logging function based on the phase.
|
258
256
|
match phase:
|
259
257
|
case "train":
|
260
|
-
self.log_train_step(
|
258
|
+
self.log_train_step(batch, output, state)
|
261
259
|
case "valid":
|
262
|
-
self.log_valid_step(
|
260
|
+
self.log_valid_step(batch, output, state)
|
263
261
|
case _:
|
264
262
|
raise KeyError(f"Unknown phase: {phase}")
|
265
263
|
|
@@ -332,8 +330,7 @@ class TrainMixin(
|
|
332
330
|
|
333
331
|
return model, optimizer, opt_state, state
|
334
332
|
|
335
|
-
|
336
|
-
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
333
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
337
334
|
"""Gets the output from the model.
|
338
335
|
|
339
336
|
By default, we assume the model is a function that takes the batch as
|
@@ -343,11 +340,11 @@ class TrainMixin(
|
|
343
340
|
Args:
|
344
341
|
model: The current model.
|
345
342
|
batch: The current minibatch of samples.
|
343
|
+
state: The current training state.
|
346
344
|
"""
|
347
345
|
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
348
346
|
|
349
|
-
|
350
|
-
def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
|
347
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
351
348
|
"""Gets the loss for the current batch.
|
352
349
|
|
353
350
|
By default, we assume the model is a function that takes the batch as
|
@@ -358,6 +355,7 @@ class TrainMixin(
|
|
358
355
|
model: The current model.
|
359
356
|
batch: The current minibatch of samples.
|
360
357
|
output: The output from the model.
|
358
|
+
state: The current training state.
|
361
359
|
|
362
360
|
Returns:
|
363
361
|
The computed loss, as a tensor.
|
@@ -366,24 +364,32 @@ class TrainMixin(
|
|
366
364
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
367
365
|
return output
|
368
366
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
367
|
+
def get_output_and_loss(
|
368
|
+
self,
|
369
|
+
model_static: PyTree,
|
370
|
+
model_arr: PyTree,
|
371
|
+
batch: Batch,
|
372
|
+
state: State,
|
373
|
+
) -> tuple[Array, Output]:
|
374
|
+
model = eqx.combine(model_arr, model_static)
|
375
|
+
output = self.get_output(model, batch, state)
|
376
|
+
loss = self.compute_loss(model, batch, output, state)
|
373
377
|
return loss, output
|
374
378
|
|
375
|
-
@eqx.filter_jit
|
376
379
|
def update(
|
377
380
|
self,
|
378
|
-
|
381
|
+
model_static: PyTree,
|
382
|
+
model_arr: PyTree,
|
379
383
|
optimizer: optax.GradientTransformation,
|
380
384
|
opt_state: optax.OptState,
|
381
385
|
batch: Batch,
|
386
|
+
state: State,
|
382
387
|
) -> tuple[Array, PyTree, optax.OptState, Output]:
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
388
|
+
grad_fn = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)
|
389
|
+
(loss, output), grads = grad_fn(model_static, model_arr, batch, state)
|
390
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
391
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
392
|
+
return loss, model_arr, opt_state, output
|
387
393
|
|
388
394
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
389
395
|
"""Gets the batch size for the current batch.
|
@@ -457,21 +463,31 @@ class TrainMixin(
|
|
457
463
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
458
464
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
459
465
|
|
460
|
-
|
466
|
+
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
467
|
+
return eqx.is_inexact_array(item)
|
468
|
+
|
469
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"])
|
461
470
|
def train_step(
|
462
471
|
self,
|
463
|
-
|
472
|
+
model_static: PyTree,
|
473
|
+
model_arr: PyTree,
|
464
474
|
optimizer: optax.GradientTransformation,
|
465
475
|
opt_state: optax.OptState,
|
466
476
|
batch: Batch,
|
477
|
+
state: State,
|
467
478
|
) -> tuple[PyTree, optax.OptState, Array, Output]:
|
468
|
-
loss,
|
469
|
-
return
|
479
|
+
loss, model_arr, opt_state, output = self.update(model_static, model_arr, optimizer, opt_state, batch, state)
|
480
|
+
return model_arr, opt_state, loss, output
|
470
481
|
|
471
|
-
@
|
472
|
-
def val_step(
|
473
|
-
|
474
|
-
|
482
|
+
@xax_jit(static_argnames=["self", "model_static"])
|
483
|
+
def val_step(
|
484
|
+
self,
|
485
|
+
model_static: PyTree,
|
486
|
+
model_arr: PyTree,
|
487
|
+
batch: Batch,
|
488
|
+
state: State,
|
489
|
+
) -> tuple[Array, Output]:
|
490
|
+
return self.get_output_and_loss(model_static, model_arr, batch, state)
|
475
491
|
|
476
492
|
def train_loop(
|
477
493
|
self,
|
@@ -482,36 +498,46 @@ class TrainMixin(
|
|
482
498
|
valid_pf: Iterator[Batch],
|
483
499
|
state: State,
|
484
500
|
) -> None:
|
501
|
+
model_arr, model_static = eqx.partition(model, self.model_partition_fn)
|
502
|
+
|
485
503
|
while not self.is_training_over(state):
|
486
504
|
if self.valid_step_timer.is_valid_step(state):
|
487
505
|
valid_batch = next(valid_pf)
|
488
|
-
|
489
|
-
|
506
|
+
state = state.replace(
|
507
|
+
phase="valid",
|
508
|
+
num_valid_steps=state.num_valid_steps + 1,
|
509
|
+
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
510
|
+
)
|
490
511
|
|
491
|
-
|
492
|
-
|
493
|
-
state.phase = "valid"
|
494
|
-
self.log_step(model, valid_batch, output, loss, state)
|
495
|
-
state.num_valid_samples += 1
|
512
|
+
loss, output = self.val_step(model_static, model_arr, valid_batch, state)
|
513
|
+
self.log_step(valid_batch, output, loss, state)
|
496
514
|
|
497
515
|
state = self.on_step_start(state)
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
516
|
+
train_batch = next(train_pf)
|
517
|
+
state = state.replace(
|
518
|
+
phase="train",
|
519
|
+
num_steps=state.num_steps + 1,
|
520
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
521
|
+
)
|
522
|
+
|
523
|
+
model_arr, opt_state, loss, output = self.train_step(
|
524
|
+
model_static=model_static,
|
525
|
+
model_arr=model_arr,
|
526
|
+
optimizer=optimizer,
|
527
|
+
opt_state=opt_state,
|
528
|
+
batch=train_batch,
|
529
|
+
state=state,
|
530
|
+
)
|
531
|
+
self.log_step(train_batch, output, loss, state)
|
508
532
|
|
509
533
|
state = self.on_step_end(state)
|
510
534
|
|
511
535
|
if self.should_checkpoint(state):
|
536
|
+
model = eqx.combine(model_arr, model_static)
|
512
537
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
513
538
|
|
514
539
|
# After finishing training, save the final checkpoint.
|
540
|
+
model = eqx.combine(model_arr, model_static)
|
515
541
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
516
542
|
|
517
543
|
@contextlib.contextmanager
|
xax/task/script.py
CHANGED
@@ -3,6 +3,8 @@
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Generic, TypeVar
|
5
5
|
|
6
|
+
import jax
|
7
|
+
|
6
8
|
from xax.task.base import BaseConfig, BaseTask
|
7
9
|
from xax.task.mixins import (
|
8
10
|
ArtifactsConfig,
|
@@ -20,6 +22,7 @@ from xax.task.mixins import (
|
|
20
22
|
)
|
21
23
|
|
22
24
|
|
25
|
+
@jax.tree_util.register_dataclass
|
23
26
|
@dataclass(kw_only=True)
|
24
27
|
class ScriptConfig(
|
25
28
|
CPUStatsConfig,
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=_xb60-jl7arZEleSwUw4ElPaq4MzD24_ZYQrnWO5_cs,13391
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=WwW0qDm-be9MMOT-bGWEFvaWF4iq2FP9xRSn1zq_4A8,2507
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
@@ -14,9 +14,9 @@ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
|
14
14
|
xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
15
15
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
16
16
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
xax/task/base.py,sha256=
|
17
|
+
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
18
18
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
19
|
-
xax/task/script.py,sha256=
|
19
|
+
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
20
20
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
21
21
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
41
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
42
|
+
xax/task/mixins/train.py,sha256=JbrSiBqpgOrdDanNYuAzzh2radPrXOVrHYA6VcxjIzY,23248
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
45
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -56,8 +56,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
56
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
58
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
59
|
+
xax-0.1.9.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
60
|
+
xax-0.1.9.dist-info/METADATA,sha256=Ou8KmYWWNxgo_9ZAU2KLaeGeXAxd6b9qJ95ky4HRm-o,1877
|
61
|
+
xax-0.1.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
62
|
+
xax-0.1.9.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
63
|
+
xax-0.1.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|