foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/backend.py CHANGED
@@ -1,716 +1,1061 @@
1
1
  import sys
2
+
2
3
  import numpy as np
3
4
 
5
+
4
6
  class foscat_backend:
5
-
6
- def __init__(self,name,mpi_rank=0,all_type='float64',gpupos=0):
7
-
8
- self.TENSORFLOW=1
9
- self.TORCH=2
10
- self.NUMPY=3
11
-
7
+
8
+ def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
9
+
10
+ self.TENSORFLOW = 1
11
+ self.TORCH = 2
12
+ self.NUMPY = 3
13
+
12
14
  # table use to compute the iso orientation rotation
13
- self._iso_orient={}
14
- self._iso_orient_T={}
15
- self._iso_orient_C={}
16
- self._iso_orient_C_T={}
17
- self._fft_1_orient={}
18
- self._fft_1_orient_C={}
19
- self._fft_2_orient={}
20
- self._fft_2_orient_C={}
21
- self._fft_3_orient={}
22
- self._fft_3_orient_C={}
23
-
24
- self.BACKEND=name
25
-
26
- if name not in ['tensorflow','torch','numpy']:
27
- print('Backend "%s" not yet implemented'%(name))
28
- print(' Choose inside the next 3 available backends :')
29
- print(' - tensorflow')
30
- print(' - torch')
31
- print(' - numpy (Impossible to do synthesis using numpy)')
32
- exit(0)
33
-
34
- if self.BACKEND=='tensorflow':
15
+ self._iso_orient = {}
16
+ self._iso_orient_T = {}
17
+ self._iso_orient_C = {}
18
+ self._iso_orient_C_T = {}
19
+ self._fft_1_orient = {}
20
+ self._fft_1_orient_C = {}
21
+ self._fft_2_orient = {}
22
+ self._fft_2_orient_C = {}
23
+ self._fft_3_orient = {}
24
+ self._fft_3_orient_C = {}
25
+
26
+ self.BACKEND = name
27
+
28
+ if name not in ["tensorflow", "torch", "numpy"]:
29
+ print('Backend "%s" not yet implemented' % (name))
30
+ print(" Choose inside the next 3 available backends :")
31
+ print(" - tensorflow")
32
+ print(" - torch")
33
+ print(" - numpy (Impossible to do synthesis using numpy)")
34
+ return None
35
+
36
+ if self.BACKEND == "tensorflow":
35
37
  import tensorflow as tf
36
-
37
- self.backend=tf
38
- self.BACKEND=self.TENSORFLOW
39
- #tf.config.threading.set_inter_op_parallelism_threads(1)
40
- #tf.config.threading.set_intra_op_parallelism_threads(1)
41
38
 
42
- if self.BACKEND=='torch':
39
+ self.backend = tf
40
+ self.BACKEND = self.TENSORFLOW
41
+ # tf.config.threading.set_inter_op_parallelism_threads(1)
42
+ # tf.config.threading.set_intra_op_parallelism_threads(1)
43
+ self.tf_function = tf.function
44
+
45
+ if self.BACKEND == "torch":
43
46
  import torch
44
- self.BACKEND=self.TORCH
45
- self.backend=torch
46
-
47
- if self.BACKEND=='numpy':
48
- self.BACKEND=self.NUMPY
49
- self.backend=np
47
+
48
+ self.BACKEND = self.TORCH
49
+ self.backend = torch
50
+ self.tf_function = self.tf_loc_function
51
+
52
+ if self.BACKEND == "numpy":
53
+ self.BACKEND = self.NUMPY
54
+ self.backend = np
50
55
  import scipy as scipy
51
- self.scipy=scipy
52
-
53
- self.float64=self.backend.float64
54
- self.float32=self.backend.float32
55
- self.int64=self.backend.int64
56
- self.int32=self.backend.int32
57
- self.complex64=self.backend.complex128
58
- self.complex128=self.backend.complex64
59
-
60
- if all_type=='float32':
61
- self.all_bk_type=self.backend.float32
62
- self.all_cbk_type=self.backend.complex64
56
+
57
+ self.scipy = scipy
58
+ self.tf_function = self.tf_loc_function
59
+
60
+ self.float64 = self.backend.float64
61
+ self.float32 = self.backend.float32
62
+ self.int64 = self.backend.int64
63
+ self.int32 = self.backend.int32
64
+ self.complex64 = self.backend.complex128
65
+ self.complex128 = self.backend.complex64
66
+
67
+ if all_type == "float32":
68
+ self.all_bk_type = self.backend.float32
69
+ self.all_cbk_type = self.backend.complex64
63
70
  else:
64
- if all_type=='float64':
65
- self.all_type='float64'
66
- self.all_bk_type=self.backend.float64
67
- self.all_cbk_type=self.backend.complex128
71
+ if all_type == "float64":
72
+ self.all_type = "float64"
73
+ self.all_bk_type = self.backend.float64
74
+ self.all_cbk_type = self.backend.complex128
68
75
  else:
69
- print('ERROR INIT FOCUS ',all_type,' should be float32 or float64')
70
- exit(0)
71
- #===========================================================================
72
- # INIT
73
- if mpi_rank==0:
74
- if self.BACKEND==self.TENSORFLOW:
75
- print("Num GPUs Available: ", len(self.backend.config.experimental.list_physical_devices('GPU')))
76
+ print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
77
+ return None
78
+ # ===========================================================================
79
+ # INIT
80
+ if mpi_rank == 0:
81
+ if self.BACKEND == self.TENSORFLOW and not silent:
82
+ print(
83
+ "Num GPUs Available: ",
84
+ len(self.backend.config.experimental.list_physical_devices("GPU")),
85
+ )
76
86
  sys.stdout.flush()
77
-
78
- if self.BACKEND==self.TENSORFLOW:
87
+
88
+ if self.BACKEND == self.TENSORFLOW:
79
89
  self.backend.debugging.set_log_device_placement(False)
80
90
  self.backend.config.set_soft_device_placement(True)
81
-
82
- gpus = self.backend.config.experimental.list_physical_devices('GPU')
83
-
84
- if self.BACKEND==self.TORCH:
85
- gpus=torch.cuda.is_available()
86
-
87
- if self.BACKEND==self.NUMPY:
88
- gpus=[]
89
- gpuname='CPU:0'
90
- self.gpulist={}
91
- self.gpulist[0]=gpuname
92
- self.ngpu=1
93
-
91
+
92
+ gpus = self.backend.config.experimental.list_physical_devices("GPU")
93
+
94
+ if self.BACKEND == self.TORCH:
95
+ gpus = torch.cuda.is_available()
96
+
97
+ if self.BACKEND == self.NUMPY:
98
+ gpus = []
99
+ gpuname = "CPU:0"
100
+ self.gpulist = {}
101
+ self.gpulist[0] = gpuname
102
+ self.ngpu = 1
103
+
94
104
  if gpus:
95
105
  try:
96
- if self.BACKEND==self.TENSORFLOW:
97
- # Currently, memory growth needs to be the same across GPUs
106
+ if self.BACKEND == self.TENSORFLOW:
107
+ # Currently, memory growth needs to be the same across GPUs
98
108
  for gpu in gpus:
99
109
  self.backend.config.experimental.set_memory_growth(gpu, True)
100
- logical_gpus = self.backend.config.experimental.list_logical_devices('GPU')
101
- print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
110
+ logical_gpus = (
111
+ self.backend.config.experimental.list_logical_devices("GPU")
112
+ )
113
+ print(
114
+ len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
115
+ )
102
116
  sys.stdout.flush()
103
- self.ngpu=len(logical_gpus)
104
- gpuname=logical_gpus[gpupos%self.ngpu].name
105
- self.gpulist={}
117
+ self.ngpu = len(logical_gpus)
118
+ gpuname = logical_gpus[gpupos % self.ngpu].name
119
+ self.gpulist = {}
106
120
  for i in range(self.ngpu):
107
- self.gpulist[i]=logical_gpus[i].name
108
- if self.BACKEND==self.TORCH:
109
- self.ngpu=torch.cuda.device_count()
110
- self.gpulist={}
121
+ self.gpulist[i] = logical_gpus[i].name
122
+ if self.BACKEND == self.TORCH:
123
+ self.ngpu = torch.cuda.device_count()
124
+ self.gpulist = {}
111
125
  for k in range(self.ngpu):
112
- self.gpulist[k]=torch.cuda.get_device_name(0)
126
+ self.gpulist[k] = torch.cuda.get_device_name(0)
113
127
 
114
128
  except RuntimeError as e:
115
129
  # Memory growth must be set before GPUs have been initialized
116
130
  print(e)
117
131
 
118
- def calc_iso_orient(self,norient):
119
- tmp=np.zeros([norient*norient,norient])
132
+ def tf_loc_function(self, func):
133
+ return func
134
+
135
+ def calc_iso_orient(self, norient):
136
+ tmp = np.zeros([norient * norient, norient])
120
137
  for i in range(norient):
121
138
  for j in range(norient):
122
- tmp[j*norient+(j+i)%norient,i]=0.25
123
-
124
- self._iso_orient[norient]=self.constant(self.bk_cast(tmp))
125
- self._iso_orient_T[norient]=self.constant(self.bk_cast(4*tmp.T))
126
- self._iso_orient_C[norient]=self.bk_complex(self._iso_orient[norient],0*self._iso_orient[norient])
127
- self._iso_orient_C_T[norient]=self.bk_complex(self._iso_orient_T[norient],0*self._iso_orient_T[norient])
128
-
129
- def calc_fft_orient(self,norient,nharm,imaginary):
139
+ tmp[j * norient + (j + i) % norient, i] = 0.25
140
+
141
+ self._iso_orient[norient] = self.constant(self.bk_cast(tmp))
142
+ self._iso_orient_T[norient] = self.constant(self.bk_cast(4 * tmp.T))
143
+ self._iso_orient_C[norient] = self.bk_complex(
144
+ self._iso_orient[norient], 0 * self._iso_orient[norient]
145
+ )
146
+ self._iso_orient_C_T[norient] = self.bk_complex(
147
+ self._iso_orient_T[norient], 0 * self._iso_orient_T[norient]
148
+ )
149
+
150
+ def calc_fft_orient(self, norient, nharm, imaginary):
151
+
152
+ x = np.arange(norient) / norient * 2 * np.pi
130
153
 
131
- x=np.arange(norient)/norient*2*np.pi
132
-
133
154
  if imaginary:
134
- tmp=np.zeros([norient,1+nharm*2])
135
- tmp[:,0]=1.0
155
+ tmp = np.zeros([norient, 1 + nharm * 2])
156
+ tmp[:, 0] = 1.0
136
157
  for k in range(nharm):
137
- tmp[:,k*2+1]=np.cos(x*(k+1))
138
- tmp[:,k*2+2]=np.sin(x*(k+1))
139
-
140
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
141
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
158
+ tmp[:, k * 2 + 1] = np.cos(x * (k + 1))
159
+ tmp[:, k * 2 + 2] = np.sin(x * (k + 1))
160
+
161
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
162
+ self.constant(tmp)
163
+ )
164
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
165
+ self._fft_1_orient[(norient, nharm, imaginary)],
166
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
167
+ )
142
168
  else:
143
- tmp=np.zeros([norient,1+nharm])
144
- for k in range(nharm+1):
145
- tmp[:,k]=np.cos(x*k)
146
-
147
- self._fft_1_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp))
148
- self._fft_1_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_1_orient[(norient,nharm,imaginary)],0*self._fft_1_orient[(norient,nharm,imaginary)])
169
+ tmp = np.zeros([norient, 1 + nharm])
170
+ for k in range(nharm + 1):
171
+ tmp[:, k] = np.cos(x * k)
149
172
 
150
- x=np.repeat(x,norient).reshape(norient,norient)
173
+ self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
174
+ self.constant(tmp)
175
+ )
176
+ self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
177
+ self._fft_1_orient[(norient, nharm, imaginary)],
178
+ 0 * self._fft_1_orient[(norient, nharm, imaginary)],
179
+ )
180
+
181
+ x = np.repeat(x, norient).reshape(norient, norient)
151
182
 
152
183
  if imaginary:
153
- tmp=np.zeros([norient,norient,(1+nharm*2),(1+nharm*2)])
154
- tmp[:,:,0,0]=1.0
184
+ tmp = np.zeros([norient, norient, (1 + nharm * 2), (1 + nharm * 2)])
185
+ tmp[:, :, 0, 0] = 1.0
155
186
  for k in range(nharm):
156
- tmp[:,:,k*2+1,0]=np.cos(x*(k+1))
157
- tmp[:,:,k*2+2,0]=np.sin(x*(k+1))
158
- tmp[:,:,0,k*2+1]=np.cos((x.T)*(k+1))
159
- tmp[:,:,0,k*2+2]=np.sin((x.T)*(k+1))
160
- for l in range(nharm):
161
- tmp[:,:,k*2+1,l*2+1]=np.cos(x*(k+1))*np.cos((x.T)*(l+1))
162
- tmp[:,:,k*2+2,l*2+1]=np.sin(x*(k+1))*np.cos((x.T)*(l+1))
163
- tmp[:,:,k*2+1,l*2+2]=np.cos(x*(k+1))*np.sin((x.T)*(l+1))
164
- tmp[:,:,k*2+2,l*2+2]=np.sin(x*(k+1))*np.sin((x.T)*(l+1))
165
-
166
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+2*nharm)*(1+2*nharm))))
167
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
187
+ tmp[:, :, k * 2 + 1, 0] = np.cos(x * (k + 1))
188
+ tmp[:, :, k * 2 + 2, 0] = np.sin(x * (k + 1))
189
+ tmp[:, :, 0, k * 2 + 1] = np.cos((x.T) * (k + 1))
190
+ tmp[:, :, 0, k * 2 + 2] = np.sin((x.T) * (k + 1))
191
+ for l_orient in range(nharm):
192
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 1] = np.cos(
193
+ x * (k + 1)
194
+ ) * np.cos((x.T) * (l_orient + 1))
195
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 1] = np.sin(
196
+ x * (k + 1)
197
+ ) * np.cos((x.T) * (l_orient + 1))
198
+ tmp[:, :, k * 2 + 1, l_orient * 2 + 2] = np.cos(
199
+ x * (k + 1)
200
+ ) * np.sin((x.T) * (l_orient + 1))
201
+ tmp[:, :, k * 2 + 2, l_orient * 2 + 2] = np.sin(
202
+ x * (k + 1)
203
+ ) * np.sin((x.T) * (l_orient + 1))
204
+
205
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
206
+ self.constant(
207
+ tmp.reshape(norient * norient, (1 + 2 * nharm) * (1 + 2 * nharm))
208
+ )
209
+ )
210
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
211
+ self._fft_2_orient[(norient, nharm, imaginary)],
212
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
213
+ )
168
214
  else:
169
- tmp=np.zeros([norient,norient,(1+nharm),(1+nharm)])
170
-
171
- for k in range(nharm+1):
172
- for l in range(nharm+1):
173
- tmp[:,:,k,l]=np.cos(x*k)*np.cos((x.T)*l)
174
-
175
- self._fft_2_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient,(1+nharm)*(1+nharm))))
176
- self._fft_2_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_2_orient[(norient,nharm,imaginary)],0*self._fft_2_orient[(norient,nharm,imaginary)])
215
+ tmp = np.zeros([norient, norient, (1 + nharm), (1 + nharm)])
177
216
 
178
- x=np.arange(norient)/norient*2*np.pi
179
- xx=np.zeros([norient,norient,norient])
180
- yy=np.zeros([norient,norient,norient])
181
- zz=np.zeros([norient,norient,norient])
217
+ for k in range(nharm + 1):
218
+ for l_orient in range(nharm + 1):
219
+ tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
220
+
221
+ self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
222
+ self.constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
223
+ )
224
+ self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
225
+ self._fft_2_orient[(norient, nharm, imaginary)],
226
+ 0 * self._fft_2_orient[(norient, nharm, imaginary)],
227
+ )
228
+
229
+ x = np.arange(norient) / norient * 2 * np.pi
230
+ xx = np.zeros([norient, norient, norient])
231
+ yy = np.zeros([norient, norient, norient])
232
+ zz = np.zeros([norient, norient, norient])
182
233
  for i in range(norient):
183
234
  for j in range(norient):
184
- xx[:,i,j]=x
185
- yy[i,:,j]=x
186
- zz[i,j,:]=x
187
-
235
+ xx[:, i, j] = x
236
+ yy[i, :, j] = x
237
+ zz[i, j, :] = x
238
+
188
239
  if imaginary:
189
- tmp=np.ones([norient,norient,norient,(1+nharm*2),(1+nharm*2),(1+nharm*2)])
190
-
240
+ tmp = np.ones(
241
+ [
242
+ norient,
243
+ norient,
244
+ norient,
245
+ (1 + nharm * 2),
246
+ (1 + nharm * 2),
247
+ (1 + nharm * 2),
248
+ ]
249
+ )
250
+
191
251
  for k in range(nharm):
192
- tmp[:,:,:,k*2+1,0,0]=np.cos(xx*(k+1))
193
- tmp[:,:,:,0,k*2+1,0]=np.cos(yy*(k+1))
194
- tmp[:,:,:,0,0,k*2+1]=np.cos(zz*(k+1))
195
-
196
- tmp[:,:,:,k*2+2,0,0]=np.sin(xx*(k+1))
197
- tmp[:,:,:,0,k*2+2,0]=np.sin(yy*(k+1))
198
- tmp[:,:,:,0,0,k*2+2]=np.sin(zz*(k+1))
199
- for l in range(nharm):
200
- tmp[:,:,:,k*2+1,l*2+1,0]=np.cos(xx*(k+1))*np.cos(yy*(l+1))
201
- tmp[:,:,:,k*2+1,l*2+2,0]=np.cos(xx*(k+1))*np.sin(yy*(l+1))
202
- tmp[:,:,:,k*2+2,l*2+1,0]=np.sin(xx*(k+1))*np.cos(yy*(l+1))
203
- tmp[:,:,:,k*2+2,l*2+2,0]=np.sin(xx*(k+1))*np.sin(yy*(l+1))
204
-
205
- tmp[:,:,:,k*2+1,0,l*2+1]=np.cos(xx*(k+1))*np.cos(zz*(l+1))
206
- tmp[:,:,:,k*2+1,0,l*2+2]=np.cos(xx*(k+1))*np.sin(zz*(l+1))
207
- tmp[:,:,:,k*2+2,0,l*2+1]=np.sin(xx*(k+1))*np.cos(zz*(l+1))
208
- tmp[:,:,:,k*2+2,0,l*2+2]=np.sin(xx*(k+1))*np.sin(zz*(l+1))
209
-
210
- tmp[:,:,:,0,k*2+1,l*2+1]=np.cos(yy*(k+1))*np.cos(zz*(l+1))
211
- tmp[:,:,:,0,k*2+1,l*2+2]=np.cos(yy*(k+1))*np.sin(zz*(l+1))
212
- tmp[:,:,:,0,k*2+2,l*2+1]=np.sin(yy*(k+1))*np.cos(zz*(l+1))
213
- tmp[:,:,:,0,k*2+2,l*2+2]=np.sin(yy*(k+1))*np.sin(zz*(l+1))
214
-
252
+ tmp[:, :, :, k * 2 + 1, 0, 0] = np.cos(xx * (k + 1))
253
+ tmp[:, :, :, 0, k * 2 + 1, 0] = np.cos(yy * (k + 1))
254
+ tmp[:, :, :, 0, 0, k * 2 + 1] = np.cos(zz * (k + 1))
255
+
256
+ tmp[:, :, :, k * 2 + 2, 0, 0] = np.sin(xx * (k + 1))
257
+ tmp[:, :, :, 0, k * 2 + 2, 0] = np.sin(yy * (k + 1))
258
+ tmp[:, :, :, 0, 0, k * 2 + 2] = np.sin(zz * (k + 1))
259
+ for l_orient in range(nharm):
260
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, 0] = np.cos(
261
+ xx * (k + 1)
262
+ ) * np.cos(yy * (l_orient + 1))
263
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, 0] = np.cos(
264
+ xx * (k + 1)
265
+ ) * np.sin(yy * (l_orient + 1))
266
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, 0] = np.sin(
267
+ xx * (k + 1)
268
+ ) * np.cos(yy * (l_orient + 1))
269
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, 0] = np.sin(
270
+ xx * (k + 1)
271
+ ) * np.sin(yy * (l_orient + 1))
272
+
273
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 1] = np.cos(
274
+ xx * (k + 1)
275
+ ) * np.cos(zz * (l_orient + 1))
276
+ tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 2] = np.cos(
277
+ xx * (k + 1)
278
+ ) * np.sin(zz * (l_orient + 1))
279
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 1] = np.sin(
280
+ xx * (k + 1)
281
+ ) * np.cos(zz * (l_orient + 1))
282
+ tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 2] = np.sin(
283
+ xx * (k + 1)
284
+ ) * np.sin(zz * (l_orient + 1))
285
+
286
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 1] = np.cos(
287
+ yy * (k + 1)
288
+ ) * np.cos(zz * (l_orient + 1))
289
+ tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 2] = np.cos(
290
+ yy * (k + 1)
291
+ ) * np.sin(zz * (l_orient + 1))
292
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 1] = np.sin(
293
+ yy * (k + 1)
294
+ ) * np.cos(zz * (l_orient + 1))
295
+ tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 2] = np.sin(
296
+ yy * (k + 1)
297
+ ) * np.sin(zz * (l_orient + 1))
298
+
215
299
  for m in range(nharm):
216
- tmp[:,:,:,k*2+1,l*2+1,m*2+1]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
217
- tmp[:,:,:,k*2+1,l*2+1,m*2+2]=np.cos(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
218
- tmp[:,:,:,k*2+1,l*2+2,m*2+1]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
219
- tmp[:,:,:,k*2+1,l*2+2,m*2+2]=np.cos(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
220
- tmp[:,:,:,k*2+2,l*2+1,m*2+1]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.cos(zz*(m+1))
221
- tmp[:,:,:,k*2+2,l*2+1,m*2+2]=np.sin(xx*(k+1))*np.cos(yy*(l+1))*np.sin(zz*(m+1))
222
- tmp[:,:,:,k*2+2,l*2+2,m*2+1]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.cos(zz*(m+1))
223
- tmp[:,:,:,k*2+2,l*2+2,m*2+2]=np.sin(xx*(k+1))*np.sin(yy*(l+1))*np.sin(zz*(m+1))
224
-
225
-
226
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm*2)*(1+nharm*2)*(1+nharm*2))))
227
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
300
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 1] = (
301
+ np.cos(xx * (k + 1))
302
+ * np.cos(yy * (l_orient + 1))
303
+ * np.cos(zz * (m + 1))
304
+ )
305
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 2] = (
306
+ np.cos(xx * (k + 1))
307
+ * np.cos(yy * (l_orient + 1))
308
+ * np.sin(zz * (m + 1))
309
+ )
310
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 1] = (
311
+ np.cos(xx * (k + 1))
312
+ * np.sin(yy * (l_orient + 1))
313
+ * np.cos(zz * (m + 1))
314
+ )
315
+ tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 2] = (
316
+ np.cos(xx * (k + 1))
317
+ * np.sin(yy * (l_orient + 1))
318
+ * np.sin(zz * (m + 1))
319
+ )
320
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 1] = (
321
+ np.sin(xx * (k + 1))
322
+ * np.cos(yy * (l_orient + 1))
323
+ * np.cos(zz * (m + 1))
324
+ )
325
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 2] = (
326
+ np.sin(xx * (k + 1))
327
+ * np.cos(yy * (l_orient + 1))
328
+ * np.sin(zz * (m + 1))
329
+ )
330
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 1] = (
331
+ np.sin(xx * (k + 1))
332
+ * np.sin(yy * (l_orient + 1))
333
+ * np.cos(zz * (m + 1))
334
+ )
335
+ tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 2] = (
336
+ np.sin(xx * (k + 1))
337
+ * np.sin(yy * (l_orient + 1))
338
+ * np.sin(zz * (m + 1))
339
+ )
340
+
341
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
342
+ self.constant(
343
+ tmp.reshape(
344
+ norient * norient * norient,
345
+ (1 + nharm * 2) * (1 + nharm * 2) * (1 + nharm * 2),
346
+ )
347
+ )
348
+ )
349
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
350
+ self._fft_3_orient[(norient, nharm, imaginary)],
351
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
352
+ )
228
353
  else:
229
- tmp=np.zeros([norient,norient,norient,(1+nharm),(1+nharm),(1+nharm)])
230
-
231
- for k in range(nharm+1):
232
- for l in range(nharm+1):
233
- for m in range(nharm+1):
234
- tmp[:,:,:,k,l,m]=np.cos(xx*k)*np.cos(yy*l)*np.cos(zz*m)
354
+ tmp = np.zeros(
355
+ [norient, norient, norient, (1 + nharm), (1 + nharm), (1 + nharm)]
356
+ )
357
+
358
+ for k in range(nharm + 1):
359
+ for l_orient in range(nharm + 1):
360
+ for m in range(nharm + 1):
361
+ tmp[:, :, :, k, l_orient, m] = (
362
+ np.cos(xx * k) * np.cos(yy * l_orient) * np.cos(zz * m)
363
+ )
364
+
365
+ self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
366
+ self.constant(
367
+ tmp.reshape(
368
+ norient * norient * norient,
369
+ (1 + nharm) * (1 + nharm) * (1 + nharm),
370
+ )
371
+ )
372
+ )
373
+ self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
374
+ self._fft_3_orient[(norient, nharm, imaginary)],
375
+ 0 * self._fft_3_orient[(norient, nharm, imaginary)],
376
+ )
235
377
 
