xax 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.0/xax.egg-info → xax-0.1.1}/PKG-INFO +1 -1
  2. {xax-0.1.0 → xax-0.1.1}/xax/__init__.py +1 -1
  3. {xax-0.1.0 → xax-0.1.1}/xax/task/logger.py +145 -64
  4. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/train.py +0 -1
  5. {xax-0.1.0 → xax-0.1.1}/xax/utils/logging.py +12 -2
  6. {xax-0.1.0 → xax-0.1.1/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.1.0 → xax-0.1.1}/LICENSE +0 -0
  8. {xax-0.1.0 → xax-0.1.1}/MANIFEST.in +0 -0
  9. {xax-0.1.0 → xax-0.1.1}/README.md +0 -0
  10. {xax-0.1.0 → xax-0.1.1}/pyproject.toml +0 -0
  11. {xax-0.1.0 → xax-0.1.1}/setup.cfg +0 -0
  12. {xax-0.1.0 → xax-0.1.1}/setup.py +0 -0
  13. {xax-0.1.0 → xax-0.1.1}/xax/core/__init__.py +0 -0
  14. {xax-0.1.0 → xax-0.1.1}/xax/core/conf.py +0 -0
  15. {xax-0.1.0 → xax-0.1.1}/xax/core/state.py +0 -0
  16. {xax-0.1.0 → xax-0.1.1}/xax/nn/__init__.py +0 -0
  17. {xax-0.1.0 → xax-0.1.1}/xax/nn/embeddings.py +0 -0
  18. {xax-0.1.0 → xax-0.1.1}/xax/nn/equinox.py +0 -0
  19. {xax-0.1.0 → xax-0.1.1}/xax/nn/export.py +0 -0
  20. {xax-0.1.0 → xax-0.1.1}/xax/nn/functions.py +0 -0
  21. {xax-0.1.0 → xax-0.1.1}/xax/nn/geom.py +0 -0
  22. {xax-0.1.0 → xax-0.1.1}/xax/nn/norm.py +0 -0
  23. {xax-0.1.0 → xax-0.1.1}/xax/nn/parallel.py +0 -0
  24. {xax-0.1.0 → xax-0.1.1}/xax/py.typed +0 -0
  25. {xax-0.1.0 → xax-0.1.1}/xax/requirements-dev.txt +0 -0
  26. {xax-0.1.0 → xax-0.1.1}/xax/requirements.txt +0 -0
  27. {xax-0.1.0 → xax-0.1.1}/xax/task/__init__.py +0 -0
  28. {xax-0.1.0 → xax-0.1.1}/xax/task/base.py +0 -0
  29. {xax-0.1.0 → xax-0.1.1}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.1.0 → xax-0.1.1}/xax/task/launchers/base.py +0 -0
  31. {xax-0.1.0 → xax-0.1.1}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.1.0 → xax-0.1.1}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/__init__.py +0 -0
  34. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/callback.py +0 -0
  35. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/json.py +0 -0
  36. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/state.py +0 -0
  37. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/stdout.py +0 -0
  38. {xax-0.1.0 → xax-0.1.1}/xax/task/loggers/tensorboard.py +0 -0
  39. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/__init__.py +0 -0
  40. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/artifacts.py +0 -0
  41. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/checkpointing.py +0 -0
  42. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/process.py +0 -0
  48. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.1.0 → xax-0.1.1}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.1.0 → xax-0.1.1}/xax/task/script.py +0 -0
  51. {xax-0.1.0 → xax-0.1.1}/xax/task/task.py +0 -0
  52. {xax-0.1.0 → xax-0.1.1}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.0 → xax-0.1.1}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.0 → xax-0.1.1}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.0 → xax-0.1.1}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.0 → xax-0.1.1}/xax/utils/experiments.py +0 -0
  57. {xax-0.1.0 → xax-0.1.1}/xax/utils/jax.py +0 -0
  58. {xax-0.1.0 → xax-0.1.1}/xax/utils/jaxpr.py +0 -0
  59. {xax-0.1.0 → xax-0.1.1}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.0 → xax-0.1.1}/xax/utils/profile.py +0 -0
  61. {xax-0.1.0 → xax-0.1.1}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.0 → xax-0.1.1}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.0 → xax-0.1.1}/xax/utils/text.py +0 -0
  64. {xax-0.1.0 → xax-0.1.1}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.0 → xax-0.1.1}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.0 → xax-0.1.1}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.0 → xax-0.1.1}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.0"
15
+ __version__ = "0.1.1"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -18,7 +18,7 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
21
+ from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
22
22
 
23
23
  import jax
