foscat 3.1.6__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -1,851 +1,1024 @@
1
1
  import sys
2
+
2
3
  import numpy as np
3
4
 
5
+
4
6
  class foscat_backend:
5
-
6
- def __init__(self,name,mpi_rank=0,all_type='float64',gpupos=0,silent=False):
7
-
8
- self.TENSORFLOW=1
9
- self.TORCH=2
10
- self.NUMPY=3
11
-
7
+
8
+ def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
9
+
10
+ self.TENSORFLOW = 1
11
+ self.TORCH = 2
12
+ self.NUMPY = 3
13
+
12
14
  # table use to compute the iso orientation rotation
13
- self._iso_orient={}
14
- self._iso_orient_T={}
15
- self._iso_orient_C={}
16
- self._iso_orient_C_T={}
17
- self._fft_1_orient={}
18
- self._fft_1_orient_C={}
19
- self._fft_2_orient={}
20
- self._fft_2_orient_C={}
21
- self._fft_3_orient={}
22
- self._fft_3_orient_C={}
23
-
24
- self.BACKEND=name
25
-
26
- if name not in ['tensorflow','torch','numpy']:
27
- print('Backend "%s" not yet implemented'%(name))
28
- print(' Choose inside the next 3 available backends :')
29
- print(' - tensorflow')
30
- print(' - torch')
31
- print(' - numpy (Impossible to do synthesis using numpy)')
15
+ self._iso_orient = {}
16
+ self._iso_orient_T = {}
17
+ self._iso_orient_C = {}
18
+ self._iso_orient_C_T = {}
19
+ self._fft_1_orient = {}
20
+ self._fft_1_orient_C = {}
21
+ self._fft_2_orient = {}
22
+ self._fft_2_orient_C = {}
23
+ self._fft_3_orient = {}
24
+ self._fft_3_orient_C = {}
25
+
26
+ self.BACKEND = name
27
+
28
+ if name not in ["tensorflow", "torch", "numpy"]:
29
+ print('Backend "%s" not yet implemented' % (name))
30
+ print(" Choose inside the next 3 available backends :")
31
+ print(" - tensorflow")
32
+ print(" - torch")
33
+ print(" - numpy (Impossible to do synthesis using numpy)")
32
34
  return None
33
-
34
- if self.BACKEND=='tensorflow':
35
+
36
+ if self.BACKEND == "tensorflow":
35
37
  import tensorflow as tf
36
-
37
- self.backend=tf
38
- self.BACKEND=self.TENSORFLOW
39
- #tf.config.threading.set_inter_op_parallelism_threads(1)
40
- #tf.config.threading.set_intra_op_parallelism_threads(1)
38
+
39
+ self.backend = tf
40
+ self.BACKEND = self.TENSORFLOW
41
+ # tf.config.threading.set_inter_op_parallelism_threads(1)
42
+ # tf.config.threading.set_intra_op_parallelism_threads(1)
41
43
  self.tf_function = tf.function
42
44
 
43
- if self.BACKEND=='torch':
45
+ if self.BACKEND == "torch":
44
46
  import torch
45
- self.BACKEND=self.TORCH
46
- self.backend=torch
47
+
48
+ self.BACKEND = self.TORCH
49
+ self.backend = torch
47
50
  self.tf_function = self.tf_loc_function
48
-
49
- if self.BACKEND=='numpy':
50
- self.BACKEND=self.NUMPY
51
- self.backend=np
51
+
52
+ if self.BACKEND == "numpy":
53
+ self.BACKEND = self.NUMPY
54
+ self.backend = np
52
55
  import scipy as scipy
53
- self.scipy=scipy
56
+
57
+ self.scipy = scipy
54
58
  self.tf_function = self.tf_loc_function
55
-
56
- self.float64=self.backend.float64
57
- self.float32=self.backend.float32
58
- self.int64=self.backend.int64
59
- self.int32=self.backend.int32
60
- self.complex64=self.backend.complex128
61
- self.complex128=self.backend.complex64
62
-
63
- if all_type=='float32':
64
- self.all_bk_type=self.backend.float32
65
- self.all_cbk_type=self.backend.complex64
59
+
60
+ self.float64 = self.backend.float64
61
+ self.float32 = self.backend.float32
62
+ self.int64 = self.backend.int64
63
+ self.int32 = self.backend.int32
64
+ self.complex64 = self.backend.complex128
65
+ self.complex128 = self.backend.complex64
66
+
67
+ if all_type == "float32":
68
+ self.all_bk_type = self.backend.float32
69
+ self.all_cbk_type = self.backend.complex64
66
70
  else:
67
- if all_type=='float64':
68
- self.all_type='float64'
69
- self.all_bk_type=self.backend.float64
70
- self.all_cbk_type=self.backend.complex128
71
+ if all_type == "float64":
72
+ self.all_type = "float64"
73
+ self.all_bk_type = self.backend.float64
74
+ self.all_cbk_type = self.backend.complex128
71
75
  else:
72
- print('ERROR INIT FOCUS ',all_type,' should be float32 or float64')
76
+ print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
73
77
  return None
74
- #===========================================================================
75
- # INIT
76
- if mpi_rank==0:
77
- if self.BACKEND==self.TENSORFLOW and silent==False:
78
- print("Num GPUs Available: ", len(self.backend.config.experimental.list_physical_devices('GPU')))
78
+ # ===========================================================================
79
+ # INIT
80
+ if mpi_rank == 0:
81
+ if self.BACKEND == self.TENSORFLOW and not silent:
82
+ print(
83
+ "Num GPUs Available: ",
84
+ len(self.backend.config.experimental.list_physical_devices("GPU")),
85
+ )
79
86
  sys.stdout.flush()
80
-
81
- if self.BACKEND==self.TENSORFLOW:
87
+
88
+ if self.BACKEND == self.TENSORFLOW:
82
89
  self.backend.debugging.set_log_device_placement(False)
83
90
  self.backend.config.set_soft_device_placement(True)
84
-
85
- gpus = self.backend.config.experimental.list_physical_devices('GPU')
86
-
87
- if self.BACKEND==self.TORCH:
88
- gpus=torch.cuda.is_available()
89
-
90
- if self.BACKEND==self.NUMPY:
91
- gpus=[]
92
- gpuname='CPU:0'
93
- self.gpulist={}
94
- self.gpulist[0]=gpuname
95
- self.ngpu=1
96
-
91
+
92
+ gpus = self.backend.config.experimental.list_physical_devices("GPU")
93
+
94
+ if self.BACKEND == self.TORCH:
95
+ gpus = torch.cuda.is_available()
96
+
97
+ if self.BACKEND == self.NUMPY:
98
+ gpus = []
99
+ gpuname = "CPU:0"
100
+ self.gpulist = {}
101
+ self.gpulist[0] = gpuname
102
+ self.ngpu = 1
103
+
97
104
  if gpus:
98
105
  try:
99
- if self.BACKEND==self.TENSORFLOW:
100
- # Currently, memory growth needs to be the same across GPUs
106
+ if self.BACKEND == self.TENSORFLOW:
107
+ # Currently, memory growth needs to be the same across GPUs
101
108
  for gpu in gpus:
102
109
  self.backend.config.experimental.set_memory_growth(gpu, True)
103
- logical_gpus = self.backend.config.experimental.list_logical_devices('GPU')
104
- print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
110
+ logical_gpus = (
111
+ self.backend.config.experimental.list_logical_devices("GPU")
112
+ )
113
+ print(
114
+ len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
115
+ )
105
116
  sys.stdout.flush()
106
- self.ngpu=len(logical_gpus)
107
- gpuname=logical_gpus[gpupos%self.ngpu].name
108
- self.gpulist={}
117
+ self.ngpu = len(logical_gpus)
118
+ gpuname = logical_gpus[gpupos % self.ngpu].name
119
+ self.gpulist = {}
109
120
  for i in range(self.ngpu):
110
- self.gpulist[i]=logical_gpus[i].name
111
- if self.BACKEND==self.TORCH:
112
- self.ngpu=torch.cuda.device_count()
113
- self.gpulist={}
121
+ self.gpulist[i] = logical_gpus[i].name
122
+ if self.BACKEND == self.TORCH:
123
+ self.ngpu = torch.cuda.device_count()
124
+ self.gpulist = {}
114
125
  for k in range(self.ngpu):
115
- self.gpulist[k]=torch.cuda.get_device_name(0)
126
+ self.gpulist[k] = torch.cuda.get_device_name(0)
116
127
 
117
128
  except RuntimeError as e:
118
129
  # Memory growth must be set before GPUs have been initialized
119
130
  print(e)
120
-
121
- def tf_loc_function(self,func):
131
+
132
+ def tf_loc_function(self, func):
122
133
  return func
