xax 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/utils/tensorboard.py CHANGED
@@ -2,11 +2,12 @@
2
2
 
3
3
  import functools
4
4
  import io
5
+ import json
5
6
  import os
6
7
  import tempfile
7
8
  import time
8
9
  from pathlib import Path
9
- from typing import Literal, TypedDict
10
+ from typing import Any, Literal, TypedDict
10
11
 
11
12
  import numpy as np
12
13
  import PIL.Image
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
14
15
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
16
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
17
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
+ from tensorboard.compat.proto.summary_pb2 import (
19
+ HistogramProto,
20
+ Summary,
21
+ SummaryMetadata,
22
+ )
18
23
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
24
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
25
+ from tensorboard.plugins.mesh import metadata as mesh_metadata
26
+ from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
20
27
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
21
28
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
22
29
 
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
84
91
  )
85
92
 
86
93
 
94
+ def _get_json_config(config_dict: dict[str, Any] | None) -> str:
95
+ json_config = "{}"
96
+ if config_dict is not None:
97
+ json_config = json.dumps(config_dict, sort_keys=True)
98
+ return json_config
99
+
100
+
101
+ def make_mesh_summary(
102
+ tag: str,
103
+ vertices: np.ndarray,
104
+ colors: np.ndarray | None,
105
+ faces: np.ndarray | None,
106
+ config_dict: dict[str, Any] | None,
107
+ display_name: str | None = None,
108
+ description: str | None = None,
109
+ ) -> Summary:
110
+ json_config = _get_json_config(config_dict)
111
+
112
+ summaries = []
113
+ tensors = [
114
+ (vertices, MeshPluginData.VERTEX),
115
+ (faces, MeshPluginData.FACE),
116
+ (colors, MeshPluginData.COLOR),
117
+ ]
118
+ # Filter out None tensors and explicitly type the list
119
+ valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
120
+ components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
121
+
122
+ for tensor, content_type in valid_tensors: # Now we know tensor is not None
123
+ tensor_metadata = mesh_metadata.create_summary_metadata(
124
+ tag,
125
+ display_name,
126
+ content_type,
127
+ components,
128
+ tensor.shape, # Safe now since tensor is not None
129
+ description,
130
+ json_config=json_config,
131
+ )
132
+
133
+ tensor_proto = TensorProto(
134
+ dtype="DT_FLOAT",
135
+ float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
136
+ tensor_shape=TensorShapeProto(
137
+ dim=[
138
+ TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
139
+ TensorShapeProto.Dim(size=tensor.shape[1]),
140
+ TensorShapeProto.Dim(size=tensor.shape[2]),
141
+ ]
142
+ ),
143
+ )
144
+
145
+ tensor_summary = Summary.Value(
146
+ tag=mesh_metadata.get_instance_name(tag, content_type),
147
+ tensor=tensor_proto,
148
+ metadata=tensor_metadata,
149
+ )
150
+
151
+ summaries.append(tensor_summary)
152
+
153
+ return Summary(value=summaries)
154
+
155
+
87
156
  class TensorboardProtobufWriter:
88
157
  def __init__(
89
158
  self,
@@ -454,6 +523,9 @@ class TensorboardWriter:
454
523
  weighted_sum = float((bin_centers * bucket_counts).sum())
455
524
  weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
456
525
 
526
+ # Convert bin edges to list of floats explicitly
527
+ bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
528
+
457
529
  self.add_histogram_raw(
458
530
  tag=tag,
459
531
  min=float(bin_edges[0]),
@@ -461,12 +533,28 @@ class TensorboardWriter:
461
533
  num=int(total_counts),
462
534
  sum=weighted_sum,
463
535
  sum_squares=weighted_sum_squares,
464
- bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
536
+ bucket_limits=bucket_limits, # Now properly typed
465
537
  bucket_counts=bucket_counts.tolist(),
466
538
  global_step=global_step,
467
539
  walltime=walltime,
468
540
  )
469
541
 
542
+ def add_mesh(
543
+ self,
544
+ tag: str,
545
+ vertices: np.ndarray,
546
+ colors: np.ndarray | None,
547
+ faces: np.ndarray | None,
548
+ config_dict: dict[str, Any] | None,
549
+ global_step: int | None = None,
550
+ walltime: float | None = None,
551
+ ) -> None:
552
+ self.pb_writer.add_summary(
553
+ make_mesh_summary(tag, vertices, colors, faces, config_dict),
554
+ global_step=global_step,
555
+ walltime=walltime,
556
+ )
557
+
470
558
 
471
559
  class TensorboardWriterKwargs(TypedDict):
472
560
  max_queue_size: int
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.15
3
+ Version: 0.2.0
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -1,10 +1,10 @@
1
- xax/__init__.py,sha256=bV2mTcuiVaVNvwgbDgg7dKDkMeuyA0mqF0muU5KZHeg,14104
1
+ xax/__init__.py,sha256=CO9UZlYsYsDL2B6z-Id0Fv0ZSD5uwUZ3eZ6zwwqtJhU,14103
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
- xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
4
+ xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
- xax/core/state.py,sha256=WwW0qDm-be9MMOT-bGWEFvaWF4iq2FP9xRSn1zq_4A8,2507
7
+ xax/core/state.py,sha256=XejW1tGINYFFcNrscK8eZQsq02J7_RXa461QpmyWuLk,3337
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
@@ -16,8 +16,8 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
- xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
19
+ xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
20
+ xax/task/logger.py,sha256=peGtfnvnBKr9l6tx1V6XAsvPs0HP6ubV_aE7IJtOMNk,40868
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,40 +26,40 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
26
26
  xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
27
27
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
- xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
29
+ xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
- xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
32
- xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
31
+ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
32
+ xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
35
+ xax/task/mixins/checkpointing.py,sha256=JHBOdcgmJvhyXldPF5pHRmyPUN9SHcxxngsC1ap4b1E,11468
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
- xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
38
- xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
39
- xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
37
+ xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
38
+ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
39
+ xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
40
40
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=1hmUx1HIL8HKfwOnupS3Knsw1CiK2YCbIQnUTYyDEms,26157
44
+ xax/task/mixins/train.py,sha256=t8Qyw40ahuJW0SPVgFLljqYbbSc1M_WLop87iwYE41Q,27064
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=X6MESZ3z_Z0DLH6NQucuPzibuOc6rZmlf5UZt4in458,29591
47
+ xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
51
51
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
52
52
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
53
53
  xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
54
- xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
54
+ xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
55
55
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
56
56
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.15.dist-info/METADATA,sha256=i5thFSTL1Zx03UpnCj7f71rxSgs0P3L6ZDd6vYEtM7U,1878
63
- xax-0.1.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.15.dist-info/RECORD,,
61
+ xax-0.2.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.0.dist-info/METADATA,sha256=FyMDy4yB_KQF_IdCMMe_10VWpIEE5g6qEIZuXx-pLgU,1882
63
+ xax-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.0.dist-info/RECORD,,
File without changes