foscat 3.0.8__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -1,548 +1,1061 @@
1
1
  import sys
2
+
2
3
  import numpy as np
3
4
 
5
+
4
6
  class foscat_backend:
5
-
6
- def __init__(self,name,mpi_rank=0,all_type='float64',gpupos=0):
7
-
8
- self.TENSORFLOW=1
9
- self.TORCH=2
10
- self.NUMPY=3
11
-
7
+
8
+ def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
9
+
10
+ self.TENSORFLOW = 1
11
+ self.TORCH = 2
12
+ self.NUMPY = 3
13
+
12
14
  # table use to compute the iso orientation rotation
13
- self._iso_orient={}
14
- self._iso_orient_T={}
15
- self._iso_orient_C={}
16
- self._iso_orient_C_T={}
17
- self._fft_1_orient={}
18
- self._fft_1_orient_C={}
19
- self._fft_2_orient={}
20
- self._fft_2_orient_C={}
21
- self._fft_3_orient={}
22
- self._fft_3_orient_C={}
23
-
24
- self.BACKEND=name
25
-
26
- if name not in ['tensorflow','torch','numpy']:
27
- print('Backend %s not yet implemented')
28
- exit(0)
29
-
30
- if self.BACKEND=='tensorflow':
15
+ self._iso_orient = {}
16
+ self._iso_orient_T = {}
17
+ self._iso_orient_C = {}
18
+ self._iso_orient_C_T = {}
19
+ self._fft_1_orient = {}
20
+ self._fft_1_orient_C = {}
21
+ self._fft_2_orient = {}
22
+ self._fft_2_orient_C = {}
23
+ self._fft_3_orient = {}
24
+ self._fft_3_orient_C = {}
25
+
26
+ self.BACKEND = name
27
+
28
+ if name not in ["tensorflow", "torch", "numpy"]:
29
+ print('Backend "%s" not yet implemented' % (name))
30
+ print(" Choose inside the next 3 available backends :")
31
+ print(" - tensorflow")
32
+ print(" - torch")
33
+ print(" - numpy (Impossible to do synthesis using numpy)")
34
+ return None
35
+
36
+ if self.BACKEND == "tensorflow":
31
37
  import tensorflow as tf
32
-
33
- self.backend=tf
34
- self.BACKEND=self.TENSORFLOW
35
- #tf.config.threading.set_inter_op_parallelism_threads(1)
36
- #tf.config.threading.set_intra_op_parallelism_threads(1)
37
38
 
38
- if self.BACKEND=='torch':
39
+ self.backend = tf
40
+ self.BACKEND = self.TENSORFLOW
41
+ # tf.config.threading.set_inter_op_parallelism_threads(1)
42
+ # tf.config.threading.set_intra_op_parallelism_threads(1)
43
+ self.tf_function = tf.function
44
+
45
+ if self.BACKEND == "torch":
39
46
  import torch
40
- self.BACKEND=self.TORCH
41
- self.backend=torch
42
-
43
- if self.BACKEND=='numpy':
44
- self.BACKEND=self.NUMPY
45
- self.backend=np
46
-
47
- self.float64=self.backend.float64
48
- self.float32=self.backend.float32
49
- self.int64=self.backend.int64
50
- self.int32=self.backend.int32
51
- self.complex64=self.backend.complex128
52
- self.complex128=self.backend.complex64
53
-
54
- if all_type=='float32':
55
- self.all_bk_type=self.backend.float32
56
- self.all_cbk_type=self.backend.complex64
47
+
48
+ self.BACKEND = self.TORCH
49
+ self.backend = torch
50
+ self.tf_function = self.tf_loc_function
51
+
52
+ if self.BACKEND == "numpy":
53
+ self.BACKEND = self.NUMPY
54
+ self.backend = np
55
+ import scipy as scipy
56
+
57
+ self.scipy = scipy
58
+ self.tf_function = self.tf_loc_function
59
+
60
+ self.float64 = self.backend.float64
61
+ self.float32 = self.backend.float32
62
+ self.int64 = self.backend.int64
63
+ self.int32 = self.backend.int32
64
+ self.complex64 = self.backend.complex128
65
+ self.complex128 = self.backend.complex64
66
+
67
+ if all_type == "float32":
68
+ self.all_bk_type = self.backend.float32
69
+ self.all_cbk_type = self.backend.complex64
57
70
  else:
58
- if all_type=='float64':
59
- self.all_type='float64'
60
- self.all_bk_type=self.backend.float64
61
- self.all_cbk_type=self.backend.complex128
71
+ if all_type == "float64":
72
+ self.all_type = "float64"
73
+ self.all_bk_type = self.backend.float64
74
+ self.all_cbk_type = self.backend.complex128
62
75
  else:
63
- print('ERROR INIT FOCUS ',all_type,' should be float32 or float64')
64
- exit(0)
65
- #===========================================================================
66
- # INIT
67
- if mpi_rank==0:
68
- if self.BACKEND==self.TENSORFLOW:
69
- print("Num GPUs Available: ", len(self.backend.config.experimental.list_physical_devices('GPU')))
76
+ print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
77
+ return None
78
+ # ===========================================================================
79
+ # INIT
80
+ if mpi_rank == 0:
81
+ if self.BACKEND == self.TENSORFLOW and not silent:
82
+ print(
83
+ "Num GPUs Available: ",
84
+ len(self.backend.config.experimental.list_physical_devices("GPU")),
85
+ )
70
86
  sys.stdout.flush()
71
-
72
- if self.BACKEND==self.TENSORFLOW:
87
+
88
+ if self.BACKEND == self.TENSORFLOW:
73
89
  self.backend.debugging.set_log_device_placement(False)
74
90
  self.backend.config.set_soft_device_placement(True)
75
-
76
- gpus = self.backend.config.experimental.list_physical_devices('GPU')
77
-
78
- if self.BACKEND==self.TORCH:
79
- gpus=torch.cuda.is_available()
80
-
81
- if self.BACKEND==self.NUMPY:
82
- gpus=[]
83
- gpuname='CPU:0'
84
- self.gpulist={}
85
- self.gpulist[0]=gpuname
86
- self.ngpu=1
87
-
91
+
92
+ gpus = self.backend.config.experimental.list_physical_devices("GPU")
93
+
94
+ if self.BACKEND == self.TORCH:
95
+ gpus = torch.cuda.is_available()
96
+
97
+ if self.BACKEND == self.NUMPY:
98
+ gpus = []
99
+ gpuname = "CPU:0"
100
+ self.gpulist = {}
101
+ self.gpulist[0] = gpuname
102
+ self.ngpu = 1
103
+
88
104
  if gpus:
89
105
  try:
90
- if self.BACKEND==self.TENSORFLOW:
91
- # Currently, memory growth needs to be the same across GPUs
106
+ if self.BACKEND == self.TENSORFLOW:
107
+ # Currently, memory growth needs to be the same across GPUs
92
108
  for gpu in gpus:
93
109
  self.backend.config.experimental.set_memory_growth(gpu, True)
94
- logical_gpus = self.backend.config.experimental.list_logical_devices('GPU')
95
- print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
110
+ logical_gpus = (
111
+ self.backend.config.experimental.list_logical_devices("GPU")
112
+ )
113
+ print(
114
+ len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
115
+ )
96
116
  sys.stdout.flush()
97
- self.ngpu=len(logical_gpus)
98
- gpuname=logical_gpus[gpupos%self.ngpu].name
99
- self.gpulist={}
117
+ self.ngpu = len(logical_gpus)
118
+ gpuname = logical_gpus[gpupos % self.ngpu].name
119
+ self.gpulist = {}
100
120
  for i in range(self.ngpu):
