foscat 3.6.0__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -44,6 +44,7 @@ class foscat_backend:
44
44
 
45
45
  if self.BACKEND == "torch":
46
46
  import torch
47
+ self.torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
48
 
48
49
  self.BACKEND = self.TORCH
49
50
  self.backend = torch
@@ -404,18 +405,53 @@ class foscat_backend:
404
405
  if self.BACKEND == self.NUMPY:
405
406
  return smat.dot(mat)
406
407
 
408
+ # for tensorflow wrapping only
409
+ def periodic_pad(self,x, pad_height, pad_width):
410
+ """
411
+ Applies periodic ('wrap') padding to a 4D TensorFlow tensor (N, H, W, C).
412
+
413
+ Args:
414
+ x (tf.Tensor): Input tensor with shape (batch_size, height, width, channels).
415
+ pad_height (tuple): Tuple (top, bottom) defining the vertical padding size.
416
+ pad_width (tuple): Tuple (left, right) defining the horizontal padding size.
417
+
418
+ Returns:
419
+ tf.Tensor: Tensor with periodic padding applied.
420
+ """
421
+ #Vertical padding: take slices from bottom and top to wrap around
422
+ top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
423
+ bottom_pad = x[:, :pad_height, :, :] # Bottom padding from the top rows
424
+ x_padded = self.backend.concat([top_pad, x, bottom_pad], axis=1) # Concatenate vertically
425
+
426
+ #Horizontal padding: take slices from right and left to wrap around
427
+ left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
428
+ right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
429
+
430
+ x_padded = self.backend.concat([left_pad, x_padded, right_pad], axis=2) # Concatenate horizontally
431
+
432
+ return x_padded
433
+
407
434
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
435
  if self.BACKEND == self.TENSORFLOW:
409
436
  kx = w.shape[0]
410
437
  ky = w.shape[1]
