foscat 3.8.2__py3-none-any.whl → 3.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTorch.py CHANGED
@@ -1,15 +1,20 @@
1
1
  import sys
2
- import foscat.BkBase as BackendBase
2
+
3
3
  import numpy as np
4
4
  import torch
5
5
 
6
+ import foscat.BkBase as BackendBase
7
+
8
+
6
9
  class BkTorch(BackendBase.BackendBase):
7
-
10
+
8
11
  def __init__(self, *args, **kwargs):
9
12
  # Impose que use_2D=True pour la classe scat
10
- super().__init__(name='torch', *args, **kwargs)
13
+ super().__init__(name="torch", *args, **kwargs)
11
14
  self.backend = torch
12
- self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
+ self.device = (
16
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ )
13
18
 
14
19
  self.float64 = self.backend.float64
15
20
  self.float32 = self.backend.float32
@@ -22,11 +27,13 @@ class BkTorch(BackendBase.BackendBase):
22
27
  "float32": (self.backend.float32, self.backend.complex64),
23
28
  "float64": (self.backend.float64, self.backend.complex128),
24
29
  }
25
-
30
+
26
31
  if self.all_type in dtype_map:
27
32
  self.all_bk_type, self.all_cbk_type = dtype_map[self.all_type]
28
33
  else:
29
- raise ValueError(f"ERROR INIT foscat: {all_type} should be float32 or float64")
34
+ raise ValueError(
35
+ f"ERROR INIT foscat: {self.all_type} should be float32 or float64"
36
+ )
30
37
 
31
38
  # ===========================================================================
32
39
  # INIT
@@ -50,14 +57,20 @@ class BkTorch(BackendBase.BackendBase):
50
57
  except RuntimeError as e:
51
58
  # Memory growth must be set before GPUs have been initialized
52
59
  print(e)
53
-
54
- self.torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
+
61
+ self.torch_device = (
62
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
+ )
55
64
 
56
65
  # ---------------------------------------------−---------
57
66
  # -- BACKEND DEFINITION --
58
67
  # ---------------------------------------------−---------
59
68
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
60
- return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
69
+ return (
70
+ self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
71
+ .to_sparse_csr()
72
+ .to(self.torch_device)
73
+ )
61
74
 
62
75
  def bk_stack(self, list, axis=0):
63
76
  return self.backend.stack(list, axis=axis).to(self.torch_device)
@@ -65,21 +78,20 @@ class BkTorch(BackendBase.BackendBase):
65
78
  def bk_sparse_dense_matmul(self, smat, mat):
66
79
  return smat.matmul(mat)
67
80
 
68
-
69
-
70
81
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
71
82
  import torch.nn.functional as F
83
+
72
84
  lx = x.permute(0, 3, 1, 2)
73
85
  wx = w.permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
74
86
 
75
87
  # Calculer le padding symétrique
76
88
  kx, ky = w.shape[0], w.shape[1]
77
-
89
+
78
90
  # Appliquer le padding
79
- x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode='circular')
91
+ x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
80
92
 
81
93
  # Appliquer la convolution
82
- return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
94
+ return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
83
95
 
84
96
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
85
97
  # to be written!!!
@@ -108,7 +120,7 @@ class BkTorch(BackendBase.BackendBase):
108
120
  def bk_flattenR(self, x):
109
121
  if self.bk_is_complex(x):
110
122
  rr = self.backend.reshape(
111
- self.bk_real(x), [np.prod(np.array(list(x.shape)))]
123
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
112
124
  )
113
125
  ii = self.backend.reshape(
114
126
  self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
@@ -122,18 +134,18 @@ class BkTorch(BackendBase.BackendBase):
122
134
 
123
135
  def bk_resize_image(self, x, shape):
124
136
  tmp = self.backend.nn.functional.interpolate(
125
- x.permute(0,3,1,2), size=shape, mode="bilinear", align_corners=False
126
- )
127
- return self.bk_cast(tmp.permute(0,2,3,1))
137
+ x.permute(0, 3, 1, 2), size=shape, mode="bilinear", align_corners=False
138
+ )
139
+ return self.bk_cast(tmp.permute(0, 2, 3, 1))
128
140
 
129
141
  def bk_L1(self, x):
130
142
  if x.dtype == self.all_cbk_type:
