foscat 3.6.1__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -44,6 +44,7 @@ class foscat_backend:
44
44
 
45
45
  if self.BACKEND == "torch":
46
46
  import torch
47
+ self.torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
48
 
48
49
  self.BACKEND = self.TORCH
49
50
  self.backend = torch
@@ -404,18 +405,53 @@ class foscat_backend:
404
405
  if self.BACKEND == self.NUMPY:
405
406
  return smat.dot(mat)
406
407
 
408
+ # for tensorflow wrapping only
409
+ def periodic_pad(self,x, pad_height, pad_width):
410
+ """
411
+ Applies periodic ('wrap') padding to a 4D TensorFlow tensor (N, H, W, C).
412
+
413
+ Args:
414
+ x (tf.Tensor): Input tensor with shape (batch_size, height, width, channels).
415
+ pad_height (tuple): Tuple (top, bottom) defining the vertical padding size.
416
+ pad_width (tuple): Tuple (left, right) defining the horizontal padding size.
417
+
418
+ Returns:
419
+ tf.Tensor: Tensor with periodic padding applied.
420
+ """
421
+ #Vertical padding: take slices from bottom and top to wrap around
422
+ top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
423
+ bottom_pad = x[:, :pad_height, :, :] # Bottom padding from the top rows
424
+ x_padded = self.backend.concat([top_pad, x, bottom_pad], axis=1) # Concatenate vertically
425
+
426
+ #Horizontal padding: take slices from right and left to wrap around
427
+ left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
428
+ right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
429
+
430
+ x_padded = self.backend.concat([left_pad, x_padded, right_pad], axis=2) # Concatenate horizontally
431
+
432
+ return x_padded
433
+
407
434
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
435
  if self.BACKEND == self.TENSORFLOW:
409
436
  kx = w.shape[0]
410
437
  ky = w.shape[1]
