xax 0.1.4__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.4/xax.egg-info → xax-0.1.5}/PKG-INFO +1 -1
- {xax-0.1.4 → xax-0.1.5}/xax/__init__.py +1 -1
- {xax-0.1.4 → xax-0.1.5}/xax/nn/export.py +8 -1
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/train.py +1 -1
- {xax-0.1.4 → xax-0.1.5}/xax/utils/tensorboard.py +30 -2
- {xax-0.1.4 → xax-0.1.5/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.1.4 → xax-0.1.5}/LICENSE +0 -0
- {xax-0.1.4 → xax-0.1.5}/MANIFEST.in +0 -0
- {xax-0.1.4 → xax-0.1.5}/README.md +0 -0
- {xax-0.1.4 → xax-0.1.5}/pyproject.toml +0 -0
- {xax-0.1.4 → xax-0.1.5}/setup.cfg +0 -0
- {xax-0.1.4 → xax-0.1.5}/setup.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/core/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/core/conf.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/core/state.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/embeddings.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/equinox.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/functions.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/geom.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/norm.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/nn/parallel.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/py.typed +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/requirements-dev.txt +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/requirements.txt +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/base.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/launchers/base.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/logger.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/json.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/state.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/process.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/script.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/task/task.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/data/collate.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/debugging.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/experiments.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/jax.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/logging.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/numpy.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/profile.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/pytree.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax/utils/text.py +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.4 → xax-0.1.5}/xax.egg-info/top_level.txt +0 -0
@@ -9,7 +9,14 @@ import jax
|
|
9
9
|
import tensorflow as tf
|
10
10
|
from jax.experimental import jax2tf
|
11
11
|
from jaxtyping import Array, PyTree
|
12
|
-
|
12
|
+
|
13
|
+
try:
|
14
|
+
from orbax.export import ExportManager, JaxModule, ServingConfig
|
15
|
+
except ImportError as e:
|
16
|
+
raise ImportError(
|
17
|
+
"Please install the package with `orbax` as a dependency, using "
|
18
|
+
"'xax[export]` to install the required dependencies."
|
19
|
+
) from e
|
13
20
|
|
14
21
|
logger = logging.getLogger(__name__)
|
15
22
|
|
@@ -122,7 +122,7 @@ class ValidStepTimer:
|
|
122
122
|
|
123
123
|
# Step-based validation.
|
124
124
|
valid_every_n_steps = self.valid_every_n_steps
|
125
|
-
if valid_every_n_steps is not None and state.num_steps
|
125
|
+
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
126
126
|
self.last_valid_step = state.num_steps
|
127
127
|
return True
|
128
128
|
|
@@ -258,12 +258,40 @@ class TensorboardWriter:
|
|
258
258
|
fps: int = 30,
|
259
259
|
) -> None:
|
260
260
|
assert value.ndim == 4, "Video must be 4D array (T, H, W, C)"
|
261
|
-
|
261
|
+
|
262
|
+
images = [PIL.Image.fromarray(frame).convert("RGB") for frame in value]
|
263
|
+
width, height = images[0].size
|
264
|
+
big_image = PIL.Image.new("RGB", (width, height * len(images)))
|
265
|
+
for i, im in enumerate(images):
|
266
|
+
big_image.paste(im, (0, i * height))
|
267
|
+
|
268
|
+
quantized_big = big_image.quantize(method=PIL.Image.Quantize.MAXCOVERAGE, dither=PIL.Image.Dither.NONE)
|
269
|
+
palette = quantized_big.getpalette()
|
270
|
+
|
271
|
+
processed = []
|
272
|
+
for im in images:
|
273
|
+
q = im.quantize(
|
274
|
+
method=PIL.Image.Quantize.MAXCOVERAGE,
|
275
|
+
palette=quantized_big,
|
276
|
+
dither=PIL.Image.Dither.NONE,
|
277
|
+
)
|
278
|
+
processed.append(q)
|
279
|
+
|
280
|
+
if palette is not None:
|
281
|
+
palette[0:3] = [255, 255, 255]
|
282
|
+
for im in processed:
|
283
|
+
im.putpalette(palette)
|
262
284
|
|
263
285
|
# Create temporary file for GIF
|
264
286
|
temp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
|
265
287
|
try:
|
266
|
-
|
288
|
+
processed[0].save(
|
289
|
+
temp_file.name,
|
290
|
+
save_all=True,
|
291
|
+
append_images=processed[1:],
|
292
|
+
duration=int(1000 / fps),
|
293
|
+
loop=0,
|
294
|
+
)
|
267
295
|
with open(temp_file.name, "rb") as f:
|
268
296
|
video_string = f.read()
|
269
297
|
|
{xax-0.1.4 → xax-0.1.5}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.1.4 → xax-0.1.5}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|