foscat 3.8.0__tar.gz → 3.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {foscat-3.8.0/src/foscat.egg-info → foscat-3.9.0}/PKG-INFO +1 -1
  2. {foscat-3.8.0 → foscat-3.9.0}/pyproject.toml +1 -1
  3. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/BkBase.py +39 -29
  4. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/BkNumpy.py +56 -47
  5. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/BkTensorflow.py +89 -81
  6. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/BkTorch.py +95 -59
  7. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/FoCUS.py +72 -56
  8. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/Synthesis.py +3 -3
  9. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/alm.py +194 -177
  10. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/backend.py +84 -70
  11. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat_cov.py +1876 -2100
  12. foscat-3.9.0/src/foscat/scat_cov2D.py +211 -0
  13. {foscat-3.8.0 → foscat-3.9.0/src/foscat.egg-info}/PKG-INFO +1 -1
  14. foscat-3.8.0/src/foscat/scat_cov2D.py +0 -118
  15. {foscat-3.8.0 → foscat-3.9.0}/LICENSE +0 -0
  16. {foscat-3.8.0 → foscat-3.9.0}/README.md +0 -0
  17. {foscat-3.8.0 → foscat-3.9.0}/setup.cfg +0 -0
  18. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/CNN.py +0 -0
  19. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/CircSpline.py +0 -0
  20. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/GCNN.py +0 -0
  21. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/Softmax.py +0 -0
  22. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/Spline1D.py +0 -0
  23. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/__init__.py +0 -0
  24. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/backend_tens.py +0 -0
  25. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/loss_backend_tens.py +0 -0
  26. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/loss_backend_torch.py +0 -0
  27. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat.py +0 -0
  28. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat1D.py +0 -0
  29. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat2D.py +0 -0
  30. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat_cov1D.py +0 -0
  31. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat_cov_map.py +0 -0
  32. {foscat-3.8.0 → foscat-3.9.0}/src/foscat/scat_cov_map2D.py +0 -0
  33. {foscat-3.8.0 → foscat-3.9.0}/src/foscat.egg-info/SOURCES.txt +0 -0
  34. {foscat-3.8.0 → foscat-3.9.0}/src/foscat.egg-info/dependency_links.txt +0 -0
  35. {foscat-3.8.0 → foscat-3.9.0}/src/foscat.egg-info/requires.txt +0 -0
  36. {foscat-3.8.0 → foscat-3.9.0}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: foscat
3
- Version: 3.8.0
3
+ Version: 3.9.0
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "3.8.0"
3
+ version = "3.9.0"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -1,14 +1,15 @@
1
1
  import numpy as np
2
2
 
3
+
3
4
  class BackendBase:
4
-
5
+
5
6
  def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
6
-
7
- self.BACKEND=name
8
- self.mpi_rank=mpi_rank
9
- self.all_type=all_type
10
- self.gpupos=gpupos
11
- self.silent=silent
7
+
8
+ self.BACKEND = name
9
+ self.mpi_rank = mpi_rank
10
+ self.all_type = all_type
11
+ self.gpupos = gpupos
12
+ self.silent = silent
12
13
  # ---------------------------------------------−---------
13
14
  # table use to compute the iso orientation rotation
14
15
  self._iso_orient = {}
@@ -101,6 +102,7 @@ class BackendBase:
101
102
  return self.bk_reshape(
102
103
  self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
103
104
  )
105
+
104
106
  def calc_iso_orient(self, norient):
105
107
  tmp = np.zeros([norient * norient, norient])
106
108
  for i in range(norient):
@@ -188,7 +190,9 @@ class BackendBase:
188
190
  tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
189
191
 
190
192
  self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
191
- self.bk_constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
193
+ self.bk_constant(
194
+ tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm))
195
+ )
192
196
  )
193
197
  self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
194
198
  self._fft_2_orient[(norient, nharm, imaginary)],
@@ -343,7 +347,7 @@ class BackendBase:
343
347
  self._fft_3_orient[(norient, nharm, imaginary)],