236
- self._fft_3_orient[(norient,nharm,imaginary)]=self.bk_cast(self.constant(tmp.reshape(norient*norient*norient,(1+nharm)*(1+nharm)*(1+nharm))))
237
- self._fft_3_orient_C[(norient,nharm,imaginary)]=self.bk_complex(self._fft_3_orient[(norient,nharm,imaginary)],0*self._fft_3_orient[(norient,nharm,imaginary)])
238
-
239
378
  # ---------------------------------------------−---------
240
379
  # -- BACKEND DEFINITION --
241
380
  # ---------------------------------------------−---------
242
- def bk_SparseTensor(self,indice,w,dense_shape=[]):
243
- if self.BACKEND==self.TENSORFLOW:
244
- return(self.backend.SparseTensor(indice,w,dense_shape=dense_shape))
245
- if self.BACKEND==self.TORCH:
246
- return(self.backend.sparse_coo_tensor(indice.T,w,dense_shape))
247
- if self.BACKEND==self.NUMPY:
248
- return self.scipy.sparse.coo_matrix((w,(indice[:,0],indice[:,1])),shape=dense_shape)
249
-
250
- def bk_sparse_dense_matmul(self,smat,mat):
251
- if self.BACKEND==self.TENSORFLOW:
252
- return self.backend.sparse.sparse_dense_matmul(smat,mat)
253
- if self.BACKEND==self.TORCH:
381
+ def bk_SparseTensor(self, indice, w, dense_shape=[]):
382
+ if self.BACKEND == self.TENSORFLOW:
383
+ return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
384
+ if self.BACKEND == self.TORCH:
385
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
386
+ if self.BACKEND == self.NUMPY:
387
+ return self.scipy.sparse.coo_matrix(
388
+ (w, (indice[:, 0], indice[:, 1])), shape=dense_shape
389
+ )
390
+
391
+ def bk_stack(self, list, axis=0):
392
+ if self.BACKEND == self.TENSORFLOW:
393
+ return self.backend.stack(list, axis=axis)
394
+ if self.BACKEND == self.TORCH:
395
+ return self.backend.stack(list, axis=axis)
396
+ if self.BACKEND == self.NUMPY:
397
+ return self.backend.stack(list, axis=axis)
398
+
399
+ def bk_sparse_dense_matmul(self, smat, mat):
400
+ if self.BACKEND == self.TENSORFLOW:
401
+ return self.backend.sparse.sparse_dense_matmul(smat, mat)
402
+ if self.BACKEND == self.TORCH:
254
403
  return smat.matmul(mat)
255
- if self.BACKEND==self.NUMPY:
404
+ if self.BACKEND == self.NUMPY:
256
405
  return smat.dot(mat)
257
406
 