131
143
  xr = self.bk_real(x)
132
- xi = self.bk_imag(x)
144
+ # xi = self.bk_imag(x)
133
145
 
134
146
  r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
135
147
  # return r
136
- i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
148
+ # i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
137
149
 
138
150
  return r
139
151
  else:
@@ -163,7 +175,6 @@ class BkTorch(BackendBase.BackendBase):
163
175
  def bk_size(self, data):
164
176
  return data.numel()
165
177
 
166
-
167
178
  def constant(self, data):
168
179
  return data
169
180
 
@@ -194,7 +205,7 @@ class BkTorch(BackendBase.BackendBase):
194
205
  r = self.backend.std(data)
195
206
  else:
196
207
  r = self.backend.std(data, axis)
197
-
208
+
198
209
  if self.bk_is_complex(data):
199
210
  return self.bk_complex(r, 0 * r)
200
211
  else:
@@ -245,7 +256,7 @@ class BkTorch(BackendBase.BackendBase):
245
256
 
246
257
  def bk_tensor(self, data):
247
258
  return self.backend.constant(data).to(self.torch_device)
248
-
259
+
249
260
  def bk_shape_tensor(self, shape):
250
261
  return self.backend.tensor(shape=shape).to(self.torch_device)
251
262
 
@@ -304,12 +315,8 @@ class BkTorch(BackendBase.BackendBase):
304
315
  if axis is None:
305
316
  if data[0].dtype == self.all_cbk_type:
306
317
  ndata = len(data)
307
- xr = self.backend.concat(
308
- [self.bk_real(data[k]) for k in range(ndata)]
309
- )
310
- xi = self.backend.concat(
311
- [self.bk_imag(data[k]) for k in range(ndata)]
312
- )
318
+ xr = self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
319
+ xi = self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
313
320
  return self.bk_complex(xr, xi)
314
321
  else:
315
322
  return self.backend.concat(data)
@@ -317,11 +324,11 @@ class BkTorch(BackendBase.BackendBase):
317
324
  if data[0].dtype == self.all_cbk_type:
318
325
  ndata = len(data)
319
326
  xr = self.backend.concat(
320
- [self.bk_real(data[k]) for k in range(ndata)], axis=axis
321
- )
327
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
328
+ )
322
329
  xi = self.backend.concat(
323
- [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
324
- )
330
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
331
+ )
325
332
  return self.bk_complex(xr, xi)
326
333
  else:
327
334
  return self.backend.concat(data, axis=axis)
@@ -329,28 +336,28 @@ class BkTorch(BackendBase.BackendBase):
329
336
  def bk_zeros(self, shape, dtype=None):
330
337
  return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
331
338
 
332
- def bk_gather(self, data, idx,axis=0):
333
- if axis==0:
339
+ def bk_gather(self, data, idx, axis=0):
340
+ if axis == 0:
334
341
  return data[idx]
335
- elif axis==1:
336
- return data[:,idx]
337
- elif axis==2:
338
- return data[:,:,idx]
339
- elif axis==3:
340
- return data[:,:,:,idx]
341
- return data[:,:,:,:,idx]
342
+ elif axis == 1:
343
+ return data[:, idx]
344
+ elif axis == 2:
345
+ return data[:, :, idx]
346
+ elif axis == 3:
347
+ return data[:, :, :, idx]
348
+ return data[:, :, :, :, idx]
342
349
 
343
350
  def bk_reverse(self, data, axis=0):
344
351
  return self.backend.flip(data, dims=[axis])
345
352
 
346
353
  def bk_fft(self, data):
347
354
  return self.backend.fft.fft(data)
348
-
349
- def bk_fftn(self, data,dim=None):
350
- return self.backend.fft.fftn(data,dim=dim)
351
355
 
352
- def bk_ifftn(self, data,dim=None,norm=None):
353
- return self.backend.fft.ifftn(data,dim=dim,norm=norm)
356
+ def bk_fftn(self, data, dim=None):
357
+ return self.backend.fft.fftn(data, dim=dim)
358
+
359
+ def bk_ifftn(self, data, dim=None, norm=None):
360
+ return self.backend.fft.ifftn(data, dim=dim, norm=norm)
354
361
 
355
362
  def bk_rfft(self, data):
