nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1769 @@
1
+ import typing as tp
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from jax import lax
7
+
8
+ # Import libraries for comparison functions
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from sklearn.metrics import confusion_matrix
12
+
13
+ from flax import nnx
14
+ from flax.nnx.module import Module
15
+ from flax.nnx import rnglib
16
+ from flax.nnx.nn import dtypes, initializers
17
+ from flax.typing import (
18
+ Dtype,
19
+ Initializer,
20
+ PrecisionLike,
21
+ ConvGeneralDilatedT,
22
+ PaddingLike,
23
+ LaxPadding,
24
+ PromoteDtypeFn,
25
+ )
26
+
27
+ from __future__ import annotations
28
+
29
+ import typing as tp
30
+
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import numpy as np
34
+ from jax import lax
35
+ import opt_einsum
36
+
37
+ from flax.core.frozen_dict import FrozenDict
38
+ from flax import nnx
39
+ from flax.nnx import rnglib, variablelib
40
+ from flax.nnx.module import Module, first_from
41
+ from flax.nnx.nn import dtypes, initializers
42
+ from flax.typing import (
43
+ Dtype,
44
+ Shape,
45
+ Initializer,
46
+ PrecisionLike,
47
+ DotGeneralT,
48
+ ConvGeneralDilatedT,
49
+ PaddingLike,
50
+ LaxPadding,
51
+ PromoteDtypeFn,
52
+ EinsumT,
53
+ )
54
+
55
+
56
+ import tensorflow_datasets as tfds
57
+ import tensorflow as tf
58
+
59
+ from flax import nnx # The Flax NNX API.
60
+ from functools import partial
61
+ import optax
62
+
63
+ Array = jax.Array
64
+
65
+ # Default initializers
66
+ default_kernel_init = initializers.lecun_normal()
67
+ default_bias_init = initializers.zeros_init()
68
+ default_alpha_init = initializers.ones_init()
69
+
70
+ # Helper functions
71
+ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
72
+ """ "Canonicalizes conv padding to a jax.lax supported format."""
73
+ if isinstance(padding, str):
74
+ return padding
75
+ if isinstance(padding, int):
76
+ return [(padding, padding)] * rank
77
+ if isinstance(padding, tp.Sequence) and len(padding) == rank:
78
+ new_pad = []
79
+ for p in padding:
80
+ if isinstance(p, int):
81
+ new_pad.append((p, p))
82
+ elif isinstance(p, tuple) and len(p) == 2:
83
+ new_pad.append(p)
84
+ else:
85
+ break
86
+ if len(new_pad) == rank:
87
+ return new_pad
88
+ raise ValueError(
89
+ f'Invalid padding format: {padding}, should be str, int,'
90
+ f' or a sequence of len {rank} where each element is an'
91
+ ' int or pair of ints.'
92
+ )
93
+
94
+ def _conv_dimension_numbers(input_shape):
95
+ """Computes the dimension numbers based on the input shape."""
96
+ ndim = len(input_shape)
97
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
98
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
99
+ out_spec = lhs_spec
100
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
101
+
102
+ class YatConv(Module):
103
+ """Yat Convolution Module wrapping ``lax.conv_general_dilated``.
104
+
105
+ Example usage::
106
+
107
+ >>> from flax.nnx import conv # Assuming this file is flax/nnx/conv.py
108
+ >>> import jax, jax.numpy as jnp
109
+ >>> from flax.nnx import rnglib, state
110
+
111
+ >>> rngs = rnglib.Rngs(0)
112
+ >>> x = jnp.ones((1, 8, 3))
113
+
114
+ >>> # valid padding
115
+ >>> layer = conv.Conv(in_features=3, out_features=4, kernel_size=(3,),
116
+ ... padding='VALID', rngs=rngs)
117
+ >>> s = state(layer)
118
+ >>> print(s['kernel'].value.shape)
119
+ (3, 3, 4)
120
+ >>> print(s['bias'].value.shape)
121
+ (4,)
122
+ >>> out = layer(x)
123
+ >>> print(out.shape)
124
+ (1, 6, 4)
125
+
126
+ Args:
127
+ in_features: int or tuple with number of input features.
128
+ out_features: int or tuple with number of output features.
129
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
130
+ the kernel size can be passed as an integer, which will be interpreted
131
+ as a tuple of the single integer. For all other cases, it must be a
132
+ sequence of integers.
133
+ strides: an integer or a sequence of ``n`` integers, representing the
134
+ inter-window strides (default: 1).
135
+ padding: either the string ``'SAME'``, the string ``'VALID'``, the string
136
+ ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
137
+ (reflection across the padding boundary), or a sequence of ``n``
138
+ ``(low, high)`` integer pairs that give the padding to apply before and after each
139
+ spatial dimension. A single int is interpeted as applying the same padding
140
+ in all dims and passign a single int in a sequence causes the same padding
141
+ to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
142
+ left-pad the convolution axis, resulting in same-sized output.
143
+ input_dilation: an integer or a sequence of ``n`` integers, giving the
144
+ dilation factor to apply in each spatial dimension of ``inputs``
145
+ (default: 1). Convolution with input dilation ``d`` is equivalent to
146
+ transposed convolution with stride ``d``.
147
+ kernel_dilation: an integer or a sequence of ``n`` integers, giving the
148
+ dilation factor to apply in each spatial dimension of the convolution
149
+ kernel (default: 1). Convolution with kernel dilation
150
+ is also known as 'atrous convolution'.
151
+ feature_group_count: integer, default 1. If specified divides the input
152
+ features into groups.
153
+ use_bias: whether to add a bias to the output (default: True).
154
+ mask: Optional mask for the weights during masked convolution. The mask must
155
+ be the same shape as the convolution weight matrix.
156
+ dtype: the dtype of the computation (default: infer from input and params).
157
+ param_dtype: the dtype passed to parameter initializers (default: float32).
158
+ precision: numerical precision of the computation see ``jax.lax.Precision``
159
+ for details.
160
+ kernel_init: initializer for the convolutional kernel.
161
+ bias_init: initializer for the bias.
162
+ promote_dtype: function to promote the dtype of the arrays to the desired
163
+ dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
164
+ and a ``dtype`` keyword argument, and return a tuple of arrays with the
165
+ promoted dtype.
166
+ epsilon: A small float added to the denominator to prevent division by zero.
167
+ rngs: rng key.
168
+ """
169
+
170
+ __data__ = ('kernel', 'bias', 'mask', 'alpha')
171
+
172
+ def __init__(
173
+ self,
174
+ in_features: int,
175
+ out_features: int,
176
+ kernel_size: int | tp.Sequence[int],
177
+ strides: tp.Union[None, int, tp.Sequence[int]] = 1,
178
+ *,
179
+ padding: PaddingLike = 'SAME',
180
+ input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
181
+ kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
182
+ feature_group_count: int = 1,
183
+ use_bias: bool = True,
184
+ mask: tp.Optional[Array] = None,
185
+ dtype: tp.Optional[Dtype] = None,
186
+ param_dtype: Dtype = jnp.float32,
187
+ precision: PrecisionLike = None,
188
+ kernel_init: Initializer = default_kernel_init,
189
+ bias_init: Initializer = default_bias_init,
190
+ conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
191
+ promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
192
+ epsilon: float = 1/137,
193
+ use_alpha: bool = True,
194
+ alpha_init: Initializer = default_alpha_init,
195
+
196
+ rngs: rnglib.Rngs,
197
+ ):
198
+ if isinstance(kernel_size, int):
199
+ kernel_size = (kernel_size,)
200
+ else:
201
+ kernel_size = tuple(kernel_size)
202
+
203
+ self.kernel_shape = kernel_size + (
204
+ in_features // feature_group_count,
205
+ out_features,
206
+ )
207
+ kernel_key = rngs.params()
208
+ self.kernel = nnx.Param(kernel_init(kernel_key, self.kernel_shape, param_dtype))
209
+
210
+ self.bias: nnx.Param[jax.Array] | None
211
+ if use_bias:
212
+ bias_shape = (out_features,)
213
+ bias_key = rngs.params()
214
+ self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype))
215
+ else:
216
+ self.bias = None
217
+
218
+
219
+ self.alpha: nnx.Param[jax.Array] | None
220
+ if use_alpha:
221
+ alpha_key = rngs.params()
222
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
223
+ else:
224
+ self.alpha = None
225
+
226
+
227
+ self.in_features = in_features
228
+ self.out_features = out_features
229
+ self.kernel_size = kernel_size
230
+ self.strides = strides
231
+ self.padding = padding
232
+ self.input_dilation = input_dilation
233
+ self.kernel_dilation = kernel_dilation
234
+ self.feature_group_count = feature_group_count
235
+ self.use_bias = use_bias
236
+ self.mask = mask
237
+ self.dtype = dtype
238
+ self.param_dtype = param_dtype
239
+ self.precision = precision
240
+ self.kernel_init = kernel_init
241
+ self.bias_init = bias_init
242
+ self.conv_general_dilated = conv_general_dilated
243
+ self.promote_dtype = promote_dtype
244
+ self.epsilon = epsilon
245
+ self.use_alpha = use_alpha
246
+ self.alpha_init = alpha_init
247
+
248
+ def __call__(self, inputs: Array) -> Array:
249
+ assert isinstance(self.kernel_size, tuple)
250
+
251
+ def maybe_broadcast(
252
+ x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
253
+ ) -> tuple[int, ...]:
254
+ if x is None:
255
+ x = 1
256
+ if isinstance(x, int):
257
+ return (x,) * len(self.kernel_size)
258
+ return tuple(x)
259
+
260
+ num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
261
+ if num_batch_dimensions != 1:
262
+ input_batch_shape = inputs.shape[:num_batch_dimensions]
263
+ total_batch_size = int(np.prod(input_batch_shape))
264
+ flat_input_shape = (total_batch_size,) + inputs.shape[
265
+ num_batch_dimensions:
266
+ ]
267
+ inputs_flat = jnp.reshape(inputs, flat_input_shape)
268
+ else:
269
+ inputs_flat = inputs
270
+ input_batch_shape = ()
271
+
272
+ strides = maybe_broadcast(self.strides)
273
+ input_dilation = maybe_broadcast(self.input_dilation)
274
+ kernel_dilation = maybe_broadcast(self.kernel_dilation)
275
+
276
+ padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
277
+ if padding_lax in ('CIRCULAR', 'REFLECT'):
278
+ assert isinstance(padding_lax, str)
279
+ kernel_size_dilated = [
280
+ (k - 1) * d + 1 for k, d in zip(self.kernel_size, kernel_dilation)
281
+ ]
282
+ zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
283
+ pads = (
284
+ zero_pad
285
+ + [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
286
+ + [(0, 0)]
287
+ )
288
+ padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
289
+ inputs_flat = jnp.pad(inputs_flat, pads, mode=padding_mode)
290
+ padding_lax = 'VALID'
291
+ elif padding_lax == 'CAUSAL':
292
+ if len(self.kernel_size) != 1:
293
+ raise ValueError(
294
+ 'Causal padding is only implemented for 1D convolutions.'
295
+ )
296
+ left_pad = kernel_dilation[0] * (self.kernel_size[0] - 1)
297
+ pads = [(0, 0), (left_pad, 0), (0, 0)]
298
+ inputs_flat = jnp.pad(inputs_flat, pads)
299
+ padding_lax = 'VALID'
300
+
301
+ dimension_numbers = _conv_dimension_numbers(inputs_flat.shape)
302
+ assert self.in_features % self.feature_group_count == 0
303
+
304
+ kernel_val = self.kernel.value
305
+
306
+ current_mask = self.mask
307
+ if current_mask is not None:
308
+ if current_mask.shape != self.kernel_shape:
309
+ raise ValueError(
310
+ 'Mask needs to have the same shape as weights. '
311
+ f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
312
+ )
313
+ kernel_val *= current_mask
314
+
315
+ bias_val = self.bias.value if self.bias is not None else None
316
+
317
+ inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
318
+ (inputs_flat, kernel_val, bias_val), dtype=self.dtype
319
+ )
320
+ inputs_flat = inputs_promoted
321
+ kernel_val = kernel_promoted
322
+ bias_val = bias_promoted
323
+
324
+ dot_prod_map = self.conv_general_dilated(
325
+ inputs_flat,
326
+ kernel_val,
327
+ strides,
328
+ padding_lax,
329
+ lhs_dilation=input_dilation,
330
+ rhs_dilation=kernel_dilation,
331
+ dimension_numbers=dimension_numbers,
332
+ feature_group_count=self.feature_group_count,
333
+ precision=self.precision,
334
+ )
335
+
336
+ inputs_flat_squared = inputs_flat**2
337
+ kernel_in_channels_for_sum_sq = self.kernel_shape[-2]
338
+ kernel_for_patch_sq_sum_shape = self.kernel_size + (kernel_in_channels_for_sum_sq, 1)
339
+ kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
340
+
341
+ patch_sq_sum_map_raw = self.conv_general_dilated(
342
+ inputs_flat_squared,
343
+ kernel_for_patch_sq_sum,
344
+ strides,
345
+ padding_lax,
346
+ lhs_dilation=input_dilation,
347
+ rhs_dilation=kernel_dilation,
348
+ dimension_numbers=dimension_numbers,
349
+ feature_group_count=self.feature_group_count,
350
+ precision=self.precision,
351
+ )
352
+
353
+ if self.feature_group_count > 1:
354
+ num_out_channels_per_group = self.out_features // self.feature_group_count
355
+ if num_out_channels_per_group == 0 :
356
+ raise ValueError(
357
+ "out_features must be a multiple of feature_group_count and greater or equal."
358
+ )
359
+ patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, num_out_channels_per_group, axis=-1)
360
+ else:
361
+ patch_sq_sum_map = patch_sq_sum_map_raw
362
+
363
+ reduce_axes_for_kernel_sq = tuple(range(kernel_val.ndim - 1))
364
+ kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
365
+
366
+ distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
367
+ y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
368
+
369
+ if self.use_bias and bias_val is not None:
370
+ bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
371
+ y += jnp.reshape(bias_val, bias_reshape_dims)
372
+
373
+ if self.use_alpha and self.alpha is not None:
374
+ alpha_val = self.alpha.value
375
+ # Ensure alpha_val is promoted to the same dtype as y if needed, though usually it's float32.
376
+ # This might require using self.promote_dtype or ensuring consistent dtypes.
377
+ # For simplicity, assuming alpha_val.dtype is compatible or jax handles promotion.
378
+ scale = (jnp.sqrt(jnp.array(self.out_features, dtype=y.dtype)) /
379
+ jnp.log(1 + jnp.array(self.out_features, dtype=y.dtype))) ** alpha_val
380
+ y = y * scale
381
+
382
+ if num_batch_dimensions != 1:
383
+ output_shape = input_batch_shape + y.shape[1:]
384
+ y = jnp.reshape(y, output_shape)
385
+ return y
386
+
387
+ Array = jax.Array
388
+ Axis = int
389
+ Size = int
390
+
391
+
392
+ default_kernel_init = initializers.lecun_normal()
393
+ default_bias_init = initializers.zeros_init()
394
+ default_alpha_init = initializers.ones_init()
395
+
396
+ class YatNMN(Module):
397
+ """A linear transformation applied over the last dimension of the input.
398
+
399
+ Example usage::
400
+
401
+ >>> from flax import nnx
402
+ >>> import jax, jax.numpy as jnp
403
+
404
+ >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
405
+ >>> jax.tree.map(jnp.shape, nnx.state(layer))
406
+ State({
407
+ 'bias': VariableState(
408
+ type=Param,
409
+ value=(4,)
410
+ ),
411
+ 'kernel': VariableState(
412
+ type=Param,
413
+ value=(3, 4)
414
+ )
415
+ })
416
+
417
+ Args:
418
+ in_features: the number of input features.
419
+ out_features: the number of output features.
420
+ use_bias: whether to add a bias to the output (default: True).
421
+ dtype: the dtype of the computation (default: infer from input and params).
422
+ param_dtype: the dtype passed to parameter initializers (default: float32).
423
+ precision: numerical precision of the computation see ``jax.lax.Precision``
424
+ for details.
425
+ kernel_init: initializer function for the weight matrix.
426
+ bias_init: initializer function for the bias.
427
+ dot_general: dot product function.
428
+ promote_dtype: function to promote the dtype of the arrays to the desired
429
+ dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
430
+ and a ``dtype`` keyword argument, and return a tuple of arrays with the
431
+ promoted dtype.
432
+ rngs: rng key.
433
+ """
434
+
435
+ __data__ = ('kernel', 'bias')
436
+
437
+ def __init__(
438
+ self,
439
+ in_features: int,
440
+ out_features: int,
441
+ *,
442
+ use_bias: bool = True,
443
+ use_alpha: bool = True,
444
+ dtype: tp.Optional[Dtype] = None,
445
+ param_dtype: Dtype = jnp.float32,
446
+ precision: PrecisionLike = None,
447
+ kernel_init: Initializer = default_kernel_init,
448
+ bias_init: Initializer = default_bias_init,
449
+ alpha_init: Initializer = default_alpha_init,
450
+ dot_general: DotGeneralT = lax.dot_general,
451
+ promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
452
+ rngs: rnglib.Rngs,
453
+ epsilon: float = 1/137,
454
+ ):
455
+
456
+ kernel_key = rngs.params()
457
+ self.kernel = nnx.Param(
458
+ kernel_init(kernel_key, (in_features, out_features), param_dtype)
459
+ )
460
+ self.bias: nnx.Param[jax.Array] | None
461
+ if use_bias:
462
+ bias_key = rngs.params()
463
+ self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
464
+ else:
465
+ self.bias = None
466
+
467
+ self.alpha: nnx.Param[jax.Array] | None
468
+ if use_alpha:
469
+ alpha_key = rngs.params()
470
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
471
+ else:
472
+ self.alpha = None
473
+
474
+ self.in_features = in_features
475
+ self.out_features = out_features
476
+ self.use_bias = use_bias
477
+ self.use_alpha = use_alpha
478
+ self.dtype = dtype
479
+ self.param_dtype = param_dtype
480
+ self.precision = precision
481
+ self.kernel_init = kernel_init
482
+ self.bias_init = bias_init
483
+ self.dot_general = dot_general
484
+ self.promote_dtype = promote_dtype
485
+ self.epsilon = epsilon
486
+
487
+ def __call__(self, inputs: Array) -> Array:
488
+ """Applies a linear transformation to the inputs along the last dimension.
489
+
490
+ Args:
491
+ inputs: The nd-array to be transformed.
492
+
493
+ Returns:
494
+ The transformed input.
495
+ """
496
+ kernel = self.kernel.value
497
+ bias = self.bias.value if self.bias is not None else None
498
+ alpha = self.alpha.value if self.alpha is not None else None
499
+
500
+ inputs, kernel, bias, alpha = self.promote_dtype(
501
+ (inputs, kernel, bias, alpha), dtype=self.dtype
502
+ )
503
+ y = self.dot_general(
504
+ inputs,
505
+ kernel,
506
+ (((inputs.ndim - 1,), (0,)), ((), ())),
507
+ precision=self.precision,
508
+ )
509
+
510
+ assert self.use_bias == (bias is not None)
511
+ assert self.use_alpha == (alpha is not None)
512
+
513
+ inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
514
+ kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True) # Change axis to 0 and keepdims to True
515
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
516
+
517
+ # # Element-wise operation
518
+ y = y ** 2 / (distances + self.epsilon)
519
+
520
+ if bias is not None:
521
+ y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
522
+
523
+ if alpha is not None:
524
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
525
+ y = y * scale
526
+
527
+
528
+ return y
529
+
530
+ def loss_fn(model, batch):
531
+ logits = model(batch['image'], training=True)
532
+ loss = optax.softmax_cross_entropy_with_integer_labels(
533
+ logits=logits, labels=batch['label']
534
+ ).mean()
535
+ return loss, logits
536
+
537
+ @nnx.jit
538
+ def train_step(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
539
+ """Train for a single step."""
540
+ grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
541
+ (loss, logits), grads = grad_fn(model, batch)
542
+ metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
543
+ optimizer.update(grads) # In-place updates.
544
+
545
+ @nnx.jit
546
+ def eval_step(model, metrics: nnx.MultiMetric, batch):
547
+ loss, logits = loss_fn(model, batch)
548
+ metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
549
+
550
+ tf.random.set_seed(0)
551
+
552
+ # ===== DATASET CONFIGURATIONS =====
553
+ DATASET_CONFIGS = {
554
+ 'cifar10': {
555
+ 'num_classes': 10, 'input_channels': 3,
556
+ 'train_split': 'train', 'test_split': 'test',
557
+ 'image_key': 'image', 'label_key': 'label',
558
+ 'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
559
+ },
560
+ 'cifar100': {
561
+ 'num_classes': 100, 'input_channels': 3,
562
+ 'train_split': 'train', 'test_split': 'test',
563
+ 'image_key': 'image', 'label_key': 'label',
564
+ 'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
565
+ },
566
+ 'stl10': {
567
+ 'num_classes': 10, 'input_channels': 3,
568
+ 'train_split': 'train', 'test_split': 'test',
569
+ 'image_key': 'image', 'label_key': 'label',
570
+ 'num_epochs': 5, 'eval_every': 200, 'batch_size': 128
571
+ },
572
+ 'eurosat/rgb': {
573
+ 'num_classes': 10, 'input_channels': 3,
574
+ 'train_split': 'train[:80%]', 'test_split': 'train[80%:]',
575
+ 'image_key': 'image', 'label_key': 'label', # EuroSAT label key is 'label' in TFDS
576
+ 'num_epochs': 5, 'eval_every': 100, 'batch_size': 128
577
+ },
578
+ 'eurosat/all': {
579
+ 'num_classes': 10, 'input_channels': 13,
580
+ 'train_split': 'train[:80%]', 'test_split': 'train[80%:]',
581
+ 'image_key': 'image', 'label_key': 'label',
582
+ 'num_epochs': 5, 'eval_every': 100, 'batch_size': 16 # Smaller batch for more channels
583
+ },
584
+ # Example for a dataset that might need specific image resizing if models were not robust
585
+ # 'some_other_dataset': {
586
+ # 'num_classes': X, 'input_channels': Y,
587
+ # 'train_split': 'train', 'test_split': 'validation',
588
+ # 'image_key': 'image_data', 'label_key': 'class_id',
589
+ # 'target_image_size': [H, W] # Optional: for explicit resizing
590
+ # },
591
+ }
592
+
593
+ # Original global dataset setup (will be superseded by _train_model_loop for actual training runs)
594
+ # These might still be used by some top-level calls if not careful, or for initial exploration.
595
+ _DEFAULT_DATASET_FOR_GLOBALS = 'cifar10'
596
+
597
+ # Get default training parameters from the default dataset's config or set fallbacks
598
+ _default_config_for_globals = DATASET_CONFIGS.get(_DEFAULT_DATASET_FOR_GLOBALS, {})
599
+ _global_num_epochs = _default_config_for_globals.get('num_epochs', 10) # Default to 10 epochs
600
+ _global_eval_every = _default_config_for_globals.get('eval_every', 200)
601
+ _global_batch_size = _default_config_for_globals.get('batch_size', 64)
602
+
603
+
604
+ _global_ds_builder = tfds.builder(_DEFAULT_DATASET_FOR_GLOBALS)
605
+ _global_ds_info = _global_ds_builder.info
606
+
607
+ train_ds_global_tf: tf.data.Dataset = tfds.load(_DEFAULT_DATASET_FOR_GLOBALS, split='train')
608
+ test_ds_global_tf: tf.data.Dataset = tfds.load(_DEFAULT_DATASET_FOR_GLOBALS, split='test')
609
+
610
+ def _global_preprocess(sample):
611
+ return {
612
+ 'image': tf.cast(sample[DATASET_CONFIGS[_DEFAULT_DATASET_FOR_GLOBALS]['image_key']], tf.float32) / 255,
613
+ 'label': sample[DATASET_CONFIGS[_DEFAULT_DATASET_FOR_GLOBALS]['label_key']],
614
+ }
615
+
616
+ train_ds_global_tf = train_ds_global_tf.map(_global_preprocess)
617
+ test_ds_global_tf = test_ds_global_tf.map(_global_preprocess)
618
+
619
+ # Original global TF dataset iterators (used for some analysis functions if they don't reload)
620
+ # It's better if analysis functions requiring data get it passed or reload it with correct dataset_name
621
+ # Removing .take() from global train_ds to align with epoch-based approach; consumers must manage iterations.
622
+ train_ds = train_ds_global_tf.repeat().shuffle(1024).batch(_global_batch_size, drop_remainder=True).prefetch(1)
623
+ test_ds = test_ds_global_tf.batch(_global_batch_size, drop_remainder=True).prefetch(1)
624
+
625
+ # ===== MODEL COMPARISON FUNCTIONS =====
626
+
627
+ def compare_training_curves(yat_history, linear_history):
628
+ """
629
+ Compare training curves between YAT and Linear models.
630
+ Plots side-by-side comparison of loss and accuracy over training steps.
631
+ """
632
+ import matplotlib.pyplot as plt
633
+
634
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
635
+ fig.suptitle('Training Curves Comparison: YAT vs Linear Models', fontsize=16, fontweight='bold')
636
+
637
+ steps = range(len(yat_history['train_loss']))
638
+
639
+ # Training Loss
640
+ ax1.plot(steps, yat_history['train_loss'], 'b-', label='YAT Model', linewidth=2)
641
+ ax1.plot(steps, linear_history['train_loss'], 'r--', label='Linear Model', linewidth=2)
642
+ ax1.set_title('Training Loss', fontweight='bold')
643
+ ax1.set_xlabel('Evaluation Steps')
644
+ ax1.set_ylabel('Loss')
645
+ ax1.legend()
646
+ ax1.grid(True, alpha=0.3)
647
+
648
+ # Test Loss
649
+ ax2.plot(steps, yat_history['test_loss'], 'b-', label='YAT Model', linewidth=2)
650
+ ax2.plot(steps, linear_history['test_loss'], 'r--', label='Linear Model', linewidth=2)
651
+ ax2.set_title('Test Loss', fontweight='bold')
652
+ ax2.set_xlabel('Evaluation Steps')
653
+ ax2.set_ylabel('Loss')
654
+ ax2.legend()
655
+ ax2.grid(True, alpha=0.3)
656
+
657
+ # Training Accuracy
658
+ ax3.plot(steps, yat_history['train_accuracy'], 'b-', label='YAT Model', linewidth=2)
659
+ ax3.plot(steps, linear_history['train_accuracy'], 'r--', label='Linear Model', linewidth=2)
660
+ ax3.set_title('Training Accuracy', fontweight='bold')
661
+ ax3.set_xlabel('Evaluation Steps')
662
+ ax3.set_ylabel('Accuracy')
663
+ ax3.legend()
664
+ ax3.grid(True, alpha=0.3)
665
+
666
+ # Test Accuracy
667
+ ax4.plot(steps, yat_history['test_accuracy'], 'b-', label='YAT Model', linewidth=2)
668
+ ax4.plot(steps, linear_history['test_accuracy'], 'r--', label='Linear Model', linewidth=2)
669
+ ax4.set_title('Test Accuracy', fontweight='bold')
670
+ ax4.set_xlabel('Evaluation Steps')
671
+ ax4.set_ylabel('Accuracy')
672
+ ax4.legend()
673
+ ax4.grid(True, alpha=0.3)
674
+
675
+ plt.tight_layout()
676
+ plt.show()
677
+
678
+ print("📈 Training curves comparison plotted successfully!")
679
+
680
+ def print_final_metrics_comparison(yat_history, linear_history):
681
+ """
682
+ Print a detailed comparison table of final metrics.
683
+ """
684
+ print("\n📊 FINAL METRICS COMPARISON")
685
+ print("=" * 60)
686
+
687
+ yat_final = {
688
+ 'train_loss': yat_history['train_loss'][-1],
689
+ 'train_accuracy': yat_history['train_accuracy'][-1],
690
+ 'test_loss': yat_history['test_loss'][-1],
691
+ 'test_accuracy': yat_history['test_accuracy'][-1]
692
+ }
693
+
694
+ linear_final = {
695
+ 'train_loss': linear_history['train_loss'][-1],
696
+ 'train_accuracy': linear_history['train_accuracy'][-1],
697
+ 'test_loss': linear_history['test_loss'][-1],
698
+ 'test_accuracy': linear_history['test_accuracy'][-1]
699
+ }
700
+
701
+ print(f"{'Metric':<20} {'YAT Model':<15} {'Linear Model':<15} {'Difference':<15}")
702
+ print("-" * 65)
703
+
704
+ for metric in ['train_loss', 'test_loss', 'train_accuracy', 'test_accuracy']:
705
+ yat_val = yat_final[metric]
706
+ linear_val = linear_final[metric]
707
+ diff = yat_val - linear_val
708
+ diff_str = f"{diff:+.4f}"
709
+ if 'accuracy' in metric:
710
+ if diff > 0:
711
+ diff_str += " (YAT better)"
712
+ elif diff < 0:
713
+ diff_str += " (Linear better)"
714
+ else: # loss
715
+ if diff < 0:
716
+ diff_str += " (YAT better)"
717
+ elif diff > 0:
718
+ diff_str += " (Linear better)"
719
+
720
+ print(f"{metric:<20} {yat_val:<15.4f} {linear_val:<15.4f} {diff_str:<15}")
721
+
722
+ # Summary
723
+ print("\n🏆 SUMMARY:")
724
+ if yat_final['test_accuracy'] > linear_final['test_accuracy']:
725
+ winner = "YAT Model"
726
+ margin = yat_final['test_accuracy'] - linear_final['test_accuracy']
727
+ else:
728
+ winner = "Linear Model"
729
+ margin = linear_final['test_accuracy'] - yat_final['test_accuracy']
730
+
731
+ print(f" Better Test Accuracy: {winner} (by {margin:.4f})")
732
+ print(f" YAT Test Accuracy: {yat_final['test_accuracy']:.4f}")
733
+ print(f" Linear Test Accuracy: {linear_final['test_accuracy']:.4f}")
734
+
735
+ def analyze_convergence(yat_history, linear_history):
736
+ """
737
+ Analyze convergence speed and stability of both models.
738
+ """
739
+ print("\n🔍 CONVERGENCE ANALYSIS")
740
+ print("=" * 50)
741
+
742
+ def calculate_convergence_metrics(history):
743
+ test_acc = history['test_accuracy']
744
+ train_acc = history['train_accuracy']
745
+ test_loss = history['test_loss']
746
+
747
+ # Find step where model reaches 50% of final accuracy
748
+ final_acc = test_acc[-1]
749
+ target_acc = 0.5 * final_acc
750
+ convergence_step = 0
751
+ for i, acc in enumerate(test_acc):
752
+ if acc >= target_acc:
753
+ convergence_step = i
754
+ break
755
+
756
+ # Calculate stability (variance in last 25% of training)
757
+ last_quarter = len(test_acc) // 4
758
+ stability = np.std(test_acc[-last_quarter:])
759
+
760
+ # Calculate final overfitting (train_acc - test_acc)
761
+ overfitting = train_acc[-1] - test_acc[-1]
762
+
763
+ return {
764
+ 'convergence_step': convergence_step,
765
+ 'stability': stability,
766
+ 'overfitting': overfitting,
767
+ 'final_loss': test_loss[-1]
768
+ }
769
+
770
+ yat_conv = calculate_convergence_metrics(yat_history)
771
+ linear_conv = calculate_convergence_metrics(linear_history)
772
+
773
+ print(f"{'Metric':<25} {'YAT Model':<15} {'Linear Model':<15}")
774
+ print("-" * 55)
775
+ print(f"{'Convergence Speed':<25} {yat_conv['convergence_step']:<15} {linear_conv['convergence_step']:<15}")
776
+ print(f"{'Stability (std)':<25} {yat_conv['stability']:<15.4f} {linear_conv['stability']:<15.4f}")
777
+ print(f"{'Overfitting Gap':<25} {yat_conv['overfitting']:<15.4f} {linear_conv['overfitting']:<15.4f}")
778
+ print(f"{'Final Test Loss':<25} {yat_conv['final_loss']:<15.4f} {linear_conv['final_loss']:<15.4f}")
779
+
780
+ # Analysis
781
+ print("\n📋 ANALYSIS:")
782
+ if yat_conv['convergence_step'] < linear_conv['convergence_step']:
783
+ print(f" 🚀 YAT model converges faster (step {yat_conv['convergence_step']} vs {linear_conv['convergence_step']})")
784
+ else:
785
+ print(f" 🚀 Linear model converges faster (step {linear_conv['convergence_step']} vs {yat_conv['convergence_step']})")
786
+
787
+ if yat_conv['stability'] < linear_conv['stability']:
788
+ print(f" 📈 YAT model is more stable (std: {yat_conv['stability']:.4f} vs {linear_conv['stability']:.4f})")
789
+ else:
790
+ print(f" 📈 Linear model is more stable (std: {linear_conv['stability']:.4f} vs {yat_conv['stability']:.4f})")
791
+
792
+ if abs(yat_conv['overfitting']) < abs(linear_conv['overfitting']):
793
+ print(f" 🎯 YAT model has less overfitting (gap: {yat_conv['overfitting']:.4f} vs {linear_conv['overfitting']:.4f})")
794
+ else:
795
+ print(f" 🎯 Linear model has less overfitting (gap: {linear_conv['overfitting']:.4f} vs {yat_conv['overfitting']:.4f})")
796
+
797
+ def detailed_test_evaluation(yat_model, linear_model, test_ds_iter, class_names: list[str]):
798
+ """
799
+ Perform detailed evaluation on test set including per-class accuracy and model agreement.
800
+ test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
801
+ class_names: List of class names for the current dataset.
802
+ """
803
+ print("Running detailed test evaluation...")
804
+
805
+ # CIFAR-10 class names # This will be replaced by the passed class_names
806
+ # cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
807
+ num_classes = len(class_names)
808
+
809
+ yat_predictions = []
810
+ linear_predictions = []
811
+ true_labels = []
812
+
813
+ # Collect predictions from both models
814
+ for batch in test_ds_iter.as_numpy_iterator():
815
+ batch_images, batch_labels = batch['image'], batch['label']
816
+
817
+ # YAT model predictions
818
+ yat_logits = yat_model(batch_images, training=False)
819
+ yat_preds = jnp.argmax(yat_logits, axis=1)
820
+
821
+ # Linear model predictions
822
+ linear_logits = linear_model(batch_images, training=False)
823
+ linear_preds = jnp.argmax(linear_logits, axis=1)
824
+
825
+ yat_predictions.extend(yat_preds.tolist())
826
+ linear_predictions.extend(linear_preds.tolist())
827
+ true_labels.extend(batch_labels.tolist())
828
+
829
+ yat_predictions = np.array(yat_predictions)
830
+ linear_predictions = np.array(linear_predictions)
831
+ true_labels = np.array(true_labels)
832
+
833
+ # Calculate per-class accuracies
834
+ print("\n🎯 PER-CLASS ACCURACY COMPARISON")
835
+ print("=" * 70)
836
+ print(f"{'Class':<12} {'YAT Acc':<10} {'Linear Acc':<12} {'Difference':<15} {'Sample Count':<12}")
837
+ print("-" * 70)
838
+
839
+ for class_idx in range(num_classes): # Use num_classes from passed class_names
840
+ class_mask = true_labels == class_idx
841
+ class_samples = np.sum(class_mask)
842
+
843
+ if class_samples > 0:
844
+ yat_class_acc = np.mean(yat_predictions[class_mask] == true_labels[class_mask])
845
+ linear_class_acc = np.mean(linear_predictions[class_mask] == true_labels[class_mask])
846
+ diff = yat_class_acc - linear_class_acc
847
+ diff_str = f"{diff:+.4f}"
848
+
849
+ print(f"{class_names[class_idx]:<12} {yat_class_acc:<10.4f} {linear_class_acc:<12.4f} {diff_str:<15} {class_samples:<12}")
850
+ elif num_classes <= 20: # Only print for manageable number of classes if no samples
851
+ print(f"{class_names[class_idx]:<12} {'N/A':<10} {'N/A':<12} {'N/A':<15} {class_samples:<12}")
852
+
853
+ # Model agreement analysis
854
+ agreement = np.mean(yat_predictions == linear_predictions)
855
+ both_correct = np.mean((yat_predictions == true_labels) & (linear_predictions == true_labels))
856
+ yat_correct_linear_wrong = np.mean((yat_predictions == true_labels) & (linear_predictions != true_labels))
857
+ linear_correct_yat_wrong = np.mean((linear_predictions == true_labels) & (yat_predictions != true_labels))
858
+ both_wrong = np.mean((yat_predictions != true_labels) & (linear_predictions != true_labels))
859
+
860
+ print(f"\n🤝 MODEL AGREEMENT ANALYSIS")
861
+ print("=" * 40)
862
+ print(f"Overall Agreement: {agreement:.4f}")
863
+ print(f"Both Correct: {both_correct:.4f}")
864
+ print(f"YAT Correct, Linear Wrong: {yat_correct_linear_wrong:.4f}")
865
+ print(f"Linear Correct, YAT Wrong: {linear_correct_yat_wrong:.4f}")
866
+ print(f"Both Wrong: {both_wrong:.4f}")
867
+
868
+ return {
869
+ 'yat_predictions': yat_predictions,
870
+ 'linear_predictions': linear_predictions,
871
+ 'true_labels': true_labels,
872
+ 'class_names': class_names,
873
+ 'agreement': agreement,
874
+ 'both_correct': both_correct
875
+ }
876
+
877
+ def plot_confusion_matrices(predictions_data):
878
+ """
879
+ Plot confusion matrices for both models side by side.
880
+ """
881
+ from sklearn.metrics import confusion_matrix
882
+ import matplotlib.pyplot as plt
883
+ import seaborn as sns
884
+
885
+ yat_preds = predictions_data['yat_predictions']
886
+ linear_preds = predictions_data['linear_predictions']
887
+ true_labels = predictions_data['true_labels']
888
+ class_names = predictions_data['class_names']
889
+
890
+ # Calculate confusion matrices
891
+ yat_cm = confusion_matrix(true_labels, yat_preds)
892
+ linear_cm = confusion_matrix(true_labels, linear_preds)
893
+
894
+ # Plot side by side
895
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
896
+
897
+ # YAT model confusion matrix
898
+ sns.heatmap(yat_cm, annot=True, fmt='d', cmap='Blues',
899
+ xticklabels=class_names, yticklabels=class_names, ax=ax1)
900
+ ax1.set_title('YAT Model - Confusion Matrix', fontweight='bold')
901
+ ax1.set_xlabel('Predicted Label')
902
+ ax1.set_ylabel('True Label')
903
+
904
+ # Linear model confusion matrix
905
+ sns.heatmap(linear_cm, annot=True, fmt='d', cmap='Reds',
906
+ xticklabels=class_names, yticklabels=class_names, ax=ax2)
907
+ ax2.set_title('Linear Model - Confusion Matrix', fontweight='bold')
908
+ ax2.set_xlabel('Predicted Label')
909
+ ax2.set_ylabel('True Label')
910
+
911
+ plt.tight_layout()
912
+ plt.show()
913
+
914
+ print("📊 Confusion matrices plotted successfully!")
915
+
916
+ def generate_summary_report(yat_history, linear_history, predictions_data):
917
+ """
918
+ Generate a comprehensive summary report of the comparison.
919
+ """
920
+ print("\n" + "="*80)
921
+ print(" COMPREHENSIVE SUMMARY REPORT")
922
+ print("="*80)
923
+
924
+ # Final metrics
925
+ yat_final_acc = yat_history['test_accuracy'][-1]
926
+ linear_final_acc = linear_history['test_accuracy'][-1]
927
+
928
+ print(f"\n🏆 OVERALL WINNER:")
929
+ if yat_final_acc > linear_final_acc:
930
+ winner = "YAT Model"
931
+ margin = yat_final_acc - linear_final_acc
932
+ print(f" 🥇 {winner} wins by {margin:.4f} accuracy points!")
933
+ elif linear_final_acc > yat_final_acc:
934
+ winner = "Linear Model"
935
+ margin = linear_final_acc - yat_final_acc
936
+ print(f" 🥇 {winner} wins by {margin:.4f} accuracy points!")
937
+ else:
938
+ print(f" 🤝 It's a tie! Both models achieved {yat_final_acc:.4f} accuracy")
939
+
940
+ print(f"\n📈 PERFORMANCE SUMMARY:")
941
+ print(f" YAT Model Test Accuracy: {yat_final_acc:.4f}")
942
+ print(f" Linear Model Test Accuracy: {linear_final_acc:.4f}")
943
+ print(f" Model Agreement: {predictions_data['agreement']:.4f}")
944
+ print(f" Both Models Correct: {predictions_data['both_correct']:.4f}")
945
+
946
+ # Best and worst performing classes
947
+ class_names = predictions_data['class_names']
948
+ true_labels = predictions_data['true_labels']
949
+ yat_preds = predictions_data['yat_predictions']
950
+ linear_preds = predictions_data['linear_predictions']
951
+
952
+ yat_class_accs = []
953
+ linear_class_accs = []
954
+
955
+ for class_idx in range(len(class_names)):
956
+ class_mask = true_labels == class_idx
957
+ if np.sum(class_mask) > 0:
958
+ yat_acc = np.mean(yat_preds[class_mask] == true_labels[class_mask])
959
+ linear_acc = np.mean(linear_preds[class_mask] == true_labels[class_mask])
960
+ yat_class_accs.append((class_names[class_idx], yat_acc))
961
+ linear_class_accs.append((class_names[class_idx], linear_acc))
962
+
963
+ # Sort by accuracy
964
+ yat_class_accs.sort(key=lambda x: x[1], reverse=True)
965
+ linear_class_accs.sort(key=lambda x: x[1], reverse=True)
966
+
967
+ print(f"\n🎯 BEST PERFORMING CLASSES (Top 3 if available):")
968
+ for i in range(min(3, len(yat_class_accs))):
969
+ print(f" YAT Model: {yat_class_accs[i][0]} ({yat_class_accs[i][1]:.4f})")
970
+ for i in range(min(3, len(linear_class_accs))):
971
+ print(f" Linear Model: {linear_class_accs[i][0]} ({linear_class_accs[i][1]:.4f})")
972
+
973
+ print(f"\n🎯 WORST PERFORMING CLASSES (Bottom 3 if available):")
974
+ for i in range(min(3, len(yat_class_accs))):
975
+ print(f" YAT Model: {yat_class_accs[-(i+1)][0]} ({yat_class_accs[-(i+1)][1]:.4f})")
976
+ for i in range(min(3, len(linear_class_accs))):
977
+ print(f" Linear Model: {linear_class_accs[-(i+1)][0]} ({linear_class_accs[-(i+1)][1]:.4f})")
978
+
979
+ print(f"\n📊 TRAINING CHARACTERISTICS:")
980
+ if yat_history['train_accuracy'] and linear_history['train_accuracy'] and \
981
+ yat_history['test_accuracy'] and linear_history['test_accuracy']:
982
+ print(f" YAT Final Train Accuracy: {yat_history['train_accuracy'][-1]:.4f}")
983
+ print(f" Linear Final Train Accuracy: {linear_history['train_accuracy'][-1]:.4f}")
984
+ print(f" YAT Overfitting Gap: {yat_history['train_accuracy'][-1] - yat_history['test_accuracy'][-1]:.4f}")
985
+ print(f" Linear Overfitting Gap: {linear_history['train_accuracy'][-1] - linear_history['test_accuracy'][-1]:.4f}")
986
+ else:
987
+ print(" Training/Test accuracy history missing for full overfitting gap analysis.")
988
+
989
+ print(f"\n💡 RECOMMENDATIONS:")
990
+ if yat_final_acc > linear_final_acc:
991
+ print(f" ✅ YAT model architecture shows superior performance")
992
+ print(f" ✅ Consider using YAT layers for similar classification tasks")
993
+ else:
994
+ print(f" ✅ Linear model architecture is sufficient for this task")
995
+ print(f" ✅ Standard convolution layers perform well on CIFAR-10")
996
+
997
+ if predictions_data['agreement'] > 0.8:
998
+ print(f" 🤝 High model agreement suggests stable learning")
999
+ else:
1000
+ print(f" 🔍 Low model agreement suggests different learning patterns")
1001
+
1002
+ print("="*80)
1003
+
1004
+ # ===== COMPLETE IMPLEMENTATION EXAMPLE =====
1005
+
1006
+ # Moved YatCNN class definition to module level
1007
+ class YatCNN(nnx.Module):
1008
+ """YAT CNN model with custom layers."""
1009
+
1010
+ def __init__(self, *, num_classes: int, input_channels: int, rngs: nnx.Rngs):
1011
+ self.conv1 = YatConv(input_channels, 32, kernel_size=(5, 5), rngs=rngs)
1012
+ self.conv2 = YatConv(32, 64, kernel_size=(5, 5), rngs=rngs)
1013
+ self.conv3 = YatConv(64, 128, kernel_size=(5, 5), rngs=rngs)
1014
+ self.conv4 = YatConv(128, 128, kernel_size=(5, 5), rngs=rngs)
1015
+ self.dropout1 = nnx.Dropout(rate=0.3, rngs=rngs)
1016
+ self.dropout2 = nnx.Dropout(rate=0.3, rngs=rngs)
1017
+ self.dropout3 = nnx.Dropout(rate=0.3, rngs=rngs)
1018
+ self.dropout4 = nnx.Dropout(rate=0.3, rngs=rngs)
1019
+
1020
+ self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
1021
+ self.non_linear2 = YatNMN(128, num_classes, use_bias=False, use_alpha=False, rngs=rngs)
1022
+
1023
+ def __call__(self, x, training: bool = False, return_activations_for_layer: tp.Optional[str] = None):
1024
+ activations = {}
1025
+ x = self.conv1(x)
1026
+ activations['conv1'] = x
1027
+ if return_activations_for_layer == 'conv1': return x
1028
+ x = self.dropout1(x, deterministic=not training)
1029
+ x = self.avg_pool(x)
1030
+
1031
+ x = self.conv2(x)
1032
+ activations['conv2'] = x
1033
+ if return_activations_for_layer == 'conv2': return x
1034
+ x = self.dropout2(x, deterministic=not training)
1035
+ x = self.avg_pool(x)
1036
+
1037
+ x = self.conv3(x)
1038
+ activations['conv3'] = x
1039
+ if return_activations_for_layer == 'conv3': return x
1040
+ x = self.dropout3(x, deterministic=not training)
1041
+ x = self.avg_pool(x)
1042
+
1043
+ x = self.conv4(x)
1044
+ activations['conv4'] = x
1045
+ if return_activations_for_layer == 'conv4': return x
1046
+ x = self.dropout4(x, deterministic=not training)
1047
+ x = self.avg_pool(x)
1048
+
1049
+ x = jnp.mean(x, axis=(1, 2))
1050
+ activations['global_avg_pool'] = x
1051
+ if return_activations_for_layer == 'global_avg_pool': return x
1052
+
1053
+ x = self.non_linear2(x)
1054
+ activations['final_layer'] = x
1055
+ if return_activations_for_layer == 'final_layer': return x
1056
+
1057
+ if return_activations_for_layer is not None and return_activations_for_layer not in activations:
1058
+ print(f"Warning: Layer '{return_activations_for_layer}' not found in YatCNN. Available: {list(activations.keys())}")
1059
+ # Fallback to returning final output if requested layer is not found after checking all
1060
+ return x
1061
+
1062
+ # Moved LinearCNN class definition to module level
1063
+ class LinearCNN(nnx.Module):
1064
+ """Standard CNN model with linear layers."""
1065
+
1066
+ def __init__(self, *, num_classes: int, input_channels: int, rngs: nnx.Rngs):
1067
+ self.conv1 = nnx.Conv(input_channels, 32, kernel_size=(5, 5), rngs=rngs)
1068
+ self.conv2 = nnx.Conv(32, 64, kernel_size=(5, 5), rngs=rngs)
1069
+ self.conv3 = nnx.Conv(64, 128, kernel_size=(5, 5), rngs=rngs)
1070
+ self.conv4 = nnx.Conv(128, 128, kernel_size=(5, 5), rngs=rngs)
1071
+ self.dropout1 = nnx.Dropout(rate=0.3, rngs=rngs) # Note: different dropout rate
1072
+ self.dropout2 = nnx.Dropout(rate=0.3, rngs=rngs)
1073
+ self.dropout3 = nnx.Dropout(rate=0.3, rngs=rngs)
1074
+ self.dropout4 = nnx.Dropout(rate=0.3, rngs=rngs)
1075
+
1076
+ self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
1077
+ self.linear2 = nnx.Linear(128, num_classes, rngs=rngs, use_bias=False)
1078
+
1079
+ def __call__(self, x, training: bool = False, return_activations_for_layer: tp.Optional[str] = None):
1080
+ activations = {}
1081
+ x = self.conv1(x)
1082
+ activations['conv1_raw'] = x # Raw output before ReLU
1083
+ if return_activations_for_layer == 'conv1_raw': return x
1084
+ x = nnx.relu(x)
1085
+ activations['conv1'] = x # Output after ReLU
1086
+ if return_activations_for_layer == 'conv1': return x
1087
+ x = self.dropout1(x, deterministic=not training)
1088
+ x = self.avg_pool(x)
1089
+
1090
+ x = self.conv2(x)
1091
+ activations['conv2_raw'] = x
1092
+ if return_activations_for_layer == 'conv2_raw': return x
1093
+ x = nnx.relu(x)
1094
+ activations['conv2'] = x
1095
+ if return_activations_for_layer == 'conv2': return x
1096
+ x = self.dropout2(x, deterministic=not training)
1097
+ x = self.avg_pool(x)
1098
+
1099
+ x = self.conv3(x)
1100
+ activations['conv3_raw'] = x
1101
+ if return_activations_for_layer == 'conv3_raw': return x
1102
+ x = nnx.relu(x)
1103
+ activations['conv3'] = x
1104
+ if return_activations_for_layer == 'conv3': return x
1105
+ x = self.dropout3(x, deterministic=not training)
1106
+ x = self.avg_pool(x)
1107
+
1108
+ x = self.conv4(x)
1109
+ activations['conv4_raw'] = x
1110
+ if return_activations_for_layer == 'conv4_raw': return x
1111
+ x = nnx.relu(x)
1112
+ activations['conv4'] = x
1113
+ if return_activations_for_layer == 'conv4': return x
1114
+ x = self.dropout4(x, deterministic=not training)
1115
+ x = self.avg_pool(x)
1116
+
1117
+ x = jnp.mean(x, axis=(1, 2))
1118
+ activations['global_avg_pool'] = x
1119
+ if return_activations_for_layer == 'global_avg_pool': return x
1120
+
1121
+ x = self.linear2(x)
1122
+ activations['final_layer'] = x
1123
+ if return_activations_for_layer == 'final_layer': return x
1124
+
1125
+ if return_activations_for_layer is not None and return_activations_for_layer not in activations:
1126
+ print(f"Warning: Layer '{return_activations_for_layer}' not found in LinearCNN. Available: {list(activations.keys())}")
1127
+ return x
1128
+
1129
+ # New helper function for the training loop
1130
+ def _train_model_loop(
1131
+ model_class: tp.Type[nnx.Module],
1132
+ model_name: str,
1133
+ dataset_name: str, # New argument
1134
+ rng_seed: int,
1135
+ learning_rate: float,
1136
+ momentum: float,
1137
+ optimizer_constructor: tp.Callable,
1138
+ ):
1139
+ """Helper function to train a model and return it with its metrics history."""
1140
+ print(f"Initializing {model_name} model for dataset {dataset_name}...")
1141
+
1142
+ config = DATASET_CONFIGS.get(dataset_name)
1143
+ ds_builder = tfds.builder(dataset_name)
1144
+ ds_info_for_model = ds_builder.info
1145
+
1146
+ if not config:
1147
+ try:
1148
+ num_classes = ds_info_for_model.features['label'].num_classes
1149
+ image_shape = ds_info_for_model.features['image'].shape
1150
+ input_channels = image_shape[-1] if len(image_shape) >= 3 else 1
1151
+ train_split_name = 'train'
1152
+ test_split_name = 'test'
1153
+ image_key = 'image'
1154
+ label_key = 'label'
1155
+ # Fallback training parameters if not in config
1156
+ current_num_epochs = _global_num_epochs # Use global default epochs
1157
+ current_eval_every = _global_eval_every
1158
+ current_batch_size = _global_batch_size
1159
+ print(f"Warning: Dataset '{dataset_name}' not in pre-defined configs. Inferred: num_classes={num_classes}, input_channels={input_channels}. Using global defaults for training params.")
1160
+ except Exception as e:
1161
+ raise ValueError(f"Dataset '{dataset_name}' not in configs and could not infer info: {e}")
1162
+ else:
1163
+ num_classes = config['num_classes']
1164
+ input_channels = config['input_channels']
1165
+ train_split_name = config['train_split']
1166
+ test_split_name = config['test_split']
1167
+ image_key = config['image_key']
1168
+ label_key = config['label_key']
1169
+ current_num_epochs = config['num_epochs']
1170
+ current_eval_every = config['eval_every']
1171
+ current_batch_size = config['batch_size']
1172
+
1173
+ model = model_class(num_classes=num_classes, input_channels=input_channels, rngs=nnx.Rngs(rng_seed))
1174
+ optimizer = nnx.Optimizer(model, optimizer_constructor(learning_rate, momentum))
1175
+ metrics_computer = nnx.MultiMetric(
1176
+ accuracy=nnx.metrics.Accuracy(),
1177
+ loss=nnx.metrics.Average('loss'),
1178
+ )
1179
+
1180
+ def preprocess_data_fn(sample):
1181
+ image = tf.cast(sample[image_key], tf.float32) / 255.0
1182
+ return {'image': image, 'label': sample[label_key]}
1183
+
1184
+ loaded_train_ds = tfds.load(dataset_name, split=train_split_name, as_supervised=False, shuffle_files=True)
1185
+ loaded_test_ds = tfds.load(dataset_name, split=test_split_name, as_supervised=False)
1186
+
1187
+ dataset_size = loaded_train_ds.cardinality().numpy()
1188
+ if dataset_size == tf.data.UNKNOWN_CARDINALITY or dataset_size == tf.data.INFINITE_CARDINALITY:
1189
+ raise ValueError(
1190
+ f"Cannot determine dataset size for '{dataset_name}' split '{train_split_name}' for epoch-based training. "
1191
+ f"Please ensure the dataset split has a known finite cardinality or revert to step-based training with .take()."
1192
+ )
1193
+ steps_per_epoch = dataset_size // current_batch_size
1194
+ total_expected_steps = current_num_epochs * steps_per_epoch
1195
+
1196
+ print(f"Training {model_name} on {dataset_name} for {current_num_epochs} epochs ({steps_per_epoch} steps/epoch, total {total_expected_steps} steps). Evaluating every {current_eval_every} steps.")
1197
+
1198
+ # Test dataset iterator (created once)
1199
+ dataset_test_iter = loaded_test_ds.map(preprocess_data_fn, num_parallel_calls=tf.data.AUTOTUNE) \
1200
+ .batch(current_batch_size, drop_remainder=True) \
1201
+ .prefetch(tf.data.AUTOTUNE)
1202
+
1203
+ metrics_history = {
1204
+ 'train_loss': [], 'train_accuracy': [],
1205
+ 'test_loss': [], 'test_accuracy': [],
1206
+ }
1207
+
1208
+ global_step_counter = 0
1209
+ for epoch in range(current_num_epochs):
1210
+ print(f" Epoch {epoch + 1}/{current_num_epochs}")
1211
+ # Create a new iterator for each epoch to ensure data is reshuffled if shuffle_files=True in tfds.load or .shuffle() is used effectively
1212
+ epoch_train_ds = loaded_train_ds.shuffle(buffer_size=1024) \
1213
+ .map(preprocess_data_fn, num_parallel_calls=tf.data.AUTOTUNE) \
1214
+ .batch(current_batch_size, drop_remainder=True) \
1215
+ .prefetch(tf.data.AUTOTUNE)
1216
+
1217
+ for batch_in_epoch, batch_data in enumerate(epoch_train_ds.as_numpy_iterator()):
1218
+ train_step(model, optimizer, metrics_computer, batch_data)
1219
+
1220
+ # Evaluation logic based on global step
1221
+ if global_step_counter > 0 and \
1222
+ (global_step_counter % current_eval_every == 0 or global_step_counter == total_expected_steps - 1) and \
1223
+ not (epoch == current_num_epochs - 1 and batch_in_epoch == steps_per_epoch -1 ): # Avoid double eval on last step if it aligns
1224
+
1225
+ computed_train_metrics = metrics_computer.compute()
1226
+ for metric_name_key, value in computed_train_metrics.items():
1227
+ metrics_history[f'train_{metric_name_key}'].append(value)
1228
+ metrics_computer.reset()
1229
+
1230
+ for test_batch in dataset_test_iter.as_numpy_iterator():
1231
+ eval_step(model, metrics_computer, test_batch)
1232
+ computed_test_metrics = metrics_computer.compute()
1233
+ for metric_name_key, value in computed_test_metrics.items():
1234
+ metrics_history[f'test_{metric_name_key}'].append(value)
1235
+ metrics_computer.reset()
1236
+ print(f" Step {global_step_counter}: {model_name} Train Acc = {metrics_history['train_accuracy'][-1]:.4f}, Test Acc = {metrics_history['test_accuracy'][-1]:.4f}")
1237
+
1238
+ global_step_counter += 1
1239
+ if global_step_counter >= total_expected_steps:
1240
+ break # Exit if total_expected_steps reached (e.g. if steps_per_epoch was rounded)
1241
+
1242
+ if global_step_counter >= total_expected_steps:
1243
+ break # Exit epoch loop as well
1244
+
1245
+ # Final evaluation at the end of all epochs if not captured by the step-based eval above
1246
+ print(f" Performing final evaluation for {model_name} after {current_num_epochs} epochs...")
1247
+ # Ensure train metrics for the last part of training are captured
1248
+ computed_train_metrics = metrics_computer.compute() # This captures metrics since last reset
1249
+ if computed_train_metrics and computed_train_metrics.get('loss') is not None: # Check if there are new metrics
1250
+ for metric_name_key, value in computed_train_metrics.items():
1251
+ metrics_history[f'train_{metric_name_key}'].append(value)
1252
+ metrics_computer.reset() # Reset for final test eval
1253
+
1254
+ for test_batch in dataset_test_iter.as_numpy_iterator():
1255
+ eval_step(model, metrics_computer, test_batch)
1256
+ computed_test_metrics = metrics_computer.compute()
1257
+ for metric_name_key, value in computed_test_metrics.items():
1258
+ metrics_history[f'test_{metric_name_key}'].append(value)
1259
+ metrics_computer.reset()
1260
+
1261
+ print(f"✅ {model_name} Model Training Complete on {dataset_name} after {current_num_epochs} epochs ({global_step_counter} steps)!")
1262
+ if metrics_history['test_accuracy']:
1263
+ print(f" Final Test Accuracy: {metrics_history['test_accuracy'][-1]:.4f}")
1264
+ else:
1265
+ print(f" No test accuracy recorded for {model_name}.")
1266
+
1267
+ return model, metrics_history
1268
+
1269
+
1270
+ # ===== NEW ADVANCED ANALYSIS FUNCTIONS =====
1271
+
1272
+ def visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16):
1273
+ """
1274
+ Visualize the kernels of the first convolutional layer for both models.
1275
+ """
1276
+ print(f"\n🎨 VISUALIZING KERNELS FROM LAYER: {layer_name}")
1277
+ print("=" * 50)
1278
+
1279
+ def get_kernels(model, layer_name_str):
1280
+ try:
1281
+ layer = getattr(model, layer_name_str)
1282
+ if hasattr(layer, 'kernel') and layer.kernel is not None:
1283
+ kernels = layer.kernel.value
1284
+ return kernels
1285
+ else:
1286
+ print(f"Kernel not found or is None in layer {layer_name_str} of {model.__class__.__name__}")
1287
+ return None
1288
+ except AttributeError:
1289
+ print(f"Layer {layer_name_str} not found in {model.__class__.__name__}")
1290
+ return None
1291
+
1292
+ yat_kernels = get_kernels(yat_model, layer_name)
1293
+ linear_kernels = get_kernels(linear_model, layer_name)
1294
+
1295
+ if yat_kernels is None and linear_kernels is None:
1296
+ print("Could not retrieve kernels for either model.")
1297
+ return
1298
+
1299
+ def plot_kernel_grid(kernels, model_name_str, num_kernels):
1300
+ if kernels is None:
1301
+ print(f"No kernels to plot for {model_name_str}")
1302
+ return
1303
+
1304
+ kh, kw, in_c, out_c = kernels.shape
1305
+ num_kernels = min(num_kernels, out_c)
1306
+ cols = int(np.ceil(np.sqrt(num_kernels)))
1307
+ rows = int(np.ceil(num_kernels / cols))
1308
+
1309
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
1310
+ fig.suptitle(f'{model_name_str} - Layer {layer_name} Kernels (First {num_kernels} of {out_c})', fontsize=16)
1311
+
1312
+ for i in range(num_kernels):
1313
+ ax = axes.flat[i] if num_kernels > 1 else axes
1314
+ if i < out_c:
1315
+ kernel_slice = kernels[:, :, 0, i]
1316
+ kernel_slice = (kernel_slice - np.min(kernel_slice)) / (np.max(kernel_slice) - np.min(kernel_slice) + 1e-5)
1317
+ ax.imshow(kernel_slice, cmap='viridis')
1318
+ ax.set_title(f'Kernel {i+1}')
1319
+ ax.axis('off')
1320
+ else:
1321
+ ax.axis('off')
1322
+
1323
+ for i in range(num_kernels, len(axes.flat)):
1324
+ axes.flat[i].axis('off')
1325
+
1326
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
1327
+ plt.show()
1328
+
1329
+ if yat_kernels is not None:
1330
+ plot_kernel_grid(np.array(yat_kernels), "YAT Model", num_kernels_to_show)
1331
+ if linear_kernels is not None:
1332
+ plot_kernel_grid(np.array(linear_kernels), "Linear Model", num_kernels_to_show)
1333
+
1334
+ print("🖼️ Kernel visualization complete (if kernels were found and plotted).")
1335
+
1336
+
1337
+ def get_activation_maps(model, layer_name, input_sample, training=False):
1338
+ """
1339
+ Extracts activation maps from a specified layer of the model
1340
+ by calling the model with the 'return_activations_for_layer' argument.
1341
+ """
1342
+ try:
1343
+ # Call the model, requesting activations for the specified layer
1344
+ activations = model(input_sample, training=training, return_activations_for_layer=layer_name)
1345
+
1346
+ # The model's __call__ method now handles printing a warning if the layer is not found
1347
+ # and will return the final output in that case. Consumers of this function
1348
+ # should be aware of this behavior if the layer_name is mistyped.
1349
+ return activations
1350
+
1351
+ except Exception as e:
1352
+ print(f"Error getting activations for {layer_name} in {model.__class__.__name__}: {e}")
1353
+ return None
1354
+
1355
+
1356
+ def activation_map_visualization(yat_model, linear_model, test_ds_iter, layer_name='conv1', num_maps_to_show=16):
1357
+ """
1358
+ Visualize activation maps from a specified layer for a sample input.
1359
+ test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
1360
+ """
1361
+ print(f"\n🗺️ VISUALIZING ACTIVATION MAPS FROM LAYER: {layer_name}")
1362
+ print("=" * 50)
1363
+
1364
+ try:
1365
+ sample_batch = next(test_ds_iter.as_numpy_iterator())
1366
+ except tf.errors.OutOfRangeError:
1367
+ print("ERROR: Test dataset iterator for activation maps is exhausted. Consider re-creating it or passing a fresh one.")
1368
+ # Fallback: Try to use the global test_ds if available, but warn this might be for the wrong dataset
1369
+ try:
1370
+ print(f"Warning: Falling back to global test_ds for activation maps. This might use data from '{_DEFAULT_DATASET_FOR_GLOBALS}'.")
1371
+ sample_batch = next(test_ds.as_numpy_iterator()) # Global test_ds
1372
+ except Exception as e_global:
1373
+ print(f"Error: Could not get sample batch for activation maps: {e_global}")
1374
+ return
1375
+
1376
+ sample_image = sample_batch['image'][0:1] # Take the first image, keep batch dim
1377
+
1378
+ yat_activations = get_activation_maps(yat_model, layer_name, sample_image)
1379
+ linear_activations = get_activation_maps(linear_model, layer_name, sample_image)
1380
+
1381
+ if yat_activations is None and linear_activations is None:
1382
+ print("Could not retrieve activation maps for either model.")
1383
+ return
1384
+
1385
+ def plot_activation_grid(activations, model_name_str, num_maps):
1386
+ if activations is None:
1387
+ print(f"No activation maps to plot for {model_name_str}")
1388
+ return
1389
+
1390
+ activations_np = np.array(activations)
1391
+ if activations_np.ndim == 4:
1392
+ activations_np = activations_np[0]
1393
+ else:
1394
+ print(f"Unexpected activation shape for {model_name_str}: {activations_np.shape}")
1395
+ return
1396
+
1397
+ num_channels = activations_np.shape[-1]
1398
+ num_maps = min(num_maps, num_channels)
1399
+ cols = int(np.ceil(np.sqrt(num_maps)))
1400
+ rows = int(np.ceil(num_maps / cols))
1401
+
1402
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
1403
+ fig.suptitle(f'{model_name_str} - Layer {layer_name} Activation Maps (First {num_maps})', fontsize=16)
1404
+
1405
+ for i in range(num_maps):
1406
+ ax = axes.flat[i] if num_maps > 1 else axes
1407
+ if i < num_channels:
1408
+ ax.imshow(activations_np[:, :, i], cmap='viridis')
1409
+ ax.set_title(f'Map {i+1}')
1410
+ ax.axis('off')
1411
+ else:
1412
+ ax.axis('off')
1413
+
1414
+ for i in range(num_maps, len(axes.flat)):
1415
+ axes.flat[i].axis('off')
1416
+
1417
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
1418
+ plt.show()
1419
+
1420
+ if yat_activations is not None:
1421
+ plot_activation_grid(yat_activations, "YAT Model", num_maps_to_show)
1422
+ if linear_activations is not None:
1423
+ plot_activation_grid(linear_activations, "Linear Model", num_maps_to_show)
1424
+
1425
+ print("🗺️ Activation map visualization complete (if maps were found and plotted).")
1426
+
1427
+
1428
+ def saliency_map_analysis(yat_model, linear_model, test_ds_iter, class_names: list[str]):
1429
+ """
1430
+ Generate and visualize saliency maps for both models.
1431
+ test_ds_iter: An iterable TFDS dataset (already batched and preprocessed).
1432
+ class_names: List of class names for the current dataset.
1433
+ """
1434
+ print(f"\n🔥 SALIENCY MAP ANALYSIS for {len(class_names)} classes")
1435
+ print("=" * 50)
1436
+
1437
+ try:
1438
+ sample_batch = next(test_ds_iter.as_numpy_iterator())
1439
+ except tf.errors.OutOfRangeError:
1440
+ print("ERROR: Test dataset iterator for saliency maps is exhausted. Consider re-creating it or passing a fresh one.")
1441
+ # Fallback: Try to use the global test_ds if available, but warn this might be for the wrong dataset
1442
+ try:
1443
+ print(f"Warning: Falling back to global test_ds for saliency maps. This might use data from '{_DEFAULT_DATASET_FOR_GLOBALS}'.")
1444
+ sample_batch = next(test_ds.as_numpy_iterator()) # Global test_ds
1445
+ except Exception as e_global:
1446
+ print(f"Error: Could not get sample batch for saliency maps: {e_global}")
1447
+ return
1448
+
1449
+ sample_image = sample_batch['image'][0:1] # Take the first image, keep batch dim
1450
+ sample_label = int(sample_batch['label'][0]) # Ensure sample_label is a Python int
1451
+
1452
+ # cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
1453
+ true_class_name = class_names[sample_label]
1454
+
1455
+ @partial(jax.jit, static_argnums=(0, 2)) # Modified: Added 2 to static_argnums
1456
+ def get_saliency_map(model, image_input, class_index=None):
1457
+ def model_output_for_grad(img):
1458
+ logits = model(img, training=False)
1459
+ if class_index is not None:
1460
+ # Ensure class_index is valid for the current model's output logits
1461
+ num_model_classes = logits.shape[-1]
1462
+ if class_index >= num_model_classes:
1463
+ print(f"Warning: class_index {class_index} is out of bounds for model with {num_model_classes} classes. Using 0 instead.")
1464
+ safe_class_index = 0
1465
+ else:
1466
+ safe_class_index = class_index
1467
+ return logits[0, safe_class_index]
1468
+ else:
1469
+ return jnp.max(logits[0]) # Logit for the predicted class
1470
+
1471
+ grads = jax.grad(model_output_for_grad)(image_input)
1472
+ saliency = jnp.max(jnp.abs(grads[0]), axis=-1)
1473
+ return saliency
1474
+
1475
+ yat_logits_sample = yat_model(sample_image, training=False)
1476
+ yat_predicted_class_idx = int(jnp.argmax(yat_logits_sample, axis=1)[0]) # Ensure is Python int
1477
+ yat_predicted_class_name = class_names[yat_predicted_class_idx]
1478
+
1479
+ linear_logits_sample = linear_model(sample_image, training=False)
1480
+ linear_predicted_class_idx = int(jnp.argmax(linear_logits_sample, axis=1)[0]) # Ensure is Python int
1481
+ linear_predicted_class_name = class_names[linear_predicted_class_idx]
1482
+
1483
+ print(f"Sample image true class: {true_class_name} (Index: {sample_label})")
1484
+ print(f"YAT predicted class: {yat_predicted_class_name} (Index: {yat_predicted_class_idx})")
1485
+ print(f"Linear predicted class: {linear_predicted_class_name} (Index: {linear_predicted_class_idx})")
1486
+
1487
+ yat_saliency_true_class = get_saliency_map(yat_model, sample_image, class_index=sample_label)
1488
+ linear_saliency_true_class = get_saliency_map(linear_model, sample_image, class_index=sample_label)
1489
+
1490
+ yat_saliency_pred_class = get_saliency_map(yat_model, sample_image, class_index=yat_predicted_class_idx)
1491
+ linear_saliency_pred_class = get_saliency_map(linear_model, sample_image, class_index=linear_predicted_class_idx)
1492
+
1493
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
1494
+ fig.suptitle(f"Saliency Map Comparison (True Class: {true_class_name})", fontsize=16)
1495
+
1496
+ img_display = sample_image[0]
1497
+ img_display = (img_display - np.min(img_display)) / (np.max(img_display) - np.min(img_display) + 1e-5)
1498
+
1499
+ axes[0, 0].imshow(img_display)
1500
+ axes[0, 0].set_title(f"YAT Input (True: {true_class_name})")
1501
+ axes[0, 0].axis('off')
1502
+
1503
+ im1 = axes[0, 1].imshow(np.array(yat_saliency_true_class), cmap='hot')
1504
+ axes[0, 1].set_title(f"YAT Saliency (for True: {true_class_name})")
1505
+ axes[0, 1].axis('off')
1506
+ fig.colorbar(im1, ax=axes[0,1])
1507
+
1508
+ im2 = axes[0, 2].imshow(np.array(yat_saliency_pred_class), cmap='hot')
1509
+ axes[0, 2].set_title(f"YAT Saliency (for Pred: {yat_predicted_class_name})")
1510
+ axes[0, 2].axis('off')
1511
+ fig.colorbar(im2, ax=axes[0,2])
1512
+
1513
+ axes[1, 0].imshow(img_display)
1514
+ axes[1, 0].set_title(f"Linear Input (True: {true_class_name})")
1515
+ axes[1, 0].axis('off')
1516
+
1517
+ im3 = axes[1, 1].imshow(np.array(linear_saliency_true_class), cmap='hot')
1518
+ axes[1, 1].set_title(f"Linear Saliency (for True: {true_class_name})")
1519
+ axes[1, 1].axis('off')
1520
+ fig.colorbar(im3, ax=axes[1,1])
1521
+
1522
+ im4 = axes[1, 2].imshow(np.array(linear_saliency_pred_class), cmap='hot')
1523
+ axes[1, 2].set_title(f"Linear Saliency (for Pred: {linear_predicted_class_name})")
1524
+ axes[1, 2].axis('off')
1525
+ fig.colorbar(im4, ax=axes[1,2])
1526
+
1527
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
1528
+ plt.show()
1529
+
1530
+ print("🔥 Saliency map analysis and visualization complete.")
1531
+
1532
+
1533
+
1534
+
1535
+ def run_complete_comparison(dataset_name: str = 'cifar10'): # Added dataset_name argument
1536
+ """
1537
+ Complete implementation that runs both models, saves metrics, and performs comparison.
1538
+ Run this function to get a full comparison between YAT and Linear models on the specified dataset.
1539
+ """
1540
+
1541
+ print("\n" + "="*80)
1542
+ print(f" RUNNING COMPLETE MODEL COMPARISON FOR: {dataset_name.upper()}")
1543
+ print("="*80)
1544
+
1545
+ # Common training parameters (could be moved to DATASET_CONFIGS if they vary a lot)
1546
+ learning_rate = 0.003
1547
+ momentum = 0.9
1548
+ # current_train_steps, current_eval_every, current_batch_size are now fetched/calculated inside _train_model_loop
1549
+
1550
+ # Fetch dataset info for analysis functions that need it (e.g. class names)
1551
+ dataset_config = DATASET_CONFIGS.get(dataset_name, {})
1552
+ if not dataset_config:
1553
+ print(f"Warning: Dataset '{dataset_name}' not in DATASET_CONFIGS. Some features might use defaults or fail.")
1554
+ # Attempt to get class names using default label key if config is missing
1555
+ ds_builder_comp_fallback = tfds.builder(dataset_name)
1556
+ ds_info_comp_fallback = ds_builder_comp_fallback.info
1557
+ try:
1558
+ class_names_comp = ds_info_comp_fallback.features['label'].names
1559
+ except (KeyError, AttributeError):
1560
+ print(f"Could not infer class names for {dataset_name}, using a placeholder list.")
1561
+ # Fallback if even 'label' key doesn't work or has no names (e.g. regression task)
1562
+ try: # Try to get num_classes and create generic names
1563
+ num_classes_fallback = ds_info_comp_fallback.features['label'].num_classes
1564
+ class_names_comp = [f"Class {i}" for i in range(num_classes_fallback)]
1565
+ except: # Absolute fallback
1566
+ class_names_comp = ["Class 0", "Class 1", "Class 2", "Class 3", "Class 4", "Class 5", "Class 6", "Class 7", "Class 8", "Class 9"] # Default to 10 generic classes
1567
+ else:
1568
+ ds_builder_comp = tfds.builder(dataset_name)
1569
+ ds_info_comp = ds_builder_comp.info
1570
+ class_names_comp = ds_info_comp.features[dataset_config.get('label_key', 'label')].names
1571
+
1572
+ # Get batch size for the current dataset to correctly batch the evaluation dataset
1573
+ # If dataset_name is not in DATASET_CONFIGS, use a global default or a fallback.
1574
+ current_batch_size_for_eval = dataset_config.get('batch_size', _global_batch_size)
1575
+
1576
+
1577
+ # Step 1: Train YAT Model
1578
+ print(f"\n🚀 STEP 1: Training YAT Model on {dataset_name}...")
1579
+ print("-" * 50)
1580
+
1581
+ yat_model, yat_metrics_history = _train_model_loop(
1582
+ model_class=YatCNN,
1583
+ model_name="YAT",
1584
+ dataset_name=dataset_name, # Pass dataset_name
1585
+ rng_seed=0,
1586
+ learning_rate=learning_rate,
1587
+ momentum=momentum,
1588
+ optimizer_constructor=optax.adamw
1589
+ )
1590
+
1591
+ # Step 2: Train Linear Model
1592
+ print(f"\n🚀 STEP 2: Training Linear Model on {dataset_name}...")
1593
+ print("-" * 50)
1594
+
1595
+ linear_model, linear_metrics_history = _train_model_loop(
1596
+ model_class=LinearCNN,
1597
+ model_name="Linear",
1598
+ dataset_name=dataset_name, # Pass dataset_name
1599
+ rng_seed=0,
1600
+ learning_rate=learning_rate,
1601
+ momentum=momentum,
1602
+ optimizer_constructor=optax.adamw
1603
+ )
1604
+
1605
+ # Step 3: Run All Comparisons
1606
+ print(f"\n📊 STEP 3: Running Complete Comparison Analysis for {dataset_name}...")
1607
+ print("-" * 50)
1608
+
1609
+ # 3.1 Compare training curves
1610
+ print("\n📈 Comparing training curves...")
1611
+ compare_training_curves(yat_metrics_history, linear_metrics_history)
1612
+
1613
+ # 3.2 Print final metrics comparison
1614
+ print_final_metrics_comparison(yat_metrics_history, linear_metrics_history)
1615
+
1616
+ # 3.3 Analyze convergence
1617
+ analyze_convergence(yat_metrics_history, linear_metrics_history)
1618
+
1619
+ # 3.4 Detailed test evaluation
1620
+ # Need to reload the test dataset here as the one from _train_model_loop is consumed / specific to its scope
1621
+ # Or pass the models and a fresh test_ds_iterable to detailed_test_evaluation
1622
+ print("\n🎯 Running detailed test evaluation...")
1623
+ # Prepare test_ds specifically for detailed_test_evaluation and other analysis functions
1624
+ eval_config = DATASET_CONFIGS.get(dataset_name, {})
1625
+ eval_image_key = eval_config.get('image_key', 'image')
1626
+ eval_label_key = eval_config.get('label_key', 'label')
1627
+ eval_test_split = eval_config.get('test_split', 'test')
1628
+
1629
+ def eval_preprocess_fn(sample):
1630
+ return {
1631
+ 'image': tf.cast(sample[eval_image_key], tf.float32) / 255.0,
1632
+ 'label': sample[eval_label_key]
1633
+ }
1634
+ current_test_ds_for_eval = tfds.load(dataset_name, split=eval_test_split, as_supervised=False) \
1635
+ .map(eval_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE) \
1636
+ .batch(current_batch_size_for_eval, drop_remainder=True) \
1637
+ .prefetch(tf.data.AUTOTUNE)
1638
+
1639
+ predictions_data = detailed_test_evaluation(yat_model, linear_model, current_test_ds_for_eval, class_names=class_names_comp)
1640
+
1641
+ # 3.5 Plot confusion matrices
1642
+ print("\n📊 Plotting confusion matrices...")
1643
+ plot_confusion_matrices(predictions_data) # predictions_data now contains class_names
1644
+
1645
+ # 3.6 Generate comprehensive summary report
1646
+ generate_summary_report(yat_metrics_history, linear_metrics_history, predictions_data)
1647
+
1648
+ # Step 4: Advanced Analysis (New)
1649
+ print("\n🔬 STEP 4: Running Advanced Analysis...")
1650
+ print("-" * 50)
1651
+
1652
+ # 4.1 Visualize Kernels (e.g., from 'conv1')
1653
+ visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16)
1654
+
1655
+ # 4.2 Visualize Activation Maps (e.g., from 'conv1' for a sample from test_ds)
1656
+ # Use current_test_ds_for_eval or reload a small part of it
1657
+ activation_map_visualization(yat_model, linear_model, current_test_ds_for_eval, layer_name='conv1', num_maps_to_show=16)
1658
+
1659
+ # 4.3 Saliency Map Analysis
1660
+ saliency_map_analysis(yat_model, linear_model, current_test_ds_for_eval, class_names=class_names_comp)
1661
+
1662
+ print("\n" + "="*80)
1663
+ print(f" COMPARISON ANALYSIS FOR {dataset_name.upper()} COMPLETE! ✅")
1664
+ print("="*80)
1665
+
1666
+ return {
1667
+ 'yat_model': yat_model,
1668
+ 'linear_model': linear_model,
1669
+ 'yat_metrics_history': yat_metrics_history,
1670
+ 'linear_metrics_history': linear_metrics_history,
1671
+ 'predictions_data': predictions_data
1672
+ }
1673
+
1674
+ # ===== QUICK START FUNCTIONS =====
1675
+
1676
+ def quick_comparison_demo():
1677
+ """
1678
+ Quick demo that shows how to use the comparison functions with dummy data.
1679
+ Use this to test the comparison functions before running full training.
1680
+ """
1681
+ print("\n🎬 RUNNING QUICK COMPARISON DEMO...")
1682
+ print("-" * 50)
1683
+
1684
+ # Create dummy metrics history for demonstration
1685
+ import random
1686
+ random.seed(42)
1687
+
1688
+ steps = 30
1689
+ yat_dummy = {
1690
+ 'train_loss': [1.5 - 0.04*i + random.random()*0.1 for i in range(steps)],
1691
+ 'train_accuracy': [0.2 + 0.025*i + random.random()*0.05 for i in range(steps)],
1692
+ 'test_loss': [1.6 - 0.035*i + random.random()*0.15 for i in range(steps)],
1693
+ 'test_accuracy': [0.15 + 0.022*i + random.random()*0.08 for i in range(steps)]
1694
+ }
1695
+
1696
+ linear_dummy = {
1697
+ 'train_loss': [1.6 - 0.045*i + random.random()*0.1 for i in range(steps)],
1698
+ 'train_accuracy': [0.18 + 0.024*i + random.random()*0.05 for i in range(steps)],
1699
+ 'test_loss': [1.7 - 0.04*i + random.random()*0.15 for i in range(steps)],
1700
+ 'test_accuracy': [0.12 + 0.023*i + random.random()*0.08 for i in range(steps)]
1701
+ }
1702
+
1703
+ print("📈 Comparing dummy training curves...")
1704
+ compare_training_curves(yat_dummy, linear_dummy)
1705
+
1706
+ print_final_metrics_comparison(yat_dummy, linear_dummy)
1707
+ analyze_convergence(yat_dummy, linear_dummy)
1708
+
1709
+ print("✅ Demo complete! Now you can run the full comparison with real models.")
1710
+
1711
+ def save_metrics_example():
1712
+ """
1713
+ Shows how to properly save metrics history during training.
1714
+ """
1715
+ print("\n💾 HOW TO SAVE METRICS DURING TRAINING:")
1716
+ print("-" * 50)
1717
+ print("""
1718
+ # After training your YAT model:
1719
+ yat_metrics_history = metrics_history.copy()
1720
+
1721
+ # After training your Linear model:
1722
+ linear_metrics_history = metrics_history.copy()
1723
+
1724
+ # Or save to files:
1725
+ import pickle
1726
+ with open('yat_metrics.pkl', 'wb') as f:
1727
+ pickle.dump(yat_metrics_history, f)
1728
+
1729
+ with open('linear_metrics.pkl', 'wb') as f:
1730
+ pickle.dump(linear_metrics_history, f)
1731
+
1732
+ # Load later:
1733
+ with open('yat_metrics.pkl', 'rb') as f:
1734
+ yat_metrics_history = pickle.load(f)
1735
+
1736
+ with open('linear_metrics.pkl', 'rb') as f:
1737
+ linear_metrics_history = pickle.load(f)
1738
+ """)
1739
+
1740
+ # Print final instructions
1741
+ print("\n" + "="*80)
1742
+ print("="*80)
1743
+ print("\n🚀 TO RUN THE COMPLETE COMPARISON (e.g., for CIFAR-10):")
1744
+ print(" results = run_complete_comparison(dataset_name='cifar10')")
1745
+ print("\n Other examples:")
1746
+ print(" results_cifar100 = run_complete_comparison(dataset_name='cifar100')")
1747
+ print(" results_stl10 = run_complete_comparison(dataset_name='stl10')")
1748
+ print(" results_eurosat_rgb = run_complete_comparison(dataset_name='eurosat/rgb')")
1749
+ # print(" results_eurosat_all = run_complete_comparison(dataset_name='eurosat/all')") # Might be slow due to 13 channels
1750
+
1751
+ print("\n🎬 TO RUN A QUICK DEMO (uses dummy data, not specific dataset):")
1752
+ print(" quick_comparison_demo()")
1753
+ print("\n💾 TO SEE HOW TO SAVE METRICS:")
1754
+ print(" save_metrics_example()")
1755
+ print("\n📖 The comparison functions are ready to use:")
1756
+ print(" - compare_training_curves(yat_history, linear_history)")
1757
+ print(" - print_final_metrics_comparison(yat_history, linear_history)")
1758
+ print(" - analyze_convergence(yat_history, linear_history)")
1759
+ print(" - detailed_test_evaluation(yat_model, linear_model, test_ds)")
1760
+ print(" - plot_confusion_matrices(predictions_data)")
1761
+ print(" - generate_summary_report(yat_history, linear_history, predictions_data)")
1762
+ print("\n ✨ NEW ADVANCED ANALYSIS:")
1763
+ print(" - visualize_kernels(yat_model, linear_model, layer_name='conv1', num_kernels_to_show=16)")
1764
+ print(" - activation_map_visualization(yat_model, linear_model, test_ds_iter, layer_name='conv1', num_maps_to_show=16)")
1765
+ print(" - saliency_map_analysis(yat_model, linear_model, test_ds_iter, class_names=class_names_comp)")
1766
+ print("="*80)
1767
+
1768
+
1769
+ results_stl10 = run_complete_comparison(dataset_name='stl10')