pyerualjetwork 4.1.4__py3-none-any.whl → 4.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  import networkx as nx
2
2
  import matplotlib.pyplot as plt
3
3
  import cupy as cp
4
+ import numpy as np
4
5
  from scipy.spatial import ConvexHull
5
6
  import seaborn as sns
6
7
  from matplotlib.animation import ArtistAnimation
@@ -31,7 +32,7 @@ def draw_neural_web(W, ax, G, return_objs=False):
31
32
  art1, art2, art3 = draw_neural_web(W, ax, G, return_objs=True)
32
33
  plt.show()
33
34
  """
34
-
35
+ W = W.get()
35
36
  for i in range(W.shape[0]):
36
37
  for j in range(W.shape[1]):
37
38
  if W[i, j] != 0:
@@ -40,6 +41,7 @@ def draw_neural_web(W, ax, G, return_objs=False):
40
41
  edges = G.edges(data=True)
41
42
  weights = [edata['ltpw'] for _, _, edata in edges]
42
43
  pos = {}
44
+
43
45
  num_motor_neurons = W.shape[0]
44
46
  num_sensory_neurons = W.shape[1]
45
47
 
@@ -68,128 +70,114 @@ def draw_neural_web(W, ax, G, return_objs=False):
68
70
  return art1, art2, art3
69
71
 
70
72
 
71
- def draw_model_architecture(model_name, model_path='', style='basic'):
73
+ def draw_model_architecture(model_name, model_path=''):
72
74
  """
