xax 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/logger.py +42 -10
- xax/task/loggers/json.py +12 -4
- xax/task/loggers/stdout.py +21 -16
- xax/task/loggers/tensorboard.py +2 -2
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/train.py +24 -6
- xax/utils/experiments.py +7 -20
- {xax-0.1.15.dist-info → xax-0.1.16.dist-info}/METADATA +1 -1
- {xax-0.1.15.dist-info → xax-0.1.16.dist-info}/RECORD +14 -14
- {xax-0.1.15.dist-info → xax-0.1.16.dist-info}/WHEEL +0 -0
- {xax-0.1.15.dist-info → xax-0.1.16.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.15.dist-info → xax-0.1.16.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/logger.py
CHANGED
@@ -205,6 +205,12 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
205
|
return np.array(array)
|
206
206
|
|
207
207
|
|
208
|
+
@dataclass(kw_only=True)
|
209
|
+
class LogString:
|
210
|
+
value: str
|
211
|
+
secondary: bool
|
212
|
+
|
213
|
+
|
208
214
|
@dataclass(kw_only=True)
|
209
215
|
class LogImage:
|
210
216
|
image: PILImage
|
@@ -223,6 +229,12 @@ class LogVideo:
|
|
223
229
|
fps: int
|
224
230
|
|
225
231
|
|
232
|
+
@dataclass(kw_only=True)
|
233
|
+
class LogScalar:
|
234
|
+
value: Number
|
235
|
+
secondary: bool
|
236
|
+
|
237
|
+
|
226
238
|
@dataclass(kw_only=True)
|
227
239
|
class LogDistribution:
|
228
240
|
mean: Number
|
@@ -243,10 +255,10 @@ class LogHistogram:
|
|
243
255
|
@dataclass(kw_only=True)
|
244
256
|
class LogLine:
|
245
257
|
state: State
|
246
|
-
scalars: dict[str, dict[str,
|
258
|
+
scalars: dict[str, dict[str, LogScalar]]
|
247
259
|
distributions: dict[str, dict[str, LogDistribution]]
|
248
260
|
histograms: dict[str, dict[str, LogHistogram]]
|
249
|
-
strings: dict[str, dict[str,
|
261
|
+
strings: dict[str, dict[str, LogString]]
|
250
262
|
images: dict[str, dict[str, LogImage]]
|
251
263
|
videos: dict[str, dict[str, LogVideo]]
|
252
264
|
|
@@ -515,10 +527,10 @@ class Logger:
|
|
515
527
|
"""Defines an intermediate container which holds values to log somewhere else."""
|
516
528
|
|
517
529
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
518
|
-
self.scalars: dict[str, dict[str, Callable[[],
|
530
|
+
self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
|
519
531
|
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
532
|
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
521
|
-
self.strings: dict[str, dict[str, Callable[[],
|
533
|
+
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
522
534
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
523
535
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
524
536
|
self.default_namespace = default_namespace
|
@@ -616,13 +628,23 @@ class Logger:
|
|
616
628
|
def resolve_namespace(self, namespace: str | None = None) -> str:
|
617
629
|
return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
|
618
630
|
|
619
|
-
def log_scalar(
|
631
|
+
def log_scalar(
|
632
|
+
self,
|
633
|
+
key: str,
|
634
|
+
value: Callable[[], Number] | Number,
|
635
|
+
*,
|
636
|
+
namespace: str | None = None,
|
637
|
+
secondary: bool = False,
|
638
|
+
) -> None:
|
620
639
|
"""Logs a scalar value.
|
621
640
|
|
622
641
|
Args:
|
623
642
|
key: The key being logged
|
624
643
|
value: The scalar value being logged
|
625
644
|
namespace: An optional logging namespace
|
645
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
646
|
+
less important than other values, and some downstream loggers
|
647
|
+
will not display it)
|
626
648
|
"""
|
627
649
|
if not self.active:
|
628
650
|
raise RuntimeError("The logger is not active")
|
@@ -632,11 +654,11 @@ class Logger:
|
|
632
654
|
assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
|
633
655
|
|
634
656
|
@functools.lru_cache(maxsize=None)
|
635
|
-
def scalar_future() ->
|
657
|
+
def scalar_future() -> LogScalar:
|
636
658
|
with ContextTimer() as timer:
|
637
659
|
value_concrete = value() if callable(value) else value
|
638
660
|
logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
|
639
|
-
return value_concrete
|
661
|
+
return LogScalar(value=value_concrete, secondary=secondary)
|
640
662
|
|
641
663
|
self.scalars[namespace][key] = scalar_future
|
642
664
|
|
@@ -770,21 +792,31 @@ class Logger:
|
|
770
792
|
|
771
793
|
self.histograms[namespace][key] = histogram_future
|
772
794
|
|
773
|
-
def log_string(
|
795
|
+
def log_string(
|
796
|
+
self,
|
797
|
+
key: str,
|
798
|
+
value: Callable[[], str] | str,
|
799
|
+
*,
|
800
|
+
namespace: str | None = None,
|
801
|
+
secondary: bool = False,
|
802
|
+
) -> None:
|
774
803
|
"""Logs a string value.
|
775
804
|
|
776
805
|
Args:
|
777
806
|
key: The key being logged
|
778
807
|
value: The string value being logged
|
779
808
|
namespace: An optional logging namespace
|
809
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
810
|
+
less important than other values, and some downstream loggers
|
811
|
+
will not display it)
|
780
812
|
"""
|
781
813
|
if not self.active:
|
782
814
|
raise RuntimeError("The logger is not active")
|
783
815
|
namespace = self.resolve_namespace(namespace)
|
784
816
|
|
785
817
|
@functools.lru_cache(maxsize=None)
|
786
|
-
def value_future() ->
|
787
|
-
return value() if callable(value) else value
|
818
|
+
def value_future() -> LogString:
|
819
|
+
return LogString(value=value() if callable(value) else value, secondary=secondary)
|
788
820
|
|
789
821
|
self.strings[namespace][key] = value_future
|
790
822
|
|
xax/task/loggers/json.py
CHANGED
@@ -3,11 +3,19 @@
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
5
|
from dataclasses import asdict
|
6
|
-
from typing import Any, Literal, TextIO
|
6
|
+
from typing import Any, Literal, Mapping, TextIO
|
7
7
|
|
8
8
|
from jaxtyping import Array
|
9
9
|
|
10
|
-
from xax.task.logger import
|
10
|
+
from xax.task.logger import (
|
11
|
+
LogError,
|
12
|
+
LoggerImpl,
|
13
|
+
LogLine,
|
14
|
+
LogPing,
|
15
|
+
LogScalar,
|
16
|
+
LogStatus,
|
17
|
+
LogString,
|
18
|
+
)
|
11
19
|
|
12
20
|
|
13
21
|
def get_json_value(value: Any) -> Any: # noqa: ANN401
|
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
|
|
61
69
|
def get_json(self, line: LogLine) -> str:
|
62
70
|
data: dict = {"state": asdict(line.state)}
|
63
71
|
|
64
|
-
def add_logs(log:
|
72
|
+
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
65
73
|
for namespace, values in log.items():
|
66
74
|
if self.remove_unicode_from_namespaces:
|
67
75
|
namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
|
68
76
|
if namespace not in data:
|
69
77
|
data[namespace] = {}
|
70
78
|
for k, v in values.items():
|
71
|
-
data[namespace][k] = get_json_value(v)
|
79
|
+
data[namespace][k] = get_json_value(v.value)
|
72
80
|
|
73
81
|
add_logs(line.scalars, data)
|
74
82
|
add_logs(line.strings, data)
|
xax/task/loggers/stdout.py
CHANGED
@@ -4,11 +4,20 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import sys
|
6
6
|
from collections import deque
|
7
|
-
from typing import Any, Deque, TextIO
|
7
|
+
from typing import Any, Deque, Mapping, TextIO
|
8
8
|
|
9
9
|
from jaxtyping import Array
|
10
10
|
|
11
|
-
from xax.task.logger import
|
11
|
+
from xax.task.logger import (
|
12
|
+
LogError,
|
13
|
+
LogErrorSummary,
|
14
|
+
LoggerImpl,
|
15
|
+
LogLine,
|
16
|
+
LogPing,
|
17
|
+
LogScalar,
|
18
|
+
LogStatus,
|
19
|
+
LogString,
|
20
|
+
)
|
12
21
|
from xax.utils.text import Color, colored, format_timedelta
|
13
22
|
|
14
23
|
|
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
|
|
95
104
|
def write_log_window(self, line: LogLine) -> None:
|
96
105
|
namespace_to_lines: dict[str, dict[str, str]] = {}
|
97
106
|
|
98
|
-
def add_logs(
|
107
|
+
def add_logs(
|
108
|
+
log: Mapping[str, Mapping[str, LogScalar | LogString]],
|
109
|
+
namespace_to_lines: dict[str, dict[str, str]],
|
110
|
+
) -> None:
|
99
111
|
for namespace, values in log.items():
|
100
|
-
if not self.log_timers and namespace.startswith("⌛"):
|
101
|
-
continue
|
102
|
-
if not self.log_perf and namespace.startswith("🔧"):
|
103
|
-
continue
|
104
|
-
if not self.log_optim and namespace.startswith("📉"):
|
105
|
-
continue
|
106
|
-
if not self.log_fp and namespace.startswith("⚖️"):
|
107
|
-
continue
|
108
|
-
if namespace not in namespace_to_lines:
|
109
|
-
namespace_to_lines[namespace] = {}
|
110
112
|
for k, v in values.items():
|
111
|
-
|
113
|
+
if v.secondary:
|
114
|
+
continue
|
115
|
+
if namespace not in namespace_to_lines:
|
116
|
+
namespace_to_lines[namespace] = {}
|
117
|
+
v_str = as_str(v.value, self.precision)
|
112
118
|
namespace_to_lines[namespace][k] = v_str
|
113
119
|
|
114
120
|
add_logs(line.scalars, namespace_to_lines)
|
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
|
|
116
122
|
if not namespace_to_lines:
|
117
123
|
return
|
118
124
|
|
119
|
-
self.write_fp.write("\n")
|
120
125
|
for namespace, lines in sorted(namespace_to_lines.items()):
|
121
|
-
self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
|
126
|
+
self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
|
122
127
|
for k, v in lines.items():
|
123
128
|
self.write_fp.write(f" ↪ {k}: {v}\n")
|
124
129
|
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -158,7 +158,7 @@ class TensorboardLogger(LoggerImpl):
|
|
158
158
|
for scalar_key, scalar_value in scalars.items():
|
159
159
|
writer.add_scalar(
|
160
160
|
f"{namespace}/{scalar_key}",
|
161
|
-
as_float(scalar_value),
|
161
|
+
as_float(scalar_value.value),
|
162
162
|
global_step=line.state.num_steps,
|
163
163
|
walltime=walltime,
|
164
164
|
)
|
@@ -192,7 +192,7 @@ class TensorboardLogger(LoggerImpl):
|
|
192
192
|
for string_key, string_value in strings.items():
|
193
193
|
writer.add_text(
|
194
194
|
f"{namespace}/{string_key}",
|
195
|
-
string_value,
|
195
|
+
string_value.value,
|
196
196
|
global_step=line.state.num_steps,
|
197
197
|
walltime=walltime,
|
198
198
|
)
|
xax/task/mixins/cpu_stats.py
CHANGED
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
248
248
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
249
249
|
|
250
250
|
if stats is not None:
|
251
|
-
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
252
|
-
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
253
|
-
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
254
|
-
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
255
|
-
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
256
|
-
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
257
|
-
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
258
|
-
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
259
|
-
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
260
|
-
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
251
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
|
252
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
|
253
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
|
254
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
|
255
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
|
256
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
|
257
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
|
258
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
|
259
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
|
260
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
|
261
261
|
|
262
262
|
return state
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
264
264
|
for gpu_stat in stats.values():
|
265
265
|
if gpu_stat is None:
|
266
266
|
continue
|
267
|
-
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
268
|
-
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
269
|
-
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
267
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
|
268
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
|
269
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
|
270
270
|
|
271
271
|
return state
|
xax/task/mixins/train.py
CHANGED
@@ -218,7 +218,12 @@ class TrainMixin(
|
|
218
218
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
219
219
|
|
220
220
|
def log_train_step(
|
221
|
-
self,
|
221
|
+
self,
|
222
|
+
model: PyTree,
|
223
|
+
batch: Batch,
|
224
|
+
output: Output,
|
225
|
+
metrics: FrozenDict[str, Array],
|
226
|
+
state: State,
|
222
227
|
) -> None:
|
223
228
|
"""Override this function to do logging during the training phase.
|
224
229
|
|
@@ -234,7 +239,12 @@ class TrainMixin(
|
|
234
239
|
"""
|
235
240
|
|
236
241
|
def log_valid_step(
|
237
|
-
self,
|
242
|
+
self,
|
243
|
+
model: PyTree,
|
244
|
+
batch: Batch,
|
245
|
+
output: Output,
|
246
|
+
metrics: FrozenDict[str, Array],
|
247
|
+
state: State,
|
238
248
|
) -> None:
|
239
249
|
"""Override this function to do logging during the validation phase.
|
240
250
|
|
@@ -252,12 +262,20 @@ class TrainMixin(
|
|
252
262
|
def log_state_timers(self, state: State) -> None:
|
253
263
|
timer = self.state_timers[state.phase]
|
254
264
|
timer.step(state)
|
255
|
-
for
|
256
|
-
|
257
|
-
|
265
|
+
for k, v in timer.log_dict().items():
|
266
|
+
if isinstance(v, tuple):
|
267
|
+
v, secondary = v
|
268
|
+
else:
|
269
|
+
secondary = False
|
270
|
+
self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
|
258
271
|
|
259
272
|
def log_step(
|
260
|
-
self,
|
273
|
+
self,
|
274
|
+
model: PyTree,
|
275
|
+
batch: Batch,
|
276
|
+
output: Output,
|
277
|
+
metrics: FrozenDict[str, Array],
|
278
|
+
state: State,
|
261
279
|
) -> None:
|
262
280
|
phase = state.phase
|
263
281
|
|
xax/utils/experiments.py
CHANGED
@@ -114,28 +114,15 @@ class StateTimer:
|
|
114
114
|
self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
|
115
115
|
self.iter_timer.step(cur_time)
|
116
116
|
|
117
|
-
def log_dict(self) -> dict[str,
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
"
|
123
|
-
"
|
124
|
-
}
|
125
|
-
|
126
|
-
# Logs sample statistics.
|
127
|
-
logs["⌛ samples"] = {
|
128
|
-
"total": self.sample_timer.steps,
|
129
|
-
"per-second": self.sample_timer.steps_per_second,
|
117
|
+
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
118
|
+
return {
|
119
|
+
"steps": (self.step_timer.steps, True),
|
120
|
+
"steps/second": self.step_timer.steps_per_second,
|
121
|
+
"samples": (self.sample_timer.steps, True),
|
122
|
+
"samples/second": (self.sample_timer.steps_per_second, True),
|
123
|
+
"dt": self.iter_timer.iter_seconds,
|
130
124
|
}
|
131
125
|
|
132
|
-
# Logs full iteration statistics.
|
133
|
-
logs["⌛ dt"] = {
|
134
|
-
"iter": self.iter_timer.iter_seconds,
|
135
|
-
}
|
136
|
-
|
137
|
-
return logs
|
138
|
-
|
139
126
|
|
140
127
|
class IntervalTicker:
|
141
128
|
def __init__(self, interval: float) -> None:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=pSWV5RtPBJynHr7dCqscbnMkETZPUyw8D6MHK4CuS90,14104
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
|
20
|
-
xax/task/logger.py,sha256=
|
20
|
+
xax/task/logger.py,sha256=Upx7cCZvaVIs75CHTfIzYmsuaFRsGu0FvziTZuazT4k,37083
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -26,25 +26,25 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
|
|
26
26
|
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
27
27
|
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
|
-
xax/task/loggers/json.py,sha256=
|
29
|
+
xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
31
|
+
xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=HjR-wiCWe0z3nivRzxEZIltzSzka1828bwxWVmMU5Sk,7718
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
35
|
xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
|
-
xax/task/mixins/cpu_stats.py,sha256=
|
37
|
+
xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
|
38
38
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
39
|
-
xax/task/mixins/gpu_stats.py,sha256=
|
39
|
+
xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
|
40
40
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
41
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=4Xr8b5LFueFh-f3k8MIJMv3M46_Aaf65YwCbjtSBQ-U,26393
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
47
|
+
xax/utils/experiments.py,sha256=vm_hWfaty_wEHVdoU2ALiBiGJze7IoDJIfXi6pd_a9I,29360
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.16.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.16.dist-info/METADATA,sha256=gfh7iFi7Wz3fJDf2w1KKs8H0uanhn2HFsR67TvP6uZM,1878
|
63
|
+
xax-0.1.16.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.16.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|