101
- self.gpulist[i]=logical_gpus[i].name
102
- if self.BACKEND==self.TORCH:
103
- self.ngpu=torch.cuda.device_count()
104
- self.gpulist={}
121
+ self.gpulist[i] = logical_gpus[i].name
122
+ if self.BACKEND == self.TORCH:
123
+ self.ngpu = torch.cuda.device_count()
124
+ self.gpulist = {}
105
125
  for k in range(self.ngpu):
106
- self.gpulist[k]=torch.cuda.get_device_name(0)
126
+ self.gpulist[k] = torch.cuda.get_device_name(0)
107
127
 
108
128
  except RuntimeError as e:
109
129
  # Memory growth must be set before GPUs have been initialized
110
130
  print(e)
111
131
 
112
- def calc_iso_orient(self,norient):
113
- tmp=np.zeros([norient*norient,norient])
132
+ def tf_loc_function(self, func):
133
+ return func
134
+
135
+ def calc_iso_orient(self, norient):
136
+ tmp = np.zeros([norient * norient, norient])
114
137
  for i in range(norient):
115
138
  for j in range(norient):
116
- tmp[j*norient+(j+i)%norient,i]=0.25
117
-
118
- self._iso_orient[norient]=self.constant(self.bk_cast(tmp))
119
- self._iso_orient_T[norient]=self.constant(self.bk_cast(4*tmp.T))
120
- self._iso_orient_C[norient]=self.bk_complex(self._iso_orient[norient],0*self._iso_orient[norient])
121
- self._iso_orient_C_T[norient]=self.bk_complex(self._iso_orient_T[norient],0*self._iso_orient_T[norient])
122
-
123
- def calc_fft_orient(self,norient,nharm):
139
+ tmp[j * norient + (j + i) % norient, i] = 0.25
124
140
 
125
- x=np.arange(norient)/norient*2*np.pi
126
-
127
- tmp=np.zeros([norient,1+nharm])
128
- for k in range(nharm+1):
129
- tmp[:,k]=np.cos(x*k)
130
-
131
- self._fft_1_orient[(norient,nharm)]=self.constant(self.bk_cast(tmp))
132
- self._fft_1_orient_C[(norient,nharm)]=self.bk_complex(self._fft_1_orient[(norient,nharm)],0*self._fft_1_orient[(norient,nharm)])
133
-
134
- x=np.repeat(x,norient).reshape(norient,norient)
135
-
136
- tmp=np.zeros([norient,norient,(1+nharm),(1+nharm)])
137
-
138
- for k in range(nharm+1):
139
- for l in range(nharm+1):
140
- tmp[:,:,k,l]=np.cos(x*k)*np.cos((x.T)*l)
141
-
142
- self._fft_2_orient[(norient,nharm)]=self.constant(self.bk_cast(tmp.reshape(norient*norient,(1+nharm)*(1+nharm))))
143
- self._fft_2_orient_C[(norient,nharm)]=self.bk_complex(self._fft_2_orient[(norient,nharm)],0*self._fft_2_orient[(norient,nharm)])
144
-
145
- tmp=np.zeros([norient,norient,norient,(1+nharm),(1+nharm),(1+nharm)])
146
- x=np.arange(norient)/norient*2*np.pi
147
- xx=np.zeros([norient,norient,norient])
148
- yy=np.zeros([norient,norient,norient])
149
- zz=np.zeros([norient,norient,norient])
141
+ self._iso_orient[norient] = self.constant(self.bk_cast(tmp))
142
+ self._iso_orient_T[norient] = self.constant(self.bk_cast(4 * tmp.T))
143
+ self._iso_orient_C[norient] = self.bk_complex(
144
+ self._iso_orient[norient], 0 * self._iso_orient[norient]
145
+ )
146
+ self._iso_orient_C_T[norient] = self.bk_complex(
147
+ self._iso_orient_T[norient], 0 * self._iso_orient_T[norient]
148
+ )
149
+
150
+ def calc_fft_orient(self, norient, nharm, imaginary):
151
+
152
+ x = np.arange(norient) / norient * 2 * np.pi
153
+
154
+ if imaginary:
155
+ tmp = np.zeros([norient, 1 + nharm * 2])
156
+ tmp[:, 0] = 1.0
157
+ for k in range(nharm):
158
+ tmp[:, k * 2 + 1] = np.cos(x * (k + 1))
159
+ tmp[:, k * 2 + 2] = np.sin(x * (k + 1))
160
+
161
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
162
+ self.constant(tmp)
163
+ )
164
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
165
+ self._fft_1_orient[(norient, nharm, imaginary)],
166
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
167
+ )
168
+ else:
169
+ tmp = np.zeros([norient, 1 + nharm])
170
+ for k in range(nharm + 1):
171
+ tmp[:, k] = np.cos(x * k)
172
+
173
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
174
+ self.constant(tmp)
175
+ )
176
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
177
+ self._fft_1_orient[(norient, nharm, imaginary)],
178
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
179
+ )
180
+
181
+ x = np.repeat(x, norient).reshape(norient, norient)
182
+
183
+ if imaginary:
184
+ tmp = np.zeros([norient, norient, (1 + nharm * 2), (1 + nharm * 2)])
185
+ tmp[:, :, 0, 0] = 1.0
186
+ for k in range(nharm):
187
+ tmp[:, :, k * 2 + 1, 0] = np.cos(x * (k + 1))
188
+ tmp[:, :, k * 2 + 2, 0] = np.sin(x * (k + 1))
189
+ tmp[:, :, 0, k * 2 + 1] = np.cos((x.T) * (k + 1))
190
+ tmp[:, :, 0, k * 2 + 2] = np.sin((x.T) * (k + 1))
191
+ for l_orient in range(nharm):
192
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 1] = np.cos(
193
+ x * (k + 1)
194
+ ) * np.cos((x.T) * (l_orient + 1))
195
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 1] = np.sin(
196
+ x * (k + 1)
197
+ ) * np.cos((x.T) * (l_orient + 1))
198
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 2] = np.cos(
199
+ x * (k + 1)
200
+ ) * np.sin((x.T) * (l_orient + 1))
201
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 2] = np.sin(
202
+ x * (k + 1)
203
+ ) * np.sin((x.T) * (l_orient + 1))
204
+
205
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
206
+ self.constant(
207
+ tmp.reshape(norient * norient, (1 + 2 * nharm) * (1 + 2 * nharm))
208
+ )
209
+ )
210
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
211
+ self._fft_2_orient[(norient, nharm, imaginary)],
212
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
213
+ )
214
+ else:
215
+ tmp = np.zeros([norient, norient, (1 + nharm), (1 + nharm)])
216
+
217
+ for k in range(nharm + 1):
218
+ for l_orient in range(nharm + 1):
219
+ tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
220
+
221
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
222
+ self.constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
223
+ )
224
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
225
+ self._fft_2_orient[(norient, nharm, imaginary)],
226
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
227
+ )
228
+
229
+ x = np.arange(norient) / norient * 2 * np.pi
230
+ xx = np.zeros([norient, norient, norient])
231
+ yy = np.zeros([norient, norient, norient])
232
+ zz = np.zeros([norient, norient, norient])
150
233
  for i in range(norient):
151
234
  for j in range(norient):