24
24
  import jax.numpy as jnp
@@ -28,7 +28,7 @@ from PIL import Image, ImageDraw, ImageFont
28
28
  from PIL.Image import Image as PILImage
29
29
 
30
30
  from xax.core.state import Phase, State
31
- from xax.utils.experiments import IntervalTicker
31
+ from xax.utils.experiments import ContextTimer, IntervalTicker
32
32
  from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
33
33
 
34
34
  logger = logging.getLogger(__name__)
@@ -236,7 +236,7 @@ class LogHistogram:
236
236
  num: int
237
237
  sum: Number
238
238
  sum_squares: Number
239
- bucket_limits: list[Number]
239
+ bucket_limits: list[float]
240
240
  bucket_counts: list[int]
241
241
 
242
242
 
@@ -564,14 +564,14 @@ class Logger:
564
564
  Args:
565
565
  state: The current step's state.
566
566
  """
567
- should_log = [logger.should_log(state) for logger in self.loggers]
567
+ should_log = [lg.should_log(state) for lg in self.loggers]
568
568
  if not any(should_log):
569
569
  self.clear()
570
570
  return
571
571
  line = self.pack(state)
572
572
  self.clear()
573
- for logger in (logger for logger, should_log in zip(self.loggers, should_log) if should_log):
574
- logger.write(line)
573
+ for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
574
+ lg.write(line)
575
575
 
576
576
  def write_error_summary(self, error_summary: str) -> None:
577
577
  for logger in self.loggers:
@@ -633,7 +633,10 @@ class Logger:
633
633
 
634
634
  @functools.lru_cache(maxsize=None)
635
635
  def scalar_future() -> Number:
636
- return value() if callable(value) else value
636
+ with ContextTimer() as timer:
637
+ value_concrete = value() if callable(value) else value
638
+ logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
639
+ return value_concrete
637
640
 
638
641
  self.scalars[namespace][key] = scalar_future
639
642
 
@@ -657,7 +660,9 @@ class Logger:
657
660
 
658
661
  @functools.lru_cache(maxsize=None)
659
662
  def distribution_future() -> LogDistribution:
660
- mean, std = value() if callable(value) else value
663
+ with ContextTimer() as timer:
664
+ mean, std = value() if callable(value) else value
665
+ logger.debug("Distribution Key: %s, Time: %s", key, timer.elapsed_time)
661
666
  return LogDistribution(mean=mean, std=std)
662
667
 
663
668
  self.distributions[namespace][key] = distribution_future
@@ -684,26 +689,84 @@ class Logger:
684
689
 
685
690
  @functools.lru_cache(maxsize=None)
686
691
  def histogram_future() -> LogHistogram:
687
- values = value() if callable(value) else value
688
- values = values.reshape(-1) # Must be flat.
689
-
690
- if isinstance(values, Array):
691
- counts, limits = jnp.histogram(values, bins=bins)
692
- counts, limits = as_numpy(counts), as_numpy(limits)
693
- elif isinstance(values, np.ndarray):
694
- counts, limits = np.histogram(values, bins=bins)
695
- else:
696
- raise ValueError(f"Unsupported histogram type: {type(values)}")
697
-
698
- return LogHistogram(
699
- min=float(values.min()),
700
- max=float(values.max()),
701
- num=int(values.size),
702
- sum=float(values.sum()),
703
- sum_squares=float(values.dot(values)),
704
- bucket_limits=limits[1:].tolist(),
705
- bucket_counts=counts.tolist(),
706
- )
692
+ with ContextTimer() as timer:
693
+ values = value() if callable(value) else value
694
+ values = values.reshape(-1) # Must be flat.
695
+
696
+ if isinstance(values, Array):
697
+ counts, limits = jnp.histogram(values, bins=bins)
698
+ counts, limits = as_numpy(counts), as_numpy(limits)
699
+ elif isinstance(values, np.ndarray):
700
+ counts, limits = np.histogram(values, bins=bins)
701
+ else:
702
+ raise ValueError(f"Unsupported histogram type: {type(values)}")
703
+
704
+ histogram_values = LogHistogram(
705
+ min=float(values.min()),
706
+ max=float(values.max()),
707
+ num=int(values.size),
708
+ sum=float(values.sum()),
709
+ sum_squares=float(values.dot(values)),
710
+ bucket_limits=cast(list[float], limits[1:].tolist()),
711
+ bucket_counts=cast(list[int], counts.tolist()),
712
+ )
713
+
714
+ logger.debug("Histogram Key: %s, Time: %s", key, timer.elapsed_time)
715
+ return histogram_values
716
+
717
+ self.histograms[namespace][key] = histogram_future
718
+
719
+ def log_histogram_raw(
720
+ self,
721
+ key: str,
722
+ counts: Array | np.ndarray,
723
+ limits: Array | np.ndarray,
724
+ minv: Number | None = None,
725
+ maxv: Number | None = None,
726
+ sumv: Number | None = None,
727
+ sum_squaresv: Number | None = None,
728
+ *,
729
+ namespace: str | None = None,
730
+ ) -> None:
731
+ """Logs a histogram from raw counts and limits.
732
+
733
+ Args:
734
+ key: The key being logged
735
+ counts: The counts of the histogram
736
+ limits: The limits of the histogram
737
+ minv: The minimum value of the histogram
738
+ maxv: The maximum value of the histogram
739
+ sumv: The sum of the histogram
740
+ sum_squaresv: The sum of the squares of the histogram
741
+ namespace: An optional logging namespace
742
+ """
743
+ if not self.active:
744
+ raise RuntimeError("The logger is not active")
745
+ namespace = self.resolve_namespace(namespace)
746
+
747
+ @functools.lru_cache(maxsize=None)
748
+ def histogram_future() -> LogHistogram:
749
+ with ContextTimer() as timer:
750
+ counts_np = (as_numpy(counts) if isinstance(counts, Array) else counts).astype(int)
751
+ limits_np = (as_numpy(limits) if isinstance(limits, Array) else limits).astype(float)
752
+
753
+ minv_ = counts_np.min() if minv is None else minv
754
+ maxv_ = counts_np.max() if maxv is None else maxv
755
+ sumv_ = counts_np.sum() if sumv is None else sumv
756
+ sum_squaresv_ = counts_np.dot(counts_np) if sum_squaresv is None else sum_squaresv
757
+
758
+ histogram_values = LogHistogram(
759
+ min=float(minv_),
760
+ max=float(maxv_),
761
+ num=int(counts_np.size),
762
+ sum=float(sumv_),
763
+ sum_squares=float(sum_squaresv_),
764
+ bucket_limits=cast(list[float], limits_np.tolist()),
765
+ bucket_counts=cast(list[int], counts_np.tolist()),
766
+ )
767
+
768
+ logger.debug("Raw Histogram Key: %s, Time: %s", key, timer.elapsed_time)
769
+ return histogram_values
707
770
 
708
771
  self.histograms[namespace][key] = histogram_future
709
772
 
@@ -748,7 +811,10 @@ class Logger:
748
811
 
749
812
  @functools.lru_cache(maxsize=None)
750
813
  def image_future() -> LogImage:
751
- return get_image(value() if callable(value) else value, target_resolution)
814
+ with ContextTimer() as timer:
815
+ image = get_image(value() if callable(value) else value, target_resolution)
816
+ logger.debug("Image Key: %s, Time: %s", key, timer.elapsed_time)
817
+ return image
752
818
 
753
819
  self.images[namespace][key] = image_future
754
820
 
@@ -784,15 +850,20 @@ class Logger:
784
850
 
785
851
  @functools.lru_cache(maxsize=None)
786
852
  def image_future() -> LogImage:
787
- image, label = value() if callable(value) else value
788
- image = get_image(image, target_resolution)
789
- return image_with_text(
790
- image.image,
791
- standardize_text(label, max_line_length),
792
- max_num_lines=max_num_lines,
793
- line_spacing=line_spacing,
794
- centered=centered,
795
- )
853
+ with ContextTimer() as timer:
854
+ image, label = value() if callable(value) else value
855
+ image = get_image(image, target_resolution)
856
+
857
+ image_value = image_with_text(
858
+ image.image,
859
+ standardize_text(label, max_line_length),
860
+ max_num_lines=max_num_lines,
861
+ line_spacing=line_spacing,
862
+ centered=centered,
863
+ )
864
+
865
+ logger.debug("Labeled Image Key: %s, Time: %s", key, timer.elapsed_time)
866
+ return image_value
796
867
 
797
868
  self.images[namespace][key] = image_future
798
869
 
@@ -831,15 +902,18 @@ class Logger:
831
902
 
832
903
  @functools.lru_cache(maxsize=None)
833
904
  def images_future() -> LogImage:
834
- images = value() if callable(value) else value
835
- if max_images is not None:
836
- images = images[:max_images]
837
- if isinstance(images, Array):
838
- images = as_numpy(images)
839
- if isinstance(images, Sequence):
840
- images = list(images)
841
- images = [get_image(image, target_resolution) for image in images]
842
- tiled = tile_images([img.image for img in images], sep)
905
+ with ContextTimer() as timer:
906
+ images = value() if callable(value) else value
907
+ if max_images is not None:
908
+ images = images[:max_images]
909
+ if isinstance(images, Array):
910
+ images = as_numpy(images)
911
+ if isinstance(images, Sequence):
912
+ images = list(images)
913
+ images = [get_image(image, target_resolution) for image in images]
914
+ tiled = tile_images([img.image for img in images], sep)
915
+
916
+ logger.debug("Images Key: %s, Time: %s", key, timer.elapsed_time)
843
917
  return LogImage(image=tiled)
844
918
 
845
919
  self.images[namespace][key] = images_future
@@ -886,22 +960,25 @@ class Logger:
886
960
 
887
961
  @functools.lru_cache(maxsize=None)
888
962
  def images_future() -> LogImage:
889
- images, labels = value() if callable(value) else value
890
- if max_images is not None:
891
- images = images[:max_images]
892
- labels = labels[:max_images]
893
- images = [get_image(image, target_resolution) for image in images]
894
- labeled = [
895
- image_with_text(
896
- img.image,
897
- standardize_text(label, max_line_length),
898
- max_num_lines=max_num_lines,
899
- line_spacing=line_spacing,
900
- centered=centered,
901
- )
902
- for img, label in zip(images, labels)
903
- ]
904
- tiled = tile_images([img.image for img in labeled], sep)
963
+ with ContextTimer() as timer:
964
+ images, labels = value() if callable(value) else value
965
+ if max_images is not None:
966
+ images = images[:max_images]
967
+ labels = labels[:max_images]
968
+ images = [get_image(image, target_resolution) for image in images]
969
+ labeled = [
970
+ image_with_text(
971
+ img.image,
972
+ standardize_text(label, max_line_length),
973
+ max_num_lines=max_num_lines,
974
+ line_spacing=line_spacing,
975
+ centered=centered,
976
+ )
977
+ for img, label in zip(images, labels)
978
+ ]
979
+ tiled = tile_images([img.image for img in labeled], sep)
980
+
981
+ logger.debug("Labeled Images Key: %s, Time: %s", key, timer.elapsed_time)
905
982
  return LogImage(image=tiled)
906
983
 
907
984
  self.images[namespace][key] = images_future
@@ -934,7 +1011,11 @@ class Logger:
934
1011
 
935
1012
  @functools.lru_cache(maxsize=None)
936
1013
  def video_future() -> LogVideo:
937
- return get_video(value() if callable(value) else value, fps=fps)
1014
+ with ContextTimer() as timer:
1015
+ video = get_video(value() if callable(value) else value, fps=fps)
1016
+
1017
+ logger.debug("Video Key: %s, Time: %s", key, timer.elapsed_time)
1018
+ return video
938
1019
 
939
1020
  self.videos[namespace][key] = video_future
940
1021
 
@@ -156,7 +156,6 @@ class TrainConfig(
156
156
  valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
157
157
  valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
158
158
  valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
159
- batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
160
159
  max_steps: int | None = field(None, help="Maximum number of steps to run")
161
160
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
162
161
  random_seed: int = field(1337, help="Random seed for the task")
@@ -140,7 +140,13 @@ class ColoredFormatter(logging.Formatter):
140
140
  return logging.Formatter.format(self, record)
141
141
 
142
142
 
143
- def configure_logging(prefix: str | None = None, *, rank: int | None = None, world_size: int | None = None) -> None:
143
+ def configure_logging(
144
+ prefix: str | None = None,
145
+ *,
146
+ rank: int | None = None,
147
+ world_size: int | None = None,
148
+ debug: bool | None = None,
149
+ ) -> None:
144
150
  """Instantiates logging.
145
151
 
146
152
  This captures logs and reroutes them to the Toasts module, which is
@@ -151,6 +157,7 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
151
157
  prefix: An optional prefix to add to the logger
152
158
  rank: The current rank, or None if not using multiprocessing
153
159
  world_size: The total world size, or None if not using multiprocessing
160
+ debug: Whether to enable debug logging
154
161
  """
155
162
  if rank is not None or world_size is not None:
156
163
  assert rank is not None and world_size is not None
@@ -168,7 +175,10 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
168
175
  stream_handler.addFilter(filter)
169
176
  root_logger.addHandler(stream_handler)
170
177
 
171
- root_logger.setLevel(logging._nameToLevel[config.log_level])
178
+ if debug is None:
179
+ root_logger.setLevel(logging._nameToLevel[config.log_level])
180
+ else:
181
+ root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
172
182
 
173
183
  # Avoid junk logs from other libraries.
174
184
  if config.hide_third_party_logs:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes