sciml 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sciml/__init__.py +1 -1
- sciml/ccc.py +35 -35
- sciml/metrics.py +122 -122
- sciml/models.py +530 -796
- sciml/pipelines.py +225 -225
- sciml/regress2.py +216 -216
- {sciml-0.0.11.dist-info → sciml-0.0.12.dist-info}/LICENSE +21 -21
- {sciml-0.0.11.dist-info → sciml-0.0.12.dist-info}/METADATA +13 -13
- sciml-0.0.12.dist-info/RECORD +11 -0
- {sciml-0.0.11.dist-info → sciml-0.0.12.dist-info}/WHEEL +1 -1
- sciml-0.0.11.dist-info/RECORD +0 -11
- {sciml-0.0.11.dist-info → sciml-0.0.12.dist-info}/top_level.txt +0 -0
sciml/models.py
CHANGED
@@ -1,796 +1,530 @@
|
|
1
|
-
import numpy as np
|
2
|
-
import copy
|
3
|
-
import itertools
|
4
|
-
import
|
5
|
-
from xgboost import XGBRegressor
|
6
|
-
from sklearn.metrics import mean_squared_error
|
7
|
-
from sklearn.model_selection import train_test_split
|
8
|
-
|
9
|
-
class
|
10
|
-
"""
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
""
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
[
|
327
|
-
|
328
|
-
|
329
|
-
]
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
""
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
#
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
best_model = model
|
532
|
-
|
533
|
-
preds = best_model.predict(X).reshape(-1, 1)
|
534
|
-
layer.append(best_model)
|
535
|
-
layer_outputs.append(preds)
|
536
|
-
|
537
|
-
output = np.hstack(layer_outputs)
|
538
|
-
return layer, output
|
539
|
-
|
540
|
-
def fit(self, X, y, X_val=None, y_val=None):
|
541
|
-
y = y.ravel()
|
542
|
-
X_current = self._prepare_input(X, apply_forget=True)
|
543
|
-
X_val_current = self._prepare_input(X_val, apply_forget=True) if X_val is not None else None
|
544
|
-
|
545
|
-
no_improve_rounds = 0
|
546
|
-
|
547
|
-
for layer_index in range(self.max_layers):
|
548
|
-
if self.verbose:
|
549
|
-
print(f"Training Layer {layer_index + 1}")
|
550
|
-
|
551
|
-
layer, output = self._fit_layer(X_current, y, X_val_current, y_val, layer_index)
|
552
|
-
self.layers.append(layer)
|
553
|
-
X_current = np.hstack([X_current, output])
|
554
|
-
|
555
|
-
if X_val is not None:
|
556
|
-
val_outputs = []
|
557
|
-
for reg in layer:
|
558
|
-
n_features = reg.n_features_in_
|
559
|
-
preds = reg.predict(X_val_current[:, :n_features]).reshape(-1, 1)
|
560
|
-
val_outputs.append(preds)
|
561
|
-
val_output = np.hstack(val_outputs)
|
562
|
-
X_val_current = np.hstack([X_val_current, val_output])
|
563
|
-
|
564
|
-
y_pred = self.predict(X_val)
|
565
|
-
rmse = np.sqrt(mean_squared_error(y_val, y_pred))
|
566
|
-
if self.verbose:
|
567
|
-
print(f"Validation RMSE: {rmse:.4f}")
|
568
|
-
|
569
|
-
if rmse < self.best_rmse:
|
570
|
-
self.best_rmse = rmse
|
571
|
-
self.best_model = copy.deepcopy(self.layers)
|
572
|
-
no_improve_rounds = 0
|
573
|
-
if self.verbose:
|
574
|
-
print(f"✅ New best RMSE: {self.best_rmse:.4f}")
|
575
|
-
else:
|
576
|
-
no_improve_rounds += 1
|
577
|
-
if no_improve_rounds >= self.early_stopping_rounds:
|
578
|
-
if self.verbose:
|
579
|
-
print("Early stopping triggered.")
|
580
|
-
break
|
581
|
-
|
582
|
-
def predict(self, X):
|
583
|
-
X_current = self._prepare_input(X, apply_forget=True)
|
584
|
-
|
585
|
-
for layer in self.layers:
|
586
|
-
layer_outputs = []
|
587
|
-
for reg in layer:
|
588
|
-
n_features = reg.n_features_in_
|
589
|
-
preds = reg.predict(X_current[:, :n_features]).reshape(-1, 1)
|
590
|
-
layer_outputs.append(preds)
|
591
|
-
output = np.hstack(layer_outputs)
|
592
|
-
X_current = np.hstack([X_current, output])
|
593
|
-
|
594
|
-
final_outputs = []
|
595
|
-
for reg in self.layers[-1]:
|
596
|
-
n_features = reg.n_features_in_
|
597
|
-
final_outputs.append(reg.predict(X_current[:, :n_features]).reshape(-1, 1))
|
598
|
-
return np.mean(np.hstack(final_outputs), axis=1)
|
599
|
-
|
600
|
-
def get_best_model(self):
|
601
|
-
return self.best_model, self.best_rmse
|
602
|
-
|
603
|
-
"""
|
604
|
-
# ============================== Test Example ==============================
|
605
|
-
import numpy as np
|
606
|
-
import copy
|
607
|
-
import itertools
|
608
|
-
from scipy import ndimage
|
609
|
-
from xgboost import XGBRegressor
|
610
|
-
from sklearn.metrics import mean_squared_error
|
611
|
-
from sklearn.model_selection import train_test_split
|
612
|
-
|
613
|
-
# Generate synthetic 4D data: (samples, time, spatial, features)
|
614
|
-
# time order is like [t (today), t - 1 (yesterday), t -2, ...]
|
615
|
-
n_samples = 200
|
616
|
-
n_time = 5
|
617
|
-
n_spatial = 4
|
618
|
-
n_features = 5
|
619
|
-
|
620
|
-
np.random.seed(42)
|
621
|
-
X = np.random.rand(n_samples, n_time, n_spatial, n_features)
|
622
|
-
y = X[:, :3, :2, :4].mean(axis=(1, 2, 3)) + 0.1 * np.random.randn(n_samples)
|
623
|
-
y = y.ravel()
|
624
|
-
|
625
|
-
# Split
|
626
|
-
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, random_state=42)
|
627
|
-
|
628
|
-
# Train model
|
629
|
-
model = SmartForest4D(
|
630
|
-
n_estimators_per_layer=5,
|
631
|
-
max_layers=20,
|
632
|
-
early_stopping_rounds=5,
|
633
|
-
spatial_h = 2,
|
634
|
-
spatial_w = 2,
|
635
|
-
forget_factor=0.1,
|
636
|
-
verbose=1
|
637
|
-
)
|
638
|
-
model.fit(X_train, y_train, X_val, y_val)
|
639
|
-
|
640
|
-
# Predict
|
641
|
-
y_pred = model.predict(X_val)
|
642
|
-
rmse = np.sqrt(mean_squared_error(y_val, y_pred))
|
643
|
-
print("\n✅ Final RMSE on validation set:", rmse)
|
644
|
-
|
645
|
-
|
646
|
-
# Output best model and RMSE
|
647
|
-
best_model, best_rmse = model.get_best_model()
|
648
|
-
print("\nBest validation RMSE:", best_rmse)
|
649
|
-
"""
|
650
|
-
|
651
|
-
# ============================================================================================================================================================
|
652
|
-
# Function mode
|
653
|
-
|
654
|
-
import tensorflow as tf
|
655
|
-
from tensorflow import keras
|
656
|
-
from tensorflow.keras import layers
|
657
|
-
from tensorflow.keras.models import load_model
|
658
|
-
|
659
|
-
def srcnn(learning_rate=0.001):
|
660
|
-
"""
|
661
|
-
Builds and compiles a Super-Resolution Convolutional Neural Network (SRCNN) model
|
662
|
-
that fuses features from both low-resolution and high-resolution images.
|
663
|
-
|
664
|
-
This model uses two parallel input streams:
|
665
|
-
- A low-resolution input which undergoes upscaling through convolutional layers.
|
666
|
-
- A high-resolution input from which texture features are extracted and fused with the low-resolution stream.
|
667
|
-
|
668
|
-
Args:
|
669
|
-
save_path (str, optional): Path to save the compiled model. If None, the model is not saved.
|
670
|
-
learning_rate (float): Learning rate for the Adam optimizer.
|
671
|
-
|
672
|
-
Returns:
|
673
|
-
keras.Model: A compiled Keras model ready for training.
|
674
|
-
"""
|
675
|
-
# Input layers
|
676
|
-
lowres_input = layers.Input(shape=(None, None, 1)) # Low-resolution input
|
677
|
-
highres_input = layers.Input(shape=(None, None, 1)) # High-resolution image
|
678
|
-
|
679
|
-
# Feature extraction from high-resolution image
|
680
|
-
highres_features = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(highres_input)
|
681
|
-
highres_features = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(highres_features)
|
682
|
-
|
683
|
-
# Processing low-resoltuion input
|
684
|
-
x = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(lowres_input)
|
685
|
-
x = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(x)
|
686
|
-
|
687
|
-
# Fusion of high-resolution image textures
|
688
|
-
fusion = layers.Concatenate()([x, highres_features])
|
689
|
-
fusion = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(fusion)
|
690
|
-
fusion = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(fusion)
|
691
|
-
|
692
|
-
# Output
|
693
|
-
output = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(fusion)
|
694
|
-
|
695
|
-
model = keras.Model(inputs=[lowres_input, highres_input], outputs=output)
|
696
|
-
model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="mse")
|
697
|
-
|
698
|
-
return model
|
699
|
-
|
700
|
-
def print_model(model):
|
701
|
-
return model.summary()
|
702
|
-
|
703
|
-
def train(lowres_data, highres_data, epochs=100, batch_size=1, verbose=1, save_path=None):
|
704
|
-
model = srcnn()
|
705
|
-
# Train SRCNN
|
706
|
-
model.fit([modis_data_1, s2_data], s2_data, epochs=epochs, batch_size=batch_size, verbose=verbose)
|
707
|
-
# Save the complete model
|
708
|
-
# Recommended in newer versions of Keras (TensorFlow 2.11+): e.g., 'texture_fusion_model.keras'
|
709
|
-
if save_path: model.save(save_path)
|
710
|
-
|
711
|
-
def apply(model, lowres_data_app, highres_data):
|
712
|
-
super_resolved = model.predict([lowres_data_app, highres_data]).squeeze()
|
713
|
-
super_resolved = xr.DataArray(
|
714
|
-
super_resolved,
|
715
|
-
dims = ("latitude", "longitude"),
|
716
|
-
coords={"latitude": highres_data.latitude, "longitude": highres_data.longitude},
|
717
|
-
name="super_res"
|
718
|
-
)
|
719
|
-
return super_resolved
|
720
|
-
|
721
|
-
def load_model(save_path):
|
722
|
-
model = load_model('texture_fusion_model.keras')
|
723
|
-
|
724
|
-
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
725
|
-
# Class mode
|
726
|
-
|
727
|
-
import numpy as np
|
728
|
-
import xarray as xr
|
729
|
-
import tensorflow as tf
|
730
|
-
from tensorflow import keras
|
731
|
-
from tensorflow.keras import layers
|
732
|
-
from tensorflow.keras.callbacks import EarlyStopping
|
733
|
-
|
734
|
-
class TextureFusionSRCNN:
|
735
|
-
def __init__(self, learning_rate=0.001):
|
736
|
-
self.learning_rate = learning_rate
|
737
|
-
self.model = self._build_model()
|
738
|
-
|
739
|
-
def _build_model(self):
|
740
|
-
# Input layers
|
741
|
-
lowres_input = layers.Input(shape=(None, None, 1)) # Low-resolution input
|
742
|
-
highres_input = layers.Input(shape=(None, None, 1)) # High-resolution image
|
743
|
-
|
744
|
-
# Feature extraction from high-resolution image
|
745
|
-
highres_features = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(highres_input)
|
746
|
-
highres_features = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(highres_features)
|
747
|
-
|
748
|
-
# Processing low-resolution input
|
749
|
-
x = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(lowres_input)
|
750
|
-
x = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(x)
|
751
|
-
|
752
|
-
# Fusion of high-resolution image textures
|
753
|
-
fusion = layers.Concatenate()([x, highres_features])
|
754
|
-
fusion = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(fusion)
|
755
|
-
fusion = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(fusion)
|
756
|
-
|
757
|
-
# Output
|
758
|
-
output = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(fusion)
|
759
|
-
|
760
|
-
model = keras.Model(inputs=[lowres_input, highres_input], outputs=output)
|
761
|
-
model.compile(optimizer=keras.optimizers.Adam(learning_rate=self.learning_rate), loss="mse")
|
762
|
-
|
763
|
-
return model
|
764
|
-
|
765
|
-
def summary(self):
|
766
|
-
return self.model.summary()
|
767
|
-
|
768
|
-
def train(self, lowres_data, highres_data, epochs=100, batch_size=1, verbose=1, save_path=None):
|
769
|
-
early_stop = EarlyStopping(
|
770
|
-
monitor='loss', # You can change to 'val_loss' if you add validation
|
771
|
-
patience=10, # Number of epochs with no improvement after which training will be stopped
|
772
|
-
restore_best_weights=True
|
773
|
-
)
|
774
|
-
|
775
|
-
self.model.fit(
|
776
|
-
[lowres_data, highres_data], highres_data,
|
777
|
-
epochs=epochs,
|
778
|
-
batch_size=batch_size,
|
779
|
-
verbose=verbose,
|
780
|
-
callbacks=[early_stop]
|
781
|
-
)
|
782
|
-
|
783
|
-
if save_path:
|
784
|
-
self.model.save(save_path)
|
785
|
-
|
786
|
-
def apply(self, lowres_data_app, highres_data):
|
787
|
-
super_resolved = self.model.predict([lowres_data_app, highres_data]).squeeze()
|
788
|
-
return super_resolved
|
789
|
-
|
790
|
-
@staticmethod
|
791
|
-
def load(save_path):
|
792
|
-
model = keras.models.load_model(save_path)
|
793
|
-
instance = TextureFusionSRCNN()
|
794
|
-
instance.model = model
|
795
|
-
return instance
|
796
|
-
|
1
|
+
import numpy as np
|
2
|
+
import copy
|
3
|
+
import itertools
|
4
|
+
from scipy import ndimage
|
5
|
+
from xgboost import XGBRegressor
|
6
|
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
7
|
+
from sklearn.model_selection import train_test_split
|
8
|
+
|
9
|
+
class SmartForest4D:
|
10
|
+
"""
|
11
|
+
SmartForest4D is an ensemble learning model designed to handle complex 4D input data
|
12
|
+
(samples, time, spatial, features). It integrates ideas from gradient-boosted decision trees
|
13
|
+
(XGBoost) with LSTM-style forget gates and spatial max pooling.
|
14
|
+
|
15
|
+
The model builds layers of regressors, each layer taking the previous output as part of its
|
16
|
+
input (deep forest style). A forget gate mechanism is applied along the time dimension to
|
17
|
+
emphasize recent temporal information. Spatial max pooling is used to reduce dimensionality
|
18
|
+
across spatial units before flattening and feeding into the regressors.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
-----------
|
22
|
+
n_estimators_per_layer : int
|
23
|
+
Number of XGBoost regressors per layer.
|
24
|
+
|
25
|
+
max_layers : int
|
26
|
+
Maximum number of layers in the deep forest.
|
27
|
+
|
28
|
+
early_stopping_rounds : int
|
29
|
+
Number of rounds without improvement on the validation set before early stopping.
|
30
|
+
|
31
|
+
param_grid : dict
|
32
|
+
Dictionary of hyperparameter lists to search over for XGBoost.
|
33
|
+
|
34
|
+
use_gpu : bool
|
35
|
+
Whether to use GPU for training XGBoost models.
|
36
|
+
|
37
|
+
gpu_id : int
|
38
|
+
GPU device ID to use if use_gpu is True.
|
39
|
+
|
40
|
+
kernel: np.ndarray
|
41
|
+
Convolutional kernel for spatial processing.
|
42
|
+
# ===============================
|
43
|
+
# 0. Do nothing
|
44
|
+
# ===============================
|
45
|
+
|
46
|
+
identity_kernel = np.array([
|
47
|
+
[0, 0, 0],
|
48
|
+
[0, 1, 0],
|
49
|
+
[0, 0, 0]
|
50
|
+
])
|
51
|
+
|
52
|
+
# ===============================
|
53
|
+
# 1. Sobel Edge Detection Kernels
|
54
|
+
# ===============================
|
55
|
+
|
56
|
+
sobel_x = np.array([
|
57
|
+
[-1, 0, 1],
|
58
|
+
[-2, 0, 2],
|
59
|
+
[-1, 0, 1]
|
60
|
+
])
|
61
|
+
|
62
|
+
sobel_y = np.array([
|
63
|
+
[-1, -2, -1],
|
64
|
+
[ 0, 0, 0],
|
65
|
+
[ 1, 2, 1]
|
66
|
+
])
|
67
|
+
|
68
|
+
# ===============================
|
69
|
+
# 2. Gaussian Blur Kernel (3x3)
|
70
|
+
# ===============================
|
71
|
+
gaussian_kernel = (1/16) * np.array([
|
72
|
+
[1, 2, 1],
|
73
|
+
[2, 4, 2],
|
74
|
+
[1, 2, 1]
|
75
|
+
])
|
76
|
+
|
77
|
+
# ===============================
|
78
|
+
# 3. Morphological Structuring Element (3x3 cross)
|
79
|
+
# Used in binary dilation/erosion
|
80
|
+
# ===============================
|
81
|
+
morph_kernel = np.array([
|
82
|
+
[0, 1, 0],
|
83
|
+
[1, 1, 1],
|
84
|
+
[0, 1, 0]
|
85
|
+
])
|
86
|
+
|
87
|
+
# ===============================
|
88
|
+
# 4. Sharpening Kernel
|
89
|
+
# Enhances edges and contrast
|
90
|
+
# ===============================
|
91
|
+
sharpen_kernel = np.array([
|
92
|
+
[ 0, -1, 0],
|
93
|
+
[-1, 5, -1],
|
94
|
+
[ 0, -1, 0]
|
95
|
+
])
|
96
|
+
|
97
|
+
# ===============================
|
98
|
+
# 5. Embossing Kernel
|
99
|
+
# Creates a 3D-like shadowed effect
|
100
|
+
# ===============================
|
101
|
+
emboss_kernel = np.array([
|
102
|
+
[-2, -1, 0],
|
103
|
+
[-1, 1, 1],
|
104
|
+
[ 0, 1, 2]
|
105
|
+
])
|
106
|
+
|
107
|
+
spatial_h : int
|
108
|
+
The height of the 2D grid for the flattened spatial dimension.
|
109
|
+
|
110
|
+
spatial_w : int
|
111
|
+
The width of the 2D grid for the flattened spatial dimension.
|
112
|
+
|
113
|
+
forget_factor : float
|
114
|
+
Exponential decay rate applied along the time axis. Higher values mean stronger forgetting.
|
115
|
+
|
116
|
+
verbose : int
|
117
|
+
Verbosity level for training output.
|
118
|
+
eval_metric : str
|
119
|
+
Statistical metric for evaluating model performance.
|
120
|
+
|
121
|
+
Attributes:
|
122
|
+
-----------
|
123
|
+
layers : list
|
124
|
+
List of trained layers, each containing a list of regressors.
|
125
|
+
|
126
|
+
best_model : list
|
127
|
+
The set of layers corresponding to the best validation RMSE seen during training.
|
128
|
+
|
129
|
+
best_score : float
|
130
|
+
The best metric e.g., lowest RMSE achieved on the validation set.
|
131
|
+
|
132
|
+
Methods:
|
133
|
+
--------
|
134
|
+
fit(X, y, X_val=None, y_val=None):
|
135
|
+
Train the SmartForest4D model on the given 4D input data.
|
136
|
+
|
137
|
+
predict(X):
|
138
|
+
Predict targets for new 4D input data using the trained model.
|
139
|
+
|
140
|
+
get_best_model():
|
141
|
+
Return the best set of layers and corresponding RMSE.
|
142
|
+
|
143
|
+
Notes:
|
144
|
+
------
|
145
|
+
- The product of spatial_h and spatial_w must equal spatial_size (spatial_h * spatial_w = spatial_size).
|
146
|
+
|
147
|
+
Example:
|
148
|
+
--------
|
149
|
+
>>> model = SmartForest4D(n_estimators_per_layer=5, max_layers=10, early_stopping_rounds=3, forget_factor=0.3, verbose=1)
|
150
|
+
>>> model.fit(X_train, y_train, X_val, y_val)
|
151
|
+
>>> y_pred = model.predict(X_val)
|
152
|
+
>>> best_model, best_rmse = model.get_best_model()
|
153
|
+
"""
|
154
|
+
def __init__(self, n_estimators_per_layer=5, max_layers=10, early_stopping_rounds=3, param_grid=None,
|
155
|
+
use_gpu=False, gpu_id=0, kernel = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), spatial_h=None, spatial_w=None,
|
156
|
+
forget_factor=0.0, verbose=1, eval_metric='rmse'):
|
157
|
+
self.n_estimators_per_layer = n_estimators_per_layer
|
158
|
+
self.max_layers = max_layers
|
159
|
+
self.early_stopping_rounds = early_stopping_rounds
|
160
|
+
self.param_grid = param_grid or {
|
161
|
+
"objective": ["reg:squarederror"],
|
162
|
+
"random_state": [42],
|
163
|
+
'n_estimators': [100],
|
164
|
+
'max_depth': [6],
|
165
|
+
'min_child_weight': [4],
|
166
|
+
'subsample': [0.8],
|
167
|
+
'colsample_bytree': [0.8],
|
168
|
+
'gamma': [0],
|
169
|
+
'reg_alpha': [0],
|
170
|
+
'reg_lambda': [1],
|
171
|
+
'learning_rate': [0.05],
|
172
|
+
}
|
173
|
+
self.use_gpu = use_gpu
|
174
|
+
self.gpu_id = gpu_id
|
175
|
+
self.kernel = kernel
|
176
|
+
self.spatial_h = spatial_h
|
177
|
+
self.spatial_w = spatial_w
|
178
|
+
self.forget_factor = forget_factor
|
179
|
+
self.layers = []
|
180
|
+
self.best_model = None
|
181
|
+
self.verbose = verbose
|
182
|
+
self.eval_metric = eval_metric.lower()
|
183
|
+
self.best_score = float("inf") if self.eval_metric != 'r2' else float("-inf")
|
184
|
+
if (self.spatial_h is None) or (self.spatial_w is None):
|
185
|
+
raise ValueError("Please specify spatial_h and spatial_w")
|
186
|
+
|
187
|
+
def _evaluate(self, y_true, y_pred):
|
188
|
+
if self.eval_metric == 'rmse':
|
189
|
+
return np.sqrt(mean_squared_error(y_true, y_pred))
|
190
|
+
elif self.eval_metric == 'nrmse':
|
191
|
+
return np.sqrt(mean_squared_error(y_true, y_pred)) / np.mean(np.abs(y_true))
|
192
|
+
elif self.eval_metric == 'mae':
|
193
|
+
return mean_absolute_error(y_true, y_pred)
|
194
|
+
elif self.eval_metric == 'mape':
|
195
|
+
return np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-8, None))) * 100
|
196
|
+
elif self.eval_metric == 'r2':
|
197
|
+
return r2_score(y_true, y_pred)
|
198
|
+
else:
|
199
|
+
raise ValueError(f"Unknown evaluation metric: {self.eval_metric}")
|
200
|
+
|
201
|
+
def _get_param_combinations(self):
|
202
|
+
keys, values = zip(*self.param_grid.items())
|
203
|
+
return [dict(zip(keys, v)) for v in itertools.product(*values)]
|
204
|
+
|
205
|
+
def _prepare_input(self, X, y=None, apply_forget=False, layer_index=0):
|
206
|
+
if X.ndim == 2:
|
207
|
+
X = X[:, np.newaxis, np.newaxis, :]
|
208
|
+
elif X.ndim == 3:
|
209
|
+
X = X[:, :, np.newaxis, :]
|
210
|
+
elif X.ndim == 4:
|
211
|
+
pass
|
212
|
+
else:
|
213
|
+
raise ValueError("Input must be 2D, 3D, or 4D.")
|
214
|
+
|
215
|
+
n_samples, n_time, n_spatial, n_features = X.shape
|
216
|
+
|
217
|
+
if apply_forget and self.forget_factor > 0:
|
218
|
+
decay = np.exp(-self.forget_factor * np.arange(n_time))[::-1]
|
219
|
+
decay = decay / decay.sum()
|
220
|
+
decay = decay.reshape(1, n_time, 1, 1)
|
221
|
+
X = X * decay
|
222
|
+
|
223
|
+
if n_spatial != 1:
|
224
|
+
if self.spatial_h * self.spatial_w != n_spatial: raise ValueError("spatial_h * spatial_w != n_spatial")
|
225
|
+
X_out = np.zeros_like(X)
|
226
|
+
for sample in range(X.shape[0]):
|
227
|
+
for t in range(X.shape[1]):
|
228
|
+
for f in range(X.shape[3]):
|
229
|
+
spatial_2d = X[sample, t, :, f].reshape(self.spatial_h, self.spatial_w)
|
230
|
+
filtered = ndimage.convolve(spatial_2d, self.kernel, mode='constant', cval=0.0)
|
231
|
+
X_out[sample, t, :, f] = filtered.reshape(n_spatial)
|
232
|
+
X = X_out; del(X_out)
|
233
|
+
X_pooled = X.max(axis=2)
|
234
|
+
X_flattened = X_pooled.reshape(n_samples, -1)
|
235
|
+
return X_flattened
|
236
|
+
|
237
|
+
def _fit_layer(self, X, y, X_val=None, y_val=None, layer_index=0):
|
238
|
+
layer = []
|
239
|
+
layer_outputs = []
|
240
|
+
param_combos = self._get_param_combinations()
|
241
|
+
|
242
|
+
for i in range(self.n_estimators_per_layer):
|
243
|
+
best_metric = float('inf')
|
244
|
+
best_model = None
|
245
|
+
|
246
|
+
for params in param_combos:
|
247
|
+
if self.use_gpu:
|
248
|
+
params['tree_method'] = 'hist'
|
249
|
+
params['device'] = 'cuda'
|
250
|
+
|
251
|
+
params = params.copy()
|
252
|
+
params['random_state'] = i
|
253
|
+
|
254
|
+
model = XGBRegressor(**params)
|
255
|
+
model.fit(X, y)
|
256
|
+
|
257
|
+
if X_val is not None:
|
258
|
+
preds_val = model.predict(X_val)
|
259
|
+
metric = self._evaluate(y_val, preds_val)
|
260
|
+
if metric < best_metric:
|
261
|
+
best_metric = metric
|
262
|
+
best_model = model
|
263
|
+
else:
|
264
|
+
best_model = model
|
265
|
+
|
266
|
+
preds = best_model.predict(X).reshape(-1, 1)
|
267
|
+
layer.append(best_model)
|
268
|
+
layer_outputs.append(preds)
|
269
|
+
|
270
|
+
output = np.hstack(layer_outputs)
|
271
|
+
return layer, output
|
272
|
+
|
273
|
+
def fit(self, X, y, X_val=None, y_val=None):
|
274
|
+
y = y.ravel()
|
275
|
+
X_current = self._prepare_input(X, apply_forget=True)
|
276
|
+
X_val_current = self._prepare_input(X_val, apply_forget=True) if X_val is not None else None
|
277
|
+
|
278
|
+
no_improve_rounds = 0
|
279
|
+
|
280
|
+
for layer_index in range(self.max_layers):
|
281
|
+
if self.verbose:
|
282
|
+
print(f"Training Layer {layer_index + 1}")
|
283
|
+
|
284
|
+
layer, output = self._fit_layer(X_current, y, X_val_current, y_val, layer_index)
|
285
|
+
self.layers.append(layer)
|
286
|
+
X_current = np.hstack([X_current, output])
|
287
|
+
|
288
|
+
if X_val is not None:
|
289
|
+
val_outputs = []
|
290
|
+
for reg in layer:
|
291
|
+
n_features = reg.n_features_in_
|
292
|
+
preds = reg.predict(X_val_current[:, :n_features]).reshape(-1, 1)
|
293
|
+
val_outputs.append(preds)
|
294
|
+
val_output = np.hstack(val_outputs)
|
295
|
+
X_val_current = np.hstack([X_val_current, val_output])
|
296
|
+
|
297
|
+
y_pred = self.predict(X_val)
|
298
|
+
score = self._evaluate(y_val, y_pred)
|
299
|
+
if self.verbose:
|
300
|
+
print(f"Validation {self.eval_metric.upper()}: {score:.4f}")
|
301
|
+
|
302
|
+
improvement = (score < self.best_score) if self.eval_metric != 'r2' else (score > self.best_score)
|
303
|
+
if improvement:
|
304
|
+
self.best_score = score
|
305
|
+
self.best_model = copy.deepcopy(self.layers)
|
306
|
+
no_improve_rounds = 0
|
307
|
+
if self.verbose:
|
308
|
+
print(f"\u2705 New best {self.eval_metric.upper()}: {self.best_score:.4f}")
|
309
|
+
else:
|
310
|
+
no_improve_rounds += 1
|
311
|
+
if no_improve_rounds >= self.early_stopping_rounds:
|
312
|
+
if self.verbose:
|
313
|
+
print("Early stopping triggered.")
|
314
|
+
break
|
315
|
+
|
316
|
+
def predict(self, X):
|
317
|
+
X_current = self._prepare_input(X, apply_forget=True)
|
318
|
+
|
319
|
+
for layer in self.layers:
|
320
|
+
layer_outputs = []
|
321
|
+
for reg in layer:
|
322
|
+
n_features = reg.n_features_in_
|
323
|
+
preds = reg.predict(X_current[:, :n_features]).reshape(-1, 1)
|
324
|
+
layer_outputs.append(preds)
|
325
|
+
output = np.hstack(layer_outputs)
|
326
|
+
X_current = np.hstack([X_current, output])
|
327
|
+
|
328
|
+
final_outputs = []
|
329
|
+
for reg in self.layers[-1]:
|
330
|
+
n_features = reg.n_features_in_
|
331
|
+
final_outputs.append(reg.predict(X_current[:, :n_features]).reshape(-1, 1))
|
332
|
+
return np.mean(np.hstack(final_outputs), axis=1)
|
333
|
+
|
334
|
+
def get_best_model(self):
|
335
|
+
return self.best_model, self.best_score
|
336
|
+
|
337
|
+
"""
|
338
|
+
# ============================== Test Example ==============================
|
339
|
+
import numpy as np
|
340
|
+
import copy
|
341
|
+
import itertools
|
342
|
+
from scipy import ndimage
|
343
|
+
from xgboost import XGBRegressor
|
344
|
+
from sklearn.metrics import mean_squared_error
|
345
|
+
from sklearn.model_selection import train_test_split
|
346
|
+
|
347
|
+
# Generate synthetic 4D data: (samples, time, spatial, features)
|
348
|
+
# time order is like [t (today), t - 1 (yesterday), t -2, ...]
|
349
|
+
n_samples = 200
|
350
|
+
n_time = 5
|
351
|
+
n_spatial = 4
|
352
|
+
n_features = 5
|
353
|
+
|
354
|
+
np.random.seed(42)
|
355
|
+
X = np.random.rand(n_samples, n_time, n_spatial, n_features)
|
356
|
+
y = X[:, :3, :2, :4].mean(axis=(1, 2, 3)) + 0.1 * np.random.randn(n_samples)
|
357
|
+
y = y.ravel()
|
358
|
+
|
359
|
+
# Split
|
360
|
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, random_state=42)
|
361
|
+
|
362
|
+
# Train model
|
363
|
+
model = SmartForest4D(
|
364
|
+
n_estimators_per_layer=5,
|
365
|
+
max_layers=20,
|
366
|
+
early_stopping_rounds=5,
|
367
|
+
spatial_h = 2,
|
368
|
+
spatial_w = 2,
|
369
|
+
forget_factor=0.1,
|
370
|
+
verbose=1
|
371
|
+
)
|
372
|
+
model.fit(X_train, y_train, X_val, y_val)
|
373
|
+
|
374
|
+
# Predict
|
375
|
+
y_pred = model.predict(X_val)
|
376
|
+
rmse = np.sqrt(mean_squared_error(y_val, y_pred))
|
377
|
+
print("\n✅ Final RMSE on validation set:", rmse)
|
378
|
+
|
379
|
+
|
380
|
+
# Output best model and RMSE
|
381
|
+
best_model, best_rmse = model.get_best_model()
|
382
|
+
print("\nBest validation RMSE:", best_rmse)
|
383
|
+
"""
|
384
|
+
|
385
|
+
# ============================================================================================================================================================
|
386
|
+
# Function mode
|
387
|
+
|
388
|
+
import tensorflow as tf
|
389
|
+
from tensorflow import keras
|
390
|
+
from tensorflow.keras import layers
|
391
|
+
from tensorflow.keras.models import load_model
|
392
|
+
|
393
|
+
def srcnn(learning_rate=0.001):
|
394
|
+
"""
|
395
|
+
Builds and compiles a Super-Resolution Convolutional Neural Network (SRCNN) model
|
396
|
+
that fuses features from both low-resolution and high-resolution images.
|
397
|
+
|
398
|
+
This model uses two parallel input streams:
|
399
|
+
- A low-resolution input which undergoes upscaling through convolutional layers.
|
400
|
+
- A high-resolution input from which texture features are extracted and fused with the low-resolution stream.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
save_path (str, optional): Path to save the compiled model. If None, the model is not saved.
|
404
|
+
learning_rate (float): Learning rate for the Adam optimizer.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
keras.Model: A compiled Keras model ready for training.
|
408
|
+
"""
|
409
|
+
# Input layers
|
410
|
+
lowres_input = layers.Input(shape=(None, None, 1)) # Low-resolution input
|
411
|
+
highres_input = layers.Input(shape=(None, None, 1)) # High-resolution image
|
412
|
+
|
413
|
+
# Feature extraction from high-resolution image
|
414
|
+
highres_features = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(highres_input)
|
415
|
+
highres_features = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(highres_features)
|
416
|
+
|
417
|
+
# Processing low-resoltuion input
|
418
|
+
x = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(lowres_input)
|
419
|
+
x = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(x)
|
420
|
+
|
421
|
+
# Fusion of high-resolution image textures
|
422
|
+
fusion = layers.Concatenate()([x, highres_features])
|
423
|
+
fusion = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(fusion)
|
424
|
+
fusion = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(fusion)
|
425
|
+
|
426
|
+
# Output
|
427
|
+
output = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(fusion)
|
428
|
+
|
429
|
+
model = keras.Model(inputs=[lowres_input, highres_input], outputs=output)
|
430
|
+
model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="mse")
|
431
|
+
|
432
|
+
return model
|
433
|
+
|
434
|
+
def print_model(model):
|
435
|
+
return model.summary()
|
436
|
+
|
437
|
+
def train(lowres_data, highres_data, epochs=100, batch_size=1, verbose=1, save_path=None):
|
438
|
+
model = srcnn()
|
439
|
+
# Train SRCNN
|
440
|
+
model.fit([lowres_data, highres_data], highres_data, epochs=epochs, batch_size=batch_size, verbose=verbose)
|
441
|
+
# Save the complete model
|
442
|
+
# Recommended in newer versions of Keras (TensorFlow 2.11+): e.g., 'texture_fusion_model.keras'
|
443
|
+
if save_path: model.save(save_path)
|
444
|
+
|
445
|
+
def apply(model, lowres_data_app, highres_data):
|
446
|
+
super_resolved = model.predict([lowres_data_app, highres_data]).squeeze()
|
447
|
+
super_resolved = xr.DataArray(
|
448
|
+
super_resolved,
|
449
|
+
dims = ("latitude", "longitude"),
|
450
|
+
coords={"latitude": highres_data.latitude, "longitude": highres_data.longitude},
|
451
|
+
name="super_res"
|
452
|
+
)
|
453
|
+
return super_resolved
|
454
|
+
|
455
|
+
def load_model(save_path):
|
456
|
+
model = load_model('texture_fusion_model.keras')
|
457
|
+
|
458
|
+
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
459
|
+
# Class mode
|
460
|
+
|
461
|
+
import numpy as np
|
462
|
+
import xarray as xr
|
463
|
+
import tensorflow as tf
|
464
|
+
from tensorflow import keras
|
465
|
+
from tensorflow.keras import layers
|
466
|
+
from tensorflow.keras.callbacks import EarlyStopping
|
467
|
+
|
468
|
+
class TextureFusionSRCNN:
|
469
|
+
def __init__(self, learning_rate=0.001):
|
470
|
+
self.learning_rate = learning_rate
|
471
|
+
self.model = self._build_model()
|
472
|
+
|
473
|
+
def _build_model(self):
|
474
|
+
# Input layers
|
475
|
+
lowres_input = layers.Input(shape=(None, None, 1)) # Low-resolution input
|
476
|
+
highres_input = layers.Input(shape=(None, None, 1)) # High-resolution image
|
477
|
+
|
478
|
+
# Feature extraction from high-resolution image
|
479
|
+
highres_features = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(highres_input)
|
480
|
+
highres_features = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(highres_features)
|
481
|
+
|
482
|
+
# Processing low-resolution input
|
483
|
+
x = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(lowres_input)
|
484
|
+
x = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(x)
|
485
|
+
|
486
|
+
# Fusion of high-resolution image textures
|
487
|
+
fusion = layers.Concatenate()([x, highres_features])
|
488
|
+
fusion = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(fusion)
|
489
|
+
fusion = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(fusion)
|
490
|
+
|
491
|
+
# Output
|
492
|
+
output = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(fusion)
|
493
|
+
|
494
|
+
model = keras.Model(inputs=[lowres_input, highres_input], outputs=output)
|
495
|
+
model.compile(optimizer=keras.optimizers.Adam(learning_rate=self.learning_rate), loss="mse")
|
496
|
+
|
497
|
+
return model
|
498
|
+
|
499
|
+
def summary(self):
|
500
|
+
return self.model.summary()
|
501
|
+
|
502
|
+
def train(self, lowres_data, highres_data, epochs=100, batch_size=1, verbose=1, save_path=None):
|
503
|
+
early_stop = EarlyStopping(
|
504
|
+
monitor='loss', # You can change to 'val_loss' if you add validation
|
505
|
+
patience=10, # Number of epochs with no improvement after which training will be stopped
|
506
|
+
restore_best_weights=True
|
507
|
+
)
|
508
|
+
|
509
|
+
self.model.fit(
|
510
|
+
[lowres_data, highres_data], highres_data,
|
511
|
+
epochs=epochs,
|
512
|
+
batch_size=batch_size,
|
513
|
+
verbose=verbose,
|
514
|
+
callbacks=[early_stop]
|
515
|
+
)
|
516
|
+
|
517
|
+
if save_path:
|
518
|
+
self.model.save(save_path)
|
519
|
+
|
520
|
+
def apply(self, lowres_data_app, highres_data):
|
521
|
+
super_resolved = self.model.predict([lowres_data_app, highres_data]).squeeze()
|
522
|
+
return super_resolved
|
523
|
+
|
524
|
+
@staticmethod
|
525
|
+
def load(save_path):
|
526
|
+
model = keras.models.load_model(save_path)
|
527
|
+
instance = TextureFusionSRCNN()
|
528
|
+
instance.model = model
|
529
|
+
return instance
|
530
|
+
|