123
-
124
- def calc_iso_orient(self,norient):
125
- tmp=np.zeros([norient*norient,norient])
134
+
135
+ def calc_iso_orient(self, norient):
136
+ tmp = np.zeros([norient * norient, norient])
126
137
  for i in range(norient):
127
138
  for j in range(norient):
128
- tmp[j*norient+(j+i)%norient,i]=0.25
129
-
130
- self._iso_orient[norient]=self.constant(self.bk_cast(tmp))
131
- self._iso_orient_T[norient]=self.constant(self.bk_cast(4*tmp.T))
132
- self._iso_orient_C[norient]=self.bk_complex(self._iso_orient[norient],0*self._iso_orient[norient])
133
- self._iso_orient_C_T[norient]=self.bk_complex(self._iso_orient_T[norient],0*self._iso_orient_T[norient])
134
-
135
- def calc_fft_orient(self,norient,nharm,imaginary):
139
+ tmp[j * norient + (j + i) % norient, i] = 0.25
140
+
141
+ self._iso_orient[norient] = self.constant(self.bk_cast(tmp))
142
+ self._iso_orient_T[norient] = self.constant(self.bk_cast(4 * tmp.T))
143
+ self._iso_orient_C[norient] = self.bk_complex(
144
+ self._iso_orient[norient], 0 * self._iso_orient[norient]
145
+ )
146
+ self._iso_orient_C_T[norient] = self.bk_complex(
147
+ self._iso_orient_T[norient], 0 * self._iso_orient_T[norient]
148
+ )
149
+
150
+ def calc_fft_orient(self, norient, nharm, imaginary):
151
+
152
+ x = np.arange(norient) / norient * 2 * np.pi
136
153
 
137
- x=np.arange(norient)/norient*2*np.pi
138
-
139
154
  if imaginary:
140
- tmp=np.zeros([norient,1+nharm*2])
141
- tmp[:,0]=1.0
155
+ tmp = np.zeros([norient, 1 + nharm * 2])
156
+ tmp[:, 0] = 1.0
142
157
  for k in range(nharm):
143
- tmp[:,k*2+1]=np.cos(x*(k+1))
144
- tmp[:,k*2+2]=np.sin(x*(k+1))
145
-
146
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
147
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
158
+ tmp[:, k * 2 + 1] = np.cos(x * (k + 1))
159
+ tmp[:, k * 2 + 2] = np.sin(x * (k + 1))
160
+
161
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
162
+ self.constant(tmp)
163
+ )
164
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
165
+ self._fft_1_orient[(norient, nharm, imaginary)],
166
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
167
+ )
148
168
  else:
149
- tmp=np.zeros([norient,1+nharm])
150
- for k in range(nharm+1):
151
- tmp[:,k]=np.cos(x*k)
152
-
153
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
154
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
169
+ tmp = np.zeros([norient, 1 + nharm])
170
+ for k in range(nharm + 1):
171
+ tmp[:, k] = np.cos(x * k)
155
172
 
156
- x=np.repeat(x,norient).reshape(norient,norient)
173
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
174
+ self.constant(tmp)
175
+ )
176
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
177
+ self._fft_1_orient[(norient, nharm, imaginary)],
178
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
179
+ )
180
+
181
+ x = np.repeat(x, norient).reshape(norient, norient)
157
182
 
158
183
  if imaginary:
159
- tmp=np.zeros([norient,norient,(1+nharm*2),(1+nharm*2)])
160
- tmp[:,:,0,0]=1.0
184
+ tmp = np.zeros([norient, norient, (1 + nharm * 2), (1 + nharm * 2)])
185
+ tmp[:, :, 0, 0] = 1.0
161
186
  for k in range(nharm):
162
- tmp[:,:,k*2+1,0]=np.cos(x*(k+1))
163
- tmp[:,:,k*2+2,0]=np.sin(x*(k+1))
164
- tmp[:,:,0,k*2+1]=np.cos((x.T)*(k+1))
165
- tmp[:,:,0,k*2+2]=np.sin((x.T)*(k+1))
166
- for l in range(nharm):
167
- tmp[:,:,k*2+1,l*2+1]=np.cos(x*(k+1))*np.cos((x.T)*(l+1))
168
- tmp[:,:,k*2+2,l*2+1]=np.sin(x*(k+1))*np.cos((x.T)*(l+1))
169
- tmp[:,:,k*2+1,l*2+2]=np.cos(x*(k+1))*np.sin((x.T)*(l+1))
170
- tmp[:,:,k*2+2,l*2+2]=np.sin(x*(k+1))*np.sin((x.T)*(l+1))
171
-
172
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+2*nharm)*(1+2*nharm))))
173
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
187
+ tmp[:, :, k * 2 + 1, 0] = np.cos(x * (k + 1))
188
+ tmp[:, :, k * 2 + 2, 0] = np.sin(x * (k + 1))
189
+ tmp[:, :, 0, k * 2 + 1] = np.cos((x.T) * (k + 1))
190
+ tmp[:, :, 0, k * 2 + 2] = np.sin((x.T) * (k + 1))
191
+ for l_orient in range(nharm):
192
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 1] = np.cos(
193
+ x * (k + 1)
194
+ ) * np.cos((x.T) * (l_orient + 1))
195
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 1] = np.sin(
196
+ x * (k + 1)
197
+ ) * np.cos((x.T) * (l_orient + 1))
198
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 2] = np.cos(
199
+ x * (k + 1)
200
+ ) * np.sin((x.T) * (l_orient + 1))
201
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 2] = np.sin(
202
+ x * (k + 1)
203
+ ) * np.sin((x.T) * (l_orient + 1))
204
+
205
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
206
+ self.constant(
207
+ tmp.reshape(norient * norient, (1 + 2 * nharm) * (1 + 2 * nharm))
208
+ )
209
+ )
210
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
211
+ self._fft_2_orient[(norient, nharm, imaginary)],
212
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
213
+ )
174
214
  else:
175
- tmp=np.zeros([norient,norient,(1+nharm),(1+nharm)])
176
-
177
- for k in range(nharm+1):
178
- for l in range(nharm+1):
179
- tmp[:,:,k,l]=np.cos(x*k)*np.cos((x.T)*l)
180
-
181
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+nharm)*(1+nharm))))
182
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
215
+ tmp = np.zeros([norient, norient, (1 + nharm), (1 + nharm)])
216
+
217
+ for k in range(nharm + 1):
218
+ for l_orient in range(nharm + 1):
219
+ tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
220
+
221
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
222
+ self.constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
223
+ )
224
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
225
+ self._fft_2_orient[(norient, nharm, imaginary)],
226
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
227
+ )
183
228
 
184
- x=np.arange(norient)/norient*2*np.pi
185
- xx=np.zeros([norient,norient,norient])
186
- yy=np.zeros([norient,norient,norient])
187
- zz=np.zeros([norient,norient,norient])
229
+ x = np.arange(norient) / norient * 2 * np.pi
230
+ xx = np.zeros([norient, norient, norient])
231
+ yy = np.zeros([norient, norient, norient])
232
+ zz = np.zeros([norient, norient, norient])
188
233
  for i in range(norient):
189
234
  for j in range(norient):
190
- xx[:,i,j]=x
191
- yy[i,:,j]=x
192
- zz[i,j,:]=x
193
-
235
+ xx[:, i, j] = x
236
+ yy[i, :, j] = x
237
+ zz[i, j, :] = x
238
+
194
239
  if imaginary:
195
- tmp=np.ones([norient,norient,norient,(1+nharm*2),(1+nharm*2),(1+nharm*2)])
196
-
240
+ tmp = np.ones(
241
+ [
242
+ norient,
243
+ norient,
244
+ norient,
245
+ (1 + nharm * 2),
246
+ (1 + nharm * 2),
247
+ (1 + nharm * 2),
248
+ ]
249
+ )
250
+
197
251
  for k in range(nharm):
198
- tmp[:,:,:,k*2+1,0,0]=np.cos(xx*(k+1))
199
- tmp[:,:,:,0,k*2+1,0]=np.cos(yy*(k+1))
200
- tmp[:,:,:,0,0,k*2+1]=np.cos(zz*(k+1))
201
-
202
- tmp[:,:,:,k*2+2,0,0]=np.sin(xx*(k+1))
203
- tmp[:,:,:,0,k*2+2,0]=np.sin(yy*(k+1))
204
- tmp[:,:,:,0,0,k*2+2]=np.sin(zz*(k+1))
205
- for l in range(nharm):
206
- tmp[:,:,:,k*2+1,l*2+1,0]=np.cos(xx*(k+1))*np.cos(yy*(l+1))
207
- tmp[:,:,:,k*2+1,l*2+2,0]=np.cos(xx*(k+1))*np.sin(yy*(l+1))
208
- tmp[:,:,:,k*2+2,l*2+1,0]=np.sin(xx*(k+1))*np.cos(yy*(l+1))
209
- tmp[:,:,:,k*2+2,l*2+2,0]=np.sin(xx*(k+1))*np.sin(yy*(l+1))
210
-
211
- tmp[:,:,:,k*2+1,0,l*2+1]=np.cos(xx*(k+1))*np.cos(zz*(l+1))
212
- tmp[:,:,:,k*2+1,0,l*2+2]=np.cos(xx*(k+1))*np.sin(zz*(l+1))
213
- tmp[:,:,:,k*2+2,0,l*2+1]=np.sin(xx*(k+1))*np.cos(zz*(l+1))
214
- tmp[:,:,:,k*2+2,0,l*2+2]=np.sin(xx*(k+1))*np.sin(zz*(l+1))
215
-
216
- tmp[:,:,:,0,k*2+1,l*2+1]=np.cos(yy*(k+1))*np.cos(zz*(l+1))
217
- tmp[:,:,:,0,k*2+1,l*2+2]=np.cos(yy*(k+1))*np.sin(zz*(l+1))
218
- tmp[:,:,:,0,k*2+2,l*2+1]=np.sin(yy*(k+1))*np.cos(zz*(l+1))
219
- tmp[:,:,:,0,k*2+2,l*2+2]=np.sin(yy*(k+1))*np.sin(zz*(l+1))
220
-
252
+ tmp[:, :, :, k * 2 + 1, 0, 0] = np.cos(xx * (k + 1))
253
+ tmp[:, :, :, 0, k * 2 + 1, 0] = np.cos(yy * (k + 1))
254
+ tmp[:, :, :, 0, 0, k * 2 + 1] = np.cos(zz * (k + 1))
255
+
256
+ tmp[:, :, :, k * 2 + 2, 0, 0] = np.sin(xx * (k + 1))
257
+ tmp[:, :, :, 0, k * 2 + 2, 0] = np.sin(yy * (k + 1))
258
+ tmp[:, :, :, 0, 0, k * 2 + 2] = np.sin(zz * (k + 1))
259
+ for l_orient in range(nharm):
260
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, 0] = np.cos(
261
+ xx * (k + 1)
262
+ ) * np.cos(yy * (l_orient + 1))
263
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, 0] = np.cos(
264
+ xx * (k + 1)
265
+ ) * np.sin(yy * (l_orient + 1))
266
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, 0] = np.sin(
267
+ xx * (k + 1)
268
+ ) * np.cos(yy * (l_orient + 1))
269
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, 0] = np.sin(
270
+ xx * (k + 1)
271
+ ) * np.sin(yy * (l_orient + 1))
272
+
273
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 1] = np.cos(
274
+ xx * (k + 1)
275
+ ) * np.cos(zz * (l_orient + 1))
276
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 2] = np.cos(
277
+ xx * (k + 1)
278
+ ) * np.sin(zz * (l_orient + 1))
279
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 1] = np.sin(
280
+ xx * (k + 1)
281
+ ) * np.cos(zz * (l_orient + 1))
282
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 2] = np.sin(
283
+ xx * (k + 1)
284
+ ) * np.sin(zz * (l_orient + 1))
285
+
286
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 1] = np.cos(
287
+ yy * (k + 1)
288
+ ) * np.cos(zz * (l_orient + 1))
289
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 2] = np.cos(
290
+ yy * (k + 1)
291
+ ) * np.sin(zz * (l_orient + 1))
292
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 1] = np.sin(
293
+ yy * (k + 1)
294
+ ) * np.cos(zz * (l_orient + 1))
295
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 2] = np.sin(
296
+ yy * (k + 1)
297
+ ) * np.sin(zz * (l_orient + 1))
298
+
221
299
  for m in range(nharm):
222
- tmp[:,:,:,k*2+1,l*2+1,m*2+1]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
223
- tmp[:,:,:,k*2+1,l*2+1,m*2+2]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
224
- tmp[:,:,:,k*2+1,l*2+2,m*2+1]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
225
- tmp[:,:,:,k*2+1,l*2+2,m*2+2]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
226
- tmp[:,:,:,k*2+2,l*2+1,m*2+1]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
227
- tmp[:,:,:,k*2+2,l*2+1,m*2+2]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
228
- tmp[:,:,:,k*2+2,l*2+2,m*2+1]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
229
- tmp[:,:,:,k*2+2,l*2+2,m*2+2]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
230
-
231
-
232
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm*2)*(1+nharm*2)*(1+nharm*2))))
233
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
300
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 1] = (
301
+ np.cos(xx * (k + 1))
302
+ * np.cos(yy * (l_orient + 1))
303
+ * np.cos(zz * (m + 1))
304
+ )
305
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 2] = (
306
+ np.cos(xx * (k + 1))
307
+ * np.cos(yy * (l_orient + 1))
308
+ * np.sin(zz * (m + 1))
309
+ )
310
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 1] = (
311
+ np.cos(xx * (k + 1))
312
+ * np.sin(yy * (l_orient + 1))
313
+ * np.cos(zz * (m + 1))
314
+ )
315
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 2] = (
316
+ np.cos(xx * (k + 1))
317
+ * np.sin(yy * (l_orient + 1))
318
+ * np.sin(zz * (m + 1))
319
+ )
320
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 1] = (
321
+ np.sin(xx * (k + 1))
322
+ * np.cos(yy * (l_orient + 1))
323
+ * np.cos(zz * (m + 1))
324
+ )
325
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 2] = (
326
+ np.sin(xx * (k + 1))
327
+ * np.cos(yy * (l_orient + 1))
328
+ * np.sin(zz * (m + 1))
329
+ )
330
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 1] = (
331
+ np.sin(xx * (k + 1))
332
+ * np.sin(yy * (l_orient + 1))
333
+ * np.cos(zz * (m + 1))
334
+ )
335
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 2] = (
336
+ np.sin(xx * (k + 1))
337
+ * np.sin(yy * (l_orient + 1))
338
+ * np.sin(zz * (m + 1))
339
+ )
340
+
341
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
342
+ self.constant(
343
+ tmp.reshape(
344
+ norient * norient * norient,
345
+ (1 + nharm * 2) * (1 + nharm * 2) * (1 + nharm * 2),
346
+ )
347
+ )
348
+ )
349
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
350
+ self._fft_3_orient[(norient, nharm, imaginary)],
351
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
352
+ )
234
353
  else:
235
- tmp=np.zeros([norient,norient,norient,(1+nharm),(1+nharm),(1+nharm)])
236
-
237
- for k in range(nharm+1):
238
- for l in range(nharm+1):
239
- for m in range(nharm+1):
240
- tmp[:,:,:,k,l,m]=np.cos(xx*k)*np.cos(yy*l)*np.cos(zz*m)
354
+ tmp = np.zeros(
355
+ [norient, norient, norient, (1 + nharm), (1 + nharm), (1 + nharm)]
356
+ )
357
+
358
+ for k in range(nharm + 1):
359
+ for l_orient in range(nharm + 1):
360
+ for m in range(nharm + 1):
361
+ tmp[:, :, :, k, l_orient, m] = (
362
+ np.cos(xx * k) * np.cos(yy * l_orient) * np.cos(zz * m)
363
+ )
364
+
365
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
366
+ self.constant(
367
+ tmp.reshape(
368
+ norient * norient * norient,
369
+ (1 + nharm) * (1 + nharm) * (1 + nharm),
370
+ )
371
+ )
372
+ )
373
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
374
+ self._fft_3_orient[(norient, nharm, imaginary)],
375
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
376
+ )
241
377
 
242
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm)*(1+nharm)*(1+nharm))))
243
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
244
-
245
378
  # ---------------------------------------------−---------
246
379
  # -- BACKEND DEFINITION --
247
380
  # ---------------------------------------------−---------
248
- def bk_SparseTensor(self,indice,w,dense_shape=[]):
249
- if self.BACKEND==self.TENSORFLOW:
250
- return(self.backend.SparseTensor(indice,w,dense_shape=dense_shape))
251
- if self.BACKEND==self.TORCH:
252
- return(self.backend.sparse_coo_tensor(indice.T,w,dense_shape))
253
- if self.BACKEND==self.NUMPY:
254
- return self.scipy.sparse.coo_matrix((w,(indice[:,0],indice[:,1])),shape=dense_shape)
255
-
256
- def bk_stack(self,list,axis=0):
257
- if self.BACKEND==self.TENSORFLOW:
258
- return self.backend.stack(list,axis=axis)
259
- if self.BACKEND==self.TORCH:
260
- return self.backend.stack(list,axis=axis)
261
- if self.BACKEND==self.NUMPY:
262
- return self.backend.stack(list,axis=axis)
263
-
264
- def bk_sparse_dense_matmul(self,smat,mat):
265
- if self.BACKEND==self.TENSORFLOW:
266
- return self.backend.sparse.sparse_dense_matmul(smat,mat)
267
- if self.BACKEND==self.TORCH:
381
+ def bk_SparseTensor(self, indice, w, dense_shape=[]):
382
+ if self.BACKEND == self.TENSORFLOW:
383
+ return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
384
+ if self.BACKEND == self.TORCH:
385
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
386
+ if self.BACKEND == self.NUMPY:
387
+ return self.scipy.sparse.coo_matrix(
388
+ (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
389
+ )
390
+
391
+ def bk_stack(self, list, axis=0):
392
+ if self.BACKEND == self.TENSORFLOW:
393
+ return self.backend.stack(list, axis=axis)
394
+ if self.BACKEND == self.TORCH:
395
+ return self.backend.stack(list, axis=axis)
396
+ if self.BACKEND == self.NUMPY:
397
+ return self.backend.stack(list, axis=axis)
398
+
399
+ def bk_sparse_dense_matmul(self, smat, mat):
400
+ if self.BACKEND == self.TENSORFLOW:
401
+ return self.backend.sparse.sparse_dense_matmul(smat, mat)
402
+ if self.BACKEND == self.TORCH:
268
403
  return smat.matmul(mat)
269
- if self.BACKEND==self.NUMPY:
404
+ if self.BACKEND == self.NUMPY:
270
405
  return smat.dot(mat)
271
406
 
272
- def conv2d(self,x,w,strides=[1, 1, 1, 1],padding='SAME'):
273
- if self.BACKEND==self.TENSORFLOW:
274
- kx=w.shape[0]
275
- ky=w.shape[1]
276
- paddings = self.backend.constant([[0,0],
277
- [kx//2,kx//2],
278
- [ky//2,ky//2],
279
- [0,0]])
280
- tmp=self.backend.pad(x, paddings, "SYMMETRIC")
281
- return self.backend.nn.conv2d(tmp,w,
282
- strides=strides,
283
- padding="VALID")
407
+ def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
+ if self.BACKEND == self.TENSORFLOW:
409
+ kx = w.shape[0]
410
+ ky = w.shape[1]
411
+ paddings = self.backend.constant(
412
+ [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
+ )
414
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
+ return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
284
416
  # to be written!!!
285
- if self.BACKEND==self.TORCH:
417
+ if self.BACKEND == self.TORCH:
286
418
  return x
287
- if self.BACKEND==self.NUMPY:
288
- res=np.zeros([x.shape[0],x.shape[1],x.shape[2],w.shape[3]],dtype=x.dtype)
419
+ if self.BACKEND == self.NUMPY:
420
+ res = np.zeros(
421
+ [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
422
+ )
289
423
  for k in range(w.shape[2]):
290
- for l in range(w.shape[3]):
424
+ for l_orient in range(w.shape[3]):
291
425
  for j in range(res.shape[0]):
292
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
293
- res[j,:,:,l]+=tmp
426
+ tmp = self.scipy.signal.convolve2d(
427
+ x[j, :, :, k],
428
+ w[:, :, k, l_orient],
429
+ mode="same",
430
+ boundary="symm",
431
+ )
432
+ res[j, :, :, l_orient] += tmp
294
433
  del tmp
295
434
  return res
296
-
297
- def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
- if self.BACKEND==self.TENSORFLOW:
299
- kx=w.shape[0]
300
- paddings = self.backend.constant([[0,0],
301
- [kx//2,kx//2],
302
- [0,0]])
303
- tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
-
305
- return self.backend.nn.conv1d(tmp,w,
306
- stride=strides,
307
- padding="VALID")
435
+
436
+ def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
437
+ if self.BACKEND == self.TENSORFLOW:
438
+ kx = w.shape[0]
439
+ paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
440
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
441
+
442
+ return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
308
443
  # to be written!!!
309
- if self.BACKEND==self.TORCH:
444
+ if self.BACKEND == self.TORCH:
310
445
  return x
311
- if self.BACKEND==self.NUMPY:
312
- res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
446
+ if self.BACKEND == self.NUMPY:
447
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[2]], dtype=x.dtype)
313
448
  for k in range(w.shape[2]):
314
449
  for j in range(res.shape[0]):
315
- tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
- res[j,:,:,l]+=tmp
450
+ tmp = self.scipy.signal.convolve1d(
451
+ x[j, :, k], w[:, k], mode="same", boundary="symm"
452
+ )
453
+ res[j, :, :] += tmp
317
454
  del tmp
318
455
  return res
319
456
 
320
- def bk_threshold(self,x,threshold,greater=True):
457
+ def bk_threshold(self, x, threshold, greater=True):
321
458
 
322
- if self.BACKEND==self.TENSORFLOW:
323
- return(self.backend.cast(x>threshold,x.dtype)*x)
324
- if self.BACKEND==self.TORCH:
459
+ if self.BACKEND == self.TENSORFLOW:
460
+ return self.backend.cast(x > threshold, x.dtype) * x
461
+ if self.BACKEND == self.TORCH:
325
462
  x.to(x.dtype)
326
- return (x>threshold)*x
327
- #return(self.backend.cast(x>threshold,x.dtype)*x)
328
- if self.BACKEND==self.NUMPY:
329
- return (x>threshold)*x
330
-
331
- def bk_maximum(self,x1,x2):
332
- if self.BACKEND==self.TENSORFLOW:
333
- return(self.backend.maximum(x1,x2))
334
- if self.BACKEND==self.TORCH:
335
- return(self.backend.maximum(x1,x2))
336
- if self.BACKEND==self.NUMPY:
337
- return x1*(x1>x2)+x2*(x2>x1)
338
-
339
- def bk_device(self,device_name):
463
+ return (x > threshold) * x
464
+ # return(self.backend.cast(x>threshold,x.dtype)*x)
465
+ if self.BACKEND == self.NUMPY:
466
+ return (x > threshold) * x
467
+
468
+ def bk_maximum(self, x1, x2):
469
+ if self.BACKEND == self.TENSORFLOW:
470
+ return self.backend.maximum(x1, x2)
471
+ if self.BACKEND == self.TORCH:
472
+ return self.backend.maximum(x1, x2)
473
+ if self.BACKEND == self.NUMPY:
474
+ return x1 * (x1 > x2) + x2 * (x2 > x1)
475
+
476
+ def bk_device(self, device_name):
340
477
  return self.backend.device(device_name)
341
-
342
- def bk_ones(self,shape,dtype=None):
478
+
479
+ def bk_ones(self, shape, dtype=None):
343
480
  if dtype is None:
344
- dtype=self.all_type
345
- if self.BACKEND==self.TORCH:
481
+ dtype = self.all_type
482
+ if self.BACKEND == self.TORCH:
346
483
  return self.bk_cast(np.ones(shape))
347
- return(self.backend.ones(shape,dtype=dtype))
348
-
349
- def bk_conv1d(self,x,w):
350
- if self.BACKEND==self.TENSORFLOW:
351
- return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
352
- if self.BACKEND==self.TORCH:
484
+ return self.backend.ones(shape, dtype=dtype)
485
+
486
+ def bk_conv1d(self, x, w):
487
+ if self.BACKEND == self.TENSORFLOW:
488
+ return self.backend.nn.conv1d(x, w, stride=[1, 1, 1], padding="SAME")
489
+ if self.BACKEND == self.TORCH:
353
490
  # Torch not yet done !!!
354
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
355
- if self.BACKEND==self.NUMPY:
356
- res=np.zeros([x.shape[0],x.shape[1],w.shape[1]],dtype=x.dtype)
491
+ return self.backend.nn.conv1d(x, w, stride=1, padding="SAME")
492
+ if self.BACKEND == self.NUMPY:
493
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[1]], dtype=x.dtype)
357
494
  for k in range(w.shape[1]):
358
- for l in range(w.shape[2]):
359
- res[:,:,l]+=self.scipy.ndimage.convolve1d(x[:,:,k],w[:,k,l],axis=1,mode='constant',cval=0.0)
495
+ for l_orient in range(w.shape[2]):
496
+ res[:, :, l_orient] += self.scipy.ndimage.convolve1d(
497
+ x[:, :, k], w[:, k, l_orient], axis=1, mode="constant", cval=0.0
498
+ )
360
499
  return res
361
500
 
362
- def bk_flattenR(self,x):
363
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
501
+ def bk_flattenR(self, x):
502
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
364
503
  if self.bk_is_complex(x):
365
- rr=self.backend.reshape(self.bk_real(x),[np.prod(np.array(list(x.shape)))])
366
- ii=self.backend.reshape(self.bk_imag(x),[np.prod(np.array(list(x.shape)))])
367
- return self.bk_concat([rr,ii],axis=0)
504
+ rr = self.backend.reshape(
505
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
506
+ )
507
+ ii = self.backend.reshape(
508
+ self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
509
+ )
510
+ return self.bk_concat([rr, ii], axis=0)
368
511
  else:
369
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
370
-
371
- if self.BACKEND==self.NUMPY:
512
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
513
+
514
+ if self.BACKEND == self.NUMPY:
372
515
  if self.bk_is_complex(x):
373
- return np.concatenate([x.real.flatten(),x.imag.flatten()],0)
516
+ return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
374
517
  else:
375
518
  return x.flatten()
376
-
377
- def bk_flatten(self,x):
378
- if self.BACKEND==self.TENSORFLOW:
379
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
380
- if self.BACKEND==self.TORCH:
381
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
382
- if self.BACKEND==self.NUMPY:
519
+
520
+ def bk_flatten(self, x):
521
+ if self.BACKEND == self.TENSORFLOW:
522
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
523
+ if self.BACKEND == self.TORCH:
524
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
525
+ if self.BACKEND == self.NUMPY:
383
526
  return x.flatten()
384
527
 
385
- def bk_size(self,x):
386
- if self.BACKEND==self.TENSORFLOW:
528
+ def bk_size(self, x):
529
+ if self.BACKEND == self.TENSORFLOW:
387
530
  return self.backend.size(x)
388
- if self.BACKEND==self.TORCH:
531
+ if self.BACKEND == self.TORCH:
389
532
  return x.numel()
390
-
391
- if self.BACKEND==self.NUMPY:
533
+
534
+ if self.BACKEND == self.NUMPY:
392
535
  return x.size
393
-
394
- def bk_resize_image(self,x,shape):
395
- if self.BACKEND==self.TENSORFLOW:
396
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
397
-
398
- if self.BACKEND==self.TORCH:
399
- tmp=self.backend.nn.functional.interpolate(x,
400
- size=shape,
401
- mode='bilinear',
402
- align_corners=False)
536
+
537
+ def bk_resize_image(self, x, shape):
538
+ if self.BACKEND == self.TENSORFLOW:
539
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
540
+
541
+ if self.BACKEND == self.TORCH:
542
+ tmp = self.backend.nn.functional.interpolate(
543
+ x, size=shape, mode="bilinear", align_corners=False
544
+ )
403
545
  return self.bk_cast(tmp)
404
- if self.BACKEND==self.NUMPY:
405
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
406
-
407
- def bk_L1(self,x):
408
- if x.dtype==self.all_cbk_type:
409
- xr=self.bk_real(x)
410
- #xi=self.bk_imag(x)
411
-
412
- r=self.backend.sign(xr)*self.backend.sqrt(self.backend.sign(xr)*xr)
413
- return r
414
- #i=self.backend.sign(xi)*self.backend.sqrt(self.backend.sign(xi)*xi)
415
- """
416
- if self.BACKEND==self.TORCH:
546
+ if self.BACKEND == self.NUMPY:
547
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
+
549
+ def bk_L1(self, x):
550
+ if x.dtype == self.all_cbk_type:
551
+ xr = self.bk_real(x)
552
+ xi = self.bk_imag(x)
553
+
554
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
555
+ # return r
556
+ i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
557
+
558
+ if self.BACKEND == self.TORCH:
417
559
  return r
418
560
  else:
419
- return self.bk_complex(r,i)
420
- """
561
+ return self.bk_complex(r, i)
421
562
  else:
422
- return self.backend.sign(x)*self.backend.sqrt(self.backend.sign(x)*x)
423
-
424
- def bk_square_comp(self,x):
425
- if x.dtype==self.all_cbk_type:
426
- xr=self.bk_real(x)
427
- xi=self.bk_imag(x)
428
-
429
- r=xr*xr
430
- i=xi*xi
431
- return self.bk_complex(r,i)
563
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
564
+
565
+ def bk_square_comp(self, x):
566
+ if x.dtype == self.all_cbk_type:
567
+ xr = self.bk_real(x)
568
+ xi = self.bk_imag(x)
569
+
570
+ r = xr * xr
571
+ i = xi * xi
572
+ return self.bk_complex(r, i)
432
573
  else:
433
- return x*x
434
-
435
- def bk_reduce_sum(self,data,axis=None):
436
-
574
+ return x * x
575
+
576
+ def bk_reduce_sum(self, data, axis=None):
577
+
437
578
  if axis is None:
438
- if self.BACKEND==self.TENSORFLOW:
439
- return(self.backend.reduce_sum(data))
440
- if self.BACKEND==self.TORCH:
441
- return(self.backend.sum(data))
442
- if self.BACKEND==self.NUMPY:
443
- return(np.sum(data))
579
+ if self.BACKEND == self.TENSORFLOW:
580
+ return self.backend.reduce_sum(data)
581
+ if self.BACKEND == self.TORCH:
582
+ return self.backend.sum(data)
583
+ if self.BACKEND == self.NUMPY:
584
+ return np.sum(data)
444
585
  else:
445
- if self.BACKEND==self.TENSORFLOW:
446
- return(self.backend.reduce_sum(data,axis=axis))
447
- if self.BACKEND==self.TORCH:
448
- return(self.backend.sum(data,axis))
449
- if self.BACKEND==self.NUMPY:
450
- return(np.sum(data,axis))
451
-
586
+ if self.BACKEND == self.TENSORFLOW:
587
+ return self.backend.reduce_sum(data, axis=axis)
588
+ if self.BACKEND == self.TORCH:
589
+ return self.backend.sum(data, axis)
590
+ if self.BACKEND == self.NUMPY:
591
+ return np.sum(data, axis)
592
+
452
593
  # ---------------------------------------------−---------
453
-
454
- def iso_mean(self,x,use_2D=False):
455
- shape=list(x.shape)
456
-
457
- i_orient=2
594
+
595
+ def iso_mean(self, x, use_2D=False):
596
+ shape = list(x.shape)
597
+
598
+ i_orient = 2
458
599
  if use_2D:
459
- i_orient=3
460
- norient=shape[i_orient]
600
+ i_orient = 3
601
+ norient = shape[i_orient]
602
+
603
+ if len(shape) == i_orient + 1:
604
+ return self.bk_reduce_mean(x, -1)
461
605
 
462
- if len(shape)==i_orient+1:
463
- return self.bk_reduce_mean(x,-1)
464
-
465
606
  if norient not in self._iso_orient:
466
607
  self.calc_iso_orient(norient)
467
608
 
468
609
  if self.bk_is_complex(x):
469
- lmat = self._iso_orient_C[norient]
470
- lmat_T = self._iso_orient_C_T[norient]
610
+ lmat = self._iso_orient_C[norient]
471
611
  else:
472
- lmat = self._iso_orient[norient]
473
- lmat_T = self._iso_orient_T[norient]
474
-
475
- oshape=shape[0]
476
- for k in range(1,len(shape)-2):
477
- oshape*=shape[k]
478
-
479
- oshape2=[shape[k] for k in range(0,len(shape)-1)]
480
-
481
- return self.bk_reshape(self.backend.matmul(
482
- self.bk_reshape(x,[oshape,norient*norient]),lmat),oshape2)
612
+ lmat = self._iso_orient[norient]
613
+
614
+ oshape = shape[0]
615
+ for k in range(1, len(shape) - 2):
616
+ oshape *= shape[k]
483
617
 
484
-
485
- def fft_ang(self,x,nharm=1,imaginary=False,use_2D=False):
486
- shape=list(x.shape)
618
+ oshape2 = [shape[k] for k in range(0, len(shape) - 1)]
487
619
 
488
- i_orient=2
620
+ return self.bk_reshape(
621
+ self.backend.matmul(self.bk_reshape(x, [oshape, norient * norient]), lmat),
622
+ oshape2,
623
+ )
624
+
625
+ def fft_ang(self, x, nharm=1, imaginary=False, use_2D=False):
626
+ shape = list(x.shape)
627
+
628
+ i_orient = 2
489
629
  if use_2D:
490
- i_orient=3
491
-
492
- norient=shape[i_orient]
493
- nout=1+nharm
494
-
495
- oshape_1=shape[0]
496
- for k in range(1,i_orient):
497
- oshape_1*=shape[k]
498
- oshape_2=norient
499
- for k in range(i_orient,len(shape)-1):
500
- oshape_2*=shape[k]
501
- oshape=[oshape_1,oshape_2]
502
-
503
-
630
+ i_orient = 3
631
+
632
+ norient = shape[i_orient]
633
+ nout = 1 + nharm
634
+
635
+ oshape_1 = shape[0]
636
+ for k in range(1, i_orient):
637
+ oshape_1 *= shape[k]
638
+ oshape_2 = norient
639
+ for k in range(i_orient, len(shape) - 1):
640
+ oshape_2 *= shape[k]
641
+ oshape = [oshape_1, oshape_2]
642
+
504
643
  if imaginary:
505
- nout=1+nharm*2
506
-
507
- oshape2=[shape[k] for k in range(0,i_orient)]+[nout for k in range(i_orient,len(shape))]
508
-
509
- if (norient,nharm) not in self._fft_1_orient:
510
- self.calc_fft_orient(norient,nharm,imaginary)
644
+ nout = 1 + nharm * 2
511
645
 
512
- if len(shape)==i_orient+1:
646
+ oshape2 = [shape[k] for k in range(0, i_orient)] + [
647
+ nout for k in range(i_orient, len(shape))
648
+ ]
649
+
650
+ if (norient, nharm) not in self._fft_1_orient:
651
+ self.calc_fft_orient(norient, nharm, imaginary)
652
+
653
+ if len(shape) == i_orient + 1:
513
654
  if self.bk_is_complex(x):
514
- lmat = self._fft_1_orient_C[(norient,nharm,imaginary)]
655
+ lmat = self._fft_1_orient_C[(norient, nharm, imaginary)]
515
656
  else:
516
- lmat = self._fft_1_orient[(norient,nharm,imaginary)]
517
-
518
- if len(shape)==i_orient+2:
657
+ lmat = self._fft_1_orient[(norient, nharm, imaginary)]
658
+
659
+ if len(shape) == i_orient + 2:
519
660
  if self.bk_is_complex(x):
520
- lmat = self._fft_2_orient_C[(norient,nharm,imaginary)]
661
+ lmat = self._fft_2_orient_C[(norient, nharm, imaginary)]
521
662
  else:
522
- lmat = self._fft_2_orient[(norient,nharm,imaginary)]
523
-
524
- if len(shape)==i_orient+3:
663
+ lmat = self._fft_2_orient[(norient, nharm, imaginary)]
664
+
665
+ if len(shape) == i_orient + 3:
525
666
  if self.bk_is_complex(x):
526
- lmat = self._fft_3_orient_C[(norient,nharm,imaginary)]
667
+ lmat = self._fft_3_orient_C[(norient, nharm, imaginary)]
527
668
  else:
528
- lmat = self._fft_3_orient[(norient,nharm,imaginary)]
529
-
530
- return self.bk_reshape(self.backend.matmul(self.bk_reshape(x,oshape),lmat),oshape2)
531
-
532
- def constant(self,data):
533
-
534
- if self.BACKEND==self.TENSORFLOW:
535
- return(self.backend.constant(data))
536
- return(data)
669
+ lmat = self._fft_3_orient[(norient, nharm, imaginary)]
670
+
671
+ return self.bk_reshape(
672
+ self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
673
+ )
674
+
675
+ def constant(self, data):
676
+
677
+ if self.BACKEND == self.TENSORFLOW:
678
+ return self.backend.constant(data)
679
+ return data
680
+
681
+ def bk_reduce_mean(self, data, axis=None):
537
682
 
538
- def bk_reduce_mean(self,data,axis=None):
539
-
540
683
  if axis is None:
541
- if self.BACKEND==self.TENSORFLOW:
542
- return(self.backend.reduce_mean(data))
543
- if self.BACKEND==self.TORCH:
544
- return(self.backend.mean(data))
545
- if self.BACKEND==self.NUMPY:
546
- return(np.mean(data))
684
+ if self.BACKEND == self.TENSORFLOW:
685
+ return self.backend.reduce_mean(data)
686
+ if self.BACKEND == self.TORCH:
687
+ return self.backend.mean(data)
688
+ if self.BACKEND == self.NUMPY:
689
+ return np.mean(data)
547
690
  else:
548
- if self.BACKEND==self.TENSORFLOW:
549
- return(self.backend.reduce_mean(data,axis=axis))
550
- if self.BACKEND==self.TORCH:
551
- return(self.backend.mean(data,axis))
552
- if self.BACKEND==self.NUMPY:
553
- return(np.mean(data,axis))
554
-
555
- def bk_reduce_min(self,data,axis=None):
556
-
691
+ if self.BACKEND == self.TENSORFLOW:
692
+ return self.backend.reduce_mean(data, axis=axis)
693
+ if self.BACKEND == self.TORCH:
694
+ return self.backend.mean(data, axis)
695
+ if self.BACKEND == self.NUMPY:
696
+ return np.mean(data, axis)
697
+
698
+ def bk_reduce_min(self, data, axis=None):
699
+
557
700
  if axis is None:
558
- if self.BACKEND==self.TENSORFLOW:
559
- return(self.backend.reduce_min(data))
560
- if self.BACKEND==self.TORCH:
561
- return(self.backend.min(data))
562
- if self.BACKEND==self.NUMPY:
563
- return(np.min(data))
701
+ if self.BACKEND == self.TENSORFLOW:
702
+ return self.backend.reduce_min(data)
703
+ if self.BACKEND == self.TORCH:
704
+ return self.backend.min(data)
705
+ if self.BACKEND == self.NUMPY:
706
+ return np.min(data)
564
707
  else:
565
- if self.BACKEND==self.TENSORFLOW:
566
- return(self.backend.reduce_min(data,axis=axis))
567
- if self.BACKEND==self.TORCH:
568
- return(self.backend.min(data,axis))
569
- if self.BACKEND==self.NUMPY:
570
- return(np.min(data,axis))
571
-
572
- def bk_random_seed(self,value):
573
-
574
- if self.BACKEND==self.TENSORFLOW:
575
- return(self.backend.random.set_seed(value))
576
- if self.BACKEND==self.TORCH:
577
- return(self.backend.random.set_seed(value))
578
- if self.BACKEND==self.NUMPY:
579
- return(np.random.seed(value))
580
-
581
- def bk_random_uniform(self,shape):
582
-
583
- if self.BACKEND==self.TENSORFLOW:
584
- return(self.backend.random.uniform(shape))
585
- if self.BACKEND==self.TORCH:
586
- return(self.backend.random.uniform(shape))
587
- if self.BACKEND==self.NUMPY:
588
- return(np.random.rand(shape))
589
-
590
- def bk_reduce_std(self,data,axis=None):
591
-
708
+ if self.BACKEND == self.TENSORFLOW:
709
+ return self.backend.reduce_min(data, axis=axis)
710
+ if self.BACKEND == self.TORCH:
711
+ return self.backend.min(data, axis)
712
+ if self.BACKEND == self.NUMPY:
713
+ return np.min(data, axis)
714
+
715
+ def bk_random_seed(self, value):
716
+
717
+ if self.BACKEND == self.TENSORFLOW:
718
+ return self.backend.random.set_seed(value)
719
+ if self.BACKEND == self.TORCH:
720
+ return self.backend.random.set_seed(value)
721
+ if self.BACKEND == self.NUMPY:
722
+ return np.random.seed(value)
723
+
724
+ def bk_random_uniform(self, shape):
725
+
726
+ if self.BACKEND == self.TENSORFLOW:
727
+ return self.backend.random.uniform(shape)
728
+ if self.BACKEND == self.TORCH:
729
+ return self.backend.random.uniform(shape)
730
+ if self.BACKEND == self.NUMPY:
731
+ return np.random.rand(shape)
732
+
733
+ def bk_reduce_std(self, data, axis=None):
592
734
  if axis is None:
593
- if self.BACKEND==self.TENSORFLOW:
594
- return(self.backend.math.reduce_std(data))
595
- if self.BACKEND==self.TORCH:
596
- return(self.backend.std(data))
597
- if self.BACKEND==self.NUMPY:
598
- return(np.std(data))
735
+ if self.BACKEND == self.TENSORFLOW:
736
+ r = self.backend.math.reduce_std(data)
737
+ if self.BACKEND == self.TORCH:
738
+ r = self.backend.std(data)
739
+ if self.BACKEND == self.NUMPY:
740
+ r = np.std(data)
741
+ return self.bk_complex(r, 0 * r)
599
742
  else:
600
- if self.BACKEND==self.TENSORFLOW:
601
- return(self.backend.math.reduce_std(data,axis=axis))
602
- if self.BACKEND==self.TORCH:
603
- return(self.backend.std(data,axis))
604
- if self.BACKEND==self.NUMPY:
605
- return(np.std(data,axis))
606
-
607
-
608
- def bk_sqrt(self,data):
609
-
610
- return(self.backend.sqrt(self.backend.abs(data)))
611
-
612
- def bk_abs(self,data):
613
- return(self.backend.abs(data))
743
+ if self.BACKEND == self.TENSORFLOW:
744
+ r = self.backend.math.reduce_std(data, axis=axis)
745
+ if self.BACKEND == self.TORCH:
746
+ r = self.backend.std(data, axis)
747
+ if self.BACKEND == self.NUMPY:
748
+ r = np.std(data, axis)
749
+ if self.bk_is_complex(data):
750
+ return self.bk_complex(r, 0 * r)
751
+ else:
752
+ return r
614
753
 
615
- def bk_is_complex(self,data):
616
-
617
- if self.BACKEND==self.TENSORFLOW:
618
- if isinstance(data,np.ndarray):
619
- return (data.dtype=='complex64' or data.dtype=='complex128')
754
+ def bk_sqrt(self, data):
755
+
756
+ return self.backend.sqrt(self.backend.abs(data))
757
+
758
+ def bk_abs(self, data):
759
+ return self.backend.abs(data)
760
+
761
+ def bk_is_complex(self, data):
762
+
763
+ if self.BACKEND == self.TENSORFLOW:
764
+ if isinstance(data, np.ndarray):
765
+ return data.dtype == "complex64" or data.dtype == "complex128"
620
766
  return data.dtype.is_complex
621
-
622
- if self.BACKEND==self.TORCH:
623
- if isinstance(data,np.ndarray):
624
- return (data.dtype=='complex64' or data.dtype=='complex128')
625
-
767
+
768
+ if self.BACKEND == self.TORCH:
769
+ if isinstance(data, np.ndarray):
770
+ return data.dtype == "complex64" or data.dtype == "complex128"
771
+
626
772
  return data.dtype.is_complex
627
-
628
- if self.BACKEND==self.NUMPY:
629
- return (data.dtype=='complex64' or data.dtype=='complex128')
630
773
 
631
- def bk_distcomp(self,data):
774
+ if self.BACKEND == self.NUMPY:
775
+ return data.dtype == "complex64" or data.dtype == "complex128"
776
+
777
+ def bk_distcomp(self, data):
632
778
  if self.bk_is_complex(data):
633
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
779
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
780
+ self.bk_imag(data)
781
+ )
634
782
  return res
635
783
  else:
636
784
  return self.bk_square(data)
637
-
638
- def bk_norm(self,data):
785
+
786
+ def bk_norm(self, data):
639
787
  if self.bk_is_complex(data):
640
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
788
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
789
+ self.bk_imag(data)
790
+ )
641
791
  return self.bk_sqrt(res)
642
792
 
643
793
  else:
644
794
  return self.bk_abs(data)
645
-
646
- def bk_square(self,data):
647
-
648
- if self.BACKEND==self.TENSORFLOW:
649
- return(self.backend.square(data))
650
- if self.BACKEND==self.TORCH:
651
- return(self.backend.square(data))
652
- if self.BACKEND==self.NUMPY:
653
- return(data*data)
654
-
655
- def bk_log(self,data):
656
- if self.BACKEND==self.TENSORFLOW:
657
- return(self.backend.math.log(data))
658
- if self.BACKEND==self.TORCH:
659
- return(self.backend.log(data))
660
- if self.BACKEND==self.NUMPY:
661
- return(np.log(data))
662
-
663
- def bk_matmul(self,a,b):
664
- if self.BACKEND==self.TENSORFLOW:
665
- return(self.backend.matmul(a,b))
666
- if self.BACKEND==self.TORCH:
667
- return(self.backend.matmul(a,b))
668
- if self.BACKEND==self.NUMPY:
669
- return(np.dot(a,b))
670
-
671
- def bk_tensor(self,data):
672
- if self.BACKEND==self.TENSORFLOW:
673
- return(self.backend.constant(data))
674
- if self.BACKEND==self.TORCH:
675
- return(self.backend.constant(data))
676
- if self.BACKEND==self.NUMPY:
677
- return(data)
678
-
679
- def bk_complex(self,real,imag):
680
- if self.BACKEND==self.TENSORFLOW:
681
- return(self.backend.dtypes.complex(real,imag))
682
- if self.BACKEND==self.TORCH:
683
- return(self.backend.complex(real,imag))
684
- if self.BACKEND==self.NUMPY:
685
- return real+1J*imag
686
-
687
- def bk_exp(self,data):
688
-
689
- return(self.backend.exp(data))
690
-
691
- def bk_min(self,data):
692
-
693
- return(self.backend.reduce_min(data))
694
-
695
- def bk_argmin(self,data):
696
-
697
- return(self.backend.argmin(data))
698
-
699
- def bk_tanh(self,data):
700
-
701
- return(self.backend.math.tanh(data))
702
-
703
- def bk_max(self,data):
704
-
705
- return(self.backend.reduce_max(data))
706
-
707
- def bk_argmax(self,data):
708
-
709
- return(self.backend.argmax(data))
710
-
711
- def bk_reshape(self,data,shape):
712
- if self.BACKEND==self.TORCH:
713
- if isinstance(data,np.ndarray):
795
+
796
+ def bk_square(self, data):
797
+
798
+ if self.BACKEND == self.TENSORFLOW:
799
+ return self.backend.square(data)
800
+ if self.BACKEND == self.TORCH:
801
+ return self.backend.square(data)
802
+ if self.BACKEND == self.NUMPY:
803
+ return data * data
804
+
805
+ def bk_log(self, data):
806
+ if self.BACKEND == self.TENSORFLOW:
807
+ return self.backend.math.log(data)
808
+ if self.BACKEND == self.TORCH:
809
+ return self.backend.log(data)
810
+ if self.BACKEND == self.NUMPY:
811
+ return np.log(data)
812
+
813
+ def bk_matmul(self, a, b):
814
+ if self.BACKEND == self.TENSORFLOW:
815
+ return self.backend.matmul(a, b)
816
+ if self.BACKEND == self.TORCH:
817
+ return self.backend.matmul(a, b)
818
+ if self.BACKEND == self.NUMPY:
819
+ return np.dot(a, b)
820
+
821
+ def bk_tensor(self, data):
822
+ if self.BACKEND == self.TENSORFLOW:
823
+ return self.backend.constant(data)
824
+ if self.BACKEND == self.TORCH:
825
+ return self.backend.constant(data)
826
+ if self.BACKEND == self.NUMPY:
827
+ return data
828
+
829
+ def bk_complex(self, real, imag):
830
+ if self.BACKEND == self.TENSORFLOW:
831
+ return self.backend.dtypes.complex(real, imag)
832
+ if self.BACKEND == self.TORCH:
833
+ return self.backend.complex(real, imag)
834
+ if self.BACKEND == self.NUMPY:
835
+ return real + 1j * imag
836
+
837
+ def bk_exp(self, data):
838
+
839
+ return self.backend.exp(data)
840
+
841
+ def bk_min(self, data):
842
+
843
+ return self.backend.reduce_min(data)
844
+
845
+ def bk_argmin(self, data):
846
+
847
+ return self.backend.argmin(data)
848
+
849
+ def bk_tanh(self, data):
850
+
851
+ return self.backend.math.tanh(data)
852
+
853
+ def bk_max(self, data):
854
+
855
+ return self.backend.reduce_max(data)
856
+
857
+ def bk_argmax(self, data):
858
+
859
+ return self.backend.argmax(data)
860
+
861
+ def bk_reshape(self, data, shape):
862
+ if self.BACKEND == self.TORCH:
863
+ if isinstance(data, np.ndarray):
714
864
  return data.reshape(shape)