344
348
  0 * self._fft_3_orient[(norient, nharm, imaginary)],
345
349
  )
346
-
350
+
347
351
  # ---------------------------------------------−---------
348
352
  # -- BACKEND DEFINITION --
349
353
  # ---------------------------------------------−---------
@@ -355,7 +359,7 @@ class BackendBase:
355
359
 
356
360
  def bk_sparse_dense_matmul(self, smat, mat):
357
361
  raise NotImplementedError("This is an abstract class.")
358
-
362
+
359
363
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
360
364
  raise NotImplementedError("This is an abstract class.")
361
365
 
@@ -379,9 +383,6 @@ class BackendBase:
379
383
 
380
384
  def bk_flattenR(self, x):
381
385
  raise NotImplementedError("This is an abstract class.")
382
-
383
- def bk_flatten(self, x):
384
- raise NotImplementedError("This is an abstract class.")
385
386
 
386
387
  def bk_flatten(self, x):
387
388
  raise NotImplementedError("This is an abstract class.")
@@ -442,7 +443,7 @@ class BackendBase:
442
443
 
443
444
  def bk_tensor(self, data):
444
445
  raise NotImplementedError("This is an abstract class.")
445
-
446
+
446
447
  def bk_shape_tensor(self, shape):
447
448
  raise NotImplementedError("This is an abstract class.")
448
449
 
@@ -460,7 +461,7 @@ class BackendBase:
460
461
 
461
462
  def bk_tanh(self, data):
462
463
  raise NotImplementedError("This is an abstract class.")
463
-
464
+
464
465
  def bk_max(self, data):
465
466
  raise NotImplementedError("This is an abstract class.")
466
467
 
@@ -499,11 +500,11 @@ class BackendBase:
499
500
 
500
501
  def bk_fft(self, data):
501
502
  raise NotImplementedError("This is an abstract class.")
502
-
503
- def bk_fftn(self, data,dim=None):
503
+
504
+ def bk_fftn(self, data, dim=None):
504
505
  raise NotImplementedError("This is an abstract class.")
505
506
 
506
- def bk_ifftn(self, data,dim=None,norm=None):
507
+ def bk_ifftn(self, data, dim=None, norm=None):
507
508
  raise NotImplementedError("This is an abstract class.")
508
509
 
509
510
  def bk_rfft(self, data):
@@ -524,23 +525,32 @@ class BackendBase:
524
525
  def bk_relu(self, x):
525
526
  raise NotImplementedError("This is an abstract class.")
526
527
 
527
- def bk_clip_by_value(self, x,xmin,xmax):
528
+ def bk_clip_by_value(self, x, xmin, xmax):
528
529
  raise NotImplementedError("This is an abstract class.")
529
530
 
530
531
  def bk_cast(self, x):
531
532
  raise NotImplementedError("This is an abstract class.")
532
-
533
- def bk_variable(self,x):
533
+
534
+ def bk_variable(self, x):
535
+ raise NotImplementedError("This is an abstract class.")
536
+
537
+ def bk_assign(self, x, y):
538
+ raise NotImplementedError("This is an abstract class.")
539
+
540
+ def bk_constant(self, x):
541
+ raise NotImplementedError("This is an abstract class.")
542
+
543
+ def bk_cos(self, x):
534
544
  raise NotImplementedError("This is an abstract class.")
535
-
536
- def bk_assign(self,x,y):
545
+
546
+ def bk_sin(self, x):
537
547
  raise NotImplementedError("This is an abstract class.")
538
-
539
- def bk_constant(self,x):
548
+
549
+ def bk_arctan2(self, c, s):
540
550
  raise NotImplementedError("This is an abstract class.")
541
-
542
- def bk_empty(self,x):
551
+
552
+ def bk_empty(self, list):
543
553
  raise NotImplementedError("This is an abstract class.")
544
-
545
- def to_numpy(self,x):
554
+
555
+ def to_numpy(self, x):
546
556
  raise NotImplementedError("This is an abstract class.")
@@ -1,11 +1,13 @@
1
- import foscat.BkBase as BackendBase
2
1
  import numpy as np