258
- def conv2d(self,x,w,strides=[1, 1, 1, 1],padding='SAME'):
259
- if self.BACKEND==self.TENSORFLOW:
260
- return self.backend.nn.conv2d(x,w,
261
- strides=strides,
262
- padding=padding)
407
+ def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
408
+ if self.BACKEND == self.TENSORFLOW:
409
+ kx = w.shape[0]
410
+ ky = w.shape[1]
411
+ paddings = self.backend.constant(
412
+ [[0, 0], [kx // 2, kx // 2], [ky // 2, ky // 2], [0, 0]]
413
+ )
414
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
415
+ return self.backend.nn.conv2d(tmp, w, strides=strides, padding="VALID")
263
416
  # to be written!!!
264
- if self.BACKEND==self.TORCH:
417
+ if self.BACKEND == self.TORCH:
265
418
  return x
266
- if self.BACKEND==self.NUMPY:
419
+ if self.BACKEND == self.NUMPY:
420
+ res = np.zeros(
421
+ [x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
422
+ )
423
+ for k in range(w.shape[2]):
424
+ for l_orient in range(w.shape[3]):
425
+ for j in range(res.shape[0]):
426
+ tmp = self.scipy.signal.convolve2d(
427
+ x[j, :, :, k],
428
+ w[:, :, k, l_orient],
429
+ mode="same",
430
+ boundary="symm",
431
+ )
432
+ res[j, :, :, l_orient] += tmp
433
+ del tmp
434
+ return res
435
+
436
+ def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
437
+ if self.BACKEND == self.TENSORFLOW:
438
+ kx = w.shape[0]
439
+ paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
440
+ tmp = self.backend.pad(x, paddings, "SYMMETRIC")
441
+
442
+ return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
443
+ # to be written!!!
444
+ if self.BACKEND == self.TORCH:
267
445
  return x
446
+ if self.BACKEND == self.NUMPY:
447
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[2]], dtype=x.dtype)
448
+ for k in range(w.shape[2]):
449
+ for j in range(res.shape[0]):
450
+ tmp = self.scipy.signal.convolve1d(
451
+ x[j, :, k], w[:, k], mode="same", boundary="symm"
452
+ )
453
+ res[j, :, :] += tmp
454
+ del tmp
455
+ return res
268
456
 
269
- def bk_threshold(self,x,threshold,greater=True):
270
-
271
- if self.BACKEND==self.TENSORFLOW:
272
- return(self.backend.cast(x>threshold,x.dtype)*x)
273
- if self.BACKEND==self.TORCH:
274
- return(self.backend.cast(x>threshold,x.dtype)*x)
275
- if self.BACKEND==self.NUMPY:
276
- return (x>threshold)*x
277
-
278
- def bk_maximum(self,x1,x2):
279
- if self.BACKEND==self.TENSORFLOW:
280
- return(self.backend.maximum(x1,x2))
281
- if self.BACKEND==self.TORCH:
282
- return(self.backend.maximum(x1,x2))
283
- if self.BACKEND==self.NUMPY:
284
- return x1*(x1>x2)+x2*(x2>x1)
285
-
286
- def bk_device(self,device_name):
457
+ def bk_threshold(self, x, threshold, greater=True):
458
+
459
+ if self.BACKEND == self.TENSORFLOW:
460
+ return self.backend.cast(x > threshold, x.dtype) * x
461
+ if self.BACKEND == self.TORCH:
462
+ x.to(x.dtype)
463
+ return (x > threshold) * x
464
+ # return(self.backend.cast(x>threshold,x.dtype)*x)
465
+ if self.BACKEND == self.NUMPY:
466
+ return (x > threshold) * x
467
+
468
+ def bk_maximum(self, x1, x2):
469
+ if self.BACKEND == self.TENSORFLOW:
470
+ return self.backend.maximum(x1, x2)
471
+ if self.BACKEND == self.TORCH:
472
+ return self.backend.maximum(x1, x2)
473
+ if self.BACKEND == self.NUMPY:
474
+ return x1 * (x1 > x2) + x2 * (x2 > x1)
475
+
476
+ def bk_device(self, device_name):
287
477
  return self.backend.device(device_name)
288
-
289
- def bk_ones(self,shape,dtype=None):
478
+
479
+ def bk_ones(self, shape, dtype=None):
290
480
  if dtype is None:
291
- dtype=self.all_type
292
- if self.BACKEND==self.TORCH:
481
+ dtype = self.all_type
482
+ if self.BACKEND == self.TORCH:
293
483
  return self.bk_cast(np.ones(shape))
294
- return(self.backend.ones(shape,dtype=dtype))
295
-
296
- def bk_conv1d(self,x,w):
297
- if self.BACKEND==self.TENSORFLOW:
298
- return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
299
- if self.BACKEND==self.TORCH:
300
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
301
- if self.BACKEND==self.NUMPY:
302
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
303
-
304
- def bk_flattenR(self,x):
305
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
484
+ return self.backend.ones(shape, dtype=dtype)
485
+
486
+ def bk_conv1d(self, x, w):
487
+ if self.BACKEND == self.TENSORFLOW:
488
+ return self.backend.nn.conv1d(x, w, stride=[1, 1, 1], padding="SAME")
489
+ if self.BACKEND == self.TORCH:
490
+ # Torch not yet done !!!
491
+ return self.backend.nn.conv1d(x, w, stride=1, padding="SAME")
492
+ if self.BACKEND == self.NUMPY:
493
+ res = np.zeros([x.shape[0], x.shape[1], w.shape[1]], dtype=x.dtype)
494
+ for k in range(w.shape[1]):
495
+ for l_orient in range(w.shape[2]):
496
+ res[:, :, l_orient] += self.scipy.ndimage.convolve1d(
497
+ x[:, :, k], w[:, k, l_orient], axis=1, mode="constant", cval=0.0
498
+ )
499
+ return res
500
+
501
+ def bk_flattenR(self, x):
502
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
306
503
  if self.bk_is_complex(x):
307
- rr=self.backend.reshape(self.bk_real(x),[np.prod(np.array(list(x.shape)))])
308
- ii=self.backend.reshape(self.bk_imag(x),[np.prod(np.array(list(x.shape)))])
309
- return self.bk_concat([rr,ii],axis=0)
504
+ rr = self.backend.reshape(
505
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
506
+ )
507
+ ii = self.backend.reshape(
508
+ self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
509
+ )
510
+ return self.bk_concat([rr, ii], axis=0)
310
511
  else:
311
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
312
-
313
- if self.BACKEND==self.NUMPY:
512
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
513
+
514
+ if self.BACKEND == self.NUMPY:
314
515
  if self.bk_is_complex(x):
315
- return np.concatenate([x.real.flatten(),x.imag.flatten()],0)
516
+ return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
316
517
  else:
317
518
  return x.flatten()
318
-
319
- def bk_flatten(self,x):
320
- if self.BACKEND==self.TENSORFLOW:
321
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
322
- if self.BACKEND==self.TORCH:
323
- return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
324
- if self.BACKEND==self.NUMPY:
519
+
520
+ def bk_flatten(self, x):
521
+ if self.BACKEND == self.TENSORFLOW:
522
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
523
+ if self.BACKEND == self.TORCH:
524
+ return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
525
+ if self.BACKEND == self.NUMPY:
325
526
  return x.flatten()
326
-
327
- def bk_resize_image(self,x,shape):
328
- if self.BACKEND==self.TENSORFLOW:
329
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
330
-
331
- if self.BACKEND==self.TORCH:
332
- print(x.shape)
333
- tmp=self.backend.nn.functional.interpolate(x,
334
- size=shape,
335
- mode='bilinear',
336
- align_corners=False)
337
- print(tmp.shape)
527
+
528
+ def bk_size(self, x):
529
+ if self.BACKEND == self.TENSORFLOW:
530
+ return self.backend.size(x)
531
+ if self.BACKEND == self.TORCH:
532
+ return x.numel()
533
+
534
+ if self.BACKEND == self.NUMPY:
535
+ return x.size
536
+
537
+ def bk_resize_image(self, x, shape):
538
+ if self.BACKEND == self.TENSORFLOW:
539
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
540
+
541
+ if self.BACKEND == self.TORCH:
542
+ tmp = self.backend.nn.functional.interpolate(
543
+ x, size=shape, mode="bilinear", align_corners=False
544
+ )
338
545
  return self.bk_cast(tmp)
339
- if self.BACKEND==self.NUMPY:
340
- return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
341
-
342
- def bk_L1(self,x):
343
- if x.dtype==self.all_cbk_type:
344
- xr=self.bk_real(x)
345
- xi=self.bk_imag(x)
346
-
347
- r=self.backend.sign(xr)*self.backend.sqrt(self.backend.sign(xr)*xr)
348
- i=self.backend.sign(xi)*self.backend.sqrt(self.backend.sign(xi)*xi)
349
- return self.bk_complex(r,i)
546
+ if self.BACKEND == self.NUMPY:
547
+ return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
548
+
549
+ def bk_L1(self, x):
550
+ if x.dtype == self.all_cbk_type:
551
+ xr = self.bk_real(x)
552
+ xi = self.bk_imag(x)
553
+
554
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
555
+ # return r
556
+ i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
557
+
558
+ if self.BACKEND == self.TORCH:
559
+ return r
560
+ else:
561
+ return self.bk_complex(r, i)
350
562
  else:
351
- return self.backend.sign(x)*self.backend.sqrt(self.backend.sign(x)*x)
352
-
353
- def bk_square_comp(self,x):
354
- if x.dtype==self.all_cbk_type:
355
- xr=self.bk_real(x)
356
- xi=self.bk_imag(x)
357
-
358
- r=xr*xr
359
- i=xi*xi
360
- return self.bk_complex(r,i)
563
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
564
+
565
+ def bk_square_comp(self, x):
566
+ if x.dtype == self.all_cbk_type:
567
+ xr = self.bk_real(x)
568
+ xi = self.bk_imag(x)
569
+
570
+ r = xr * xr
571
+ i = xi * xi
572
+ return self.bk_complex(r, i)
361
573
  else:
362
- return x*x
363
-
364
- def bk_reduce_sum(self,data,axis=None):
365
-
574
+ return x * x
575
+
576
+ def bk_reduce_sum(self, data, axis=None):
577
+
366
578
  if axis is None:
367
- if self.BACKEND==self.TENSORFLOW:
368
- return(self.backend.reduce_sum(data))
369
- if self.BACKEND==self.TORCH:
370
- return(self.backend.sum(data))
371
- if self.BACKEND==self.NUMPY:
372
- return(np.sum(data))
579
+ if self.BACKEND == self.TENSORFLOW:
580
+ return self.backend.reduce_sum(data)
581
+ if self.BACKEND == self.TORCH:
582
+ return self.backend.sum(data)
583
+ if self.BACKEND == self.NUMPY:
584
+ return np.sum(data)
373
585
  else:
374
- if self.BACKEND==self.TENSORFLOW:
375
- return(self.backend.reduce_sum(data,axis=axis))
376
- if self.BACKEND==self.TORCH:
377
- return(self.backend.sum(data,axis))
378
- if self.BACKEND==self.NUMPY:
379
- return(np.sum(data,axis))
380
-
586
+ if self.BACKEND == self.TENSORFLOW:
587
+ return self.backend.reduce_sum(data, axis=axis)
588
+ if self.BACKEND == self.TORCH:
589
+ return self.backend.sum(data, axis)
590
+ if self.BACKEND == self.NUMPY:
591
+ return np.sum(data, axis)
592
+
381
593
  # ---------------------------------------------−---------
594
+ # return a tensor size
382
595
 
383
- def iso_mean(self,x,use_2D=False):
384
- shape=list(x.shape)
596
+ def bk_size(self, data):
597
+ if self.BACKEND == self.TENSORFLOW:
598
+ return self.backend.size(data)
599
+ if self.BACKEND == self.TORCH:
600
+ return data.numel()
601
+ if self.BACKEND == self.NUMPY:
602
+ return data.size
385
603
 
386
- i_orient=2
604
+ # ---------------------------------------------−---------
605
+
606
+ def iso_mean(self, x, use_2D=False):
607
+ shape = list(x.shape)
608
+
609
+ i_orient = 2
387
610
  if use_2D:
388
- i_orient=3
389
- norient=shape[i_orient]
611
+ i_orient = 3
612
+ norient = shape[i_orient]
613
+
614
+ if len(shape) == i_orient + 1:
615
+ return self.bk_reduce_mean(x, -1)
390
616
 
391
- if len(shape)==i_orient+1:
392
- return self.bk_reduce_mean(x,-1)
393
-
394
617
  if norient not in self._iso_orient:
395
618
  self.calc_iso_orient(norient)
396
619
 
397
620
  if self.bk_is_complex(x):
398
- lmat = self._iso_orient_C[norient]
399
- lmat_T = self._iso_orient_C_T[norient]
621
+ lmat = self._iso_orient_C[norient]
400
622
  else:
401
- lmat = self._iso_orient[norient]
402
- lmat_T = self._iso_orient_T[norient]
403
-
404
- oshape=shape[0]
405
- for k in range(1,len(shape)-2):
406
- oshape*=shape[k]
407
-
408
- oshape2=[shape[k] for k in range(0,len(shape)-1)]
409
-
410
- return self.bk_reshape(self.backend.matmul(
411
- self.bk_reshape(x,[oshape,norient*norient]),lmat),oshape2)
623
+ lmat = self._iso_orient[norient]
412
624
 
413
-
414
- def fft_ang(self,x,nharm=1,imaginary=False,use_2D=False):
415
- shape=list(x.shape)
625
+ oshape = shape[0]
626
+ for k in range(1, len(shape) - 2):
627
+ oshape *= shape[k]
628
+
629
+ oshape2 = [shape[k] for k in range(0, len(shape) - 1)]
630
+
631
+ return self.bk_reshape(
632
+ self.backend.matmul(self.bk_reshape(x, [oshape, norient * norient]), lmat),
633
+ oshape2,
634
+ )
416
635
 
417
- i_orient=2
636
+ def fft_ang(self, x, nharm=1, imaginary=False, use_2D=False):
637
+ shape = list(x.shape)
638
+
639
+ i_orient = 2
418
640
  if use_2D:
419
- i_orient=3
420
-
421
- norient=shape[i_orient]
422
- nout=1+nharm
423
-
424
- oshape_1=shape[0]
425
- for k in range(1,i_orient):
426
- oshape_1*=shape[k]
427
- oshape_2=norient
428
- for k in range(i_orient,len(shape)-1):
429
- oshape_2*=shape[k]
430
- oshape=[oshape_1,oshape_2]
431
-
432
-
641
+ i_orient = 3
642
+
643
+ norient = shape[i_orient]
644
+ nout = 1 + nharm
645
+
646
+ oshape_1 = shape[0]
647
+ for k in range(1, i_orient):
648
+ oshape_1 *= shape[k]
649
+ oshape_2 = norient
650
+ for k in range(i_orient, len(shape) - 1):
651
+ oshape_2 *= shape[k]
652
+ oshape = [oshape_1, oshape_2]
653
+
433
654
  if imaginary:
434
- nout=1+nharm*2
435
-
436
- oshape2=[shape[k] for k in range(0,i_orient)]+[nout for k in range(i_orient,len(shape))]
437
-
438
- if (norient,nharm) not in self._fft_1_orient:
439
- self.calc_fft_orient(norient,nharm,imaginary)
655
+ nout = 1 + nharm * 2
656
+
657
+ oshape2 = [shape[k] for k in range(0, i_orient)] + [
658
+ nout for k in range(i_orient, len(shape))
659
+ ]
440
660
 
441
- if len(shape)==i_orient+1:
661
+ if (norient, nharm) not in self._fft_1_orient:
662
+ self.calc_fft_orient(norient, nharm, imaginary)
663
+
664
+ if len(shape) == i_orient + 1:
442
665
  if self.bk_is_complex(x):
443
- lmat = self._fft_1_orient_C[(norient,nharm,imaginary)]
666
+ lmat = self._fft_1_orient_C[(norient, nharm, imaginary)]
444
667
  else:
445
- lmat = self._fft_1_orient[(norient,nharm,imaginary)]
446
-
447
- if len(shape)==i_orient+2:
668
+ lmat = self._fft_1_orient[(norient, nharm, imaginary)]
669
+
670
+ if len(shape) == i_orient + 2:
448
671
  if self.bk_is_complex(x):
449
- lmat = self._fft_2_orient_C[(norient,nharm,imaginary)]
672
+ lmat = self._fft_2_orient_C[(norient, nharm, imaginary)]
450
673
  else:
451
- lmat = self._fft_2_orient[(norient,nharm,imaginary)]
452
-
453
- if len(shape)==i_orient+3:
674
+ lmat = self._fft_2_orient[(norient, nharm, imaginary)]
675
+
676
+ if len(shape) == i_orient + 3:
454
677
  if self.bk_is_complex(x):
455
- lmat = self._fft_3_orient_C[(norient,nharm,imaginary)]
678
+ lmat = self._fft_3_orient_C[(norient, nharm, imaginary)]
456
679
  else:
457
- lmat = self._fft_3_orient[(norient,nharm,imaginary)]
458
-
459
- return self.bk_reshape(self.backend.matmul(self.bk_reshape(x,oshape),lmat),oshape2)
460
-
461
- def constant(self,data):
462
-
463
- if self.BACKEND==self.TENSORFLOW:
464
- return(self.backend.constant(data))
465
- return(data)
680
+ lmat = self._fft_3_orient[(norient, nharm, imaginary)]
681
+
682
+ return self.bk_reshape(
683
+ self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
684
+ )
685
+
686
+ def constant(self, data):
687
+
688
+ if self.BACKEND == self.TENSORFLOW:
689
+ return self.backend.constant(data)
690
+ return data
691
+
692
+ def bk_reduce_mean(self, data, axis=None):
466
693
 
467
- def bk_reduce_mean(self,data,axis=None):
468
-
469
694
  if axis is None:
470
- if self.BACKEND==self.TENSORFLOW:
471
- return(self.backend.reduce_mean(data))
472
- if self.BACKEND==self.TORCH:
473
- return(self.backend.mean(data))
474
- if self.BACKEND==self.NUMPY:
475
- return(np.mean(data))
695
+ if self.BACKEND == self.TENSORFLOW:
696
+ return self.backend.reduce_mean(data)
697
+ if self.BACKEND == self.TORCH:
698
+ return self.backend.mean(data)
699
+ if self.BACKEND == self.NUMPY:
700
+ return np.mean(data)
476
701
  else:
477
- if self.BACKEND==self.TENSORFLOW:
478
- return(self.backend.reduce_mean(data,axis=axis))
479
- if self.BACKEND==self.TORCH:
480
- return(self.backend.mean(data,axis))
481
- if self.BACKEND==self.NUMPY:
482
- return(np.mean(data,axis))
483
-
484
- def bk_reduce_std(self,data,axis=None):
485
-
702
+ if self.BACKEND == self.TENSORFLOW:
703
+ return self.backend.reduce_mean(data, axis=axis)
704
+ if self.BACKEND == self.TORCH:
705
+ return self.backend.mean(data, axis)
706
+ if self.BACKEND == self.NUMPY:
707
+ return np.mean(data, axis)
708
+
709
+ def bk_reduce_min(self, data, axis=None):
710
+
486
711
  if axis is None:
487
- if self.BACKEND==self.TENSORFLOW:
488
- return(self.backend.math.reduce_std(data))
489
- if self.BACKEND==self.TORCH:
490
- return(self.backend.std(data))
491
- if self.BACKEND==self.NUMPY:
492
- return(np.std(data))
712
+ if self.BACKEND == self.TENSORFLOW:
713
+ return self.backend.reduce_min(data)
714
+ if self.BACKEND == self.TORCH:
715
+ return self.backend.min(data)
716
+ if self.BACKEND == self.NUMPY:
717
+ return np.min(data)
493
718
  else:
494
- if self.BACKEND==self.TENSORFLOW:
495
- return(self.backend.math.reduce_std(data,axis=axis))
496
- if self.BACKEND==self.TORCH:
497
- return(self.backend.std(data,axis))
498
- if self.BACKEND==self.NUMPY:
499
- return(np.std(data,axis))
500
-
501
-
502
- def bk_sqrt(self,data):
503
-
504
- return(self.backend.sqrt(self.backend.abs(data)))
505
-
506
- def bk_abs(self,data):
507
- return(self.backend.abs(data))
719
+ if self.BACKEND == self.TENSORFLOW:
720
+ return self.backend.reduce_min(data, axis=axis)
721
+ if self.BACKEND == self.TORCH:
722
+ return self.backend.min(data, axis)
723
+ if self.BACKEND == self.NUMPY:
724
+ return np.min(data, axis)
508
725
 
509
- def bk_is_complex(self,data):
510
-
511
- if self.BACKEND==self.TENSORFLOW:
512
- if isinstance(data,np.ndarray):
513
- return (data.dtype=='complex64' or data.dtype=='complex128')
726
+ def bk_random_seed(self, value):
727
+
728
+ if self.BACKEND == self.TENSORFLOW:
729
+ return self.backend.random.set_seed(value)
730
+ if self.BACKEND == self.TORCH:
731
+ return self.backend.random.set_seed(value)
732
+ if self.BACKEND == self.NUMPY:
733
+ return np.random.seed(value)
734
+
735
+ def bk_random_uniform(self, shape):
736
+
737
+ if self.BACKEND == self.TENSORFLOW:
738
+ return self.backend.random.uniform(shape)
739
+ if self.BACKEND == self.TORCH:
740
+ return self.backend.random.uniform(shape)
741
+ if self.BACKEND == self.NUMPY:
742
+ return np.random.rand(shape)
743
+
744
+ def bk_reduce_std(self, data, axis=None):
745
+ if axis is None:
746
+ if self.BACKEND == self.TENSORFLOW:
747
+ r = self.backend.math.reduce_std(data)
748
+ if self.BACKEND == self.TORCH:
749
+ r = self.backend.std(data)
750
+ if self.BACKEND == self.NUMPY:
751
+ r = np.std(data)
752
+ return self.bk_complex(r, 0 * r)
753
+ else:
754
+ if self.BACKEND == self.TENSORFLOW:
755
+ r = self.backend.math.reduce_std(data, axis=axis)
756
+ if self.BACKEND == self.TORCH:
757
+ r = self.backend.std(data, axis)
758
+ if self.BACKEND == self.NUMPY:
759
+ r = np.std(data, axis)
760
+ if self.bk_is_complex(data):
761
+ return self.bk_complex(r, 0 * r)
762
+ else:
763
+ return r
764
+
765
+ def bk_sqrt(self, data):
766
+
767
+ return self.backend.sqrt(self.backend.abs(data))
768
+
769
+ def bk_abs(self, data):
770
+ return self.backend.abs(data)
771
+
772
+ def bk_is_complex(self, data):
773
+
774
+ if self.BACKEND == self.TENSORFLOW:
775
+ if isinstance(data, np.ndarray):
776
+ return data.dtype == "complex64" or data.dtype == "complex128"
514
777
  return data.dtype.is_complex
515
-
516
- if self.BACKEND==self.TORCH:
517
- if isinstance(data,np.ndarray):
518
- return (data.dtype=='complex64' or data.dtype=='complex128')
519
-
778
+
779
+ if self.BACKEND == self.TORCH:
780
+ if isinstance(data, np.ndarray):
781
+ return data.dtype == "complex64" or data.dtype == "complex128"
782
+
520
783
  return data.dtype.is_complex
521
-
522
- if self.BACKEND==self.NUMPY:
523
- return (data.dtype=='complex64' or data.dtype=='complex128')
524
-
525
- def bk_norm(self,data):
784
+
785
+ if self.BACKEND == self.NUMPY:
786
+ return data.dtype == "complex64" or data.dtype == "complex128"
787
+
788
+ def bk_distcomp(self, data):
526
789
  if self.bk_is_complex(data):
527
- res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
790
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
791
+ self.bk_imag(data)
792
+ )
793
+ return res
794
+ else:
795
+ return self.bk_square(data)
796
+
797
+ def bk_norm(self, data):
798
+ if self.bk_is_complex(data):
799
+ res = self.bk_square(self.bk_real(data)) + self.bk_square(
800
+ self.bk_imag(data)
801
+ )
528
802
  return self.bk_sqrt(res)