715
-
716
- return(self.backend.reshape(data,shape))
717
-
718
- def bk_repeat(self,data,nn,axis=0):
719
- return(self.backend.repeat(data,nn,axis=axis))
720
-
721
- def bk_tile(self,data,nn,axis=0):
722
- return(self.backend.tile(data,nn))
723
-
724
- def bk_roll(self,data,nn,axis=0):
725
- return(self.backend.roll(data,nn,axis=axis))
726
-
727
- def bk_expand_dims(self,data,axis=0):
728
- if self.BACKEND==self.TENSORFLOW:
729
- return(self.backend.expand_dims(data,axis=axis))
730
- if self.BACKEND==self.TORCH:
731
- if isinstance(data,np.ndarray):
732
- data=self.backend.from_numpy(data)
733
- return(self.backend.unsqueeze(data,axis))
734
- if self.BACKEND==self.NUMPY:
735
- return(np.expand_dims(data,axis))
736
-
737
- def bk_transpose(self,data,thelist):
738
- if self.BACKEND==self.TENSORFLOW:
739
- return(self.backend.transpose(data,thelist))
740
- if self.BACKEND==self.TORCH:
741
- return(self.backend.transpose(data,thelist))
742
- if self.BACKEND==self.NUMPY:
743
- return(np.transpose(data,thelist))
744
-
745
- def bk_concat(self,data,axis=None):
746
-
747
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
865
+
866
+ return self.backend.reshape(data, shape)
867
+
868
+ def bk_repeat(self, data, nn, axis=0):
869
+ return self.backend.repeat(data, nn, axis=axis)
870
+
871
+ def bk_tile(self, data, nn, axis=0):
872
+ return self.backend.tile(data, nn)
873
+
874
+ def bk_roll(self, data, nn, axis=0):
875
+ return self.backend.roll(data, nn, axis=axis)
876
+
877
+ def bk_expand_dims(self, data, axis=0):
878
+ if self.BACKEND == self.TENSORFLOW:
879
+ return self.backend.expand_dims(data, axis=axis)
880
+ if self.BACKEND == self.TORCH:
881
+ if isinstance(data, np.ndarray):
882
+ data = self.backend.from_numpy(data)
883
+ return self.backend.unsqueeze(data, axis)
884
+ if self.BACKEND == self.NUMPY:
885
+ return np.expand_dims(data, axis)
886
+
887
+ def bk_transpose(self, data, thelist):
888
+ if self.BACKEND == self.TENSORFLOW:
889
+ return self.backend.transpose(data, thelist)
890
+ if self.BACKEND == self.TORCH:
891
+ return self.backend.transpose(data, thelist)
892
+ if self.BACKEND == self.NUMPY:
893
+ return np.transpose(data, thelist)
894
+
895
+ def bk_concat(self, data, axis=None):
896
+
897
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
748
898
  if axis is None:
749
- if data[0].dtype==self.all_cbk_type:
750
- ndata=len(data)
751
- xr=self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
752
- xi=self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
753
- return self.backend.complex(xr,xi)
899
+ if data[0].dtype == self.all_cbk_type:
900
+ ndata = len(data)
901
+ xr = self.backend.concat(
902
+ [self.bk_real(data[k]) for k in range(ndata)]
903
+ )
904
+ xi = self.backend.concat(
905
+ [self.bk_imag(data[k]) for k in range(ndata)]
906
+ )
907
+ return self.backend.complex(xr, xi)
754
908
  else:
755
- return(self.backend.concat(data))
909
+ return self.backend.concat(data)
756
910
  else:
757
- if data[0].dtype==self.all_cbk_type:
758
- ndata=len(data)
759
- xr=self.backend.concat([self.bk_real(data[k]) for k in range(ndata)],axis=axis)
760
- xi=self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)],axis=axis)
761
- return self.backend.complex(xr,xi)
911
+ if data[0].dtype == self.all_cbk_type:
912
+ ndata = len(data)
913
+ xr = self.backend.concat(
914
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
915
+ )
916
+ xi = self.backend.concat(
917
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
918
+ )
919
+ return self.backend.complex(xr, xi)
762
920
  else:
763
- return(self.backend.concat(data,axis=axis))
921
+ return self.backend.concat(data, axis=axis)
764
922
  else:
765
923
  if axis is None:
766
- return np.concatenate(data,axis=0)
924
+ return np.concatenate(data, axis=0)
767
925
  else:
768
- return np.concatenate(data,axis=axis)
926
+ return np.concatenate(data, axis=axis)
927
+
928
+ def bk_zeros(self, shape,dtype=None):
929
+ if self.BACKEND == self.TENSORFLOW:
930
+ return self.backend.zeros(shape,dtype=dtype)
931
+ if self.BACKEND == self.TORCH:
932
+ return self.backend.zeros(shape,dtype=dtype)
933
+ if self.BACKEND == self.NUMPY:
934
+ return np.zeros(shape,dtype=dtype)
935
+
936
+ def bk_fft(self, data):
937
+ if self.BACKEND == self.TENSORFLOW:
938
+ return self.backend.signal.fft(data)
939
+ if self.BACKEND == self.TORCH:
940
+ return self.backend.fft(data)
941
+ if self.BACKEND == self.NUMPY:
942
+ return self.backend.fft.fft(data)
943
+
944
+ def bk_conjugate(self, data):
769
945
 
770
-
771
- def bk_conjugate(self,data):
772
-
773
- if self.BACKEND==self.TENSORFLOW:
946
+ if self.BACKEND == self.TENSORFLOW:
774
947
  return self.backend.math.conj(data)
775
- if self.BACKEND==self.TORCH:
948
+ if self.BACKEND == self.TORCH:
776
949
  return self.backend.conj(data)
777
- if self.BACKEND==self.NUMPY:
950
+ if self.BACKEND == self.NUMPY:
778
951
  return data.conjugate()
779
-
780
- def bk_real(self,data):
781
- if self.BACKEND==self.TENSORFLOW:
952
+
953
+ def bk_real(self, data):
954
+ if self.BACKEND == self.TENSORFLOW:
782
955
  return self.backend.math.real(data)
783
- if self.BACKEND==self.TORCH:
956
+ if self.BACKEND == self.TORCH:
784
957
  return data.real
785
- if self.BACKEND==self.NUMPY:
958
+ if self.BACKEND == self.NUMPY:
786
959
  return data.real
787
960
 
788
- def bk_imag(self,data):
789
- if self.BACKEND==self.TENSORFLOW:
961
+ def bk_imag(self, data):
962
+ if self.BACKEND == self.TENSORFLOW:
790
963
  return self.backend.math.imag(data)
791
- if self.BACKEND==self.TORCH:
792
- if data.dtype==self.all_cbk_type:
964
+ if self.BACKEND == self.TORCH:
965
+ if data.dtype == self.all_cbk_type:
793
966
  return data.imag
794
967
  else:
795
968
  return 0
796
-
797
- if self.BACKEND==self.NUMPY:
969
+
970
+ if self.BACKEND == self.NUMPY:
798
971
  return data.imag
799
-
800
- def bk_relu(self,x):
801
- if self.BACKEND==self.TENSORFLOW:
802
- if x.dtype==self.all_cbk_type:
803
- xr=self.backend.nn.relu(self.bk_real(x))
804
- xi=self.backend.nn.relu(self.bk_imag(x))
805
- return self.backend.complex(xr,xi)
972
+
973
+ def bk_relu(self, x):
974
+ if self.BACKEND == self.TENSORFLOW:
975
+ if x.dtype == self.all_cbk_type:
976
+ xr = self.backend.nn.relu(self.bk_real(x))
977
+ xi = self.backend.nn.relu(self.bk_imag(x))
978
+ return self.backend.complex(xr, xi)
806
979
  else:
807
980
  return self.backend.nn.relu(x)
808
- if self.BACKEND==self.TORCH:
981
+ if self.BACKEND == self.TORCH:
809
982
  return self.backend.relu(x)
810
- if self.BACKEND==self.NUMPY:
811
- return (x>0)*x
812
-
813
- def bk_cast(self,x):
814
- if isinstance(x,np.float64):
815
- if self.all_bk_type=='float32':
816
- return(np.float32(x))
983
+ if self.BACKEND == self.NUMPY:
984
+ return (x > 0) * x
985
+
986
+ def bk_cast(self, x):
987
+ if isinstance(x, np.float64):
988
+ if self.all_bk_type == "float32":
989
+ return np.float32(x)
817
990
  else:
818
- return(x)
819
- if isinstance(x,np.float32):
820
- if self.all_bk_type=='float64':
821
- return(np.float64(x))
991
+ return x
992
+ if isinstance(x, np.float32):
993
+ if self.all_bk_type == "float64":
994
+ return np.float64(x)
822
995
  else:
823
- return(x)
996
+ return x
824
997
 
825
- if isinstance(x,np.int32) or isinstance(x,np.int64) or isinstance(x,int):
826
- if self.all_bk_type=='float64':
827
- return(np.float64(x))
998
+ if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
999
+ if self.all_bk_type == "float64":
1000
+ return np.float64(x)
828
1001
  else:
829
- return(np.float32(x))
830
-
1002
+ return np.float32(x)
1003
+
831
1004
  if self.bk_is_complex(x):
832
- out_type=self.all_cbk_type
1005
+ out_type = self.all_cbk_type
833
1006
  else:
834
- out_type=self.all_bk_type
835
-
836
- if self.BACKEND==self.TENSORFLOW:
837
- return self.backend.cast(x,out_type)
838
-
839
- if self.BACKEND==self.TORCH:
840
- if isinstance(x,np.ndarray):
841
- x=self.backend.from_numpy(x)
842
-
1007
+ out_type = self.all_bk_type
1008
+
1009
+ if self.BACKEND == self.TENSORFLOW:
1010
+ return self.backend.cast(x, out_type)
1011
+
1012
+ if self.BACKEND == self.TORCH:
1013
+ if isinstance(x, np.ndarray):
1014
+ x = self.backend.from_numpy(x)
1015
+
843
1016
  if x.dtype.is_complex:
844
- out_type=self.all_cbk_type
1017
+ out_type = self.all_cbk_type
845
1018
  else:
846
- out_type=self.all_bk_type
847
-
1019
+ out_type = self.all_bk_type
1020
+
848
1021
  return x.type(out_type)
849
-
850
- if self.BACKEND==self.NUMPY:
1022
+
1023
+ if self.BACKEND == self.NUMPY:
851
1024
  return x.astype(out_type)