pyerualjetwork 2.7.8__py3-none-any.whl → 4.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,826 @@
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+ import cupy as cp
4
+ import numpy as np
5
+ from scipy.spatial import ConvexHull
6
+ import seaborn as sns
7
+ from matplotlib.animation import ArtistAnimation
8
+
9
+ def draw_neural_web(W, ax, G, return_objs=False):
10
+ """
11
+ Visualizes a neural web by drawing the neural network structure.
12
+
13
+ Parameters:
14
+ W : numpy.ndarray
15
+ A 2D array representing the connection weights of the neural network.
16
+ ax : matplotlib.axes.Axes
17
+ The matplotlib axes where the graph will be drawn.
18
+ G : networkx.Graph
19
+ The NetworkX graph representing the neural network structure.
20
+ return_objs : bool, optional
21
+ If True, returns the drawn objects (nodes and edges). Default is False.
22
+
23
+ Returns:
24
+ art1 : matplotlib.collections.PathCollection or None
25
+ Returns the node collection if return_objs is True; otherwise, returns None.
26
+ art2 : matplotlib.collections.LineCollection or None
27
+ Returns the edge collection if return_objs is True; otherwise, returns None.
28
+ art3 : matplotlib.collections.TextCollection or None
29
+ Returns the label collection if return_objs is True; otherwise, returns None.
30
+
31
+ Example:
32
+ art1, art2, art3 = draw_neural_web(W, ax, G, return_objs=True)
33
+ plt.show()
34
+ """
35
+ W = W.get()
36
+ for i in range(W.shape[0]):
37
+ for j in range(W.shape[1]):
38
+ if W[i, j] != 0:
39
+ G.add_edge(f'Output{i}', f'Input{j}', ltpw=W[i, j])
40
+
41
+ edges = G.edges(data=True)
42
+ weights = [edata['ltpw'] for _, _, edata in edges]
43
+ pos = {}
44
+
45
+ num_motor_neurons = W.shape[0]
46
+ num_sensory_neurons = W.shape[1]
47
+
48
+ for j in range(num_sensory_neurons):
49
+ pos[f'Input{j}'] = (0, j)
50
+
51
+ motor_y_start = (num_sensory_neurons - num_motor_neurons) / 2
52
+ for i in range(num_motor_neurons):
53
+ pos[f'Output{i}'] = (1, motor_y_start + i)
54
+
55
+
56
+ art1 = nx.draw_networkx_nodes(G, pos, ax=ax, node_size=1000, node_color='lightblue')
57
+ art2 = nx.draw_networkx_edges(G, pos, ax=ax, edge_color=weights, edge_cmap=plt.cm.Blues, width=2)
58
+ art3 = nx.draw_networkx_labels(G, pos, ax=ax, font_size=10, font_weight='bold')
59
+
60
+ ax.spines['top'].set_visible(False)
61
+ ax.spines['right'].set_visible(False)
62
+ ax.spines['left'].set_visible(False)
63
+ ax.spines['bottom'].set_visible(False)
64
+ ax.get_xaxis().set_visible(False)
65
+ ax.get_yaxis().set_visible(False)
66
+ ax.set_title('Neural Web')
67
+
68
+ if return_objs == True:
69
+
70
+ return art1, art2, art3
71
+
72
+
73
+ def draw_model_architecture(model_name, model_path=''):
74
+ """
75
+ The `draw_model_architecture` function visualizes the architecture of a neural network model with
76
+ multiple inputs based on activation functions.
77
+
78
+ :param model_name: The `model_name` parameter in the `draw_model_architecture` function is used to
79
+ specify the name of the neural network model whose architecture you want to visualize. This function
80
+ visualizes the architecture of a neural network model with multiple inputs based on activation
81
+ functions
82
+ :param model_path: The `model_path` parameter in the `draw_model_architecture` function is used to
83
+ specify the path where the neural network model is saved. If the model is saved in a specific
84
+ directory or file location, you can provide that path as a string when calling the function. If the
85
+ model is saved
86
+ """
87
+ """
88
+ Visualizes the architecture of a neural network model with multiple inputs based on activation functions.
89
+ """
90
+
91
+ from .model_operations_cuda import load_model, get_scaler, get_act_pot, get_weights
92
+
93
+ model = load_model(model_name=model_name, model_path=model_path)
94
+
95
+ W = model[get_weights()]
96
+ activation_potentiation = model[get_act_pot()]
97
+ scaler_params = model[get_scaler()]
98
+
99
+ # Calculate dimensions based on number of activation functions
100
+ num_activations = len(activation_potentiation)
101
+ input_groups = num_activations # Number of input groups equals number of activations
102
+ num_inputs = W.shape[1]
103
+
104
+ # Create figure
105
+ fig = plt.figure(figsize=(15, 10))
106
+
107
+ # Calculate positions for nodes
108
+ def get_node_positions():
109
+ positions = {}
110
+
111
+ # Input layer positions
112
+ total_height = 0.8 # Maksimum dikey alan
113
+ group_height = total_height / input_groups # Her grup için ayrılan dikey alan
114
+ input_spacing = min(group_height / (num_inputs + 1), 0.1) # Her girdi arasındaki mesafe
115
+
116
+ for group in range(input_groups):
117
+ group_start_y = 0.9 - (group * group_height) # Grubun başlangıç y koordinatı
118
+ for i in range(num_inputs):
119
+ y_pos = group_start_y - ((i + 1) * input_spacing)
120
+ positions[f'input_{group}_{i}'] = (0.2, y_pos)
121
+
122
+ # Aggregation layer positions
123
+ agg_spacing = total_height / (num_inputs + 1)
124
+ for i in range(num_inputs):
125
+ positions[f'summed_{i}'] = (0.5, 0.9 - ((i + 1) * agg_spacing))
126
+
127
+ # Output layer positions
128
+ output_spacing = total_height / (W.shape[0] + 1)
129
+ for i in range(W.shape[0]):
130
+ positions[f'output_{i}'] = (0.8, 0.9 - ((i + 1) * output_spacing))
131
+
132
+ return positions
133
+
134
+ # Draw the network
135
+ pos = get_node_positions()
136
+
137
+ # Draw nodes
138
+ for group in range(input_groups):
139
+ # Draw input nodes
140
+ for i in range(num_inputs):
141
+ plt.plot(*pos[f'input_{group}_{i}'], 'o', color='lightgreen', markersize=20)
142
+ plt.text(pos[f'input_{group}_{i}'][0] - 0.05, pos[f'input_{group}_{i}'][1],
143
+ f'Input #{i+1} ({activation_potentiation[group]})', ha='right', va='center')
144
+
145
+ # Draw connections from input to summed input directly
146
+ plt.plot([pos[f'input_{group}_{i}'][0], pos[f'summed_{i}'][0]],
147
+ [pos[f'input_{group}_{i}'][1], pos[f'summed_{i}'][1]], 'k-')
148
+ # Draw aggregation nodes
149
+ if group == 0:
150
+ plt.plot(*pos[f'summed_{i}'], 'o', color='lightgreen', markersize=20)
151
+ plt.text(pos[f'summed_{i}'][0], pos[f'summed_{i}'][1] + 0.02,
152
+ f'Summed\nInput #{i+1}', ha='center', va='bottom')
153
+
154
+ # Draw output nodes and connections
155
+ for i in range(W.shape[0]):
156
+ plt.plot(*pos[f'output_{i}'], 'o', color='gold', markersize=20)
157
+ plt.text(pos[f'output_{i}'][0] + 0.05, pos[f'output_{i}'][1],
158
+ f'Output #{i+1}', ha='left', va='center', color='purple')
159
+
160
+ # Connect all aggregation nodes to each output
161
+ for group in range(num_inputs):
162
+ plt.plot([pos[f'summed_{group}'][0], pos[f'output_{i}'][0]],
163
+ [pos[f'summed_{group}'][1], pos[f'output_{i}'][1]], 'k-')
164
+
165
+ # Add labels and annotations
166
+ plt.text(0.2, 0.95, 'Input Layer', ha='center', va='bottom', fontsize=12)
167
+ plt.text(0.5, 0.95, 'Aggregation\nLayer', ha='center', va='bottom', fontsize=12)
168
+ plt.text(0.8, 0.95, 'Output Layer', ha='center', va='bottom', fontsize=12)
169
+
170
+ # Remove axes
171
+ plt.axis('off')
172
+
173
+ # Add model information
174
+ if scaler_params is None:
175
+ plt.text(0.95, 0.05, 'Standard Scaler=No', fontsize=10, ha='right', va='bottom')
176
+ else:
177
+ plt.text(0.95, 0.05, 'Standard Scaler=Yes', fontsize=10, ha='right', va='bottom')
178
+
179
+ # Add model architecture title
180
+ plt.text(0.95, 0.1, f"PLAN Model Architecture: {model_name}", fontsize=12, ha='right', va='bottom', fontweight='bold')
181
+ plt.tight_layout()
182
+ plt.show()
183
+
184
+
185
+ def draw_activations(x_train, activation):
186
+
187
+ from . import activation_functions_cuda as af
188
+
189
+ if activation == 'sigmoid':
190
+ result = af.Sigmoid(x_train)
191
+
192
+ elif activation == 'swish':
193
+ result = af.swish(x_train)
194
+
195
+ elif activation == 'circular':
196
+ result = af.circular_activation(x_train)
197
+
198
+ elif activation == 'mod_circular':
199
+ result = af.modular_circular_activation(x_train)
200
+
201
+ elif activation == 'tanh_circular':
202
+ result = af.tanh_circular_activation(x_train)
203
+
204
+ elif activation == 'leaky_relu':
205
+ result = af.leaky_relu(x_train)
206
+
207
+ elif activation == 'relu':
208
+ result = af.Relu(x_train)
209
+
210
+ elif activation == 'softplus':
211
+ result = af.softplus(x_train)
212
+
213
+ elif activation == 'elu':
214
+ result = af.elu(x_train)
215
+
216
+ elif activation == 'gelu':
217
+ result = af.gelu(x_train)
218
+
219
+ elif activation == 'selu':
220
+ result = af.selu(x_train)
221
+
222
+ elif activation == 'softmax':
223
+ result = af.Softmax(x_train)
224
+
225
+ elif activation == 'tanh':
226
+ result = af.tanh(x_train)
227
+
228
+ elif activation == 'sinakt':
229
+ result = af.sinakt(x_train)
230
+
231
+ elif activation == 'p_squared':
232
+ result = af.p_squared(x_train)
233
+
234
+ elif activation == 'sglu':
235
+ result = af.sglu(x_train, alpha=1.0)
236
+
237
+ elif activation == 'dlrelu':
238
+ result = af.dlrelu(x_train)
239
+
240
+ elif activation == 'exsig':
241
+ result = af.exsig(x_train)
242
+
243
+ elif activation == 'sin_plus':
244
+ result = af.sin_plus(x_train)
245
+
246
+ elif activation == 'acos':
247
+ result = af.acos(x_train, alpha=1.0, beta=0.0)
248
+
249
+ elif activation == 'gla':
250
+ result = af.gla(x_train, alpha=1.0, mu=0.0)
251
+
252
+ elif activation == 'srelu':
253
+ result = af.srelu(x_train)
254
+
255
+ elif activation == 'qelu':
256
+ result = af.qelu(x_train)
257
+
258
+ elif activation == 'isra':
259
+ result = af.isra(x_train)
260
+
261
+ elif activation == 'waveakt':
262
+ result = af.waveakt(x_train)
263
+
264
+ elif activation == 'arctan':
265
+ result = af.arctan(x_train)
266
+
267
+ elif activation == 'bent_identity':
268
+ result = af.bent_identity(x_train)
269
+
270
+ elif activation == 'sech':
271
+ result = af.sech(x_train)
272
+
273
+ elif activation == 'softsign':
274
+ result = af.softsign(x_train)
275
+
276
+ elif activation == 'pwl':
277
+ result = af.pwl(x_train)
278
+
279
+ elif activation == 'cubic':
280
+ result = af.cubic(x_train)
281
+
282
+ elif activation == 'gaussian':
283
+ result = af.gaussian(x_train)
284
+
285
+ elif activation == 'sine':
286
+ result = af.sine(x_train)
287
+
288
+ elif activation == 'tanh_square':
289
+ result = af.tanh_square(x_train)
290
+
291
+ elif activation == 'mod_sigmoid':
292
+ result = af.mod_sigmoid(x_train)
293
+
294
+ elif activation == 'linear':
295
+ result = x_train
296
+
297
+ elif activation == 'quartic':
298
+ result = af.quartic(x_train)
299
+
300
+ elif activation == 'square_quartic':
301
+ result = af.square_quartic(x_train)
302
+
303
+ elif activation == 'cubic_quadratic':
304
+ result = af.cubic_quadratic(x_train)
305
+
306
+ elif activation == 'exp_cubic':
307
+ result = af.exp_cubic(x_train)
308
+
309
+ elif activation == 'sine_square':
310
+ result = af.sine_square(x_train)
311
+
312
+ elif activation == 'logarithmic':
313
+ result = af.logarithmic(x_train)
314
+
315
+ elif activation == 'scaled_cubic':
316
+ result = af.scaled_cubic(x_train, 1.0)
317
+
318
+ elif activation == 'sine_offset':
319
+ result = af.sine_offset(x_train, 1.0)
320
+
321
+ elif activation == 'spiral':
322
+ result = af.spiral_activation(x_train)
323
+
324
+ try: return result
325
+ except:
326
+ print('WARNING: error in drawing some activation.')
327
+ return x_train
328
+
329
+
330
+ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation):
331
+
332
+ from .metrics_cuda import metrics, confusion_matrix, roc_curve
333
+ from .ui import loading_bars, initialize_loading_bar
334
+ from .data_operations_cuda import decode_one_hot
335
+ from .model_operations_cuda import predict_model_ram
336
+
337
+ bar_format_normal = loading_bars()[0]
338
+
339
+ acc = acc_list[len(acc_list) - 1]
340
+ y_true = decode_one_hot(y_test)
341
+
342
+ y_true = cp.array(y_true, copy=True)
343
+ y_preds = cp.array(y_preds, copy=True)
344
+ Class = cp.unique(decode_one_hot(y_test))
345
+
346
+ precision, recall, f1 = metrics(y_test, y_preds)
347
+
348
+
349
+ cm = confusion_matrix(y_true, y_preds, len(Class))
350
+ fig, axs = plt.subplots(2, 2, figsize=(16, 12))
351
+
352
+ sns.heatmap(cm.get(), annot=True, fmt='d', ax=axs[0, 0])
353
+ axs[0, 0].set_title("Confusion Matrix")
354
+ axs[0, 0].set_xlabel("Predicted Class")
355
+ axs[0, 0].set_ylabel("Actual Class")
356
+
357
+ if len(Class) == 2:
358
+ fpr, tpr, thresholds = roc_curve(y_true, y_preds)
359
+
360
+ roc_auc = cp.trapz(tpr, fpr)
361
+ axs[1, 0].plot(fpr.get(), tpr.get(), color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
362
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
363
+ axs[1, 0].set_xlim([0.0, 1.0])
364
+ axs[1, 0].set_ylim([0.0, 1.05])
365
+ axs[1, 0].set_xlabel('False Positive Rate')
366
+ axs[1, 0].set_ylabel('True Positive Rate')
367
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
368
+ axs[1, 0].legend(loc="lower right")
369
+ axs[1, 0].legend(loc="lower right")
370
+ else:
371
+
372
+ for i in range(len(Class)):
373
+
374
+ y_true_copy = cp.copy(y_true)
375
+ y_preds_copy = cp.copy(y_preds)
376
+
377
+ y_true_copy[y_true_copy == i] = 0
378
+ y_true_copy[y_true_copy != 0] = 1
379
+
380
+ y_preds_copy[y_preds_copy == i] = 0
381
+ y_preds_copy[y_preds_copy != 0] = 1
382
+
383
+
384
+ fpr, tpr, thresholds = roc_curve(y_true_copy, y_preds_copy)
385
+
386
+ roc_auc = cp.trapz(tpr, fpr)
387
+ axs[1, 0].plot(fpr.get(), tpr.get(), color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
388
+ axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
389
+ axs[1, 0].set_xlim([0.0, 1.0])
390
+ axs[1, 0].set_ylim([0.0, 1.05])
391
+ axs[1, 0].set_xlabel('False Positive Rate')
392
+ axs[1, 0].set_ylabel('True Positive Rate')
393
+ axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
394
+ axs[1, 0].legend(loc="lower right")
395
+ axs[1, 0].legend(loc="lower right")
396
+
397
+
398
+ metric = ['Precision', 'Recall', 'F1 Score', 'Accuracy']
399
+ values = [precision, recall, f1, acc.get()]
400
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
401
+
402
+
403
+ bars = axs[0, 1].bar(metric, values, color=colors)
404
+
405
+
406
+ for bar, value in zip(bars, values):
407
+ axs[0, 1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() - 0.05, f'{value:.2f}',
408
+ ha='center', va='bottom', fontsize=12, color='white', weight='bold')
409
+
410
+ axs[0, 1].set_ylim(0, 1)
411
+ axs[0, 1].set_xlabel('Metrics')
412
+ axs[0, 1].set_ylabel('Score')
413
+ axs[0, 1].set_title('Precision, Recall, F1 Score, and Accuracy (Weighted)')
414
+ axs[0, 1].grid(True, axis='y', linestyle='--', alpha=0.7)
415
+
416
+ feature_indices=[0, 1]
417
+
418
+ h = .02
419
+ x_min, x_max = x_test[:, feature_indices[0]].min() - 1, x_test[:, feature_indices[0]].max() + 1
420
+ y_min, y_max = x_test[:, feature_indices[1]].min() - 1, x_test[:, feature_indices[1]].max() + 1
421
+ xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
422
+ cp.arange(y_min, y_max, h))
423
+
424
+ grid = cp.c_[xx.ravel(), yy.ravel()]
425
+
426
+ grid_full = cp.zeros((grid.shape[0], x_test.shape[1]), dtype=cp.float32)
427
+ grid_full[:, feature_indices] = grid
428
+
429
+ Z = [None] * len(grid_full)
430
+
431
+ predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
432
+ bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
433
+
434
+ for i in range(len(grid_full)):
435
+
436
+ Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
437
+ predict_progress.update(1)
438
+
439
+ predict_progress.close()
440
+
441
+ Z = cp.array(Z)
442
+ Z = Z.reshape(xx.shape)
443
+
444
+ axs[1,1].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
445
+ axs[1,1].scatter(x_test[:, feature_indices[0]].get(), x_test[:, feature_indices[1]].get(), c=decode_one_hot(y_test).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
446
+ axs[1,1].set_xlabel(f'Feature {0 + 1}')
447
+ axs[1,1].set_ylabel(f'Feature {1 + 1}')
448
+ axs[1,1].set_title('Decision Boundary')
449
+
450
+ plt.show()
451
+
452
+
453
+ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=None):
454
+
455
+ from .model_operations_cuda import predict_model_ram
456
+ from .data_operations_cuda import decode_one_hot
457
+
458
+ feature_indices = [0, 1]
459
+
460
+ h = .02
461
+ x_min, x_max = x[:, feature_indices[0]].min() - 1, x[:, feature_indices[0]].max() + 1
462
+ y_min, y_max = x[:, feature_indices[1]].min() - 1, x[:, feature_indices[1]].max() + 1
463
+ xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
464
+ cp.arange(y_min, y_max, h))
465
+
466
+ grid = cp.c_[xx.ravel(), yy.ravel()]
467
+ grid_full = cp.zeros((grid.shape[0], x.shape[1]))
468
+ grid_full[:, feature_indices] = grid
469
+
470
+ Z = [None] * len(grid_full)
471
+
472
+ for i in range(len(grid_full)):
473
+ Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
474
+
475
+ Z = cp.array(Z, dtype=cp.int32)
476
+ Z = Z.reshape(xx.shape)
477
+
478
+ if ax is None:
479
+
480
+ plt.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
481
+ plt.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
482
+ plt.xlabel(f'Feature {0 + 1}')
483
+ plt.ylabel(f'Feature {1 + 1}')
484
+ plt.title('Decision Boundary')
485
+
486
+ plt.show()
487
+
488
+ else:
489
+
490
+ try:
491
+ art1_1 = ax[1, 0].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
492
+ art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
493
+ ax[1, 0].set_xlabel(f'Feature {0 + 1}')
494
+ ax[1, 0].set_ylabel(f'Feature {1 + 1}')
495
+ ax[1, 0].set_title('Decision Boundary')
496
+
497
+ return art1_1, art1_2
498
+
499
+ except:
500
+
501
+ art1_1 = ax.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
502
+ art1_2 = ax.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
503
+ ax.set_xlabel(f'Feature {0 + 1}')
504
+ ax.set_ylabel(f'Feature {1 + 1}')
505
+ ax.set_title('Decision Boundary')
506
+
507
+
508
+ return art1_1, art1_2
509
+
510
+
511
+ def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
512
+
513
+ from .metrics_cuda import pca
514
+ from .data_operations_cuda import decode_one_hot
515
+
516
+ if x.shape[1] > 2:
517
+
518
+ X_pca = pca(x, n_components=2)
519
+ else:
520
+ X_pca = x
521
+
522
+ if y_preds == None:
523
+ y_preds = decode_one_hot(y)
524
+
525
+ y = decode_one_hot(y)
526
+ num_classes = len(cp.unique(y))
527
+
528
+ cmap = plt.get_cmap(color)
529
+
530
+
531
+ norm = plt.Normalize(vmin=0, vmax=num_classes - 1)
532
+
533
+
534
+ plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, edgecolor='k', s=50, cmap=cmap, norm=norm)
535
+
536
+
537
+ for cls in range(num_classes):
538
+
539
+ class_points = []
540
+
541
+
542
+ for i in range(len(y)):
543
+ if y_preds[i] == cls:
544
+ class_points.append(X_pca[i])
545
+
546
+ class_points = cp.array(class_points, dtype=y.dtype)
547
+
548
+
549
+ if len(class_points) > 2:
550
+ hull = ConvexHull(class_points)
551
+ hull_points = class_points[hull.vertices]
552
+
553
+ hull_points = cp.vstack([hull_points, hull_points[0]])
554
+
555
+ plt.fill(hull_points[:, 0], hull_points[:, 1], color=cmap(norm(cls)), alpha=0.3, edgecolor='k', label=f'Class {cls} Hull')
556
+
557
+ plt.title("Decision Space (Data Distribution)")
558
+
559
+ plt.draw()
560
+
561
+
562
+ def update_neuron_history(LTPW, ax1, row, col, class_count, artist5, fig1, acc=False, loss=False):
563
+
564
+ for j in range(class_count):
565
+
566
+ if acc != False and loss != False:
567
+ suptitle_info = ' Accuracy:' + str(acc) + '\n' + '\nNeurons Memory:'
568
+ else:
569
+ suptitle_info = 'Neurons Memory:'
570
+
571
+ mat = LTPW[j,:].reshape(row, col).get()
572
+
573
+ title_info = f'{j+1}. Neuron'
574
+
575
+ art5 = ax1[j].imshow(mat, interpolation='sinc', cmap='viridis')
576
+
577
+ ax1[j].set_aspect('equal')
578
+ ax1[j].set_xticks([])
579
+ ax1[j].set_yticks([])
580
+ ax1[j].set_title(title_info)
581
+
582
+
583
+ artist5.append([art5])
584
+
585
+ fig1.suptitle(suptitle_info, fontsize=16)
586
+
587
+
588
+ def initialize_visualization_for_fit(val, show_training, neurons_history, x_train, y_train):
589
+ """Initializes the visualization setup based on the parameters."""
590
+ from .data_operations_cuda import find_closest_factors
591
+ visualization_objects = {}
592
+
593
+ if show_training or neurons_history:
594
+ if not val:
595
+ raise ValueError("For showing training or neurons history, 'val' parameter must be True.")
596
+
597
+ G = nx.Graph()
598
+ fig, ax = plt.subplots(2, 2)
599
+ fig.suptitle('Train History')
600
+ visualization_objects.update({
601
+ 'G': G,
602
+ 'fig': fig,
603
+ 'ax': ax,
604
+ 'artist1': [],
605
+ 'artist2': [],
606
+ 'artist3': [],
607
+ 'artist4': []
608
+ })
609
+
610
+ if neurons_history:
611
+ row, col = find_closest_factors(len(x_train[0]))
612
+ fig1, ax1 = plt.subplots(1, len(y_train[0]), figsize=(18, 14))
613
+ visualization_objects.update({
614
+ 'fig1': fig1,
615
+ 'ax1': ax1,
616
+ 'artist5': [],
617
+ 'row': row,
618
+ 'col': col
619
+ })
620
+
621
+ return visualization_objects
622
+
623
+
624
+ def update_weight_visualization_for_fit(ax, LTPW, artist2):
625
+ """Updates the weight visualization plot."""
626
+ art2 = ax.imshow(LTPW.get(), interpolation='sinc', cmap='viridis')
627
+ artist2.append([art2])
628
+
629
+ def show():
630
+ plt.tight_layout()
631
+ plt.show()
632
+
633
+ def update_neural_web_for_fit(W, ax, G, artist):
634
+ """
635
+ The function `update_neural_web_for_fit` updates a neural web visualization for fitting.
636
+ """
637
+ art5_1, art5_2, art5_3 = draw_neural_web(W=W, ax=ax, G=G, return_objs=True)
638
+ art5_list = [art5_1] + [art5_2] + list(art5_3.values())
639
+ artist.append(art5_list)
640
+
641
+ def update_decision_boundary_for_fit(ax, x_val, y_val, activation_potentiation, LTPW, artist1):
642
+ """Updates the decision boundary visualization."""
643
+ art1_1, art1_2 = plot_decision_boundary(x_val, y_val, activation_potentiation, LTPW, artist=artist1, ax=ax)
644
+ artist1.append([*art1_1.collections, art1_2])
645
+
646
+
647
+ def update_validation_history_for_fit(ax, val_list, artist3):
648
+ """Updates the validation accuracy history plot."""
649
+ val_list_cpu = []
650
+ for i in range(len(val_list)):
651
+ val_list_cpu.append(val_list[i].get())
652
+ period = list(range(1, len(val_list_cpu) + 1))
653
+ art3 = ax.plot(
654
+ period,
655
+ val_list_cpu,
656
+ linestyle='--',
657
+ color='g',
658
+ marker='o',
659
+ markersize=6,
660
+ linewidth=2,
661
+ label='Validation Accuracy'
662
+ )
663
+ ax.set_title('Validation History')
664
+ ax.set_xlabel('Time')
665
+ ax.set_ylabel('Validation Accuracy')
666
+ ax.set_ylim([0, 1])
667
+ artist3.append(art3)
668
+
669
+
670
+ def display_visualization_for_fit(fig, artist_list, interval):
671
+ """Displays the animation for the given artist list."""
672
+ ani = ArtistAnimation(fig, artist_list, interval=interval, blit=True)
673
+ return ani
674
+
675
+ def update_neuron_history_for_learner(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=False, loss=False):
676
+
677
+ for j in range(len(class_count)):
678
+
679
+ if acc != False and loss != False:
680
+ suptitle_info = data + ' Accuracy:' + str(acc) + '\n' + data + ' Loss:' + str(loss) + '\nNeurons Memory:'
681
+ else:
682
+ suptitle_info = 'Neurons Memory:'
683
+
684
+ mat = LTPW[j,:].reshape(row, col)
685
+
686
+ title_info = f'{j+1}. Neuron'
687
+
688
+ art5 = ax1[j].imshow(mat.get(), interpolation='sinc', cmap='viridis')
689
+
690
+ ax1[j].set_aspect('equal')
691
+ ax1[j].set_xticks([])
692
+ ax1[j].set_yticks([])
693
+ ax1[j].set_title(title_info)
694
+
695
+
696
+ artist5.append([art5])
697
+
698
+ fig1.suptitle(suptitle_info, fontsize=16)
699
+
700
+ return artist5
701
+
702
+ def initialize_visualization_for_learner(show_history, neurons_history, neural_web_history, x_train, y_train):
703
+ """Initialize all visualization components"""
704
+ from .data_operations_cuda import find_closest_factors
705
+ viz_objects = {}
706
+
707
+ if show_history:
708
+ fig, ax = plt.subplots(3, 1, figsize=(6, 8))
709
+ fig.suptitle('Learner History')
710
+ viz_objects['history'] = {
711
+ 'fig': fig,
712
+ 'ax': ax,
713
+ 'artist1': [],
714
+ 'artist2': [],
715
+ 'artist3': []
716
+ }
717
+
718
+ if neurons_history:
719
+ row, col = find_closest_factors(len(x_train[0]))
720
+ if row != 0:
721
+ fig1, ax1 = plt.subplots(1, len(y_train[0]), figsize=(18, 14))
722
+ else:
723
+ fig1, ax1 = plt.subplots(1, 1, figsize=(18, 14))
724
+ viz_objects['neurons'] = {
725
+ 'fig': fig1,
726
+ 'ax': ax1,
727
+ 'artists': [],
728
+ 'row': row,
729
+ 'col': col
730
+ }
731
+
732
+ if neural_web_history:
733
+ G = nx.Graph()
734
+ fig2, ax2 = plt.subplots(figsize=(18, 4))
735
+ viz_objects['web'] = {
736
+ 'fig': fig2,
737
+ 'ax': ax2,
738
+ 'G': G,
739
+ 'artists': []
740
+ }
741
+
742
+ return viz_objects
743
+
744
+ def update_history_plots_for_learner(viz_objects, depth_list, loss_list, best_acc_per_depth_list, x_train, final_activations):
745
+ """Update history visualization plots"""
746
+ if 'history' not in viz_objects:
747
+ return
748
+
749
+ hist = viz_objects['history']
750
+ for i in range(len(loss_list)):
751
+ loss_list[i] = loss_list[i].get()
752
+
753
+ # Loss plot
754
+ art1 = hist['ax'][0].plot(depth_list, loss_list, color='r', markersize=6, linewidth=2)
755
+ hist['ax'][0].set_title('Test Loss Over Depth')
756
+ hist['artist1'].append(art1)
757
+
758
+ # Accuracy plot
759
+
760
+ for i in range(len(best_acc_per_depth_list)):
761
+ best_acc_per_depth_list[i] = best_acc_per_depth_list[i].get()
762
+
763
+ art2 = hist['ax'][1].plot(depth_list, best_acc_per_depth_list, color='g', markersize=6, linewidth=2)
764
+ hist['ax'][1].set_title('Test Accuracy Over Depth')
765
+ hist['artist2'].append(art2)
766
+
767
+ # Activation shape plot
768
+ x = cp.linspace(cp.min(x_train), cp.max(x_train), len(x_train))
769
+ translated_x_train = cp.copy(x)
770
+ for activation in final_activations:
771
+ translated_x_train += draw_activations(x, activation)
772
+
773
+ art3 = hist['ax'][2].plot(x.get(), translated_x_train.get(), color='b', markersize=6, linewidth=2)
774
+ hist['ax'][2].set_title('Potentiation Shape Over Depth')
775
+ hist['artist3'].append(art3)
776
+
777
+ def display_visualizations_for_learner(viz_objects, best_weights, data, best_acc, test_loss, y_train, interval):
778
+ """Display all final visualizations"""
779
+ if 'history' in viz_objects:
780
+ hist = viz_objects['history']
781
+ for _ in range(30):
782
+ hist['artist1'].append(hist['artist1'][-1])
783
+ hist['artist2'].append(hist['artist2'][-1])
784
+ hist['artist3'].append(hist['artist3'][-1])
785
+
786
+ ani1 = ArtistAnimation(hist['fig'], hist['artist1'], interval=interval, blit=True)
787
+ ani2 = ArtistAnimation(hist['fig'], hist['artist2'], interval=interval, blit=True)
788
+ ani3 = ArtistAnimation(hist['fig'], hist['artist3'], interval=interval, blit=True)
789
+ plt.tight_layout()
790
+ plt.show()
791
+
792
+ if 'neurons' in viz_objects:
793
+ neurons = viz_objects['neurons']
794
+ for _ in range(10):
795
+ neurons['artists'] = update_neuron_history_for_learner(
796
+ cp.copy(best_weights),
797
+ neurons['ax'],
798
+ neurons['row'],
799
+ neurons['col'],
800
+ y_train[0],
801
+ neurons['artists'],
802
+ data=data,
803
+ fig1=neurons['fig'],
804
+ acc=best_acc,
805
+ loss=test_loss
806
+ )
807
+
808
+ ani4 = ArtistAnimation(neurons['fig'], neurons['artists'], interval=interval, blit=True)
809
+ plt.tight_layout()
810
+ plt.show()
811
+
812
+ if 'web' in viz_objects:
813
+ web = viz_objects['web']
814
+ for _ in range(30):
815
+ art5_1, art5_2, art5_3 = draw_neural_web(
816
+ W=best_weights,
817
+ ax=web['ax'],
818
+ G=web['G'],
819
+ return_objs=True
820
+ )
821
+ art5_list = [art5_1] + [art5_2] + list(art5_3.values())
822
+ web['artists'].append(art5_list)
823
+
824
+ ani5 = ArtistAnimation(web['fig'], web['artists'], interval=interval, blit=True)
825
+ plt.tight_layout()
826
+ plt.show()