356
363
  return self.backend.fft.rfft(data)
@@ -374,12 +381,24 @@ class BkTorch(BackendBase.BackendBase):
374
381
  def bk_relu(self, x):
375
382
  return self.backend.relu(x)
376
383
 
377
- def bk_clip_by_value(self, x,xmin,xmax):
384
+ def bk_clip_by_value(self, x, xmin, xmax):
378
385
  if isinstance(x, np.ndarray):
379
- x = np.clip(x,xmin,xmax)
380
- x = self.backend.tensor(x, dtype=self.backend.float32) if not isinstance(x, self.backend.Tensor) else x
381
- xmin = self.backend.tensor(xmin, dtype=self.backend.float32) if not isinstance(xmin, self.backend.Tensor) else xmin
382
- xmax = self.backend.tensor(xmax, dtype=self.backend.float32) if not isinstance(xmax, self.backend.Tensor) else xmax
386
+ x = np.clip(x, xmin, xmax)
387
+ x = (
388
+ self.backend.tensor(x, dtype=self.backend.float32)
389
+ if not isinstance(x, self.backend.Tensor)
390
+ else x
391
+ )
392
+ xmin = (
393
+ self.backend.tensor(xmin, dtype=self.backend.float32)
394
+ if not isinstance(xmin, self.backend.Tensor)
395
+ else xmin
396
+ )
397
+ xmax = (
398
+ self.backend.tensor(xmax, dtype=self.backend.float32)
399
+ if not isinstance(xmax, self.backend.Tensor)
400
+ else xmax
401
+ )
383
402
  return self.backend.clamp(x, min=xmin, max=xmax)
384
403
 
385
404
  def bk_cast(self, x):
@@ -424,31 +443,31 @@ class BkTorch(BackendBase.BackendBase):
424
443
  out_type = self.all_bk_type
425
444
 
426
445
  return x.type(out_type).to(self.torch_device)
427
-
428
- def bk_variable(self,x):
446
+
447
+ def bk_variable(self, x):
429
448
  return self.bk_cast(x)
430
-
431
- def bk_assign(self,x,y):
432
- x=y
433
-
434
- def bk_constant(self,x):
435
-
449
+
450
+ def bk_assign(self, x, y):
451
+ return y
452
+
453
+ def bk_constant(self, x):
454
+
436
455
  return self.bk_cast(x)
437
-
438
- def bk_cos(self,x):
456
+
457
+ def bk_cos(self, x):
439
458
  return self.backend.cos(x)
440
-
441
- def bk_sin(self,x):
459
+
460
+ def bk_sin(self, x):
442
461
  return self.backend.sin(x)
443
-
444
- def bk_arctan2(self,c,s):
445
- return self.backend.arctan2(c,s)
446
-
447
- def bk_empty(self,list):
462
+
463
+ def bk_arctan2(self, c, s):
464
+ return self.backend.arctan2(c, s)
465
+
466
+ def bk_empty(self, list):
448
467
  return self.backend.empty(list)
449
-
450
- def to_numpy(self,x):
468
+
469
+ def to_numpy(self, x):
451
470
  if isinstance(x, np.ndarray):
452
471
  return x
453
-
472
+
454
473
  return x.cpu().numpy()
foscat/FoCUS.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- import os
3
2
  import sys
4
3
 
5
4
  import healpy as hp
@@ -11,32 +10,32 @@ TMPFILE_VERSION = "V4_0"
11
10
 
12
11
  class FoCUS:
13
12
  def __init__(
14
- self,
15
- NORIENT=4,
16
- LAMBDA=1.2,
17
- KERNELSZ=3,
18
- slope=1.0,
19
- all_type="float32",
20
- nstep_max=16,
21
- padding="SAME",
22
- gpupos=0,
23
- mask_thres=None,
24
- mask_norm=False,
25
- isMPI=False,
26
- TEMPLATE_PATH="data",
27
- BACKEND="tensorflow",
28
- use_2D=False,
29
- use_1D=False,
30
- return_data=False,
31
- JmaxDelta=0,
32
- DODIV=False,
33
- InitWave=None,
34
- silent=True,
35
- mpi_size=1,
36
- mpi_rank=0,
13
+ self,
14
+ NORIENT=4,
15
+ LAMBDA=1.2,
16
+ KERNELSZ=3,
17
+ slope=1.0,
18
+ all_type="float32",
19
+ nstep_max=16,
20
+ padding="SAME",
21
+ gpupos=0,
22
+ mask_thres=None,
23
+ mask_norm=False,
24
+ isMPI=False,
25
+ TEMPLATE_PATH="data",
26
+ BACKEND="tensorflow",
27
+ use_2D=False,
28
+ use_1D=False,
29
+ return_data=False,
30
+ JmaxDelta=0,
31
+ DODIV=False,
32
+ InitWave=None,
33
+ silent=True,
34
+ mpi_size=1,
35
+ mpi_rank=0,
37
36
  ):
