foscat 3.1.5__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -1,851 +1,1005 @@
1
1
  import sys
2
+
2
3
  import numpy as np
3
4
 
5
+
4
6
  class foscat_backend:
5
-
6
- def __init__(self,name,mpi_rank=0,all_type='float64',gpupos=0,silent=False):
7
-
8
- self.TENSORFLOW=1
9
- self.TORCH=2
10
- self.NUMPY=3
11
-
7
+
8
+ def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
9
+
10
+ self.TENSORFLOW = 1
11
+ self.TORCH = 2
12
+ self.NUMPY = 3
13
+
12
14
  # table use to compute the iso orientation rotation
13
- self._iso_orient={}
14
- self._iso_orient_T={}
15
- self._iso_orient_C={}
16
- self._iso_orient_C_T={}
17
- self._fft_1_orient={}
18
- self._fft_1_orient_C={}
19
- self._fft_2_orient={}
20
- self._fft_2_orient_C={}
21
- self._fft_3_orient={}
22
- self._fft_3_orient_C={}
23
-
24
- self.BACKEND=name
25
-
26
- if name not in ['tensorflow','torch','numpy']:
27
- print('Backend "%s" not yet implemented'%(name))
28
- print(' Choose inside the next 3 available backends :')
29
- print(' - tensorflow')
30
- print(' - torch')
31
- print(' - numpy (Impossible to do synthesis using numpy)')
15
+ self._iso_orient = {}
16
+ self._iso_orient_T = {}
17
+ self._iso_orient_C = {}
18
+ self._iso_orient_C_T = {}
19
+ self._fft_1_orient = {}
20
+ self._fft_1_orient_C = {}
21
+ self._fft_2_orient = {}
22
+ self._fft_2_orient_C = {}
23
+ self._fft_3_orient = {}
24
+ self._fft_3_orient_C = {}
25
+
26
+ self.BACKEND = name
27
+
28
+ if name not in ["tensorflow", "torch", "numpy"]:
29
+ print('Backend "%s" not yet implemented' % (name))
30
+ print(" Choose inside the next 3 available backends :")
31
+ print(" - tensorflow")
32
+ print(" - torch")
33
+ print(" - numpy (Impossible to do synthesis using numpy)")
32
34
  return None
33
-
34
- if self.BACKEND=='tensorflow':
35
+
36
+ if self.BACKEND == "tensorflow":
35
37
  import tensorflow as tf
36
-
37
- self.backend=tf
38
- self.BACKEND=self.TENSORFLOW
39
- #tf.config.threading.set_inter_op_parallelism_threads(1)
40
- #tf.config.threading.set_intra_op_parallelism_threads(1)
38
+
39
+ self.backend = tf
40
+ self.BACKEND = self.TENSORFLOW
41
+ # tf.config.threading.set_inter_op_parallelism_threads(1)
42
+ # tf.config.threading.set_intra_op_parallelism_threads(1)
41
43
  self.tf_function = tf.function
42
44
 
43
- if self.BACKEND=='torch':
45
+ if self.BACKEND == "torch":
44
46
  import torch
45
- self.BACKEND=self.TORCH
46
- self.backend=torch
47
+
48
+ self.BACKEND = self.TORCH
49
+ self.backend = torch
47
50
  self.tf_function = self.tf_loc_function
48
-
49
- if self.BACKEND=='numpy':
50
- self.BACKEND=self.NUMPY
51
- self.backend=np
51
+
52
+ if self.BACKEND == "numpy":
53
+ self.BACKEND = self.NUMPY
54
+ self.backend = np
52
55
  import scipy as scipy
53
- self.scipy=scipy
56
+
57
+ self.scipy = scipy
54
58
  self.tf_function = self.tf_loc_function
55
-
56
- self.float64=self.backend.float64
57
- self.float32=self.backend.float32
58
- self.int64=self.backend.int64
59
- self.int32=self.backend.int32
60
- self.complex64=self.backend.complex128
61
- self.complex128=self.backend.complex64
62
-
63
- if all_type=='float32':
64
- self.all_bk_type=self.backend.float32
65
- self.all_cbk_type=self.backend.complex64
59
+
60
+ self.float64 = self.backend.float64
61
+ self.float32 = self.backend.float32
62
+ self.int64 = self.backend.int64
63
+ self.int32 = self.backend.int32
64
+ self.complex64 = self.backend.complex128
65
+ self.complex128 = self.backend.complex64
66
+
67
+ if all_type == "float32":
68
+ self.all_bk_type = self.backend.float32
69
+ self.all_cbk_type = self.backend.complex64
66
70
  else:
67
- if all_type=='float64':
68
- self.all_type='float64'
69
- self.all_bk_type=self.backend.float64
70
- self.all_cbk_type=self.backend.complex128
71
+ if all_type == "float64":
72
+ self.all_type = "float64"
73
+ self.all_bk_type = self.backend.float64
74
+ self.all_cbk_type = self.backend.complex128
71
75
  else:
72
- print('ERROR INIT FOCUS ',all_type,' should be float32 or float64')
76
+ print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
73
77
  return None
74
- #===========================================================================
75
- # INIT
76
- if mpi_rank==0:
77
- if self.BACKEND==self.TENSORFLOW and silent==False:
78
- print("Num GPUs Available: ", len(self.backend.config.experimental.list_physical_devices('GPU')))
78
+ # ===========================================================================
79
+ # INIT
80
+ if mpi_rank == 0:
81
+ if self.BACKEND == self.TENSORFLOW and not silent:
82
+ print(
83
+ "Num GPUs Available: ",
84
+ len(self.backend.config.experimental.list_physical_devices("GPU")),
85
+ )
79
86
  sys.stdout.flush()
80
-
81
- if self.BACKEND==self.TENSORFLOW:
87
+
88
+ if self.BACKEND == self.TENSORFLOW:
82
89
  self.backend.debugging.set_log_device_placement(False)
83
90
  self.backend.config.set_soft_device_placement(True)
84
-
85
- gpus = self.backend.config.experimental.list_physical_devices('GPU')
86
-
87
- if self.BACKEND==self.TORCH:
88
- gpus=torch.cuda.is_available()
89
-
90
- if self.BACKEND==self.NUMPY:
91
- gpus=[]
92
- gpuname='CPU:0'
93
- self.gpulist={}
94
- self.gpulist[0]=gpuname
95
- self.ngpu=1
96
-
91
+
92
+ gpus = self.backend.config.experimental.list_physical_devices("GPU")
93
+
94
+ if self.BACKEND == self.TORCH:
95
+ gpus = torch.cuda.is_available()
96
+
97
+ if self.BACKEND == self.NUMPY:
98
+ gpus = []
99
+ gpuname = "CPU:0"
100
+ self.gpulist = {}
101
+ self.gpulist[0] = gpuname
102
+ self.ngpu = 1
103
+
97
104
  if gpus:
98
105
  try:
99
- if self.BACKEND==self.TENSORFLOW:
100
- # Currently, memory growth needs to be the same across GPUs
106
+ if self.BACKEND == self.TENSORFLOW:
107
+ # Currently, memory growth needs to be the same across GPUs
101
108
  for gpu in gpus:
102
109
  self.backend.config.experimental.set_memory_growth(gpu, True)
103
- logical_gpus = self.backend.config.experimental.list_logical_devices('GPU')
104
- print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
110
+ logical_gpus = (
111
+ self.backend.config.experimental.list_logical_devices("GPU")
112
+ )
113
+ print(
114
+ len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
115
+ )
105
116
  sys.stdout.flush()
106
- self.ngpu=len(logical_gpus)
107
- gpuname=logical_gpus[gpupos%self.ngpu].name
108
- self.gpulist={}
117
+ self.ngpu = len(logical_gpus)
118
+ gpuname = logical_gpus[gpupos % self.ngpu].name
119
+ self.gpulist = {}
109
120
  for i in range(self.ngpu):
110
- self.gpulist[i]=logical_gpus[i].name
111
- if self.BACKEND==self.TORCH:
112
- self.ngpu=torch.cuda.device_count()
113
- self.gpulist={}
121
+ self.gpulist[i] = logical_gpus[i].name
122
+ if self.BACKEND == self.TORCH:
123
+ self.ngpu = torch.cuda.device_count()
124
+ self.gpulist = {}
114
125
  for k in range(self.ngpu):
115
- self.gpulist[k]=torch.cuda.get_device_name(0)
126
+ self.gpulist[k] = torch.cuda.get_device_name(0)
116
127
 
117
128
  except RuntimeError as e:
118
129
  # Memory growth must be set before GPUs have been initialized
119
130
  print(e)
120
-
121
- def tf_loc_function(self,func):
131
+
132
+ def tf_loc_function(self, func):
122
133
  return func
