pyerualjetwork 4.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,799 @@
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from scipy.spatial import ConvexHull
5
+ import seaborn as sns
6
+ from matplotlib.animation import ArtistAnimation
7
+
8
+ def draw_neural_web(W, ax, G, return_objs=False):
9
+ """
10
+ Visualizes a neural web by drawing the neural network structure.
11
+
12
+ Parameters:
13
+ W : numpy.ndarray
14
+ A 2D array representing the connection weights of the neural network.
15
+ ax : matplotlib.axes.Axes
16
+ The matplotlib axes where the graph will be drawn.
17
+ G : networkx.Graph
18
+ The NetworkX graph representing the neural network structure.
19
+ return_objs : bool, optional
20
+ If True, returns the drawn objects (nodes and edges). Default is False.
21
+
22
+ Returns:
23
+ art1 : matplotlib.collections.PathCollection or None
24
+ Returns the node collection if return_objs is True; otherwise, returns None.
25
+ art2 : matplotlib.collections.LineCollection or None
26
+ Returns the edge collection if return_objs is True; otherwise, returns None.
27
+ art3 : matplotlib.collections.TextCollection or None
28
+ Returns the label collection if return_objs is True; otherwise, returns None.
29
+
30
+ Example:
31
+ art1, art2, art3 = draw_neural_web(W, ax, G, return_objs=True)
32
+ plt.show()
33
+ """
34
+
35
+ for i in range(W.shape[0]):
36
+ for j in range(W.shape[1]):
37
+ if W[i, j] != 0:
38
+ G.add_edge(f'Output{i}', f'Input{j}', ltpw=W[i, j])
39
+
40
+ edges = G.edges(data=True)
41
+ weights = [edata['ltpw'] for _, _, edata in edges]
42
+ pos = {}
43
+ num_motor_neurons = W.shape[0]
44
+ num_sensory_neurons = W.shape[1]
45
+
46
+ for j in range(num_sensory_neurons):
47
+ pos[f'Input{j}'] = (0, j)
48
+
49
+ motor_y_start = (num_sensory_neurons - num_motor_neurons) / 2
50
+ for i in range(num_motor_neurons):
51
+ pos[f'Output{i}'] = (1, motor_y_start + i)
52
+
53
+
54
+ art1 = nx.draw_networkx_nodes(G, pos, ax=ax, node_size=1000, node_color='lightblue')
55
+ art2 = nx.draw_networkx_edges(G, pos, ax=ax, edge_color=weights, edge_cmap=plt.cm.Blues, width=2)
56
+ art3 = nx.draw_networkx_labels(G, pos, ax=ax, font_size=10, font_weight='bold')
57
+
58
+ ax.spines['top'].set_visible(False)
59
+ ax.spines['right'].set_visible(False)
60
+ ax.spines['left'].set_visible(False)
61
+ ax.spines['bottom'].set_visible(False)
62
+ ax.get_xaxis().set_visible(False)
63
+ ax.get_yaxis().set_visible(False)
64
+ ax.set_title('Neural Web')
65
+
66
+ if return_objs == True:
67
+
68
+ return art1, art2, art3
69
+
70
+
71
+ def draw_model_architecture(model_name, model_path='', style='basic'):
72
+ """
73
+ Visualizes the architecture of a neural network model.
74
+
75
+ Parameters
76
+ ----------
77
+ model_name : str
78
+ The name of the model to be visualized, which will be displayed in the title or label.
79
+
80
+ model_path : str
81
+ The file path to the model, from which the architecture is loaded. Default is ''
82
+
83
+ style : str, optional
84
+ The style of the visualization.
85
+ Options:
86
+ - 'basic': Displays a simplified view of the model architecture.
87
+ - 'detailed': Shows a more comprehensive view, including layer details and parameters.
88
+ Default is 'basic'.
89
+
90
+ Returns
91
+ -------
92
+ None
93
+ Draws and displays the architecture of the specified model.
94
+
95
+
96
+ Examples
97
+ --------
98
+ >>> draw_model_architecture("MyModel", "path/to/model", style='detailed')
99
+ """
100
+ from .plan import get_scaler, get_act_pot, get_weights
101
+ from .model_operations import load_model
102
+
103
+ model = load_model(model_name=model_name, model_path=model_path)
104
+
105
+ W = model[get_weights()]
106
+ activation_potentiation = model[get_act_pot()]
107
+ scaler_params = model[get_scaler()]
108
+
109
+ text_1 = f"Input Shape:\n{W.shape[1]}"
110
+ text_2 = f"Output Shape:\n{W.shape[0]}"
111
+
112
+ if scaler_params is None:
113
+ bottom_left_text = 'Standard Scaler=No'
114
+ else:
115
+ bottom_left_text = 'Standard Scaler=Yes'
116
+
117
+ if len(activation_potentiation) != 1 or (len(activation_potentiation) == 1 and activation_potentiation[0] != 'linear'):
118
+
119
+ bottom_left_text_1 = f'Aggregation Layers(Aggregates All Conversions)={len(activation_potentiation)}'
120
+
121
+ else:
122
+
123
+ bottom_left_text_1 = 'Aggregation Layers(Aggregates All Conversions)=0'
124
+
125
+ bottom_left_text_2 = 'Potentiation Layer(Fully Connected)=1'
126
+
127
+ if scaler_params is None:
128
+ bottom_left_text = 'Standard Scaler=No'
129
+ else:
130
+ bottom_left_text = 'Standard Scaler=Yes'
131
+
132
+ num_middle_axes = len(activation_potentiation)
133
+
134
+ if style == 'detailed':
135
+
136
+ col = 1
137
+
138
+ elif style == 'basic':
139
+
140
+ col = 2
141
+
142
+ fig, axes = plt.subplots(1, num_middle_axes + col, figsize=(5 * (num_middle_axes + 2), 5))
143
+
144
+ fig.suptitle("Model Architecture", fontsize=16, fontweight='bold')
145
+
146
+ for i, activation in enumerate(activation_potentiation):
147
+ x = np.linspace(-100, 100, 100)
148
+ translated_x_train = draw_activations(x, activation)
149
+ y = translated_x_train
150
+
151
+ axes[i].plot(x, y, color='b', markersize=6, linewidth=2, label='Activations Over Depth')
152
+ axes[i].set_title(activation_potentiation[i])
153
+
154
+ axes[i].spines['top'].set_visible(False)
155
+ axes[i].spines['right'].set_visible(False)
156
+ axes[i].spines['left'].set_visible(False)
157
+ axes[i].spines['bottom'].set_visible(False)
158
+ axes[i].get_xaxis().set_visible(False)
159
+ axes[i].get_yaxis().set_visible(False)
160
+
161
+
162
+ if i < num_middle_axes - 1:
163
+ axes[i].annotate('', xy=(1.05, 0.5), xytext=(0.95, 0.5),
164
+ xycoords='axes fraction', textcoords='axes fraction',
165
+ arrowprops=dict(arrowstyle="->", color='black', lw=1.5))
166
+
167
+ if style == 'detailed':
168
+
169
+ G = nx.Graph()
170
+ draw_neural_web(W=W, ax=axes[num_middle_axes], G=G)
171
+
172
+ elif style == 'basic':
173
+
174
+ circle1 = plt.Circle((0.5, 0.5), 0.4, color='skyblue', ec='black', lw=1.5)
175
+ axes[num_middle_axes].add_patch(circle1)
176
+ axes[num_middle_axes].text(0.5, 0.5, text_1, ha='center', va='center', fontsize=12)
177
+ axes[num_middle_axes].set_xlim(0, 1)
178
+ axes[num_middle_axes].set_ylim(0, 1)
179
+ axes[num_middle_axes].axis('off')
180
+
181
+ circle2 = plt.Circle((0.5, 0.5), 0.4, color='lightcoral', ec='black', lw=1.5)
182
+ axes[-1].add_patch(circle2)
183
+ axes[-1].text(0.5, 0.5, text_2, ha='center', va='center', fontsize=12)
184
+ axes[-1].set_xlim(0, 1)
185
+ axes[-1].set_ylim(0, 1)
186
+ axes[-1].axis('off')
187
+
188
+
189
+ fig.text(0.01, 0, bottom_left_text, ha='left', va='bottom', fontsize=10)
190
+ fig.text(0.01, 0.04, bottom_left_text_1, ha='left', va='bottom', fontsize=10)
191
+ fig.text(0.01, 0.08, bottom_left_text_2, ha='left', va='bottom', fontsize=10)
192
+
193
+ plt.tight_layout()
194
+ plt.show()
195
+
196
+
197
+ def draw_activations(x_train, activation):
198
+
199
+ from . import activation_functions as af
200
+
201
+ if activation == 'sigmoid':
202
+ result = af.Sigmoid(x_train)
203
+
204
+ elif activation == 'swish':
205
+ result = af.swish(x_train)
206
+
207
+ elif activation == 'circular':
208
+ result = af.circular_activation(x_train)
209
+
210
+ elif activation == 'mod_circular':
211
+ result = af.modular_circular_activation(x_train)
212
+
213
+ elif activation == 'tanh_circular':
214
+ result = af.tanh_circular_activation(x_train)
215
+
216
+ elif activation == 'leaky_relu':
217
+ result = af.leaky_relu(x_train)
218
+
219
+ elif activation == 'relu':
220
+ result = af.Relu(x_train)
221
+
222
+ elif activation == 'softplus':
223
+ result = af.softplus(x_train)
224
+
225
+ elif activation == 'elu':
226
+ result = af.elu(x_train)
227
+
228
+ elif activation == 'gelu':
229
+ result = af.gelu(x_train)
230
+
231
+ elif activation == 'selu':
232
+ result = af.selu(x_train)
233
+
234
+ elif activation == 'softmax':
235
+ result = af.Softmax(x_train)
236
+
237
+ elif activation == 'tanh':
238
+ result = af.tanh(x_train)
239
+
240
+ elif activation == 'sinakt':
241
+ result = af.sinakt(x_train)
242
+
243
+ elif activation == 'p_squared':
244
+ result = af.p_squared(x_train)
245
+
246
+ elif activation == 'sglu':
247
+ result = af.sglu(x_train, alpha=1.0)
248
+
249
+ elif activation == 'dlrelu':
250
+ result = af.dlrelu(x_train)
251
+
252
+ elif activation == 'exsig':
253
+ result = af.exsig(x_train)
254
+
255
+ elif activation == 'sin_plus':
256
+ result = af.sin_plus(x_train)
257
+
258
+ elif activation == 'acos':
259
+ result = af.acos(x_train, alpha=1.0, beta=0.0)
260
+
261
+ elif activation == 'gla':
262
+ result = af.gla(x_train, alpha=1.0, mu=0.0)
263
+
264
+ elif activation == 'srelu':
265
+ result = af.srelu(x_train)
266
+
267
+ elif activation == 'qelu':
268
+ result = af.qelu(x_train)
269
+
270
+ elif activation == 'isra':
271
+ result = af.isra(x_train)
272
+
273
+ elif activation == 'waveakt':
274
+ result = af.waveakt(x_train)
275
+
276
+ elif activation == 'arctan':
277
+ result = af.arctan(x_train)
278
+
279
+ elif activation == 'bent_identity':
280
+ result = af.bent_identity(x_train)
281
+
282
+ elif activation == 'sech':
283
+ result = af.sech(x_train)
284
+
285
+ elif activation == 'softsign':
286
+ result = af.softsign(x_train)
287
+
288
+ elif activation == 'pwl':
289
+ result = af.pwl(x_train)
290
+
291
+ elif activation == 'cubic':
292
+ result = af.cubic(x_train)
293
+
294
+ elif activation == 'gaussian':
295
+ result = af.gaussian(x_train)
296
+
297
+ elif activation == 'sine':
298
+ result = af.sine(x_train)
299
+
300
+ elif activation == 'tanh_square':
301
+ result = af.tanh_square(x_train)
302
+
303
+ elif activation == 'mod_sigmoid':
304
+ result = af.mod_sigmoid(x_train)
305
+
306
+ elif activation == 'linear':
307
+ result = x_train
308
+
309
+ elif activation == 'quartic':
310
+ result = af.quartic(x_train)
311
+
312
+ elif activation == 'square_quartic':
313
+ result = af.square_quartic(x_train)
314
+
315
+ elif activation == 'cubic_quadratic':
316
+ result = af.cubic_quadratic(x_train)
317
+
318
+ elif activation == 'exp_cubic':
319
+ result = af.exp_cubic(x_train)
320
+
321
+ elif activation == 'sine_square':
322
+ result = af.sine_square(x_train)
323
+
324
+ elif activation == 'logarithmic':
325
+ result = af.logarithmic(x_train)
326
+
327
+ elif activation == 'scaled_cubic':
328
+ result = af.scaled_cubic(x_train, 1.0)
329
+
330
+ elif activation == 'sine_offset':
331
+ result = af.sine_offset(x_train, 1.0)
332
+
333
+ elif activation == 'spiral':
334
+ result = af.spiral_activation(x_train)
335
+
336
+ return result
337
+
338
+
339
+ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation):
340
+
341
+ from .metrics import metrics, confusion_matrix, roc_curve
342
+ from .ui import loading_bars, initialize_loading_bar
343
+ from .data_operations import decode_one_hot
344
+ from .model_operations import predict_model_ram
345
+
346
+ bar_format_normal = loading_bars()[0]
347
+
348
+ acc = acc_list[len(acc_list) - 1]
349
+ y_true = decode_one_hot(y_test)
350
+
351
+ y_true = np.array(y_true)
352
+ y_preds = np.array(y_preds)
353
+ Class = np.unique(decode_one_hot(y_test))
354
+
355
+ precision, recall, f1 = metrics(y_test, y_preds)
356
+
357
+
358
+ cm = confusion_matrix(y_true, y_preds, len(Class))
359
+ fig, axs = plt.subplots(2, 2, figsize=(16, 12))
360
+
361
+ sns.heatmap(cm, annot=True, fmt='d', ax=axs[0, 0])
362
+ axs[0, 0].set_title("Confusion Matrix")
363
+ axs[0, 0].set_xlabel("Predicted Class")
364
+ axs[0, 0].set_ylabel("Actual Class")
365
+
366
+ if len(Class) == 2:
367
+ fpr, tpr, thresholds = roc_curve(y_true, y_preds)
368
+
369
+ roc_auc = np.trapz(tpr, fpr)
370
+ axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
371
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
372
+ axs[1, 0].set_xlim([0.0, 1.0])
373
+ axs[1, 0].set_ylim([0.0, 1.05])
374
+ axs[1, 0].set_xlabel('False Positive Rate')
375
+ axs[1, 0].set_ylabel('True Positive Rate')
376
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
377
+ axs[1, 0].legend(loc="lower right")
378
+ axs[1, 0].legend(loc="lower right")
379
+ else:
380
+
381
+ for i in range(len(Class)):
382
+
383
+ y_true_copy = np.copy(y_true)
384
+ y_preds_copy = np.copy(y_preds)
385
+
386
+ y_true_copy[y_true_copy == i] = 0
387
+ y_true_copy[y_true_copy != 0] = 1
388
+
389
+ y_preds_copy[y_preds_copy == i] = 0
390
+ y_preds_copy[y_preds_copy != 0] = 1
391
+
392
+
393
+ fpr, tpr, thresholds = roc_curve(y_true_copy, y_preds_copy)
394
+
395
+ roc_auc = np.trapz(tpr, fpr)
396
+ axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
397
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
398
+ axs[1, 0].set_xlim([0.0, 1.0])
399
+ axs[1, 0].set_ylim([0.0, 1.05])
400
+ axs[1, 0].set_xlabel('False Positive Rate')
401
+ axs[1, 0].set_ylabel('True Positive Rate')
402
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
403
+ axs[1, 0].legend(loc="lower right")
404
+ axs[1, 0].legend(loc="lower right")
405
+
406
+
407
+ metric = ['Precision', 'Recall', 'F1 Score', 'Accuracy']
408
+ values = [precision, recall, f1, acc]
409
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
410
+
411
+
412
+ bars = axs[0, 1].bar(metric, values, color=colors)
413
+
414
+
415
+ for bar, value in zip(bars, values):
416
+ axs[0, 1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() - 0.05, f'{value:.2f}',
417
+ ha='center', va='bottom', fontsize=12, color='white', weight='bold')
418
+
419
+ axs[0, 1].set_ylim(0, 1)
420
+ axs[0, 1].set_xlabel('Metrics')
421
+ axs[0, 1].set_ylabel('Score')
422
+ axs[0, 1].set_title('Precision, Recall, F1 Score, and Accuracy (Weighted)')
423
+ axs[0, 1].grid(True, axis='y', linestyle='--', alpha=0.7)
424
+
425
+ feature_indices=[0, 1]
426
+
427
+ h = .02
428
+ x_min, x_max = x_test[:, feature_indices[0]].min() - 1, x_test[:, feature_indices[0]].max() + 1
429
+ y_min, y_max = x_test[:, feature_indices[1]].min() - 1, x_test[:, feature_indices[1]].max() + 1
430
+ xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
431
+ np.arange(y_min, y_max, h))
432
+
433
+ grid = np.c_[xx.ravel(), yy.ravel()]
434
+
435
+ try:
436
+
437
+ grid_full = np.zeros((grid.shape[0], x_test.shape[1]))
438
+ grid_full[:, feature_indices] = grid
439
+
440
+ Z = [None] * len(grid_full)
441
+
442
+ predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
443
+ bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
444
+
445
+ for i in range(len(grid_full)):
446
+
447
+ Z[i] = np.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
448
+ predict_progress.update(1)
449
+
450
+ predict_progress.close()
451
+
452
+ Z = np.array(Z)
453
+ Z = Z.reshape(xx.shape)
454
+
455
+ axs[1,1].contourf(xx, yy, Z, alpha=0.8)
456
+ axs[1,1].scatter(x_test[:, feature_indices[0]], x_test[:, feature_indices[1]], c=decode_one_hot(y_test), edgecolors='k', marker='o', s=20, alpha=0.9)
457
+ axs[1,1].set_xlabel(f'Feature {0 + 1}')
458
+ axs[1,1].set_ylabel(f'Feature {1 + 1}')
459
+ axs[1,1].set_title('Decision Boundary')
460
+
461
+ except Exception as e:
462
+ # Hata meydana geldiğinde yapılacak işlemler
463
+ print(f"Hata oluştu: {e}")
464
+
465
+ plt.show()
466
+
467
+
468
+ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=None):
469
+
470
+ from .model_operations import predict_model_ram
471
+ from .data_operations import decode_one_hot
472
+
473
+ feature_indices = [0, 1]
474
+
475
+ h = .02
476
+ x_min, x_max = x[:, feature_indices[0]].min() - 1, x[:, feature_indices[0]].max() + 1
477
+ y_min, y_max = x[:, feature_indices[1]].min() - 1, x[:, feature_indices[1]].max() + 1
478
+ xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
479
+ np.arange(y_min, y_max, h))
480
+
481
+ grid = np.c_[xx.ravel(), yy.ravel()]
482
+ grid_full = np.zeros((grid.shape[0], x.shape[1]))
483
+ grid_full[:, feature_indices] = grid
484
+
485
+ Z = [None] * len(grid_full)
486
+
487
+ for i in range(len(grid_full)):
488
+ Z[i] = np.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
489
+
490
+ Z = np.array(Z)
491
+ Z = Z.reshape(xx.shape)
492
+
493
+ if ax is None:
494
+
495
+ plt.contourf(xx, yy, Z, alpha=0.8)
496
+ plt.scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
497
+ plt.xlabel(f'Feature {0 + 1}')
498
+ plt.ylabel(f'Feature {1 + 1}')
499
+ plt.title('Decision Boundary')
500
+
501
+ plt.show()
502
+
503
+ else:
504
+
505
+ try:
506
+ art1_1 = ax[1, 0].contourf(xx, yy, Z, alpha=0.8)
507
+ art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
508
+ ax[1, 0].set_xlabel(f'Feature {0 + 1}')
509
+ ax[1, 0].set_ylabel(f'Feature {1 + 1}')
510
+ ax[1, 0].set_title('Decision Boundary')
511
+
512
+ return art1_1, art1_2
513
+
514
+ except:
515
+
516
+ art1_1 = ax[0].contourf(xx, yy, Z, alpha=0.8)
517
+ art1_2 = ax[0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
518
+ ax[0].set_xlabel(f'Feature {0 + 1}')
519
+ ax[0].set_ylabel(f'Feature {1 + 1}')
520
+ ax[0].set_title('Decision Boundary')
521
+
522
+
523
+ return art1_1, art1_2
524
+
525
+
526
+ def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
527
+
528
+ from .metrics import pca
529
+ from .data_operations import decode_one_hot
530
+
531
+ if x.shape[1] > 2:
532
+
533
+ X_pca = pca(x, n_components=2)
534
+ else:
535
+ X_pca = x
536
+
537
+ if y_preds == None:
538
+ y_preds = decode_one_hot(y)
539
+
540
+ y = decode_one_hot(y)
541
+ num_classes = len(np.unique(y))
542
+
543
+ cmap = plt.get_cmap(color)
544
+
545
+
546
+ norm = plt.Normalize(vmin=0, vmax=num_classes - 1)
547
+
548
+
549
+ plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, edgecolor='k', s=50, cmap=cmap, norm=norm)
550
+
551
+
552
+ for cls in range(num_classes):
553
+
554
+ class_points = []
555
+
556
+
557
+ for i in range(len(y)):
558
+ if y_preds[i] == cls:
559
+ class_points.append(X_pca[i])
560
+
561
+ class_points = np.array(class_points)
562
+
563
+
564
+ if len(class_points) > 2:
565
+ hull = ConvexHull(class_points)
566
+ hull_points = class_points[hull.vertices]
567
+
568
+ hull_points = np.vstack([hull_points, hull_points[0]])
569
+
570
+ plt.fill(hull_points[:, 0], hull_points[:, 1], color=cmap(norm(cls)), alpha=0.3, edgecolor='k', label=f'Class {cls} Hull')
571
+
572
+ plt.title("Decision Space (Data Distribution)")
573
+
574
+ plt.draw()
575
+
576
+
577
+ def neuron_history(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=False, loss=False):
578
+
579
+ for j in range(len(class_count)):
580
+
581
+ if acc != False and loss != False:
582
+ suptitle_info = data + ' Accuracy:' + str(acc) + '\n' + data + ' Loss:' + str(loss) + '\nNeurons Memory:'
583
+ else:
584
+ suptitle_info = 'Neurons Memory:'
585
+
586
+ mat = LTPW[j,:].reshape(row, col)
587
+
588
+ title_info = f'{j+1}. Neuron'
589
+
590
+ art5 = ax1[j].imshow(mat, interpolation='sinc', cmap='viridis')
591
+
592
+ ax1[j].set_aspect('equal')
593
+ ax1[j].set_xticks([])
594
+ ax1[j].set_yticks([])
595
+ ax1[j].set_title(title_info)
596
+
597
+
598
+ artist5.append([art5])
599
+
600
+ fig1.suptitle(suptitle_info, fontsize=16)
601
+
602
+ return artist5
603
+
604
+
605
+ def initialize_visualization_for_fit(val, show_training, neurons_history, x_train, y_train):
606
+ """Initializes the visualization setup based on the parameters."""
607
+ from .data_operations import find_closest_factors
608
+ visualization_objects = {}
609
+
610
+ if show_training:
611
+ if not val:
612
+ raise ValueError("For showing training, 'val' parameter must be True.")
613
+
614
+ G = nx.Graph()
615
+ fig, ax = plt.subplots(2, 2)
616
+ fig.suptitle('Train History')
617
+ visualization_objects.update({
618
+ 'G': G,
619
+ 'fig': fig,
620
+ 'ax': ax,
621
+ 'artist1': [],
622
+ 'artist2': [],
623
+ 'artist3': [],
624
+ 'artist4': []
625
+ })
626
+
627
+ if neurons_history:
628
+ row, col = find_closest_factors(len(x_train[0]))
629
+ fig1, ax1 = plt.subplots(1, len(set(y_train)), figsize=(18, 14))
630
+ visualization_objects.update({
631
+ 'fig1': fig1,
632
+ 'ax1': ax1,
633
+ 'artist5': [],
634
+ 'row': row,
635
+ 'col': col
636
+ })
637
+
638
+ return visualization_objects
639
+
640
+
641
+ def update_weight_visualization_for_fit(ax, LTPW, artist2):
642
+ """Updates the weight visualization plot."""
643
+ art2 = ax.imshow(LTPW, interpolation='sinc', cmap='viridis')
644
+ artist2.append([art2])
645
+
646
+
647
+ def update_decision_boundary_for_fit(ax, x_val, y_val, activation_potentiation, LTPW, artist1):
648
+ """Updates the decision boundary visualization."""
649
+ art1_1, art1_2 = plot_decision_boundary(x_val, y_val, activation_potentiation, LTPW, artist=artist1, ax=ax)
650
+ artist1.append([*art1_1.collections, art1_2])
651
+
652
+
653
+ def update_validation_history_for_fit(ax, val_list, artist3):
654
+ """Updates the validation accuracy history plot."""
655
+ period = list(range(1, len(val_list) + 1))
656
+ art3 = ax.plot(
657
+ period,
658
+ val_list,
659
+ linestyle='--',
660
+ color='g',
661
+ marker='o',
662
+ markersize=6,
663
+ linewidth=2,
664
+ label='Validation Accuracy'
665
+ )
666
+ ax.set_title('Validation History')
667
+ ax.set_xlabel('Time')
668
+ ax.set_ylabel('Validation Accuracy')
669
+ ax.set_ylim([0, 1])
670
+ artist3.append(art3)
671
+
672
+
673
+ def display_visualization_for_fit(fig, artist_list, interval):
674
+ """Displays the animation for the given artist list."""
675
+ ani = ArtistAnimation(fig, artist_list, interval=interval, blit=True)
676
+ plt.tight_layout()
677
+ plt.show()
678
+
679
+
680
+
681
+ def initialize_visualization_for_learner(show_history, neurons_history, neural_web_history, x_train, y_train):
682
+ """Initialize all visualization components"""
683
+ from .data_operations import find_closest_factors
684
+ viz_objects = {}
685
+
686
+ if show_history:
687
+ fig, ax = plt.subplots(3, 1, figsize=(6, 8))
688
+ fig.suptitle('Learner History')
689
+ viz_objects['history'] = {
690
+ 'fig': fig,
691
+ 'ax': ax,
692
+ 'artist1': [],
693
+ 'artist2': [],
694
+ 'artist3': []
695
+ }
696
+
697
+ if neurons_history:
698
+ row, col = find_closest_factors(len(x_train[0]))
699
+ if row != 0:
700
+ fig1, ax1 = plt.subplots(1, len(y_train[0]), figsize=(18, 14))
701
+ else:
702
+ fig1, ax1 = plt.subplots(1, 1, figsize=(18, 14))
703
+ viz_objects['neurons'] = {
704
+ 'fig': fig1,
705
+ 'ax': ax1,
706
+ 'artists': [],
707
+ 'row': row,
708
+ 'col': col
709
+ }
710
+
711
+ if neural_web_history:
712
+ G = nx.Graph()
713
+ fig2, ax2 = plt.subplots(figsize=(18, 4))
714
+ viz_objects['web'] = {
715
+ 'fig': fig2,
716
+ 'ax': ax2,
717
+ 'G': G,
718
+ 'artists': []
719
+ }
720
+
721
+ return viz_objects
722
+
723
+ def update_history_plots_for_learner(viz_objects, depth_list, loss_list, best_acc_per_depth_list, x_train, final_activations):
724
+ """Update history visualization plots"""
725
+ if 'history' not in viz_objects:
726
+ return
727
+
728
+ hist = viz_objects['history']
729
+
730
+ # Loss plot
731
+ art1 = hist['ax'][0].plot(depth_list, loss_list, color='r', markersize=6, linewidth=2)
732
+ hist['ax'][0].set_title('Test Loss Over Depth')
733
+ hist['artist1'].append(art1)
734
+
735
+ # Accuracy plot
736
+ art2 = hist['ax'][1].plot(depth_list, best_acc_per_depth_list, color='g', markersize=6, linewidth=2)
737
+ hist['ax'][1].set_title('Test Accuracy Over Depth')
738
+ hist['artist2'].append(art2)
739
+
740
+ # Activation shape plot
741
+ x = np.linspace(np.min(x_train), np.max(x_train), len(x_train))
742
+ translated_x_train = np.copy(x)
743
+ for activation in final_activations:
744
+ translated_x_train += draw_activations(x, activation)
745
+
746
+ art3 = hist['ax'][2].plot(x, translated_x_train, color='b', markersize=6, linewidth=2)
747
+ hist['ax'][2].set_title('Potentiation Shape Over Depth')
748
+ hist['artist3'].append(art3)
749
+
750
+ def display_visualizations_for_learner(viz_objects, best_weights, data, best_acc, test_loss, y_train, interval):
751
+ """Display all final visualizations"""
752
+ if 'history' in viz_objects:
753
+ hist = viz_objects['history']
754
+ for _ in range(30):
755
+ hist['artist1'].append(hist['artist1'][-1])
756
+ hist['artist2'].append(hist['artist2'][-1])
757
+ hist['artist3'].append(hist['artist3'][-1])
758
+
759
+ ani1 = ArtistAnimation(hist['fig'], hist['artist1'], interval=interval, blit=True)
760
+ ani2 = ArtistAnimation(hist['fig'], hist['artist2'], interval=interval, blit=True)
761
+ ani3 = ArtistAnimation(hist['fig'], hist['artist3'], interval=interval, blit=True)
762
+ plt.tight_layout()
763
+ plt.show()
764
+
765
+ if 'neurons' in viz_objects:
766
+ neurons = viz_objects['neurons']
767
+ for _ in range(10):
768
+ neurons['artists'] = neuron_history(
769
+ np.copy(best_weights),
770
+ neurons['ax'],
771
+ neurons['row'],
772
+ neurons['col'],
773
+ y_train[0],
774
+ neurons['artists'],
775
+ data=data,
776
+ fig1=neurons['fig'],
777
+ acc=best_acc,
778
+ loss=test_loss
779
+ )
780
+
781
+ ani4 = ArtistAnimation(neurons['fig'], neurons['artists'], interval=interval, blit=True)
782
+ plt.tight_layout()
783
+ plt.show()
784
+
785
+ if 'web' in viz_objects:
786
+ web = viz_objects['web']
787
+ for _ in range(30):
788
+ art5_1, art5_2, art5_3 = draw_neural_web(
789
+ W=best_weights,
790
+ ax=web['ax'],
791
+ G=web['G'],
792
+ return_objs=True
793
+ )
794
+ art5_list = [art5_1] + [art5_2] + list(art5_3.values())
795
+ web['artists'].append(art5_list)
796
+
797
+ ani5 = ArtistAnimation(web['fig'], web['artists'], interval=interval, blit=True)
798
+ plt.tight_layout()
799
+ plt.show()