xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/mixins/train.py CHANGED
@@ -13,14 +13,24 @@ import traceback
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import dataclass, is_dataclass
15
15
  from threading import Thread
16
- from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
16
+ from typing import (
17
+ Any,
18
+ Generator,
19
+ Generic,
20
+ Iterator,
21
+ Literal,
22
+ Mapping,
23
+ Sequence,
24
+ TypeVar,
25
+ cast,
26
+ get_args,
27
+ )
17
28
 
18
29
  import equinox as eqx
19
30
  import jax
20
- import jax.numpy as jnp
21
31
  import numpy as np
22
32
  import optax
23
- from jaxtyping import Array, PyTree
33
+ from jaxtyping import Array, PRNGKeyArray, PyTree
24
34
  from omegaconf import DictConfig
25
35
 
26
36
  from xax.core.conf import field
@@ -130,6 +140,7 @@ class ValidStepTimer:
130
140
  return False
131
141
 
132
142
 
143
+ @jax.tree_util.register_dataclass
133
144
  @dataclass
134
145
  class TrainConfig(
135
146
  CheckpointingConfig,
@@ -191,16 +202,13 @@ class TrainMixin(
191
202
  # The kind of step that was specified in the config.
192
203
  self._step_kind = cast_step_kind(self.config.step_kind)
193
204
 
194
- def prng_key(self) -> jnp.ndarray:
205
+ def prng_key(self) -> PRNGKeyArray:
195
206
  return jax.random.PRNGKey(self.config.random_seed)
196
207
 
197
208
  def on_step_end(self, state: State) -> State:
198
209
  state = super().on_step_end(state)
199
- return state.replace(
200
- {
201
- "elapsed_time_s": time.time() - state.start_time_s,
202
- },
203
- )
210
+ state.elapsed_time_s = time.time() - state.start_time_s
211
+ return state
204
212
 
205
213
  def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
206
214
  """Override this function to do logging during the training phase.
@@ -228,16 +236,19 @@ class TrainMixin(
228
236
  state: The current training state.
229
237
  """
230
238
 
231
- def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
232
- phase = state.phase
233
-
234
- # Log the state timers.
235
- timer = self.state_timers[phase]
239
+ def log_state_timers(self, state: State) -> None:
240
+ timer = self.state_timers[state.phase]
236
241
  timer.step(state)
237
242
  for ns, d in timer.log_dict().items():
238
243
  for k, v in d.items():
239
244
  self.logger.log_scalar(k, v, namespace=ns)
240
245
 
246
+ def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
247
+ phase = state.phase
248
+
249
+ self.logger.log_scalar("loss", loss, namespace="loss")
250
+ self.log_state_timers(state)
251
+
241
252
  # Delegate to the appropriate logging function based on the phase.
242
253
  match phase:
243
254
  case "train":
@@ -247,8 +258,10 @@ class TrainMixin(
247
258
  case _:
248
259
  raise KeyError(f"Unknown phase: {phase}")
249
260
 
261
+ self.write_logs(state)
262
+
250
263
  @abstractmethod
251
- def get_model(self) -> PyTree:
264
+ def get_model(self, key: PRNGKeyArray) -> PyTree:
252
265
  """Returns the Equinox model to train.
253
266
 
254
267
  Returns:
@@ -266,7 +279,10 @@ class TrainMixin(
266
279
  def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
267
280
  return optimizer.init(eqx.filter(model, eqx.is_array))
268
281
 
269
- def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
282
+ def load_initial_state(
283
+ self,
284
+ key: PRNGKeyArray,
285
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
270
286
  init_ckpt_path = self.get_init_ckpt_path()
271
287
 
272
288
  if init_ckpt_path is not None:
@@ -279,18 +295,18 @@ class TrainMixin(
279
295
  return model, optimizer, opt_state, state
280
296
 
281
297
  with self.step_context("get_model"):
282
- model = self.get_model()
298
+ model = self.get_model(key)
283
299
 
284
300
  with self.step_context("get_optimizer"):
285
301
  optimizer = self.get_optimizer()
286
302
 
287
303
  with self.step_context("get_initial_opt_state"):
288
- opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
304
+ opt_state = self.get_initial_opt_state(model, optimizer)
289
305
 
290
306
  return model, optimizer, opt_state, State.init_state()
291
307
 
292
- @abstractmethod
293
- def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
308
+ @eqx.filter_jit
309
+ def get_output(self, model: PyTree, batch: Batch) -> Output:
294
310
  """Gets the output from the model.
295
311
 
296
312
  By default, we assume the model is a function that takes the batch as
@@ -300,10 +316,11 @@ class TrainMixin(
300
316
  Args:
301
317
  model: The current model.
302
318
  batch: The current minibatch of samples.
303
- state: The current training state.
304
319
  """
320
+ raise NotImplementedError("`get_output` must be implemented by the subclass")
305
321
 
306
- def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
322
+ @eqx.filter_jit
323
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
307
324
  """Gets the loss for the current batch.
308
325
 
309
326
  By default, we assume the model is a function that takes the batch as
@@ -314,7 +331,6 @@ class TrainMixin(
314
331
  model: The current model.
315
332
  batch: The current minibatch of samples.
316
333
  output: The output from the model.
317
- state: The current training state.
318
334
 
319
335
  Returns:
320
336
  The computed loss, as a tensor.
@@ -323,9 +339,10 @@ class TrainMixin(
323
339
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
324
340
  return output
325
341
 
326
- def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
327
- output = self.get_output(model, batch, state)
328
- loss = self.compute_loss(model, batch, output, state)
342
+ @eqx.filter_jit
343
+ def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
344
+ output = self.get_output(model, batch)
345
+ loss = self.compute_loss(model, batch, output)
329
346
  return loss, output
330
347
 
331
348
  @eqx.filter_jit
@@ -335,10 +352,9 @@ class TrainMixin(
335
352
  optimizer: optax.GradientTransformation,
336
353
  opt_state: optax.OptState,
337
354
  batch: Batch,
338
- state: State,
339
355
  ) -> tuple[Array, PyTree, optax.OptState, Output]:
340
- (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
341
- updates, opt_state = optimizer.update(grads, opt_state)
356
+ (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
357
+ updates, opt_state = optimizer.update(grads, opt_state, model)
342
358
  model = eqx.apply_updates(model, updates)
343
359
  return loss, model, opt_state, output
344
360
 
@@ -377,7 +393,13 @@ class TrainMixin(
377
393
  self._last_printed_remaining_time = state.elapsed_time_s
378
394
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
379
395
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
380
- logger.info("Estimated finish time: %s", termination_time)
396
+ # logger.info("Estimated finish time: %s", termination_time)
397
+ jax.debug.print("Estimated finish time: {}", termination_time)
398
+
399
+ def get_remaining_percent(self, state: State) -> float | None:
400
+ if self.config.max_steps is None:
401
+ return None
402
+ return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
381
403
 
382
404
  def is_training_over(self, state: State) -> bool:
383
405
  if self._training_over_flag:
@@ -385,7 +407,6 @@ class TrainMixin(
385
407
  remaining_percent = self.get_remaining_percent(state)
386
408
  if remaining_percent is None:
387
409
  return False
388
- self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
389
410
  self.maybe_log_termination_time(remaining_percent, state)
390
411
  return remaining_percent <= 0.0
391
412
 
@@ -400,59 +421,124 @@ class TrainMixin(
400
421
  case _:
401
422
  raise ValueError(f"Invalid step kind {self._step_kind}")
402
423
 
403
- def get_remaining_percent(self, state: State) -> float | None:
404
- if self.config.max_steps is None:
405
- return None
406
- return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
407
-
408
424
  def log_state(self) -> None:
409
425
  logger.log(LOG_STATUS, self.task_path)
410
426
  logger.log(LOG_STATUS, self.task_name)
411
- self.logger.log_git_state(get_git_state(self))
412
- self.logger.log_training_code(get_training_code(self))
413
- self.logger.log_config(cast(DictConfig, self.config))
427
+ self.logger.log_file("git_state.txt", get_git_state(self))
428
+ self.logger.log_file("training_code.txt", get_training_code(self))
429
+ self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
414
430
 
431
+ @eqx.filter_jit
415
432
  def train_step(
416
433
  self,
417
434
  model: PyTree,
418
435
  optimizer: optax.GradientTransformation,
419
436
  opt_state: optax.OptState,
420
437
  batch: Batch,
438
+ ) -> tuple[PyTree, optax.OptState, Array, Output]:
439
+ loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
440
+ return model, opt_state, loss, output
441
+
442
+ @eqx.filter_jit
443
+ def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
444
+ loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
445
+ return model, loss, output
446
+
447
+ def train_loop(
448
+ self,
449
+ model: PyTree,
450
+ optimizer: optax.GradientTransformation,
451
+ opt_state: optax.OptState,
452
+ train_pf: Iterator[Batch],
453
+ valid_pf: Iterator[Batch],
421
454
  state: State,
422
- ) -> tuple[PyTree, optax.OptState, State]:
423
- state = state.with_phase("train")
424
- loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
425
- self.logger.log_scalar("loss", loss, namespace="loss")
426
- self.log_step(model, batch, output, state)
427
- self.write_logs(state)
428
- return (
429
- model,
430
- opt_state,
431
- state.replace(
432
- {
433
- "num_steps": state.num_steps + 1,
434
- "num_samples": state.num_samples + (self.get_size_of_batch(batch) or 0),
435
- },
436
- ),
437
- )
455
+ ) -> None:
456
+ while not self.is_training_over(state):
457
+ if self.valid_step_timer.is_valid_step(state):
458
+ valid_batch = next(valid_pf)
459
+ model, loss, output = self.val_step(model, valid_batch)
460
+
461
+ # Perform logging.
462
+ with self.step_context("write_logs"):
463
+ state.phase = "valid"
464
+ self.log_step(model, valid_batch, output, loss, state)
465
+ state.num_valid_samples += 1
466
+
467
+ with self.step_context("on_step_start"):
468
+ state = self.on_step_start(state)
469
+
470
+ with self.step_context("update_state"):
471
+ train_batch = next(train_pf)
472
+ model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
+
474
+ # Perform logging.
475
+ with self.step_context("write_logs"):
476
+ state.phase = "train"
477
+ self.log_step(model, train_batch, output, loss, state)
478
+ state.num_steps += 1
479
+ state.num_samples += self.get_size_of_batch(train_batch) or 0
480
+
481
+ with self.step_context("on_step_end"):
482
+ state = self.on_step_end(state)
483
+
484
+ if self.should_checkpoint(state):
485
+ self.save_checkpoint(model, optimizer, opt_state, state)
438
486
 
439
- def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
440
- state = state.with_phase("valid")
441
- loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
442
- self.logger.log_scalar("loss", loss, namespace="loss")
443
- self.log_step(model, batch, output, state)
444
- self.write_logs(state)
445
- return model, state.replace(
446
- {
447
- "num_valid_steps": state.num_valid_steps + 1,
448
- "num_valid_samples": state.num_valid_samples + (self.get_size_of_batch(batch) or 0),
449
- },
450
- )
487
+ # After finishing training, save the final checkpoint.
488
+ self.save_checkpoint(model, optimizer, opt_state, state)
489
+
490
+ @contextlib.contextmanager
491
+ def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
492
+ try:
493
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train")
494
+ yield train_iterator
495
+ return
496
+ except NotImplementedError:
497
+ pass
498
+
499
+ with self.step_context("get_dataset"):
500
+ train_ds = self.get_dataset("train")
501
+
502
+ with self.step_context("get_dataloader"):
503
+ train_dl = self.get_dataloader(train_ds, "train")
504
+
505
+ with self.step_context("get_prefetcher"):
506
+ train_pf = self.get_prefetcher(train_dl)
507
+
508
+ try:
509
+ with train_pf as train_pf_ctx:
510
+ yield train_pf_ctx
511
+ finally:
512
+ logger.info("Closing train prefetcher")
513
+
514
+ @contextlib.contextmanager
515
+ def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
516
+ try:
517
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
518
+ yield valid_iterator
519
+ return
520
+ except NotImplementedError:
521
+ pass
522
+
523
+ with self.step_context("get_dataset"):
524
+ valid_ds = self.get_dataset("valid")
525
+
526
+ with self.step_context("get_dataloader"):
527
+ valid_dl = self.get_dataloader(valid_ds, "valid")
528
+
529
+ with self.step_context("get_prefetcher"):
530
+ valid_pf = self.get_prefetcher(valid_dl)
531
+
532
+ try:
533
+ with valid_pf as valid_pf_ctx:
534
+ yield valid_pf_ctx
535
+ finally:
536
+ logger.info("Closing valid prefetcher")
451
537
 
452
538
  def run(self) -> None:
453
- self.run_training_loop()
539
+ self.run_training()
454
540
 
455
- def run_training_loop(self) -> None:
541
+ def run_training(self) -> None:
456
542
  """Runs the training loop.
457
543
 
458
544
  Args:
@@ -464,33 +550,16 @@ class TrainMixin(
464
550
  Raises:
465
551
  ValueError: If the task is not a supervised learning task
466
552
  """
467
- with contextlib.ExitStack() as ctx:
553
+ with self:
554
+ key = self.prng_key()
555
+
468
556
  self.set_loggers()
469
557
 
470
558
  if is_master():
471
559
  Thread(target=self.log_state, daemon=True).start()
472
560
 
473
- # Gets the datasets.
474
- with self.step_context("get_dataset"):
475
- train_ds = self.get_dataset("train")
476
- valid_ds = self.get_dataset("valid")
477
-
478
- # Gets the dataloaders.
479
- with self.step_context("get_dataloader"):
480
- train_dl = self.get_dataloader(train_ds, "train")
481
- valid_dl = self.get_dataloader(valid_ds, "valid")
482
-
483
- # Gets the prefetchers.
484
- with self.step_context("get_prefetcher"):
485
- train_pf = self.get_prefetcher(train_dl)
486
- valid_pf = self.get_prefetcher(valid_dl)
487
-
488
- ctx.enter_context(self)
489
- ctx.enter_context(train_pf)
490
- ctx.enter_context(valid_pf)
491
-
492
- model, optimizer, opt_state, state = self.load_initial_state()
493
-
561
+ key, model_key = jax.random.split(key)
562
+ model, optimizer, opt_state, state = self.load_initial_state(model_key)
494
563
  state = self.on_training_start(state)
495
564
 
496
565
  def on_exit() -> None:
@@ -499,43 +568,34 @@ class TrainMixin(
499
568
  # Handle user-defined interrupts during the training loop.
500
569
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
501
570
 
502
- try:
503
- while True:
504
- while True:
505
- if self.is_training_over(state):
506
- raise TrainingFinishedError
507
-
508
- if self.valid_step_timer.is_valid_step(state):
509
- model, state = self.val_step(model, next(valid_pf), state)
510
-
511
- with self.step_context("on_step_start"):
512
- state = self.on_step_start(state)
513
-
514
- model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
515
-
516
- with self.step_context("on_step_end"):
517
- state = self.on_step_end(state)
518
-
519
- if self.should_checkpoint(state):
520
- self.save_checkpoint(model, optimizer, opt_state, state)
521
-
522
- except TrainingFinishedError:
523
- if is_master():
524
- show_info(
525
- f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
526
- important=True,
571
+ with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
572
+ try:
573
+ self.train_loop(
574
+ model=model,
575
+ optimizer=optimizer,
576
+ opt_state=opt_state,
577
+ train_pf=train_pf,
578
+ valid_pf=valid_pf,
579
+ state=state,
527
580
  )
528
- self.save_checkpoint(model, optimizer, opt_state, state)
529
-
530
- except (KeyboardInterrupt, bdb.BdbQuit):
531
- if is_master():
532
- show_info("Interrupted training", important=True)
533
-
534
- except BaseException:
535
- exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
536
- sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
537
- sys.stdout.flush()
538
- self.save_checkpoint(model, optimizer, opt_state, state)
539
581
 
540
- finally:
541
- state = self.on_training_end(state)
582
+ except TrainingFinishedError:
583
+ if is_master():
584
+ show_info(
585
+ f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
586
+ important=True,
587
+ )
588
+ self.save_checkpoint(model, optimizer, opt_state, state)
589
+
590
+ except (KeyboardInterrupt, bdb.BdbQuit):
591
+ if is_master():
592
+ show_info("Interrupted training", important=True)
593
+
594
+ except BaseException:
595
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
596
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
597
+ sys.stdout.flush()
598
+ self.save_checkpoint(model, optimizer, opt_state, state)
599
+
600
+ finally:
601
+ state = self.on_training_end(state)
xax/task/script.py CHANGED
@@ -22,7 +22,7 @@ from xax.task.mixins import (
22
22
  )
23
23
 
24
24
 
25
- @dataclass
25
+ @dataclass(kw_only=True)
26
26
  class ScriptConfig(
27
27
  CPUStatsConfig,
28
28
  GPUStatsConfig,
xax/task/task.py CHANGED
@@ -3,12 +3,16 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Generic, TypeVar
5
5
 
6
+ import jax
7
+
6
8
  from xax.task.base import BaseConfig, BaseTask
7
9
  from xax.task.mixins import (
8
10
  ArtifactsConfig,
9
11
  ArtifactsMixin,
10
12
  CheckpointingConfig,
11
13
  CheckpointingMixin,
14
+ CompileConfig,
15
+ CompileMixin,
12
16
  CPUStatsConfig,
13
17
  CPUStatsMixin,
14
18
  DataloadersConfig,
@@ -28,10 +32,12 @@ from xax.task.mixins import (
28
32
  )
29
33
 
30
34
 
35
+ @jax.tree_util.register_dataclass
31
36
  @dataclass
32
37
  class Config(
33
38
  TrainConfig,
34
39
  CheckpointingConfig,
40
+ CompileConfig,
35
41
  DataloadersConfig,
36
42
  CPUStatsConfig,
37
43
  GPUStatsConfig,
@@ -51,6 +57,7 @@ ConfigT = TypeVar("ConfigT", bound=Config)
51
57
  class Task(
52
58
  TrainMixin[ConfigT],
53
59
  CheckpointingMixin[ConfigT],
60
+ CompileMixin[ConfigT],
54
61
  DataloadersMixin[ConfigT],
55
62
  CPUStatsMixin[ConfigT],
56
63
  GPUStatsMixin[ConfigT],
xax/utils/tensorboard.py CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  import functools
4
4
  import io
5
+ import os
6
+ import tempfile
5
7
  import time
6
8
  from pathlib import Path
7
9
  from typing import Literal, TypedDict
8
10
 
11
+ import numpy as np
12
+ import PIL.Image
9
13
  from PIL.Image import Image as PILImage
10
14
  from tensorboard.compat.proto.config_pb2 import RunMetadata
11
15
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
@@ -186,6 +190,50 @@ class TensorboardWriter:
186
190
  walltime=walltime,
187
191
  )
188
192
 
193
+ def add_video(
194
+ self,
195
+ tag: str,
196
+ value: np.ndarray,
197
+ global_step: int | None = None,
198
+ walltime: float | None = None,
199
+ fps: int = 30,
200
+ ) -> None:
201
+ assert value.ndim == 4, "Video must be 4D array (T, H, W, C)"
202
+ images = [PIL.Image.fromarray(frame) for frame in value]
203
+
204
+ # Create temporary file for GIF
205
+ temp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
206
+ try:
207
+ images[0].save(temp_file.name, save_all=True, append_images=images[1:], duration=int(1000 / fps), loop=0)
208
+ with open(temp_file.name, "rb") as f:
209
+ video_string = f.read()
210
+
211
+ finally:
212
+ # Clean up temporary file
213
+ try:
214
+ os.remove(temp_file.name)
215
+ except OSError:
216
+ pass
217
+
218
+ # Add to summary
219
+ self.pb_writer.add_summary(
220
+ Summary(
221
+ value=[
222
+ Summary.Value(
223
+ tag=tag,
224
+ image=Summary.Image(
225
+ height=value.shape[1],
226
+ width=value.shape[2],
227
+ colorspace=value.shape[3],
228
+ encoded_image_string=video_string,
229
+ ),
230
+ ),
231
+ ],
232
+ ),
233
+ global_step=global_step,
234
+ walltime=walltime,
235
+ )
236
+
189
237
  def add_text(
190
238
  self,
191
239
  tag: str,
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: xax
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: The xax project
5
5
  Home-page: https://github.com/dpshai/xax
6
6
  Author: Benjamin Bolte
@@ -12,6 +12,8 @@ Requires-Dist: jaxtyping
12
12
  Requires-Dist: equinox
13
13
  Requires-Dist: optax
14
14
  Requires-Dist: dpshdl
15
+ Requires-Dist: chex
16
+ Requires-Dist: importlib-resources
15
17
  Requires-Dist: cloudpickle
16
18
  Requires-Dist: pillow
17
19
  Requires-Dist: omegaconf
@@ -28,6 +30,14 @@ Requires-Dist: pytest; extra == "dev"
28
30
  Requires-Dist: types-pillow; extra == "dev"
29
31
  Requires-Dist: types-psutil; extra == "dev"
30
32
  Requires-Dist: types-requests; extra == "dev"
33
+ Dynamic: author
34
+ Dynamic: description
35
+ Dynamic: description-content-type
36
+ Dynamic: home-page
37
+ Dynamic: provides-extra
38
+ Dynamic: requires-dist
39
+ Dynamic: requires-python
40
+ Dynamic: summary
31
41
 
32
42
  # xax
33
43
 
@@ -0,0 +1,52 @@
1
+ xax/__init__.py,sha256=RTUsDh_R0TFa09q-_U0vd-eCYRC-bCaHqHlayp8U2hU,9736
2
+ xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
+ xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
5
+ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
+ xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
+ xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
+ xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
11
+ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
12
+ xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
14
+ xax/task/logger.py,sha256=orN1jmM4SIR2EiYk8bNoJZscmhX1FytADBU6p9qpows,29256
15
+ xax/task/script.py,sha256=4LyXrpj0V36TjAZT4lvQeiOTqa5U2tommHKwgWDCE24,1025
16
+ xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
17
+ xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
19
+ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
20
+ xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
21
+ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
23
+ xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
24
+ xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
25
+ xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
26
+ xax/task/loggers/tensorboard.py,sha256=FGW96z77oG0Kf3cO6Zznx5U3kJNzPWcuSkpY4RnbFCo,6909
27
+ xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
28
+ xax/task/mixins/artifacts.py,sha256=1H7ZbR-KSsXhVtqGVlqMi-TXfn1-dM7YnTCLVuw594s,3835
29
+ xax/task/mixins/checkpointing.py,sha256=AMlobojybvJdDZcNCxm1DHSVC_2Qvnu_MbRcsc_8eoA,8508
30
+ xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
31
+ xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
32
+ xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
33
+ xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
34
+ xax/task/mixins/logger.py,sha256=CIQ4w4K3FcxN6A9xUfITdVkulSxPa4iaTe6cbs9ruaM,1958
35
+ xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
36
+ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
37
+ xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhcVQ,1560
38
+ xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
39
+ xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
41
+ xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
42
+ xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
43
+ xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
44
+ xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
45
+ xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
46
+ xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
+ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
48
+ xax-0.0.6.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
49
+ xax-0.0.6.dist-info/METADATA,sha256=YO2c2PUMWkH1ILfPhFWKK4Sodbo9qUpUOCIkm4aLHfg,1171
50
+ xax-0.0.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
51
+ xax-0.0.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
52
+ xax-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5