123
-
124
- def calc_iso_orient(self,norient):
125
- tmp=np.zeros([norient*norient,norient])
134
+
135
+ def calc_iso_orient(self, norient):
136
+ tmp = np.zeros([norient * norient, norient])
126
137
  for i in range(norient):
127
138
  for j in range(norient):
128
- tmp[j*norient+(j+i)%norient,i]=0.25
129
-
130
- self._iso_orient[norient]=self.constant(self.bk_cast(tmp))
131
- self._iso_orient_T[norient]=self.constant(self.bk_cast(4*tmp.T))
132
- self._iso_orient_C[norient]=self.bk_complex(self._iso_orient[norient],0*self._iso_orient[norient])
133
- self._iso_orient_C_T[norient]=self.bk_complex(self._iso_orient_T[norient],0*self._iso_orient_T[norient])
134
-
135
- def calc_fft_orient(self,norient,nharm,imaginary):
136
-
137
- x=np.arange(norient)/norient*2*np.pi
138
-
139
+ tmp[j * norient + (j + i) % norient, i] = 0.25
140
+
141
+ self._iso_orient[norient] = self.constant(self.bk_cast(tmp))
142
+ self._iso_orient_T[norient] = self.constant(self.bk_cast(4 * tmp.T))
143
+ self._iso_orient_C[norient] = self.bk_complex(
144
+ self._iso_orient[norient], 0 * self._iso_orient[norient]
145
+ )
146
+ self._iso_orient_C_T[norient] = self.bk_complex(
147
+ self._iso_orient_T[norient], 0 * self._iso_orient_T[norient]
148
+ )
149
+
150
+ def calc_fft_orient(self, norient, nharm, imaginary):
151
+
152
+ x = np.arange(norient) / norient * 2 * np.pi
153
+
139
154
  if imaginary:
140
- tmp=np.zeros([norient,1+nharm*2])
141
- tmp[:,0]=1.0
155
+ tmp = np.zeros([norient, 1 + nharm * 2])
156
+ tmp[:, 0] = 1.0
142
157
  for k in range(nharm):
143
- tmp[:,k*2+1]=np.cos(x*(k+1))
144
- tmp[:,k*2+2]=np.sin(x*(k+1))
145
-
146
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
147
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
158
+ tmp[:, k * 2 + 1] = np.cos(x * (k + 1))
159
+ tmp[:, k * 2 + 2] = np.sin(x * (k + 1))
160
+
161
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
162
+ self.constant(tmp)
163
+ )
164
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
165
+ self._fft_1_orient[(norient, nharm, imaginary)],
166
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
167
+ )
148
168
  else:
149
- tmp=np.zeros([norient,1+nharm])
150
- for k in range(nharm+1):
151
- tmp[:,k]=np.cos(x*k)
152
-
153
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
154
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
169
+ tmp = np.zeros([norient, 1 + nharm])
170
+ for k in range(nharm + 1):
171
+ tmp[:, k] = np.cos(x * k)
155
172
 
156
- x=np.repeat(x,norient).reshape(norient,norient)
173
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
174
+ self.constant(tmp)
175
+ )
176
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
177
+ self._fft_1_orient[(norient, nharm, imaginary)],
178
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
179
+ )
180
+
181
+ x = np.repeat(x, norient).reshape(norient, norient)
157
182
 
158
183
  if imaginary:
159
- tmp=np.zeros([norient,norient,(1+nharm*2),(1+nharm*2)])
160
- tmp[:,:,0,0]=1.0
184
+ tmp = np.zeros([norient, norient, (1 + nharm * 2), (1 + nharm * 2)])
185
+ tmp[:, :, 0, 0] = 1.0
161
186
  for k in range(nharm):
162
- tmp[:,:,k*2+1,0]=np.cos(x*(k+1))
163
- tmp[:,:,k*2+2,0]=np.sin(x*(k+1))
164
- tmp[:,:,0,k*2+1]=np.cos((x.T)*(k+1))
165
- tmp[:,:,0,k*2+2]=np.sin((x.T)*(k+1))
166
- for l in range(nharm):
167
- tmp[:,:,k*2+1,l*2+1]=np.cos(x*(k+1))*np.cos((x.T)*(l+1))
168
- tmp[:,:,k*2+2,l*2+1]=np.sin(x*(k+1))*np.cos((x.T)*(l+1))
169
- tmp[:,:,k*2+1,l*2+2]=np.cos(x*(k+1))*np.sin((x.T)*(l+1))
170
- tmp[:,:,k*2+2,l*2+2]=np.sin(x*(k+1))*np.sin((x.T)*(l+1))
171
-
172
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+2*nharm)*(1+2*nharm))))
173
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
187
+ tmp[:, :, k * 2 + 1, 0] = np.cos(x * (k + 1))
188
+ tmp[:, :, k * 2 + 2, 0] = np.sin(x * (k + 1))
189
+ tmp[:, :, 0, k * 2 + 1] = np.cos((x.T) * (k + 1))
190
+ tmp[:, :, 0, k * 2 + 2] = np.sin((x.T) * (k + 1))
191
+ for l_orient in range(nharm):
192
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 1] = np.cos(x * (k + 1)) * np.cos(
193
+ (x.T) * (l_orient + 1)
194
+ )
195
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 1] = np.sin(x * (k + 1)) * np.cos(
196
+ (x.T) * (l_orient + 1)
197
+ )
198
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 2] = np.cos(x * (k + 1)) * np.sin(
199
+ (x.T) * (l_orient + 1)
200
+ )
201
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 2] = np.sin(x * (k + 1)) * np.sin(
202
+ (x.T) * (l_orient + 1)
203
+ )
204
+
205
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
206
+ self.constant(
207
+ tmp.reshape(norient * norient, (1 + 2 * nharm) * (1 + 2 * nharm))
208
+ )
209
+ )
210
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
211
+ self._fft_2_orient[(norient, nharm, imaginary)],
212
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
213
+ )
174
214
  else:
175
- tmp=np.zeros([norient,norient,(1+nharm),(1+nharm)])
176
-
177
- for k in range(nharm+1):
178
- for l in range(nharm+1):
179
- tmp[:,:,k,l]=np.cos(x*k)*np.cos((x.T)*l)
180
-
181
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+nharm)*(1+nharm))))
182
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
183
-
184
- x=np.arange(norient)/norient*2*np.pi
185
- xx=np.zeros([norient,norient,norient])
186
- yy=np.zeros([norient,norient,norient])
187
- zz=np.zeros([norient,norient,norient])
215
+ tmp = np.zeros([norient, norient, (1 + nharm), (1 + nharm)])
216
+
217
+ for k in range(nharm + 1):
218
+ for l_orient in range(nharm + 1):
219
+ tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
220
+
221
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
222
+ self.constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
223
+ )
224
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
225
+ self._fft_2_orient[(norient, nharm, imaginary)],
226
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
227
+ )
228
+
229
+ x = np.arange(norient) / norient * 2 * np.pi
230
+ xx = np.zeros([norient, norient, norient])
231
+ yy = np.zeros([norient, norient, norient])
232
+ zz = np.zeros([norient, norient, norient])
188
233
  for i in range(norient):
189
234
  for j in range(norient):
190
- xx[:,i,j]=x
191
- yy[i,:,j]=x
192
- zz[i,j,:]=x
193
-
235
+ xx[:, i, j] = x
236
+ yy[i, :, j] = x
237
+ zz[i, j, :] = x
238
+
194
239
  if imaginary:
195
- tmp=np.ones([norient,norient,norient,(1+nharm*2),(1+nharm*2),(1+nharm*2)])
196
-
240
+ tmp = np.ones(
241
+ [
242
+ norient,
243
+ norient,
244
+ norient,
245
+ (1 + nharm * 2),
246
+ (1 + nharm * 2),
247
+ (1 + nharm * 2),
248
+ ]
249
+ )
250
+
197
251
  for k in range(nharm):
