nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/examples/language/mingpt.py +1650 -0
- nmn/nnx/examples/vision/cnn_cifar.py +1769 -0
- nmn/nnx/nmn.py +1 -1
- nmn/nnx/yatattention.py +764 -0
- nmn/nnx/yatconv.py +22 -2
- nmn/torch/nmn.py +2 -1
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/METADATA +2 -2
- nmn-0.1.4.dist-info/RECORD +14 -0
- nmn-0.1.2.dist-info/RECORD +0 -11
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/WHEEL +0 -0
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1769 @@
|
|
1
|
+
import typing as tp
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import numpy as np
|
6
|
+
from jax import lax
|
7
|
+
|
8
|
+
# Import libraries for comparison functions
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
import seaborn as sns
|
11
|
+
from sklearn.metrics import confusion_matrix
|
12
|
+
|
13
|
+
from flax import nnx
|
14
|
+
from flax.nnx.module import Module
|
15
|
+
from flax.nnx import rnglib
|
16
|
+
from flax.nnx.nn import dtypes, initializers
|
17
|
+
from flax.typing import (
|
18
|
+
Dtype,
|
19
|
+
Initializer,
|
20
|
+
PrecisionLike,
|
21
|
+
ConvGeneralDilatedT,
|
22
|
+
PaddingLike,
|
23
|
+
LaxPadding,
|
24
|
+
PromoteDtypeFn,
|
25
|
+
)
|
26
|
+
|
27
|
+
from __future__ import annotations
|
28
|
+
|
29
|
+
import typing as tp
|
30
|
+
|
31
|
+
import jax
|
32
|
+
import jax.numpy as jnp
|
33
|
+
import numpy as np
|
34
|
+
from jax import lax
|
35
|
+
import opt_einsum
|
36
|
+
|
37
|
+
from flax.core.frozen_dict import FrozenDict
|
38
|
+
from flax import nnx
|
39
|
+
from flax.nnx import rnglib, variablelib
|
40
|
+
from flax.nnx.module import Module, first_from
|
41
|
+
from flax.nnx.nn import dtypes, initializers
|
42
|
+
from flax.typing import (
|
43
|
+
Dtype,
|
44
|
+
Shape,
|
45
|
+
Initializer,
|
46
|
+
PrecisionLike,
|
47
|
+
DotGeneralT,
|
48
|
+
ConvGeneralDilatedT,
|
49
|
+
PaddingLike,
|
50
|
+
LaxPadding,
|
51
|
+
PromoteDtypeFn,
|
52
|
+
EinsumT,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
import tensorflow_datasets as tfds
|
57
|
+
import tensorflow as tf
|
58
|
+
|
59
|
+
from flax import nnx # The Flax NNX API.
|
60
|
+
from functools import partial
|
61
|
+
import optax
|
62
|
+
|
63
|
+
Array = jax.Array
|
64
|
+
|
65
|
+
# Default initializers
|
66
|
+
default_kernel_init = initializers.lecun_normal()
|
67
|
+
default_bias_init = initializers.zeros_init()
|
68
|
+
default_alpha_init = initializers.ones_init()
|
69
|
+
|
70
|
+
# Helper functions
|
71
|
+
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
72
|
+
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
73
|
+
if isinstance(padding, str):
|
74
|
+
return padding
|
75
|
+
if isinstance(padding, int):
|
76
|
+
return [(padding, padding)] * rank
|
77
|
+
if isinstance(padding, tp.Sequence) and len(padding) == rank:
|
78
|
+
new_pad = []
|
79
|
+
for p in padding:
|
80
|
+
if isinstance(p, int):
|
81
|
+
new_pad.append((p, p))
|
82
|
+
elif isinstance(p, tuple) and len(p) == 2:
|
83
|
+
new_pad.append(p)
|
84
|
+
else:
|
85
|
+
break
|
86
|
+
if len(new_pad) == rank:
|
87
|
+
return new_pad
|
88
|
+
raise ValueError(
|
89
|
+
f'Invalid padding format: {padding}, should be str, int,'
|
90
|
+
f' or a sequence of len {rank} where each element is an'
|
91
|
+
' int or pair of ints.'
|
92
|
+
)
|
93
|
+
|
94
|
+
def _conv_dimension_numbers(input_shape):
|
95
|
+
"""Computes the dimension numbers based on the input shape."""
|
96
|
+
ndim = len(input_shape)
|
97
|
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
98
|
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
99
|
+
out_spec = lhs_spec
|
100
|
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
101
|
+
|
102
|
+
class YatConv(Module):
|
103
|
+
"""Yat Convolution Module wrapping ``lax.conv_general_dilated``.
|
104
|
+
|
105
|
+
Example usage::
|
106
|
+
|
107
|
+
>>> from flax.nnx import conv # Assuming this file is flax/nnx/conv.py
|
108
|
+
>>> import jax, jax.numpy as jnp
|
109
|
+
>>> from flax.nnx import rnglib, state
|
110
|
+
|
111
|
+
>>> rngs = rnglib.Rngs(0)
|
112
|
+
>>> x = jnp.ones((1, 8, 3))
|
113
|
+
|
114
|
+
>>> # valid padding
|
115
|
+
>>> layer = conv.Conv(in_features=3, out_features=4, kernel_size=(3,),
|
116
|
+
... padding='VALID', rngs=rngs)
|
117
|
+
>>> s = state(layer)
|
118
|
+
>>> print(s['kernel'].value.shape)
|
119
|
+
(3, 3, 4)
|
120
|
+
>>> print(s['bias'].value.shape)
|
121
|
+
(4,)
|
122
|
+
>>> out = layer(x)
|
123
|
+
>>> print(out.shape)
|
124
|
+
(1, 6, 4)
|
125
|
+
|
126
|
+
Args:
|
127
|
+
in_features: int or tuple with number of input features.
|
128
|
+
out_features: int or tuple with number of output features.
|
129
|
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
130
|
+
the kernel size can be passed as an integer, which will be interpreted
|
131
|
+
as a tuple of the single integer. For all other cases, it must be a
|
132
|
+
sequence of integers.
|
133
|
+
strides: an integer or a sequence of ``n`` integers, representing the
|
134
|
+
inter-window strides (default: 1).
|
135
|
+
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
|
136
|
+
``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
|
137
|
+
(reflection across the padding boundary), or a sequence of ``n``
|
138
|
+
``(low, high)`` integer pairs that give the padding to apply before and after each
|
139
|
+
spatial dimension. A single int is interpeted as applying the same padding
|
140
|
+
in all dims and passign a single int in a sequence causes the same padding
|
141
|
+
to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
|
142
|
+
left-pad the convolution axis, resulting in same-sized output.
|
143
|
+
input_dilation: an integer or a sequence of ``n`` integers, giving the
|
144
|
+
dilation factor to apply in each spatial dimension of ``inputs``
|
145
|
+
(default: 1). Convolution with input dilation ``d`` is equivalent to
|
146
|
+
transposed convolution with stride ``d``.
|
147
|
+
kernel_dilation: an integer or a sequence of ``n`` integers, giving the
|
148
|
+
dilation factor to apply in each spatial dimension of the convolution
|
149
|
+
kernel (default: 1). Convolution with kernel dilation
|
150
|
+
is also known as 'atrous convolution'.
|
151
|
+
feature_group_count: integer, default 1. If specified divides the input
|
152
|
+
features into groups.
|
153
|
+
use_bias: whether to add a bias to the output (default: True).
|
154
|
+
mask: Optional mask for the weights during masked convolution. The mask must
|
155
|
+
be the same shape as the convolution weight matrix.
|
156
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
157
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
158
|
+
precision: numerical precision of the computation see ``jax.lax.Precision``
|
159
|
+
for details.
|
160
|
+
kernel_init: initializer for the convolutional kernel.
|
161
|
+
bias_init: initializer for the bias.
|
162
|
+
promote_dtype: function to promote the dtype of the arrays to the desired
|
163
|
+
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
164
|
+
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
165
|
+
promoted dtype.
|
166
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
167
|
+
rngs: rng key.
|
168
|
+
"""
|
169
|
+
|
170
|
+
__data__ = ('kernel', 'bias', 'mask', 'alpha')
|
171
|
+
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
in_features: int,
|
175
|
+
out_features: int,
|
176
|
+
kernel_size: int | tp.Sequence[int],
|
177
|
+
strides: tp.Union[None, int, tp.Sequence[int]] = 1,
|
178
|
+
*,
|
179
|
+
padding: PaddingLike = 'SAME',
|
180
|
+
input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
181
|
+
kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
182
|
+
feature_group_count: int = 1,
|
183
|
+
use_bias: bool = True,
|
184
|
+
mask: tp.Optional[Array] = None,
|
185
|
+
dtype: tp.Optional[Dtype] = None,
|
186
|
+
param_dtype: Dtype = jnp.float32,
|
187
|
+
precision: PrecisionLike = None,
|
188
|
+
kernel_init: Initializer = default_kernel_init,
|
189
|
+
bias_init: Initializer = default_bias_init,
|
190
|
+
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
|
191
|
+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
192
|
+
epsilon: float = 1/137,
|
193
|
+
use_alpha: bool = True,
|
194
|
+
alpha_init: Initializer = default_alpha_init,
|
195
|
+
|
196
|
+
rngs: rnglib.Rngs,
|
197
|
+
):
|
198
|
+
if isinstance(kernel_size, int):
|
199
|
+
kernel_size = (kernel_size,)
|
200
|
+
else:
|
201
|
+
kernel_size = tuple(kernel_size)
|
202
|
+
|
203
|
+
self.kernel_shape = kernel_size + (
|
204
|
+
in_features // feature_group_count,
|
205
|
+
out_features,
|
206
|
+
)
|
207
|
+
kernel_key = rngs.params()
|
208
|
+
self.kernel = nnx.Param(kernel_init(kernel_key, self.kernel_shape, param_dtype))
|
209
|
+
|
210
|
+
self.bias: nnx.Param[jax.Array] | None
|
211
|
+
if use_bias:
|
212
|
+
bias_shape = (out_features,)
|
213
|
+
bias_key = rngs.params()
|
214
|
+
self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype))
|
215
|
+
else:
|
216
|
+
self.bias = None
|
217
|
+
|
218
|
+
|
219
|
+
self.alpha: nnx.Param[jax.Array] | None
|
220
|
+
if use_alpha:
|
221
|
+
alpha_key = rngs.params()
|
222
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
223
|
+
else:
|
224
|
+
self.alpha = None
|
225
|
+
|
226
|
+
|
227
|
+
self.in_features = in_features
|
228
|
+
self.out_features = out_features
|
229
|
+
self.kernel_size = kernel_size
|
230
|
+
self.strides = strides
|
231
|
+
self.padding = padding
|
232
|
+
self.input_dilation = input_dilation
|
233
|
+
self.kernel_dilation = kernel_dilation
|
234
|
+
self.feature_group_count = feature_group_count
|
235
|
+
self.use_bias = use_bias
|
236
|
+
self.mask = mask
|
237
|
+
self.dtype = dtype
|
238
|
+
self.param_dtype = param_dtype
|
239
|
+
self.precision = precision
|
240
|
+
self.kernel_init = kernel_init
|
241
|
+
self.bias_init = bias_init
|
242
|
+
self.conv_general_dilated = conv_general_dilated
|
243
|
+
self.promote_dtype = promote_dtype
|
244
|
+
self.epsilon = epsilon
|
245
|
+
self.use_alpha = use_alpha
|
246
|
+
self.alpha_init = alpha_init
|
247
|
+
|
248
|
+
def __call__(self, inputs: Array) -> Array:
|
249
|
+
assert isinstance(self.kernel_size, tuple)
|
250
|
+
|
251
|
+
def maybe_broadcast(
|
252
|
+
x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
|
253
|
+
) -> tuple[int, ...]:
|
254
|
+
if x is None:
|
255
|
+
x = 1
|
256
|
+
if isinstance(x, int):
|
257
|
+
return (x,) * len(self.kernel_size)
|
258
|
+
return tuple(x)
|
259
|
+
|
260
|
+
num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
|
261
|
+
if num_batch_dimensions != 1:
|
262
|
+
input_batch_shape = inputs.shape[:num_batch_dimensions]
|
263
|
+
total_batch_size = int(np.prod(input_batch_shape))
|
264
|
+
flat_input_shape = (total_batch_size,) + inputs.shape[
|
265
|
+
num_batch_dimensions:
|
266
|
+
]
|
267
|
+
inputs_flat = jnp.reshape(inputs, flat_input_shape)
|
268
|
+
else:
|
269
|
+
inputs_flat = inputs
|
270
|
+
input_batch_shape = ()
|
271
|
+
|
272
|
+
strides = maybe_broadcast(self.strides)
|
273
|
+
input_dilation = maybe_broadcast(self.input_dilation)
|
274
|
+
kernel_dilation = maybe_broadcast(self.kernel_dilation)
|
275
|
+
|
276
|
+
padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
|
277
|
+
if padding_lax in ('CIRCULAR', 'REFLECT'):
|
278
|
+
assert isinstance(padding_lax, str)
|
279
|
+
kernel_size_dilated = [
|
280
|
+
(k - 1) * d + 1 for k, d in zip(self.kernel_size, kernel_dilation)
|
281
|
+
]
|
282
|
+
zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
|
283
|
+
pads = (
|
284
|
+
zero_pad
|
285
|
+
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
|
286
|
+
+ [(0, 0)]
|
287
|
+
)
|
288
|
+
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
|
289
|
+
inputs_flat = jnp.pad(inputs_flat, pads, mode=padding_mode)
|
290
|
+
padding_lax = 'VALID'
|
291
|
+
elif padding_lax == 'CAUSAL':
|
292
|
+
if len(self.kernel_size) != 1:
|
293
|
+
raise ValueError(
|
294
|
+
'Causal padding is only implemented for 1D convolutions.'
|
295
|
+
)
|
296
|
+
left_pad = kernel_dilation[0] * (self.kernel_size[0] - 1)
|
297
|
+
pads = [(0, 0), (left_pad, 0), (0, 0)]
|
298
|
+
inputs_flat = jnp.pad(inputs_flat, pads)
|
299
|
+
padding_lax = 'VALID'
|
300
|
+
|
301
|
+
dimension_numbers = _conv_dimension_numbers(inputs_flat.shape)
|
302
|
+
assert self.in_features % self.feature_group_count == 0
|
303
|
+
|
304
|
+
kernel_val = self.kernel.value
|
305
|
+
|
306
|
+
current_mask = self.mask
|
307
|
+
if current_mask is not None:
|
308
|
+
if current_mask.shape != self.kernel_shape:
|
309
|
+
raise ValueError(
|
310
|
+
'Mask needs to have the same shape as weights. '
|
311
|
+
f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
|
312
|
+
)
|
313
|
+
kernel_val *= current_mask
|
314
|
+
|
315
|
+
bias_val = self.bias.value if self.bias is not None else None
|
316
|
+
|
317
|
+
inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
|
318
|
+
(inputs_flat, kernel_val, bias_val), dtype=self.dtype
|
319
|
+
)
|
320
|
+
inputs_flat = inputs_promoted
|
321
|
+
kernel_val = kernel_promoted
|
322
|
+
bias_val = bias_promoted
|
323
|
+
|
324
|
+
dot_prod_map = self.conv_general_dilated(
|
325
|
+
inputs_flat,
|
326
|
+
kernel_val,
|
327
|
+
strides,
|
328
|
+
padding_lax,
|
329
|
+
lhs_dilation=input_dilation,
|
330
|
+
rhs_dilation=kernel_dilation,
|
331
|
+
dimension_numbers=dimension_numbers,
|
332
|
+
feature_group_count=self.feature_group_count,
|
333
|
+
precision=self.precision,
|
334
|
+
)
|
335
|
+
|
336
|
+
inputs_flat_squared = inputs_flat**2
|
337
|
+
kernel_in_channels_for_sum_sq = self.kernel_shape[-2]
|
338
|
+
kernel_for_patch_sq_sum_shape = self.kernel_size + (kernel_in_channels_for_sum_sq, 1)
|
339
|
+
kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
|
340
|
+
|
341
|
+
patch_sq_sum_map_raw = self.conv_general_dilated(
|
342
|
+
inputs_flat_squared,
|
343
|
+
kernel_for_patch_sq_sum,
|
344
|
+
strides,
|
345
|
+
padding_lax,
|
346
|
+
lhs_dilation=input_dilation,
|
347
|
+
rhs_dilation=kernel_dilation,
|
348
|
+
dimension_numbers=dimension_numbers,
|
349
|
+
feature_group_count=self.feature_group_count,
|
350
|
+
precision=self.precision,
|
351
|
+
)
|
352
|
+
|
353
|
+
if self.feature_group_count > 1:
|
354
|
+
num_out_channels_per_group = self.out_features // self.feature_group_count
|
355
|
+
if num_out_channels_per_group == 0 :
|
356
|
+
raise ValueError(
|
357
|
+
"out_features must be a multiple of feature_group_count and greater or equal."
|
358
|
+
)
|
359
|
+
patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, num_out_channels_per_group, axis=-1)
|
360
|
+
else:
|
361
|
+
patch_sq_sum_map = patch_sq_sum_map_raw
|
362
|
+
|
363
|
+
reduce_axes_for_kernel_sq = tuple(range(kernel_val.ndim - 1))
|
364
|
+
kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
|
365
|
+
|
366
|
+
distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
|
367
|
+
y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
|
368
|
+
|
369
|
+
if self.use_bias and bias_val is not None:
|
370
|
+
bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
|
371
|
+
y += jnp.reshape(bias_val, bias_reshape_dims)
|
372
|
+
|
373
|
+
if self.use_alpha and self.alpha is not None:
|
374
|
+
alpha_val = self.alpha.value
|
375
|
+
# Ensure alpha_val is promoted to the same dtype as y if needed, though usually it's float32.
|
376
|
+
# This might require using self.promote_dtype or ensuring consistent dtypes.
|
377
|
+
# For simplicity, assuming alpha_val.dtype is compatible or jax handles promotion.
|
378
|
+
scale = (jnp.sqrt(jnp.array(self.out_features, dtype=y.dtype)) /
|
379
|
+
jnp.log(1 + jnp.array(self.out_features, dtype=y.dtype))) ** alpha_val
|
380
|
+
y = y * scale
|
381
|
+
|
382
|
+
if num_batch_dimensions != 1:
|
383
|
+
output_shape = input_batch_shape + y.shape[1:]
|
384
|
+
y = jnp.reshape(y, output_shape)
|
385
|
+
return y
|
386
|
+
|
387
|
+
Array = jax.Array
|
388
|
+
Axis = int
|
389
|
+
Size = int
|
390
|
+
|
391
|
+
|
392
|
+
default_kernel_init = initializers.lecun_normal()
|
393
|
+
default_bias_init = initializers.zeros_init()
|
394
|
+
default_alpha_init = initializers.ones_init()
|
395
|
+
|
396
|
+
class YatNMN(Module):
|
397
|
+
"""A linear transformation applied over the last dimension of the input.
|
398
|
+
|
399
|
+
Example usage::
|
400
|
+
|
401
|
+
>>> from flax import nnx
|
402
|
+
>>> import jax, jax.numpy as jnp
|
403
|
+
|
404
|
+
>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
|
405
|
+
>>> jax.tree.map(jnp.shape, nnx.state(layer))
|
406
|
+
State({
|
407
|
+
'bias': VariableState(
|
408
|
+
type=Param,
|
409
|
+
value=(4,)
|
410
|
+
),
|
411
|
+
'kernel': VariableState(
|
412
|
+
type=Param,
|
413
|
+
value=(3, 4)
|
414
|
+
)
|
415
|
+
})
|
416
|
+
|
417
|
+
Args:
|
418
|
+
in_features: the number of input features.
|
419
|
+
out_features: the number of output features.
|
420
|
+
use_bias: whether to add a bias to the output (default: True).
|
421
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
422
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
423
|
+
precision: numerical precision of the computation see ``jax.lax.Precision``
|
424
|
+
for details.
|
425
|
+
kernel_init: initializer function for the weight matrix.
|
426
|
+
bias_init: initializer function for the bias.
|
427
|
+
dot_general: dot product function.
|
428
|
+
promote_dtype: function to promote the dtype of the arrays to the desired
|
429
|
+
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
430
|
+
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
431
|
+
promoted dtype.
|
432
|
+
rngs: rng key.
|
433
|
+
"""
|
434
|
+
|
435
|
+
__data__ = ('kernel', 'bias')
|
436
|
+
|
437
|
+
def __init__(
|
438
|
+
self,
|
439
|
+
in_features: int,
|
440
|
+
out_features: int,
|
441
|
+
*,
|
442
|
+
use_bias: bool = True,
|
443
|
+
use_alpha: bool = True,
|
444
|
+
dtype: tp.Optional[Dtype] = None,
|
445
|
+
param_dtype: Dtype = jnp.float32,
|
446
|
+
precision: PrecisionLike = None,
|
447
|
+
kernel_init: Initializer = default_kernel_init,
|
448
|
+
bias_init: Initializer = default_bias_init,
|
449
|
+
alpha_init: Initializer = default_alpha_init,
|
450
|
+
dot_general: DotGeneralT = lax.dot_general,
|
451
|
+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
452
|
+
rngs: rnglib.Rngs,
|
453
|
+
epsilon: float = 1/137,
|
454
|
+
):
|
455
|
+
|
456
|
+
kernel_key = rngs.params()
|
457
|
+
self.kernel = nnx.Param(
|
458
|
+
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
459
|
+
)
|
460
|
+
self.bias: nnx.Param[jax.Array] | None
|
461
|
+
if use_bias:
|
462
|
+
bias_key = rngs.params()
|
463
|
+
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
464
|
+
else:
|
465
|
+
self.bias = None
|
466
|
+
|
467
|
+
self.alpha: nnx.Param[jax.Array] | None
|
468
|
+
if use_alpha:
|
469
|
+
alpha_key = rngs.params()
|
470
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
471
|
+
else:
|
472
|
+
self.alpha = None
|
473
|
+
|
474
|
+
self.in_features = in_features
|
475
|
+
self.out_features = out_features
|
476
|
+
self.use_bias = use_bias
|
477
|
+
self.use_alpha = use_alpha
|
478
|
+
self.dtype = dtype
|
479
|
+
self.param_dtype = param_dtype
|
480
|
+
self.precision = precision
|
481
|
+
self.kernel_init = kernel_init
|
482
|
+
self.bias_init = bias_init
|
483
|
+
self.dot_general = dot_general
|
484
|
+
self.promote_dtype = promote_dtype
|
485
|
+
self.epsilon = epsilon
|
486
|
+
|
487
|
+
def __call__(self, inputs: Array) -> Array:
|
488
|
+
"""Applies a linear transformation to the inputs along the last dimension.
|
489
|
+
|
490
|
+
Args:
|
491
|
+
inputs: The nd-array to be transformed.
|
492
|
+
|
493
|
+
Returns:
|
494
|
+
The transformed input.
|
495
|
+
"""
|
496
|
+
kernel = self.kernel.value
|
497
|
+
bias = self.bias.value if self.bias is not None else None
|
498
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
499
|
+
|
500
|
+
inputs, kernel, bias, alpha = self.promote_dtype(
|
501
|
+
(inputs, kernel, bias, alpha), dtype=self.dtype
|
502
|
+
)
|
503
|
+
y = self.dot_general(
|
504
|
+
inputs,
|
505
|
+
kernel,
|
506
|
+
(((inputs.ndim - 1,), (0,)), ((), ())),
|
507
|
+
precision=self.precision,
|
508
|
+
)
|
509
|
+
|
510
|
+
assert self.use_bias == (bias is not None)
|
511
|
+
assert self.use_alpha == (alpha is not None)
|
512
|
+
|
513
|
+
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
514
|
+
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True) # Change axis to 0 and keepdims to True
|
515
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
516
|
+
|
517
|
+
# # Element-wise operation
|
518
|
+
y = y ** 2 / (distances + self.epsilon)
|
519
|
+
|
520
|
+
if bias is not None:
|
521
|
+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
522
|
+
|
523
|
+
if alpha is not None:
|
524
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
525
|
+
y = y * scale
|
526
|
+
|
527
|
+
|
528
|
+
return y
|
529
|
+
|
530
|
+
def loss_fn(model, batch):
|
531
|
+
logits = model(batch['image'], training=True)
|
532
|
+
loss = optax.softmax_cross_entropy_with_integer_labels(
|
533
|
+
logits=logits, labels=batch['label']
|
534
|
+
).mean()
|
535
|
+
return loss, logits
|
536
|
+
|
537
|
+
@nnx.jit
|
538
|
+
def train_step(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
|
539
|
+
"""Train for a single step."""
|
540
|
+
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
|
541
|
+
(loss, logits), grads = grad_fn(model, batch)
|
542
|
+
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
|
543
|
+
optimizer.update(grads) # In-place updates.
|
544
|
+
|
545
|
+
@nnx.jit
|
546
|
+
def eval_step(model, metrics: nnx.MultiMetric, batch):
|
547
|
+
loss, logits = loss_fn(model, batch)
|
548
|
+
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
|
549
|
+
|
550
|
+
tf.random.set_seed(0)
|
551
|
+
|
552
|
+
# ===== DATASET CONFIGURATIONS =====
|
553
|
+
DATASET_CONFIGS = {
|
554
|
+
'cifar10': {
|
555
|
+
'num_classes': 10, 'input_channels': 3,
|
556
|
+
'train_split': 'train', 'test_split': 'test',
|
557
|
+
'image_key': 'image', 'label_key': 'label',
|
558
|
+
'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
|
559
|
+
},
|
560
|
+
'cifar100': {
|
561
|
+
'num_classes': 100, 'input_channels': 3,
|
562
|
+
'train_split': 'train', 'test_split': 'test',
|
563
|
+
'image_key': 'image', 'label_key': 'label',
|
564
|
+
'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
|
565
|
+
},
|
566
|
+
'stl10': {
|
567
|
+
'num_classes': 10, 'input_channels': 3,
|
568
|
+
'train_split': 'train', 'test_split': 'test',
|
569
|
+
'image_key': 'image', 'label_key': 'label',
|
570
|
+
'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
|
571
|
+
},
|
572
|
+
'eurosat/rgb': {
|
573
|
+
'num_classes': 10, 'input_channels': 3,
|
574
|
+
'train_split': 'train[:80%]', 'test_split': 'train[80%:]',
|
575
|
+
'image_key': 'image', 'label_key': 'label', # EuroSAT label key is 'label' in TFDS
|
576
|
+
'num_epochs': 5, 'eval_every': 100, 'batch_size': 128
|
577
|
+
},
|
578
|
+
'eurosat/all': {
|
579
|
+
'num_classes': 10, 'input_channels': 13,
|
580
|
+
'train_split': 'train[:80%]', 'test_split': 'train[80%:]',
|
581
|
+
'image_key': 'image', 'label_key': 'label',
|
582
|
+
'num_epochs': 5, 'eval_every': 100, 'batch_size': 16 # Smaller batch for more channels
|
583
|
+
},
|
584
|
+
# Example for a dataset that might need specific image resizing if models were not robust
|
585
|
+
# 'some_other_dataset': {
|
586
|
+
# 'num_classes': X, 'input_channels': Y,
|
587
|
+
# 'train_split': 'train', 'test_split': 'validation',
|
588
|
+
# 'image_key': 'image_data', 'label_key': 'class_id',
|
589
|
+
# 'target_image_size': [H, W] # Optional: for explicit resizing
|
590
|
+
# },
|
591
|
+
}
|
592
|
+
|
593
|
+
# Original global dataset setup (will be superseded by _train_model_loop for actual training runs)
|
594
|
+
# These might still be used by some top-level calls if not careful, or for initial exploration.
|
595
|
+
_DEFAULT_DATASET_FOR_GLOBALS = 'cifar10'
|
596
|
+
|
597
|
+
# Get default training parameters from the default dataset's config or set fallbacks
|
598
|
+
_default_config_for_globals = DATASET_CONFIGS.get(_DEFAULT_DATASET_FOR_GLOBALS, {})
|
599
|
+
_global_num_epochs = _default_config_for_globals.get('num_epochs', 10) # Default to 10 epochs
|
600
|
+
_global_eval_every = _default_config_for_globals.get('eval_every', 200)
|
601
|
+
_global_batch_size = _default_config_for_globals.get('batch_size', 64)
|
602
|
+
|
603
|
+
|
604
|
+
_global_ds_builder = tfds.builder(_DEFAULT_DATASET_FOR_GLOBALS)
|
605
|
+
_global_ds_info = _global_ds_builder.info
|
606
|
+
|
607
|
+
train_ds_global_tf: tf.data.Dataset = tfds.load(_DEFAULT_DATASET_FOR_GLOBALS, split='train')
|
608
|
+
test_ds_global_tf: tf.data.Dataset = tfds.load(_DEFAULT_DATASET_FOR_GLOBALS, split='test')
|
609
|
+
|
610
|
+
def _global_preprocess(sample):
|
611
|
+
return {
|
612
|
+
'image': tf.cast(sample[DATASET_CONFIGS[_DEFAULT_DATASET_FOR_GLOBALS]['image_key']], tf.float32) / 255,
|
613
|
+
'label': sample[DATASET_CONFIGS[_DEFAULT_DATASET_FOR_GLOBALS]['label_key']],
|
614
|
+
}
|
615
|
+
|
616
|
+
train_ds_global_tf = train_ds_global_tf.map(_global_preprocess)
|
617
|
+
test_ds_global_tf = test_ds_global_tf.map(_global_preprocess)
|
618
|
+
|
619
|
+
# Original global TF dataset iterators (used for some analysis functions if they don't reload)
|
620
|
+
# It's better if analysis functions requiring data get it passed or reload it with correct dataset_name
|
621
|
+
# Removing .take() from global train_ds to align with epoch-based approach; consumers must manage iterations.
|
622
|
+
train_ds = train_ds_global_tf.repeat().shuffle(1024).batch(_global_batch_size, drop_remainder=True).prefetch(1)
|
623
|
+
test_ds = test_ds_global_tf.batch(_global_batch_size, drop_remainder=True).prefetch(1)
|
624
|
+
|
625
|
+
# ===== MODEL COMPARISON FUNCTIONS =====
|
626
|
+
|
627
|
+
def compare_training_curves(yat_history, linear_history):
|
628
|
+
"""
|
629
|
+
Compare training curves between YAT and Linear models.
|
630
|
+
Plots side-by-side comparison of loss and accuracy over training steps.
|
631
|
+
"""
|
632
|
+
import matplotlib.pyplot as plt
|
633
|
+
|
634
|
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
635
|
+
fig.suptitle('Training Curves Comparison: YAT vs Linear Models', fontsize=16, fontweight='bold')
|
636
|
+
|
637
|
+
steps = range(len(yat_history['train_loss']))
|
638
|
+
|
639
|
+
# Training Loss
|
640
|
+
ax1.plot(steps, yat_history['train_loss'], 'b-', label='YAT Model', linewidth=2)
|
641
|
+
ax1.plot(steps, linear_history['train_loss'], 'r--', label='Linear Model', linewidth=2)
|
642
|
+
ax1.set_title('Training Loss', fontweight='bold')
|
643
|
+
ax1.set_xlabel('Evaluation Steps')
|
644
|
+
ax1.set_ylabel('Loss')
|
645
|
+
ax1.legend()
|
646
|
+
ax1.grid(True, alpha=0.3)
|
647
|
+
|
648
|
+
# Test Loss
|
649
|
+
ax2.plot(steps, yat_history['test_loss'], 'b-', label='YAT Model', linewidth=2)
|
650
|
+
ax2.plot(steps, linear_history['test_loss'], 'r--', label='Linear Model', linewidth=2)
|
651
|
+
ax2.set_title('Test Loss', fontweight='bold')
|
652
|
+
ax2.set_xlabel('Evaluation Steps')
|
653
|
+
ax2.set_ylabel('Loss')
|
654
|
+
ax2.legend()
|
655
|
+
ax2.grid(True, alpha=0.3)
|
656
|
+
|
657
|
+
# Training Accuracy
|
658
|
+
ax3.plot(steps, yat_history['train_accuracy'], 'b-', label='YAT Model', linewidth=2)
|
659
|
+
ax3.plot(steps, linear_history['train_accuracy'], 'r--', label='Linear Model', linewidth=2)
|
660
|
+
ax3.set_title('Training Accuracy', fontweight='bold')
|
661
|
+
ax3.set_xlabel('Evaluation Steps')
|
662
|
+
ax3.set_ylabel('Accuracy')
|
663
|
+
ax3.legend()
|
664
|
+
ax3.grid(True, alpha=0.3)
|
665
|
+
|
666
|
+
# Test Accuracy
|
667
|
+
ax4.plot(steps, yat_history['test_accuracy'], 'b-', label='YAT Model', linewidth=2)
|
668
|
+
ax4.plot(steps, linear_history['test_accuracy'], 'r--', label='Linear Model', linewidth=2)
|
669
|
+
ax4.set_title('Test Accuracy', fontweight='bold')
|
670
|
+
ax4.set_xlabel('Evaluation Steps')
|
671
|
+
ax4.set_ylabel('Accuracy')
|
672
|
+
ax4.legend()
|
673
|
+
ax4.grid(True, alpha=0.3)
|
674
|
+
|
675
|
+
plt.tight_layout()
|
676
|
+
plt.show()
|
677
|
+
|
678
|
+
print("📈 Training curves comparison plotted successfully!")
|
679
|
+
|
680
|
+
def print_final_metrics_comparison(yat_history, linear_history):
|
681
|
+
"""
|
682
|
+
Print a detailed comparison table of final metrics.
|
683
|
+
"""
|
684
|
+
print("\n📊 FINAL METRICS COMPARISON")
|
685
|
+
print("=" * 60)
|
686
|
+
|
687
|
+
yat_final = {
|
688
|
+
'train_loss': yat_history['train_loss'][-1],
|
689
|
+
'train_accuracy': yat_history['train_accuracy'][-1],
|
690
|
+
'test_loss': yat_history['test_loss'][-1],
|
691
|
+
'test_accuracy': yat_history['test_accuracy'][-1]
|
692
|
+
}
|
693
|
+
|
694
|
+
linear_final = {
|
695
|
+
'train_loss': linear_history['train_loss'][-1],
|
696
|
+
'train_accuracy': linear_history['train_accuracy'][-1],
|
697
|
+
'test_loss': linear_history['test_loss'][-1],
|
698
|
+
'test_accuracy': linear_history['test_accuracy'][-1]
|
699
|
+
}
|
700
|
+
|
701
|
+
print(f"{'Metric':<20} {'YAT Model':<15} {'Linear Model':<15} {'Difference':<15}")
|
702
|
+
print("-" * 65)
|
703
|
+
|
704
|
+
for metric in ['train_loss', 'test_loss', 'train_accuracy', 'test_accuracy']:
|
705
|
+
yat_val = yat_final[metric]
|
706
|
+
linear_val = linear_final[metric]
|
707
|
+
diff = yat_val - linear_val
|
708
|
+
diff_str = f"{diff:+.4f}"
|
709
|
+
if 'accuracy' in metric:
|
710
|
+
if diff > 0:
|
711
|
+
diff_str += " (YAT better)"
|
712
|
+
elif diff < 0:
|
713
|
+
diff_str += " (Linear better)"
|
714
|
+
else: # loss
|
715
|
+
if diff < 0:
|
716
|
+
diff_str += " (YAT better)"
|
717
|
+
elif diff > 0:
|
718
|
+
diff_str += " (Linear better)"
|
719
|
+
|
720
|
+
print(f"{metric:<20} {yat_val:<15.4f} {linear_val:<15.4f} {diff_str:<15}")
|
721
|
+
|
722
|
+
# Summary
|
723
|
+
print("\n🏆 SUMMARY:")
|
724
|
+
if yat_final['test_accuracy'] > linear_final['test_accuracy']:
|
725
|
+
winner = "YAT Model"
|
726
|
+
margin = yat_final['test_accuracy'] - linear_final['test_accuracy']
|
727
|
+
else:
|
728
|
+
winner = "Linear Model"
|
729
|
+
margin = linear_final['test_accuracy'] - yat_final['test_accuracy']
|
730
|
+
|
731
|
+
print(f" Better Test Accuracy: {winner} (by {margin:.4f})")
|
732
|
+
print(f" YAT Test Accuracy: {yat_final['test_accuracy']:.4f}")
|
733
|
+
print(f" Linear Test Accuracy: {linear_final['test_accuracy']:.4f}")
|
734
|
+
|
735
|
+
def analyze_convergence(yat_history, linear_history):
|
736
|
+
"""
|
737
|
+
Analyze convergence speed and stability of both models.
|
738
|
+
"""
|
739
|
+
print("\n🔍 CONVERGENCE ANALYSIS")
|
740
|
+
print("=" * 50)
|
741
|
+
|
742
|
+
def calculate_convergence_metrics(history):
|
743
|
+
test_acc = history['test_accuracy']
|
744
|
+
train_acc = history['train_accuracy']
|
745
|
+
test_loss = history['test_loss']
|
746
|
+
|
747
|
+
# Find step where model reaches 50% of final accuracy
|
748
|
+
final_acc = test_acc[-1]
|
749
|
+
target_acc = 0.5 * final_acc
|
750
|
+
convergence_step = 0
|
751
|
+
for i, acc in enumerate(test_acc):
|
752
|
+
if acc >= target_acc:
|
753
|
+
convergence_step = i
|
754
|
+
break
|
755
|
+
|
756
|
+
# Calculate stability (variance in last 25% of training)
|
757
|
+
last_quarter = len(test_acc) // 4
|
758
|
+
stability = np.std(test_acc[-last_quarter:])
|
759
|
+
|
760
|
+
# Calculate final overfitting (train_acc - test_acc)
|
761
|
+
overfitting = train_acc[-1] - test_acc[-1]
|
762
|
+
|
763
|
+
return {
|
764
|
+
'convergence_step': convergence_step,
|
765
|
+
'stability': stability,
|
766
|
+
'overfitting': overfitting,
|
767
|
+
'final_loss': test_loss[-1]
|
768
|
+
}
|
769
|
+
|
770
|
+
yat_conv = calculate_convergence_metrics(yat_history)
|
771
|
+
linear_conv = calculate_convergence_metrics(linear_history)
|
772
|
+
|
773
|
+
print(f"{'Metric':<25} {'YAT Model':<15} {'Linear Model':<15}")
|
774
|
+
print("-" * 55)
|
775
|
+
print(f"{'Convergence Speed':<25} {yat_conv['convergence_step']:<15} {linear_conv['convergence_step']:<15}")
|
776
|
+
print(f"{'Stability (std)':<25} {yat_conv['stability']:<15.4f} {linear_conv['stability']:<15.4f}")
|
777
|
+
print(f"{'Overfitting Gap':<25} {yat_conv['overfitting']:<15.4f} {linear_conv['overfitting']:<15.4f}")
|
778
|
+
print(f"{'Final Test Loss':<25} {yat_conv['final_loss']:<15.4f} {linear_conv['final_loss']:<15.4f}")
|
779
|
+
|
780
|
+
# Analysis
|
781
|
+
print("\n📋 ANALYSIS:")
|
782
|
+
if yat_conv['convergence_step'] < linear_conv['convergence_step']:
|
783
|
+
print(f" 🚀 YAT model converges faster (step {yat_conv['convergence_step']} vs {linear_conv['convergence_step']})")
|
784
|
+
else:
|
785
|
+
print(f" 🚀 Linear model converges faster (step {linear_conv['convergence_step']} vs {yat_conv['convergence_step']})")
|
786
|
+
|
787
|
+
if yat_conv['stability'] < linear_conv['stability']:
|
788
|
+
print(f" 📈 YAT model is more stable (std: {yat_conv['stability']:.4f} vs {linear_conv['stability']:.4f})")
|
789
|
+
else:
|
790
|
+
print(f" 📈 Linear model is more stable (std: {linear_conv['stability']:.4f} vs {yat_conv['stability']:.4f})")
|
791
|
+
|
792
|
+
if abs(yat_conv['overfitting']) < abs(linear_conv['overfitting']):
|
793
|
+
print(f" 🎯 YAT model has less overfitting (gap: {yat_conv['overfitting']:.4f} vs {linear_conv['overfitting']:.4f})")
|
794
|
+
else:
|
795
|
+
print(f" 🎯 Linear model has less overfitting (gap: {linear_conv['overfitting']:.4f} vs {yat_conv['overfitting']:.4f})")
|
796
|
+
|
797
|
+
def detailed_test_evaluation(yat_model, linear_model, test_ds_iter, class_names: list[str]):
|
798
|
+
"""
|
799
|
+
Perform detailed evaluation on test set including per-class accuracy and model agreement.
|
800
|
+
test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
|
801
|
+
class_names: List of class names for the current dataset.
|
802
|
+
"""
|
803
|
+
print("Running detailed test evaluation...")
|
804
|
+
|
805
|
+
# CIFAR-10 class names # This will be replaced by the passed class_names
|
806
|
+
# cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
807
|
+
num_classes = len(class_names)
|
808
|
+
|
809
|
+
yat_predictions = []
|
810
|
+
linear_predictions = []
|
811
|
+
true_labels = []
|
812
|
+
|
813
|
+
# Collect predictions from both models
|
814
|
+
for batch in test_ds_iter.as_numpy_iterator():
|
815
|
+
batch_images, batch_labels = batch['image'], batch['label']
|
816
|
+
|
817
|
+
# YAT model predictions
|
818
|
+
yat_logits = yat_model(batch_images, training=False)
|
819
|
+
yat_preds = jnp.argmax(yat_logits, axis=1)
|
820
|
+
|
821
|
+
# Linear model predictions
|
822
|
+
linear_logits = linear_model(batch_images, training=False)
|
823
|
+
linear_preds = jnp.argmax(linear_logits, axis=1)
|
824
|
+
|
825
|
+
yat_predictions.extend(yat_preds.tolist())
|
826
|
+
linear_predictions.extend(linear_preds.tolist())
|
827
|
+
true_labels.extend(batch_labels.tolist())
|
828
|
+
|
829
|
+
yat_predictions = np.array(yat_predictions)
|
830
|
+
linear_predictions = np.array(linear_predictions)
|
831
|
+
true_labels = np.array(true_labels)
|
832
|
+
|
833
|
+
# Calculate per-class accuracies
|
834
|
+
print("\n🎯 PER-CLASS ACCURACY COMPARISON")
|
835
|
+
print("=" * 70)
|
836
|
+
print(f"{'Class':<12} {'YAT Acc':<10} {'Linear Acc':<12} {'Difference':<15} {'Sample Count':<12}")
|
837
|
+
print("-" * 70)
|
838
|
+
|
839
|
+
for class_idx in range(num_classes): # Use num_classes from passed class_names
|
840
|
+
class_mask = true_labels == class_idx
|
841
|
+
class_samples = np.sum(class_mask)
|
842
|
+
|
843
|
+
if class_samples > 0:
|
844
|
+
yat_class_acc = np.mean(yat_predictions[class_mask] == true_labels[class_mask])
|
845
|
+
linear_class_acc = np.mean(linear_predictions[class_mask] == true_labels[class_mask])
|
846
|
+
diff = yat_class_acc - linear_class_acc
|
847
|
+
diff_str = f"{diff:+.4f}"
|
848
|
+
|
849
|
+
print(f"{class_names[class_idx]:<12} {yat_class_acc:<10.4f} {linear_class_acc:<12.4f} {diff_str:<15} {class_samples:<12}")
|
850
|
+
elif num_classes <= 20: # Only print for manageable number of classes if no samples
|
851
|
+
print(f"{class_names[class_idx]:<12} {'N/A':<10} {'N/A':<12} {'N/A':<15} {class_samples:<12}")
|
852
|
+
|
853
|
+
# Model agreement analysis
|
854
|
+
agreement = np.mean(yat_predictions == linear_predictions)
|
855
|
+
both_correct = np.mean((yat_predictions == true_labels) & (linear_predictions == true_labels))
|
856
|
+
yat_correct_linear_wrong = np.mean((yat_predictions == true_labels) & (linear_predictions != true_labels))
|
857
|
+
linear_correct_yat_wrong = np.mean((linear_predictions == true_labels) & (yat_predictions != true_labels))
|
858
|
+
both_wrong = np.mean((yat_predictions != true_labels) & (linear_predictions != true_labels))
|
859
|
+
|
860
|
+
print(f"\n🤝 MODEL AGREEMENT ANALYSIS")
|
861
|
+
print("=" * 40)
|
862
|
+
print(f"Overall Agreement: {agreement:.4f}")
|
863
|
+
print(f"Both Correct: {both_correct:.4f}")
|
864
|
+
print(f"YAT Correct, Linear Wrong: {yat_correct_linear_wrong:.4f}")
|
865
|
+
print(f"Linear Correct, YAT Wrong: {linear_correct_yat_wrong:.4f}")
|
866
|
+
print(f"Both Wrong: {both_wrong:.4f}")
|
867
|
+
|
868
|
+
return {
|
869
|
+
'yat_predictions': yat_predictions,
|
870
|
+
'linear_predictions': linear_predictions,
|
871
|
+
'true_labels': true_labels,
|
872
|
+
'class_names': class_names,
|
873
|
+
'agreement': agreement,
|
874
|
+
'both_correct': both_correct
|
875
|
+
}
|
876
|
+
|
877
|
+
def plot_confusion_matrices(predictions_data):
|
878
|
+
"""
|
879
|
+
Plot confusion matrices for both models side by side.
|
880
|
+
"""
|
881
|
+
from sklearn.metrics import confusion_matrix
|
882
|
+
import matplotlib.pyplot as plt
|
883
|
+
import seaborn as sns
|
884
|
+
|
885
|
+
yat_preds = predictions_data['yat_predictions']
|
886
|
+
linear_preds = predictions_data['linear_predictions']
|
887
|
+
true_labels = predictions_data['true_labels']
|
888
|
+
class_names = predictions_data['class_names']
|
889
|
+
|
890
|
+
# Calculate confusion matrices
|
891
|
+
yat_cm = confusion_matrix(true_labels, yat_preds)
|
892
|
+
linear_cm = confusion_matrix(true_labels, linear_preds)
|
893
|
+
|
894
|
+
# Plot side by side
|
895
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
896
|
+
|
897
|
+
# YAT model confusion matrix
|
898
|
+
sns.heatmap(yat_cm, annot=True, fmt='d', cmap='Blues',
|
899
|
+
xticklabels=class_names, yticklabels=class_names, ax=ax1)
|
900
|
+
ax1.set_title('YAT Model - Confusion Matrix', fontweight='bold')
|
901
|
+
ax1.set_xlabel('Predicted Label')
|
902
|
+
ax1.set_ylabel('True Label')
|
903
|
+
|
904
|
+
# Linear model confusion matrix
|
905
|
+
sns.heatmap(linear_cm, annot=True, fmt='d', cmap='Reds',
|
906
|
+
xticklabels=class_names, yticklabels=class_names, ax=ax2)
|
907
|
+
ax2.set_title('Linear Model - Confusion Matrix', fontweight='bold')
|
908
|
+
ax2.set_xlabel('Predicted Label')
|
909
|
+
ax2.set_ylabel('True Label')
|
910
|
+
|
911
|
+
plt.tight_layout()
|
912
|
+
plt.show()
|
913
|
+
|
914
|
+
print("📊 Confusion matrices plotted successfully!")
|
915
|
+
|
916
|
+
def generate_summary_report(yat_history, linear_history, predictions_data):
|
917
|
+
"""
|
918
|
+
Generate a comprehensive summary report of the comparison.
|
919
|
+
"""
|
920
|
+
print("\n" + "="*80)
|
921
|
+
print(" COMPREHENSIVE SUMMARY REPORT")
|
922
|
+
print("="*80)
|
923
|
+
|
924
|
+
# Final metrics
|
925
|
+
yat_final_acc = yat_history['test_accuracy'][-1]
|
926
|
+
linear_final_acc = linear_history['test_accuracy'][-1]
|
927
|
+
|
928
|
+
print(f"\n🏆 OVERALL WINNER:")
|
929
|
+
if yat_final_acc > linear_final_acc:
|
930
|
+
winner = "YAT Model"
|
931
|
+
margin = yat_final_acc - linear_final_acc
|
932
|
+
print(f" 🥇 {winner} wins by {margin:.4f} accuracy points!")
|
933
|
+
elif linear_final_acc > yat_final_acc:
|
934
|
+
winner = "Linear Model"
|
935
|
+
margin = linear_final_acc - yat_final_acc
|
936
|
+
print(f" 🥇 {winner} wins by {margin:.4f} accuracy points!")
|
937
|
+
else:
|
938
|
+
print(f" 🤝 It's a tie! Both models achieved {yat_final_acc:.4f} accuracy")
|
939
|
+
|
940
|
+
print(f"\n📈 PERFORMANCE SUMMARY:")
|
941
|
+
print(f" YAT Model Test Accuracy: {yat_final_acc:.4f}")
|
942
|
+
print(f" Linear Model Test Accuracy: {linear_final_acc:.4f}")
|
943
|
+
print(f" Model Agreement: {predictions_data['agreement']:.4f}")
|
944
|
+
print(f" Both Models Correct: {predictions_data['both_correct']:.4f}")
|
945
|
+
|
946
|
+
# Best and worst performing classes
|
947
|
+
class_names = predictions_data['class_names']
|
948
|
+
true_labels = predictions_data['true_labels']
|
949
|
+
yat_preds = predictions_data['yat_predictions']
|
950
|
+
linear_preds = predictions_data['linear_predictions']
|
951
|
+
|
952
|
+
yat_class_accs = []
|
953
|
+
linear_class_accs = []
|
954
|
+
|
955
|
+
for class_idx in range(len(class_names)):
|
956
|
+
class_mask = true_labels == class_idx
|
957
|
+
if np.sum(class_mask) > 0:
|
958
|
+
yat_acc = np.mean(yat_preds[class_mask] == true_labels[class_mask])
|
959
|
+
linear_acc = np.mean(linear_preds[class_mask] == true_labels[class_mask])
|
960
|
+
yat_class_accs.append((class_names[class_idx], yat_acc))
|
961
|
+
linear_class_accs.append((class_names[class_idx], linear_acc))
|
962
|
+
|
963
|
+
# Sort by accuracy
|
964
|
+
yat_class_accs.sort(key=lambda x: x[1], reverse=True)
|
965
|
+
linear_class_accs.sort(key=lambda x: x[1], reverse=True)
|
966
|
+
|
967
|
+
print(f"\n🎯 BEST PERFORMING CLASSES (Top 3 if available):")
|
968
|
+
for i in range(min(3, len(yat_class_accs))):
|
969
|
+
print(f" YAT Model: {yat_class_accs[i][0]} ({yat_class_accs[i][1]:.4f})")
|
970
|
+
for i in range(min(3, len(linear_class_accs))):
|
971
|
+
print(f" Linear Model: {linear_class_accs[i][0]} ({linear_class_accs[i][1]:.4f})")
|
972
|
+
|
973
|
+
print(f"\n🎯 WORST PERFORMING CLASSES (Bottom 3 if available):")
|
974
|
+
for i in range(min(3, len(yat_class_accs))):
|
975
|
+
print(f" YAT Model: {yat_class_accs[-(i+1)][0]} ({yat_class_accs[-(i+1)][1]:.4f})")
|
976
|
+
for i in range(min(3, len(linear_class_accs))):
|
977
|
+
print(f" Linear Model: {linear_class_accs[-(i+1)][0]} ({linear_class_accs[-(i+1)][1]:.4f})")
|
978
|
+
|
979
|
+
print(f"\n📊 TRAINING CHARACTERISTICS:")
|
980
|
+
if yat_history['train_accuracy'] and linear_history['train_accuracy'] and \
|
981
|
+
yat_history['test_accuracy'] and linear_history['test_accuracy']:
|
982
|
+
print(f" YAT Final Train Accuracy: {yat_history['train_accuracy'][-1]:.4f}")
|
983
|
+
print(f" Linear Final Train Accuracy: {linear_history['train_accuracy'][-1]:.4f}")
|
984
|
+
print(f" YAT Overfitting Gap: {yat_history['train_accuracy'][-1] - yat_history['test_accuracy'][-1]:.4f}")
|
985
|
+
print(f" Linear Overfitting Gap: {linear_history['train_accuracy'][-1] - linear_history['test_accuracy'][-1]:.4f}")
|
986
|
+
else:
|
987
|
+
print(" Training/Test accuracy history missing for full overfitting gap analysis.")
|
988
|
+
|
989
|
+
print(f"\n💡 RECOMMENDATIONS:")
|
990
|
+
if yat_final_acc > linear_final_acc:
|
991
|
+
print(f" ✅ YAT model architecture shows superior performance")
|
992
|
+
print(f" ✅ Consider using YAT layers for similar classification tasks")
|
993
|
+
else:
|
994
|
+
print(f" ✅ Linear model architecture is sufficient for this task")
|
995
|
+
print(f" ✅ Standard convolution layers perform well on CIFAR-10")
|
996
|
+
|
997
|
+
if predictions_data['agreement'] > 0.8:
|
998
|
+
print(f" 🤝 High model agreement suggests stable learning")
|
999
|
+
else:
|
1000
|
+
print(f" 🔍 Low model agreement suggests different learning patterns")
|
1001
|
+
|
1002
|
+
print("="*80)
|
1003
|
+
|
1004
|
+
# ===== COMPLETE IMPLEMENTATION EXAMPLE =====
|
1005
|
+
|
1006
|
+
# Moved YatCNN class definition to module level
|
1007
|
+
class YatCNN(nnx.Module):
|
1008
|
+
"""YAT CNN model with custom layers."""
|
1009
|
+
|
1010
|
+
def __init__(self, *, num_classes: int, input_channels: int, rngs: nnx.Rngs):
|
1011
|
+
self.conv1 = YatConv(input_channels, 32, kernel_size=(5, 5), rngs=rngs)
|
1012
|
+
self.conv2 = YatConv(32, 64, kernel_size=(5, 5), rngs=rngs)
|
1013
|
+
self.conv3 = YatConv(64, 128, kernel_size=(5, 5), rngs=rngs)
|
1014
|
+
self.conv4 = YatConv(128, 128, kernel_size=(5, 5), rngs=rngs)
|
1015
|
+
self.dropout1 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1016
|
+
self.dropout2 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1017
|
+
self.dropout3 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1018
|
+
self.dropout4 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1019
|
+
|
1020
|
+
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
|
1021
|
+
self.non_linear2 = YatNMN(128, num_classes, use_bias=False, use_alpha=False, rngs=rngs)
|
1022
|
+
|
1023
|
+
def __call__(self, x, training: bool = False, return_activations_for_layer: tp.Optional[str] = None):
|
1024
|
+
activations = {}
|
1025
|
+
x = self.conv1(x)
|
1026
|
+
activations['conv1'] = x
|
1027
|
+
if return_activations_for_layer == 'conv1': return x
|
1028
|
+
x = self.dropout1(x, deterministic=not training)
|
1029
|
+
x = self.avg_pool(x)
|
1030
|
+
|
1031
|
+
x = self.conv2(x)
|
1032
|
+
activations['conv2'] = x
|
1033
|
+
if return_activations_for_layer == 'conv2': return x
|
1034
|
+
x = self.dropout2(x, deterministic=not training)
|
1035
|
+
x = self.avg_pool(x)
|
1036
|
+
|
1037
|
+
x = self.conv3(x)
|
1038
|
+
activations['conv3'] = x
|
1039
|
+
if return_activations_for_layer == 'conv3': return x
|
1040
|
+
x = self.dropout3(x, deterministic=not training)
|
1041
|
+
x = self.avg_pool(x)
|
1042
|
+
|
1043
|
+
x = self.conv4(x)
|
1044
|
+
activations['conv4'] = x
|
1045
|
+
if return_activations_for_layer == 'conv4': return x
|
1046
|
+
x = self.dropout4(x, deterministic=not training)
|
1047
|
+
x = self.avg_pool(x)
|
1048
|
+
|
1049
|
+
x = jnp.mean(x, axis=(1, 2))
|
1050
|
+
activations['global_avg_pool'] = x
|
1051
|
+
if return_activations_for_layer == 'global_avg_pool': return x
|
1052
|
+
|
1053
|
+
x = self.non_linear2(x)
|
1054
|
+
activations['final_layer'] = x
|
1055
|
+
if return_activations_for_layer == 'final_layer': return x
|
1056
|
+
|
1057
|
+
if return_activations_for_layer is not None and return_activations_for_layer not in activations:
|
1058
|
+
print(f"Warning: Layer '{return_activations_for_layer}' not found in YatCNN. Available: {list(activations.keys())}")
|
1059
|
+
# Fallback to returning final output if requested layer is not found after checking all
|
1060
|
+
return x
|
1061
|
+
|
1062
|
+
# Moved LinearCNN class definition to module level
|
1063
|
+
class LinearCNN(nnx.Module):
|
1064
|
+
"""Standard CNN model with linear layers."""
|
1065
|
+
|
1066
|
+
def __init__(self, *, num_classes: int, input_channels: int, rngs: nnx.Rngs):
|
1067
|
+
self.conv1 = nnx.Conv(input_channels, 32, kernel_size=(5, 5), rngs=rngs)
|
1068
|
+
self.conv2 = nnx.Conv(32, 64, kernel_size=(5, 5), rngs=rngs)
|
1069
|
+
self.conv3 = nnx.Conv(64, 128, kernel_size=(5, 5), rngs=rngs)
|
1070
|
+
self.conv4 = nnx.Conv(128, 128, kernel_size=(5, 5), rngs=rngs)
|
1071
|
+
self.dropout1 = nnx.Dropout(rate=0.3, rngs=rngs) # Note: different dropout rate
|
1072
|
+
self.dropout2 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1073
|
+
self.dropout3 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1074
|
+
self.dropout4 = nnx.Dropout(rate=0.3, rngs=rngs)
|
1075
|
+
|
1076
|
+
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
|
1077
|
+
self.linear2 = nnx.Linear(128, num_classes, rngs=rngs, use_bias=False)
|
1078
|
+
|
1079
|
+
def __call__(self, x, training: bool = False, return_activations_for_layer: tp.Optional[str] = None):
|
1080
|
+
activations = {}
|
1081
|
+
x = self.conv1(x)
|
1082
|
+
activations['conv1_raw'] = x # Raw output before ReLU
|
1083
|
+
if return_activations_for_layer == 'conv1_raw': return x
|
1084
|
+
x = nnx.relu(x)
|
1085
|
+
activations['conv1'] = x # Output after ReLU
|
1086
|
+
if return_activations_for_layer == 'conv1': return x
|
1087
|
+
x = self.dropout1(x, deterministic=not training)
|
1088
|
+
x = self.avg_pool(x)
|
1089
|
+
|
1090
|
+
x = self.conv2(x)
|
1091
|
+
activations['conv2_raw'] = x
|
1092
|
+
if return_activations_for_layer == 'conv2_raw': return x
|
1093
|
+
x = nnx.relu(x)
|
1094
|
+
activations['conv2'] = x
|
1095
|
+
if return_activations_for_layer == 'conv2': return x
|
1096
|
+
x = self.dropout2(x, deterministic=not training)
|
1097
|
+
x = self.avg_pool(x)
|
1098
|
+
|
1099
|
+
x = self.conv3(x)
|
1100
|
+
activations['conv3_raw'] = x
|
1101
|
+
if return_activations_for_layer == 'conv3_raw': return x
|
1102
|
+
x = nnx.relu(x)
|
1103
|
+
activations['conv3'] = x
|
1104
|
+
if return_activations_for_layer == 'conv3': return x
|
1105
|
+
x = self.dropout3(x, deterministic=not training)
|
1106
|
+
x = self.avg_pool(x)
|
1107
|
+
|
1108
|
+
x = self.conv4(x)
|
1109
|
+
activations['conv4_raw'] = x
|
1110
|
+
if return_activations_for_layer == 'conv4_raw': return x
|
1111
|
+
x = nnx.relu(x)
|
1112
|
+
activations['conv4'] = x
|
1113
|
+
if return_activations_for_layer == 'conv4': return x
|
1114
|
+
x = self.dropout4(x, deterministic=not training)
|
1115
|
+
x = self.avg_pool(x)
|
1116
|
+
|
1117
|
+
x = jnp.mean(x, axis=(1, 2))
|
1118
|
+
activations['global_avg_pool'] = x
|
1119
|
+
if return_activations_for_layer == 'global_avg_pool': return x
|
1120
|
+
|
1121
|
+
x = self.linear2(x)
|
1122
|
+
activations['final_layer'] = x
|
1123
|
+
if return_activations_for_layer == 'final_layer': return x
|
1124
|
+
|
1125
|
+
if return_activations_for_layer is not None and return_activations_for_layer not in activations:
|
1126
|
+
print(f"Warning: Layer '{return_activations_for_layer}' not found in LinearCNN. Available: {list(activations.keys())}")
|
1127
|
+
return x
|
1128
|
+
|
1129
|
+
# New helper function for the training loop
|
1130
|
+
def _train_model_loop(
|
1131
|
+
model_class: tp.Type[nnx.Module],
|
1132
|
+
model_name: str,
|
1133
|
+
dataset_name: str, # New argument
|
1134
|
+
rng_seed: int,
|
1135
|
+
learning_rate: float,
|
1136
|
+
momentum: float,
|
1137
|
+
optimizer_constructor: tp.Callable,
|
1138
|
+
):
|
1139
|
+
"""Helper function to train a model and return it with its metrics history."""
|
1140
|
+
print(f"Initializing {model_name} model for dataset {dataset_name}...")
|
1141
|
+
|
1142
|
+
config = DATASET_CONFIGS.get(dataset_name)
|
1143
|
+
ds_builder = tfds.builder(dataset_name)
|
1144
|
+
ds_info_for_model = ds_builder.info
|
1145
|
+
|
1146
|
+
if not config:
|
1147
|
+
try:
|
1148
|
+
num_classes = ds_info_for_model.features['label'].num_classes
|
1149
|
+
image_shape = ds_info_for_model.features['image'].shape
|
1150
|
+
input_channels = image_shape[-1] if len(image_shape) >= 3 else 1
|
1151
|
+
train_split_name = 'train'
|
1152
|
+
test_split_name = 'test'
|
1153
|
+
image_key = 'image'
|
1154
|
+
label_key = 'label'
|
1155
|
+
# Fallback training parameters if not in config
|
1156
|
+
current_num_epochs = _global_num_epochs # Use global default epochs
|
1157
|
+
current_eval_every = _global_eval_every
|
1158
|
+
current_batch_size = _global_batch_size
|
1159
|
+
print(f"Warning: Dataset '{dataset_name}' not in pre-defined configs. Inferred: num_classes={num_classes}, input_channels={input_channels}. Using global defaults for training params.")
|
1160
|
+
except Exception as e:
|
1161
|
+
raise ValueError(f"Dataset '{dataset_name}' not in configs and could not infer info: {e}")
|
1162
|
+
else:
|
1163
|
+
num_classes = config['num_classes']
|
1164
|
+
input_channels = config['input_channels']
|
1165
|
+
train_split_name = config['train_split']
|
1166
|
+
test_split_name = config['test_split']
|
1167
|
+
image_key = config['image_key']
|
1168
|
+
label_key = config['label_key']
|
1169
|
+
current_num_epochs = config['num_epochs']
|
1170
|
+
current_eval_every = config['eval_every']
|
1171
|
+
current_batch_size = config['batch_size']
|
1172
|
+
|
1173
|
+
model = model_class(num_classes=num_classes, input_channels=input_channels, rngs=nnx.Rngs(rng_seed))
|
1174
|
+
optimizer = nnx.Optimizer(model, optimizer_constructor(learning_rate, momentum))
|
1175
|
+
metrics_computer = nnx.MultiMetric(
|
1176
|
+
accuracy=nnx.metrics.Accuracy(),
|
1177
|
+
loss=nnx.metrics.Average('loss'),
|
1178
|
+
)
|
1179
|
+
|
1180
|
+
def preprocess_data_fn(sample):
|
1181
|
+
image = tf.cast(sample[image_key], tf.float32) / 255.0
|
1182
|
+
return {'image': image, 'label': sample[label_key]}
|
1183
|
+
|
1184
|
+
loaded_train_ds = tfds.load(dataset_name, split=train_split_name, as_supervised=False, shuffle_files=True)
|
1185
|
+
loaded_test_ds = tfds.load(dataset_name, split=test_split_name, as_supervised=False)
|
1186
|
+
|
1187
|
+
dataset_size = loaded_train_ds.cardinality().numpy()
|
1188
|
+
if dataset_size == tf.data.UNKNOWN_CARDINALITY or dataset_size == tf.data.INFINITE_CARDINALITY:
|
1189
|
+
raise ValueError(
|
1190
|
+
f"Cannot determine dataset size for '{dataset_name}' split '{train_split_name}' for epoch-based training. "
|
1191
|
+
f"Please ensure the dataset split has a known finite cardinality or revert to step-based training with .take()."
|
1192
|
+
)
|
1193
|
+
steps_per_epoch = dataset_size // current_batch_size
|
1194
|
+
total_expected_steps = current_num_epochs * steps_per_epoch
|
1195
|
+
|
1196
|
+
print(f"Training {model_name} on {dataset_name} for {current_num_epochs} epochs ({steps_per_epoch} steps/epoch, total {total_expected_steps} steps). Evaluating every {current_eval_every} steps.")
|
1197
|
+
|
1198
|
+
# Test dataset iterator (created once)
|
1199
|
+
dataset_test_iter = loaded_test_ds.map(preprocess_data_fn, num_parallel_calls=tf.data.AUTOTUNE) \
|
1200
|
+
.batch(current_batch_size, drop_remainder=True) \
|
1201
|
+
.prefetch(tf.data.AUTOTUNE)
|
1202
|
+
|
1203
|
+
metrics_history = {
|
1204
|
+
'train_loss': [], 'train_accuracy': [],
|
1205
|
+
'test_loss': [], 'test_accuracy': [],
|
1206
|
+
}
|
1207
|
+
|
1208
|
+
global_step_counter = 0
|
1209
|
+
for epoch in range(current_num_epochs):
|
1210
|
+
print(f" Epoch {epoch + 1}/{current_num_epochs}")
|
1211
|
+
# Create a new iterator for each epoch to ensure data is reshuffled if shuffle_files=True in tfds.load or .shuffle() is used effectively
|
1212
|
+
epoch_train_ds = loaded_train_ds.shuffle(buffer_size=1024) \
|
1213
|
+
.map(preprocess_data_fn, num_parallel_calls=tf.data.AUTOTUNE) \
|
1214
|
+
.batch(current_batch_size, drop_remainder=True) \
|
1215
|
+
.prefetch(tf.data.AUTOTUNE)
|
1216
|
+
|
1217
|
+
for batch_in_epoch, batch_data in enumerate(epoch_train_ds.as_numpy_iterator()):
|
1218
|
+
train_step(model, optimizer, metrics_computer, batch_data)
|
1219
|
+
|
1220
|
+
# Evaluation logic based on global step
|
1221
|
+
if global_step_counter > 0 and \
|
1222
|
+
(global_step_counter % current_eval_every == 0 or global_step_counter == total_expected_steps - 1) and \
|
1223
|
+
not (epoch == current_num_epochs - 1 and batch_in_epoch == steps_per_epoch -1 ): # Avoid double eval on last step if it aligns
|
1224
|
+
|
1225
|
+
computed_train_metrics = metrics_computer.compute()
|
1226
|
+
for metric_name_key, value in computed_train_metrics.items():
|
1227
|
+
metrics_history[f'train_{metric_name_key}'].append(value)
|
1228
|
+
metrics_computer.reset()
|
1229
|
+
|
1230
|
+
for test_batch in dataset_test_iter.as_numpy_iterator():
|
1231
|
+
eval_step(model, metrics_computer, test_batch)
|
1232
|
+
computed_test_metrics = metrics_computer.compute()
|
1233
|
+
for metric_name_key, value in computed_test_metrics.items():
|
1234
|
+
metrics_history[f'test_{metric_name_key}'].append(value)
|
1235
|
+
metrics_computer.reset()
|
1236
|
+
print(f" Step {global_step_counter}: {model_name} Train Acc = {metrics_history['train_accuracy'][-1]:.4f}, Test Acc = {metrics_history['test_accuracy'][-1]:.4f}")
|
1237
|
+
|
1238
|
+
global_step_counter += 1
|
1239
|
+
if global_step_counter >= total_expected_steps:
|
1240
|
+
break # Exit if total_expected_steps reached (e.g. if steps_per_epoch was rounded)
|
1241
|
+
|
1242
|
+
if global_step_counter >= total_expected_steps:
|
1243
|
+
break # Exit epoch loop as well
|
1244
|
+
|
1245
|
+
# Final evaluation at the end of all epochs if not captured by the step-based eval above
|
1246
|
+
print(f" Performing final evaluation for {model_name} after {current_num_epochs} epochs...")
|
1247
|
+
# Ensure train metrics for the last part of training are captured
|
1248
|
+
computed_train_metrics = metrics_computer.compute() # This captures metrics since last reset
|
1249
|
+
if computed_train_metrics and computed_train_metrics.get('loss') is not None: # Check if there are new metrics
|
1250
|
+
for metric_name_key, value in computed_train_metrics.items():
|
1251
|
+
metrics_history[f'train_{metric_name_key}'].append(value)
|
1252
|
+
metrics_computer.reset() # Reset for final test eval
|
1253
|
+
|
1254
|
+
for test_batch in dataset_test_iter.as_numpy_iterator():
|
1255
|
+
eval_step(model, metrics_computer, test_batch)
|
1256
|
+
computed_test_metrics = metrics_computer.compute()
|
1257
|
+
for metric_name_key, value in computed_test_metrics.items():
|
1258
|
+
metrics_history[f'test_{metric_name_key}'].append(value)
|
1259
|
+
metrics_computer.reset()
|
1260
|
+
|
1261
|
+
print(f"✅ {model_name} Model Training Complete on {dataset_name} after {current_num_epochs} epochs ({global_step_counter} steps)!")
|
1262
|
+
if metrics_history['test_accuracy']:
|
1263
|
+
print(f" Final Test Accuracy: {metrics_history['test_accuracy'][-1]:.4f}")
|
1264
|
+
else:
|
1265
|
+
print(f" No test accuracy recorded for {model_name}.")
|
1266
|
+
|
1267
|
+
return model, metrics_history
|
1268
|
+
|
1269
|
+
|
1270
|
+
# ===== NEW ADVANCED ANALYSIS FUNCTIONS =====
|
1271
|
+
|
1272
|
+
def visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16):
|
1273
|
+
"""
|
1274
|
+
Visualize the kernels of the first convolutional layer for both models.
|
1275
|
+
"""
|
1276
|
+
print(f"\n🎨 VISUALIZING KERNELS FROM LAYER: {layer_name}")
|
1277
|
+
print("=" * 50)
|
1278
|
+
|
1279
|
+
def get_kernels(model, layer_name_str):
|
1280
|
+
try:
|
1281
|
+
layer = getattr(model, layer_name_str)
|
1282
|
+
if hasattr(layer, 'kernel') and layer.kernel is not None:
|
1283
|
+
kernels = layer.kernel.value
|
1284
|
+
return kernels
|
1285
|
+
else:
|
1286
|
+
print(f"Kernel not found or is None in layer {layer_name_str} of {model.__class__.__name__}")
|
1287
|
+
return None
|
1288
|
+
except AttributeError:
|
1289
|
+
print(f"Layer {layer_name_str} not found in {model.__class__.__name__}")
|
1290
|
+
return None
|
1291
|
+
|
1292
|
+
yat_kernels = get_kernels(yat_model, layer_name)
|
1293
|
+
linear_kernels = get_kernels(linear_model, layer_name)
|
1294
|
+
|
1295
|
+
if yat_kernels is None and linear_kernels is None:
|
1296
|
+
print("Could not retrieve kernels for either model.")
|
1297
|
+
return
|
1298
|
+
|
1299
|
+
def plot_kernel_grid(kernels, model_name_str, num_kernels):
|
1300
|
+
if kernels is None:
|
1301
|
+
print(f"No kernels to plot for {model_name_str}")
|
1302
|
+
return
|
1303
|
+
|
1304
|
+
kh, kw, in_c, out_c = kernels.shape
|
1305
|
+
num_kernels = min(num_kernels, out_c)
|
1306
|
+
cols = int(np.ceil(np.sqrt(num_kernels)))
|
1307
|
+
rows = int(np.ceil(num_kernels / cols))
|
1308
|
+
|
1309
|
+
fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
|
1310
|
+
fig.suptitle(f'{model_name_str} - Layer {layer_name} Kernels (First {num_kernels} of {out_c})', fontsize=16)
|
1311
|
+
|
1312
|
+
for i in range(num_kernels):
|
1313
|
+
ax = axes.flat[i] if num_kernels > 1 else axes
|
1314
|
+
if i < out_c:
|
1315
|
+
kernel_slice = kernels[:, :, 0, i]
|
1316
|
+
kernel_slice = (kernel_slice - np.min(kernel_slice)) / (np.max(kernel_slice) - np.min(kernel_slice) + 1e-5)
|
1317
|
+
ax.imshow(kernel_slice, cmap='viridis')
|
1318
|
+
ax.set_title(f'Kernel {i+1}')
|
1319
|
+
ax.axis('off')
|
1320
|
+
else:
|
1321
|
+
ax.axis('off')
|
1322
|
+
|
1323
|
+
for i in range(num_kernels, len(axes.flat)):
|
1324
|
+
axes.flat[i].axis('off')
|
1325
|
+
|
1326
|
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
1327
|
+
plt.show()
|
1328
|
+
|
1329
|
+
if yat_kernels is not None:
|
1330
|
+
plot_kernel_grid(np.array(yat_kernels), "YAT Model", num_kernels_to_show)
|
1331
|
+
if linear_kernels is not None:
|
1332
|
+
plot_kernel_grid(np.array(linear_kernels), "Linear Model", num_kernels_to_show)
|
1333
|
+
|
1334
|
+
print("🖼️ Kernel visualization complete (if kernels were found and plotted).")
|
1335
|
+
|
1336
|
+
|
1337
|
+
def get_activation_maps(model, layer_name, input_sample, training=False):
|
1338
|
+
"""
|
1339
|
+
Extracts activation maps from a specified layer of the model
|
1340
|
+
by calling the model with the 'return_activations_for_layer' argument.
|
1341
|
+
"""
|
1342
|
+
try:
|
1343
|
+
# Call the model, requesting activations for the specified layer
|
1344
|
+
activations = model(input_sample, training=training, return_activations_for_layer=layer_name)
|
1345
|
+
|
1346
|
+
# The model's __call__ method now handles printing a warning if the layer is not found
|
1347
|
+
# and will return the final output in that case. Consumers of this function
|
1348
|
+
# should be aware of this behavior if the layer_name is mistyped.
|
1349
|
+
return activations
|
1350
|
+
|
1351
|
+
except Exception as e:
|
1352
|
+
print(f"Error getting activations for {layer_name} in {model.__class__.__name__}: {e}")
|
1353
|
+
return None
|
1354
|
+
|
1355
|
+
|
1356
|
+
def activation_map_visualization(yat_model, linear_model, test_ds_iter, layer_name='conv1', num_maps_to_show=16):
|
1357
|
+
"""
|
1358
|
+
Visualize activation maps from a specified layer for a sample input.
|
1359
|
+
test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
|
1360
|
+
"""
|
1361
|
+
print(f"\n🗺️ VISUALIZING ACTIVATION MAPS FROM LAYER: {layer_name}")
|
1362
|
+
print("=" * 50)
|
1363
|
+
|
1364
|
+
try:
|
1365
|
+
sample_batch = next(test_ds_iter.as_numpy_iterator())
|
1366
|
+
except tf.errors.OutOfRangeError:
|
1367
|
+
print("ERROR: Test dataset iterator for activation maps is exhausted. Consider re-creating it or passing a fresh one.")
|
1368
|
+
# Fallback: Try to use the global test_ds if available, but warn this might be for the wrong dataset
|
1369
|
+
try:
|
1370
|
+
print(f"Warning: Falling back to global test_ds for activation maps. This might use data from '{_DEFAULT_DATASET_FOR_GLOBALS}'.")
|
1371
|
+
sample_batch = next(test_ds.as_numpy_iterator()) # Global test_ds
|
1372
|
+
except Exception as e_global:
|
1373
|
+
print(f"Error: Could not get sample batch for activation maps: {e_global}")
|
1374
|
+
return
|
1375
|
+
|
1376
|
+
sample_image = sample_batch['image'][0:1] # Take the first image, keep batch dim
|
1377
|
+
|
1378
|
+
yat_activations = get_activation_maps(yat_model, layer_name, sample_image)
|
1379
|
+
linear_activations = get_activation_maps(linear_model, layer_name, sample_image)
|
1380
|
+
|
1381
|
+
if yat_activations is None and linear_activations is None:
|
1382
|
+
print("Could not retrieve activation maps for either model.")
|
1383
|
+
return
|
1384
|
+
|
1385
|
+
def plot_activation_grid(activations, model_name_str, num_maps):
|
1386
|
+
if activations is None:
|
1387
|
+
print(f"No activation maps to plot for {model_name_str}")
|
1388
|
+
return
|
1389
|
+
|
1390
|
+
activations_np = np.array(activations)
|
1391
|
+
if activations_np.ndim == 4:
|
1392
|
+
activations_np = activations_np[0]
|
1393
|
+
else:
|
1394
|
+
print(f"Unexpected activation shape for {model_name_str}: {activations_np.shape}")
|
1395
|
+
return
|
1396
|
+
|
1397
|
+
num_channels = activations_np.shape[-1]
|
1398
|
+
num_maps = min(num_maps, num_channels)
|
1399
|
+
cols = int(np.ceil(np.sqrt(num_maps)))
|
1400
|
+
rows = int(np.ceil(num_maps / cols))
|
1401
|
+
|
1402
|
+
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
|
1403
|
+
fig.suptitle(f'{model_name_str} - Layer {layer_name} Activation Maps (First {num_maps})', fontsize=16)
|
1404
|
+
|
1405
|
+
for i in range(num_maps):
|
1406
|
+
ax = axes.flat[i] if num_maps > 1 else axes
|
1407
|
+
if i < num_channels:
|
1408
|
+
ax.imshow(activations_np[:, :, i], cmap='viridis')
|
1409
|
+
ax.set_title(f'Map {i+1}')
|
1410
|
+
ax.axis('off')
|
1411
|
+
else:
|
1412
|
+
ax.axis('off')
|
1413
|
+
|
1414
|
+
for i in range(num_maps, len(axes.flat)):
|
1415
|
+
axes.flat[i].axis('off')
|
1416
|
+
|
1417
|
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
1418
|
+
plt.show()
|
1419
|
+
|
1420
|
+
if yat_activations is not None:
|
1421
|
+
plot_activation_grid(yat_activations, "YAT Model", num_maps_to_show)
|
1422
|
+
if linear_activations is not None:
|
1423
|
+
plot_activation_grid(linear_activations, "Linear Model", num_maps_to_show)
|
1424
|
+
|
1425
|
+
print("🗺️ Activation map visualization complete (if maps were found and plotted).")
|
1426
|
+
|
1427
|
+
|
1428
|
+
def saliency_map_analysis(yat_model, linear_model, test_ds_iter, class_names: list[str]):
|
1429
|
+
"""
|
1430
|
+
Generate and visualize saliency maps for both models.
|
1431
|
+
test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
|
1432
|
+
class_names: List of class names for the current dataset.
|
1433
|
+
"""
|
1434
|
+
print(f"\n🔥 SALIENCY MAP ANALYSIS for {len(class_names)} classes")
|
1435
|
+
print("=" * 50)
|
1436
|
+
|
1437
|
+
try:
|
1438
|
+
sample_batch = next(test_ds_iter.as_numpy_iterator())
|
1439
|
+
except tf.errors.OutOfRangeError:
|
1440
|
+
print("ERROR: Test dataset iterator for saliency maps is exhausted. Consider re-creating it or passing a fresh one.")
|
1441
|
+
# Fallback: Try to use the global test_ds if available, but warn this might be for the wrong dataset
|
1442
|
+
try:
|
1443
|
+
print(f"Warning: Falling back to global test_ds for saliency maps. This might use data from '{_DEFAULT_DATASET_FOR_GLOBALS}'.")
|
1444
|
+
sample_batch = next(test_ds.as_numpy_iterator()) # Global test_ds
|
1445
|
+
except Exception as e_global:
|
1446
|
+
print(f"Error: Could not get sample batch for saliency maps: {e_global}")
|
1447
|
+
return
|
1448
|
+
|
1449
|
+
sample_image = sample_batch['image'][0:1] # Take the first image, keep batch dim
|
1450
|
+
sample_label = int(sample_batch['label'][0]) # Ensure sample_label is a Python int
|
1451
|
+
|
1452
|
+
# cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
1453
|
+
true_class_name = class_names[sample_label]
|
1454
|
+
|
1455
|
+
@partial(jax.jit, static_argnums=(0, 2)) # Modified: Added 2 to static_argnums
|
1456
|
+
def get_saliency_map(model, image_input, class_index=None):
|
1457
|
+
def model_output_for_grad(img):
|
1458
|
+
logits = model(img, training=False)
|
1459
|
+
if class_index is not None:
|
1460
|
+
# Ensure class_index is valid for the current model's output logits
|
1461
|
+
num_model_classes = logits.shape[-1]
|
1462
|
+
if class_index >= num_model_classes:
|
1463
|
+
print(f"Warning: class_index {class_index} is out of bounds for model with {num_model_classes} classes. Using 0 instead.")
|
1464
|
+
safe_class_index = 0
|
1465
|
+
else:
|
1466
|
+
safe_class_index = class_index
|
1467
|
+
return logits[0, safe_class_index]
|
1468
|
+
else:
|
1469
|
+
return jnp.max(logits[0]) # Logit for the predicted class
|
1470
|
+
|
1471
|
+
grads = jax.grad(model_output_for_grad)(image_input)
|
1472
|
+
saliency = jnp.max(jnp.abs(grads[0]), axis=-1)
|
1473
|
+
return saliency
|
1474
|
+
|
1475
|
+
yat_logits_sample = yat_model(sample_image, training=False)
|
1476
|
+
yat_predicted_class_idx = int(jnp.argmax(yat_logits_sample, axis=1)[0]) # Ensure is Python int
|
1477
|
+
yat_predicted_class_name = class_names[yat_predicted_class_idx]
|
1478
|
+
|
1479
|
+
linear_logits_sample = linear_model(sample_image, training=False)
|
1480
|
+
linear_predicted_class_idx = int(jnp.argmax(linear_logits_sample, axis=1)[0]) # Ensure is Python int
|
1481
|
+
linear_predicted_class_name = class_names[linear_predicted_class_idx]
|
1482
|
+
|
1483
|
+
print(f"Sample image true class: {true_class_name} (Index: {sample_label})")
|
1484
|
+
print(f"YAT predicted class: {yat_predicted_class_name} (Index: {yat_predicted_class_idx})")
|
1485
|
+
print(f"Linear predicted class: {linear_predicted_class_name} (Index: {linear_predicted_class_idx})")
|
1486
|
+
|
1487
|
+
yat_saliency_true_class = get_saliency_map(yat_model, sample_image, class_index=sample_label)
|
1488
|
+
linear_saliency_true_class = get_saliency_map(linear_model, sample_image, class_index=sample_label)
|
1489
|
+
|
1490
|
+
yat_saliency_pred_class = get_saliency_map(yat_model, sample_image, class_index=yat_predicted_class_idx)
|
1491
|
+
linear_saliency_pred_class = get_saliency_map(linear_model, sample_image, class_index=linear_predicted_class_idx)
|
1492
|
+
|
1493
|
+
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
1494
|
+
fig.suptitle(f"Saliency Map Comparison (True Class: {true_class_name})", fontsize=16)
|
1495
|
+
|
1496
|
+
img_display = sample_image[0]
|
1497
|
+
img_display = (img_display - np.min(img_display)) / (np.max(img_display) - np.min(img_display) + 1e-5)
|
1498
|
+
|
1499
|
+
axes[0, 0].imshow(img_display)
|
1500
|
+
axes[0, 0].set_title(f"YAT Input (True: {true_class_name})")
|
1501
|
+
axes[0, 0].axis('off')
|
1502
|
+
|
1503
|
+
im1 = axes[0, 1].imshow(np.array(yat_saliency_true_class), cmap='hot')
|
1504
|
+
axes[0, 1].set_title(f"YAT Saliency (for True: {true_class_name})")
|
1505
|
+
axes[0, 1].axis('off')
|
1506
|
+
fig.colorbar(im1, ax=axes[0,1])
|
1507
|
+
|
1508
|
+
im2 = axes[0, 2].imshow(np.array(yat_saliency_pred_class), cmap='hot')
|
1509
|
+
axes[0, 2].set_title(f"YAT Saliency (for Pred: {yat_predicted_class_name})")
|
1510
|
+
axes[0, 2].axis('off')
|
1511
|
+
fig.colorbar(im2, ax=axes[0,2])
|
1512
|
+
|
1513
|
+
axes[1, 0].imshow(img_display)
|
1514
|
+
axes[1, 0].set_title(f"Linear Input (True: {true_class_name})")
|
1515
|
+
axes[1, 0].axis('off')
|
1516
|
+
|
1517
|
+
im3 = axes[1, 1].imshow(np.array(linear_saliency_true_class), cmap='hot')
|
1518
|
+
axes[1, 1].set_title(f"Linear Saliency (for True: {true_class_name})")
|
1519
|
+
axes[1, 1].axis('off')
|
1520
|
+
fig.colorbar(im3, ax=axes[1,1])
|
1521
|
+
|
1522
|
+
im4 = axes[1, 2].imshow(np.array(linear_saliency_pred_class), cmap='hot')
|
1523
|
+
axes[1, 2].set_title(f"Linear Saliency (for Pred: {linear_predicted_class_name})")
|
1524
|
+
axes[1, 2].axis('off')
|
1525
|
+
fig.colorbar(im4, ax=axes[1,2])
|
1526
|
+
|
1527
|
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
1528
|
+
plt.show()
|
1529
|
+
|
1530
|
+
print("🔥 Saliency map analysis and visualization complete.")
|
1531
|
+
|
1532
|
+
|
1533
|
+
|
1534
|
+
|
1535
|
+
def run_complete_comparison(dataset_name: str = 'cifar10'): # Added dataset_name argument
|
1536
|
+
"""
|
1537
|
+
Complete implementation that runs both models, saves metrics, and performs comparison.
|
1538
|
+
Run this function to get a full comparison between YAT and Linear models on the specified dataset.
|
1539
|
+
"""
|
1540
|
+
|
1541
|
+
print("\n" + "="*80)
|
1542
|
+
print(f" RUNNING COMPLETE MODEL COMPARISON FOR: {dataset_name.upper()}")
|
1543
|
+
print("="*80)
|
1544
|
+
|
1545
|
+
# Common training parameters (could be moved to DATASET_CONFIGS if they vary a lot)
|
1546
|
+
learning_rate = 0.003
|
1547
|
+
momentum = 0.9
|
1548
|
+
# current_train_steps, current_eval_every, current_batch_size are now fetched/calculated inside _train_model_loop
|
1549
|
+
|
1550
|
+
# Fetch dataset info for analysis functions that need it (e.g. class names)
|
1551
|
+
dataset_config = DATASET_CONFIGS.get(dataset_name, {})
|
1552
|
+
if not dataset_config:
|
1553
|
+
print(f"Warning: Dataset '{dataset_name}' not in DATASET_CONFIGS. Some features might use defaults or fail.")
|
1554
|
+
# Attempt to get class names using default label key if config is missing
|
1555
|
+
ds_builder_comp_fallback = tfds.builder(dataset_name)
|
1556
|
+
ds_info_comp_fallback = ds_builder_comp_fallback.info
|
1557
|
+
try:
|
1558
|
+
class_names_comp = ds_info_comp_fallback.features['label'].names
|
1559
|
+
except (KeyError, AttributeError):
|
1560
|
+
print(f"Could not infer class names for {dataset_name}, using a placeholder list.")
|
1561
|
+
# Fallback if even 'label' key doesn't work or has no names (e.g. regression task)
|
1562
|
+
try: # Try to get num_classes and create generic names
|
1563
|
+
num_classes_fallback = ds_info_comp_fallback.features['label'].num_classes
|
1564
|
+
class_names_comp = [f"Class {i}" for i in range(num_classes_fallback)]
|
1565
|
+
except: # Absolute fallback
|
1566
|
+
class_names_comp = ["Class 0", "Class 1", "Class 2", "Class 3", "Class 4", "Class 5", "Class 6", "Class 7", "Class 8", "Class 9"] # Default to 10 generic classes
|
1567
|
+
else:
|
1568
|
+
ds_builder_comp = tfds.builder(dataset_name)
|
1569
|
+
ds_info_comp = ds_builder_comp.info
|
1570
|
+
class_names_comp = ds_info_comp.features[dataset_config.get('label_key', 'label')].names
|
1571
|
+
|
1572
|
+
# Get batch size for the current dataset to correctly batch the evaluation dataset
|
1573
|
+
# If dataset_name is not in DATASET_CONFIGS, use a global default or a fallback.
|
1574
|
+
current_batch_size_for_eval = dataset_config.get('batch_size', _global_batch_size)
|
1575
|
+
|
1576
|
+
|
1577
|
+
# Step 1: Train YAT Model
|
1578
|
+
print(f"\n🚀 STEP 1: Training YAT Model on {dataset_name}...")
|
1579
|
+
print("-" * 50)
|
1580
|
+
|
1581
|
+
yat_model, yat_metrics_history = _train_model_loop(
|
1582
|
+
model_class=YatCNN,
|
1583
|
+
model_name="YAT",
|
1584
|
+
dataset_name=dataset_name, # Pass dataset_name
|
1585
|
+
rng_seed=0,
|
1586
|
+
learning_rate=learning_rate,
|
1587
|
+
momentum=momentum,
|
1588
|
+
optimizer_constructor=optax.adamw
|
1589
|
+
)
|
1590
|
+
|
1591
|
+
# Step 2: Train Linear Model
|
1592
|
+
print(f"\n🚀 STEP 2: Training Linear Model on {dataset_name}...")
|
1593
|
+
print("-" * 50)
|
1594
|
+
|
1595
|
+
linear_model, linear_metrics_history = _train_model_loop(
|
1596
|
+
model_class=LinearCNN,
|
1597
|
+
model_name="Linear",
|
1598
|
+
dataset_name=dataset_name, # Pass dataset_name
|
1599
|
+
rng_seed=0,
|
1600
|
+
learning_rate=learning_rate,
|
1601
|
+
momentum=momentum,
|
1602
|
+
optimizer_constructor=optax.adamw
|
1603
|
+
)
|
1604
|
+
|
1605
|
+
# Step 3: Run All Comparisons
|
1606
|
+
print(f"\n📊 STEP 3: Running Complete Comparison Analysis for {dataset_name}...")
|
1607
|
+
print("-" * 50)
|
1608
|
+
|
1609
|
+
# 3.1 Compare training curves
|
1610
|
+
print("\n📈 Comparing training curves...")
|
1611
|
+
compare_training_curves(yat_metrics_history, linear_metrics_history)
|
1612
|
+
|
1613
|
+
# 3.2 Print final metrics comparison
|
1614
|
+
print_final_metrics_comparison(yat_metrics_history, linear_metrics_history)
|
1615
|
+
|
1616
|
+
# 3.3 Analyze convergence
|
1617
|
+
analyze_convergence(yat_metrics_history, linear_metrics_history)
|
1618
|
+
|
1619
|
+
# 3.4 Detailed test evaluation
|
1620
|
+
# Need to reload the test dataset here as the one from _train_model_loop is consumed / specific to its scope
|
1621
|
+
# Or pass the models and a fresh test_ds_iterable to detailed_test_evaluation
|
1622
|
+
print("\n🎯 Running detailed test evaluation...")
|
1623
|
+
# Prepare test_ds specifically for detailed_test_evaluation and other analysis functions
|
1624
|
+
eval_config = DATASET_CONFIGS.get(dataset_name, {})
|
1625
|
+
eval_image_key = eval_config.get('image_key', 'image')
|
1626
|
+
eval_label_key = eval_config.get('label_key', 'label')
|
1627
|
+
eval_test_split = eval_config.get('test_split', 'test')
|
1628
|
+
|
1629
|
+
def eval_preprocess_fn(sample):
|
1630
|
+
return {
|
1631
|
+
'image': tf.cast(sample[eval_image_key], tf.float32) / 255.0,
|
1632
|
+
'label': sample[eval_label_key]
|
1633
|
+
}
|
1634
|
+
current_test_ds_for_eval = tfds.load(dataset_name, split=eval_test_split, as_supervised=False) \
|
1635
|
+
.map(eval_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE) \
|
1636
|
+
.batch(current_batch_size_for_eval, drop_remainder=True) \
|
1637
|
+
.prefetch(tf.data.AUTOTUNE)
|
1638
|
+
|
1639
|
+
predictions_data = detailed_test_evaluation(yat_model, linear_model, current_test_ds_for_eval, class_names=class_names_comp)
|
1640
|
+
|
1641
|
+
# 3.5 Plot confusion matrices
|
1642
|
+
print("\n📊 Plotting confusion matrices...")
|
1643
|
+
plot_confusion_matrices(predictions_data) # predictions_data now contains class_names
|
1644
|
+
|
1645
|
+
# 3.6 Generate comprehensive summary report
|
1646
|
+
generate_summary_report(yat_metrics_history, linear_metrics_history, predictions_data)
|
1647
|
+
|
1648
|
+
# Step 4: Advanced Analysis (New)
|
1649
|
+
print("\n🔬 STEP 4: Running Advanced Analysis...")
|
1650
|
+
print("-" * 50)
|
1651
|
+
|
1652
|
+
# 4.1 Visualize Kernels (e.g., from 'conv1')
|
1653
|
+
visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16)
|
1654
|
+
|
1655
|
+
# 4.2 Visualize Activation Maps (e.g., from 'conv1' for a sample from test_ds)
|
1656
|
+
# Use current_test_ds_for_eval or reload a small part of it
|
1657
|
+
activation_map_visualization(yat_model, linear_model, current_test_ds_for_eval, layer_name='conv1', num_maps_to_show=16)
|
1658
|
+
|
1659
|
+
# 4.3 Saliency Map Analysis
|
1660
|
+
saliency_map_analysis(yat_model, linear_model, current_test_ds_for_eval, class_names=class_names_comp)
|
1661
|
+
|
1662
|
+
print("\n" + "="*80)
|
1663
|
+
print(f" COMPARISON ANALYSIS FOR {dataset_name.upper()} COMPLETE! ✅")
|
1664
|
+
print("="*80)
|
1665
|
+
|
1666
|
+
return {
|
1667
|
+
'yat_model': yat_model,
|
1668
|
+
'linear_model': linear_model,
|
1669
|
+
'yat_metrics_history': yat_metrics_history,
|
1670
|
+
'linear_metrics_history': linear_metrics_history,
|
1671
|
+
'predictions_data': predictions_data
|
1672
|
+
}
|
1673
|
+
|
1674
|
+
# ===== QUICK START FUNCTIONS =====
|
1675
|
+
|
1676
|
+
def quick_comparison_demo():
|
1677
|
+
"""
|
1678
|
+
Quick demo that shows how to use the comparison functions with dummy data.
|
1679
|
+
Use this to test the comparison functions before running full training.
|
1680
|
+
"""
|
1681
|
+
print("\n🎬 RUNNING QUICK COMPARISON DEMO...")
|
1682
|
+
print("-" * 50)
|
1683
|
+
|
1684
|
+
# Create dummy metrics history for demonstration
|
1685
|
+
import random
|
1686
|
+
random.seed(42)
|
1687
|
+
|
1688
|
+
steps = 30
|
1689
|
+
yat_dummy = {
|
1690
|
+
'train_loss': [1.5 - 0.04*i + random.random()*0.1 for i in range(steps)],
|
1691
|
+
'train_accuracy': [0.2 + 0.025*i + random.random()*0.05 for i in range(steps)],
|
1692
|
+
'test_loss': [1.6 - 0.035*i + random.random()*0.15 for i in range(steps)],
|
1693
|
+
'test_accuracy': [0.15 + 0.022*i + random.random()*0.08 for i in range(steps)]
|
1694
|
+
}
|
1695
|
+
|
1696
|
+
linear_dummy = {
|
1697
|
+
'train_loss': [1.6 - 0.045*i + random.random()*0.1 for i in range(steps)],
|
1698
|
+
'train_accuracy': [0.18 + 0.024*i + random.random()*0.05 for i in range(steps)],
|
1699
|
+
'test_loss': [1.7 - 0.04*i + random.random()*0.15 for i in range(steps)],
|
1700
|
+
'test_accuracy': [0.12 + 0.023*i + random.random()*0.08 for i in range(steps)]
|
1701
|
+
}
|
1702
|
+
|
1703
|
+
print("📈 Comparing dummy training curves...")
|
1704
|
+
compare_training_curves(yat_dummy, linear_dummy)
|
1705
|
+
|
1706
|
+
print_final_metrics_comparison(yat_dummy, linear_dummy)
|
1707
|
+
analyze_convergence(yat_dummy, linear_dummy)
|
1708
|
+
|
1709
|
+
print("✅ Demo complete! Now you can run the full comparison with real models.")
|
1710
|
+
|
1711
|
+
def save_metrics_example():
|
1712
|
+
"""
|
1713
|
+
Shows how to properly save metrics history during training.
|
1714
|
+
"""
|
1715
|
+
print("\n💾 HOW TO SAVE METRICS DURING TRAINING:")
|
1716
|
+
print("-" * 50)
|
1717
|
+
print("""
|
1718
|
+
# After training your YAT model:
|
1719
|
+
yat_metrics_history = metrics_history.copy()
|
1720
|
+
|
1721
|
+
# After training your Linear model:
|
1722
|
+
linear_metrics_history = metrics_history.copy()
|
1723
|
+
|
1724
|
+
# Or save to files:
|
1725
|
+
import pickle
|
1726
|
+
with open('yat_metrics.pkl', 'wb') as f:
|
1727
|
+
pickle.dump(yat_metrics_history, f)
|
1728
|
+
|
1729
|
+
with open('linear_metrics.pkl', 'wb') as f:
|
1730
|
+
pickle.dump(linear_metrics_history, f)
|
1731
|
+
|
1732
|
+
# Load later:
|
1733
|
+
with open('yat_metrics.pkl', 'rb') as f:
|
1734
|
+
yat_metrics_history = pickle.load(f)
|
1735
|
+
|
1736
|
+
with open('linear_metrics.pkl', 'rb') as f:
|
1737
|
+
linear_metrics_history = pickle.load(f)
|
1738
|
+
""")
|
1739
|
+
|
1740
|
+
# Print final instructions
|
1741
|
+
print("\n" + "="*80)
|
1742
|
+
print("="*80)
|
1743
|
+
print("\n🚀 TO RUN THE COMPLETE COMPARISON (e.g., for CIFAR-10):")
|
1744
|
+
print(" results = run_complete_comparison(dataset_name='cifar10')")
|
1745
|
+
print("\n Other examples:")
|
1746
|
+
print(" results_cifar100 = run_complete_comparison(dataset_name='cifar100')")
|
1747
|
+
print(" results_stl10 = run_complete_comparison(dataset_name='stl10')")
|
1748
|
+
print(" results_eurosat_rgb = run_complete_comparison(dataset_name='eurosat/rgb')")
|
1749
|
+
# print(" results_eurosat_all = run_complete_comparison(dataset_name='eurosat/all')") # Might be slow due to 13 channels
|
1750
|
+
|
1751
|
+
print("\n🎬 TO RUN A QUICK DEMO (uses dummy data, not specific dataset):")
|
1752
|
+
print(" quick_comparison_demo()")
|
1753
|
+
print("\n💾 TO SEE HOW TO SAVE METRICS:")
|
1754
|
+
print(" save_metrics_example()")
|
1755
|
+
print("\n📖 The comparison functions are ready to use:")
|
1756
|
+
print(" - compare_training_curves(yat_history, linear_history)")
|
1757
|
+
print(" - print_final_metrics_comparison(yat_history, linear_history)")
|
1758
|
+
print(" - analyze_convergence(yat_history, linear_history)")
|
1759
|
+
print(" - detailed_test_evaluation(yat_model, linear_model, test_ds)")
|
1760
|
+
print(" - plot_confusion_matrices(predictions_data)")
|
1761
|
+
print(" - generate_summary_report(yat_history, linear_history, predictions_data)")
|
1762
|
+
print("\n ✨ NEW ADVANCED ANALYSIS:")
|
1763
|
+
print(" - visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16)")
|
1764
|
+
print(" - activation_map_visualization(yat_model, linear_model, test_ds_iter, layer_name='conv1', num_maps_to_show=16)")
|
1765
|
+
print(" - saliency_map_analysis(yat_model, linear_model, test_ds_iter, class_names=class_names_comp)")
|
1766
|
+
print("="*80)
|
1767
|
+
|
1768
|
+
|
1769
|
+
results_stl10 = run_complete_comparison(dataset_name='stl10')
|