3
2
 
3
+ import foscat.BkBase as BackendBase
4
+
5
+
4
6
  class BkNumpy(BackendBase.BackendBase):
5
-
7
+
6
8
  def __init__(self, *args, **kwargs):
7
9
  # Impose que use_2D=True pour la classe scat
8
- super().__init__(name='tensorflow', *args, **kwargs)
10
+ super().__init__(name="tensorflow", *args, **kwargs)
9
11
 
10
12
  # ===========================================================================
11
13
  # INIT
@@ -31,13 +33,14 @@ class BkNumpy(BackendBase.BackendBase):
31
33
  self.all_bk_type = self.backend.float64
32
34
  self.all_cbk_type = self.backend.complex128
33
35
  else:
34
- print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
36
+ print(
37
+ "ERROR INIT FOCUS ", self.all_type, " should be float32 or float64"
38
+ )
35
39
  return None
36
-
40
+
37
41
  # ===========================================================================
38
42
  # INIT
39
43
 
40
- gpus = []
41
44
  gpuname = "CPU:0"
42
45
  self.gpulist = {}
43
46
  self.gpulist[0] = gpuname
@@ -49,27 +52,25 @@ class BkNumpy(BackendBase.BackendBase):
49
52
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
50
53
  return self.scipy.sparse.coo_matrix(
51
54
  (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
52
- )
55
+ )
53
56
 
54
57
  def bk_stack(self, list, axis=0):
55
58
  return self.backend.stack(list, axis=axis)
56
59
 
57
60
  def bk_sparse_dense_matmul(self, smat, mat):
58
61
  return smat.dot(mat)
59
-
62
+
60
63
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
61
- res = np.zeros(
62
- [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
63
- )
64
+ res = np.zeros([x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype)
64
65
  for k in range(w.shape[2]):
65
66
  for l_orient in range(w.shape[3]):
66
67
  for j in range(res.shape[0]):
67
68
  tmp = self.scipy.signal.convolve2d(
68
- x[j, :, :, k],
69
- w[:, :, k, l_orient],
70
- mode="same",
71
- boundary="symm",
72
- )
69
+ x[j, :, :, k],
70
+ w[:, :, k, l_orient],
71
+ mode="same",
72
+ boundary="symm",
73
+ )
73
74
  res[j, :, :, l_orient] += tmp
74
75
  del tmp
75
76
  return res
@@ -79,8 +80,8 @@ class BkNumpy(BackendBase.BackendBase):
79
80
  for k in range(w.shape[2]):
80
81
  for j in range(res.shape[0]):
81
82
  tmp = self.scipy.signal.convolve1d(
82
- x[j, :, k], w[:, k], mode="same", boundary="symm"
83
- )
83
+ x[j, :, k], w[:, k], mode="same", boundary="symm"
84
+ )
84
85
  res[j, :, :] += tmp
85
86
  del tmp
86
87
  return res
@@ -114,10 +115,6 @@ class BkNumpy(BackendBase.BackendBase):
114
115
  return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
115
116
  else:
116
117
  return x.flatten()
117
-
118
-
119
- def bk_flatten(self, x):
120
- return x.flatten()
121
118
 
122
119
  def bk_flatten(self, x):
123
120
  return x.flatten()
@@ -144,10 +141,6 @@ class BkNumpy(BackendBase.BackendBase):
144
141
  # ---------------------------------------------−---------
145
142
  # return a tensor size
146
143
 
147
- def bk_size(self, data):
148
- return data.size
149
-
150
-
151
144
  def bk_reduce_mean(self, data, axis=None):
152
145
 
153
146
  if axis is None:
@@ -223,7 +216,7 @@ class BkNumpy(BackendBase.BackendBase):
223
216
 
224
217
  def bk_tensor(self, data):
225
218
  return data
226
-
219
+
227
220
  def bk_shape_tensor(self, shape):
228
221
  return np.zeros(shape)
229
222
 