529
803
 
530
804
  else:
531
805
  return self.bk_abs(data)
532
-
533
- def bk_square(self,data):
534
-
535
- if self.BACKEND==self.TENSORFLOW:
536
- return(self.backend.square(data))
537
- if self.BACKEND==self.TORCH:
538
- return(self.backend.square(data))
539
- if self.BACKEND==self.NUMPY:
540
- return(data*data)
541
-
542
- def bk_log(self,data):
543
- if self.BACKEND==self.TENSORFLOW:
544
- return(self.backend.math.log(data))
545
- if self.BACKEND==self.TORCH:
546
- return(self.backend.log(data))
547
- if self.BACKEND==self.NUMPY:
548
- return(np.log(data))
549
-
550
- def bk_matmul(self,a,b):
551
- if self.BACKEND==self.TENSORFLOW:
552
- return(self.backend.matmul(a,b))
553
- if self.BACKEND==self.TORCH:
554
- return(self.backend.matmul(a,b))
555
- if self.BACKEND==self.NUMPY:
556
- return(np.dot(a,b))
557
-
558
- def bk_tensor(self,data):
559
- if self.BACKEND==self.TENSORFLOW:
560
- return(self.backend.constant(data))
561
- if self.BACKEND==self.TORCH:
562
- return(self.backend.constant(data))
563
- if self.BACKEND==self.NUMPY:
564
- return(data)
565
-
566
- def bk_complex(self,real,imag):
567
- if self.BACKEND==self.TENSORFLOW:
568
- return(self.backend.dtypes.complex(real,imag))
569
- if self.BACKEND==self.TORCH:
570
- return(self.backend.complex(real,imag))
571
- if self.BACKEND==self.NUMPY:
572
- return real+1J*imag
573
-
574
- def bk_exp(self,data):
575
-
576
- return(self.backend.exp(data))
577
-
578
- def bk_min(self,data):
579
-
580
- return(self.backend.reduce_min(data))
581
-
582
- def bk_argmin(self,data):
583
-
584
- return(self.backend.argmin(data))
585
-
586
- def bk_tanh(self,data):
587
-
588
- return(self.backend.math.tanh(data))
589
-
590
- def bk_max(self,data):
591
-
592
- return(self.backend.reduce_max(data))
593
-
594
- def bk_argmax(self,data):
595
-
596
- return(self.backend.argmax(data))
597
-
598
- def bk_reshape(self,data,shape):
599
- if self.BACKEND==self.TORCH:
600
- if isinstance(data,np.ndarray):
806
+
807
+ def bk_square(self, data):
808
+
809
+ if self.BACKEND == self.TENSORFLOW:
810
+ return self.backend.square(data)
811
+ if self.BACKEND == self.TORCH:
812
+ return self.backend.square(data)
813
+ if self.BACKEND == self.NUMPY:
814
+ return data * data
815
+
816
+ def bk_log(self, data):
817
+ if self.BACKEND == self.TENSORFLOW:
818
+ return self.backend.math.log(data)
819
+ if self.BACKEND == self.TORCH:
820
+ return self.backend.log(data)
821
+ if self.BACKEND == self.NUMPY:
822
+ return np.log(data)
823
+
824
+ def bk_matmul(self, a, b):
825
+ if self.BACKEND == self.TENSORFLOW:
826
+ return self.backend.matmul(a, b)
827
+ if self.BACKEND == self.TORCH:
828
+ return self.backend.matmul(a, b)
829
+ if self.BACKEND == self.NUMPY:
830
+ return np.dot(a, b)
831
+
832
+ def bk_tensor(self, data):
833
+ if self.BACKEND == self.TENSORFLOW:
834
+ return self.backend.constant(data)
835
+ if self.BACKEND == self.TORCH:
836
+ return self.backend.constant(data)
837
+ if self.BACKEND == self.NUMPY:
838
+ return data
839
+
840
+ def bk_complex(self, real, imag):
841
+ if self.BACKEND == self.TENSORFLOW:
842
+ return self.backend.dtypes.complex(real, imag)
843
+ if self.BACKEND == self.TORCH:
844
+ return self.backend.complex(real, imag)
845
+ if self.BACKEND == self.NUMPY:
846
+ return real + 1j * imag
847
+
848
+ def bk_exp(self, data):
849
+
850
+ return self.backend.exp(data)
851
+
852
+ def bk_min(self, data):
853
+
854
+ return self.backend.reduce_min(data)
855
+
856
+ def bk_argmin(self, data):
857
+
858
+ return self.backend.argmin(data)
859
+
860
+ def bk_tanh(self, data):
861
+
862
+ return self.backend.math.tanh(data)
863
+
864
+ def bk_max(self, data):
865
+
866
+ return self.backend.reduce_max(data)
867
+
868
+ def bk_argmax(self, data):
869
+
870
+ return self.backend.argmax(data)
871
+
872
+ def bk_reshape(self, data, shape):
873
+ if self.BACKEND == self.TORCH:
874
+ if isinstance(data, np.ndarray):
601
875
  return data.reshape(shape)
