waveletvit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waveletvit/BandDecoder.py +146 -0
- waveletvit/UWaveletViTNet3D.py +213 -0
- waveletvit/WaveletViT.py +778 -0
- waveletvit/__init__.py +39 -0
- waveletvit-0.0.1.dist-info/METADATA +91 -0
- waveletvit-0.0.1.dist-info/RECORD +8 -0
- waveletvit-0.0.1.dist-info/WHEEL +4 -0
- waveletvit-0.0.1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""BandDecoder: reconstruct full-resolution wavelet bands from compressed grids.
|
|
2
|
+
|
|
3
|
+
Reverse of the separable strided conv tokenization in WaveletViTBlockND.
|
|
4
|
+
Per-level decoder shared across HP bands. Uses axis-by-axis ConvTranspose1D.
|
|
5
|
+
|
|
6
|
+
Copyright 2026 Kishore Kumar Tarafdar.
|
|
7
|
+
Licensed under the Apache License, Version 2.0. See LICENSE for details.
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
import tensorflow as tf
|
|
14
|
+
from tensorflow import keras
|
|
15
|
+
from tensorflow.keras import layers
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _apply_convtranspose1d_along_axis(x, conv, axis, dims=3):
|
|
19
|
+
"""Apply Conv1D transpose along a specific spatial axis of an N-D tensor."""
|
|
20
|
+
spatial = list(range(1, 1 + dims))
|
|
21
|
+
target = 1 + axis
|
|
22
|
+
perm = [0, target] + [s for s in spatial if s != target] + [dims + 1]
|
|
23
|
+
x_t = tf.transpose(x, perm)
|
|
24
|
+
shape_t = tf.shape(x_t)
|
|
25
|
+
T = shape_t[1]
|
|
26
|
+
C = shape_t[-1]
|
|
27
|
+
others = tf.reduce_prod(shape_t[2:-1])
|
|
28
|
+
x_flat = tf.reshape(x_t, [shape_t[0] * others, T, C])
|
|
29
|
+
y_flat = conv(x_flat)
|
|
30
|
+
T_out = tf.shape(y_flat)[1]
|
|
31
|
+
C_out = tf.shape(y_flat)[-1]
|
|
32
|
+
y_t = tf.reshape(y_flat, tf.concat([shape_t[:1], [T_out], shape_t[2:-1], [C_out]], axis=0))
|
|
33
|
+
inv_perm = [0] * (dims + 2)
|
|
34
|
+
for i, p in enumerate(perm):
|
|
35
|
+
inv_perm[p] = i
|
|
36
|
+
return tf.transpose(y_t, inv_perm)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@tf.keras.utils.register_keras_serializable(package='waveletvit')
|
|
40
|
+
class BandDecoder(keras.layers.Layer):
|
|
41
|
+
"""Decode compressed 3D grids back to full-resolution wavelet bands.
|
|
42
|
+
|
|
43
|
+
Reverse of WaveletViTBlockND's separable strided conv tokenization.
|
|
44
|
+
Per-level decoder: shared ConvTranspose1D across 7 HP bands at same level.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
levels: number of DWT levels.
|
|
48
|
+
groups: number of HP bands per level (7 for 3D DWT).
|
|
49
|
+
token_dim: channel dim F of compressed grids.
|
|
50
|
+
stride: spatial stride factor used in tokenization.
|
|
51
|
+
target_spatials: list of target spatial sizes per level, e.g. [32, 16, 8].
|
|
52
|
+
out_channels: number of channels in reconstructed wavelet bands.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, levels: int, groups: int, token_dim: int, stride: int,
|
|
56
|
+
target_spatials: List[int], out_channels: int = 1,
|
|
57
|
+
name: str = 'band_decoder', **kwargs):
|
|
58
|
+
super().__init__(name=name, **kwargs)
|
|
59
|
+
self.levels = int(levels)
|
|
60
|
+
self.groups = int(groups)
|
|
61
|
+
self.F = int(token_dim)
|
|
62
|
+
self.stride = int(stride)
|
|
63
|
+
self.target_spatials = [int(s) for s in target_spatials]
|
|
64
|
+
self.out_channels = int(out_channels)
|
|
65
|
+
self.dims = 3
|
|
66
|
+
|
|
67
|
+
# Per-level: 3 ConvTranspose1D (one per axis) + final Conv1D to raw band channels
|
|
68
|
+
self._up_convs = []
|
|
69
|
+
self._out_convs = []
|
|
70
|
+
Conv1DT = layers.Conv1DTranspose
|
|
71
|
+
for l in range(self.levels):
|
|
72
|
+
s = min(self.stride, self.target_spatials[l])
|
|
73
|
+
axis_convs = [
|
|
74
|
+
Conv1DT(self.F, s, strides=s, padding='valid', activation='relu',
|
|
75
|
+
name=f'{name}_l{l}_ax{ax}')
|
|
76
|
+
for ax in range(self.dims)
|
|
77
|
+
]
|
|
78
|
+
self._up_convs.append(axis_convs)
|
|
79
|
+
self._out_convs.append(layers.Conv1D(self.out_channels, 1,
|
|
80
|
+
kernel_initializer='zeros', bias_initializer='zeros',
|
|
81
|
+
name=f'{name}_l{l}_out'))
|
|
82
|
+
|
|
83
|
+
# LP decoder: simple Conv to raw band channels; zero-init for stable start
|
|
84
|
+
self._lp_out = layers.Conv1D(self.out_channels, 1,
|
|
85
|
+
kernel_initializer='zeros', bias_initializer='zeros',
|
|
86
|
+
name=f'{name}_lp_out')
|
|
87
|
+
|
|
88
|
+
def call(self, compressed_per_level: List[tf.Tensor], lp_compressed: tf.Tensor):
|
|
89
|
+
"""Decode compressed features to raw wavelet bands.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
compressed_per_level: list of (B, comp^3, groups*F) stacked features per level.
|
|
93
|
+
lp_compressed: (B, comp_lp^3, F_lp); deepest LP compressed features.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
bands: list[level] of {'lp': Tensor, 'hp': List[Tensor]}; raw wavelet bands.
|
|
97
|
+
"""
|
|
98
|
+
bands = []
|
|
99
|
+
for l in range(self.levels):
|
|
100
|
+
stacked = compressed_per_level[l] # (B, comp^3, groups*F)
|
|
101
|
+
# Split into per-band features
|
|
102
|
+
hp_feats = tf.split(stacked, self.groups, axis=-1) # list of (B, comp^3, F)
|
|
103
|
+
|
|
104
|
+
hp_bands = []
|
|
105
|
+
target = self.target_spatials[l]
|
|
106
|
+
for g, feat in enumerate(hp_feats):
|
|
107
|
+
t = feat
|
|
108
|
+
for ax, conv in enumerate(self._up_convs[l]):
|
|
109
|
+
t = _apply_convtranspose1d_along_axis(t, conv, ax, self.dims)
|
|
110
|
+
# Trim to target spatial size
|
|
111
|
+
t = t[:, :target, :target, :target, :]
|
|
112
|
+
# Project to raw band channels
|
|
113
|
+
shape = tf.shape(t)
|
|
114
|
+
t_flat = tf.reshape(t, [shape[0] * shape[1] * shape[2] * shape[3], 1, self.F])
|
|
115
|
+
t_out = self._out_convs[l](t_flat)
|
|
116
|
+
t_out = tf.reshape(t_out, [shape[0], shape[1], shape[2], shape[3], self.out_channels])
|
|
117
|
+
hp_bands.append(t_out)
|
|
118
|
+
|
|
119
|
+
# LP: only at deepest level
|
|
120
|
+
if l == self.levels - 1:
|
|
121
|
+
lp_shape = tf.shape(lp_compressed)
|
|
122
|
+
lp_flat = tf.reshape(lp_compressed, [lp_shape[0] * lp_shape[1] * lp_shape[2] * lp_shape[3], 1, lp_shape[-1]])
|
|
123
|
+
lp_out = self._lp_out(lp_flat)
|
|
124
|
+
lp_out = tf.reshape(lp_out, [lp_shape[0], lp_shape[1], lp_shape[2], lp_shape[3], self.out_channels])
|
|
125
|
+
else:
|
|
126
|
+
lp_out = None # non-deepest LP is passthrough in IDWT
|
|
127
|
+
|
|
128
|
+
bands.append({'lp': lp_out, 'hp': hp_bands})
|
|
129
|
+
|
|
130
|
+
# For IDWT, non-deepest LP comes from the reconstruction cascade; set to zeros placeholder
|
|
131
|
+
for l in range(self.levels - 1):
|
|
132
|
+
if bands[l]['lp'] is None:
|
|
133
|
+
hp_shape = tf.shape(bands[l]['hp'][0])
|
|
134
|
+
bands[l]['lp'] = tf.zeros_like(bands[l]['hp'][0])
|
|
135
|
+
|
|
136
|
+
return bands
|
|
137
|
+
|
|
138
|
+
def get_config(self):
|
|
139
|
+
config = super().get_config()
|
|
140
|
+
config.update({
|
|
141
|
+
'levels': self.levels, 'groups': self.groups,
|
|
142
|
+
'token_dim': self.F, 'stride': self.stride,
|
|
143
|
+
'target_spatials': self.target_spatials,
|
|
144
|
+
'out_channels': self.out_channels,
|
|
145
|
+
})
|
|
146
|
+
return config
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""UWaveletViTNet3D: 3D U-Net with a multilevel wavelet-band transformer bottleneck.
|
|
2
|
+
|
|
3
|
+
Copyright 2026 Kishore Kumar Tarafdar.
|
|
4
|
+
Licensed under the Apache License, Version 2.0. See LICENSE for details.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow import keras
|
|
10
|
+
from tensorflow.keras.layers import (
|
|
11
|
+
Activation,
|
|
12
|
+
Concatenate,
|
|
13
|
+
Conv3D,
|
|
14
|
+
Conv3DTranspose,
|
|
15
|
+
Input,
|
|
16
|
+
MaxPool3D,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from waveletvit.WaveletViT import (
|
|
20
|
+
WaveletAssembler3D,
|
|
21
|
+
WaveletTokenizer3D,
|
|
22
|
+
WaveletViTBlock3D,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"waveletvit_bottleneck_factory",
|
|
27
|
+
"UWaveletViTNet3D",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _default_conv(filters, x, activation="relu"):
|
|
32
|
+
x = Conv3D(filters, 3, padding="same")(x)
|
|
33
|
+
x = Conv3D(filters, 3, padding="same")(x)
|
|
34
|
+
return Activation(activation)(x)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def waveletvit_bottleneck_factory(
|
|
38
|
+
*,
|
|
39
|
+
levels: int = 1,
|
|
40
|
+
wave: str = "haar",
|
|
41
|
+
token_dim: int = 128,
|
|
42
|
+
stride: int = 4,
|
|
43
|
+
heads: int = 4,
|
|
44
|
+
key_dim: int = 16,
|
|
45
|
+
strength: float = 0.25,
|
|
46
|
+
level_decay: float = 1.0,
|
|
47
|
+
use_orient_mix: bool = True,
|
|
48
|
+
lp_global_if_small: bool = True,
|
|
49
|
+
max_tokens_global: int = 8192,
|
|
50
|
+
qkv_bias: bool = True,
|
|
51
|
+
use_ffn: bool = True,
|
|
52
|
+
ffn_hidden_mult: float = 2.0,
|
|
53
|
+
ffn_dropout: float = 0.0,
|
|
54
|
+
kl_scale: float = 0.0,
|
|
55
|
+
sample_at_inference: bool = False,
|
|
56
|
+
tokenizer_norm: bool = True,
|
|
57
|
+
dwt_compute_dtype: str = "float32",
|
|
58
|
+
prebuild_banks: bool = True,
|
|
59
|
+
pad_to_wavelet_multiple: bool = False,
|
|
60
|
+
):
|
|
61
|
+
"""Create a multilevel wavelet-band transformer bottleneck callable for a 3D U-Net.
|
|
62
|
+
|
|
63
|
+
The WaveletViT block modulates raw DWT coefficient bands and the
|
|
64
|
+
assembler reconstructs a feature map with the same channel width as the
|
|
65
|
+
incoming U-Net bottleneck tensor. ``token_dim`` controls the internal
|
|
66
|
+
compressed attention width, not the output channel count.
|
|
67
|
+
"""
|
|
68
|
+
tokenizer = WaveletTokenizer3D(
|
|
69
|
+
levels=int(levels),
|
|
70
|
+
wave=wave,
|
|
71
|
+
norm=bool(tokenizer_norm),
|
|
72
|
+
pad_to_wavelet_multiple=bool(pad_to_wavelet_multiple),
|
|
73
|
+
dwt_compute_dtype=dwt_compute_dtype,
|
|
74
|
+
prebuild_banks=bool(prebuild_banks),
|
|
75
|
+
)
|
|
76
|
+
block = WaveletViTBlock3D(
|
|
77
|
+
heads=int(heads),
|
|
78
|
+
key_dim=int(key_dim),
|
|
79
|
+
strength=float(strength),
|
|
80
|
+
level_decay=float(level_decay),
|
|
81
|
+
use_orient_mix=bool(use_orient_mix),
|
|
82
|
+
lp_global_if_small=bool(lp_global_if_small),
|
|
83
|
+
max_tokens_global=int(max_tokens_global),
|
|
84
|
+
qkv_bias=bool(qkv_bias),
|
|
85
|
+
use_ffn=bool(use_ffn),
|
|
86
|
+
ffn_hidden_mult=float(ffn_hidden_mult),
|
|
87
|
+
ffn_dropout=float(ffn_dropout),
|
|
88
|
+
kl_scale=float(kl_scale),
|
|
89
|
+
sample_at_inference=bool(sample_at_inference),
|
|
90
|
+
token_dim=int(token_dim),
|
|
91
|
+
stride=int(stride),
|
|
92
|
+
)
|
|
93
|
+
assembler = WaveletAssembler3D(
|
|
94
|
+
levels=int(levels),
|
|
95
|
+
wave=wave,
|
|
96
|
+
mode="feature",
|
|
97
|
+
dwt_compute_dtype=dwt_compute_dtype,
|
|
98
|
+
prebuild_banks=bool(prebuild_banks),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def bottleneck(x: tf.Tensor, training=None) -> tf.Tensor:
|
|
102
|
+
bands = tokenizer(x)
|
|
103
|
+
bands = block(bands, training=training)
|
|
104
|
+
y = assembler(bands)
|
|
105
|
+
return tf.cast(y, x.dtype) if y.dtype != x.dtype else y
|
|
106
|
+
|
|
107
|
+
return bottleneck
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def UWaveletViTNet3D(
|
|
111
|
+
input_shape=(32, 32, 32, 1),
|
|
112
|
+
config=(16, 32, 64, 128),
|
|
113
|
+
n_classes=4,
|
|
114
|
+
*,
|
|
115
|
+
token_dim: int = 128,
|
|
116
|
+
levels: int = 1,
|
|
117
|
+
wave: str = "haar",
|
|
118
|
+
stride: int = 4,
|
|
119
|
+
heads: int = 4,
|
|
120
|
+
key_dim: int = 16,
|
|
121
|
+
strength: float = 0.25,
|
|
122
|
+
level_decay: float = 1.0,
|
|
123
|
+
use_orient_mix: bool = True,
|
|
124
|
+
lp_global_if_small: bool = True,
|
|
125
|
+
max_tokens_global: int = 8192,
|
|
126
|
+
qkv_bias: bool = True,
|
|
127
|
+
use_ffn: bool = True,
|
|
128
|
+
ffn_hidden_mult: float = 2.0,
|
|
129
|
+
ffn_dropout: float = 0.0,
|
|
130
|
+
kl_scale: float = 0.0,
|
|
131
|
+
sample_at_inference: bool = False,
|
|
132
|
+
tokenizer_norm: bool = True,
|
|
133
|
+
dwt_compute_dtype: str = "float32",
|
|
134
|
+
prebuild_banks: bool = True,
|
|
135
|
+
output_kernel_regularizer=None,
|
|
136
|
+
one_hot_encode=True,
|
|
137
|
+
residual=False,
|
|
138
|
+
):
|
|
139
|
+
"""Build a 3D U-Net with one multilevel wavelet-band transformer module at the bottleneck."""
|
|
140
|
+
bottleneck = waveletvit_bottleneck_factory(
|
|
141
|
+
levels=levels,
|
|
142
|
+
wave=wave,
|
|
143
|
+
token_dim=token_dim,
|
|
144
|
+
stride=stride,
|
|
145
|
+
heads=heads,
|
|
146
|
+
key_dim=key_dim,
|
|
147
|
+
strength=strength,
|
|
148
|
+
level_decay=level_decay,
|
|
149
|
+
use_orient_mix=use_orient_mix,
|
|
150
|
+
lp_global_if_small=lp_global_if_small,
|
|
151
|
+
max_tokens_global=max_tokens_global,
|
|
152
|
+
qkv_bias=qkv_bias,
|
|
153
|
+
use_ffn=use_ffn,
|
|
154
|
+
ffn_hidden_mult=ffn_hidden_mult,
|
|
155
|
+
ffn_dropout=ffn_dropout,
|
|
156
|
+
kl_scale=kl_scale,
|
|
157
|
+
sample_at_inference=sample_at_inference,
|
|
158
|
+
tokenizer_norm=tokenizer_norm,
|
|
159
|
+
dwt_compute_dtype=dwt_compute_dtype,
|
|
160
|
+
prebuild_banks=prebuild_banks,
|
|
161
|
+
pad_to_wavelet_multiple=False,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
inputs = Input(shape=input_shape, name="input")
|
|
165
|
+
x = inputs
|
|
166
|
+
skips = []
|
|
167
|
+
for filters in config:
|
|
168
|
+
x = _default_conv(int(filters), x)
|
|
169
|
+
skips.append(x)
|
|
170
|
+
x = MaxPool3D(pool_size=2)(x)
|
|
171
|
+
|
|
172
|
+
x = bottleneck(x)
|
|
173
|
+
if x.shape[-1] is None or int(x.shape[-1]) != int(config[-1]):
|
|
174
|
+
x = Conv3D(int(config[-1]), 1, padding="same", name="bottleneck_channel_align")(x)
|
|
175
|
+
|
|
176
|
+
for i, filters in reversed(list(enumerate(config))):
|
|
177
|
+
filters = int(filters)
|
|
178
|
+
x = Conv3DTranspose(filters, 2, strides=2, padding="same")(x)
|
|
179
|
+
x = Concatenate()([x, skips[i]])
|
|
180
|
+
x = _default_conv(filters, x)
|
|
181
|
+
|
|
182
|
+
if residual:
|
|
183
|
+
x = Concatenate()([x, inputs])
|
|
184
|
+
|
|
185
|
+
outputs = Conv3D(
|
|
186
|
+
n_classes,
|
|
187
|
+
1,
|
|
188
|
+
kernel_regularizer=output_kernel_regularizer,
|
|
189
|
+
activation="softmax" if one_hot_encode else None,
|
|
190
|
+
name="segmentation_logits",
|
|
191
|
+
)(x)
|
|
192
|
+
model_name = "CVUWaveletViTNet3D" if float(kl_scale) > 0.0 else "UWaveletViTNet3D"
|
|
193
|
+
return keras.Model(inputs, outputs, name=model_name)
|
|
194
|
+
|
|
195
|
+
if __name__ == "__main__":
|
|
196
|
+
tf.get_logger().setLevel("ERROR")
|
|
197
|
+
tf.random.set_seed(0)
|
|
198
|
+
n = 32
|
|
199
|
+
model = UWaveletViTNet3D(
|
|
200
|
+
input_shape=(n, n, n, 1),
|
|
201
|
+
config=(16, 32, 64, 128),
|
|
202
|
+
n_classes=4,
|
|
203
|
+
token_dim=64,
|
|
204
|
+
levels=1,
|
|
205
|
+
wave="haar",
|
|
206
|
+
heads=2,
|
|
207
|
+
key_dim=16,
|
|
208
|
+
strength=0.5,
|
|
209
|
+
level_decay=1.0,
|
|
210
|
+
)
|
|
211
|
+
y = model(tf.random.normal((1, n, n, n, 1)))
|
|
212
|
+
model.summary()
|
|
213
|
+
print("out:", y.shape)
|