@@ -284,19 +277,27 @@ class BkNumpy(BackendBase.BackendBase):
284
277
  def bk_zeros(self, shape, dtype=None):
285
278
  return np.zeros(shape, dtype=dtype)
286
279
 
287
- def bk_gather(self, data, idx):
288
- return data[idx]
280
+ def bk_gather(self, data, idx, axis=0):
281
+ if axis == 0:
282
+ return data[idx]
283
+ elif axis == 1:
284
+ return data[:, idx]
285
+ elif axis == 2:
286
+ return data[:, :, idx]
287
+ elif axis == 3:
288
+ return data[:, :, :, idx]
289
+ return data[:, :, :, :, idx]
289
290
 
290
291
  def bk_reverse(self, data, axis=0):
291
292
  return np.reverse(data, axis=axis)
292
293
 
293
294
  def bk_fft(self, data):
294
295
  return self.backend.fft.fft(data)
295
-
296
- def bk_fftn(self, data,dim=None):
296
+
297
+ def bk_fftn(self, data, dim=None):
297
298
  return self.backend.fft.fftn(data)
298
299
 
299
- def bk_ifftn(self, data,dim=None,norm=None):
300
+ def bk_ifftn(self, data, dim=None, norm=None):
300
301
  return self.backend.fft.ifftn(data)
301
302
 
302
303
  def bk_rfft(self, data):
@@ -318,8 +319,8 @@ class BkNumpy(BackendBase.BackendBase):
318
319
  def bk_relu(self, x):
319
320
  return (x > 0) * x
320
321
 
321
- def bk_clip_by_value(self, x,xmin,xmax):
322
- return self.backend.clip(x,xmin,xmax)
322
+ def bk_clip_by_value(self, x, xmin, xmax):
323
+ return self.backend.clip(x, xmin, xmax)
323
324
 
324
325
  def bk_cast(self, x):
325
326
  if isinstance(x, np.float64):
@@ -355,20 +356,28 @@ class BkNumpy(BackendBase.BackendBase):
355
356
  out_type = self.all_bk_type
356
357
 
357
358
  return x.astype(out_type)
358
-
359
- def bk_variable(self,x):
360
-
359
+
360
+ def bk_variable(self, x):
361
+
361
362
  return self.bk_cast(x)
362
-
363
- def bk_assign(self,x,y):
364
- x=y
365
-
366
- def bk_constant(self,x):
367
-
363
+
364
+ def bk_assign(self, x, y):
365
+ return y
366
+
367
+ def bk_constant(self, x):
368
368
  return self.bk_cast(x)
369
-
370
- def bk_empty(self,list):
369
+
370
+ def bk_cos(self, x):
371
+ return self.backend.cos(x)
372
+
373
+ def bk_sin(self, x):
374
+ return self.backend.sin(x)
375
+
376
+ def bk_arctan2(self, c, s):
377
+ return self.backend.arctan2(c, s)
378
+
379
+ def bk_empty(self, list):
371
380
  return self.backend.empty(list)
372
-
373
- def to_numpy(self,x):
381
+
382
+ def to_numpy(self, x):
374
383
  return x
@@ -1,15 +1,16 @@
1
1
  import sys
2
2
 
3
- import foscat.BkBase as BackendBase
4
3
  import numpy as np
5
4
  import tensorflow as tf
6
-
5
+
6
+ import foscat.BkBase as BackendBase
7
+
8
+
7
9
  class BkTensorflow(BackendBase.BackendBase):
8
-
10
+
9
11
  def __init__(self, *args, **kwargs):
10
12
  # Impose que use_2D=True pour la classe scat
11
- super().__init__(name='tensorflow', *args, **kwargs)
12
-
13
+ super().__init__(name="tensorflow", *args, **kwargs)
13
14
 
14
15
  # ===========================================================================
15
16
  # INIT
@@ -30,12 +31,14 @@ class BkTensorflow(BackendBase.BackendBase):
30
31
  "float32": (self.backend.float32, self.backend.complex64),
31
32
  "float64": (self.backend.float64, self.backend.complex128),
32
33
  }