152
- xx[:,i,j]=x
153
- yy[i,:,j]=x
154
- zz[i,j,:]=x
155
-
156
- for k in range(nharm+1):
157
- for l in range(nharm+1):
158
- for m in range(nharm+1):
159
- tmp[:,:,:,k,l,m]=np.cos(xx*k)*np.cos(yy*l)*np.cos(zz*m)
235
+ xx[:, i, j] = x
236
+ yy[i, :, j] = x
237
+ zz[i, j, :] = x
238
+
239
+ if imaginary:
240
+ tmp = np.ones(
241
+ [
242
+ norient,
243
+ norient,
244
+ norient,
245
+ (1 + nharm * 2),
246
+ (1 + nharm * 2),
247
+ (1 + nharm * 2),
248
+ ]
249
+ )
250
+
251
+ for k in range(nharm):
252
+ tmp[:, :, :, k * 2 + 1, 0, 0] = np.cos(xx * (k + 1))
253
+ tmp[:, :, :, 0, k * 2 + 1, 0] = np.cos(yy * (k + 1))
254
+ tmp[:, :, :, 0, 0, k * 2 + 1] = np.cos(zz * (k + 1))
255
+
256
+ tmp[:, :, :, k * 2 + 2, 0, 0] = np.sin(xx * (k + 1))
257
+ tmp[:, :, :, 0, k * 2 + 2, 0] = np.sin(yy * (k + 1))
258
+ tmp[:, :, :, 0, 0, k * 2 + 2] = np.sin(zz * (k + 1))
259
+ for l_orient in range(nharm):
260
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, 0] = np.cos(
261
+ xx * (k + 1)
262
+ ) * np.cos(yy * (l_orient + 1))
263
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, 0] = np.cos(
264
+ xx * (k + 1)
265
+ ) * np.sin(yy * (l_orient + 1))
266
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, 0] = np.sin(
267
+ xx * (k + 1)
268
+ ) * np.cos(yy * (l_orient + 1))
269
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, 0] = np.sin(
270
+ xx * (k + 1)
271
+ ) * np.sin(yy * (l_orient + 1))
272
+
273
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 1] = np.cos(
274
+ xx * (k + 1)
275
+ ) * np.cos(zz * (l_orient + 1))
276
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 2] = np.cos(
277
+ xx * (k + 1)
278
+ ) * np.sin(zz * (l_orient + 1))
279
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 1] = np.sin(
280
+ xx * (k + 1)
281
+ ) * np.cos(zz * (l_orient + 1))
282
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 2] = np.sin(
283
+ xx * (k + 1)
284
+ ) * np.sin(zz * (l_orient + 1))
285
+
286
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 1] = np.cos(
287
+ yy * (k + 1)
288
+ ) * np.cos(zz * (l_orient + 1))
289
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 2] = np.cos(
290
+ yy * (k + 1)
291
+ ) * np.sin(zz * (l_orient + 1))
292
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 1] = np.sin(
293
+ yy * (k + 1)
294
+ ) * np.cos(zz * (l_orient + 1))
295
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 2] = np.sin(
296
+ yy * (k + 1)
297
+ ) * np.sin(zz * (l_orient + 1))
298
+
299
+ for m in range(nharm):
300
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 1] = (
301
+ np.cos(xx * (k + 1))
302
+ * np.cos(yy * (l_orient + 1))
303
+ * np.cos(zz * (m + 1))
304
+ )
305
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 2] = (
306
+ np.cos(xx * (k + 1))
307
+ * np.cos(yy * (l_orient + 1))
308
+ * np.sin(zz * (m + 1))
309
+ )
310
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 1] = (
311
+ np.cos(xx * (k + 1))
312
+ * np.sin(yy * (l_orient + 1))
313
+ * np.cos(zz * (m + 1))
314
+ )
315
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 2] = (
316
+ np.cos(xx * (k + 1))
317
+ * np.sin(yy * (l_orient + 1))
318
+ * np.sin(zz * (m + 1))
319
+ )
320
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 1] = (
321
+ np.sin(xx * (k + 1))
322
+ * np.cos(yy * (l_orient + 1))
323
+ * np.cos(zz * (m + 1))
324
+ )
325
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 2] = (
326
+ np.sin(xx * (k + 1))
327
+ * np.cos(yy * (l_orient + 1))
328
+ * np.sin(zz * (m + 1))
329
+ )
330
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 1] = (
331
+ np.sin(xx * (k + 1))
332
+ * np.sin(yy * (l_orient + 1))
333
+ * np.cos(zz * (m + 1))
334
+ )
335
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 2] = (
336
+ np.sin(xx * (k + 1))
337
+ * np.sin(yy * (l_orient + 1))
338
+ * np.sin(zz * (m + 1))
339
+ )
340
+
341
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
342
+ self.constant(
343
+ tmp.reshape(
344
+ norient * norient * norient,
345
+ (1 + nharm * 2) * (1 + nharm * 2) * (1 + nharm * 2),
346
+ )
347
+ )
348
+ )
349
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
350
+ self._fft_3_orient[(norient, nharm, imaginary)],
351
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
352
+ )
353
+ else:
354
+ tmp = np.zeros(
355
+ [norient, norient, norient, (1 + nharm), (1 + nharm), (1 + nharm)]
356
+ )
357
+
358
+ for k in range(nharm + 1):
359
+ for l_orient in range(nharm + 1):
360
+ for m in range(nharm + 1):
361
+ tmp[:, :, :, k, l_orient, m] = (
362
+ np.cos(xx * k) * np.cos(yy * l_orient) * np.cos(zz * m)
363
+ )
364
+
365
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
366
+ self.constant(
367
+ tmp.reshape(
368
+ norient * norient * norient,
369
+ (1 + nharm) * (1 + nharm) * (1 + nharm),
370
+ )
371
+ )
372
+ )
373
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
374
+ self._fft_3_orient[(norient, nharm, imaginary)],
375
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
376
+ )
160
377
 
161
- self._fft_3_orient[(norient,nharm)]=self.constant(self.bk_cast(tmp.reshape(norient*norient*norient,(1+nharm)*(1+nharm)*(1+nharm))))
162
- self._fft_3_orient_C[(norient,nharm)]=self.bk_complex(self._fft_3_orient[(norient,nharm)],0*self._fft_3_orient[(norient,nharm)])
163
-
164
378
  # ---------------------------------------------−---------
165
379
  # -- BACKEND DEFINITION --
166
380
  # ---------------------------------------------−---------