38
37
 
39
- self.__version__ = "3.8.2"
38
+ self.__version__ = "3.9.0"
40
39
  # P00 coeff for normalization for scat_cov
41
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
41
  self.P1_dic = None
@@ -45,7 +44,7 @@ class FoCUS:
45
44
  self.mask_thres = mask_thres
46
45
  self.mask_norm = mask_norm
47
46
  self.InitWave = InitWave
48
- self.mask_mask=None
47
+ self.mask_mask = None
49
48
  self.mpi_size = mpi_size
50
49
  self.mpi_rank = mpi_rank
51
50
  self.return_data = return_data
@@ -105,31 +104,34 @@ class FoCUS:
105
104
 
106
105
  self.all_type = all_type
107
106
  self.BACKEND = BACKEND
108
-
109
- if BACKEND=='torch':
107
+
108
+ if BACKEND == "torch":
110
109
  from foscat.BkTorch import BkTorch
110
+
111
111
  self.backend = BkTorch(
112
112
  all_type=all_type,
113
113
  mpi_rank=mpi_rank,
114
114
  gpupos=gpupos,
115
115
  silent=self.silent,
116
- )
117
- elif BACKEND=='tensorflow':
116
+ )
117
+ elif BACKEND == "tensorflow":
118
118
  from foscat.BkTensorflow import BkTensorflow
119
+
119
120
  self.backend = BkTensorflow(
120
121
  all_type=all_type,
121
122
  mpi_rank=mpi_rank,
122
123
  gpupos=gpupos,
123
124
  silent=self.silent,
124
- )
125
+ )
125
126
  else:
126
127
  from foscat.BkNumpy import BkNumpy
128
+
127
129
  self.backend = BkNumpy(
128
130
  all_type=all_type,
129
131
  mpi_rank=mpi_rank,
130
132
  gpupos=gpupos,
131
133
  silent=self.silent,
132
- )
134
+ )
133
135
 
134
136
  self.all_bk_type = self.backend.all_bk_type
135
137
  self.all_cbk_type = self.backend.all_cbk_type
@@ -172,9 +174,9 @@ class FoCUS:
172
174
  self.Y_CNN = {}
173
175
  self.Z_CNN = {}
174
176
 
175
- self.filters_set={}
176
- self.edge_masks={}
177
-
177
+ self.filters_set = {}
178
+ self.edge_masks = {}
179
+
178
180
  wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
179
181
  wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
180
182
 
@@ -209,15 +211,27 @@ class FoCUS:
209
211
  w_smooth = w_smooth.flatten()
210
212
  else:
211
213
  for i in range(NORIENT):
