waveletvit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,146 @@
1
+ """BandDecoder: reconstruct full-resolution wavelet bands from compressed grids.
2
+
3
+ Reverse of the separable strided conv tokenization in WaveletViTBlockND.
4
+ Per-level decoder shared across HP bands. Uses axis-by-axis ConvTranspose1D.
5
+
6
+ Copyright 2026 Kishore Kumar Tarafdar.
7
+ Licensed under the Apache License, Version 2.0. See LICENSE for details.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import List
12
+
13
+ import tensorflow as tf
14
+ from tensorflow import keras
15
+ from tensorflow.keras import layers
16
+
17
+
18
+ def _apply_convtranspose1d_along_axis(x, conv, axis, dims=3):
19
+ """Apply Conv1D transpose along a specific spatial axis of an N-D tensor."""
20
+ spatial = list(range(1, 1 + dims))
21
+ target = 1 + axis
22
+ perm = [0, target] + [s for s in spatial if s != target] + [dims + 1]
23
+ x_t = tf.transpose(x, perm)
24
+ shape_t = tf.shape(x_t)
25
+ T = shape_t[1]
26
+ C = shape_t[-1]
27
+ others = tf.reduce_prod(shape_t[2:-1])
28
+ x_flat = tf.reshape(x_t, [shape_t[0] * others, T, C])
29
+ y_flat = conv(x_flat)
30
+ T_out = tf.shape(y_flat)[1]
31
+ C_out = tf.shape(y_flat)[-1]
32
+ y_t = tf.reshape(y_flat, tf.concat([shape_t[:1], [T_out], shape_t[2:-1], [C_out]], axis=0))
33
+ inv_perm = [0] * (dims + 2)
34
+ for i, p in enumerate(perm):
35
+ inv_perm[p] = i
36
+ return tf.transpose(y_t, inv_perm)
37
+
38
+
39
+ @tf.keras.utils.register_keras_serializable(package='waveletvit')
40
+ class BandDecoder(keras.layers.Layer):
41
+ """Decode compressed 3D grids back to full-resolution wavelet bands.
42
+
43
+ Reverse of WaveletViTBlockND's separable strided conv tokenization.
44
+ Per-level decoder: shared ConvTranspose1D across 7 HP bands at same level.
45
+
46
+ Args:
47
+ levels: number of DWT levels.
48
+ groups: number of HP bands per level (7 for 3D DWT).
49
+ token_dim: channel dim F of compressed grids.
50
+ stride: spatial stride factor used in tokenization.
51
+ target_spatials: list of target spatial sizes per level, e.g. [32, 16, 8].
52
+ out_channels: number of channels in reconstructed wavelet bands.
53
+ """
54
+
55
+ def __init__(self, levels: int, groups: int, token_dim: int, stride: int,
56
+ target_spatials: List[int], out_channels: int = 1,
57
+ name: str = 'band_decoder', **kwargs):
58
+ super().__init__(name=name, **kwargs)
59
+ self.levels = int(levels)
60
+ self.groups = int(groups)
61
+ self.F = int(token_dim)
62
+ self.stride = int(stride)
63
+ self.target_spatials = [int(s) for s in target_spatials]
64
+ self.out_channels = int(out_channels)
65
+ self.dims = 3
66
+
67
+ # Per-level: 3 ConvTranspose1D (one per axis) + final Conv1D to raw band channels
68
+ self._up_convs = []
69
+ self._out_convs = []
70
+ Conv1DT = layers.Conv1DTranspose
71
+ for l in range(self.levels):
72
+ s = min(self.stride, self.target_spatials[l])
73
+ axis_convs = [
74
+ Conv1DT(self.F, s, strides=s, padding='valid', activation='relu',
75
+ name=f'{name}_l{l}_ax{ax}')
76
+ for ax in range(self.dims)
77
+ ]
78
+ self._up_convs.append(axis_convs)
79
+ self._out_convs.append(layers.Conv1D(self.out_channels, 1,
80
+ kernel_initializer='zeros', bias_initializer='zeros',
81
+ name=f'{name}_l{l}_out'))
82
+
83
+ # LP decoder: simple Conv to raw band channels; zero-init for stable start
84
+ self._lp_out = layers.Conv1D(self.out_channels, 1,
85
+ kernel_initializer='zeros', bias_initializer='zeros',
86
+ name=f'{name}_lp_out')
87
+
88
+ def call(self, compressed_per_level: List[tf.Tensor], lp_compressed: tf.Tensor):
89
+ """Decode compressed features to raw wavelet bands.
90
+
91
+ Args:
92
+ compressed_per_level: list of (B, comp^3, groups*F) stacked features per level.
93
+ lp_compressed: (B, comp_lp^3, F_lp); deepest LP compressed features.
94
+
95
+ Returns:
96
+ bands: list[level] of {'lp': Tensor, 'hp': List[Tensor]}; raw wavelet bands.
97
+ """
98
+ bands = []
99
+ for l in range(self.levels):
100
+ stacked = compressed_per_level[l] # (B, comp^3, groups*F)
101
+ # Split into per-band features
102
+ hp_feats = tf.split(stacked, self.groups, axis=-1) # list of (B, comp^3, F)
103
+
104
+ hp_bands = []
105
+ target = self.target_spatials[l]
106
+ for g, feat in enumerate(hp_feats):
107
+ t = feat
108
+ for ax, conv in enumerate(self._up_convs[l]):
109
+ t = _apply_convtranspose1d_along_axis(t, conv, ax, self.dims)
110
+ # Trim to target spatial size
111
+ t = t[:, :target, :target, :target, :]
112
+ # Project to raw band channels
113
+ shape = tf.shape(t)
114
+ t_flat = tf.reshape(t, [shape[0] * shape[1] * shape[2] * shape[3], 1, self.F])
115
+ t_out = self._out_convs[l](t_flat)
116
+ t_out = tf.reshape(t_out, [shape[0], shape[1], shape[2], shape[3], self.out_channels])
117
+ hp_bands.append(t_out)
118
+
119
+ # LP: only at deepest level
120
+ if l == self.levels - 1:
121
+ lp_shape = tf.shape(lp_compressed)
122
+ lp_flat = tf.reshape(lp_compressed, [lp_shape[0] * lp_shape[1] * lp_shape[2] * lp_shape[3], 1, lp_shape[-1]])
123
+ lp_out = self._lp_out(lp_flat)
124
+ lp_out = tf.reshape(lp_out, [lp_shape[0], lp_shape[1], lp_shape[2], lp_shape[3], self.out_channels])
125
+ else:
126
+ lp_out = None # non-deepest LP is passthrough in IDWT
127
+
128
+ bands.append({'lp': lp_out, 'hp': hp_bands})
129
+
130
+ # For IDWT, non-deepest LP comes from the reconstruction cascade; set to zeros placeholder
131
+ for l in range(self.levels - 1):
132
+ if bands[l]['lp'] is None:
133
+ hp_shape = tf.shape(bands[l]['hp'][0])
134
+ bands[l]['lp'] = tf.zeros_like(bands[l]['hp'][0])
135
+
136
+ return bands
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update({
141
+ 'levels': self.levels, 'groups': self.groups,
142
+ 'token_dim': self.F, 'stride': self.stride,
143
+ 'target_spatials': self.target_spatials,
144
+ 'out_channels': self.out_channels,
145
+ })
146
+ return config
@@ -0,0 +1,213 @@
1
+ """UWaveletViTNet3D: 3D U-Net with a multilevel wavelet-band transformer bottleneck.
2
+
3
+ Copyright 2026 Kishore Kumar Tarafdar.
4
+ Licensed under the Apache License, Version 2.0. See LICENSE for details.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import tensorflow as tf
9
+ from tensorflow import keras
10
+ from tensorflow.keras.layers import (
11
+ Activation,
12
+ Concatenate,
13
+ Conv3D,
14
+ Conv3DTranspose,
15
+ Input,
16
+ MaxPool3D,
17
+ )
18
+
19
+ from waveletvit.WaveletViT import (
20
+ WaveletAssembler3D,
21
+ WaveletTokenizer3D,
22
+ WaveletViTBlock3D,
23
+ )
24
+
25
+ __all__ = [
26
+ "waveletvit_bottleneck_factory",
27
+ "UWaveletViTNet3D",
28
+ ]
29
+
30
+
31
+ def _default_conv(filters, x, activation="relu"):
32
+ x = Conv3D(filters, 3, padding="same")(x)
33
+ x = Conv3D(filters, 3, padding="same")(x)
34
+ return Activation(activation)(x)
35
+
36
+
37
+ def waveletvit_bottleneck_factory(
38
+ *,
39
+ levels: int = 1,
40
+ wave: str = "haar",
41
+ token_dim: int = 128,
42
+ stride: int = 4,
43
+ heads: int = 4,
44
+ key_dim: int = 16,
45
+ strength: float = 0.25,
46
+ level_decay: float = 1.0,
47
+ use_orient_mix: bool = True,
48
+ lp_global_if_small: bool = True,
49
+ max_tokens_global: int = 8192,
50
+ qkv_bias: bool = True,
51
+ use_ffn: bool = True,
52
+ ffn_hidden_mult: float = 2.0,
53
+ ffn_dropout: float = 0.0,
54
+ kl_scale: float = 0.0,
55
+ sample_at_inference: bool = False,
56
+ tokenizer_norm: bool = True,
57
+ dwt_compute_dtype: str = "float32",
58
+ prebuild_banks: bool = True,
59
+ pad_to_wavelet_multiple: bool = False,
60
+ ):
61
+ """Create a multilevel wavelet-band transformer bottleneck callable for a 3D U-Net.
62
+
63
+ The WaveletViT block modulates raw DWT coefficient bands and the
64
+ assembler reconstructs a feature map with the same channel width as the
65
+ incoming U-Net bottleneck tensor. ``token_dim`` controls the internal
66
+ compressed attention width, not the output channel count.
67
+ """
68
+ tokenizer = WaveletTokenizer3D(
69
+ levels=int(levels),
70
+ wave=wave,
71
+ norm=bool(tokenizer_norm),
72
+ pad_to_wavelet_multiple=bool(pad_to_wavelet_multiple),
73
+ dwt_compute_dtype=dwt_compute_dtype,
74
+ prebuild_banks=bool(prebuild_banks),
75
+ )
76
+ block = WaveletViTBlock3D(
77
+ heads=int(heads),
78
+ key_dim=int(key_dim),
79
+ strength=float(strength),
80
+ level_decay=float(level_decay),
81
+ use_orient_mix=bool(use_orient_mix),
82
+ lp_global_if_small=bool(lp_global_if_small),
83
+ max_tokens_global=int(max_tokens_global),
84
+ qkv_bias=bool(qkv_bias),
85
+ use_ffn=bool(use_ffn),
86
+ ffn_hidden_mult=float(ffn_hidden_mult),
87
+ ffn_dropout=float(ffn_dropout),
88
+ kl_scale=float(kl_scale),
89
+ sample_at_inference=bool(sample_at_inference),
90
+ token_dim=int(token_dim),
91
+ stride=int(stride),
92
+ )
93
+ assembler = WaveletAssembler3D(
94
+ levels=int(levels),
95
+ wave=wave,
96
+ mode="feature",
97
+ dwt_compute_dtype=dwt_compute_dtype,
98
+ prebuild_banks=bool(prebuild_banks),
99
+ )
100
+
101
+ def bottleneck(x: tf.Tensor, training=None) -> tf.Tensor:
102
+ bands = tokenizer(x)
103
+ bands = block(bands, training=training)
104
+ y = assembler(bands)
105
+ return tf.cast(y, x.dtype) if y.dtype != x.dtype else y
106
+
107
+ return bottleneck
108
+
109
+
110
+ def UWaveletViTNet3D(
111
+ input_shape=(32, 32, 32, 1),
112
+ config=(16, 32, 64, 128),
113
+ n_classes=4,
114
+ *,
115
+ token_dim: int = 128,
116
+ levels: int = 1,
117
+ wave: str = "haar",
118
+ stride: int = 4,
119
+ heads: int = 4,
120
+ key_dim: int = 16,
121
+ strength: float = 0.25,
122
+ level_decay: float = 1.0,
123
+ use_orient_mix: bool = True,
124
+ lp_global_if_small: bool = True,
125
+ max_tokens_global: int = 8192,
126
+ qkv_bias: bool = True,
127
+ use_ffn: bool = True,
128
+ ffn_hidden_mult: float = 2.0,
129
+ ffn_dropout: float = 0.0,
130
+ kl_scale: float = 0.0,
131
+ sample_at_inference: bool = False,
132
+ tokenizer_norm: bool = True,
133
+ dwt_compute_dtype: str = "float32",
134
+ prebuild_banks: bool = True,
135
+ output_kernel_regularizer=None,
136
+ one_hot_encode=True,
137
+ residual=False,
138
+ ):
139
+ """Build a 3D U-Net with one multilevel wavelet-band transformer module at the bottleneck."""
140
+ bottleneck = waveletvit_bottleneck_factory(
141
+ levels=levels,
142
+ wave=wave,
143
+ token_dim=token_dim,
144
+ stride=stride,
145
+ heads=heads,
146
+ key_dim=key_dim,
147
+ strength=strength,
148
+ level_decay=level_decay,
149
+ use_orient_mix=use_orient_mix,
150
+ lp_global_if_small=lp_global_if_small,
151
+ max_tokens_global=max_tokens_global,
152
+ qkv_bias=qkv_bias,
153
+ use_ffn=use_ffn,
154
+ ffn_hidden_mult=ffn_hidden_mult,
155
+ ffn_dropout=ffn_dropout,
156
+ kl_scale=kl_scale,
157
+ sample_at_inference=sample_at_inference,
158
+ tokenizer_norm=tokenizer_norm,
159
+ dwt_compute_dtype=dwt_compute_dtype,
160
+ prebuild_banks=prebuild_banks,
161
+ pad_to_wavelet_multiple=False,
162
+ )
163
+
164
+ inputs = Input(shape=input_shape, name="input")
165
+ x = inputs
166
+ skips = []
167
+ for filters in config:
168
+ x = _default_conv(int(filters), x)
169
+ skips.append(x)
170
+ x = MaxPool3D(pool_size=2)(x)
171
+
172
+ x = bottleneck(x)
173
+ if x.shape[-1] is None or int(x.shape[-1]) != int(config[-1]):
174
+ x = Conv3D(int(config[-1]), 1, padding="same", name="bottleneck_channel_align")(x)
175
+
176
+ for i, filters in reversed(list(enumerate(config))):
177
+ filters = int(filters)
178
+ x = Conv3DTranspose(filters, 2, strides=2, padding="same")(x)
179
+ x = Concatenate()([x, skips[i]])
180
+ x = _default_conv(filters, x)
181
+
182
+ if residual:
183
+ x = Concatenate()([x, inputs])
184
+
185
+ outputs = Conv3D(
186
+ n_classes,
187
+ 1,
188
+ kernel_regularizer=output_kernel_regularizer,
189
+ activation="softmax" if one_hot_encode else None,
190
+ name="segmentation_logits",
191
+ )(x)
192
+ model_name = "CVUWaveletViTNet3D" if float(kl_scale) > 0.0 else "UWaveletViTNet3D"
193
+ return keras.Model(inputs, outputs, name=model_name)
194
+
195
+ if __name__ == "__main__":
196
+ tf.get_logger().setLevel("ERROR")
197
+ tf.random.set_seed(0)
198
+ n = 32
199
+ model = UWaveletViTNet3D(
200
+ input_shape=(n, n, n, 1),
201
+ config=(16, 32, 64, 128),
202
+ n_classes=4,
203
+ token_dim=64,
204
+ levels=1,
205
+ wave="haar",
206
+ heads=2,
207
+ key_dim=16,
208
+ strength=0.5,
209
+ level_decay=1.0,
210
+ )
211
+ y = model(tf.random.normal((1, n, n, n, 1)))
212
+ model.summary()
213
+ print("out:", y.shape)