198
- tmp[:,:,:,k*2+1,0,0]=np.cos(xx*(k+1))
199
- tmp[:,:,:,0,k*2+1,0]=np.cos(yy*(k+1))
200
- tmp[:,:,:,0,0,k*2+1]=np.cos(zz*(k+1))
201
-
202
- tmp[:,:,:,k*2+2,0,0]=np.sin(xx*(k+1))
203
- tmp[:,:,:,0,k*2+2,0]=np.sin(yy*(k+1))
204
- tmp[:,:,:,0,0,k*2+2]=np.sin(zz*(k+1))
205
- for l in range(nharm):
206
- tmp[:,:,:,k*2+1,l*2+1,0]=np.cos(xx*(k+1))*np.cos(yy*(l+1))
207
- tmp[:,:,:,k*2+1,l*2+2,0]=np.cos(xx*(k+1))*np.sin(yy*(l+1))
208
- tmp[:,:,:,k*2+2,l*2+1,0]=np.sin(xx*(k+1))*np.cos(yy*(l+1))
209
- tmp[:,:,:,k*2+2,l*2+2,0]=np.sin(xx*(k+1))*np.sin(yy*(l+1))
210
-
211
- tmp[:,:,:,k*2+1,0,l*2+1]=np.cos(xx*(k+1))*np.cos(zz*(l+1))
212
- tmp[:,:,:,k*2+1,0,l*2+2]=np.cos(xx*(k+1))*np.sin(zz*(l+1))
213
- tmp[:,:,:,k*2+2,0,l*2+1]=np.sin(xx*(k+1))*np.cos(zz*(l+1))
214
- tmp[:,:,:,k*2+2,0,l*2+2]=np.sin(xx*(k+1))*np.sin(zz*(l+1))
215
-
216
- tmp[:,:,:,0,k*2+1,l*2+1]=np.cos(yy*(k+1))*np.cos(zz*(l+1))
217
- tmp[:,:,:,0,k*2+1,l*2+2]=np.cos(yy*(k+1))*np.sin(zz*(l+1))
218
- tmp[:,:,:,0,k*2+2,l*2+1]=np.sin(yy*(k+1))*np.cos(zz*(l+1))
219
- tmp[:,:,:,0,k*2+2,l*2+2]=np.sin(yy*(k+1))*np.sin(zz*(l+1))
220
-
252
+ tmp[:, :, :, k * 2 + 1, 0, 0] = np.cos(xx * (k + 1))
253
+ tmp[:, :, :, 0, k * 2 + 1, 0] = np.cos(yy * (k + 1))
254
+ tmp[:, :, :, 0, 0, k * 2 + 1] = np.cos(zz * (k + 1))
255
+
256
+ tmp[:, :, :, k * 2 + 2, 0, 0] = np.sin(xx * (k + 1))
257
+ tmp[:, :, :, 0, k * 2 + 2, 0] = np.sin(yy * (k + 1))
258
+ tmp[:, :, :, 0, 0, k * 2 + 2] = np.sin(zz * (k + 1))
259
+ for l_orient in range(nharm):
260
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, 0] = np.cos(
261
+ xx * (k + 1)
262
+ ) * np.cos(yy * (l_orient + 1))
263
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, 0] = np.cos(
264
+ xx * (k + 1)
265
+ ) * np.sin(yy * (l_orient + 1))
266
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, 0] = np.sin(
267
+ xx * (k + 1)
268
+ ) * np.cos(yy * (l_orient + 1))
269
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, 0] = np.sin(
270
+ xx * (k + 1)
271
+ ) * np.sin(yy * (l_orient + 1))
272
+
273
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 1] = np.cos(
274
+ xx * (k + 1)
275
+ ) * np.cos(zz * (l_orient + 1))
276
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 2] = np.cos(
277
+ xx * (k + 1)
278
+ ) * np.sin(zz * (l_orient + 1))
279
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 1] = np.sin(
280
+ xx * (k + 1)
281
+ ) * np.cos(zz * (l_orient + 1))
282
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 2] = np.sin(
283
+ xx * (k + 1)
284
+ ) * np.sin(zz * (l_orient + 1))
285
+
286
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 1] = np.cos(
287
+ yy * (k + 1)
288
+ ) * np.cos(zz * (l_orient + 1))
289
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 2] = np.cos(
290
+ yy * (k + 1)
291
+ ) * np.sin(zz * (l_orient + 1))
292
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 1] = np.sin(
293
+ yy * (k + 1)
294
+ ) * np.cos(zz * (l_orient + 1))
295
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 2] = np.sin(
296
+ yy * (k + 1)
297
+ ) * np.sin(zz * (l_orient + 1))
298
+
221
299
  for m in range(nharm):
222
- tmp[:,:,:,k*2+1,l*2+1,m*2+1]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
223
- tmp[:,:,:,k*2+1,l*2+1,m*2+2]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
224
- tmp[:,:,:,k*2+1,l*2+2,m*2+1]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
225
- tmp[:,:,:,k*2+1,l*2+2,m*2+2]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
226
- tmp[:,:,:,k*2+2,l*2+1,m*2+1]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
227
- tmp[:,:,:,k*2+2,l*2+1,m*2+2]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
228
- tmp[:,:,:,k*2+2,l*2+2,m*2+1]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
229
- tmp[:,:,:,k*2+2,l*2+2,m*2+2]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
230
-
231
-
232
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm*2)*(1+nharm*2)*(1+nharm*2))))
233
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
300
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 1] = (
301
+ np.cos(xx * (k + 1))
302
+ * np.cos(yy * (l_orient + 1))
303
+ * np.cos(zz * (m + 1))
304
+ )
305
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 2] = (
306
+ np.cos(xx * (k + 1))
307
+ * np.cos(yy * (l_orient + 1))
308
+ * np.sin(zz * (m + 1))
309
+ )
310
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 1] = (
311
+ np.cos(xx * (k + 1))
312
+ * np.sin(yy * (l_orient + 1))
313
+ * np.cos(zz * (m + 1))
314
+ )
315
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 2] = (
316
+ np.cos(xx * (k + 1))
317
+ * np.sin(yy * (l_orient + 1))
318
+ * np.sin(zz * (m + 1))
319
+ )
320
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 1] = (
321
+ np.sin(xx * (k + 1))
322
+ * np.cos(yy * (l_orient + 1))
323
+ * np.cos(zz * (m + 1))
324
+ )
325
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 2] = (
326
+ np.sin(xx * (k + 1))
327
+ * np.cos(yy * (l_orient + 1))
328
+ * np.sin(zz * (m + 1))
329
+ )
330
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 1] = (
331
+ np.sin(xx * (k + 1))
332
+ * np.sin(yy * (l_orient + 1))
333
+ * np.cos(zz * (m + 1))
334
+ )
335
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 2] = (
336
+ np.sin(xx * (k + 1))
337
+ * np.sin(yy * (l_orient + 1))
338
+ * np.sin(zz * (m + 1))
339
+ )
340
+
341
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
342
+ self.constant(
343
+ tmp.reshape(
344
+ norient * norient * norient,
345
+ (1 + nharm * 2) * (1 + nharm * 2) * (1 + nharm * 2),
346
+ )
347
+ )
348
+ )
349
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
350
+ self._fft_3_orient[(norient, nharm, imaginary)],
351
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
352
+ )
234
353
  else:
235
- tmp=np.zeros([norient,norient,norient,(1+nharm),(1+nharm),(1+nharm)])
236
-
237
- for k in range(nharm+1):
238
- for l in range(nharm+1):
239
- for m in range(nharm+1):
240
- tmp[:,:,:,k,l,m]=np.cos(xx*k)*np.cos(yy*l)*np.cos(zz*m)
241
-
242
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm)*(1+nharm)*(1+nharm))))
243
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
244
-
354
+ tmp = np.zeros(
355
+ [norient, norient, norient, (1 + nharm), (1 + nharm), (1 + nharm)]
356
+ )
357
+
358
+ for k in range(nharm + 1):
359
+ for l_orient in range(nharm + 1):
360
+ for m in range(nharm + 1):
361
+ tmp[:, :, :, k, l_orient, m] = (
362
+ np.cos(xx * k) * np.cos(yy * l_orient) * np.cos(zz * m)
363
+ )
364
+
365
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
366
+ self.constant(
367
+ tmp.reshape(
368
+ norient * norient * norient,
369
+ (1 + nharm) * (1 + nharm) * (1 + nharm),
370
+ )
371
+ )
372
+ )
373
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
374
+ self._fft_3_orient[(norient, nharm, imaginary)],
375
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
376
+ )
377
+
245
378
  # ---------------------------------------------−---------
246
379
  # -- BACKEND DEFINITION --
247
380
  # ---------------------------------------------−---------
