foscat 3.6.1__py3-none-any.whl → 3.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -44,6 +44,7 @@ class foscat_backend:
44
44
 
45
45
  if self.BACKEND == "torch":
46
46
  import torch
47
+ self.torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
48
 
48
49
  self.BACKEND = self.TORCH
49
50
  self.backend = torch
@@ -382,7 +383,7 @@ class foscat_backend:
382
383
  if self.BACKEND == self.TENSORFLOW:
383
384
  return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
384
385
  if self.BACKEND == self.TORCH:
385
- return self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
386
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
386
387
  if self.BACKEND == self.NUMPY:
387
388
  return self.scipy.sparse.coo_matrix(
388
389
  (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
@@ -392,7 +393,7 @@ class foscat_backend:
392
393
  if self.BACKEND == self.TENSORFLOW:
393
394
  return self.backend.stack(list, axis=axis)
394
395
  if self.BACKEND == self.TORCH:
395
- return self.backend.stack(list, axis=axis)
396
+ return self.backend.stack(list, axis=axis).to(self.torch_device)
396
397
  if self.BACKEND == self.NUMPY:
397
398
  return self.backend.stack(list, axis=axis)
398
399
 
@@ -404,18 +405,53 @@ class foscat_backend:
404
405
  if self.BACKEND == self.NUMPY:
405
406
  return smat.dot(mat)
406
407
 
408
+ # for tensorflow wrapping only
409
+ def periodic_pad(self,x, pad_height, pad_width):
410
+ """
411
+ Applies periodic ('wrap') padding to a 4D TensorFlow tensor (N, H, W, C).
412
+
413
+ Args:
414
+ x (tf.Tensor): Input tensor with shape (batch_size, height, width, channels).
415
+ pad_height (tuple): Tuple (top, bottom) defining the vertical padding size.
416
+ pad_width (tuple): Tuple (left, right) defining the horizontal padding size.
417
+
418
+ Returns:
419
+ tf.Tensor: Tensor with periodic padding applied.
420
+ """
421
+ #Vertical padding: take slices from bottom and top to wrap around
422
+ top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
423
+ bottom_pad = x[:, :pad_height, :, :] # Bottom padding from the top rows
424
+ x_padded = self.backend.concat([top_pad, x, bottom_pad], axis=1) # Concatenate vertically
425
+
426
+ #Horizontal padding: take slices from right and left to wrap around
427
+ left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
428
+ right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
429
+
430
+ x_padded = self.backend.concat([left_pad, x_padded, right_pad], axis=2) # Concatenate horizontally
431
+
432
+ return x_padded
433
+
407
434
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
435
  if self.BACKEND == self.TENSORFLOW:
409
436
  kx = w.shape[0]
410
437
  ky = w.shape[1]
411
- paddings = self.backend.constant(
412
- [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
- )
414
- tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
- return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
416
- # to be written!!!
438
+ x_padded = self.periodic_pad(x, kx // 2, ky // 2)
439
+ return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
440
+
417
441
  if self.BACKEND == self.TORCH:
418
- return x
442
+ import torch.nn.functional as F
443
+ lx = x.permute(0, 3, 1, 2)
444
+ wx = self.backend.from_numpy(w).to(self.torch_device).permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
445
+
446
+ # Calculer le padding symétrique
447
+ kx, ky = w.shape[0], w.shape[1]
448
+
449
+ # Appliquer le padding
450
+ x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode='circular')
451
+
452
+ # Appliquer la convolution
453
+ return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
454
+
419
455
  if self.BACKEND == self.NUMPY:
420
456
  res = np.zeros(
421
457
  [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
@@ -516,6 +552,15 @@ class foscat_backend:
516
552
  return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
517
553
  else:
518
554
  return x.flatten()
555
+
556
+
557
+ def bk_flatten(self, x):
558
+ if self.BACKEND == self.TENSORFLOW:
559
+ return self.backend.flatten(x)
560
+ elif self.BACKEND == self.TORCH:
561
+ return self.backend.flatten(x)
562
+ else:
563
+ return x.flatten()
519
564
 
520
565
  def bk_flatten(self, x):
521
566
  if self.BACKEND == self.TENSORFLOW:
@@ -540,9 +585,10 @@ class foscat_backend:
540
585
 
541
586
  if self.BACKEND == self.TORCH:
542
587
  tmp = self.backend.nn.functional.interpolate(
543
- x, size=shape, mode="bilinear", align_corners=False
588
+ x.permute(0,3,1,2), size=shape, mode="bilinear", align_corners=False
544
589
  )
545
- return self.bk_cast(tmp)
590
+ return self.bk_cast(tmp.permute(0,2,3,1))
591
+
546
592
  if self.BACKEND == self.NUMPY:
547
593
  return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
594
 
@@ -592,7 +638,7 @@ class foscat_backend:
592
638
 
593
639
  # ---------------------------------------------−---------
594
640
  # return a tensor size
595
-
641
+
596
642
  def bk_size(self, data):
597
643
  if self.BACKEND == self.TENSORFLOW:
598
644
  return self.backend.size(data)
@@ -600,7 +646,7 @@ class foscat_backend:
600
646
  return data.numel()
601
647
  if self.BACKEND == self.NUMPY:
602
648
  return data.size
603
-
649
+
604
650
  # ---------------------------------------------−---------
605
651
 
606
652
  def iso_mean(self, x, use_2D=False):
@@ -833,15 +879,23 @@ class foscat_backend:
833
879
  if self.BACKEND == self.TENSORFLOW:
834
880
  return self.backend.constant(data)
835
881
  if self.BACKEND == self.TORCH:
836
- return self.backend.constant(data)
882
+ return self.backend.constant(data).to(self.torch_device)
837
883
  if self.BACKEND == self.NUMPY:
838
884
  return data
885
+
886
+ def bk_shape_tensor(self, shape):
887
+ if self.BACKEND == self.TENSORFLOW:
888
+ return self.backend.tensor(shape=shape)
889
+ if self.BACKEND == self.TORCH:
890
+ return self.backend.tensor(shape=shape).to(self.torch_device)
891
+ if self.BACKEND == self.NUMPY:
892
+ return np.zeros(shape)
839
893
 
840
894
  def bk_complex(self, real, imag):
841
895
  if self.BACKEND == self.TENSORFLOW:
842
896
  return self.backend.dtypes.complex(real, imag)
843
897
  if self.BACKEND == self.TORCH:
844
- return self.backend.complex(real, imag)
898
+ return self.backend.complex(real, imag).to(self.torch_device)
845
899
  if self.BACKEND == self.NUMPY:
846
900
  return real + 1j * imag
847
901
 
@@ -879,10 +933,10 @@ class foscat_backend:
879
933
  def bk_repeat(self, data, nn, axis=0):
880
934
  return self.backend.repeat(data, nn, axis=axis)
881
935
 
882
- def bk_tile(self, data, nn,axis=0):
936
+ def bk_tile(self, data, nn, axis=0):
883
937
  if self.BACKEND == self.TENSORFLOW:
884
938
  return self.backend.tile(data, [nn])
885
-
939
+
886
940
  return self.backend.tile(data, nn)
887
941
 
888
942
  def bk_roll(self, data, nn, axis=0):
@@ -918,7 +972,7 @@ class foscat_backend:
918
972
  xi = self.backend.concat(
919
973
  [self.bk_imag(data[k]) for k in range(ndata)]
920
974
  )
921
- return self.backend.complex(xr, xi)
975
+ return self.bk_complex(xr, xi)
922
976
  else:
923
977
  return self.backend.concat(data)
924
978
  else:
@@ -930,7 +984,7 @@ class foscat_backend:
930
984
  xi = self.backend.concat(
931
985
  [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
932
986
  )
933
- return self.backend.complex(xr, xi)
987
+ return self.bk_complex(xr, xi)
934
988
  else:
935
989
  return self.backend.concat(data, axis=axis)
936
990
  else:
@@ -939,30 +993,30 @@ class foscat_backend:
939
993
  else:
940
994
  return np.concatenate(data, axis=axis)
941
995
 
942
- def bk_zeros(self, shape,dtype=None):
996
+ def bk_zeros(self, shape, dtype=None):
943
997
  if self.BACKEND == self.TENSORFLOW:
944
- return self.backend.zeros(shape,dtype=dtype)
998
+ return self.backend.zeros(shape, dtype=dtype)
945
999
  if self.BACKEND == self.TORCH:
946
- return self.backend.zeros(shape,dtype=dtype)
1000
+ return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
947
1001
  if self.BACKEND == self.NUMPY:
948
- return np.zeros(shape,dtype=dtype)
1002
+ return np.zeros(shape, dtype=dtype)
949
1003
 
950
- def bk_gather(self, data,idx):
1004
+ def bk_gather(self, data, idx):
951
1005
  if self.BACKEND == self.TENSORFLOW:
952
- return self.backend.gather(data,idx)
1006
+ return self.backend.gather(data, idx)
953
1007
  if self.BACKEND == self.TORCH:
954
1008
  return data[idx]
955
1009
  if self.BACKEND == self.NUMPY:
956
1010
  return data[idx]
957
-
958
- def bk_reverse(self, data,axis=0):
1011
+
1012
+ def bk_reverse(self, data, axis=0):
959
1013
  if self.BACKEND == self.TENSORFLOW:
960
- return self.backend.reverse(data,axis=[axis])
1014
+ return self.backend.reverse(data, axis=[axis])
961
1015
  if self.BACKEND == self.TORCH:
962
- return self.backend.reverse(data,axis=axis)
1016
+ return self.backend.reverse(data, axis=axis)
963
1017
  if self.BACKEND == self.NUMPY:
964
- return np.reverse(data,axis=axis)
965
-
1018
+ return np.reverse(data, axis=axis)
1019
+
966
1020
  def bk_fft(self, data):
967
1021
  if self.BACKEND == self.TENSORFLOW:
968
1022
  return self.backend.signal.fft(data)
@@ -970,7 +1024,23 @@ class foscat_backend:
970
1024
  return self.backend.fft(data)
971
1025
  if self.BACKEND == self.NUMPY:
972
1026
  return self.backend.fft.fft(data)
973
-
1027
+
1028
+ def bk_fftn(self, data,dim=None):
1029
+ if self.BACKEND == self.TENSORFLOW:
1030
+ return self.backend.signal.fftn(data)
1031
+ if self.BACKEND == self.TORCH:
1032
+ return self.backend.fft.fftn(data,dim=dim)
1033
+ if self.BACKEND == self.NUMPY:
1034
+ return self.backend.fft.fftn(data)
1035
+
1036
+ def bk_ifftn(self, data,dim=None,norm=None):
1037
+ if self.BACKEND == self.TENSORFLOW:
1038
+ return self.backend.signal.ifftn(data)
1039
+ if self.BACKEND == self.TORCH:
1040
+ return self.backend.fft.ifftn(data,dim=dim,norm=norm)
1041
+ if self.BACKEND == self.NUMPY:
1042
+ return self.backend.fft.ifftn(data)
1043
+
974
1044
  def bk_rfft(self, data):
975
1045
  if self.BACKEND == self.TENSORFLOW:
976
1046
  return self.backend.signal.rfft(data)
@@ -979,7 +1049,6 @@ class foscat_backend:
979
1049
  if self.BACKEND == self.NUMPY:
980
1050
  return self.backend.fft.rfft(data)
981
1051
 
982
-
983
1052
  def bk_irfft(self, data):
984
1053
  if self.BACKEND == self.TENSORFLOW:
985
1054
  return self.backend.signal.irfft(data)
@@ -987,7 +1056,7 @@ class foscat_backend:
987
1056
  return self.backend.irfft(data)
988
1057
  if self.BACKEND == self.NUMPY:
989
1058
  return self.backend.fft.irfft(data)
990
-
1059
+
991
1060
  def bk_conjugate(self, data):
992
1061
 
993
1062
  if self.BACKEND == self.TENSORFLOW:
@@ -1022,7 +1091,7 @@ class foscat_backend:
1022
1091
  if x.dtype == self.all_cbk_type:
1023
1092
  xr = self.backend.nn.relu(self.bk_real(x))
1024
1093
  xi = self.backend.nn.relu(self.bk_imag(x))
1025
- return self.backend.complex(xr, xi)
1094
+ return self.bk_complex(xr, xi)
1026
1095
  else:
1027
1096
  return self.backend.nn.relu(x)
1028
1097
  if self.BACKEND == self.TORCH:
@@ -1030,6 +1099,19 @@ class foscat_backend:
1030
1099
  if self.BACKEND == self.NUMPY:
1031
1100
  return (x > 0) * x
1032
1101
 
1102
+ def bk_clip_by_value(self, x,xmin,xmax):
1103
+ if isinstance(x, np.ndarray):
1104
+ x = np.clip(x,xmin,xmax)
1105
+ if self.BACKEND == self.TENSORFLOW:
1106
+ return self.backend.clip_by_value(x,xmin,xmax)
1107
+ if self.BACKEND == self.TORCH:
1108
+ x = self.backend.tensor(x, dtype=self.backend.float32) if not isinstance(x, self.backend.Tensor) else x
1109
+ xmin = self.backend.tensor(xmin, dtype=self.backend.float32) if not isinstance(xmin, self.backend.Tensor) else xmin
1110
+ xmax = self.backend.tensor(xmax, dtype=self.backend.float32) if not isinstance(xmax, self.backend.Tensor) else xmax
1111
+ return self.backend.clamp(x, min=xmin, max=xmax)
1112
+ if self.BACKEND == self.NUMPY:
1113
+ return self.backend.clip(x,xmin,xmax)
1114
+
1033
1115
  def bk_cast(self, x):
1034
1116
  if isinstance(x, np.float64):
1035
1117
  if self.all_bk_type == "float32":
@@ -1041,6 +1123,16 @@ class foscat_backend:
1041
1123
  return np.float64(x)
1042
1124
  else:
1043
1125
  return x
1126
+ if isinstance(x, np.complex128):
1127
+ if self.all_bk_type == "float32":
1128
+ return np.complex64(x)
1129
+ else:
1130
+ return x
1131
+ if isinstance(x, np.complex64):
1132
+ if self.all_bk_type == "float64":
1133
+ return np.complex128(x)
1134
+ else:
1135
+ return x
1044
1136
 
1045
1137
  if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
1046
1138
  if self.all_bk_type == "float64":
@@ -1058,14 +1150,26 @@ class foscat_backend:
1058
1150
 
1059
1151
  if self.BACKEND == self.TORCH:
1060
1152
  if isinstance(x, np.ndarray):
1061
- x = self.backend.from_numpy(x)
1153
+ x = self.backend.from_numpy(x).to(self.torch_device)
1062
1154
 
1063
1155
  if x.dtype.is_complex:
1064
1156
  out_type = self.all_cbk_type
1065
1157
  else:
1066
1158
  out_type = self.all_bk_type
1067
1159
 
1068
- return x.type(out_type)
1160
+ return x.type(out_type).to(self.torch_device)
1069
1161
 
1070
1162
  if self.BACKEND == self.NUMPY:
1071
1163
  return x.astype(out_type)
1164
+
1165
+ def to_numpy(self,x):
1166
+ if isinstance(x, np.ndarray):
1167
+ return x
1168
+
1169
+ if self.BACKEND == self.NUMPY:
1170
+ return x
1171
+ if self.BACKEND == self.TENSORFLOW:
1172
+ return x.numpy()
1173
+
1174
+ if self.BACKEND == self.TORCH:
1175
+ return x.cpu().numpy()