602
-
603
- return(self.backend.reshape(data,shape))
604
-
605
- def bk_repeat(self,data,nn,axis=0):
606
- return(self.backend.repeat(data,nn,axis=axis))
607
-
608
- def bk_tile(self,data,nn,axis=0):
609
- return(self.backend.tile(data,nn))
610
-
611
- def bk_roll(self,data,nn,axis=0):
612
- return(self.backend.roll(data,nn,axis=axis))
613
-
614
- def bk_expand_dims(self,data,axis=0):
615
- if self.BACKEND==self.TENSORFLOW:
616
- return(self.backend.expand_dims(data,axis=axis))
617
- if self.BACKEND==self.TORCH:
618
- if isinstance(data,np.ndarray):
619
- data=self.backend.from_numpy(data)
620
- return(self.backend.unsqueeze(data,axis))
621
- if self.BACKEND==self.NUMPY:
622
- return(np.expand_dims(data,axis))
623
-
624
- def bk_transpose(self,data,thelist):
625
- if self.BACKEND==self.TENSORFLOW:
626
- return(self.backend.transpose(data,thelist))
627
- if self.BACKEND==self.TORCH:
628
- return(self.backend.transpose(data,thelist))
629
- if self.BACKEND==self.NUMPY:
630
- return(np.transpose(data,thelist))
631
-
632
- def bk_concat(self,data,axis=None):
633
-
634
- if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
876
+
877
+ return self.backend.reshape(data, shape)
878
+
879
+ def bk_repeat(self, data, nn, axis=0):
880
+ return self.backend.repeat(data, nn, axis=axis)
881
+
882
+ def bk_tile(self, data, nn,axis=0):
883
+ if self.BACKEND == self.TENSORFLOW:
884
+ return self.backend.tile(data, [nn])
885
+
886
+ return self.backend.tile(data, nn)
887
+
888
+ def bk_roll(self, data, nn, axis=0):
889
+ return self.backend.roll(data, nn, axis=axis)
890
+
891
+ def bk_expand_dims(self, data, axis=0):
892
+ if self.BACKEND == self.TENSORFLOW:
893
+ return self.backend.expand_dims(data, axis=axis)
894
+ if self.BACKEND == self.TORCH:
895
+ if isinstance(data, np.ndarray):
896
+ data = self.backend.from_numpy(data)
897
+ return self.backend.unsqueeze(data, axis)
898
+ if self.BACKEND == self.NUMPY:
899
+ return np.expand_dims(data, axis)
900
+
901
+ def bk_transpose(self, data, thelist):
902
+ if self.BACKEND == self.TENSORFLOW:
903
+ return self.backend.transpose(data, thelist)
904
+ if self.BACKEND == self.TORCH:
905
+ return self.backend.transpose(data, thelist)
906
+ if self.BACKEND == self.NUMPY:
907
+ return np.transpose(data, thelist)
908
+
909
+ def bk_concat(self, data, axis=None):
910
+
911
+ if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
635
912
  if axis is None:
636
- return(self.backend.concat(data))
913
+ if data[0].dtype == self.all_cbk_type:
914
+ ndata = len(data)
915
+ xr = self.backend.concat(
916
+ [self.bk_real(data[k]) for k in range(ndata)]
917
+ )
918
+ xi = self.backend.concat(
919
+ [self.bk_imag(data[k]) for k in range(ndata)]
920
+ )
921
+ return self.backend.complex(xr, xi)
922
+ else:
923
+ return self.backend.concat(data)
637
924
  else:
638
- return(self.backend.concat(data,axis=axis))
925
+ if data[0].dtype == self.all_cbk_type:
926
+ ndata = len(data)
927
+ xr = self.backend.concat(
928
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
929
+ )
930
+ xi = self.backend.concat(
931
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
932
+ )
933
+ return self.backend.complex(xr, xi)
934
+ else:
935
+ return self.backend.concat(data, axis=axis)
639
936
  else:
640
937
  if axis is None:
641
- return np.concatenate(data,axis=0)
938
+ return np.concatenate(data, axis=0)
642
939
  else:
643
- return np.concatenate(data,axis=axis)
940
+ return np.concatenate(data, axis=axis)
644
941
 
645
-
646
- def bk_conjugate(self,data):
647
-
648
- if self.BACKEND==self.TENSORFLOW:
942
+ def bk_zeros(self, shape,dtype=None):
943
+ if self.BACKEND == self.TENSORFLOW:
944
+ return self.backend.zeros(shape,dtype=dtype)
945
+ if self.BACKEND == self.TORCH:
946
+ return self.backend.zeros(shape,dtype=dtype)
947
+ if self.BACKEND == self.NUMPY:
948
+ return np.zeros(shape,dtype=dtype)
949
+
950
+ def bk_gather(self, data,idx):
951
+ if self.BACKEND == self.TENSORFLOW:
952
+ return self.backend.gather(data,idx)
953
+ if self.BACKEND == self.TORCH:
954
+ return data[idx]
955
+ if self.BACKEND == self.NUMPY:
956
+ return data[idx]
957
+
958
+ def bk_reverse(self, data,axis=0):
959
+ if self.BACKEND == self.TENSORFLOW:
960
+ return self.backend.reverse(data,axis=[axis])
961
+ if self.BACKEND == self.TORCH:
962
+ return self.backend.reverse(data,axis=axis)
963
+ if self.BACKEND == self.NUMPY:
964
+ return np.reverse(data,axis=axis)
965
+
966
+ def bk_fft(self, data):
967
+ if self.BACKEND == self.TENSORFLOW:
968
+ return self.backend.signal.fft(data)
969
+ if self.BACKEND == self.TORCH:
970
+ return self.backend.fft(data)
971
+ if self.BACKEND == self.NUMPY:
972
+ return self.backend.fft.fft(data)
973
+
974
+ def bk_rfft(self, data):
975
+ if self.BACKEND == self.TENSORFLOW:
976
+ return self.backend.signal.rfft(data)
977
+ if self.BACKEND == self.TORCH:
978
+ return self.backend.rfft(data)
979
+ if self.BACKEND == self.NUMPY:
980
+ return self.backend.fft.rfft(data)
981
+ def bk_conjugate(self, data):
982
+
983
+ if self.BACKEND == self.TENSORFLOW:
649
984
  return self.backend.math.conj(data)
650
- if self.BACKEND==self.TORCH:
985
+ if self.BACKEND == self.TORCH:
651
986
  return self.backend.conj(data)
652
- if self.BACKEND==self.NUMPY:
987
+ if self.BACKEND == self.NUMPY:
653
988
  return data.conjugate()
654
-
655
- def bk_real(self,data):
656
- if self.BACKEND==self.TENSORFLOW:
989
+
990
+ def bk_real(self, data):
991
+ if self.BACKEND == self.TENSORFLOW:
657
992
  return self.backend.math.real(data)
658
- if self.BACKEND==self.TORCH:
659
- return self.backend.real(data)
660
- if self.BACKEND==self.NUMPY:
661
- return self.backend.real(data)
993
+ if self.BACKEND == self.TORCH:
994
+ return data.real
995
+ if self.BACKEND == self.NUMPY:
996
+ return data.real
662
997
 
663
- def bk_imag(self,data):
664
- if self.BACKEND==self.TENSORFLOW:
998
+ def bk_imag(self, data):
999
+ if self.BACKEND == self.TENSORFLOW:
665
1000
  return self.backend.math.imag(data)
666
- if self.BACKEND==self.TORCH:
667
- return self.backend.imag(data)
668
- if self.BACKEND==self.NUMPY:
669
- return self.backend.imag(data)
670
-
671
- def bk_relu(self,x):
672
- if self.BACKEND==self.TENSORFLOW:
673
- if x.dtype==self.all_cbk_type:
674
- xr=self.backend.nn.relu(self.bk_real(x))
675
- xi=self.backend.nn.relu(self.bk_imag(x))
676
- return self.backend.complex(xr,xi)
1001
+ if self.BACKEND == self.TORCH:
1002
+ if data.dtype == self.all_cbk_type:
1003
+ return data.imag
1004
+ else:
1005
+ return 0
1006
+
1007
+ if self.BACKEND == self.NUMPY:
1008
+ return data.imag
1009
+
1010
+ def bk_relu(self, x):
1011
+ if self.BACKEND == self.TENSORFLOW:
1012
+ if x.dtype == self.all_cbk_type:
1013
+ xr = self.backend.nn.relu(self.bk_real(x))
1014
+ xi = self.backend.nn.relu(self.bk_imag(x))
1015
+ return self.backend.complex(xr, xi)
677
1016
  else:
678
1017
  return self.backend.nn.relu(x)
679
- if self.BACKEND==self.TORCH:
1018
+ if self.BACKEND == self.TORCH:
680
1019
  return self.backend.relu(x)
681
- if self.BACKEND==self.NUMPY:
682
- return (x>0)*x
683
-
684
- def bk_cast(self,x):
685
- if isinstance(x,np.float64):
686
- if self.all_bk_type=='float32':
687
- return(np.float32(x))
1020
+ if self.BACKEND == self.NUMPY:
1021
+ return (x > 0) * x
1022
+
1023
+ def bk_cast(self, x):
1024
+ if isinstance(x, np.float64):
1025
+ if self.all_bk_type == "float32":
1026
+ return np.float32(x)
688
1027
  else:
689
- return(x)
690
- if isinstance(x,np.float32):
691
- if self.all_bk_type=='float64':
692
- return(np.float64(x))
1028
+ return x
1029
+ if isinstance(x, np.float32):
1030
+ if self.all_bk_type == "float64":
1031
+ return np.float64(x)
693
1032
  else:
694
- return(x)
1033
+ return x
1034
+
1035
+ if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
1036
+ if self.all_bk_type == "float64":
1037
+ return np.float64(x)
1038
+ else:
1039
+ return np.float32(x)
695
1040
 
696
1041
  if self.bk_is_complex(x):
697
- out_type=self.all_cbk_type
1042
+ out_type = self.all_cbk_type
698
1043
  else:
699
- out_type=self.all_bk_type
700
-
701
- if self.BACKEND==self.TENSORFLOW:
702
- return self.backend.cast(x,out_type)
703
-
704
- if self.BACKEND==self.TORCH:
705
- if isinstance(x,np.ndarray):
706
- x=self.backend.from_numpy(x)
707
-
1044
+ out_type = self.all_bk_type
1045
+
1046
+ if self.BACKEND == self.TENSORFLOW:
1047
+ return self.backend.cast(x, out_type)
1048
+
1049
+ if self.BACKEND == self.TORCH:
1050
+ if isinstance(x, np.ndarray):
1051
+ x = self.backend.from_numpy(x)
1052
+
708
1053
  if x.dtype.is_complex:
709
- out_type=self.all_cbk_type
1054
+ out_type = self.all_cbk_type
710
1055
  else:
711
- out_type=self.all_bk_type
712
-
1056
+ out_type = self.all_bk_type
1057
+
713
1058
  return x.type(out_type)
714
-
715
- if self.BACKEND==self.NUMPY:
1059
+
1060
+ if self.BACKEND == self.NUMPY:
716
1061
  return x.astype(out_type)