248
- def bk_SparseTensor(self,indice,w,dense_shape=[]):
249
- if self.BACKEND==self.TENSORFLOW:
250
- return(self.backend.SparseTensor(indice,w,dense_shape=dense_shape))
251
- if self.BACKEND==self.TORCH:
252
- return(self.backend.sparse_coo_tensor(indice.T,w,dense_shape))
253
- if self.BACKEND==self.NUMPY:
254
- return self.scipy.sparse.coo_matrix((w,(indice[:,0],indice[:,1])),shape=dense_shape)
255
-
256
- def bk_stack(self,list,axis=0):
257
- if self.BACKEND==self.TENSORFLOW:
258
- return self.backend.stack(list,axis=axis)
259
- if self.BACKEND==self.TORCH:
260
- return self.backend.stack(list,axis=axis)
261
- if self.BACKEND==self.NUMPY:
262
- return self.backend.stack(list,axis=axis)
263
-
264
- def bk_sparse_dense_matmul(self,smat,mat):
265
- if self.BACKEND==self.TENSORFLOW:
266
- return self.backend.sparse.sparse_dense_matmul(smat,mat)
267
- if self.BACKEND==self.TORCH:
381
+ def bk_SparseTensor(self, indice, w, dense_shape=[]):
382
+ if self.BACKEND == self.TENSORFLOW:
383
+ return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
384
+ if self.BACKEND == self.TORCH:
385
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
386
+ if self.BACKEND == self.NUMPY:
387
+ return self.scipy.sparse.coo_matrix(
388
+ (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
389
+ )
390
+
391
+ def bk_stack(self, list, axis=0):
392
+ if self.BACKEND == self.TENSORFLOW:
393
+ return self.backend.stack(list, axis=axis)
394
+ if self.BACKEND == self.TORCH:
395
+ return self.backend.stack(list, axis=axis)
396
+ if self.BACKEND == self.NUMPY:
397
+ return self.backend.stack(list, axis=axis)
398
+
399
+ def bk_sparse_dense_matmul(self, smat, mat):
400
+ if self.BACKEND == self.TENSORFLOW:
401
+ return self.backend.sparse.sparse_dense_matmul(smat, mat)
402
+ if self.BACKEND == self.TORCH:
268
403
  return smat.matmul(mat)
269
- if self.BACKEND==self.NUMPY:
404
+ if self.BACKEND == self.NUMPY:
270
405
  return smat.dot(mat)
271
406
 
272
- def conv2d(self,x,w,strides=[1, 1, 1, 1],padding='SAME'):
273
- if self.BACKEND==self.TENSORFLOW:
274
- kx=w.shape[0]
275
- ky=w.shape[1]
276
- paddings = self.backend.constant([[0,0],
277
- [kx//2,kx//2],
278
- [ky//2,ky//2],
279
- [0,0]])
280
- tmp=self.backend.pad(x, paddings, "SYMMETRIC")
281
- return self.backend.nn.conv2d(tmp,w,
282
- strides=strides,
283
- padding="VALID")
407
+ def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
+ if self.BACKEND == self.TENSORFLOW:
409
+ kx = w.shape[0]
410
+ ky = w.shape[1]
411
+ paddings = self.backend.constant(
412
+ [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
+ )
414
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
+ return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
284
416
  # to be written!!!
285
- if self.BACKEND==self.TORCH:
417
+ if self.BACKEND == self.TORCH:
286
418
  return x
287
- if self.BACKEND==self.NUMPY:
288
- res=np.zeros([x.shape[0],x.shape[1],x.shape[2],w.shape[3]],dtype=x.dtype)
419
+ if self.BACKEND == self.NUMPY:
420
+ res = np.zeros(
421
+ [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
422
+ )
289
423
  for k in range(w.shape[2]):
290
- for l in range(w.shape[3]):
424
+ for l_orient in range(w.shape[3]):
291
425
  for j in range(res.shape[0]):
292
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
293
- res[j,:,:,l]+=tmp
426
+ tmp = self.scipy.signal.convolve2d(
427
+ x[j, :, :, k], w[:, :, k, l_orient], mode="same", boundary="symm"
428
+ )
429
+ res[j, :, :, l_orient] += tmp
294
430
  del tmp
295
431
  return res
296
-
297
- def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
- if self.BACKEND==self.TENSORFLOW:
299
- kx=w.shape[0]
300
- paddings = self.backend.constant([[0,0],
301
- [kx//2,kx//2],
302
- [0,0]])
303
- tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
-
305
- return self.backend.nn.conv1d(tmp,w,
306
- stride=strides,
307
- padding="VALID")
432
+
433
+ def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
434
+ if self.BACKEND == self.TENSORFLOW:
435
+ kx = w.shape[0]
436
+ paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
437
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
438
+
439
+ return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
308
440
  # to be written!!!
309
- if self.BACKEND==self.TORCH:
441
+ if self.BACKEND == self.TORCH:
310
442
  return x
311
- if self.BACKEND==self.NUMPY:
312
- res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
443
+ if self.BACKEND == self.NUMPY:
444
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[2]], dtype=x.dtype)
313
445
  for k in range(w.shape[2]):
314
446
  for j in range(res.shape[0]):
315
- tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
- res[j,:,:,l]+=tmp
447
+ tmp = self.scipy.signal.convolve1d(
448
+ x[j, :, k], w[:, k], mode="same", boundary="symm"
449
+ )
450
+ res[j, :, :] += tmp
317
451
  del tmp
318
452
  return res
319
453
 
320
- def bk_threshold(self,x,threshold,greater=True):
454
+ def bk_threshold(self, x, threshold, greater=True):
321
455
 
322
- if self.BACKEND==self.TENSORFLOW:
323
- return(self.backend.cast(x>threshold,x.dtype)*x)
324
- if self.BACKEND==self.TORCH:
456
+ if self.BACKEND == self.TENSORFLOW:
457
+ return self.backend.cast(x > threshold, x.dtype) * x
458
+ if self.BACKEND == self.TORCH:
325
459
  x.to(x.dtype)
326
- return (x>threshold)*x
327
- #return(self.backend.cast(x>threshold,x.dtype)*x)
328
- if self.BACKEND==self.NUMPY:
329
- return (x>threshold)*x
330
-
331
- def bk_maximum(self,x1,x2):
332
- if self.BACKEND==self.TENSORFLOW:
333
- return(self.backend.maximum(x1,x2))
334
- if self.BACKEND==self.TORCH:
335
- return(self.backend.maximum(x1,x2))
336
- if self.BACKEND==self.NUMPY:
337
- return x1*(x1>x2)+x2*(x2>x1)
338
-
339
- def bk_device(self,device_name):
460
+ return (x > threshold) * x
461
+ # return(self.backend.cast(x>threshold,x.dtype)*x)
462
+ if self.BACKEND == self.NUMPY:
463
+ return (x > threshold) * x
464
+
465
+ def bk_maximum(self, x1, x2):
466
+ if self.BACKEND == self.TENSORFLOW:
467
+ return self.backend.maximum(x1, x2)
468
+ if self.BACKEND == self.TORCH:
469
+ return self.backend.maximum(x1, x2)
470
+ if self.BACKEND == self.NUMPY:
471
+ return x1 * (x1 > x2) + x2 * (x2 > x1)
472
+
473
+ def bk_device(self, device_name):
340
474
  return self.backend.device(device_name)
341
-
342
- def bk_ones(self,shape,dtype=None):
475
+
476
+ def bk_ones(self, shape, dtype=None):
343
477
  if dtype is None:
344
- dtype=self.all_type
345
- if self.BACKEND==self.TORCH:
478
+ dtype = self.all_type
479
+ if self.BACKEND == self.TORCH:
346
480
  return self.bk_cast(np.ones(shape))
347
- return(self.backend.ones(shape,dtype=dtype))
348
-
349
- def bk_conv1d(self,x,w):
350
- if self.BACKEND==self.TENSORFLOW:
351
- return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
352
- if self.BACKEND==self.TORCH:
481
+ return self.backend.ones(shape, dtype=dtype)
482
+
483
+ def bk_conv1d(self, x, w):
484
+ if self.BACKEND == self.TENSORFLOW:
485
+ return self.backend.nn.conv1d(x, w, stride=[1, 1, 1], padding="SAME")
486
+ if self.BACKEND == self.TORCH:
353
487
  # Torch not yet done !!!
354
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
355
- if self.BACKEND==self.NUMPY:
356
- res=np.zeros([x.shape[0],x.shape[1],w.shape[1]],dtype=x.dtype)
488
+ return self.backend.nn.conv1d(x, w, stride=1, padding="SAME")
489
+ if self.BACKEND == self.NUMPY:
490
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[1]], dtype=x.dtype)
357
491
  for k in range(w.shape[1]):
358
- for l in range(w.shape[2]):
359
- res[:,:,l]+=self.scipy.ndimage.convolve1d(x[:,:,k],w[:,k,l],axis=1,mode='constant',cval=0.0)
492
+ for l_orient in range(w.shape[2]):
493
+ res[:, :, l_orient] += self.scipy.ndimage.convolve1d(
494
+ x[:, :, k], w[:, k, l_orient], axis=1, mode="constant", cval=0.0
495
+ )
360
496
  return res
361
497
 
362
- def bk_flattenR(self,x):
363
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
498
+ def bk_flattenR(self, x):
499
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
364
500
  if self.bk_is_complex(x):
365
- rr=self.backend.reshape(self.bk_real(x),[np.prod(np.array(list(x.shape)))])
366
- ii=self.backend.reshape(self.bk_imag(x),[np.prod(np.array(list(x.shape)))])
367
- return self.bk_concat([rr,ii],axis=0)
501
+ rr = self.backend.reshape(
502
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
503
+ )
504
+ ii = self.backend.reshape(
505
+ self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
506
+ )
507
+ return self.bk_concat([rr, ii], axis=0)
368
508
  else:
369
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
370
-
371
- if self.BACKEND==self.NUMPY:
509
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
510
+
511
+ if self.BACKEND == self.NUMPY:
372
512
  if self.bk_is_complex(x):
373
- return np.concatenate([x.real.flatten(),x.imag.flatten()],0)
513
+ return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
374
514
  else:
375
515
  return x.flatten()
376
-
377
- def bk_flatten(self,x):
378
- if self.BACKEND==self.TENSORFLOW:
379
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
380
- if self.BACKEND==self.TORCH:
381
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
382
- if self.BACKEND==self.NUMPY:
516
+
517
+ def bk_flatten(self, x):
518
+ if self.BACKEND == self.TENSORFLOW:
519
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
520
+ if self.BACKEND == self.TORCH:
521
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
522
+ if self.BACKEND == self.NUMPY:
383
523
  return x.flatten()
384
524
 
385
- def bk_size(self,x):
386
- if self.BACKEND==self.TENSORFLOW:
525
+ def bk_size(self, x):
526
+ if self.BACKEND == self.TENSORFLOW:
387
527
  return self.backend.size(x)
388
- if self.BACKEND==self.TORCH:
528
+ if self.BACKEND == self.TORCH:
389
529
  return x.numel()
390
-
391
- if self.BACKEND==self.NUMPY:
530
+
531
+ if self.BACKEND == self.NUMPY:
392
532
  return x.size
393
-
394
- def bk_resize_image(self,x,shape):
395
- if self.BACKEND==self.TENSORFLOW:
396
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
397
-
398
- if self.BACKEND==self.TORCH:
399
- tmp=self.backend.nn.functional.interpolate(x,
400
- size=shape,
401
- mode='bilinear',
402
- align_corners=False)
533
+
534
+ def bk_resize_image(self, x, shape):
535
+ if self.BACKEND == self.TENSORFLOW:
536
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
537
+
538
+ if self.BACKEND == self.TORCH:
539
+ tmp = self.backend.nn.functional.interpolate(
540
+ x, size=shape, mode="bilinear", align_corners=False
541
+ )
403
542
  return self.bk_cast(tmp)
404
- if self.BACKEND==self.NUMPY:
405
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
406
-
407
- def bk_L1(self,x):
408
- if x.dtype==self.all_cbk_type:
409
- xr=self.bk_real(x)
410
- #xi=self.bk_imag(x)
411
-
412
- r=self.backend.sign(xr)*self.backend.sqrt(self.backend.sign(xr)*xr)
413
- return r
414
- #i=self.backend.sign(xi)*self.backend.sqrt(self.backend.sign(xi)*xi)
415
- """
543
+ if self.BACKEND == self.NUMPY:
544
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
545
+
546
+ def bk_L1(self, x):
547
+ if x.dtype == self.all_cbk_type:
548
+ xr = self.bk_real(x)
549
+ xi = self.bk_imag(x)
550
+
551
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
552
+ #return r
553
+ i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
554
+
416
555
  if self.BACKEND==self.TORCH:
417
556
  return r
418
557
  else:
419
558
  return self.bk_complex(r,i)
420
- """
421
559
  else:
422
- return self.backend.sign(x)*self.backend.sqrt(self.backend.sign(x)*x)
423
-
424
- def bk_square_comp(self,x):
425
- if x.dtype==self.all_cbk_type:
426
- xr=self.bk_real(x)
427
- xi=self.bk_imag(x)
428
-
429
- r=xr*xr
430
- i=xi*xi
431
- return self.bk_complex(r,i)
560
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
561
+
562
+ def bk_square_comp(self, x):
563
+ if x.dtype == self.all_cbk_type:
564
+ xr = self.bk_real(x)
565
+ xi = self.bk_imag(x)
566
+
567
+ r = xr * xr
568
+ i = xi * xi
569
+ return self.bk_complex(r, i)
432
570
  else:
433
- return x*x
434
-
435
- def bk_reduce_sum(self,data,axis=None):
436
-
571
+ return x * x
572
+
573
+ def bk_reduce_sum(self, data, axis=None):
574
+
437
575
  if axis is None:
438
- if self.BACKEND==self.TENSORFLOW:
439
- return(self.backend.reduce_sum(data))
440
- if self.BACKEND==self.TORCH:
441
- return(self.backend.sum(data))
442
- if self.BACKEND==self.NUMPY:
443
- return(np.sum(data))
576
+ if self.BACKEND == self.TENSORFLOW:
577
+ return self.backend.reduce_sum(data)
578
+ if self.BACKEND == self.TORCH:
579
+ return self.backend.sum(data)
580
+ if self.BACKEND == self.NUMPY:
581
+ return np.sum(data)
444
582
  else:
445
- if self.BACKEND==self.TENSORFLOW:
446
- return(self.backend.reduce_sum(data,axis=axis))
447
- if self.BACKEND==self.TORCH:
448
- return(self.backend.sum(data,axis))
449
- if self.BACKEND==self.NUMPY:
450
- return(np.sum(data,axis))
451
-
583
+ if self.BACKEND == self.TENSORFLOW:
584
+ return self.backend.reduce_sum(data, axis=axis)
585
+ if self.BACKEND == self.TORCH:
586
+ return self.backend.sum(data, axis)
587
+ if self.BACKEND == self.NUMPY:
588
+ return np.sum(data, axis)
589
+
452
590
  # ---------------------------------------------−---------
453
-
454
- def iso_mean(self,x,use_2D=False):
455
- shape=list(x.shape)
456
-
457
- i_orient=2
591
+
592
+ def iso_mean(self, x, use_2D=False):
593
+ shape = list(x.shape)
594
+
595
+ i_orient = 2
458
596
  if use_2D:
459
- i_orient=3
460
- norient=shape[i_orient]
597
+ i_orient = 3
598
+ norient = shape[i_orient]
599
+
600
+ if len(shape) == i_orient + 1:
601
+ return self.bk_reduce_mean(x, -1)
461
602
 
462
- if len(shape)==i_orient+1:
463
- return self.bk_reduce_mean(x,-1)
464
-
465
603
  if norient not in self._iso_orient:
466
604
  self.calc_iso_orient(norient)
467
605
 
468
606
  if self.bk_is_complex(x):
469
- lmat = self._iso_orient_C[norient]
470
- lmat_T = self._iso_orient_C_T[norient]
607
+ lmat = self._iso_orient_C[norient]
471
608
  else:
472
- lmat = self._iso_orient[norient]
473
- lmat_T = self._iso_orient_T[norient]
609
+ lmat = self._iso_orient[norient]
474
610
 
475
- oshape=shape[0]
476
- for k in range(1,len(shape)-2):
477
- oshape*=shape[k]
478
-
479
- oshape2=[shape[k] for k in range(0,len(shape)-1)]
480
-
481
- return self.bk_reshape(self.backend.matmul(
482
- self.bk_reshape(x,[oshape,norient*norient]),lmat),oshape2)
611
+ oshape = shape[0]
612
+ for k in range(1, len(shape) - 2):
613
+ oshape *= shape[k]
614
+
615
+ oshape2 = [shape[k] for k in range(0, len(shape) - 1)]
616
+
617
+ return self.bk_reshape(
618
+ self.backend.matmul(self.bk_reshape(x, [oshape, norient * norient]), lmat),
619
+ oshape2,
620
+ )
483
621
 
484
-
485
- def fft_ang(self,x,nharm=1,imaginary=False,use_2D=False):
486
- shape=list(x.shape)
622
+ def fft_ang(self, x, nharm=1, imaginary=False, use_2D=False):
623
+ shape = list(x.shape)
487
624
 
488
- i_orient=2
625
+ i_orient = 2
489
626
  if use_2D:
490
- i_orient=3
491
-
492
- norient=shape[i_orient]
493
- nout=1+nharm
494
-
495
- oshape_1=shape[0]
496
- for k in range(1,i_orient):
497
- oshape_1*=shape[k]
498
- oshape_2=norient
499
- for k in range(i_orient,len(shape)-1):
500
- oshape_2*=shape[k]
501
- oshape=[oshape_1,oshape_2]
502
-
503
-
627
+ i_orient = 3
628
+
629
+ norient = shape[i_orient]
630
+ nout = 1 + nharm
631
+
632
+ oshape_1 = shape[0]
633
+ for k in range(1, i_orient):
634
+ oshape_1 *= shape[k]
635
+ oshape_2 = norient
636
+ for k in range(i_orient, len(shape) - 1):
637
+ oshape_2 *= shape[k]
638
+ oshape = [oshape_1, oshape_2]
639
+
504
640
  if imaginary:
505
- nout=1+nharm*2
506
-
507
- oshape2=[shape[k] for k in range(0,i_orient)]+[nout for k in range(i_orient,len(shape))]
508
-
509
- if (norient,nharm) not in self._fft_1_orient:
510
- self.calc_fft_orient(norient,nharm,imaginary)
641
+ nout = 1 + nharm * 2
642
+
643
+ oshape2 = [shape[k] for k in range(0, i_orient)] + [
644
+ nout for k in range(i_orient, len(shape))
645
+ ]
646
+
647
+ if (norient, nharm) not in self._fft_1_orient:
648
+ self.calc_fft_orient(norient, nharm, imaginary)
511
649
 
512
- if len(shape)==i_orient+1:
650
+ if len(shape) == i_orient + 1:
513
651
  if self.bk_is_complex(x):
514
- lmat = self._fft_1_orient_C[(norient,nharm,imaginary)]
652
+ lmat = self._fft_1_orient_C[(norient, nharm, imaginary)]
515
653
  else:
516
- lmat = self._fft_1_orient[(norient,nharm,imaginary)]
517
-
518
- if len(shape)==i_orient+2:
654
+ lmat = self._fft_1_orient[(norient, nharm, imaginary)]
655
+
656
+ if len(shape) == i_orient + 2:
519
657
  if self.bk_is_complex(x):
520
- lmat = self._fft_2_orient_C[(norient,nharm,imaginary)]
658
+ lmat = self._fft_2_orient_C[(norient, nharm, imaginary)]
521
659
  else:
522
- lmat = self._fft_2_orient[(norient,nharm,imaginary)]
523
-
524
- if len(shape)==i_orient+3:
660
+ lmat = self._fft_2_orient[(norient, nharm, imaginary)]
661
+
662
+ if len(shape) == i_orient + 3:
525
663
  if self.bk_is_complex(x):
526
- lmat = self._fft_3_orient_C[(norient,nharm,imaginary)]
664
+ lmat = self._fft_3_orient_C[(norient, nharm, imaginary)]
527
665
  else:
528
- lmat = self._fft_3_orient[(norient,nharm,imaginary)]
529
-
530
- return self.bk_reshape(self.backend.matmul(self.bk_reshape(x,oshape),lmat),oshape2)
531
-
532
- def constant(self,data):
533
-
534
- if self.BACKEND==self.TENSORFLOW:
535
- return(self.backend.constant(data))
536
- return(data)
537
-
538
- def bk_reduce_mean(self,data,axis=None):
539
-
666
+ lmat = self._fft_3_orient[(norient, nharm, imaginary)]
667
+
668
+ return self.bk_reshape(
669
+ self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
670
+ )
671
+
672
+ def constant(self, data):
673
+
674
+ if self.BACKEND == self.TENSORFLOW:
675
+ return self.backend.constant(data)
676
+ return data
677
+
678
+ def bk_reduce_mean(self, data, axis=None):
679
+
540
680
  if axis is None:
541
- if self.BACKEND==self.TENSORFLOW:
542
- return(self.backend.reduce_mean(data))
543
- if self.BACKEND==self.TORCH:
544
- return(self.backend.mean(data))
545
- if self.BACKEND==self.NUMPY:
546
- return(np.mean(data))
681
+ if self.BACKEND == self.TENSORFLOW:
682
+ return self.backend.reduce_mean(data)
683
+ if self.BACKEND == self.TORCH:
684
+ return self.backend.mean(data)
685
+ if self.BACKEND == self.NUMPY:
686
+ return np.mean(data)
547
687
  else:
548
- if self.BACKEND==self.TENSORFLOW:
549
- return(self.backend.reduce_mean(data,axis=axis))
550
- if self.BACKEND==self.TORCH:
551
- return(self.backend.mean(data,axis))
552
- if self.BACKEND==self.NUMPY:
553
- return(np.mean(data,axis))
554
-
555
- def bk_reduce_min(self,data,axis=None):
556
-
688
+ if self.BACKEND == self.TENSORFLOW:
689
+ return self.backend.reduce_mean(data, axis=axis)
690
+ if self.BACKEND == self.TORCH:
691
+ return self.backend.mean(data, axis)
692
+ if self.BACKEND == self.NUMPY:
693
+ return np.mean(data, axis)
694
+
695
+ def bk_reduce_min(self, data, axis=None):
696
+
557
697
  if axis is None:
558
- if self.BACKEND==self.TENSORFLOW:
559
- return(self.backend.reduce_min(data))
560
- if self.BACKEND==self.TORCH:
561
- return(self.backend.min(data))
562
- if self.BACKEND==self.NUMPY:
563
- return(np.min(data))
698
+ if self.BACKEND == self.TENSORFLOW:
699
+ return self.backend.reduce_min(data)
700
+ if self.BACKEND == self.TORCH:
701
+ return self.backend.min(data)
702
+ if self.BACKEND == self.NUMPY:
703
+ return np.min(data)
564
704
  else:
565
- if self.BACKEND==self.TENSORFLOW:
566
- return(self.backend.reduce_min(data,axis=axis))
567
- if self.BACKEND==self.TORCH:
568
- return(self.backend.min(data,axis))
569
- if self.BACKEND==self.NUMPY:
570
- return(np.min(data,axis))
571
-
572
- def bk_random_seed(self,value):
573
-
574
- if self.BACKEND==self.TENSORFLOW:
575
- return(self.backend.random.set_seed(value))
576
- if self.BACKEND==self.TORCH:
577
- return(self.backend.random.set_seed(value))
578
- if self.BACKEND==self.NUMPY:
579
- return(np.random.seed(value))
580
-
581
- def bk_random_uniform(self,shape):
582
-
583
- if self.BACKEND==self.TENSORFLOW:
584
- return(self.backend.random.uniform(shape))
585
- if self.BACKEND==self.TORCH:
586
- return(self.backend.random.uniform(shape))
587
- if self.BACKEND==self.NUMPY:
588
- return(np.random.rand(shape))
589
-
590
- def bk_reduce_std(self,data,axis=None):
591
-
705
+ if self.BACKEND == self.TENSORFLOW:
706
+ return self.backend.reduce_min(data, axis=axis)
707
+ if self.BACKEND == self.TORCH:
708
+ return self.backend.min(data, axis)
709
+ if self.BACKEND == self.NUMPY:
710
+ return np.min(data, axis)
711
+
712
+ def bk_random_seed(self, value):
713
+
714
+ if self.BACKEND == self.TENSORFLOW:
715
+ return self.backend.random.set_seed(value)
716
+ if self.BACKEND == self.TORCH:
717
+ return self.backend.random.set_seed(value)
718
+ if self.BACKEND == self.NUMPY:
719
+ return np.random.seed(value)
720
+
721
+ def bk_random_uniform(self, shape):
722
+
723
+ if self.BACKEND == self.TENSORFLOW:
724
+ return self.backend.random.uniform(shape)
725
+ if self.BACKEND == self.TORCH:
726
+ return self.backend.random.uniform(shape)
727
+ if self.BACKEND == self.NUMPY:
728
+ return np.random.rand(shape)
729
+
730
+ def bk_reduce_std(self, data, axis=None):
592
731
  if axis is None:
593
- if self.BACKEND==self.TENSORFLOW:
594
- return(self.backend.math.reduce_std(data))
595
- if self.BACKEND==self.TORCH:
596
- return(self.backend.std(data))
597
- if self.BACKEND==self.NUMPY:
598
- return(np.std(data))
732
+ if self.BACKEND == self.TENSORFLOW:
733
+ r=self.backend.math.reduce_std(data)
734
+ if self.BACKEND == self.TORCH:
735
+ r=self.backend.std(data)
736
+ if self.BACKEND == self.NUMPY:
737
+ r=np.std(data)
738
+ return self.bk_complex(r,0*r)
599
739
  else:
600
- if self.BACKEND==self.TENSORFLOW:
601
- return(self.backend.math.reduce_std(data,axis=axis))
602
- if self.BACKEND==self.TORCH:
603
- return(self.backend.std(data,axis))
604
- if self.BACKEND==self.NUMPY:
605
- return(np.std(data,axis))
606
-
607
-
608
- def bk_sqrt(self,data):
609
-
610
- return(self.backend.sqrt(self.backend.abs(data)))
611
-
612
- def bk_abs(self,data):
613
- return(self.backend.abs(data))
614
-
615
- def bk_is_complex(self,data):
616
-
617
- if self.BACKEND==self.TENSORFLOW:
618
- if isinstance(data,np.ndarray):
619
- return (data.dtype=='complex64' or data.dtype=='complex128')
740
+ if self.BACKEND == self.TENSORFLOW:
741
+ r=self.backend.math.reduce_std(data, axis=axis)
742
+ if self.BACKEND == self.TORCH:
743
+ r=self.backend.std(data, axis)
744
+ if self.BACKEND == self.NUMPY:
745
+ r=np.std(data, axis)
746
+ if self.bk_is_complex(data):
747
+ return self.bk_complex(r,0*r)
748
+ else:
749
+ return r
750
+
751
+ def bk_sqrt(self, data):
752
+
753
+ return self.backend.sqrt(self.backend.abs(data))
754
+
755
+ def bk_abs(self, data):
756
+ return self.backend.abs(data)
757
+
758
+ def bk_is_complex(self, data):
759
+
760
+ if self.BACKEND == self.TENSORFLOW:
761
+ if isinstance(data, np.ndarray):
762
+ return data.dtype == "complex64" or data.dtype == "complex128"
620
763
  return data.dtype.is_complex
621
-
622
- if self.BACKEND==self.TORCH:
623
- if isinstance(data,np.ndarray):
624
- return (data.dtype=='complex64' or data.dtype=='complex128')
625
-
764
+
765
+ if self.BACKEND == self.TORCH:
766
+ if isinstance(data, np.ndarray):
767
+ return data.dtype == "complex64" or data.dtype == "complex128"
768
+
626
769
  return data.dtype.is_complex
627
-
628
- if self.BACKEND==self.NUMPY:
629
- return (data.dtype=='complex64' or data.dtype=='complex128')
630
770
 
631
- def bk_distcomp(self,data):
771
+ if self.BACKEND == self.NUMPY:
772
+ return data.dtype == "complex64" or data.dtype == "complex128"
773
+
774
+ def bk_distcomp(self, data):
632
775
  if self.bk_is_complex(data):
633
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
776
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
777
+ self.bk_imag(data)
778
+ )
634
779
  return res
635
780
  else:
636
781
  return self.bk_square(data)
637
-
638
- def bk_norm(self,data):
782
+
783
+ def bk_norm(self, data):
639
784
  if self.bk_is_complex(data):
640
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
785
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
786
+ self.bk_imag(data)
787
+ )
641
788
  return self.bk_sqrt(res)
642
789
 
643
790
  else:
644
791
  return self.bk_abs(data)
645
-
646
- def bk_square(self,data):
647
-
648
- if self.BACKEND==self.TENSORFLOW:
649
- return(self.backend.square(data))
650
- if self.BACKEND==self.TORCH:
651
- return(self.backend.square(data))
652
- if self.BACKEND==self.NUMPY:
653
- return(data*data)
654
-
655
- def bk_log(self,data):
656
- if self.BACKEND==self.TENSORFLOW:
657
- return(self.backend.math.log(data))
658
- if self.BACKEND==self.TORCH:
659
- return(self.backend.log(data))
660
- if self.BACKEND==self.NUMPY:
661
- return(np.log(data))
662
-
663
- def bk_matmul(self,a,b):
664
- if self.BACKEND==self.TENSORFLOW:
665
- return(self.backend.matmul(a,b))
666
- if self.BACKEND==self.TORCH:
667
- return(self.backend.matmul(a,b))
668
- if self.BACKEND==self.NUMPY:
669
- return(np.dot(a,b))
670
-
671
- def bk_tensor(self,data):
672
- if self.BACKEND==self.TENSORFLOW:
673
- return(self.backend.constant(data))
674
- if self.BACKEND==self.TORCH:
675
- return(self.backend.constant(data))
676
- if self.BACKEND==self.NUMPY:
677
- return(data)
678
-
679
- def bk_complex(self,real,imag):
680
- if self.BACKEND==self.TENSORFLOW:
681
- return(self.backend.dtypes.complex(real,imag))
682
- if self.BACKEND==self.TORCH:
683
- return(self.backend.complex(real,imag))
684
- if self.BACKEND==self.NUMPY:
685
- return real+1J*imag
686
-
687
- def bk_exp(self,data):
688
-
689
- return(self.backend.exp(data))
690
-
691
- def bk_min(self,data):
692
-
693
- return(self.backend.reduce_min(data))
694
-
695
- def bk_argmin(self,data):
696
-
697
- return(self.backend.argmin(data))
698
-
699
- def bk_tanh(self,data):
700
-
701
- return(self.backend.math.tanh(data))
702
-
703
- def bk_max(self,data):
704
-
705
- return(self.backend.reduce_max(data))
706
-
707
- def bk_argmax(self,data):
708
-
709
- return(self.backend.argmax(data))
710
-
711
- def bk_reshape(self,data,shape):
712
- if self.BACKEND==self.TORCH:
713
- if isinstance(data,np.ndarray):
792
+
793
+ def bk_square(self, data):
794
+
795
+ if self.BACKEND == self.TENSORFLOW:
796
+ return self.backend.square(data)
797
+ if self.BACKEND == self.TORCH:
798
+ return self.backend.square(data)
799
+ if self.BACKEND == self.NUMPY:
800
+ return data * data
801
+
802
+ def bk_log(self, data):
803
+ if self.BACKEND == self.TENSORFLOW:
804
+ return self.backend.math.log(data)
805
+ if self.BACKEND == self.TORCH:
806
+ return self.backend.log(data)
807
+ if self.BACKEND == self.NUMPY:
808
+ return np.log(data)
809
+
810
+ def bk_matmul(self, a, b):
811
+ if self.BACKEND == self.TENSORFLOW:
812
+ return self.backend.matmul(a, b)
813
+ if self.BACKEND == self.TORCH:
814
+ return self.backend.matmul(a, b)
815
+ if self.BACKEND == self.NUMPY:
816
+ return np.dot(a, b)
817
+
818
+ def bk_tensor(self, data):
819
+ if self.BACKEND == self.TENSORFLOW:
820
+ return self.backend.constant(data)
821
+ if self.BACKEND == self.TORCH:
822
+ return self.backend.constant(data)
823
+ if self.BACKEND == self.NUMPY:
824
+ return data
825
+
826
+ def bk_complex(self, real, imag):
827
+ if self.BACKEND == self.TENSORFLOW:
828
+ return self.backend.dtypes.complex(real, imag)
829
+ if self.BACKEND == self.TORCH:
830
+ return self.backend.complex(real, imag)
831
+ if self.BACKEND == self.NUMPY:
832
+ return real + 1j * imag
833
+
834
+ def bk_exp(self, data):
835
+
836
+ return self.backend.exp(data)
837
+
838
+ def bk_min(self, data):
839
+
840
+ return self.backend.reduce_min(data)
841
+
842
+ def bk_argmin(self, data):
843
+
844
+ return self.backend.argmin(data)
845
+
846
+ def bk_tanh(self, data):
847
+
848
+ return self.backend.math.tanh(data)
849
+
850
+ def bk_max(self, data):
851
+
852
+ return self.backend.reduce_max(data)
853
+
854
+ def bk_argmax(self, data):
855
+
856
+ return self.backend.argmax(data)
857
+
858
+ def bk_reshape(self, data, shape):
859
+ if self.BACKEND == self.TORCH:
860
+ if isinstance(data, np.ndarray):
714
861
  return data.reshape(shape)
