pyerualjetwork 5.37__py3-none-any.whl → 5.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyerualjetwork/__init__.py +12 -11
- pyerualjetwork/cpu/__init__.py +1 -4
- pyerualjetwork/cpu/metrics.py +2 -0
- pyerualjetwork/cpu/visualizations.py +103 -135
- pyerualjetwork/cuda/__init__.py +1 -4
- pyerualjetwork/cuda/activation_functions.py +3 -0
- pyerualjetwork/cuda/data_ops.py +1 -1
- pyerualjetwork/cuda/visualizations.py +4 -240
- pyerualjetwork/{cpu/ene.py → ene.py} +5 -57
- pyerualjetwork/help.py +1 -1
- pyerualjetwork/issue_solver.py +39 -11
- pyerualjetwork/model_ops.py +692 -0
- pyerualjetwork/{cpu/nn.py → nn.py} +211 -86
- pyerualjetwork/{cuda/model_ops.py → old_cuda_model_ops.py} +1 -1
- {pyerualjetwork-5.37.dist-info → pyerualjetwork-5.40.dist-info}/METADATA +8 -7
- pyerualjetwork-5.40.dist-info/RECORD +27 -0
- pyerualjetwork/cuda/ene.py +0 -948
- pyerualjetwork/cuda/nn.py +0 -605
- pyerualjetwork-5.37.dist-info/RECORD +0 -28
- /pyerualjetwork/{cpu/model_ops.py → old_cpu_model_ops.py} +0 -0
- {pyerualjetwork-5.37.dist-info → pyerualjetwork-5.40.dist-info}/WHEEL +0 -0
- {pyerualjetwork-5.37.dist-info → pyerualjetwork-5.40.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import networkx as nx
|
|
2
2
|
import matplotlib.pyplot as plt
|
3
3
|
import cupy as cp
|
4
4
|
from scipy.spatial import ConvexHull
|
5
|
-
import seaborn as sns
|
6
5
|
from matplotlib.animation import ArtistAnimation
|
7
6
|
|
8
7
|
def draw_neural_web(W, ax, G, return_objs=False):
|
@@ -87,13 +86,13 @@ def draw_model_architecture(model_name, model_path=''):
|
|
87
86
|
Visualizes the architecture of a neural network model with multiple inputs based on activation functions.
|
88
87
|
"""
|
89
88
|
|
90
|
-
from
|
89
|
+
from ..model_ops import load_model
|
91
90
|
|
92
91
|
model = load_model(model_name=model_name, model_path=model_path)
|
93
92
|
|
94
|
-
W = model
|
95
|
-
activations = model
|
96
|
-
scaler_params = model
|
93
|
+
W = model.weights
|
94
|
+
activations = model.activations
|
95
|
+
scaler_params = model.scaler_params
|
97
96
|
|
98
97
|
# Calculate dimensions based on number of activation functions
|
99
98
|
num_activations = len(activations)
|
@@ -188,9 +187,6 @@ def draw_activations(x_train, activation):
|
|
188
187
|
if activation == 'sigmoid':
|
189
188
|
result = af.Sigmoid(x_train)
|
190
189
|
|
191
|
-
elif activation == 'swish':
|
192
|
-
result = af.swish(x_train)
|
193
|
-
|
194
190
|
elif activation == 'circular':
|
195
191
|
result = af.circular_activation(x_train)
|
196
192
|
|
@@ -206,18 +202,6 @@ def draw_activations(x_train, activation):
|
|
206
202
|
elif activation == 'relu':
|
207
203
|
result = af.Relu(x_train)
|
208
204
|
|
209
|
-
elif activation == 'softplus':
|
210
|
-
result = af.softplus(x_train)
|
211
|
-
|
212
|
-
elif activation == 'elu':
|
213
|
-
result = af.elu(x_train)
|
214
|
-
|
215
|
-
elif activation == 'gelu':
|
216
|
-
result = af.gelu(x_train)
|
217
|
-
|
218
|
-
elif activation == 'selu':
|
219
|
-
result = af.selu(x_train)
|
220
|
-
|
221
205
|
elif activation == 'softmax':
|
222
206
|
result = af.Softmax(x_train)
|
223
207
|
|
@@ -236,24 +220,12 @@ def draw_activations(x_train, activation):
|
|
236
220
|
elif activation == 'dlrelu':
|
237
221
|
result = af.dlrelu(x_train)
|
238
222
|
|
239
|
-
elif activation == 'exsig':
|
240
|
-
result = af.exsig(x_train)
|
241
|
-
|
242
223
|
elif activation == 'sin_plus':
|
243
224
|
result = af.sin_plus(x_train)
|
244
225
|
|
245
226
|
elif activation == 'acos':
|
246
227
|
result = af.acos(x_train, alpha=1.0, beta=0.0)
|
247
228
|
|
248
|
-
elif activation == 'gla':
|
249
|
-
result = af.gla(x_train, alpha=1.0, mu=0.0)
|
250
|
-
|
251
|
-
elif activation == 'srelu':
|
252
|
-
result = af.srelu(x_train)
|
253
|
-
|
254
|
-
elif activation == 'qelu':
|
255
|
-
result = af.qelu(x_train)
|
256
|
-
|
257
229
|
elif activation == 'isra':
|
258
230
|
result = af.isra(x_train)
|
259
231
|
|
@@ -266,54 +238,27 @@ def draw_activations(x_train, activation):
|
|
266
238
|
elif activation == 'bent_identity':
|
267
239
|
result = af.bent_identity(x_train)
|
268
240
|
|
269
|
-
elif activation == 'sech':
|
270
|
-
result = af.sech(x_train)
|
271
|
-
|
272
241
|
elif activation == 'softsign':
|
273
242
|
result = af.softsign(x_train)
|
274
243
|
|
275
244
|
elif activation == 'pwl':
|
276
245
|
result = af.pwl(x_train)
|
277
246
|
|
278
|
-
elif activation == 'cubic':
|
279
|
-
result = af.cubic(x_train)
|
280
|
-
|
281
|
-
elif activation == 'gaussian':
|
282
|
-
result = af.gaussian(x_train)
|
283
|
-
|
284
247
|
elif activation == 'sine':
|
285
248
|
result = af.sine(x_train)
|
286
249
|
|
287
250
|
elif activation == 'tanh_square':
|
288
251
|
result = af.tanh_square(x_train)
|
289
252
|
|
290
|
-
elif activation == 'mod_sigmoid':
|
291
|
-
result = af.mod_sigmoid(x_train)
|
292
|
-
|
293
253
|
elif activation == 'linear':
|
294
254
|
result = x_train
|
295
255
|
|
296
|
-
elif activation == 'quartic':
|
297
|
-
result = af.quartic(x_train)
|
298
|
-
|
299
|
-
elif activation == 'square_quartic':
|
300
|
-
result = af.square_quartic(x_train)
|
301
|
-
|
302
|
-
elif activation == 'cubic_quadratic':
|
303
|
-
result = af.cubic_quadratic(x_train)
|
304
|
-
|
305
|
-
elif activation == 'exp_cubic':
|
306
|
-
result = af.exp_cubic(x_train)
|
307
|
-
|
308
256
|
elif activation == 'sine_square':
|
309
257
|
result = af.sine_square(x_train)
|
310
258
|
|
311
259
|
elif activation == 'logarithmic':
|
312
260
|
result = af.logarithmic(x_train)
|
313
261
|
|
314
|
-
elif activation == 'scaled_cubic':
|
315
|
-
result = af.scaled_cubic(x_train, 1.0)
|
316
|
-
|
317
262
|
elif activation == 'sine_offset':
|
318
263
|
result = af.sine_offset(x_train, 1.0)
|
319
264
|
|
@@ -325,187 +270,6 @@ def draw_activations(x_train, activation):
|
|
325
270
|
print('\rWARNING: error in drawing some activation.', end='')
|
326
271
|
return x_train
|
327
272
|
|
328
|
-
|
329
|
-
def plot_evaluate(x_test, y_test, y_preds, acc_list, W, activations):
|
330
|
-
|
331
|
-
from .metrics import metrics, confusion_matrix, roc_curve
|
332
|
-
from ..ui import loading_bars, initialize_loading_bar
|
333
|
-
from .data_ops import decode_one_hot
|
334
|
-
from .model_ops import predict_model_ram
|
335
|
-
|
336
|
-
bar_format_normal = loading_bars()[0]
|
337
|
-
|
338
|
-
acc = acc_list[len(acc_list) - 1]
|
339
|
-
y_true = decode_one_hot(y_test)
|
340
|
-
|
341
|
-
y_true = cp.array(y_true, copy=True)
|
342
|
-
y_preds = cp.array(y_preds, copy=True)
|
343
|
-
Class = cp.unique(decode_one_hot(y_test))
|
344
|
-
|
345
|
-
precision, recall, f1 = metrics(y_test, y_preds)
|
346
|
-
|
347
|
-
|
348
|
-
cm = confusion_matrix(y_true, y_preds, len(Class))
|
349
|
-
fig, axs = plt.subplots(2, 2, figsize=(16, 12))
|
350
|
-
|
351
|
-
sns.heatmap(cm.get(), annot=True, fmt='d', ax=axs[0, 0])
|
352
|
-
axs[0, 0].set_title("Confusion Matrix")
|
353
|
-
axs[0, 0].set_xlabel("Predicted Class")
|
354
|
-
axs[0, 0].set_ylabel("Actual Class")
|
355
|
-
|
356
|
-
if len(Class) == 2:
|
357
|
-
fpr, tpr, thresholds = roc_curve(y_true, y_preds)
|
358
|
-
|
359
|
-
roc_auc = cp.trapz(tpr, fpr)
|
360
|
-
axs[1, 0].plot(fpr.get(), tpr.get(), color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
361
|
-
axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
362
|
-
axs[1, 0].set_xlim([0.0, 1.0])
|
363
|
-
axs[1, 0].set_ylim([0.0, 1.05])
|
364
|
-
axs[1, 0].set_xlabel('False Positive Rate')
|
365
|
-
axs[1, 0].set_ylabel('True Positive Rate')
|
366
|
-
axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
|
367
|
-
axs[1, 0].legend(loc="lower right")
|
368
|
-
axs[1, 0].legend(loc="lower right")
|
369
|
-
else:
|
370
|
-
|
371
|
-
for i in range(len(Class)):
|
372
|
-
|
373
|
-
y_true_copy = cp.copy(y_true)
|
374
|
-
y_preds_copy = cp.copy(y_preds)
|
375
|
-
|
376
|
-
y_true_copy[y_true_copy == i] = 0
|
377
|
-
y_true_copy[y_true_copy != 0] = 1
|
378
|
-
|
379
|
-
y_preds_copy[y_preds_copy == i] = 0
|
380
|
-
y_preds_copy[y_preds_copy != 0] = 1
|
381
|
-
|
382
|
-
|
383
|
-
fpr, tpr, thresholds = roc_curve(y_true_copy, y_preds_copy)
|
384
|
-
|
385
|
-
roc_auc = cp.trapz(tpr, fpr)
|
386
|
-
axs[1, 0].plot(fpr.get(), tpr.get(), color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
387
|
-
axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
388
|
-
axs[1, 0].set_xlim([0.0, 1.0])
|
389
|
-
axs[1, 0].set_ylim([0.0, 1.05])
|
390
|
-
axs[1, 0].set_xlabel('False Positive Rate')
|
391
|
-
axs[1, 0].set_ylabel('True Positive Rate')
|
392
|
-
axs[1, 0].set_title('Receiver Operating Characteristic (ROC) Curve')
|
393
|
-
axs[1, 0].legend(loc="lower right")
|
394
|
-
axs[1, 0].legend(loc="lower right")
|
395
|
-
|
396
|
-
|
397
|
-
metric = ['Precision', 'Recall', 'F1 Score', 'Accuracy']
|
398
|
-
values = [precision, recall, f1, acc.get()]
|
399
|
-
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
|
400
|
-
|
401
|
-
|
402
|
-
bars = axs[0, 1].bar(metric, values, color=colors)
|
403
|
-
|
404
|
-
|
405
|
-
for bar, value in zip(bars, values):
|
406
|
-
axs[0, 1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() - 0.05, f'{value:.2f}',
|
407
|
-
ha='center', va='bottom', fontsize=12, color='white', weight='bold')
|
408
|
-
|
409
|
-
axs[0, 1].set_ylim(0, 1)
|
410
|
-
axs[0, 1].set_xlabel('Metrics')
|
411
|
-
axs[0, 1].set_ylabel('Score')
|
412
|
-
axs[0, 1].set_title('Precision, Recall, F1 Score, and Accuracy (Weighted)')
|
413
|
-
axs[0, 1].grid(True, axis='y', linestyle='--', alpha=0.7)
|
414
|
-
|
415
|
-
feature_indices=[0, 1]
|
416
|
-
|
417
|
-
h = .02
|
418
|
-
x_min, x_max = x_test[:, feature_indices[0]].min() - 1, x_test[:, feature_indices[0]].max() + 1
|
419
|
-
y_min, y_max = x_test[:, feature_indices[1]].min() - 1, x_test[:, feature_indices[1]].max() + 1
|
420
|
-
xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
|
421
|
-
cp.arange(y_min, y_max, h))
|
422
|
-
|
423
|
-
grid = cp.c_[xx.ravel(), yy.ravel()]
|
424
|
-
|
425
|
-
grid_full = cp.zeros((grid.shape[0], x_test.shape[1]), dtype=cp.float32)
|
426
|
-
grid_full[:, feature_indices] = grid
|
427
|
-
|
428
|
-
Z = [None] * len(grid_full)
|
429
|
-
|
430
|
-
predict_progress = initialize_loading_bar(total=len(grid_full),leave=False,
|
431
|
-
bar_format=bar_format_normal ,desc="Predicts For Decision Boundary",ncols= 65)
|
432
|
-
|
433
|
-
for i in range(len(grid_full)):
|
434
|
-
|
435
|
-
Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activations=activations))
|
436
|
-
predict_progress.update(1)
|
437
|
-
|
438
|
-
predict_progress.close()
|
439
|
-
|
440
|
-
Z = cp.array(Z)
|
441
|
-
Z = Z.reshape(xx.shape)
|
442
|
-
|
443
|
-
axs[1,1].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
|
444
|
-
axs[1,1].scatter(x_test[:, feature_indices[0]].get(), x_test[:, feature_indices[1]].get(), c=decode_one_hot(y_test).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
|
445
|
-
axs[1,1].set_xlabel(f'Feature {0 + 1}')
|
446
|
-
axs[1,1].set_ylabel(f'Feature {1 + 1}')
|
447
|
-
axs[1,1].set_title('Decision Boundary')
|
448
|
-
|
449
|
-
plt.show()
|
450
|
-
|
451
|
-
|
452
|
-
def plot_decision_boundary(x, y, activations, W, artist=None, ax=None):
|
453
|
-
|
454
|
-
from .model_ops import predict_model_ram
|
455
|
-
from .data_ops import decode_one_hot
|
456
|
-
|
457
|
-
feature_indices = [0, 1]
|
458
|
-
|
459
|
-
h = .02
|
460
|
-
x_min, x_max = x[:, feature_indices[0]].min() - 1, x[:, feature_indices[0]].max() + 1
|
461
|
-
y_min, y_max = x[:, feature_indices[1]].min() - 1, x[:, feature_indices[1]].max() + 1
|
462
|
-
xx, yy = cp.meshgrid(cp.arange(x_min, x_max, h),
|
463
|
-
cp.arange(y_min, y_max, h))
|
464
|
-
|
465
|
-
grid = cp.c_[xx.ravel(), yy.ravel()]
|
466
|
-
grid_full = cp.zeros((grid.shape[0], x.shape[1]))
|
467
|
-
grid_full[:, feature_indices] = grid
|
468
|
-
|
469
|
-
Z = [None] * len(grid_full)
|
470
|
-
|
471
|
-
for i in range(len(grid_full)):
|
472
|
-
Z[i] = cp.argmax(predict_model_ram(grid_full[i], W=W, activations=activations))
|
473
|
-
|
474
|
-
Z = cp.array(Z, dtype=cp.int32)
|
475
|
-
Z = Z.reshape(xx.shape)
|
476
|
-
|
477
|
-
if ax is None:
|
478
|
-
|
479
|
-
plt.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
|
480
|
-
plt.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
|
481
|
-
plt.xlabel(f'Feature {0 + 1}')
|
482
|
-
plt.ylabel(f'Feature {1 + 1}')
|
483
|
-
plt.title('Decision Boundary')
|
484
|
-
|
485
|
-
plt.show()
|
486
|
-
|
487
|
-
else:
|
488
|
-
|
489
|
-
try:
|
490
|
-
art1_1 = ax[1, 0].contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
|
491
|
-
art1_2 = ax[1, 0].scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
|
492
|
-
ax[1, 0].set_xlabel(f'Feature {0 + 1}')
|
493
|
-
ax[1, 0].set_ylabel(f'Feature {1 + 1}')
|
494
|
-
ax[1, 0].set_title('Decision Boundary')
|
495
|
-
|
496
|
-
return art1_1, art1_2
|
497
|
-
|
498
|
-
except:
|
499
|
-
|
500
|
-
art1_1 = ax.contourf(xx.get(), yy.get(), Z.get(), alpha=0.8)
|
501
|
-
art1_2 = ax.scatter(x[:, feature_indices[0]].get(), x[:, feature_indices[1]].get(), c=decode_one_hot(y).get(), edgecolors='k', marker='o', s=20, alpha=0.9)
|
502
|
-
ax.set_xlabel(f'Feature {0 + 1}')
|
503
|
-
ax.set_ylabel(f'Feature {1 + 1}')
|
504
|
-
ax.set_title('Decision Boundary')
|
505
|
-
|
506
|
-
|
507
|
-
return art1_1, art1_2
|
508
|
-
|
509
273
|
|
510
274
|
def plot_decision_space(x, y, y_preds=None, s=100, color='tab20'):
|
511
275
|
|
@@ -1,17 +1,16 @@
|
|
1
1
|
"""
|
2
2
|
|
3
3
|
|
4
|
-
ENE (Eugenic NeuroEvolution)
|
4
|
+
ENE (Eugenic NeuroEvolution)
|
5
5
|
===================================
|
6
6
|
|
7
|
-
This module contains all the functions necessary for implementing and testing the ENE (Eugenic NeuroEvolution)
|
7
|
+
This module contains all the functions necessary for implementing and testing the ENE (Eugenic NeuroEvolution).
|
8
8
|
For more information about the ENE algorithm: https://github.com/HCB06/PyerualJetwork/blob/main/Welcome_to_PLAN/PLAN.pdf
|
9
9
|
|
10
10
|
Module functions:
|
11
11
|
-----------------
|
12
12
|
- evolver()
|
13
13
|
- define_genomes()
|
14
|
-
- evaluate()
|
15
14
|
- cross_over()
|
16
15
|
- mutation()
|
17
16
|
- dominant_parent_selection()
|
@@ -34,9 +33,9 @@ import math
|
|
34
33
|
import copy
|
35
34
|
|
36
35
|
### LIBRARY IMPORTS ###
|
37
|
-
from .data_ops import non_neg_normalization
|
38
|
-
from
|
39
|
-
from .activation_functions import
|
36
|
+
from .cpu.data_ops import non_neg_normalization
|
37
|
+
from .ui import loading_bars, initialize_loading_bar
|
38
|
+
from .cpu.activation_functions import all_activations
|
40
39
|
|
41
40
|
def define_genomes(input_shape, output_shape, population_size, neurons=[], activation_functions=[], dtype=np.float32):
|
42
41
|
"""
|
@@ -518,57 +517,6 @@ def evolver(weights,
|
|
518
517
|
return weights, activations
|
519
518
|
|
520
519
|
|
521
|
-
def evaluate(Input, weights, activations, is_mlp=False):
|
522
|
-
"""
|
523
|
-
Evaluates the performance of a population of genomes, applying different activation functions
|
524
|
-
and weights depending on whether reinforcement learning mode is enabled or not.
|
525
|
-
|
526
|
-
Args:
|
527
|
-
Input (list or numpy.ndarray): A list or 2D numpy array where each element represents
|
528
|
-
a genome (A list of input features for each genome, or a single set of input features for one genome).
|
529
|
-
weights (list or numpy.ndarray): A list or 2D numpy array of weights corresponding to each genome
|
530
|
-
in `x_population`. This determines the strength of connections.
|
531
|
-
activations (list or str): A list where each entry represents an activation function
|
532
|
-
or a potentiation strategy applied to each genome. If only one
|
533
|
-
activation function is used, this can be a single string.
|
534
|
-
is_mlp (bool, optional): Evaluate PLAN model or MLP model ? Default: False (PLAN)
|
535
|
-
|
536
|
-
Returns:
|
537
|
-
list: A list of outputs corresponding to each genome in the population after applying the respective
|
538
|
-
activation function and weights.
|
539
|
-
|
540
|
-
Example:
|
541
|
-
```python
|
542
|
-
outputs = evaluate(Input, weights, activations)
|
543
|
-
```
|
544
|
-
|
545
|
-
- The function returns a list of outputs after processing the population, where each element corresponds to
|
546
|
-
the output for each genome in population.
|
547
|
-
"""
|
548
|
-
### THE OUTPUTS ARE RETURNED, WHERE EACH GENOME'S OUTPUT MATCHES ITS INDEX:
|
549
|
-
|
550
|
-
if isinstance(activations, str):
|
551
|
-
activations = [activations]
|
552
|
-
elif isinstance(activations, list):
|
553
|
-
activations = [item if isinstance(item, list) or isinstance(item, str) else [item] for item in activations]
|
554
|
-
|
555
|
-
|
556
|
-
if is_mlp:
|
557
|
-
layer = Input
|
558
|
-
for i in range(len(weights)):
|
559
|
-
if i != len(weights) - 1 and i != 0: layer = apply_activation(layer, activations[i])
|
560
|
-
|
561
|
-
layer = layer @ weights[i].T
|
562
|
-
|
563
|
-
return layer
|
564
|
-
|
565
|
-
else:
|
566
|
-
|
567
|
-
Input = apply_activation(Input, activations)
|
568
|
-
result = Input @ weights.T
|
569
|
-
|
570
|
-
return result
|
571
|
-
|
572
520
|
def cross_over(first_parent_W,
|
573
521
|
second_parent_W,
|
574
522
|
first_parent_act,
|
pyerualjetwork/help.py
CHANGED
pyerualjetwork/issue_solver.py
CHANGED
@@ -9,7 +9,7 @@ ensuring users are not affected by such problems. PyereualJetwork aims to offer
|
|
9
9
|
|
10
10
|
Module functions:
|
11
11
|
-----------------
|
12
|
-
-
|
12
|
+
- update_model_to_v5_4()
|
13
13
|
|
14
14
|
Examples: https://github.com/HCB06/PyerualJetwork/tree/main/Welcome_to_PyerualJetwork/ExampleCodes
|
15
15
|
|
@@ -22,10 +22,10 @@ PyerualJetwork document: https://github.com/HCB06/PyerualJetwork/blob/main/Welco
|
|
22
22
|
- Contact: tchasancan@gmail.com
|
23
23
|
"""
|
24
24
|
|
25
|
-
def
|
25
|
+
def update_model_to_v5_4(model_name, model_path, is_cuda):
|
26
26
|
|
27
27
|
"""
|
28
|
-
|
28
|
+
update_model_to_v5_4 function helps users for update models from older versions to newer versions.
|
29
29
|
|
30
30
|
:param str model_name: Name of saved model.
|
31
31
|
|
@@ -36,28 +36,34 @@ def update_model_to_v5(model_name, model_path, is_cuda):
|
|
36
36
|
:return: prints terminal if succes.
|
37
37
|
"""
|
38
38
|
|
39
|
+
from .model_ops import save_model
|
40
|
+
|
39
41
|
if is_cuda:
|
40
42
|
|
41
|
-
from
|
43
|
+
from .old_cuda_model_ops import (get_act,
|
42
44
|
get_weights,
|
43
45
|
get_scaler,
|
44
46
|
get_acc,
|
45
47
|
get_model_type,
|
46
48
|
get_weights_type,
|
47
49
|
get_weights_format,
|
48
|
-
|
49
|
-
|
50
|
+
get_model_df,
|
51
|
+
get_act_pot,
|
52
|
+
get_model_version,
|
53
|
+
load_model)
|
50
54
|
else:
|
51
55
|
|
52
|
-
from
|
56
|
+
from .old_cpu_model_ops import (get_act,
|
53
57
|
get_weights,
|
54
58
|
get_scaler,
|
55
59
|
get_acc,
|
56
60
|
get_model_type,
|
57
61
|
get_weights_type,
|
58
62
|
get_weights_format,
|
59
|
-
|
60
|
-
|
63
|
+
get_model_df,
|
64
|
+
get_act_pot,
|
65
|
+
get_model_version,
|
66
|
+
load_model)
|
61
67
|
|
62
68
|
model = load_model(model_name, model_path)
|
63
69
|
|
@@ -68,10 +74,32 @@ def update_model_to_v5(model_name, model_path, is_cuda):
|
|
68
74
|
model_type = model[get_model_type()]
|
69
75
|
weights_type = model[get_weights_type()]
|
70
76
|
weights_format = model[get_weights_format()]
|
77
|
+
model_df = model[get_model_df()]
|
78
|
+
act_pot = model[get_act_pot()]
|
79
|
+
version = model[get_model_version()]
|
80
|
+
|
81
|
+
from .model_ops import get_model_template
|
82
|
+
|
83
|
+
template_model = get_model_template()
|
84
|
+
|
85
|
+
model = template_model(weights,
|
86
|
+
None,
|
87
|
+
test_acc,
|
88
|
+
activations,
|
89
|
+
scaler_params,
|
90
|
+
None,
|
91
|
+
model_type,
|
92
|
+
weights_type,
|
93
|
+
weights_format,
|
94
|
+
device_version,
|
95
|
+
model_df,
|
96
|
+
act_pot
|
97
|
+
)
|
98
|
+
|
71
99
|
|
72
100
|
from .__init__ import __version__
|
73
101
|
device_version = __version__
|
74
102
|
|
75
|
-
save_model("updated_" + model_name,
|
103
|
+
save_model(model, "updated_" + model_name, model_path)
|
76
104
|
|
77
|
-
print(f"\nModel succesfully updated to {device_version}.
|
105
|
+
print(f"\nModel succesfully updated from {version} to {device_version}. In this location: {model_path}\nNOTE: This operation just for compatibility. You may still have perfomance issues in this situation please install model's version of pyerualjetwork.")
|