167
- def bk_SparseTensor(self,indice,w,dense_shape=[]):
168
- if self.BACKEND==self.TENSORFLOW:
169
- return(self.backend.SparseTensor(indice,w,dense_shape=dense_shape))
170
- if self.BACKEND==self.TORCH:
171
- return(self.backend.sparse_coo_tensor(indice.T,w,dense_shape))
172
- if self.BACKEND==self.NUMPY:
173
- return np.sparse_matrix(indice,w,dense_shape=dense_shape)
174
-
175
- def bk_sparse_dense_matmul(self,smat,mat):
176
- if self.BACKEND==self.TENSORFLOW:
177
- return self.backend.sparse.sparse_dense_matmul(smat,mat)
178
- if self.BACKEND==self.TORCH:
381
+ def bk_SparseTensor(self, indice, w, dense_shape=[]):
382
+ if self.BACKEND == self.TENSORFLOW:
383
+ return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
384
+ if self.BACKEND == self.TORCH:
385
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
386
+ if self.BACKEND == self.NUMPY:
387
+ return self.scipy.sparse.coo_matrix(
388
+ (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
389
+ )
390
+
391
+ def bk_stack(self, list, axis=0):
392
+ if self.BACKEND == self.TENSORFLOW:
393
+ return self.backend.stack(list, axis=axis)
394
+ if self.BACKEND == self.TORCH:
395
+ return self.backend.stack(list, axis=axis)
396
+ if self.BACKEND == self.NUMPY:
397
+ return self.backend.stack(list, axis=axis)
398
+
399
+ def bk_sparse_dense_matmul(self, smat, mat):
400
+ if self.BACKEND == self.TENSORFLOW:
401
+ return self.backend.sparse.sparse_dense_matmul(smat, mat)
402
+ if self.BACKEND == self.TORCH:
179
403
  return smat.matmul(mat)
180
- if self.BACKEND==self.NUMPY:
181
- return np.sparse.sparse_dense_matmul(smat,mat)
182
-
183
- def conv2d(self,x,w,strides=[1, 1, 1, 1],padding='SAME'):
184
- if self.BACKEND==self.TENSORFLOW:
185
- return self.backend.nn.conv2d(x,w,
186
- strides=strides,
187
- padding=padding)
404
+ if self.BACKEND == self.NUMPY:
405
+ return smat.dot(mat)
406
+
407
+ def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
+ if self.BACKEND == self.TENSORFLOW:
409
+ kx = w.shape[0]
410
+ ky = w.shape[1]
411
+ paddings = self.backend.constant(
412
+ [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
+ )
414
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
+ return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
188
416
  # to be written!!!
189
- if self.BACKEND==self.TORCH:
417
+ if self.BACKEND == self.TORCH:
190
418
  return x
191
- if self.BACKEND==self.NUMPY:
419
+ if self.BACKEND == self.NUMPY:
420
+ res = np.zeros(
421
+ [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
422
+ )
423
+ for k in range(w.shape[2]):
424
+ for l_orient in range(w.shape[3]):
425
+ for j in range(res.shape[0]):
426
+ tmp = self.scipy.signal.convolve2d(
427
+ x[j, :, :, k],
428
+ w[:, :, k, l_orient],
429
+ mode="same",
430
+ boundary="symm",
431
+ )
432
+ res[j, :, :, l_orient] += tmp
433
+ del tmp
434
+ return res
435
+
436
+ def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
437
+ if self.BACKEND == self.TENSORFLOW:
438
+ kx = w.shape[0]
439
+ paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
440
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
441
+
442
+ return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
443
+ # to be written!!!
444
+ if self.BACKEND == self.TORCH:
192
445
  return x
446
+ if self.BACKEND == self.NUMPY:
447
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[2]], dtype=x.dtype)
448
+ for k in range(w.shape[2]):
449
+ for j in range(res.shape[0]):
450
+ tmp = self.scipy.signal.convolve1d(
451
+ x[j, :, k], w[:, k], mode="same", boundary="symm"
452
+ )
453
+ res[j, :, :] += tmp
454
+ del tmp
455
+ return res
193
456
 
194
- def bk_threshold(self,x,threshold,greater=True):
195
-
196
- if self.BACKEND==self.TENSORFLOW:
197
- return(self.backend.cast(x>threshold,x.dtype)*x)
198
- if self.BACKEND==self.TORCH:
199
- return(self.backend.cast(x>threshold,x.dtype)*x)
200
- if self.BACKEND==self.NUMPY:
201
- return (x>threshold)*x
202
-
203
- def bk_maximum(self,x1,x2):
204
- if self.BACKEND==self.TENSORFLOW:
205
- return(self.backend.maximum(x1,x2))
206
- if self.BACKEND==self.TORCH:
207
- return(self.backend.maximum(x1,x2))
208
- if self.BACKEND==self.NUMPY:
209
- return x1*(x1>x2)+x2*(x2>x1)
210
-
211
- def bk_device(self,device_name):
457
+ def bk_threshold(self, x, threshold, greater=True):
458
+
459
+ if self.BACKEND == self.TENSORFLOW:
460
+ return self.backend.cast(x > threshold, x.dtype) * x
461
+ if self.BACKEND == self.TORCH:
462
+ x.to(x.dtype)
463
+ return (x > threshold) * x
464
+ # return(self.backend.cast(x>threshold,x.dtype)*x)
465
+ if self.BACKEND == self.NUMPY:
466
+ return (x > threshold) * x
467
+
468
+ def bk_maximum(self, x1, x2):
469
+ if self.BACKEND == self.TENSORFLOW:
470
+ return self.backend.maximum(x1, x2)
471
+ if self.BACKEND == self.TORCH:
472
+ return self.backend.maximum(x1, x2)
473
+ if self.BACKEND == self.NUMPY:
474
+ return x1 * (x1 > x2) + x2 * (x2 > x1)
475
+
476
+ def bk_device(self, device_name):
212
477
  return self.backend.device(device_name)
213
-
214
- def bk_ones(self,shape,dtype=None):
478
+
479
+ def bk_ones(self, shape, dtype=None):
215
480
  if dtype is None:
216
- dtype=self.all_type
217
- return(self.backend.ones(shape,dtype=dtype))
218
-
219
- def bk_conv1d(self,x,w):
220
- if self.BACKEND==self.TENSORFLOW:
221
- return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
222
- if self.BACKEND==self.TORCH:
223
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
224
- if self.BACKEND==self.NUMPY:
225
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
226
-
227
- def bk_flattenR(self,x):
228
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
229
- if x.dtype=='complex32' or x.dtype=='complex64':
230
- rr=self.backend.reshape(self.bk_real(x),[np.prod(np.array(list(x.shape)))])
231
- ii=self.backend.reshape(self.bk_imag(x),[np.prod(np.array(list(x.shape)))])
232
- return self.bk_concat([rr,ii],axis=0)
481
+ dtype = self.all_type
482
+ if self.BACKEND == self.TORCH:
483
+ return self.bk_cast(np.ones(shape))
484
+ return self.backend.ones(shape, dtype=dtype)
485
+
486
+ def bk_conv1d(self, x, w):
487
+ if self.BACKEND == self.TENSORFLOW:
488
+ return self.backend.nn.conv1d(x, w, stride=[1, 1, 1], padding="SAME")
489
+ if self.BACKEND == self.TORCH:
490
+ # Torch not yet done !!!
491
+ return self.backend.nn.conv1d(x, w, stride=1, padding="SAME")
492
+ if self.BACKEND == self.NUMPY:
493
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[1]], dtype=x.dtype)
494
+ for k in range(w.shape[1]):
495
+ for l_orient in range(w.shape[2]):
496
+ res[:, :, l_orient] += self.scipy.ndimage.convolve1d(
497
+ x[:, :, k], w[:, k, l_orient], axis=1, mode="constant", cval=0.0
498
+ )
499
+ return res
500
+
501
+ def bk_flattenR(self, x):
502
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
503
+ if self.bk_is_complex(x):
504
+ rr = self.backend.reshape(
505
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
506
+ )
507
+ ii = self.backend.reshape(
508
+ self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
509
+ )
510
+ return self.bk_concat([rr, ii], axis=0)
233
511
  else:
234
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
235
-
236
- if self.BACKEND==self.NUMPY:
237
- if x.dtype=='complex32' or x.dtype=='complex64':
238
- return np.concatenate([x.real.flatten(),x.imag.flatten()],0)
512
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
513
+
514
+ if self.BACKEND == self.NUMPY:
515
+ if self.bk_is_complex(x):
516
+ return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
239
517
  else:
240
518
  return x.flatten()
241
-
242
- def bk_flatten(self,x):
243
- if self.BACKEND==self.TENSORFLOW:
244
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
245
- if self.BACKEND==self.TORCH:
246
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
247
- if self.BACKEND==self.NUMPY:
519
+
520
+ def bk_flatten(self, x):
521
+ if self.BACKEND == self.TENSORFLOW:
522
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
523
+ if self.BACKEND == self.TORCH:
524
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
525
+ if self.BACKEND == self.NUMPY:
248
526
  return x.flatten()