411
- paddings = self.backend.constant(
412
- [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
- )
414
- tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
- return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
416
- # to be written!!!
438
+ x_padded = self.periodic_pad(x, kx // 2, ky // 2)
439
+ return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
440
+
417
441
  if self.BACKEND == self.TORCH:
418
- return x
442
+ import torch.nn.functional as F
443
+ lx = x.permute(0, 3, 1, 2)
444
+ wx = self.backend.from_numpy(w).to(self.torch_device).permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
445
+
446
+ # Calculer le padding symétrique
447
+ kx, ky = w.shape[0], w.shape[1]
448
+
449
+ # Appliquer le padding
450
+ x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode='circular')
451
+
452
+ # Appliquer la convolution
453
+ return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
454
+
419
455
  if self.BACKEND == self.NUMPY:
420
456
  res = np.zeros(
421
457
  [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
@@ -540,9 +576,10 @@ class foscat_backend:
540
576
 
541
577
  if self.BACKEND == self.TORCH:
542
578
  tmp = self.backend.nn.functional.interpolate(
543
- x, size=shape, mode="bilinear", align_corners=False
579
+ x.permute(0,3,1,2), size=shape, mode="bilinear", align_corners=False
544
580
  )
545
- return self.bk_cast(tmp)
581
+ return self.bk_cast(tmp.permute(0,2,3,1))
582
+
546
583
  if self.BACKEND == self.NUMPY:
547
584
  return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
585
 
@@ -592,7 +629,7 @@ class foscat_backend:
592
629
 
593
630
  # ---------------------------------------------−---------
594
631
  # return a tensor size
595
-
632
+
596
633
  def bk_size(self, data):
597
634
  if self.BACKEND == self.TENSORFLOW:
598
635
  return self.backend.size(data)
@@ -600,7 +637,7 @@ class foscat_backend:
600
637
  return data.numel()
601
638
  if self.BACKEND == self.NUMPY:
602
639
  return data.size
603
-
640
+
604
641
  # ---------------------------------------------−---------
605
642
 
606
643
  def iso_mean(self, x, use_2D=False):
@@ -836,6 +873,14 @@ class foscat_backend:
836
873
  return self.backend.constant(data)
837
874
  if self.BACKEND == self.NUMPY:
838
875
  return data
876
+
877
+ def bk_shape_tensor(self, shape):
878
+ if self.BACKEND == self.TENSORFLOW:
879
+ return self.backend.tensor(shape=shape)
880
+ if self.BACKEND == self.TORCH:
881
+ return self.backend.tensor(shape=shape)
882
+ if self.BACKEND == self.NUMPY:
883
+ return np.zeros(shape)
839
884
 
840
885
  def bk_complex(self, real, imag):
841
886
  if self.BACKEND == self.TENSORFLOW:
@@ -879,10 +924,10 @@ class foscat_backend:
879
924
  def bk_repeat(self, data, nn, axis=0):
880
925
  return self.backend.repeat(data, nn, axis=axis)
881
926
 
882
- def bk_tile(self, data, nn,axis=0):
927
+ def bk_tile(self, data, nn, axis=0):
883
928
  if self.BACKEND == self.TENSORFLOW:
884
929
  return self.backend.tile(data, [nn])
885
-
930
+
886
931
  return self.backend.tile(data, nn)
887
932
 
888
933
  def bk_roll(self, data, nn, axis=0):
@@ -939,30 +984,30 @@ class foscat_backend:
939
984
  else:
940
985
  return np.concatenate(data, axis=axis)
941
986
 
942
- def bk_zeros(self, shape,dtype=None):
987
+ def bk_zeros(self, shape, dtype=None):
943
988
  if self.BACKEND == self.TENSORFLOW:
944
- return self.backend.zeros(shape,dtype=dtype)
989
+ return self.backend.zeros(shape, dtype=dtype)
945
990
  if self.BACKEND == self.TORCH:
946
- return self.backend.zeros(shape,dtype=dtype)
991
+ return self.backend.zeros(shape, dtype=dtype)
947
992
  if self.BACKEND == self.NUMPY:
948
- return np.zeros(shape,dtype=dtype)
993
+ return np.zeros(shape, dtype=dtype)
949
994
 
950
- def bk_gather(self, data,idx):
995
+ def bk_gather(self, data, idx):
951
996
  if self.BACKEND == self.TENSORFLOW:
952
- return self.backend.gather(data,idx)
997
+ return self.backend.gather(data, idx)
953
998
  if self.BACKEND == self.TORCH:
954
999
  return data[idx]
955
1000
  if self.BACKEND == self.NUMPY:
956
1001
  return data[idx]
957
-
958
- def bk_reverse(self, data,axis=0):
1002
+
1003
+ def bk_reverse(self, data, axis=0):
959
1004
  if self.BACKEND == self.TENSORFLOW:
960
- return self.backend.reverse(data,axis=[axis])
1005
+ return self.backend.reverse(data, axis=[axis])
961
1006
  if self.BACKEND == self.TORCH:
962
- return self.backend.reverse(data,axis=axis)
1007
+ return self.backend.reverse(data, axis=axis)
963
1008
  if self.BACKEND == self.NUMPY:
964
- return np.reverse(data,axis=axis)
965
-
1009
+ return np.reverse(data, axis=axis)
1010
+
966
1011
  def bk_fft(self, data):
967
1012
  if self.BACKEND == self.TENSORFLOW:
968
1013
  return self.backend.signal.fft(data)
@@ -970,7 +1015,7 @@ class foscat_backend:
970
1015
  return self.backend.fft(data)
971
1016
  if self.BACKEND == self.NUMPY:
972
1017
  return self.backend.fft.fft(data)
973
-
1018
+
974
1019
  def bk_rfft(self, data):
975
1020
  if self.BACKEND == self.TENSORFLOW:
976
1021
  return self.backend.signal.rfft(data)
@@ -978,6 +1023,15 @@ class foscat_backend:
978
1023
  return self.backend.rfft(data)
979
1024
  if self.BACKEND == self.NUMPY:
980
1025
  return self.backend.fft.rfft(data)
1026
+
1027
+ def bk_irfft(self, data):
1028
+ if self.BACKEND == self.TENSORFLOW:
1029
+ return self.backend.signal.irfft(data)
1030
+ if self.BACKEND == self.TORCH:
1031
+ return self.backend.irfft(data)
1032
+ if self.BACKEND == self.NUMPY:
1033
+ return self.backend.fft.irfft(data)
1034
+
981
1035
  def bk_conjugate(self, data):
982
1036
 
983
1037
  if self.BACKEND == self.TENSORFLOW:
@@ -1020,6 +1074,19 @@ class foscat_backend:
1020
1074
  if self.BACKEND == self.NUMPY:
1021
1075
  return (x > 0) * x
1022
1076
 
1077
+ def bk_clip_by_value(self, x,xmin,xmax):
1078
+ if isinstance(x, np.ndarray):
1079
+ x = np.clip(x,xmin,xmax)
1080
+ if self.BACKEND == self.TENSORFLOW:
1081
+ return self.backend.clip_by_value(x,xmin,xmax)
1082
+ if self.BACKEND == self.TORCH:
1083
+ x = self.backend.tensor(x, dtype=self.backend.float32) if not isinstance(x, self.backend.Tensor) else x
1084
+ xmin = self.backend.tensor(xmin, dtype=self.backend.float32) if not isinstance(xmin, self.backend.Tensor) else xmin
1085
+ xmax = self.backend.tensor(xmax, dtype=self.backend.float32) if not isinstance(xmax, self.backend.Tensor) else xmax
1086
+ return self.backend.clamp(x, min=xmin, max=xmax)
1087
+ if self.BACKEND == self.NUMPY:
1088
+ return self.backend.clip(x,xmin,xmax)
1089
+
1023
1090
  def bk_cast(self, x):
1024
1091
  if isinstance(x, np.float64):
1025
1092
  if self.all_bk_type == "float32":
@@ -1055,7 +1122,19 @@ class foscat_backend:
1055
1122
  else:
1056
1123
  out_type = self.all_bk_type
1057
1124
 
1058
- return x.type(out_type)
1125
+ return x.type(out_type).to(self.torch_device)
1059
1126
 
1060
1127
  if self.BACKEND == self.NUMPY:
1061
1128
  return x.astype(out_type)
1129
+
1130
+ def to_numpy(self,x):
1131
+ if isinstance(x, np.ndarray):
1132
+ return x
1133
+
1134
+ if self.BACKEND == self.NUMPY:
1135
+ return x
1136
+ if self.BACKEND == self.TENSORFLOW:
1137
+ return x.numpy()
1138
+
1139
+ if self.BACKEND == self.TORCH:
1140
+ return x.cpu().numpy()