212
- a = (NORIENT-1-i) / float(NORIENT) * np.pi # get the same angle number than scattering lib
213
- if KERNELSZ<5:
214
- xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
215
- yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
214
+ a = (
215
+ (NORIENT - 1 - i) / float(NORIENT) * np.pi
216
+ ) # get the same angle number than scattering lib
217
+ if KERNELSZ < 5:
218
+ xx = (
219
+ (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
220
+ )
221
+ yy = (
222
+ (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
223
+ )
216
224
  else:
217
- xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
218
- yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
225
+ xx = (3 / 5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
226
+ yy = (3 / 5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
219
227
  if KERNELSZ == 5:
220
- w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
228
+ w_smooth = np.exp(
229
+ -2
230
+ * (
231
+ (3.0 / float(KERNELSZ) * xx) ** 2
232
+ + (3.0 / float(KERNELSZ) * yy) ** 2
233
+ )
234
+ )
221
235
  else:
222
236
  w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
223
237
  tmp1 = np.cos(yy * np.pi) * w_smooth
@@ -225,7 +239,7 @@ class FoCUS:
225
239
 
226
240
  wwc[:, i] = tmp1.flatten() - tmp1.mean()
227
241
  wws[:, i] = tmp2.flatten() - tmp2.mean()
228
- #sigma = np.sqrt((wwc[:, i] ** 2).mean())
242
+ # sigma = np.sqrt((wwc[:, i] ** 2).mean())
229
243
  sigma = np.mean(w_smooth)
230
244
  wwc[:, i] /= sigma
231
245
  wws[:, i] /= sigma
@@ -239,7 +253,7 @@ class FoCUS:
239
253
 
240
254
  wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
241
255
  wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
242
- #sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
256
+ # sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
243
257
  sigma = np.mean(w_smooth)
244
258
 
245
259
  wwc[:, NORIENT] /= sigma
@@ -249,13 +263,13 @@ class FoCUS:
249
263
 
250
264
  wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
251
265
  wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
252
- #sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
266
+ # sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
253
267
  sigma = np.mean(w_smooth)
254
268
  wwc[:, NORIENT + 1] /= sigma
255
269
  wws[:, NORIENT + 1] /= sigma
256
270
 
257
271
  w_smooth = w_smooth.flatten()
258
-
272
+
259
273
  if self.use_1D:
260
274
  KERNELSZ = 5
261
275
 
@@ -723,7 +737,7 @@ class FoCUS:
723
737
  def ud_grade(self, im, j, axis=0):
724
738
  rim = im
725
739
  for k in range(j):
726
- #rim = self.smooth(rim, axis=axis)
740
+ # rim = self.smooth(rim, axis=axis)
727
741
  rim = self.ud_grade_2(rim, axis=axis)
728
742
  return rim
729
743
 
@@ -1794,14 +1808,14 @@ class FoCUS:
1794
1808
  if self.padding == "VALID":
1795
1809
  l_mask = l_mask[
1796
1810
  :,
1797
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1798
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1811
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1812
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1799
1813
  ]
1800
1814
  if shape[axis] != l_mask.shape[1]:
1801
1815
  l_mask = l_mask[
1802
1816
  :,
1803
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1804
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1817
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1818
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1805
1819
  ]
1806
1820
 
1807
1821
  ichannel = 1
@@ -1868,10 +1882,10 @@ class FoCUS:
1868
1882
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1869
1883
 
1870
1884
  if self.use_2D:
1871
- #if self.padding == "VALID":
1885
+ # if self.padding == "VALID":
1872
1886
  mtmp = l_mask
1873
1887
  vtmp = l_x
1874
- #else:
1888
+ # else:
1875
1889
  # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1876
1890
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1877
1891
 
@@ -2707,9 +2721,11 @@ class FoCUS:
2707
2721
  # ---------------------------------------------−---------
2708
2722
  def get_ww(self, nside=1):
2709
2723
  if self.use_2D:
2710
-
2711
- return (self.ww_RealT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT),
2712
- self.ww_ImagT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT))
2724
+
2725
+ return (
2726
+ self.ww_RealT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
2727
+ self.ww_ImagT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
2728
+ )
2713
2729
  else:
2714
2730
  return (self.ww_Real[nside], self.ww_Imag[nside])
2715
2731
 
foscat/Synthesis.py CHANGED
@@ -240,9 +240,9 @@ class Synthesis:
240
240
  grd_mask = self.grd_mask
241
241
 
242
242
  if grd_mask is not None:
243
- g_tot = grd_mask * self.to_numpy(g_tot)
243
+ g_tot = self.operation.backend.to_numpy(g_tot*grd_mask)
244
244
  else:
245
- g_tot = self.to_numpy(g_tot)
245
+ g_tot = self.operation.backend.to_numpy(g_tot)
246
246
 
247
247
  g_tot[np.isnan(g_tot)] = 0.0
248
248
 
@@ -426,7 +426,7 @@ class Synthesis:
426
426
  factr=factr,
427
427
  maxiter=maxitt,
428
428
  )
429
- print('Final Loss ',loss)
429
+ print("Final Loss ", loss)
430
430
  # update bias input data
431
431
  if iteration < NUM_STEP_BIAS - 1:
432
432
  # if self.mpi_rank==0: