pyerualjetwork 4.0.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,799 @@
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+ import cupy as cp
4
+ from scipy.spatial import ConvexHull
5
+ import seaborn as sns
6
+ from matplotlib.animation import ArtistAnimation
7
+
8
+ def draw_neural_web(W, ax, G, return_objs=False):
9
+ """
10
+ Visualizes a neural web by drawing the neural network structure.
11
+
12
+ Parameters:
13
+ W : numpy.ndarray
14
+ A 2D array representing the connection weights of the neural network.
15
+ ax : matplotlib.axes.Axes
16
+ The matplotlib axes where the graph will be drawn.
17
+ G : networkx.Graph
18
+ The NetworkX graph representing the neural network structure.
19
+ return_objs : bool, optional
20
+ If True, returns the drawn objects (nodes and edges). Default is False.
21
+
22
+ Returns:
23
+ art1 : matplotlib.collections.PathCollection or None
24
+ Returns the node collection if return_objs is True; otherwise, returns None.
25
+ art2 : matplotlib.collections.LineCollection or None
26
+ Returns the edge collection if return_objs is True; otherwise, returns None.
27
+ art3 : matplotlib.collections.TextCollection or None
28
+ Returns the label collection if return_objs is True; otherwise, returns None.
29
+
30
+ Example:
31
+ art1, art2, art3 = draw_neural_web(W, ax, G, return_objs=True)
32
+ plt.show()
33
+ """
34
+
35
+ for i in range(W.shape[0]):
36
+ for j in range(W.shape[1]):
37
+ if W[i, j] != 0:
38
+ G.add_edge(f'Output{i}', f'Input{j}', ltpw=W[i, j])
39
+
40
+ edges = G.edges(data=True)
41
+ weights = [edata['ltpw'] for _, _, edata in edges]
42
+ pos = {}
43
+ num_motor_neurons = W.shape[0]
44
+ num_sensory_neurons = W.shape[1]
45
+
46
+ for j in range(num_sensory_neurons):
47
+ pos[f'Input{j}'] = (0, j)
48
+
49
+ motor_y_start = (num_sensory_neurons - num_motor_neurons) / 2
50
+ for i in range(num_motor_neurons):
51
+ pos[f'Output{i}'] = (1, motor_y_start + i)
52
+
53
+
54
+ art1 = nx.draw_networkx_nodes(G, pos, ax=ax, node_size=1000, node_color='lightblue')
55
+ art2 = nx.draw_networkx_edges(G, pos, ax=ax, edge_color=weights, edge_cmap=plt.cm.Blues, width=2)
56
+ art3 = nx.draw_networkx_labels(G, pos, ax=ax, font_size=10, font_weight='bold')
57
+
58
+ ax.spines['top'].set_visible(False)
59
+ ax.spines['right'].set_visible(False)
60
+ ax.spines['left'].set_visible(False)
61
+ ax.spines['bottom'].set_visible(False)
62
+ ax.get_xaxis().set_visible(False)
63
+ ax.get_yaxis().set_visible(False)
64
+ ax.set_title('Neural Web')
65
+
66
+ if return_objs == True:
67
+
68
+ return art1, art2, art3
69
+
70
+
71
+ def draw_model_architecture(model_name, model_path='', style='basic'):
72
+ """
73
+ Visualizes the architecture of a neural network model.
74
+
75
+ Parameters
76
+ ----------
77
+ model_name : str
78
+ The name of the model to be visualized, which will be displayed in the title or label.
79
+
80
+ model_path : str
81
+ The file path to the model, from which the architecture is loaded. Default is ''
82
+
83
+ style : str, optional
84
+ The style of the visualization.
85
+ Options:
86
+ - 'basic': Displays a simplified view of the model architecture.
87
+ - 'detailed': Shows a more comprehensive view, including layer details and parameters.
88
+ Default is 'basic'.
89
+
90
+ Returns
91
+ -------
92
+ None
93
+ Draws and displays the architecture of the specified model.
94
+
95
+
96
+ Examples
97
+ --------
98
+ >>> draw_model_architecture("MyModel", "path/to/model", style='detailed')
99
+ """
100
+ from .plan_cuda import get_scaler, get_act_pot, get_weights
101
+ from .model_operations_cuda import load_model
102
+
103
+ model = load_model(model_name=model_name, model_path=model_path)
104
+
105
+ W = model[get_weights()]
106
+ activation_potentiation = model[get_act_pot()]
107
+ scaler_params = model[get_scaler()]
108
+
109
+ text_1 = f"Input Shape:\n{W.shape[1]}"
110
+ text_2 = f"Output Shape:\n{W.shape[0]}"
111
+
112
+ if scaler_params is None:
113
+ bottom_left_text = 'Standard Scaler=No'
114
+ else:
115
+ bottom_left_text = 'Standard Scaler=Yes'
116
+
117
+ if len(activation_potentiation) != 1 or (len(activation_potentiation) == 1 and activation_potentiation[0] != 'linear'):
118
+
119
+ bottom_left_text_1 = f'Aggregation Layers(Aggregates All Conversions)={len(activation_potentiation)}'
120
+
121
+ else:
122
+
123
+ bottom_left_text_1 = 'Aggregation Layers(Aggregates All Conversions)=0'
124
+
125
+ bottom_left_text_2 = 'Potentiation Layer(Fully Connected)=1'
126
+
127
+ if scaler_params is None:
128
+ bottom_left_text = 'Standard Scaler=No'
129
+ else:
130
+ bottom_left_text = 'Standard Scaler=Yes'
131
+
132
+ num_middle_axes = len(activation_potentiation)
133
+
134
+ if style == 'detailed':
135
+
136
+ col = 1
137
+
138
+ elif style == 'basic':
139
+
140
+ col = 2
141
+
142
+ fig, axes = plt.subplots(1, num_middle_axes + col, figsize=(5 * (num_middle_axes + 2), 5))
143
+
144
+ fig.suptitle("Model Architecture", fontsize=16, fontweight='bold')
145
+
146
+ for i, activation in enumerate(activation_potentiation):
147
+ x = cp.linspace(-100, 100, 100)
148
+ translated_x_train = draw_activations(x, activation)
149
+ y = translated_x_train
150
+
151
+ axes[i].plot(x, y, color='b', markersize=6, linewidth=2, label='Activations Over Depth')
152
+ axes[i].set_title(activation_potentiation[i])
153
+
154
+ axes[i].spines['top'].set_visible(False)
155
+ axes[i].spines['right'].set_visible(False)
156
+ axes[i].spines['left'].set_visible(False)
157
+ axes[i].spines['bottom'].set_visible(False)
158
+ axes[i].get_xaxis().set_visible(False)
159
+ axes[i].get_yaxis().set_visible(False)
160
+
161
+
162
+ if i < num_middle_axes - 1:
163
+ axes[i].annotate('', xy=(1.05, 0.5), xytext=(0.95, 0.5),
164
+ xycoords='axes fraction', textcoords='axes fraction',
165
+ arrowprops=dict(arrowstyle="->", color='black', lw=1.5))
166
+
167
+ if style == 'detailed':
168
+
169
+ G = nx.Graph()
170
+ draw_neural_web(W=W, ax=axes[num_middle_axes], G=G)
171
+
172
+ elif style == 'basic':
173
+
174
+ circle1 = plt.Circle((0.5, 0.5), 0.4, color='skyblue', ec='black', lw=1.5)
175
+ axes[num_middle_axes].add_patch(circle1)
176
+ axes[num_middle_axes].text(0.5, 0.5, text_1, ha='center', va='center', fontsize=12)
177
+ axes[num_middle_axes].set_xlim(0, 1)
178
+ axes[num_middle_axes].set_ylim(0, 1)
179
+ axes[num_middle_axes].axis('off')
180
+
181
+ circle2 = plt.Circle((0.5, 0.5), 0.4, color='lightcoral', ec='black', lw=1.5)
182
+ axes[-1].add_patch(circle2)
183
+ axes[-1].text(0.5, 0.5, text_2, ha='center', va='center', fontsize=12)
184
+ axes[-1].set_xlim(0, 1)
185
+ axes[-1].set_ylim(0, 1)
186
+ axes[-1].axis('off')
187
+
188
+
189
+ fig.text(0.01, 0, bottom_left_text, ha='left', va='bottom', fontsize=10)
190
+ fig.text(0.01, 0.04, bottom_left_text_1, ha='left', va='bottom', fontsize=10)
191
+ fig.text(0.01, 0.08, bottom_left_text_2, ha='left', va='bottom', fontsize=10)
192
+
193
+ plt.tight_layout()
194
+ plt.show()
195
+
196
+
197
+ def draw_activations(x_train, activation):
198
+
199
+ from . import activation_functions_cuda as af
200
+
201
+ if activation == 'sigmoid':
202
+ result = af.Sigmoid(x_train)
203
+
204
+ elif activation == 'swish':
205
+ result = af.swish(x_train)
206
+
207
+ elif activation == 'circular':
208
+ result = af.circular_activation(x_train)
209
+
210
+ elif activation == 'mod_circular':
211
+ result = af.modular_circular_activation(x_train)
212
+
213
+ elif activation == 'tanh_circular':
214
+ result = af.tanh_circular_activation(x_train)
215
+
216
+ elif activation == 'leaky_relu':
217
+ result = af.leaky_relu(x_train)
218
+
219
+ elif activation == 'relu':
220
+ result = af.Relu(x_train)
221
+
222
+ elif activation == 'softplus':
223
+ result = af.softplus(x_train)
224
+
225
+ elif activation == 'elu':
226
+ result = af.elu(x_train)
227
+
228
+ elif activation == 'gelu':
229
+ result = af.gelu(x_train)
230
+
231
+ elif activation == 'selu':
232
+ result = af.selu(x_train)
233
+
234
+ elif activation == 'softmax':
235
+ result = af.Softmax(x_train)
236
+
237
+ elif activation == 'tanh':
238
+ result = af.tanh(x_train)
239
+
240
+ elif activation == 'sinakt':
241
+ result = af.sinakt(x_train)
242
+
243
+ elif activation == 'p_squared':
244
+ result = af.p_squared(x_train)
245
+
246
+ elif activation == 'sglu':
247
+ result = af.sglu(x_train, alpha=1.0)
248
+
249
+ elif activation == 'dlrelu':
250
+ result = af.dlrelu(x_train)
251
+
252
+ elif activation == 'exsig':
253
+ result = af.exsig(x_train)
254
+
255
+ elif activation == 'sin_plus':
256
+ result = af.sin_plus(x_train)
257
+
258
+ elif activation == 'acos':
259
+ result = af.acos(x_train, alpha=1.0, beta=0.0)
260
+
261
+ elif activation == 'gla':
262
+ result = af.gla(x_train, alpha=1.0, mu=0.0)
263
+
264
+ elif activation == 'srelu':
265
+ result = af.srelu(x_train)
266
+
267
+ elif activation == 'qelu':
268
+ result = af.qelu(x_train)
269
+
270
+ elif activation == 'isra':
271
+ result = af.isra(x_train)
272
+
273
+ elif activation == 'waveakt':
274
+ result = af.waveakt(x_train)
275
+
276
+ elif activation == 'arctan':
277
+ result = af.arctan(x_train)
278
+
279
+ elif activation == 'bent_identity':
280
+ result = af.bent_identity(x_train)
281
+
282
+ elif activation == 'sech':
283
+ result = af.sech(x_train)
284
+
285
+ elif activation == 'softsign':
286
+ result = af.softsign(x_train)
287
+
288
+ elif activation == 'pwl':
289
+ result = af.pwl(x_train)
290
+
291
+ elif activation == 'cubic':
292
+ result = af.cubic(x_train)
293
+
294
+ elif activation == 'gaussian':
295
+ result = af.gaussian(x_train)
296
+
297
+ elif activation == 'sine':
298
+ result = af.sine(x_train)
299
+
300
+ elif activation == 'tanh_square':
301
+ result = af.tanh_square(x_train)
302
+
303
+ elif activation == 'mod_sigmoid':
304
+ result = af.mod_sigmoid(x_train)
305
+
306
+ elif activation == 'linear':
307
+ result = x_train
308
+
309
+ elif activation == 'quartic':
310
+ result = af.quartic(x_train)
311
+
312
+ elif activation == 'square_quartic':
313
+ result = af.square_quartic(x_train)
314
+
315
+ elif activation == 'cubic_quadratic':
316
+ result = af.cubic_quadratic(x_train)
317
+
318
+ elif activation == 'exp_cubic':
319
+ result = af.exp_cubic(x_train)
320
+
321
+ elif activation == 'sine_square':
322
+ result = af.sine_square(x_train)
323
+
324
+ elif activation == 'logarithmic':
325
+ result = af.logarithmic(x_train)
326
+
327
+ elif activation == 'scaled_cubic':
328
+ result = af.scaled_cubic(x_train, 1.0)
329
+
330
+ elif activation == 'sine_offset':
331
+ result = af.sine_offset(x_train, 1.0)
332
+
333
+ elif activation == 'spiral':
334
+ result = af.spiral_activation(x_train)
335
+
336
+ return result
337
+
338
+
339
+ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation):
340
+
341
+ from .metrics_cuda import metrics, confusion_matrix, roc_curve
342
+ from .ui import loading_bars, initialize_loading_bar
343
+ from .data_operations_cuda import decode_one_hot
344
+ from .model_operations_cuda import predict_model_ram
345
+
346
+ bar_format_normal = loading_bars()[0]
347
+
348
+ acc = acc_list[len(acc_list) - 1]
349
+ y_true = decode_one_hot(y_test)
350
+
351
+ y_true = cp.array(y_true)
352
+ y_preds = cp.array(y_preds)
353
+ Class = cp.unique(decode_one_hot(y_test))
354
+
355
+ precision, recall, f1 = metrics(y_test, y_preds)
356
+
357
+
358
+ cm = confusion_matrix(y_true, y_preds, len(Class))
359
+ fig, axs = plt.subplots(2, 2, figsize=(16, 12))
360
+
361
+ sns.heatmap(cm, annot=True, fmt='d', ax=axs[0, 0])
362
+ axs[0, 0].set_title("Confusion Matrix")
363
+ axs[0, 0].set_xlabel("Predicted Class")
364
+ axs[0, 0].set_ylabel("Actual Class")
365
+
366
+ if len(Class) == 2:
367
+ fpr, tpr, thresholds = roc_curve(y_true, y_preds)
368
+
369
+ roc_auc = cp.trapz(tpr, fpr)
370
+ axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
371
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
372
+ axs[1, 0].set_xlim([0.0, 1.0])
373
+ axs[1, 0].set_ylim([0.0, 1.05])
374
+ axs[1, 0].set_xlabel('False Positive Rate')
375
+ axs[1, 0].set_ylabel('True Positive Rate')
376
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
377
+ axs[1, 0].legend(loc="lower right")
378
+ axs[1, 0].legend(loc="lower right")
379
+ else:
380
+
381
+ for i in range(len(Class)):
382
+
383
+ y_true_copy = cp.copy(y_true)
384
+ y_preds_copy = cp.copy(y_preds)
385
+
386
+ y_true_copy[y_true_copy == i] = 0
387
+ y_true_copy[y_true_copy != 0] = 1
388
+
389
+ y_preds_copy[y_preds_copy == i] = 0
390
+ y_preds_copy[y_preds_copy != 0] = 1
391
+
392
+
393
+ fpr, tpr, thresholds = roc_curve(y_true_copy, y_preds_copy)
394
+
395
+ roc_auc = cp.trapz(tpr, fpr)
396
+ axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
397
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
398
+ axs[1, 0].set_xlim([0.0, 1.0])
399
+ axs[1, 0].set_ylim([0.0, 1.05])
400
+ axs[1, 0].set_xlabel('False Positive Rate')
401
+ axs[1, 0].set_ylabel('True Positive Rate')
402
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
403
+ axs[1, 0].legend(loc="lower right")
404
+ axs[1, 0].legend(loc="lower right")
405
+
406
+
407
+ metric = ['Precision', 'Recall', 'F1 Score', 'Accuracy']
408
+ values = [precision, recall, f1, acc]
409
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
410
+
411
+
412
+ bars = axs[0, 1].bar(metric, values, color=colors)
413
+
414
+
415
+ for bar, value in zip(bars, values):
416
+ axs[0, 1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() - 0.05, f'{value:.2f}',
417
+ ha='center', va='bottom', fontsize=12, color='white', weight='bold')
418
+
419
+ axs[0, 1].set_ylim(0, 1)
420
+ axs[0, 1].set_xlabel('Metrics')
421
+ axs[0, 1].set_ylabel('Score')
422
+ axs[0, 1].set_title('Precision, Recall, F1 Score, and Accuracy (Weighted)')
423
+ axs[0, 1].grid(True, axis='y', linestyle='--', alpha=0.7)
424
+
425
+ feature_indices=[0, 1]
426
+
427
+ h = .02
428
+ x_min, x_max = x_test[:, feature_indices[0]].min() - 1, x_test[:, feature_indices[0]].max() + 1
429
+ y_min, y_max = x_test[:, feature_indices[1]].min() - 1, x_test[:, feature_indices[1]].max() + 1
430
+ xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
431
+ cp.arange(y_min, y_max, h))
432
+
433
+ grid = cp.c_[xx.ravel(), yy.ravel()]
434
+
435
+ try:
436
+
437
+ grid_full = cp.zeros((grid.shape[0], x_test.shape[1]))
438
+ grid_full[:, feature_indices] = grid
439
+
440
+ Z = [None] * len(grid_full)
441
+
442
+ predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
443
+ bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
444
+
445
+ for i in range(len(grid_full)):
446
+
447
+ Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
448
+ predict_progress.update(1)
449
+
450
+ predict_progress.close()
451
+
452
+ Z = cp.array(Z)
453
+ Z = Z.reshape(xx.shape)
454
+
455
+ axs[1,1].contourf(xx, yy, Z, alpha=0.8)
456
+ axs[1,1].scatter(x_test[:, feature_indices[0]], x_test[:, feature_indices[1]], c=decode_one_hot(y_test), edgecolors='k', marker='o', s=20, alpha=0.9)
457
+ axs[1,1].set_xlabel(f'Feature {0 + 1}')
458
+ axs[1,1].set_ylabel(f'Feature {1 + 1}')
459
+ axs[1,1].set_title('Decision Boundary')
460
+
461
+ except Exception as e:
462
+ # Hata meydana geldiğinde yapılacak işlemler
463
+ print(f"Hata oluştu: {e}")
464
+
465
+ plt.show()
466
+
467
+
468
+ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=None):
469
+
470
+ from .model_operations_cuda import predict_model_ram
471
+ from .data_operations_cuda import decode_one_hot
472
+
473
+ feature_indices = [0, 1]
474
+
475
+ h = .02
476
+ x_min, x_max = x[:, feature_indices[0]].min() - 1, x[:, feature_indices[0]].max() + 1
477
+ y_min, y_max = x[:, feature_indices[1]].min() - 1, x[:, feature_indices[1]].max() + 1
478
+ xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
479
+ cp.arange(y_min, y_max, h))
480
+
481
+ grid = cp.c_[xx.ravel(), yy.ravel()]
482
+ grid_full = cp.zeros((grid.shape[0], x.shape[1]))
483
+ grid_full[:, feature_indices] = grid
484
+
485
+ Z = [None] * len(grid_full)
486
+
487
+ for i in range(len(grid_full)):
488
+ Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
489
+
490
+ Z = cp.array(Z)
491
+ Z = Z.reshape(xx.shape)
492
+
493
+ if ax is None:
494
+
495
+ plt.contourf(xx, yy, Z, alpha=0.8)
496
+ plt.scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
497
+ plt.xlabel(f'Feature {0 + 1}')
498
+ plt.ylabel(f'Feature {1 + 1}')
499
+ plt.title('Decision Boundary')
500
+
501
+ plt.show()
502
+
503
+ else:
504
+
505
+ try:
506
+ art1_1 = ax[1, 0].contourf(xx, yy, Z, alpha=0.8)
507
+ art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
508
+ ax[1, 0].set_xlabel(f'Feature {0 + 1}')
509
+ ax[1, 0].set_ylabel(f'Feature {1 + 1}')
510
+ ax[1, 0].set_title('Decision Boundary')
511
+
512
+ return art1_1, art1_2
513
+
514
+ except:
515
+
516
+ art1_1 = ax[0].contourf(xx, yy, Z, alpha=0.8)
517
+ art1_2 = ax[0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
518
+ ax[0].set_xlabel(f'Feature {0 + 1}')
519
+ ax[0].set_ylabel(f'Feature {1 + 1}')
520
+ ax[0].set_title('Decision Boundary')
521
+
522
+
523
+ return art1_1, art1_2
524
+
525
+
526
+ def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
527
+
528
+ from metrics import pca
529
+ from data_operations import decode_one_hot
530
+
531
+ if x.shape[1] > 2:
532
+
533
+ X_pca = pca(x, n_components=2)
534
+ else:
535
+ X_pca = x
536
+
537
+ if y_preds == None:
538
+ y_preds = decode_one_hot(y)
539
+
540
+ y = decode_one_hot(y)
541
+ num_classes = len(cp.unique(y))
542
+
543
+ cmap = plt.get_cmap(color)
544
+
545
+
546
+ norm = plt.Normalize(vmin=0, vmax=num_classes - 1)
547
+
548
+
549
+ plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, edgecolor='k', s=50, cmap=cmap, norm=norm)
550
+
551
+
552
+ for cls in range(num_classes):
553
+
554
+ class_points = []
555
+
556
+
557
+ for i in range(len(y)):
558
+ if y_preds[i] == cls:
559
+ class_points.append(X_pca[i])
560
+
561
+ class_points = cp.array(class_points)
562
+
563
+
564
+ if len(class_points) > 2:
565
+ hull = ConvexHull(class_points)
566
+ hull_points = class_points[hull.vertices]
567
+
568
+ hull_points = cp.vstack([hull_points, hull_points[0]])
569
+
570
+ plt.fill(hull_points[:, 0], hull_points[:, 1], color=cmap(norm(cls)), alpha=0.3, edgecolor='k', label=f'Class {cls} Hull')
571
+
572
+ plt.title("Decision Space (Data Distribution)")
573
+
574
+ plt.draw()
575
+
576
+
577
+ def neuron_history(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=False, loss=False):
578
+
579
+ for j in range(len(class_count)):
580
+
581
+ if acc != False and loss != False:
582
+ suptitle_info = data + ' Accuracy:' + str(acc) + '\n' + data + ' Loss:' + str(loss) + '\nNeurons Memory:'
583
+ else:
584
+ suptitle_info = 'Neurons Memory:'
585
+
586
+ mat = LTPW[j,:].reshape(row, col)
587
+
588
+ title_info = f'{j+1}. Neuron'
589
+
590
+ art5 = ax1[j].imshow(mat, interpolation='sinc', cmap='viridis')
591
+
592
+ ax1[j].set_aspect('equal')
593
+ ax1[j].set_xticks([])
594
+ ax1[j].set_yticks([])
595
+ ax1[j].set_title(title_info)
596
+
597
+
598
+ artist5.append([art5])
599
+
600
+ fig1.suptitle(suptitle_info, fontsize=16)
601
+
602
+ return artist5
603
+
604
+
605
+ def initialize_visualization_for_fit(val, show_training, neurons_history, x_train, y_train):
606
+ """Initializes the visualization setup based on the parameters."""
607
+ from data_operations import find_closest_factors
608
+ visualization_objects = {}
609
+
610
+ if show_training:
611
+ if not val:
612
+ raise ValueError("For showing training, 'val' parameter must be True.")
613
+
614
+ G = nx.Graph()
615
+ fig, ax = plt.subplots(2, 2)
616
+ fig.suptitle('Train History')
617
+ visualization_objects.update({
618
+ 'G': G,
619
+ 'fig': fig,
620
+ 'ax': ax,
621
+ 'artist1': [],
622
+ 'artist2': [],
623
+ 'artist3': [],
624
+ 'artist4': []
625
+ })
626
+
627
+ if neurons_history:
628
+ row, col = find_closest_factors(len(x_train[0]))
629
+ fig1, ax1 = plt.subplots(1, len(set(y_train)), figsize=(18, 14))
630
+ visualization_objects.update({
631
+ 'fig1': fig1,
632
+ 'ax1': ax1,
633
+ 'artist5': [],
634
+ 'row': row,
635
+ 'col': col
636
+ })
637
+
638
+ return visualization_objects
639
+
640
+
641
+ def update_weight_visualization_for_fit(ax, LTPW, artist2):
642
+ """Updates the weight visualization plot."""
643
+ art2 = ax.imshow(LTPW, interpolation='sinc', cmap='viridis')
644
+ artist2.append([art2])
645
+
646
+
647
+ def update_decision_boundary_for_fit(ax, x_val, y_val, activation_potentiation, LTPW, artist1):
648
+ """Updates the decision boundary visualization."""
649
+ art1_1, art1_2 = plot_decision_boundary(x_val, y_val, activation_potentiation, LTPW, artist=artist1, ax=ax)
650
+ artist1.append([*art1_1.collections, art1_2])
651
+
652
+
653
+ def update_validation_history_for_fit(ax, val_list, artist3):
654
+ """Updates the validation accuracy history plot."""
655
+ period = list(range(1, len(val_list) + 1))
656
+ art3 = ax.plot(
657
+ period,
658
+ val_list,
659
+ linestyle='--',
660
+ color='g',
661
+ marker='o',
662
+ markersize=6,
663
+ linewidth=2,
664
+ label='Validation Accuracy'
665
+ )
666
+ ax.set_title('Validation History')
667
+ ax.set_xlabel('Time')
668
+ ax.set_ylabel('Validation Accuracy')
669
+ ax.set_ylim([0, 1])
670
+ artist3.append(art3)
671
+
672
+
673
+ def display_visualization_for_fit(fig, artist_list, interval):
674
+ """Displays the animation for the given artist list."""
675
+ ani = ArtistAnimation(fig, artist_list, interval=interval, blit=True)
676
+ plt.tight_layout()
677
+ plt.show()
678
+
679
+
680
+
681
+ def initialize_visualization_for_learner(show_history, neurons_history, neural_web_history, x_train, y_train):
682
+ """Initialize all visualization components"""
683
+ from data_operations import find_closest_factors
684
+ viz_objects = {}
685
+
686
+ if show_history:
687
+ fig, ax = plt.subplots(3, 1, figsize=(6, 8))
688
+ fig.suptitle('Learner History')
689
+ viz_objects['history'] = {
690
+ 'fig': fig,
691
+ 'ax': ax,
692
+ 'artist1': [],
693
+ 'artist2': [],
694
+ 'artist3': []
695
+ }
696
+
697
+ if neurons_history:
698
+ row, col = find_closest_factors(len(x_train[0]))
699
+ if row != 0:
700
+ fig1, ax1 = plt.subplots(1, len(y_train[0]), figsize=(18, 14))
701
+ else:
702
+ fig1, ax1 = plt.subplots(1, 1, figsize=(18, 14))
703
+ viz_objects['neurons'] = {
704
+ 'fig': fig1,
705
+ 'ax': ax1,
706
+ 'artists': [],
707
+ 'row': row,
708
+ 'col': col
709
+ }
710
+
711
+ if neural_web_history:
712
+ G = nx.Graph()
713
+ fig2, ax2 = plt.subplots(figsize=(18, 4))
714
+ viz_objects['web'] = {
715
+ 'fig': fig2,
716
+ 'ax': ax2,
717
+ 'G': G,
718
+ 'artists': []
719
+ }
720
+
721
+ return viz_objects
722
+
723
+ def update_history_plots_for_learner(viz_objects, depth_list, loss_list, best_acc_per_depth_list, x_train, final_activations):
724
+ """Update history visualization plots"""
725
+ if 'history' not in viz_objects:
726
+ return
727
+
728
+ hist = viz_objects['history']
729
+
730
+ # Loss plot
731
+ art1 = hist['ax'][0].plot(depth_list, loss_list, color='r', markersize=6, linewidth=2)
732
+ hist['ax'][0].set_title('Test Loss Over Depth')
733
+ hist['artist1'].append(art1)
734
+
735
+ # Accuracy plot
736
+ art2 = hist['ax'][1].plot(depth_list, best_acc_per_depth_list, color='g', markersize=6, linewidth=2)
737
+ hist['ax'][1].set_title('Test Accuracy Over Depth')
738
+ hist['artist2'].append(art2)
739
+
740
+ # Activation shape plot
741
+ x = cp.linspace(cp.min(x_train), cp.max(x_train), len(x_train))
742
+ translated_x_train = cp.copy(x)
743
+ for activation in final_activations:
744
+ translated_x_train += draw_activations(x, activation)
745
+
746
+ art3 = hist['ax'][2].plot(x, translated_x_train, color='b', markersize=6, linewidth=2)
747
+ hist['ax'][2].set_title('Potentiation Shape Over Depth')
748
+ hist['artist3'].append(art3)
749
+
750
+ def display_visualizations_for_learner(viz_objects, best_weights, data, best_acc, test_loss, y_train, interval):
751
+ """Display all final visualizations"""
752
+ if 'history' in viz_objects:
753
+ hist = viz_objects['history']
754
+ for _ in range(30):
755
+ hist['artist1'].append(hist['artist1'][-1])
756
+ hist['artist2'].append(hist['artist2'][-1])
757
+ hist['artist3'].append(hist['artist3'][-1])
758
+
759
+ ani1 = ArtistAnimation(hist['fig'], hist['artist1'], interval=interval, blit=True)
760
+ ani2 = ArtistAnimation(hist['fig'], hist['artist2'], interval=interval, blit=True)
761
+ ani3 = ArtistAnimation(hist['fig'], hist['artist3'], interval=interval, blit=True)
762
+ plt.tight_layout()
763
+ plt.show()
764
+
765
+ if 'neurons' in viz_objects:
766
+ neurons = viz_objects['neurons']
767
+ for _ in range(10):
768
+ neurons['artists'] = neuron_history(
769
+ cp.copy(best_weights),
770
+ neurons['ax'],
771
+ neurons['row'],
772
+ neurons['col'],
773
+ y_train[0],
774
+ neurons['artists'],
775
+ data=data,
776
+ fig1=neurons['fig'],
777
+ acc=best_acc,
778
+ loss=test_loss
779
+ )
780
+
781
+ ani4 = ArtistAnimation(neurons['fig'], neurons['artists'], interval=interval, blit=True)
782
+ plt.tight_layout()
783
+ plt.show()
784
+
785
+ if 'web' in viz_objects:
786
+ web = viz_objects['web']
787
+ for _ in range(30):
788
+ art5_1, art5_2, art5_3 = draw_neural_web(
789
+ W=best_weights,
790
+ ax=web['ax'],
791
+ G=web['G'],
792
+ return_objs=True
793
+ )
794
+ art5_list = [art5_1] + [art5_2] + list(art5_3.values())
795
+ web['artists'].append(art5_list)
796
+
797
+ ani5 = ArtistAnimation(web['fig'], web['artists'], interval=interval, blit=True)
798
+ plt.tight_layout()
799
+ plt.show()