73
- Visualizes the architecture of a neural network model.
74
-
75
- Parameters
76
- ----------
77
- model_name : str
78
- The name of the model to be visualized, which will be displayed in the title or label.
79
-
80
- model_path : str
81
- The file path to the model, from which the architecture is loaded. Default is ''
82
-
83
- style : str, optional
84
- The style of the visualization.
85
- Options:
86
- - 'basic': Displays a simplified view of the model architecture.
87
- - 'detailed': Shows a more comprehensive view, including layer details and parameters.
88
- Default is 'basic'.
89
-
90
- Returns
91
- -------
92
- None
93
- Draws and displays the architecture of the specified model.
94
-
95
-
96
- Examples
97
- --------
98
- >>> draw_model_architecture("MyModel", "path/to/model", style='detailed')
75
+ The `draw_model_architecture` function visualizes the architecture of a neural network model with
76
+ multiple inputs based on activation functions.
77
+
78
+ :param model_name: The `model_name` parameter in the `draw_model_architecture` function is used to
79
+ specify the name of the neural network model whose architecture you want to visualize. This function
80
+ visualizes the architecture of a neural network model with multiple inputs based on activation
81
+ functions
82
+ :param model_path: The `model_path` parameter in the `draw_model_architecture` function is used to
83
+ specify the path where the neural network model is saved. If the model is saved in a specific
84
+ directory or file location, you can provide that path as a string when calling the function. If the
85
+ model is saved
86
+ """
87
+ """
88
+ Visualizes the architecture of a neural network model with multiple inputs based on activation functions.
99
89
  """
100
- from .plan_cuda import get_scaler, get_act_pot, get_weights
101
- from .model_operations_cuda import load_model
90
+
91
+ from .model_operations_cuda import load_model, get_scaler, get_act_pot, get_weights
102
92
 
103
93
  model = load_model(model_name=model_name, model_path=model_path)
104
94
 
105
95
  W = model[get_weights()]
106
96
  activation_potentiation = model[get_act_pot()]
107
97
  scaler_params = model[get_scaler()]
108
-
109
- text_1 = f"Input Shape:\n{W.shape[1]}"
110
- text_2 = f"Output Shape:\n{W.shape[0]}"
111
-
112
- if scaler_params is None:
113
- bottom_left_text = 'Standard Scaler=No'
114
- else:
115
- bottom_left_text = 'Standard Scaler=Yes'
116
-
117
- if len(activation_potentiation) != 1 or (len(activation_potentiation) == 1 and activation_potentiation[0] != 'linear'):
118
-
119
- bottom_left_text_1 = f'Aggregation Layers(Aggregates All Conversions)={len(activation_potentiation)}'
120
-
121
- else:
122
-
123
- bottom_left_text_1 = 'Aggregation Layers(Aggregates All Conversions)=0'
124
-
125
- bottom_left_text_2 = 'Potentiation Layer(Fully Connected)=1'
126
-
127
- if scaler_params is None:
128
- bottom_left_text = 'Standard Scaler=No'
129
- else:
130
- bottom_left_text = 'Standard Scaler=Yes'
131
-
132
- num_middle_axes = len(activation_potentiation)
133
-
134
- if style == 'detailed':
135
-
136
- col = 1
137
-
138
- elif style == 'basic':
139
98
 
140
- col = 2
141
-
142
- fig, axes = plt.subplots(1, num_middle_axes + col, figsize=(5 * (num_middle_axes + 2), 5))
143
-
144
- fig.suptitle("Model Architecture", fontsize=16, fontweight='bold')
145
-
146
- for i, activation in enumerate(activation_potentiation):
147
- x = cp.linspace(-100, 100, 100)
148
- translated_x_train = draw_activations(x, activation)
149
- y = translated_x_train
99
+ # Calculate dimensions based on number of activation functions
100
+ num_activations = len(activation_potentiation)
101
+ input_groups = num_activations # Number of input groups equals number of activations
102
+ num_inputs = W.shape[1]
150
103
 
151
- axes[i].plot(x, y, color='b', markersize=6, linewidth=2, label='Activations Over Depth')
152
- axes[i].set_title(activation_potentiation[i])
153
-
154
- axes[i].spines['top'].set_visible(False)
155
- axes[i].spines['right'].set_visible(False)
156
- axes[i].spines['left'].set_visible(False)
157
- axes[i].spines['bottom'].set_visible(False)
158
- axes[i].get_xaxis().set_visible(False)
159
- axes[i].get_yaxis().set_visible(False)
104
+ # Create figure
105
+ fig = plt.figure(figsize=(15, 10))
106
+
107
+ # Calculate positions for nodes
108
+ def get_node_positions():
109
+ positions = {}
160
110
 
161
-
162
- if i < num_middle_axes - 1:
163
- axes[i].annotate('', xy=(1.05, 0.5), xytext=(0.95, 0.5),
164
- xycoords='axes fraction', textcoords='axes fraction',
165
- arrowprops=dict(arrowstyle="->", color='black', lw=1.5))
111
+ # Input layer positions
112
+ total_height = 0.8 # Maksimum dikey alan
113
+ group_height = total_height / input_groups # Her grup için ayrılan dikey alan
114
+ input_spacing = min(group_height / (num_inputs + 1), 0.1) # Her girdi arasındaki mesafe
115
+
116
+ for group in range(input_groups):
117
+ group_start_y = 0.9 - (group * group_height) # Grubun başlangıç y koordinatı
118
+ for i in range(num_inputs):
119
+ y_pos = group_start_y - ((i + 1) * input_spacing)
120
+ positions[f'input_{group}_{i}'] = (0.2, y_pos)
121
+
122
+ # Aggregation layer positions
123
+ agg_spacing = total_height / (num_inputs + 1)
124
+ for i in range(num_inputs):
125
+ positions[f'summed_{i}'] = (0.5, 0.9 - ((i + 1) * agg_spacing))
126
+
127
+ # Output layer positions
128
+ output_spacing = total_height / (W.shape[0] + 1)
129
+ for i in range(W.shape[0]):
130
+ positions[f'output_{i}'] = (0.8, 0.9 - ((i + 1) * output_spacing))
131
+
132
+ return positions
133
+
134
+ # Draw the network
135
+ pos = get_node_positions()
136
+
137
+ # Draw nodes
138
+ for group in range(input_groups):
139
+ # Draw input nodes
140
+ for i in range(num_inputs):
141
+ plt.plot(*pos[f'input_{group}_{i}'], 'o', color='lightgreen', markersize=20)
142
+ plt.text(pos[f'input_{group}_{i}'][0] - 0.05, pos[f'input_{group}_{i}'][1],
143
+ f'Input #{i+1} ({activation_potentiation[group]})', ha='right', va='center')
144
+
145
+ # Draw connections from input to summed input directly
146
+ plt.plot([pos[f'input_{group}_{i}'][0], pos[f'summed_{i}'][0]],
147
+ [pos[f'input_{group}_{i}'][1], pos[f'summed_{i}'][1]], 'k-')
148
+ # Draw aggregation nodes
149
+ if group == 0:
150
+ plt.plot(*pos[f'summed_{i}'], 'o', color='lightgreen', markersize=20)
151
+ plt.text(pos[f'summed_{i}'][0], pos[f'summed_{i}'][1] + 0.02,
152
+ f'Summed\nInput #{i+1}', ha='center', va='bottom')
153
+
154
+ # Draw output nodes and connections
155
+ for i in range(W.shape[0]):
156
+ plt.plot(*pos[f'output_{i}'], 'o', color='gold', markersize=20)
157
+ plt.text(pos[f'output_{i}'][0] + 0.05, pos[f'output_{i}'][1],
158
+ f'Output #{i+1}', ha='left', va='center', color='purple')
159
+
160
+ # Connect all aggregation nodes to each output
161
+ for group in range(num_inputs):
162
+ plt.plot([pos[f'summed_{group}'][0], pos[f'output_{i}'][0]],
163
+ [pos[f'summed_{group}'][1], pos[f'output_{i}'][1]], 'k-')
166
164
 
167
- if style == 'detailed':
165
+ # Add labels and annotations
166
+ plt.text(0.2, 0.95, 'Input Layer', ha='center', va='bottom', fontsize=12)
167
+ plt.text(0.5, 0.95, 'Aggregation\nLayer', ha='center', va='bottom', fontsize=12)
168
+ plt.text(0.8, 0.95, 'Output Layer', ha='center', va='bottom', fontsize=12)
168
169
 
169
- G = nx.Graph()
170
- draw_neural_web(W=W, ax=axes[num_middle_axes], G=G)
171
-
172
- elif style == 'basic':
173
-
174
- circle1 = plt.Circle((0.5, 0.5), 0.4, color='skyblue', ec='black', lw=1.5)
175
- axes[num_middle_axes].add_patch(circle1)
176
- axes[num_middle_axes].text(0.5, 0.5, text_1, ha='center', va='center', fontsize=12)
177
- axes[num_middle_axes].set_xlim(0, 1)
178
- axes[num_middle_axes].set_ylim(0, 1)
179
- axes[num_middle_axes].axis('off')
180
-
181
- circle2 = plt.Circle((0.5, 0.5), 0.4, color='lightcoral', ec='black', lw=1.5)
182
- axes[-1].add_patch(circle2)
183
- axes[-1].text(0.5, 0.5, text_2, ha='center', va='center', fontsize=12)
184
- axes[-1].set_xlim(0, 1)
185
- axes[-1].set_ylim(0, 1)
186
- axes[-1].axis('off')
187
-
188
-
189
- fig.text(0.01, 0, bottom_left_text, ha='left', va='bottom', fontsize=10)
190
- fig.text(0.01, 0.04, bottom_left_text_1, ha='left', va='bottom', fontsize=10)
191
- fig.text(0.01, 0.08, bottom_left_text_2, ha='left', va='bottom', fontsize=10)
170
+ # Remove axes
171
+ plt.axis('off')
172
+
173
+ # Add model information
174
+ if scaler_params is None:
175
+ plt.text(0.95, 0.05, 'Standard Scaler=No', fontsize=10, ha='right', va='bottom')
176
+ else:
177
+ plt.text(0.95, 0.05, 'Standard Scaler=Yes', fontsize=10, ha='right', va='bottom')
192
178
 
179
+ # Add model architecture title
180
+ plt.text(0.95, 0.1, f"PLAN Model Architecture: {model_name}", fontsize=12, ha='right', va='bottom', fontweight='bold')
193
181
  plt.tight_layout()
194
182
  plt.show()
195
183
 
@@ -348,8 +336,8 @@ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation)
348
336
  acc = acc_list[len(acc_list) - 1]
349
337
  y_true = decode_one_hot(y_test)
350
338
 
351
- y_true = cp.array(y_true)
352
- y_preds = cp.array(y_preds)
339
+ y_true = cp.array(y_true, copy=True)
340
+ y_preds = cp.array(y_preds, copy=True)
353
341
  Class = cp.unique(decode_one_hot(y_test))
354
342
 
355
343
  precision, recall, f1 = metrics(y_test, y_preds)
@@ -358,7 +346,7 @@ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation)
358
346
  cm = confusion_matrix(y_true, y_preds, len(Class))
359
347
  fig, axs = plt.subplots(2, 2, figsize=(16, 12))
360
348
 
361
- sns.heatmap(cm, annot=True, fmt='d', ax=axs[0, 0])
349
+ sns.heatmap(cm.get(), annot=True, fmt='d', ax=axs[0, 0])
362
350
  axs[0, 0].set_title("Confusion Matrix")
363
351
  axs[0, 0].set_xlabel("Predicted Class")
364
352
  axs[0, 0].set_ylabel("Actual Class")
@@ -393,7 +381,7 @@ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation)
393
381
  fpr, tpr, thresholds = roc_curve(y_true_copy, y_preds_copy)
394
382
 
395
383
  roc_auc = cp.trapz(tpr, fpr)
396
- axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
384
+ axs[1, 0].plot(fpr.get(), tpr.get(), color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
397
385
  axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
398
386
  axs[1, 0].set_xlim([0.0, 1.0])
399
387
  axs[1, 0].set_ylim([0.0, 1.05])
@@ -405,7 +393,7 @@ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation)
405
393
 
406
394
 
407
395
  metric = ['Precision', 'Recall', 'F1 Score', 'Accuracy']
408
- values = [precision, recall, f1, acc]
396
+ values = [precision, recall, f1, acc.get()]
409
397
  colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
410
398
 
411
399
 
@@ -432,35 +420,29 @@ def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activation_potentiation)
432
420
 
433
421
  grid = cp.c_[xx.ravel(), yy.ravel()]
434
422
 
435
- try:
436
-
437
- grid_full = cp.zeros((grid.shape[0], x_test.shape[1]))
438
- grid_full[:, feature_indices] = grid
439
-
440
- Z = [None] * len(grid_full)
441
-
442
- predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
443
- bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
423
+ grid_full = cp.zeros((grid.shape[0], x_test.shape[1]), dtype=cp.float32)
424
+ grid_full[:, feature_indices] = grid
425
+
426
+ Z = [None] * len(grid_full)
444
427
 
445
- for i in range(len(grid_full)):
428
+ predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
429
+ bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
446
430
 
447
- Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
448
- predict_progress.update(1)
431
+ for i in range(len(grid_full)):
449
432
 
450
- predict_progress.close()
433
+ Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
434
+ predict_progress.update(1)
451
435
 
452
- Z = cp.array(Z)
453
- Z = Z.reshape(xx.shape)
436
+ predict_progress.close()
454
437
 
455
- axs[1,1].contourf(xx, yy, Z, alpha=0.8)
456
- axs[1,1].scatter(x_test[:, feature_indices[0]], x_test[:, feature_indices[1]], c=decode_one_hot(y_test), edgecolors='k', marker='o', s=20, alpha=0.9)
457
- axs[1,1].set_xlabel(f'Feature {0 + 1}')
458
- axs[1,1].set_ylabel(f'Feature {1 + 1}')
459
- axs[1,1].set_title('Decision Boundary')
438
+ Z = cp.array(Z)
439
+ Z = Z.reshape(xx.shape)
460
440
 
461
- except Exception as e:
462
- # Hata meydana geldiğinde yapılacak işlemler
463
- print(f"Hata oluştu: {e}")
441
+ axs[1,1].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
442
+ axs[1,1].scatter(x_test[:, feature_indices[0]].get(), x_test[:, feature_indices[1]].get(), c=decode_one_hot(y_test).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
443
+ axs[1,1].set_xlabel(f'Feature {0 + 1}')
444
+ axs[1,1].set_ylabel(f'Feature {1 + 1}')
445
+ axs[1,1].set_title('Decision Boundary')
464
446
 
465
447
  plt.show()
466
448
 
@@ -487,13 +469,13 @@ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=Non
487
469
  for i in range(len(grid_full)):
488
470
  Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activation_potentiation=activation_potentiation))
489
471
 
490
- Z = cp.array(Z)
472
+ Z = cp.array(Z, dtype=cp.int32)
491
473
  Z = Z.reshape(xx.shape)
492
474
 
493
475
  if ax is None:
494
476
 
495
- plt.contourf(xx, yy, Z, alpha=0.8)
496
- plt.scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
477
+ plt.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
478
+ plt.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
497
479
  plt.xlabel(f'Feature {0 + 1}')
498
480
  plt.ylabel(f'Feature {1 + 1}')
499
481
  plt.title('Decision Boundary')
@@ -503,8 +485,8 @@ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=Non
503
485
  else:
504
486
 
505
487
  try:
506
- art1_1 = ax[1, 0].contourf(xx, yy, Z, alpha=0.8)
507
- art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
488
+ art1_1 = ax[1, 0].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
489
+ art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
508
490
  ax[1, 0].set_xlabel(f'Feature {0 + 1}')
509
491
  ax[1, 0].set_ylabel(f'Feature {1 + 1}')
510
492
  ax[1, 0].set_title('Decision Boundary')
@@ -513,11 +495,11 @@ def plot_decision_boundary(x, y, activation_potentiation, W, artist=None, ax=Non
513
495
 
514
496
  except:
515
497
 
516
- art1_1 = ax[0].contourf(xx, yy, Z, alpha=0.8)
517
- art1_2 = ax[0].scatter(x[:, feature_indices[0]], x[:, feature_indices[1]], c=decode_one_hot(y), edgecolors='k', marker='o', s=20, alpha=0.9)
518
- ax[0].set_xlabel(f'Feature {0 + 1}')
519
- ax[0].set_ylabel(f'Feature {1 + 1}')
520
- ax[0].set_title('Decision Boundary')
498
+ art1_1 = ax.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
499
+ art1_2 = ax.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
500
+ ax.set_xlabel(f'Feature {0 + 1}')
501
+ ax.set_ylabel(f'Feature {1 + 1}')
502
+ ax.set_title('Decision Boundary')
521
503
 
522
504
 
523
505
  return art1_1, art1_2
@@ -558,7 +540,7 @@ def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
558
540
  if y_preds[i] == cls:
559
541
  class_points.append(X_pca[i])
560
542
 
561
- class_points = cp.array(class_points)
543
+ class_points = cp.array(class_points, dtype=y.dtype)
562
544
 
563
545
 
564
546
  if len(class_points) > 2:
@@ -574,20 +556,20 @@ def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
574
556
  plt.draw()
575
557
 
576
558
 
577
- def neuron_history(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=False, loss=False):
559
+ def update_neuron_history(LTPW, ax1, row, col, class_count, artist5, fig1, acc=False, loss=False):
578
560
 
579
- for j in range(len(class_count)):
561
+ for j in range(class_count):
580
562
 
581
563
  if acc != False and loss != False:
582
- suptitle_info = data + ' Accuracy:' + str(acc) + '\n' + data + ' Loss:' + str(loss) + '\nNeurons Memory:'
564
+ suptitle_info = ' Accuracy:' + str(acc) + '\n' + '\nNeurons Memory:'
583
565
  else:
584
566
  suptitle_info = 'Neurons Memory:'
585
567
 
586
- mat = LTPW[j,:].reshape(row, col)
568
+ mat = LTPW[j,:].reshape(row, col).get()
587
569
 
588
570
  title_info = f'{j+1}. Neuron'
589
571
 
590
- art5 = ax1[j].imshow(mat.get(), interpolation='sinc', cmap='viridis')
572
+ art5 = ax1[j].imshow(mat, interpolation='sinc', cmap='viridis')
591
573
 
592
574
  ax1[j].set_aspect('equal')
593
575
  ax1[j].set_xticks([])
@@ -599,17 +581,15 @@ def neuron_history(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=Fa
599
581
 
600
582
  fig1.suptitle(suptitle_info, fontsize=16)
601
583
 
602
- return artist5
603
-
604
584
 
605
585
  def initialize_visualization_for_fit(val, show_training, neurons_history, x_train, y_train):
606
586
  """Initializes the visualization setup based on the parameters."""
607
- from .data_operations import find_closest_factors
587
+ from .data_operations_cuda import find_closest_factors
608
588
  visualization_objects = {}
609
589
 
610
- if show_training:
590
+ if show_training or neurons_history:
611
591
  if not val:
612
- raise ValueError("For showing training, 'val' parameter must be True.")
592
+ raise ValueError("For showing training or neurons history, 'val' parameter must be True.")
613
593
 
614
594
  G = nx.Graph()
615
595
  fig, ax = plt.subplots(2, 2)
@@ -626,7 +606,7 @@ def initialize_visualization_for_fit(val, show_training, neurons_history, x_trai
626
606
 
627
607
  if neurons_history:
628
608
  row, col = find_closest_factors(len(x_train[0]))
629
- fig1, ax1 = plt.subplots(1, len(set(y_train)), figsize=(18, 14))
609
+ fig1, ax1 = plt.subplots(1, len(y_train[0]), figsize=(18, 14))
630
610
  visualization_objects.update({
631
611
  'fig1': fig1,
632
612
  'ax1': ax1,
@@ -640,10 +620,21 @@ def initialize_visualization_for_fit(val, show_training, neurons_history, x_trai
640
620
 
641
621
  def update_weight_visualization_for_fit(ax, LTPW, artist2):
642
622
  """Updates the weight visualization plot."""
643
- art2 = ax.imshow(LTPW, interpolation='sinc', cmap='viridis')
623
+ art2 = ax.imshow(LTPW.get(), interpolation='sinc', cmap='viridis')
644
624
  artist2.append([art2])
645
625
 
626
+ def show():
627
+ plt.tight_layout()
628
+ plt.show()
646
629
 
630
+ def update_neural_web_for_fit(W, ax, G, artist):
631
+ """
632
+ The function `update_neural_web_for_fit` updates a neural web visualization for fitting.
633
+ """
634
+ art5_1, art5_2, art5_3 = draw_neural_web(W=W, ax=ax, G=G, return_objs=True)
635
+ art5_list = [art5_1] + [art5_2] + list(art5_3.values())
636
+ artist.append(art5_list)
637
+
647
638
  def update_decision_boundary_for_fit(ax, x_val, y_val, activation_potentiation, LTPW, artist1):
648
639
  """Updates the decision boundary visualization."""
649
640
  art1_1, art1_2 = plot_decision_boundary(x_val, y_val, activation_potentiation, LTPW, artist=artist1, ax=ax)
@@ -652,10 +643,13 @@ def update_decision_boundary_for_fit(ax, x_val, y_val, activation_potentiation,
652
643
 
653
644
  def update_validation_history_for_fit(ax, val_list, artist3):
654
645
  """Updates the validation accuracy history plot."""
655
- period = list(range(1, len(val_list) + 1))
646
+ val_list_cpu = []
647
+ for i in range(len(val_list)):
648
+ val_list_cpu.append(val_list[i].get())
649
+ period = list(range(1, len(val_list_cpu) + 1))
656
650
  art3 = ax.plot(
657
651
  period,
658
- val_list,
652
+ val_list_cpu,
659
653
  linestyle='--',
660
654
  color='g',
661
655
  marker='o',
@@ -673,14 +667,38 @@ def update_validation_history_for_fit(ax, val_list, artist3):
673
667
  def display_visualization_for_fit(fig, artist_list, interval):
674
668
  """Displays the animation for the given artist list."""
675
669
  ani = ArtistAnimation(fig, artist_list, interval=interval, blit=True)
676
- plt.tight_layout()
677
- plt.show()
670
+ return ani
678
671
 
672
+ def update_neuron_history_for_learner(LTPW, ax1, row, col, class_count, artist5, data, fig1, acc=False, loss=False):
679
673
 
674
+ for j in range(len(class_count)):
675
+
676
+ if acc != False and loss != False:
677
+ suptitle_info = data + ' Accuracy:' + str(acc) + '\n' + data + ' Loss:' + str(loss) + '\nNeurons Memory:'
678
+ else:
679
+ suptitle_info = 'Neurons Memory:'
680
+
681
+ mat = LTPW[j,:].reshape(row, col)
682
+
683
+ title_info = f'{j+1}. Neuron'
684
+
685
+ art5 = ax1[j].imshow(mat.get(), interpolation='sinc', cmap='viridis')
686
+
687
+ ax1[j].set_aspect('equal')
688
+ ax1[j].set_xticks([])
689
+ ax1[j].set_yticks([])
690
+ ax1[j].set_title(title_info)
691
+
692
+
693
+ artist5.append([art5])
694
+
695
+ fig1.suptitle(suptitle_info, fontsize=16)
696
+
697
+ return artist5
680
698
 
681
699
  def initialize_visualization_for_learner(show_history, neurons_history, neural_web_history, x_train, y_train):
682
700
  """Initialize all visualization components"""
683
- from .data_operations import find_closest_factors
701
+ from .data_operations_cuda import find_closest_factors
684
702
  viz_objects = {}
685
703
 
686
704
  if show_history:
@@ -726,13 +744,19 @@ def update_history_plots_for_learner(viz_objects, depth_list, loss_list, best_ac
726
744
  return
727
745
 
728
746
  hist = viz_objects['history']
729
-
747
+ for i in range(len(loss_list)):
748
+ loss_list[i] = loss_list[i].get()
749
+
730
750
  # Loss plot
731
751
  art1 = hist['ax'][0].plot(depth_list, loss_list, color='r', markersize=6, linewidth=2)
732
752
  hist['ax'][0].set_title('Test Loss Over Depth')
733
753
  hist['artist1'].append(art1)
734
754
 
735
755
  # Accuracy plot
756
+
757
+ for i in range(len(best_acc_per_depth_list)):
758
+ best_acc_per_depth_list[i] = best_acc_per_depth_list[i].get()
759
+
736
760
  art2 = hist['ax'][1].plot(depth_list, best_acc_per_depth_list, color='g', markersize=6, linewidth=2)
737
761
  hist['ax'][1].set_title('Test Accuracy Over Depth')
738
762
  hist['artist2'].append(art2)
@@ -743,7 +767,7 @@ def update_history_plots_for_learner(viz_objects, depth_list, loss_list, best_ac
743
767
  for activation in final_activations:
744
768
  translated_x_train += draw_activations(x, activation)
745
769
 
746
- art3 = hist['ax'][2].plot(x, translated_x_train, color='b', markersize=6, linewidth=2)
770
+ art3 = hist['ax'][2].plot(x.get(), translated_x_train.get(), color='b', markersize=6, linewidth=2)
747
771
  hist['ax'][2].set_title('Potentiation Shape Over Depth')
748
772
  hist['artist3'].append(art3)
749
773
 
@@ -765,8 +789,8 @@ def display_visualizations_for_learner(viz_objects, best_weights, data, best_acc
765
789
  if 'neurons' in viz_objects:
766
790
  neurons = viz_objects['neurons']
767
791
  for _ in range(10):
768
- neurons['artists'] = neuron_history(
769
- cp.copy(best_weights),
792
+ neurons['artists'] = update_neuron_history_for_learner(
793
+ cp.copy(best_weights),
770
794
  neurons['ax'],
771
795
  neurons['row'],
772
796
  neurons['col'],