249
-
250
- def bk_resize_image(self,x,shape):
251
- if self.BACKEND==self.TENSORFLOW:
252
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
253
- if self.BACKEND==self.TORCH:
254
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
255
- if self.BACKEND==self.NUMPY:
256
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
257
-
258
- def bk_L1(self,x):
259
- if x.dtype==self.all_cbk_type:
260
- xr=self.bk_real(x)
261
- xi=self.bk_imag(x)
262
-
263
- r=self.backend.sign(xr)*self.backend.sqrt(self.backend.sign(xr)*xr)
264
- i=self.backend.sign(xi)*self.backend.sqrt(self.backend.sign(xi)*xi)
265
- return self.bk_complex(r,i)
527
+
528
+ def bk_size(self, x):
529
+ if self.BACKEND == self.TENSORFLOW:
530
+ return self.backend.size(x)
531
+ if self.BACKEND == self.TORCH:
532
+ return x.numel()
533
+
534
+ if self.BACKEND == self.NUMPY:
535
+ return x.size
536
+
537
+ def bk_resize_image(self, x, shape):
538
+ if self.BACKEND == self.TENSORFLOW:
539
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
540
+
541
+ if self.BACKEND == self.TORCH:
542
+ tmp = self.backend.nn.functional.interpolate(
543
+ x, size=shape, mode="bilinear", align_corners=False
544
+ )
545
+ return self.bk_cast(tmp)
546
+ if self.BACKEND == self.NUMPY:
547
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
+
549
+ def bk_L1(self, x):
550
+ if x.dtype == self.all_cbk_type:
551
+ xr = self.bk_real(x)
552
+ xi = self.bk_imag(x)
553
+
554
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
555
+ # return r
556
+ i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
557
+
558
+ if self.BACKEND == self.TORCH:
559
+ return r
560
+ else:
561
+ return self.bk_complex(r, i)
266
562
  else:
267
- return self.backend.sign(x)*self.backend.sqrt(self.backend.sign(x)*x)
268
-
269
- def bk_square_comp(self,x):
270
- if x.dtype==self.all_cbk_type:
271
- xr=self.bk_real(x)
272
- xi=self.bk_imag(x)
273
-
274
- r=xr*xr
275
- i=xi*xi
276
- return self.bk_complex(r,i)
563
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
564
+
565
+ def bk_square_comp(self, x):
566
+ if x.dtype == self.all_cbk_type:
567
+ xr = self.bk_real(x)
568
+ xi = self.bk_imag(x)
569
+
570
+ r = xr * xr
571
+ i = xi * xi
572
+ return self.bk_complex(r, i)
277
573
  else:
278
- return x*x
279
-
280
- def bk_reduce_sum(self,data,axis=None):
281
-
574
+ return x * x
575
+
576
+ def bk_reduce_sum(self, data, axis=None):
577
+
282
578
  if axis is None:
283
- if self.BACKEND==self.TENSORFLOW:
284
- return(self.backend.reduce_sum(data))
285
- if self.BACKEND==self.TORCH:
286
- return(self.backend.sum(data))
287
- if self.BACKEND==self.NUMPY:
288
- return(np.sum(data))
579
+ if self.BACKEND == self.TENSORFLOW:
580
+ return self.backend.reduce_sum(data)
581
+ if self.BACKEND == self.TORCH:
582
+ return self.backend.sum(data)
583
+ if self.BACKEND == self.NUMPY:
584
+ return np.sum(data)
289
585
  else:
290
- if self.BACKEND==self.TENSORFLOW:
291
- return(self.backend.reduce_sum(data,axis=axis))
292
- if self.BACKEND==self.TORCH:
293
- return(self.backend.sum(data,axis))
294
- if self.BACKEND==self.NUMPY:
295
- return(np.sum(data,axis))
296
-
586
+ if self.BACKEND == self.TENSORFLOW:
587
+ return self.backend.reduce_sum(data, axis=axis)
588
+ if self.BACKEND == self.TORCH:
589
+ return self.backend.sum(data, axis)
590
+ if self.BACKEND == self.NUMPY:
591
+ return np.sum(data, axis)
592
+
297
593
  # ---------------------------------------------−---------
298
- def check_dense(self,data,datasz):
299
- if self.BACKEND==self.TENSORFLOW:
300
- if isinstance(data, tf.Tensor):
301
- return data
302
-
303
- idx=tf.cast(data.indices, tf.int32)
304
- data=tf.math.bincount(idx,weights=data.values,
305
- minlength=datasz)
306
- return data
307
-
308
- return data
309
-
594
+ # return a tensor size
310
595
 
311
- def constant(self,data):
596
+ def bk_size(self, data):
597
+ if self.BACKEND == self.TENSORFLOW:
598
+ return self.backend.size(data)
599
+ if self.BACKEND == self.TORCH:
600
+ return data.numel()
601
+ if self.BACKEND == self.NUMPY:
602
+ return data.size
312
603
 
313
- if self.BACKEND==self.TENSORFLOW:
314
- return(self.backend.constant(data))
315
- return(data)
604
+ # ---------------------------------------------−---------
605
+
606
+ def iso_mean(self, x, use_2D=False):
607
+ shape = list(x.shape)
608
+
609
+ i_orient = 2
610
+ if use_2D:
611
+ i_orient = 3
612
+ norient = shape[i_orient]
613
+
614
+ if len(shape) == i_orient + 1:
615
+ return self.bk_reduce_mean(x, -1)
616
+
617
+ if norient not in self._iso_orient:
618
+ self.calc_iso_orient(norient)
619
+
620
+ if self.bk_is_complex(x):
621
+ lmat = self._iso_orient_C[norient]
622
+ else:
623
+ lmat = self._iso_orient[norient]
624
+
625
+ oshape = shape[0]
626
+ for k in range(1, len(shape) - 2):
627
+ oshape *= shape[k]
628
+
629
+ oshape2 = [shape[k] for k in range(0, len(shape) - 1)]
630
+
631
+ return self.bk_reshape(
632
+ self.backend.matmul(self.bk_reshape(x, [oshape, norient * norient]), lmat),
633
+ oshape2,
634
+ )
635
+
636
+ def fft_ang(self, x, nharm=1, imaginary=False, use_2D=False):
637
+ shape = list(x.shape)
638
+
639
+ i_orient = 2
640
+ if use_2D:
641
+ i_orient = 3
642
+
643
+ norient = shape[i_orient]
644
+ nout = 1 + nharm
645
+
646
+ oshape_1 = shape[0]
647
+ for k in range(1, i_orient):
648
+ oshape_1 *= shape[k]
649
+ oshape_2 = norient
650
+ for k in range(i_orient, len(shape) - 1):
651
+ oshape_2 *= shape[k]
652
+ oshape = [oshape_1, oshape_2]
653
+
654
+ if imaginary:
655
+ nout = 1 + nharm * 2
656
+
657
+ oshape2 = [shape[k] for k in range(0, i_orient)] + [
658
+ nout for k in range(i_orient, len(shape))
659
+ ]
660
+
661
+ if (norient, nharm) not in self._fft_1_orient:
662
+ self.calc_fft_orient(norient, nharm, imaginary)
663
+
664
+ if len(shape) == i_orient + 1:
665
+ if self.bk_is_complex(x):
666
+ lmat = self._fft_1_orient_C[(norient, nharm, imaginary)]
667
+ else:
668
+ lmat = self._fft_1_orient[(norient, nharm, imaginary)]
669
+
670
+ if len(shape) == i_orient + 2:
671
+ if self.bk_is_complex(x):
672
+ lmat = self._fft_2_orient_C[(norient, nharm, imaginary)]
673
+ else:
674
+ lmat = self._fft_2_orient[(norient, nharm, imaginary)]
675
+
676
+ if len(shape) == i_orient + 3:
677
+ if self.bk_is_complex(x):
678
+ lmat = self._fft_3_orient_C[(norient, nharm, imaginary)]
679
+ else:
680
+ lmat = self._fft_3_orient[(norient, nharm, imaginary)]
681
+
682
+ return self.bk_reshape(
683
+ self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
684
+ )
685
+
686
+ def constant(self, data):
687
+
688
+ if self.BACKEND == self.TENSORFLOW:
689
+ return self.backend.constant(data)
690
+ return data
691
+
692
+ def bk_reduce_mean(self, data, axis=None):
316
693
 