411
- paddings = self.backend.constant(
412
- [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
- )
414
- tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
- return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
416
- # to be written!!!
438
+ x_padded = self.periodic_pad(x, kx // 2, ky // 2)
439
+ return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
440
+
417
441
  if self.BACKEND == self.TORCH:
418
- return x
442
+ import torch.nn.functional as F
443
+ lx = x.permute(0, 3, 1, 2)
444
+ wx = self.backend.from_numpy(w).to(self.torch_device).permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
445
+
446
+ # Calculer le padding symétrique
447
+ kx, ky = w.shape[0], w.shape[1]
448
+
449
+ # Appliquer le padding
450
+ x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode='circular')
451
+
452
+ # Appliquer la convolution
453
+ return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
454
+
419
455
  if self.BACKEND == self.NUMPY:
420
456
  res = np.zeros(
421
457
  [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
@@ -540,9 +576,10 @@ class foscat_backend:
540
576
 
541
577
  if self.BACKEND == self.TORCH:
542
578
  tmp = self.backend.nn.functional.interpolate(
543
- x, size=shape, mode="bilinear", align_corners=False
579
+ x.permute(0,3,1,2), size=shape, mode="bilinear", align_corners=False
544
580
  )
545
- return self.bk_cast(tmp)
581
+ return self.bk_cast(tmp.permute(0,2,3,1))
582
+
546
583
  if self.BACKEND == self.NUMPY:
547
584
  return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
585
 
@@ -592,7 +629,7 @@ class foscat_backend:
592
629
 
593
630
  # ---------------------------------------------−---------
594
631
  # return a tensor size
595
-
632
+
596
633
  def bk_size(self, data):
597
634
  if self.BACKEND == self.TENSORFLOW:
598
635
  return self.backend.size(data)
@@ -600,7 +637,7 @@ class foscat_backend:
600
637
  return data.numel()
601
638
  if self.BACKEND == self.NUMPY:
602
639
  return data.size
603
-
640
+
604
641
  # ---------------------------------------------−---------
605
642
 
606
643
  def iso_mean(self, x, use_2D=False):
@@ -836,6 +873,14 @@ class foscat_backend:
836
873
  return self.backend.constant(data)
837
874
  if self.BACKEND == self.NUMPY:
838
875
  return data
876
+
877
+ def bk_shape_tensor(self, shape):
878
+ if self.BACKEND == self.TENSORFLOW:
879
+ return self.backend.tensor(shape=shape)
880
+ if self.BACKEND == self.TORCH:
881
+ return self.backend.tensor(shape=shape)
882
+ if self.BACKEND == self.NUMPY:
883
+ return np.zeros(shape)
839
884
 
840
885
  def bk_complex(self, real, imag):
841
886
  if self.BACKEND == self.TENSORFLOW:
@@ -879,10 +924,10 @@ class foscat_backend:
879
924
  def bk_repeat(self, data, nn, axis=0):
880
925
  return self.backend.repeat(data, nn, axis=axis)
881
926
 
882
- def bk_tile(self, data, nn,axis=0):
927
+ def bk_tile(self, data, nn, axis=0):
883
928
  if self.BACKEND == self.TENSORFLOW:
884
929
  return self.backend.tile(data, [nn])
885
-
930
+
886
931
  return self.backend.tile(data, nn)
887
932
 
888
933
  def bk_roll(self, data, nn, axis=0):
@@ -939,30 +984,30 @@ class foscat_backend:
939
984
  else:
940
985
  return np.concatenate(data, axis=axis)
941
986
 
942
- def bk_zeros(self, shape,dtype=None):
987
+ def bk_zeros(self, shape, dtype=None):
943
988
  if self.BACKEND == self.TENSORFLOW:
944
- return self.backend.zeros(shape,dtype=dtype)
989
+ return self.backend.zeros(shape, dtype=dtype)
945
990
  if self.BACKEND == self.TORCH:
946
- return self.backend.zeros(shape,dtype=dtype)
991
+ return self.backend.zeros(shape, dtype=dtype)
947
992
  if self.BACKEND == self.NUMPY:
948
- return np.zeros(shape,dtype=dtype)
993
+ return np.zeros(shape, dtype=dtype)
949
994
 
950
- def bk_gather(self, data,idx):
995
+ def bk_gather(self, data, idx):
951
996
  if self.BACKEND == self.TENSORFLOW:
952
- return self.backend.gather(data,idx)
997
+ return self.backend.gather(data, idx)
953
998
  if self.BACKEND == self.TORCH:
954
999
  return data[idx]
955
1000
  if self.BACKEND == self.NUMPY:
956
1001
  return data[idx]
957
-
958
- def bk_reverse(self, data,axis=0):
1002
+
1003
+ def bk_reverse(self, data, axis=0):
959
1004
  if self.BACKEND == self.TENSORFLOW:
960
- return self.backend.reverse(data,axis=[axis])
1005
+ return self.backend.reverse(data, axis=[axis])
961
1006
  if self.BACKEND == self.TORCH:
962
- return self.backend.reverse(data,axis=axis)
1007
+ return self.backend.reverse(data, axis=axis)
963
1008
  if self.BACKEND == self.NUMPY:
964
- return np.reverse(data,axis=axis)
965
-
1009
+ return np.reverse(data, axis=axis)
1010
+
966
1011
  def bk_fft(self, data):
967
1012
  if self.BACKEND == self.TENSORFLOW:
968
1013
  return self.backend.signal.fft(data)
@@ -970,7 +1015,7 @@ class foscat_backend:
970
1015
  return self.backend.fft(data)
971
1016
  if self.BACKEND == self.NUMPY:
972
1017
  return self.backend.fft.fft(data)
973
-
1018
+
974
1019
  def bk_rfft(self, data):
975
1020
  if self.BACKEND == self.TENSORFLOW:
976
1021
  return self.backend.signal.rfft(data)
@@ -979,7 +1024,6 @@ class foscat_backend:
979
1024
  if self.BACKEND == self.NUMPY:
980
1025
  return self.backend.fft.rfft(data)
981
1026
 
982
-
983
1027
  def bk_irfft(self, data):
984
1028
  if self.BACKEND == self.TENSORFLOW:
985
1029
  return self.backend.signal.irfft(data)
@@ -987,7 +1031,7 @@ class foscat_backend:
987
1031
  return self.backend.irfft(data)
988
1032
  if self.BACKEND == self.NUMPY:
989
1033
  return self.backend.fft.irfft(data)
990
-
1034
+
991
1035
  def bk_conjugate(self, data):
992
1036
 
993
1037
  if self.BACKEND == self.TENSORFLOW:
@@ -1030,6 +1074,19 @@ class foscat_backend:
1030
1074
  if self.BACKEND == self.NUMPY:
1031
1075
  return (x > 0) * x
1032
1076
 
1077
+ def bk_clip_by_value(self, x,xmin,xmax):
1078
+ if isinstance(x, np.ndarray):
1079
+ x = np.clip(x,xmin,xmax)
1080
+ if self.BACKEND == self.TENSORFLOW:
1081
+ return self.backend.clip_by_value(x,xmin,xmax)
1082
+ if self.BACKEND == self.TORCH:
1083
+ x = self.backend.tensor(x, dtype=self.backend.float32) if not isinstance(x, self.backend.Tensor) else x
1084
+ xmin = self.backend.tensor(xmin, dtype=self.backend.float32) if not isinstance(xmin, self.backend.Tensor) else xmin
1085
+ xmax = self.backend.tensor(xmax, dtype=self.backend.float32) if not isinstance(xmax, self.backend.Tensor) else xmax
1086
+ return self.backend.clamp(x, min=xmin, max=xmax)
1087
+ if self.BACKEND == self.NUMPY:
1088
+ return self.backend.clip(x,xmin,xmax)
1089
+
1033
1090
  def bk_cast(self, x):
1034
1091
  if isinstance(x, np.float64):
1035
1092
  if self.all_bk_type == "float32":
@@ -1065,7 +1122,19 @@ class foscat_backend:
1065
1122
  else:
1066
1123
  out_type = self.all_bk_type
1067
1124
 
1068
- return x.type(out_type)
1125
+ return x.type(out_type).to(self.torch_device)
1069
1126
 
1070
1127
  if self.BACKEND == self.NUMPY:
1071
1128
  return x.astype(out_type)
1129
+
1130
+ def to_numpy(self,x):
1131
+ if isinstance(x, np.ndarray):
1132
+ return x
1133
+
1134
+ if self.BACKEND == self.NUMPY:
1135
+ return x
1136
+ if self.BACKEND == self.TENSORFLOW:
1137
+ return x.numpy()
1138
+
1139
+ if self.BACKEND == self.TORCH:
1140
+ return x.cpu().numpy()