xax 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/logger.py +145 -64
- xax/task/mixins/train.py +0 -1
- xax/utils/logging.py +12 -2
- {xax-0.1.0.dist-info → xax-0.1.1.dist-info}/METADATA +1 -1
- {xax-0.1.0.dist-info → xax-0.1.1.dist-info}/RECORD +9 -9
- {xax-0.1.0.dist-info → xax-0.1.1.dist-info}/WHEEL +0 -0
- {xax-0.1.0.dist-info → xax-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.0.dist-info → xax-0.1.1.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/logger.py
CHANGED
@@ -18,7 +18,7 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
|
21
|
+
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
|
22
22
|
|
23
23
|
import jax
|
24
24
|
import jax.numpy as jnp
|
@@ -28,7 +28,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|
28
28
|
from PIL.Image import Image as PILImage
|
29
29
|
|
30
30
|
from xax.core.state import Phase, State
|
31
|
-
from xax.utils.experiments import IntervalTicker
|
31
|
+
from xax.utils.experiments import ContextTimer, IntervalTicker
|
32
32
|
from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
@@ -236,7 +236,7 @@ class LogHistogram:
|
|
236
236
|
num: int
|
237
237
|
sum: Number
|
238
238
|
sum_squares: Number
|
239
|
-
bucket_limits: list[
|
239
|
+
bucket_limits: list[float]
|
240
240
|
bucket_counts: list[int]
|
241
241
|
|
242
242
|
|
@@ -564,14 +564,14 @@ class Logger:
|
|
564
564
|
Args:
|
565
565
|
state: The current step's state.
|
566
566
|
"""
|
567
|
-
should_log = [
|
567
|
+
should_log = [lg.should_log(state) for lg in self.loggers]
|
568
568
|
if not any(should_log):
|
569
569
|
self.clear()
|
570
570
|
return
|
571
571
|
line = self.pack(state)
|
572
572
|
self.clear()
|
573
|
-
for
|
574
|
-
|
573
|
+
for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
|
574
|
+
lg.write(line)
|
575
575
|
|
576
576
|
def write_error_summary(self, error_summary: str) -> None:
|
577
577
|
for logger in self.loggers:
|
@@ -633,7 +633,10 @@ class Logger:
|
|
633
633
|
|
634
634
|
@functools.lru_cache(maxsize=None)
|
635
635
|
def scalar_future() -> Number:
|
636
|
-
|
636
|
+
with ContextTimer() as timer:
|
637
|
+
value_concrete = value() if callable(value) else value
|
638
|
+
logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
|
639
|
+
return value_concrete
|
637
640
|
|
638
641
|
self.scalars[namespace][key] = scalar_future
|
639
642
|
|
@@ -657,7 +660,9 @@ class Logger:
|
|
657
660
|
|
658
661
|
@functools.lru_cache(maxsize=None)
|
659
662
|
def distribution_future() -> LogDistribution:
|
660
|
-
|
663
|
+
with ContextTimer() as timer:
|
664
|
+
mean, std = value() if callable(value) else value
|
665
|
+
logger.debug("Distribution Key: %s, Time: %s", key, timer.elapsed_time)
|
661
666
|
return LogDistribution(mean=mean, std=std)
|
662
667
|
|
663
668
|
self.distributions[namespace][key] = distribution_future
|
@@ -684,26 +689,84 @@ class Logger:
|
|
684
689
|
|
685
690
|
@functools.lru_cache(maxsize=None)
|
686
691
|
def histogram_future() -> LogHistogram:
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
692
|
+
with ContextTimer() as timer:
|
693
|
+
values = value() if callable(value) else value
|
694
|
+
values = values.reshape(-1) # Must be flat.
|
695
|
+
|
696
|
+
if isinstance(values, Array):
|
697
|
+
counts, limits = jnp.histogram(values, bins=bins)
|
698
|
+
counts, limits = as_numpy(counts), as_numpy(limits)
|
699
|
+
elif isinstance(values, np.ndarray):
|
700
|
+
counts, limits = np.histogram(values, bins=bins)
|
701
|
+
else:
|
702
|
+
raise ValueError(f"Unsupported histogram type: {type(values)}")
|
703
|
+
|
704
|
+
histogram_values = LogHistogram(
|
705
|
+
min=float(values.min()),
|
706
|
+
max=float(values.max()),
|
707
|
+
num=int(values.size),
|
708
|
+
sum=float(values.sum()),
|
709
|
+
sum_squares=float(values.dot(values)),
|
710
|
+
bucket_limits=cast(list[float], limits[1:].tolist()),
|
711
|
+
bucket_counts=cast(list[int], counts.tolist()),
|
712
|
+
)
|
713
|
+
|
714
|
+
logger.debug("Histogram Key: %s, Time: %s", key, timer.elapsed_time)
|
715
|
+
return histogram_values
|
716
|
+
|
717
|
+
self.histograms[namespace][key] = histogram_future
|
718
|
+
|
719
|
+
def log_histogram_raw(
|
720
|
+
self,
|
721
|
+
key: str,
|
722
|
+
counts: Array | np.ndarray,
|
723
|
+
limits: Array | np.ndarray,
|
724
|
+
minv: Number | None = None,
|
725
|
+
maxv: Number | None = None,
|
726
|
+
sumv: Number | None = None,
|
727
|
+
sum_squaresv: Number | None = None,
|
728
|
+
*,
|
729
|
+
namespace: str | None = None,
|
730
|
+
) -> None:
|
731
|
+
"""Logs a histogram from raw counts and limits.
|
732
|
+
|
733
|
+
Args:
|
734
|
+
key: The key being logged
|
735
|
+
counts: The counts of the histogram
|
736
|
+
limits: The limits of the histogram
|
737
|
+
minv: The minimum value of the histogram
|
738
|
+
maxv: The maximum value of the histogram
|
739
|
+
sumv: The sum of the histogram
|
740
|
+
sum_squaresv: The sum of the squares of the histogram
|
741
|
+
namespace: An optional logging namespace
|
742
|
+
"""
|
743
|
+
if not self.active:
|
744
|
+
raise RuntimeError("The logger is not active")
|
745
|
+
namespace = self.resolve_namespace(namespace)
|
746
|
+
|
747
|
+
@functools.lru_cache(maxsize=None)
|
748
|
+
def histogram_future() -> LogHistogram:
|
749
|
+
with ContextTimer() as timer:
|
750
|
+
counts_np = (as_numpy(counts) if isinstance(counts, Array) else counts).astype(int)
|
751
|
+
limits_np = (as_numpy(limits) if isinstance(limits, Array) else limits).astype(float)
|
752
|
+
|
753
|
+
minv_ = counts_np.min() if minv is None else minv
|
754
|
+
maxv_ = counts_np.max() if maxv is None else maxv
|
755
|
+
sumv_ = counts_np.sum() if sumv is None else sumv
|
756
|
+
sum_squaresv_ = counts_np.dot(counts_np) if sum_squaresv is None else sum_squaresv
|
757
|
+
|
758
|
+
histogram_values = LogHistogram(
|
759
|
+
min=float(minv_),
|
760
|
+
max=float(maxv_),
|
761
|
+
num=int(counts_np.size),
|
762
|
+
sum=float(sumv_),
|
763
|
+
sum_squares=float(sum_squaresv_),
|
764
|
+
bucket_limits=cast(list[float], limits_np.tolist()),
|
765
|
+
bucket_counts=cast(list[int], counts_np.tolist()),
|
766
|
+
)
|
767
|
+
|
768
|
+
logger.debug("Raw Histogram Key: %s, Time: %s", key, timer.elapsed_time)
|
769
|
+
return histogram_values
|
707
770
|
|
708
771
|
self.histograms[namespace][key] = histogram_future
|
709
772
|
|
@@ -748,7 +811,10 @@ class Logger:
|
|
748
811
|
|
749
812
|
@functools.lru_cache(maxsize=None)
|
750
813
|
def image_future() -> LogImage:
|
751
|
-
|
814
|
+
with ContextTimer() as timer:
|
815
|
+
image = get_image(value() if callable(value) else value, target_resolution)
|
816
|
+
logger.debug("Image Key: %s, Time: %s", key, timer.elapsed_time)
|
817
|
+
return image
|
752
818
|
|
753
819
|
self.images[namespace][key] = image_future
|
754
820
|
|
@@ -784,15 +850,20 @@ class Logger:
|
|
784
850
|
|
785
851
|
@functools.lru_cache(maxsize=None)
|
786
852
|
def image_future() -> LogImage:
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
853
|
+
with ContextTimer() as timer:
|
854
|
+
image, label = value() if callable(value) else value
|
855
|
+
image = get_image(image, target_resolution)
|
856
|
+
|
857
|
+
image_value = image_with_text(
|
858
|
+
image.image,
|
859
|
+
standardize_text(label, max_line_length),
|
860
|
+
max_num_lines=max_num_lines,
|
861
|
+
line_spacing=line_spacing,
|
862
|
+
centered=centered,
|
863
|
+
)
|
864
|
+
|
865
|
+
logger.debug("Labeled Image Key: %s, Time: %s", key, timer.elapsed_time)
|
866
|
+
return image_value
|
796
867
|
|
797
868
|
self.images[namespace][key] = image_future
|
798
869
|
|
@@ -831,15 +902,18 @@ class Logger:
|
|
831
902
|
|
832
903
|
@functools.lru_cache(maxsize=None)
|
833
904
|
def images_future() -> LogImage:
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
905
|
+
with ContextTimer() as timer:
|
906
|
+
images = value() if callable(value) else value
|
907
|
+
if max_images is not None:
|
908
|
+
images = images[:max_images]
|
909
|
+
if isinstance(images, Array):
|
910
|
+
images = as_numpy(images)
|
911
|
+
if isinstance(images, Sequence):
|
912
|
+
images = list(images)
|
913
|
+
images = [get_image(image, target_resolution) for image in images]
|
914
|
+
tiled = tile_images([img.image for img in images], sep)
|
915
|
+
|
916
|
+
logger.debug("Images Key: %s, Time: %s", key, timer.elapsed_time)
|
843
917
|
return LogImage(image=tiled)
|
844
918
|
|
845
919
|
self.images[namespace][key] = images_future
|
@@ -886,22 +960,25 @@ class Logger:
|
|
886
960
|
|
887
961
|
@functools.lru_cache(maxsize=None)
|
888
962
|
def images_future() -> LogImage:
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
963
|
+
with ContextTimer() as timer:
|
964
|
+
images, labels = value() if callable(value) else value
|
965
|
+
if max_images is not None:
|
966
|
+
images = images[:max_images]
|
967
|
+
labels = labels[:max_images]
|
968
|
+
images = [get_image(image, target_resolution) for image in images]
|
969
|
+
labeled = [
|
970
|
+
image_with_text(
|
971
|
+
img.image,
|
972
|
+
standardize_text(label, max_line_length),
|
973
|
+
max_num_lines=max_num_lines,
|
974
|
+
line_spacing=line_spacing,
|
975
|
+
centered=centered,
|
976
|
+
)
|
977
|
+
for img, label in zip(images, labels)
|
978
|
+
]
|
979
|
+
tiled = tile_images([img.image for img in labeled], sep)
|
980
|
+
|
981
|
+
logger.debug("Labeled Images Key: %s, Time: %s", key, timer.elapsed_time)
|
905
982
|
return LogImage(image=tiled)
|
906
983
|
|
907
984
|
self.images[namespace][key] = images_future
|
@@ -934,7 +1011,11 @@ class Logger:
|
|
934
1011
|
|
935
1012
|
@functools.lru_cache(maxsize=None)
|
936
1013
|
def video_future() -> LogVideo:
|
937
|
-
|
1014
|
+
with ContextTimer() as timer:
|
1015
|
+
video = get_video(value() if callable(value) else value, fps=fps)
|
1016
|
+
|
1017
|
+
logger.debug("Video Key: %s, Time: %s", key, timer.elapsed_time)
|
1018
|
+
return video
|
938
1019
|
|
939
1020
|
self.videos[namespace][key] = video_future
|
940
1021
|
|
xax/task/mixins/train.py
CHANGED
@@ -156,7 +156,6 @@ class TrainConfig(
|
|
156
156
|
valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
|
157
157
|
valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
|
158
158
|
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
159
|
-
batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
|
160
159
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
161
160
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
162
161
|
random_seed: int = field(1337, help="Random seed for the task")
|
xax/utils/logging.py
CHANGED
@@ -140,7 +140,13 @@ class ColoredFormatter(logging.Formatter):
|
|
140
140
|
return logging.Formatter.format(self, record)
|
141
141
|
|
142
142
|
|
143
|
-
def configure_logging(
|
143
|
+
def configure_logging(
|
144
|
+
prefix: str | None = None,
|
145
|
+
*,
|
146
|
+
rank: int | None = None,
|
147
|
+
world_size: int | None = None,
|
148
|
+
debug: bool | None = None,
|
149
|
+
) -> None:
|
144
150
|
"""Instantiates logging.
|
145
151
|
|
146
152
|
This captures logs and reroutes them to the Toasts module, which is
|
@@ -151,6 +157,7 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
|
|
151
157
|
prefix: An optional prefix to add to the logger
|
152
158
|
rank: The current rank, or None if not using multiprocessing
|
153
159
|
world_size: The total world size, or None if not using multiprocessing
|
160
|
+
debug: Whether to enable debug logging
|
154
161
|
"""
|
155
162
|
if rank is not None or world_size is not None:
|
156
163
|
assert rank is not None and world_size is not None
|
@@ -168,7 +175,10 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
|
|
168
175
|
stream_handler.addFilter(filter)
|
169
176
|
root_logger.addHandler(stream_handler)
|
170
177
|
|
171
|
-
|
178
|
+
if debug is None:
|
179
|
+
root_logger.setLevel(logging._nameToLevel[config.log_level])
|
180
|
+
else:
|
181
|
+
root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
|
172
182
|
|
173
183
|
# Avoid junk logs from other libraries.
|
174
184
|
if config.hide_third_party_logs:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=JyKRACir9b0bkuG93bwxADFrVr-Lo76kenDBJtvb_wQ,13280
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -15,7 +15,7 @@ xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
|
15
15
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
16
16
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
|
18
|
-
xax/task/logger.py,sha256=
|
18
|
+
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
19
19
|
xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
|
20
20
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
21
21
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -39,13 +39,13 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
41
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
42
|
+
xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
45
|
xax/utils/experiments.py,sha256=_cwoBaiBxoQ_Tstm0rz7TEqfELqcktmPflb6AP1K0qA,28779
|
46
46
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
47
47
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
48
|
-
xax/utils/logging.py,sha256=
|
48
|
+
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
49
49
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
50
50
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
51
51
|
xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
|
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
|
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
56
|
-
xax-0.1.
|
57
|
-
xax-0.1.
|
58
|
-
xax-0.1.
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
56
|
+
xax-0.1.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
57
|
+
xax-0.1.1.dist-info/METADATA,sha256=tJ4ilL3uBbykHBQTHbh-bN6m4hrHqivyyFeuI33ddX4,1877
|
58
|
+
xax-0.1.1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
|
59
|
+
xax-0.1.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
60
|
+
xax-0.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|