317
- def bk_reduce_mean(self,data,axis=None):
318
-
319
694
  if axis is None:
320
- if self.BACKEND==self.TENSORFLOW:
321
- return(self.backend.reduce_mean(data))
322
- if self.BACKEND==self.TORCH:
323
- return(self.backend.mean(data))
324
- if self.BACKEND==self.NUMPY:
325
- return(np.mean(data))
695
+ if self.BACKEND == self.TENSORFLOW:
696
+ return self.backend.reduce_mean(data)
697
+ if self.BACKEND == self.TORCH:
698
+ return self.backend.mean(data)
699
+ if self.BACKEND == self.NUMPY:
700
+ return np.mean(data)
326
701
  else:
327
- if self.BACKEND==self.TENSORFLOW:
328
- return(self.backend.reduce_mean(data,axis=axis))
329
- if self.BACKEND==self.TORCH:
330
- return(self.backend.mean(data,axis))
331
- if self.BACKEND==self.NUMPY:
332
- return(np.mean(data,axis))
333
-
334
- def bk_reduce_std(self,data,axis=None):
335
-
702
+ if self.BACKEND == self.TENSORFLOW:
703
+ return self.backend.reduce_mean(data, axis=axis)
704
+ if self.BACKEND == self.TORCH:
705
+ return self.backend.mean(data, axis)
706
+ if self.BACKEND == self.NUMPY:
707
+ return np.mean(data, axis)
708
+
709
+ def bk_reduce_min(self, data, axis=None):
710
+
336
711
  if axis is None:
337
- if self.BACKEND==self.TENSORFLOW:
338
- return(self.backend.math.reduce_std(data))
339
- if self.BACKEND==self.TORCH:
340
- return(self.backend.std(data))
341
- if self.BACKEND==self.NUMPY:
342
- return(np.std(data))
712
+ if self.BACKEND == self.TENSORFLOW:
713
+ return self.backend.reduce_min(data)
714
+ if self.BACKEND == self.TORCH:
715
+ return self.backend.min(data)
716
+ if self.BACKEND == self.NUMPY:
717
+ return np.min(data)
343
718
  else:
344
- if self.BACKEND==self.TENSORFLOW:
345
- return(self.backend.math.reduce_std(data,axis=axis))
346
- if self.BACKEND==self.TORCH:
347
- return(self.backend.std(data,axis))
348
- if self.BACKEND==self.NUMPY:
349
- return(np.std(data,axis))
350
-
351
-
352
- def bk_sqrt(self,data):
353
-
354
- return(self.backend.sqrt(self.backend.abs(data)))
355
-
356
- def bk_abs(self,data):
357
- return(self.backend.abs(data))
358
-
359
- def bk_is_complex(self,data):
360
- if self.BACKEND==self.TENSORFLOW:
361
- return data.dtype==self.all_cbk_type
362
- if self.BACKEND==self.TORCH:
363
- return data.dtype==self.all_cbk_type
364
- if self.BACKEND==self.NUMPY:
365
- return data.dtype==self.all_cbk_type
366
-
367
- def bk_norm(self,data):
719
+ if self.BACKEND == self.TENSORFLOW:
720
+ return self.backend.reduce_min(data, axis=axis)
721
+ if self.BACKEND == self.TORCH:
722
+ return self.backend.min(data, axis)
723
+ if self.BACKEND == self.NUMPY:
724
+ return np.min(data, axis)
725
+
726
+ def bk_random_seed(self, value):
727
+
728
+ if self.BACKEND == self.TENSORFLOW:
729
+ return self.backend.random.set_seed(value)
730
+ if self.BACKEND == self.TORCH:
731
+ return self.backend.random.set_seed(value)
732
+ if self.BACKEND == self.NUMPY:
733
+ return np.random.seed(value)
734
+
735
+ def bk_random_uniform(self, shape):
736
+
737
+ if self.BACKEND == self.TENSORFLOW:
738
+ return self.backend.random.uniform(shape)
739
+ if self.BACKEND == self.TORCH:
740
+ return self.backend.random.uniform(shape)
741
+ if self.BACKEND == self.NUMPY:
742
+ return np.random.rand(shape)
743
+
744
+ def bk_reduce_std(self, data, axis=None):
745
+ if axis is None:
746
+ if self.BACKEND == self.TENSORFLOW:
747
+ r = self.backend.math.reduce_std(data)
748
+ if self.BACKEND == self.TORCH:
749
+ r = self.backend.std(data)
750
+ if self.BACKEND == self.NUMPY:
751
+ r = np.std(data)
752
+ return self.bk_complex(r, 0 * r)
753
+ else:
754
+ if self.BACKEND == self.TENSORFLOW:
755
+ r = self.backend.math.reduce_std(data, axis=axis)
756
+ if self.BACKEND == self.TORCH:
757
+ r = self.backend.std(data, axis)
758
+ if self.BACKEND == self.NUMPY:
759
+ r = np.std(data, axis)
760
+ if self.bk_is_complex(data):
761
+ return self.bk_complex(r, 0 * r)
762
+ else:
763
+ return r
764
+
765
+ def bk_sqrt(self, data):
766
+
767
+ return self.backend.sqrt(self.backend.abs(data))
768
+
769
+ def bk_abs(self, data):
770
+ return self.backend.abs(data)
771
+
772
+ def bk_is_complex(self, data):
773
+
774
+ if self.BACKEND == self.TENSORFLOW:
775
+ if isinstance(data, np.ndarray):
776
+ return data.dtype == "complex64" or data.dtype == "complex128"
777
+ return data.dtype.is_complex
778
+
779
+ if self.BACKEND == self.TORCH:
780
+ if isinstance(data, np.ndarray):
781
+ return data.dtype == "complex64" or data.dtype == "complex128"
782
+
783
+ return data.dtype.is_complex
784
+
785
+ if self.BACKEND == self.NUMPY:
786
+ return data.dtype == "complex64" or data.dtype == "complex128"
787
+
788
+ def bk_distcomp(self, data):
368
789
  if self.bk_is_complex(data):
369
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
790
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
791
+ self.bk_imag(data)
792
+ )
793
+ return res
794
+ else:
795
+ return self.bk_square(data)
796
+
797
+ def bk_norm(self, data):
798
+ if self.bk_is_complex(data):
799
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
800
+ self.bk_imag(data)
801
+ )
370
802
  return self.bk_sqrt(res)
371
803
 
372
804
  else:
373
805
  return self.bk_abs(data)