715
-
716
- return(self.backend.reshape(data,shape))
717
-
718
- def bk_repeat(self,data,nn,axis=0):
719
- return(self.backend.repeat(data,nn,axis=axis))
720
-
721
- def bk_tile(self,data,nn,axis=0):
722
- return(self.backend.tile(data,nn))
723
-
724
- def bk_roll(self,data,nn,axis=0):
725
- return(self.backend.roll(data,nn,axis=axis))
726
-
727
- def bk_expand_dims(self,data,axis=0):
728
- if self.BACKEND==self.TENSORFLOW:
729
- return(self.backend.expand_dims(data,axis=axis))
730
- if self.BACKEND==self.TORCH:
731
- if isinstance(data,np.ndarray):
732
- data=self.backend.from_numpy(data)
733
- return(self.backend.unsqueeze(data,axis))
734
- if self.BACKEND==self.NUMPY:
735
- return(np.expand_dims(data,axis))
736
-
737
- def bk_transpose(self,data,thelist):
738
- if self.BACKEND==self.TENSORFLOW:
739
- return(self.backend.transpose(data,thelist))
740
- if self.BACKEND==self.TORCH:
741
- return(self.backend.transpose(data,thelist))
742
- if self.BACKEND==self.NUMPY:
743
- return(np.transpose(data,thelist))
744
-
745
- def bk_concat(self,data,axis=None):
746
-
747
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
862
+
863
+ return self.backend.reshape(data, shape)
864
+
865
+ def bk_repeat(self, data, nn, axis=0):
866
+ return self.backend.repeat(data, nn, axis=axis)
867
+
868
+ def bk_tile(self, data, nn, axis=0):
869
+ return self.backend.tile(data, nn)
870
+
871
+ def bk_roll(self, data, nn, axis=0):
872
+ return self.backend.roll(data, nn, axis=axis)
873
+
874
+ def bk_expand_dims(self, data, axis=0):
875
+ if self.BACKEND == self.TENSORFLOW:
876
+ return self.backend.expand_dims(data, axis=axis)
877
+ if self.BACKEND == self.TORCH:
878
+ if isinstance(data, np.ndarray):
879
+ data = self.backend.from_numpy(data)
880
+ return self.backend.unsqueeze(data, axis)
881
+ if self.BACKEND == self.NUMPY:
882
+ return np.expand_dims(data, axis)
883
+
884
+ def bk_transpose(self, data, thelist):
885
+ if self.BACKEND == self.TENSORFLOW:
886
+ return self.backend.transpose(data, thelist)
887
+ if self.BACKEND == self.TORCH:
888
+ return self.backend.transpose(data, thelist)
889
+ if self.BACKEND == self.NUMPY:
890
+ return np.transpose(data, thelist)
891
+
892
+ def bk_concat(self, data, axis=None):
893
+
894
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
748
895
  if axis is None:
749
- if data[0].dtype==self.all_cbk_type:
750
- ndata=len(data)
751
- xr=self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
752
- xi=self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
753
- return self.backend.complex(xr,xi)
896
+ if data[0].dtype == self.all_cbk_type:
897
+ ndata = len(data)
898
+ xr = self.backend.concat(
899
+ [self.bk_real(data[k]) for k in range(ndata)]
900
+ )
901
+ xi = self.backend.concat(
902
+ [self.bk_imag(data[k]) for k in range(ndata)]
903
+ )
904
+ return self.backend.complex(xr, xi)
754
905
  else:
755
- return(self.backend.concat(data))
906
+ return self.backend.concat(data)
756
907
  else:
757
- if data[0].dtype==self.all_cbk_type:
758
- ndata=len(data)
759
- xr=self.backend.concat([self.bk_real(data[k]) for k in range(ndata)],axis=axis)
760
- xi=self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)],axis=axis)
761
- return self.backend.complex(xr,xi)
908
+ if data[0].dtype == self.all_cbk_type:
909
+ ndata = len(data)
910
+ xr = self.backend.concat(
911
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
912
+ )
913
+ xi = self.backend.concat(
914
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
915
+ )
916
+ return self.backend.complex(xr, xi)
762
917
  else:
763
- return(self.backend.concat(data,axis=axis))
918
+ return self.backend.concat(data, axis=axis)
764
919
  else:
765
920
  if axis is None:
766
- return np.concatenate(data,axis=0)
921
+ return np.concatenate(data, axis=0)
767
922
  else:
768
- return np.concatenate(data,axis=axis)
923
+ return np.concatenate(data, axis=axis)
769
924
 
770
-
771
- def bk_conjugate(self,data):
772
-
773
- if self.BACKEND==self.TENSORFLOW:
925
+ def bk_conjugate(self, data):
926
+
927
+ if self.BACKEND == self.TENSORFLOW:
774
928
  return self.backend.math.conj(data)
775
- if self.BACKEND==self.TORCH:
929
+ if self.BACKEND == self.TORCH:
776
930
  return self.backend.conj(data)
777
- if self.BACKEND==self.NUMPY:
931
+ if self.BACKEND == self.NUMPY:
778
932
  return data.conjugate()
779
-
780
- def bk_real(self,data):
781
- if self.BACKEND==self.TENSORFLOW:
933
+
934
+ def bk_real(self, data):
935
+ if self.BACKEND == self.TENSORFLOW:
782
936
  return self.backend.math.real(data)
783
- if self.BACKEND==self.TORCH:
937
+ if self.BACKEND == self.TORCH:
784
938
  return data.real
785
- if self.BACKEND==self.NUMPY:
939
+ if self.BACKEND == self.NUMPY:
786
940
  return data.real
787
941
 
788
- def bk_imag(self,data):
789
- if self.BACKEND==self.TENSORFLOW:
942
+ def bk_imag(self, data):
943
+ if self.BACKEND == self.TENSORFLOW:
790
944
  return self.backend.math.imag(data)
791
- if self.BACKEND==self.TORCH:
792
- if data.dtype==self.all_cbk_type:
945
+ if self.BACKEND == self.TORCH:
946
+ if data.dtype == self.all_cbk_type:
793
947
  return data.imag
794
948
  else:
795
949
  return 0
796
-
797
- if self.BACKEND==self.NUMPY:
950
+
951
+ if self.BACKEND == self.NUMPY:
798
952
  return data.imag
799
-
800
- def bk_relu(self,x):
801
- if self.BACKEND==self.TENSORFLOW:
802
- if x.dtype==self.all_cbk_type:
803
- xr=self.backend.nn.relu(self.bk_real(x))
804
- xi=self.backend.nn.relu(self.bk_imag(x))
805
- return self.backend.complex(xr,xi)
953
+
954
+ def bk_relu(self, x):
955
+ if self.BACKEND == self.TENSORFLOW:
956
+ if x.dtype == self.all_cbk_type:
957
+ xr = self.backend.nn.relu(self.bk_real(x))
958
+ xi = self.backend.nn.relu(self.bk_imag(x))
959
+ return self.backend.complex(xr, xi)
806
960
  else:
807
961
  return self.backend.nn.relu(x)
808
- if self.BACKEND==self.TORCH:
962
+ if self.BACKEND == self.TORCH:
809
963
  return self.backend.relu(x)
810
- if self.BACKEND==self.NUMPY:
811
- return (x>0)*x
812
-
813
- def bk_cast(self,x):
814
- if isinstance(x,np.float64):
815
- if self.all_bk_type=='float32':
816
- return(np.float32(x))
964
+ if self.BACKEND == self.NUMPY:
965
+ return (x > 0) * x
966
+
967
+ def bk_cast(self, x):
968
+ if isinstance(x, np.float64):
969
+ if self.all_bk_type == "float32":
970
+ return np.float32(x)
817
971
  else:
818
- return(x)
819
- if isinstance(x,np.float32):
820
- if self.all_bk_type=='float64':
821
- return(np.float64(x))
972
+ return x
973
+ if isinstance(x, np.float32):
974
+ if self.all_bk_type == "float64":
975
+ return np.float64(x)
822
976
  else:
823
- return(x)
977
+ return x
824
978
 
825
- if isinstance(x,np.int32) or isinstance(x,np.int64) or isinstance(x,int):
826
- if self.all_bk_type=='float64':
827
- return(np.float64(x))
979
+ if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
980
+ if self.all_bk_type == "float64":
981
+ return np.float64(x)
828
982
  else:
829
- return(np.float32(x))
830
-
983
+ return np.float32(x)
984
+
831
985
  if self.bk_is_complex(x):
832
- out_type=self.all_cbk_type
986
+ out_type = self.all_cbk_type
833
987
  else:
834
- out_type=self.all_bk_type
835
-
836
- if self.BACKEND==self.TENSORFLOW:
837
- return self.backend.cast(x,out_type)
838
-
839
- if self.BACKEND==self.TORCH:
840
- if isinstance(x,np.ndarray):
841
- x=self.backend.from_numpy(x)
842
-
988
+ out_type = self.all_bk_type
989
+
990
+ if self.BACKEND == self.TENSORFLOW:
991
+ return self.backend.cast(x, out_type)
992
+
993
+ if self.BACKEND == self.TORCH:
994
+ if isinstance(x, np.ndarray):
995
+ x = self.backend.from_numpy(x)
996
+
843
997
  if x.dtype.is_complex:
844
- out_type=self.all_cbk_type
998
+ out_type = self.all_cbk_type
845
999
  else:
846
- out_type=self.all_bk_type
847
-
1000
+ out_type = self.all_bk_type
1001
+
848
1002
  return x.type(out_type)
849
-
850
- if self.BACKEND==self.NUMPY:
1003
+
1004
+ if self.BACKEND == self.NUMPY:
851
1005
  return x.astype(out_type)