xax 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.3"
15
+ __version__ = "0.1.5"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/nn/export.py CHANGED
@@ -9,7 +9,14 @@ import jax
9
9
  import tensorflow as tf
10
10
  from jax.experimental import jax2tf
11
11
  from jaxtyping import Array, PyTree
12
- from orbax.export import ExportManager, JaxModule, ServingConfig
12
+
13
+ try:
14
+ from orbax.export import ExportManager, JaxModule, ServingConfig
15
+ except ImportError as e:
16
+ raise ImportError(
17
+ "Please install the package with `orbax` as a dependency, using "
18
+ "'xax[export]` to install the required dependencies."
19
+ ) from e
13
20
 
14
21
  logger = logging.getLogger(__name__)
15
22
 
xax/task/base.py CHANGED
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
153
153
  for base in cls.__orig_bases__:
154
154
  if hasattr(base, "__args__"):
155
155
  for arg in base.__args__:
156
+ if isinstance(arg, TypeVar) and arg.__bound__ is not None:
157
+ arg = arg.__bound__
156
158
  if issubclass(arg, BaseConfig):
157
159
  return arg
158
160
 
@@ -5,6 +5,7 @@ behavior during initialization and training.
5
5
  """
6
6
 
7
7
  import logging
8
+ import sys
8
9
  from dataclasses import dataclass
9
10
  from pathlib import Path
10
11
  from typing import Generic, TypeVar
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
17
18
  logger = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ def get_cache_dir() -> str | None:
22
+ # By default, only cache on MacOS, since Jax caching on Linux is very
23
+ # prone to NaNs.
24
+ match sys.platform:
25
+ case "darwin":
26
+ return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
+ case _:
28
+ return None
29
+
30
+
20
31
  @jax.tree_util.register_dataclass
21
32
  @dataclass
22
33
  class CompileOptions:
@@ -42,7 +53,8 @@ class CompileOptions:
42
53
 
43
54
  # JAX cache options
44
55
  cache_dir: str | None = field(
45
- value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
56
+ # Only cache by default on MacOS systems.
57
+ value=get_cache_dir,
46
58
  help="Directory for JAX compilation cache. If None, caching is disabled",
47
59
  )
48
60
  cache_min_size_bytes: int = field(
@@ -54,7 +66,7 @@ class CompileOptions:
54
66
  help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
67
  )
56
68
  cache_enable_xla: str = field(
57
- value="all",
69
+ value="none",
58
70
  help="Which XLA caches to enable",
59
71
  )
60
72
 
xax/task/mixins/train.py CHANGED
@@ -122,7 +122,7 @@ class ValidStepTimer:
122
122
 
123
123
  # Step-based validation.
124
124
  valid_every_n_steps = self.valid_every_n_steps
125
- if valid_every_n_steps is not None and state.num_steps > valid_every_n_steps + self.last_valid_step:
125
+ if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
126
126
  self.last_valid_step = state.num_steps
127
127
  return True
128
128
 
xax/utils/tensorboard.py CHANGED
@@ -258,12 +258,40 @@ class TensorboardWriter:
258
258
  fps: int = 30,
259
259
  ) -> None:
260
260
  assert value.ndim == 4, "Video must be 4D array (T, H, W, C)"
261
- images = [PIL.Image.fromarray(frame) for frame in value]
261
+
262
+ images = [PIL.Image.fromarray(frame).convert("RGB") for frame in value]
263
+ width, height = images[0].size
264
+ big_image = PIL.Image.new("RGB", (width, height * len(images)))
265
+ for i, im in enumerate(images):
266
+ big_image.paste(im, (0, i * height))
267
+
268
+ quantized_big = big_image.quantize(method=PIL.Image.Quantize.MAXCOVERAGE, dither=PIL.Image.Dither.NONE)
269
+ palette = quantized_big.getpalette()
270
+
271
+ processed = []
272
+ for im in images:
273
+ q = im.quantize(
274
+ method=PIL.Image.Quantize.MAXCOVERAGE,
275
+ palette=quantized_big,
276
+ dither=PIL.Image.Dither.NONE,
277
+ )
278
+ processed.append(q)
279
+
280
+ if palette is not None:
281
+ palette[0:3] = [255, 255, 255]
282
+ for im in processed:
283
+ im.putpalette(palette)
262
284
 
263
285
  # Create temporary file for GIF
264
286
  temp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
265
287
  try:
266
- images[0].save(temp_file.name, save_all=True, append_images=images[1:], duration=int(1000 / fps), loop=0)
288
+ processed[0].save(
289
+ temp_file.name,
290
+ save_all=True,
291
+ append_images=processed[1:],
292
+ duration=int(1000 / fps),
293
+ loop=0,
294
+ )
267
295
  with open(temp_file.name, "rb") as f:
268
296
  video_string = f.read()
269
297
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=n7EXl0pwEPzlw2DjS-3ePgx0VoQnMDnHLVc5exkHGcM,13361
1
+ xax/__init__.py,sha256=Ig2JALbHQOtFHr57jpmPm16L0aYkNTNZ-Upz9jxIyrc,13361
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -8,13 +8,13 @@ xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=1Ck6ycz76dhit2LHX4y2lp3WJSPsDuRt7TK7AxxQhww,4837
11
- xax/nn/export.py,sha256=Do2bLjJTD744mxpQuPYpz8fZ3EIjBLaaZfhp8maNVrg,5303
11
+ xax/nn/export.py,sha256=bu2m-4FDnadEhXDb9zM6SgOZvsf5p4xiee1sFZyNF7c,5510
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
13
  xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
14
  xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
15
15
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
16
16
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
17
+ xax/task/base.py,sha256=4fUjrG-llQpeESQuaQbww4M6WR6djjTK89fY20UV9zU,7610
18
18
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
19
19
  xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
20
20
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
@@ -31,7 +31,7 @@ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCn
31
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
32
32
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
33
  xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
34
- xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
34
+ xax/task/mixins/compile.py,sha256=9pVJEUvizu6-6tq0HaMtHGNSi9Yk_mxNyqBFimcfwL0,3683
35
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
36
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
37
37
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
41
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=8AaBXaopnrxtSZXldyFCE3QX1k5r3IsZMr6O0ICnNnU,22332
42
+ xax/task/mixins/train.py,sha256=vsH_QpyrThlh9AzWnyvDJv58Y8U_516oi8gmMq_0iMg,22333
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
45
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -49,12 +49,12 @@ xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
49
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
50
50
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
51
51
  xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
52
- xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,16170
52
+ xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
56
- xax-0.1.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
- xax-0.1.3.dist-info/METADATA,sha256=m3AyjlRD9C-O2Tp5zH5i5TbEL7bZooeIpypUCYuYPtQ,1877
58
- xax-0.1.3.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
59
- xax-0.1.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
- xax-0.1.3.dist-info/RECORD,,
56
+ xax-0.1.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
+ xax-0.1.5.dist-info/METADATA,sha256=dkvPH-GXRErGzzR0VGsiRisNYBzDVf97nEGG0F2-HjY,1877
58
+ xax-0.1.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
59
+ xax-0.1.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
+ xax-0.1.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.2)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5