806
+
807
+ def bk_square(self, data):
808
+
809
+ if self.BACKEND == self.TENSORFLOW:
810
+ return self.backend.square(data)
811
+ if self.BACKEND == self.TORCH:
812
+ return self.backend.square(data)
813
+ if self.BACKEND == self.NUMPY:
814
+ return data * data
815
+
816
+ def bk_log(self, data):
817
+ if self.BACKEND == self.TENSORFLOW:
818
+ return self.backend.math.log(data)
819
+ if self.BACKEND == self.TORCH:
820
+ return self.backend.log(data)
821
+ if self.BACKEND == self.NUMPY:
822
+ return np.log(data)
823
+
824
+ def bk_matmul(self, a, b):
825
+ if self.BACKEND == self.TENSORFLOW:
826
+ return self.backend.matmul(a, b)
827
+ if self.BACKEND == self.TORCH:
828
+ return self.backend.matmul(a, b)
829
+ if self.BACKEND == self.NUMPY:
830
+ return np.dot(a, b)
831
+
832
+ def bk_tensor(self, data):
833
+ if self.BACKEND == self.TENSORFLOW:
834
+ return self.backend.constant(data)
835
+ if self.BACKEND == self.TORCH:
836
+ return self.backend.constant(data)
837
+ if self.BACKEND == self.NUMPY:
838
+ return data
839
+
840
+ def bk_complex(self, real, imag):
841
+ if self.BACKEND == self.TENSORFLOW:
842
+ return self.backend.dtypes.complex(real, imag)
843
+ if self.BACKEND == self.TORCH:
844
+ return self.backend.complex(real, imag)
845
+ if self.BACKEND == self.NUMPY:
846
+ return real + 1j * imag
847
+
848
+ def bk_exp(self, data):
849
+
850
+ return self.backend.exp(data)
851
+
852
+ def bk_min(self, data):
853
+
854
+ return self.backend.reduce_min(data)
855
+
856
+ def bk_argmin(self, data):
857
+
858
+ return self.backend.argmin(data)
859
+
860
+ def bk_tanh(self, data):
861
+
862
+ return self.backend.math.tanh(data)
863
+
864
+ def bk_max(self, data):
865
+
866
+ return self.backend.reduce_max(data)
867
+
868
+ def bk_argmax(self, data):
869
+
870
+ return self.backend.argmax(data)
871
+
872
+ def bk_reshape(self, data, shape):
873
+ if self.BACKEND == self.TORCH:
874
+ if isinstance(data, np.ndarray):
875
+ return data.reshape(shape)
876
+
877
+ return self.backend.reshape(data, shape)
878
+
879
+ def bk_repeat(self, data, nn, axis=0):
880
+ return self.backend.repeat(data, nn, axis=axis)
881
+
882
+ def bk_tile(self, data, nn,axis=0):
883
+ if self.BACKEND == self.TENSORFLOW:
884
+ return self.backend.tile(data, [nn])
374
885
 
375
- def bk_square(self,data):
376
-
377
- if self.BACKEND==self.TENSORFLOW:
378
- return(self.backend.square(data))
379
- if self.BACKEND==self.TORCH:
380
- return(self.backend.square(data))
381
- if self.BACKEND==self.NUMPY:
382
- return(data*data)
383
-
384
- def bk_log(self,data):
385
- if self.BACKEND==self.TENSORFLOW:
386
- return(self.backend.math.log(data))
387
- if self.BACKEND==self.TORCH:
388
- return(self.backend.log(data))
389
- if self.BACKEND==self.NUMPY:
390
- return(np.log(data))
391
-
392
- def bk_matmul(self,a,b):
393
- if self.BACKEND==self.TENSORFLOW:
394
- return(self.backend.matmul(a,b))
395
- if self.BACKEND==self.TORCH:
396
- return(self.backend.matmul(a,b))
397
- if self.BACKEND==self.NUMPY:
398
- return(np.dot(a,b))
399
-
400
- def bk_tensor(self,data):
401
- if self.BACKEND==self.TENSORFLOW:
402
- return(self.backend.constant(data))
403
- if self.BACKEND==self.TORCH:
404
- return(self.backend.constant(data))
405
- if self.BACKEND==self.NUMPY:
406
- return(data)
407
-
408
- def bk_complex(self,real,imag):
409
- if self.BACKEND==self.TENSORFLOW:
410
- return(self.backend.dtypes.complex(real,imag))
411
- if self.BACKEND==self.TORCH:
412
- return(self.backend.complex(real,imag))
413
- if self.BACKEND==self.NUMPY:
414
- return(np.complex(real,imag))
415
-
416
- def bk_exp(self,data):
417
-
418
- return(self.backend.exp(data))
419
-
420
- def bk_min(self,data):
421
-
422
- return(self.backend.reduce_min(data))
423
-
424
- def bk_argmin(self,data):
425
-
426
- return(self.backend.argmin(data))
427
-
428
- def bk_tanh(self,data):
886
+ return self.backend.tile(data, nn)
887
+
888
+ def bk_roll(self, data, nn, axis=0):
889
+ return self.backend.roll(data, nn, axis=axis)
890
+
891
+ def bk_expand_dims(self, data, axis=0):
892
+ if self.BACKEND == self.TENSORFLOW:
893
+ return self.backend.expand_dims(data, axis=axis)
894
+ if self.BACKEND == self.TORCH:
895
+ if isinstance(data, np.ndarray):
896
+ data = self.backend.from_numpy(data)
897
+ return self.backend.unsqueeze(data, axis)
898
+ if self.BACKEND == self.NUMPY:
899
+ return np.expand_dims(data, axis)
900
+
901
+ def bk_transpose(self, data, thelist):
902
+ if self.BACKEND == self.TENSORFLOW:
903
+ return self.backend.transpose(data, thelist)
904
+ if self.BACKEND == self.TORCH:
905
+ return self.backend.transpose(data, thelist)
906
+ if self.BACKEND == self.NUMPY:
907
+ return np.transpose(data, thelist)
908
+
909
+ def bk_concat(self, data, axis=None):
910
+
911
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
912
+ if axis is None:
913
+ if data[0].dtype == self.all_cbk_type:
914
+ ndata = len(data)
915
+ xr = self.backend.concat(
916
+ [self.bk_real(data[k]) for k in range(ndata)]
917
+ )
918
+ xi = self.backend.concat(
919
+ [self.bk_imag(data[k]) for k in range(ndata)]
920
+ )
921
+ return self.backend.complex(xr, xi)
922
+ else:
923
+ return self.backend.concat(data)
924
+ else:
925
+ if data[0].dtype == self.all_cbk_type:
926
+ ndata = len(data)
927
+ xr = self.backend.concat(
928
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
929
+ )
930
+ xi = self.backend.concat(
931
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
932
+ )
933
+ return self.backend.complex(xr, xi)
934
+ else:
935
+ return self.backend.concat(data, axis=axis)
936
+ else:
937
+ if axis is None:
938
+ return np.concatenate(data, axis=0)
939
+ else:
940
+ return np.concatenate(data, axis=axis)
941
+
942
+ def bk_zeros(self, shape,dtype=None):
943
+ if self.BACKEND == self.TENSORFLOW:
944
+ return self.backend.zeros(shape,dtype=dtype)
945
+ if self.BACKEND == self.TORCH:
946
+ return self.backend.zeros(shape,dtype=dtype)
947
+ if self.BACKEND == self.NUMPY:
948
+ return np.zeros(shape,dtype=dtype)
949
+
950
+ def bk_gather(self, data,idx):
951
+ if self.BACKEND == self.TENSORFLOW:
952
+ return self.backend.gather(data,idx)
953
+ if self.BACKEND == self.TORCH:
954
+ return data[idx]
955
+ if self.BACKEND == self.NUMPY:
956
+ return data[idx]
429
957
 
430
- return(self.backend.math.tanh(data))
431
-
432
- def bk_max(self,data):
958
+ def bk_reverse(self, data,axis=0):
959
+ if self.BACKEND == self.TENSORFLOW:
960
+ return self.backend.reverse(data,axis=[axis])
961
+ if self.BACKEND == self.TORCH:
962
+ return self.backend.reverse(data,axis=axis)
963
+ if self.BACKEND == self.NUMPY:
964
+ return np.reverse(data,axis=axis)
433
965
 
434
- return(self.backend.reduce_max(data))
435
-
436
- def bk_argmax(self,data):
966
+ def bk_fft(self, data):
967
+ if self.BACKEND == self.TENSORFLOW:
968
+ return self.backend.signal.fft(data)
969
+ if self.BACKEND == self.TORCH:
970
+ return self.backend.fft(data)
971
+ if self.BACKEND == self.NUMPY:
972
+ return self.backend.fft.fft(data)
437
973
 
438
- return(self.backend.argmax(data))
439
-
440
- def bk_reshape(self,data,shape):
441
- return(self.backend.reshape(data,shape))
442
-
443
- def bk_repeat(self,data,nn,axis=0):
444
- return(self.backend.repeat(data,nn,axis=axis))
445
-
446
- def bk_tile(self,data,nn,axis=0):
447
- return(self.backend.tile(data,nn))
448
-
449
- def bk_roll(self,data,nn,axis=0):
450
- return(self.backend.roll(data,nn,axis=axis))
451
-
452
- def bk_expand_dims(self,data,axis=0):
453
- if self.BACKEND==self.TENSORFLOW:
454
- return(self.backend.expand_dims(data,axis=axis))
455
- if self.BACKEND==self.TORCH:
456
- if isinstance(data,np.ndarray):
457
- data=self.backend.from_numpy(data)
458
- return(self.backend.unsqueeze(data,axis))
459
- if self.BACKEND==self.NUMPY:
460
- return(np.expand_dims(data,axis))
461
-
462
- def bk_transpose(self,data,thelist):
463
- if self.BACKEND==self.TENSORFLOW:
464
- return(self.backend.transpose(data,thelist))
465
- if self.BACKEND==self.TORCH:
466
- return(self.backend.transpose(data,thelist))
467
- if self.BACKEND==self.NUMPY:
468
- return(np.transpose(data,thelist))
469
-
470
- def bk_concat(self,data,axis=None):
471
-
472
- if axis is None:
473
- return(self.backend.concat(data))
474
- else:
475
- return(self.backend.concat(data,axis=axis))
974
+ def bk_rfft(self, data):
975
+ if self.BACKEND == self.TENSORFLOW:
976
+ return self.backend.signal.rfft(data)
977
+ if self.BACKEND == self.TORCH:
978
+ return self.backend.rfft(data)
979
+ if self.BACKEND == self.NUMPY:
980
+ return self.backend.fft.rfft(data)
981
+ def bk_conjugate(self, data):
476
982
 
477
-
478
- def bk_conjugate(self,data):
479
-
480
- if self.BACKEND==self.TENSORFLOW:
983
+ if self.BACKEND == self.TENSORFLOW:
481
984
  return self.backend.math.conj(data)
482
- if self.BACKEND==self.TORCH:
985
+ if self.BACKEND == self.TORCH:
483
986
  return self.backend.conj(data)
484
- if self.BACKEND==self.NUMPY:
987
+ if self.BACKEND == self.NUMPY:
485
988
  return data.conjugate()
486
-
487
- def bk_real(self,data):
488
- if self.BACKEND==self.TENSORFLOW:
989
+
990
+ def bk_real(self, data):
991
+ if self.BACKEND == self.TENSORFLOW:
489
992
  return self.backend.math.real(data)
490
- if self.BACKEND==self.TORCH:
491
- return self.backend.real(data)
492
- if self.BACKEND==self.NUMPY:
493
- return self.backend.real(data)
993
+ if self.BACKEND == self.TORCH:
994
+ return data.real
995
+ if self.BACKEND == self.NUMPY:
996
+ return data.real
494
997
 
495
- def bk_imag(self,data):
496
- if self.BACKEND==self.TENSORFLOW:
998
+ def bk_imag(self, data):
999
+ if self.BACKEND == self.TENSORFLOW:
497
1000
  return self.backend.math.imag(data)
498
- if self.BACKEND==self.TORCH:
499
- return self.backend.imag(data)
500
- if self.BACKEND==self.NUMPY:
501
- return self.backend.imag(data)
502
-
503
- def bk_relu(self,x):
504
- if self.BACKEND==self.TENSORFLOW:
505
- if x.dtype==self.all_cbk_type:
506
- xr=self.backend.nn.relu(self.bk_real(x))
507
- xi=self.backend.nn.relu(self.bk_imag(x))
508
- return self.backend.complex(xr,xi)
1001
+ if self.BACKEND == self.TORCH:
1002
+ if data.dtype == self.all_cbk_type:
1003
+ return data.imag
1004
+ else:
1005
+ return 0
1006
+
1007
+ if self.BACKEND == self.NUMPY:
1008
+ return data.imag
1009
+
1010
+ def bk_relu(self, x):
1011
+ if self.BACKEND == self.TENSORFLOW:
1012
+ if x.dtype == self.all_cbk_type:
1013
+ xr = self.backend.nn.relu(self.bk_real(x))
1014
+ xi = self.backend.nn.relu(self.bk_imag(x))
1015
+ return self.backend.complex(xr, xi)
509
1016
  else:
510
1017
  return self.backend.nn.relu(x)
511
- if self.BACKEND==self.TORCH:
1018
+ if self.BACKEND == self.TORCH:
512
1019
  return self.backend.relu(x)
513
- if self.BACKEND==self.NUMPY:
514
- return (x>0)*x
515
-
516
- def bk_cast(self,x):
517
- if isinstance(x,np.float64):
518
- if self.all_bk_type=='float32':
519
- return(np.float32(x))
1020
+ if self.BACKEND == self.NUMPY:
1021
+ return (x > 0) * x
1022
+
1023
+ def bk_cast(self, x):
1024
+ if isinstance(x, np.float64):
1025
+ if self.all_bk_type == "float32":
1026
+ return np.float32(x)
1027
+ else:
1028
+ return x
1029
+ if isinstance(x, np.float32):
1030
+ if self.all_bk_type == "float64":
1031
+ return np.float64(x)
520
1032
  else:
521
- return(x)
522
- if isinstance(x,np.float32):
523
- if self.all_bk_type=='float64':
524
- return(np.float64(x))
1033
+ return x
1034
+
1035
+ if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
1036
+ if self.all_bk_type == "float64":
1037
+ return np.float64(x)
525
1038
  else:
526
- return(x)
1039
+ return np.float32(x)
527
1040
 
528
- if x.dtype=='complex128' or x.dtype=='complex64':
529
- out_type=self.all_cbk_type
1041
+ if self.bk_is_complex(x):
1042
+ out_type = self.all_cbk_type
530
1043
  else:
531
- out_type=self.all_bk_type
532
-
533
- if self.BACKEND==self.TENSORFLOW:
534
- return self.backend.cast(x,out_type)
535
-
536
- if self.BACKEND==self.TORCH:
537
- if isinstance(x,np.ndarray):
538
- x=self.backend.from_numpy(x)
539
-
1044
+ out_type = self.all_bk_type
1045
+
1046
+ if self.BACKEND == self.TENSORFLOW:
1047
+ return self.backend.cast(x, out_type)
1048
+
1049
+ if self.BACKEND == self.TORCH:
1050
+ if isinstance(x, np.ndarray):
1051
+ x = self.backend.from_numpy(x)
1052
+
540
1053
  if x.dtype.is_complex:
541
- out_type=self.all_cbk_type
1054
+ out_type = self.all_cbk_type
542
1055
  else:
543
- out_type=self.all_bk_type
544
-
1056
+ out_type = self.all_bk_type
1057
+
545
1058
  return x.type(out_type)
546
-
547
- if self.BACKEND==self.NUMPY:
1059
+
1060
+ if self.BACKEND == self.NUMPY:
548
1061
  return x.astype(out_type)