xax 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/core/state.py +26 -1
- xax/requirements.txt +5 -5
- xax/task/base.py +1 -1
- xax/task/logger.py +149 -12
- xax/task/loggers/json.py +12 -4
- xax/task/loggers/stdout.py +21 -16
- xax/task/loggers/tensorboard.py +18 -2
- xax/task/mixins/checkpointing.py +118 -41
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +2 -1
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/train.py +59 -29
- xax/utils/experiments.py +34 -30
- xax/utils/tensorboard.py +91 -3
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/METADATA +6 -6
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/RECORD +20 -20
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/WHEEL +0 -0
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/top_level.txt +0 -0
xax/utils/tensorboard.py
CHANGED
@@ -2,11 +2,12 @@
|
|
2
2
|
|
3
3
|
import functools
|
4
4
|
import io
|
5
|
+
import json
|
5
6
|
import os
|
6
7
|
import tempfile
|
7
8
|
import time
|
8
9
|
from pathlib import Path
|
9
|
-
from typing import Literal, TypedDict
|
10
|
+
from typing import Any, Literal, TypedDict
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import PIL.Image
|
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
|
|
14
15
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
15
16
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
16
17
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
17
|
-
from tensorboard.compat.proto.summary_pb2 import
|
18
|
+
from tensorboard.compat.proto.summary_pb2 import (
|
19
|
+
HistogramProto,
|
20
|
+
Summary,
|
21
|
+
SummaryMetadata,
|
22
|
+
)
|
18
23
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
19
24
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
25
|
+
from tensorboard.plugins.mesh import metadata as mesh_metadata
|
26
|
+
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
|
20
27
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
21
28
|
from tensorboard.summary.writer.event_file_writer import EventFileWriter
|
22
29
|
|
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
|
|
84
91
|
)
|
85
92
|
|
86
93
|
|
94
|
+
def _get_json_config(config_dict: dict[str, Any] | None) -> str:
|
95
|
+
json_config = "{}"
|
96
|
+
if config_dict is not None:
|
97
|
+
json_config = json.dumps(config_dict, sort_keys=True)
|
98
|
+
return json_config
|
99
|
+
|
100
|
+
|
101
|
+
def make_mesh_summary(
|
102
|
+
tag: str,
|
103
|
+
vertices: np.ndarray,
|
104
|
+
colors: np.ndarray | None,
|
105
|
+
faces: np.ndarray | None,
|
106
|
+
config_dict: dict[str, Any] | None,
|
107
|
+
display_name: str | None = None,
|
108
|
+
description: str | None = None,
|
109
|
+
) -> Summary:
|
110
|
+
json_config = _get_json_config(config_dict)
|
111
|
+
|
112
|
+
summaries = []
|
113
|
+
tensors = [
|
114
|
+
(vertices, MeshPluginData.VERTEX),
|
115
|
+
(faces, MeshPluginData.FACE),
|
116
|
+
(colors, MeshPluginData.COLOR),
|
117
|
+
]
|
118
|
+
# Filter out None tensors and explicitly type the list
|
119
|
+
valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
|
120
|
+
components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
|
121
|
+
|
122
|
+
for tensor, content_type in valid_tensors: # Now we know tensor is not None
|
123
|
+
tensor_metadata = mesh_metadata.create_summary_metadata(
|
124
|
+
tag,
|
125
|
+
display_name,
|
126
|
+
content_type,
|
127
|
+
components,
|
128
|
+
tensor.shape, # Safe now since tensor is not None
|
129
|
+
description,
|
130
|
+
json_config=json_config,
|
131
|
+
)
|
132
|
+
|
133
|
+
tensor_proto = TensorProto(
|
134
|
+
dtype="DT_FLOAT",
|
135
|
+
float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
|
136
|
+
tensor_shape=TensorShapeProto(
|
137
|
+
dim=[
|
138
|
+
TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
|
139
|
+
TensorShapeProto.Dim(size=tensor.shape[1]),
|
140
|
+
TensorShapeProto.Dim(size=tensor.shape[2]),
|
141
|
+
]
|
142
|
+
),
|
143
|
+
)
|
144
|
+
|
145
|
+
tensor_summary = Summary.Value(
|
146
|
+
tag=mesh_metadata.get_instance_name(tag, content_type),
|
147
|
+
tensor=tensor_proto,
|
148
|
+
metadata=tensor_metadata,
|
149
|
+
)
|
150
|
+
|
151
|
+
summaries.append(tensor_summary)
|
152
|
+
|
153
|
+
return Summary(value=summaries)
|
154
|
+
|
155
|
+
|
87
156
|
class TensorboardProtobufWriter:
|
88
157
|
def __init__(
|
89
158
|
self,
|
@@ -454,6 +523,9 @@ class TensorboardWriter:
|
|
454
523
|
weighted_sum = float((bin_centers * bucket_counts).sum())
|
455
524
|
weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
|
456
525
|
|
526
|
+
# Convert bin edges to list of floats explicitly
|
527
|
+
bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
|
528
|
+
|
457
529
|
self.add_histogram_raw(
|
458
530
|
tag=tag,
|
459
531
|
min=float(bin_edges[0]),
|
@@ -461,12 +533,28 @@ class TensorboardWriter:
|
|
461
533
|
num=int(total_counts),
|
462
534
|
sum=weighted_sum,
|
463
535
|
sum_squares=weighted_sum_squares,
|
464
|
-
bucket_limits=
|
536
|
+
bucket_limits=bucket_limits, # Now properly typed
|
465
537
|
bucket_counts=bucket_counts.tolist(),
|
466
538
|
global_step=global_step,
|
467
539
|
walltime=walltime,
|
468
540
|
)
|
469
541
|
|
542
|
+
def add_mesh(
|
543
|
+
self,
|
544
|
+
tag: str,
|
545
|
+
vertices: np.ndarray,
|
546
|
+
colors: np.ndarray | None,
|
547
|
+
faces: np.ndarray | None,
|
548
|
+
config_dict: dict[str, Any] | None,
|
549
|
+
global_step: int | None = None,
|
550
|
+
walltime: float | None = None,
|
551
|
+
) -> None:
|
552
|
+
self.pb_writer.add_summary(
|
553
|
+
make_mesh_summary(tag, vertices, colors, faces, config_dict),
|
554
|
+
global_step=global_step,
|
555
|
+
walltime=walltime,
|
556
|
+
)
|
557
|
+
|
470
558
|
|
471
559
|
class TensorboardWriterKwargs(TypedDict):
|
472
560
|
max_queue_size: int
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=CO9UZlYsYsDL2B6z-Id0Fv0ZSD5uwUZ3eZ6zwwqtJhU,14103
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
-
xax/requirements.txt,sha256=
|
4
|
+
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=XejW1tGINYFFcNrscK8eZQsq02J7_RXa461QpmyWuLk,3337
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
@@ -16,8 +16,8 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
20
|
-
xax/task/logger.py,sha256=
|
19
|
+
xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
|
20
|
+
xax/task/logger.py,sha256=peGtfnvnBKr9l6tx1V6XAsvPs0HP6ubV_aE7IJtOMNk,40868
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -26,40 +26,40 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
|
|
26
26
|
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
27
27
|
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
|
-
xax/task/loggers/json.py,sha256=
|
29
|
+
xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
31
|
+
xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=JHBOdcgmJvhyXldPF5pHRmyPUN9SHcxxngsC1ap4b1E,11468
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
|
-
xax/task/mixins/cpu_stats.py,sha256=
|
38
|
-
xax/task/mixins/data_loader.py,sha256=
|
39
|
-
xax/task/mixins/gpu_stats.py,sha256=
|
37
|
+
xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
|
38
|
+
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
39
|
+
xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
|
40
40
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
41
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=t8Qyw40ahuJW0SPVgFLljqYbbSc1M_WLop87iwYE41Q,27064
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
47
|
+
xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
53
|
xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
|
54
|
-
xax/utils/tensorboard.py,sha256=
|
54
|
+
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
55
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.
|
62
|
-
xax-0.
|
63
|
-
xax-0.
|
64
|
-
xax-0.
|
65
|
-
xax-0.
|
61
|
+
xax-0.2.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.0.dist-info/METADATA,sha256=FyMDy4yB_KQF_IdCMMe_10VWpIEE5g6qEIZuXx-pLgU,1882
|
63
|
+
xax-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|