33
-
34
+
34
35
  if self.all_type in dtype_map:
35
36
  self.all_bk_type, self.all_cbk_type = dtype_map[self.all_type]
36
37
  else:
37
- raise ValueError(f"ERROR INIT foscat: {all_type} should be float32 or float64")
38
-
38
+ raise ValueError(
39
+ f"ERROR INIT foscat: {self.all_type} should be float32 or float64"
40
+ )
41
+
39
42
  if self.mpi_rank == 0:
40
43
  if not self.silent:
41
44
  print(
@@ -59,12 +62,10 @@ class BkTensorflow(BackendBase.BackendBase):
59
62
  # Currently, memory growth needs to be the same across GPUs
60
63
  for gpu in gpus:
61
64
  self.backend.config.experimental.set_memory_growth(gpu, True)
62
- logical_gpus = (
63
- self.backend.config.experimental.list_logical_devices("GPU")
64
- )
65
- print(
66
- len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
65
+ logical_gpus = self.backend.config.experimental.list_logical_devices(
66
+ "GPU"
67
67
  )
68
+ print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
68
69
  sys.stdout.flush()
69
70
  self.ngpu = len(logical_gpus)
70
71
  gpuname = logical_gpus[self.gpupos % self.ngpu].name
@@ -92,10 +93,10 @@ class BkTensorflow(BackendBase.BackendBase):
92
93
  return self.backend.sparse.sparse_dense_matmul(smat, mat)
93
94
 
94
95
  # for tensorflow wrapping only
95
- def periodic_pad(self,x, pad_height, pad_width):
96
+ def periodic_pad(self, x, pad_height, pad_width):
96
97
  """
97
98
  Applies periodic ('wrap') padding to a 4D TensorFlow tensor (N, H, W, C).
98
-
99
+
99
100
  Args:
100
101
  x (tf.Tensor): Input tensor with shape (batch_size, height, width, channels).
101
102
  pad_height (tuple): Tuple (top, bottom) defining the vertical padding size.
@@ -104,23 +105,27 @@ class BkTensorflow(BackendBase.BackendBase):
104
105
  Returns:
105
106
  tf.Tensor: Tensor with periodic padding applied.
106
107
  """
107
- #Vertical padding: take slices from bottom and top to wrap around
108
- top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
108
+ # Vertical padding: take slices from bottom and top to wrap around
109
+ top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
109
110
  bottom_pad = x[:, :pad_height, :, :] # Bottom padding from the top rows
110
- x_padded = self.backend.concat([top_pad, x, bottom_pad], axis=1) # Concatenate vertically
111
+ x_padded = self.backend.concat(
112
+ [top_pad, x, bottom_pad], axis=1
113
+ ) # Concatenate vertically
114
+
115
+ # Horizontal padding: take slices from right and left to wrap around
116
+ left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
117
+ right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
111
118
 
112
- #Horizontal padding: take slices from right and left to wrap around
113
- left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
114
- right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
115
-
116
- x_padded = self.backend.concat([left_pad, x_padded, right_pad], axis=2) # Concatenate horizontally
119
+ x_padded = self.backend.concat(
120
+ [left_pad, x_padded, right_pad], axis=2
121
+ ) # Concatenate horizontally
117
122
 
118
123
  return x_padded
119
-
124
+
120
125
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
121
126
  kx = w.shape[0]
122
127
  ky = w.shape[1]
123
- x_padded = self.periodic_pad(x, kx // 2, ky // 2)
128
+ x_padded = self.periodic_pad(x, kx // 2, ky // 2)
124
129
  return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
125
130
 
126
131
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
@@ -133,11 +138,9 @@ class BkTensorflow(BackendBase.BackendBase):
133
138
  def bk_threshold(self, x, threshold, greater=True):
134
139
 
135
140
  return self.backend.cast(x > threshold, x.dtype) * x
136
-
137
141
 
138
142
  def bk_maximum(self, x1, x2):
139
143
  return self.backend.maximum(x1, x2)
140
-
141
144
 
142
145
  def bk_device(self, device_name):
143
146
  return self.backend.device(device_name)
@@ -153,22 +156,18 @@ class BkTensorflow(BackendBase.BackendBase):
153
156
  def bk_flattenR(self, x):
154
157
  if self.bk_is_complex(x):
155
158
  rr = self.backend.reshape(
156
- self.bk_real(x), [np.prod(np.array(list(x.shape)))]
157
- )
159
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
160
+ )
158
161
  ii = self.backend.reshape(
159
- self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
160
- )
162
+ self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
163
+ )
161
164
  return self.bk_concat([rr, ii], axis=0)
162
165
  else:
163
166
  return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
164
-
165
-
167
+
166
168
  def bk_flatten(self, x):
167
169
  return self.backend.flatten(x)
168
170
 
169
- def bk_size(self, x):
170
- return self.backend.size(x)
171
-
172
171
  def bk_resize_image(self, x, shape):
173
172
  return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
174
173
 
@@ -282,7 +281,7 @@ class BkTensorflow(BackendBase.BackendBase):
282
281
 
283
282
  def bk_tensor(self, data):
284
283
  return self.backend.constant(data)
285
-
284
+
286
285
  def bk_shape_tensor(self, shape):
287
286
  return self.backend.tensor(shape=shape)
288
287
 
@@ -320,10 +319,10 @@ class BkTensorflow(BackendBase.BackendBase):
320
319
  return self.backend.repeat(data, nn, axis=axis)
321
320
 
322
321
  def bk_tile(self, data, nn, axis=0):
323
- order=[1 for k in data.shape]
324
- order[axis]=nn
322
+ order = [1 for k in data.shape]
323
+ order[axis] = nn
325
324
  return self.backend.tile(data, self.backend.constant(order, tf.int32))
326
-
325
+
327
326
  def bk_roll(self, data, nn, axis=0):
328
327
  return self.backend.roll(data, nn, axis=axis)
329
328
 
@@ -338,12 +337,8 @@ class BkTensorflow(BackendBase.BackendBase):
338
337
  if axis is None:
339
338
  if data[0].dtype == self.all_cbk_type:
340
339
  ndata = len(data)
341
- xr = self.backend.concat(
342
- [self.bk_real(data[k]) for k in range(ndata)]
343
- )
344
- xi = self.backend.concat(
345
- [self.bk_imag(data[k]) for k in range(ndata)]
346
- )
340
+ xr = self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
341
+ xi = self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
347
342
  return self.bk_complex(xr, xi)
348
343
  else:
349
344
  return self.backend.concat(data)
@@ -351,10 +346,10 @@ class BkTensorflow(BackendBase.BackendBase):
351
346
  if data[0].dtype == self.all_cbk_type:
352
347
  ndata = len(data)
353
348
  xr = self.backend.concat(
354
- [self.bk_real(data[k]) for k in range(ndata)], axis=axis
355
- )
349
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
350
+ )
356
351
  xi = self.backend.concat(
357
- [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
352
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
358
353
  )
359
354
  return self.bk_complex(xr, xi)
360
355
  else:
@@ -363,38 +358,42 @@ class BkTensorflow(BackendBase.BackendBase):
363
358
  def bk_zeros(self, shape, dtype=None):
364
359
  return self.backend.zeros(shape, dtype=dtype)
365
360
 
366
- def bk_gather(self, data, idx):
367
- return self.backend.gather(data, idx)
361
+ def bk_gather(self, data, idx, axis=0):
362
+ return self.backend.gather(data, idx, axis=axis)
368
363
 
369
364
  def bk_reverse(self, data, axis=0):
370
365
  return self.backend.reverse(data, axis=[axis])
371
366
 
372
367
  def bk_fft(self, data):
373
368
  return self.backend.signal.fft(data)
374
-
375
369
 
376
- def bk_fftn(self, data,dim=None):
377
- #Equivalent of torch.fft.fftn(x, dim=dims) in TensorFlow
378
- if len(dim)==2:
379
- return self.backend.signal.fft2d(self.bk_complex(data, 0*data))
370
+ def bk_fftn(self, data, dim=None):
371
+ # Equivalent of torch.fft.fftn(x, dim=dims) in TensorFlow
372
+ if len(dim) == 2:
373
+ return self.backend.signal.fft2d(self.bk_complex(data, 0 * data))
380
374
  else:
381
- return self.backend.signal.fft1d(self.bk_complex(data, 0*data))
375
+ return self.backend.signal.fft1d(self.bk_complex(data, 0 * data))
382
376
 
383
- def bk_ifftn(self, data,dim=None,norm=None):
377
+ def bk_ifftn(self, data, dim=None, norm=None):
384
378
  if norm is not None:
385
- if len(dim)==2:
386
- normalization=self.backend.sqrt(self.backend.cast(data.shape[dim[0]]*data.shape[dim[1]], self.all_cbk_type))
387
- return self.backend.signal.ifft2d(data)*normalization
388
-
379
+ if len(dim) == 2:
380
+ normalization = self.backend.sqrt(
381
+ self.backend.cast(
382
+ data.shape[dim[0]] * data.shape[dim[1]], self.all_cbk_type
383
+ )
384
+ )
385
+ return self.backend.signal.ifft2d(data) * normalization
386
+
389
387
  else:
390
- normalization=self.backend.sqrt(self.backend.cast(data.shape[dim[0]], self.all_cbk_type))
391
- return self.backend.signal.ifft1d(data)*normalization
388
+ normalization = self.backend.sqrt(
389
+ self.backend.cast(data.shape[dim[0]], self.all_cbk_type)
390
+ )
391
+ return self.backend.signal.ifft1d(data) * normalization
392
392
  else:
393
- if len(dim)==2:
393
+ if len(dim) == 2:
394
394
  return self.backend.signal.ifft2d(data)
395
395
  else:
396
396
  return self.backend.signal.ifft1d(data)
397
-
398
397
 
399
398
  def bk_rfft(self, data):
400
399
  return self.backend.signal.rfft(data)
@@ -420,10 +419,10 @@ class BkTensorflow(BackendBase.BackendBase):
420
419
  else:
421
420
  return self.backend.nn.relu(x)
422
421
 
423
- def bk_clip_by_value(self, x,xmin,xmax):
422
+ def bk_clip_by_value(self, x, xmin, xmax):
424
423
  if isinstance(x, np.ndarray):
425
- x = np.clip(x,xmin,xmax)
426
- return self.backend.clip_by_value(x,xmin,xmax)
424
+ x = np.clip(x, xmin, xmax)
425
+ return self.backend.clip_by_value(x, xmin, xmax)
427
426
 
428
427
  def bk_cast(self, x):
429
428
  if isinstance(x, np.float64):
@@ -459,21 +458,30 @@ class BkTensorflow(BackendBase.BackendBase):
459
458
  out_type = self.all_bk_type
460
459
 
461
460
  return self.backend.cast(x, out_type)
462
-
463
- def bk_variable(self,x):
461
+
462
+ def bk_variable(self, x):
464
463
  return self.backend.Variable(x)
465
-
466
- def bk_assign(self,x,y):
467
- x.assign(y)
468
-
469
- def bk_constant(self,x):
464
+
465
+ def bk_assign(self, x, y):
466
+ return x.assign(y)
467
+
468
+ def bk_constant(self, x):
470
469
  return self.backend.constant(x)
471
-
472
- def bk_empty(self,list):
470
+
471
+ def bk_cos(self, x):
472
+ return self.backend.cos(x)
473
+
474
+ def bk_sin(self, x):
475
+ return self.backend.sin(x)
476
+
477
+ def bk_arctan2(self, c, s):
478
+ return self.backend.arctan2(c, s)
479
+
480
+ def bk_empty(self, list):
473
481
  return self.backend.constant(list)
474
-
475
- def to_numpy(self,x):
482
+
483
+ def to_numpy(self, x):
476
484
  if isinstance(x, np.ndarray):
477
485
  return x
478
-
486
+
479
487
  return x.numpy()