foscat 2025.6.3__py3-none-any.whl → 2025.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,7 +5,7 @@ import healpy as hp
5
5
  import numpy as np
6
6
  from scipy.interpolate import griddata
7
7
 
8
- TMPFILE_VERSION = "V5_0"
8
+ TMPFILE_VERSION = "V6_0"
9
9
 
10
10
 
11
11
  class FoCUS:
@@ -22,7 +22,7 @@ class FoCUS:
22
22
  mask_thres=None,
23
23
  mask_norm=False,
24
24
  isMPI=False,
25
- TEMPLATE_PATH="data",
25
+ TEMPLATE_PATH=None,
26
26
  BACKEND="torch",
27
27
  use_2D=False,
28
28
  use_1D=False,
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0
36
36
  ):
37
37
 
38
- self.__version__ = "2025.06.3"
38
+ self.__version__ = "2025.07.1"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -62,6 +62,11 @@ class FoCUS:
62
62
  print("================================================")
63
63
  sys.stdout.flush()
64
64
 
65
+ home_dir = os.environ["HOME"]
66
+
67
+ if TEMPLATE_PATH is None:
68
+ TEMPLATE_PATH=home_dir+"/.FOSCAT/data"
69
+
65
70
  self.TEMPLATE_PATH = TEMPLATE_PATH
66
71
  if not os.path.exists(self.TEMPLATE_PATH):
67
72
  if not self.silent:
@@ -281,28 +286,10 @@ class FoCUS:
281
286
  self.KERNELSZ = KERNELSZ
282
287
 
283
288
  self.Idx_Neighbours = {}
289
+ self.w_smooth = {}
284
290
 
285
- if not self.use_2D and not self.use_1D:
286
- self.w_smooth = {}
287
- for i in range(nstep_max):
288
- lout = 2**i
289
- self.ww_Real[lout] = None
290
-
291
- for i in range(1, 6):
292
- lout = 2**i
293
- if not self.silent:
294
- print("Init Wave ", lout)
295
-
296
- if self.InitWave is None:
297
- wr, wi, ws, widx = self.init_index(lout)
298
- else:
299
- wr, wi, ws, widx = self.InitWave(self, lout)
300
-
301
- self.Idx_Neighbours[lout] = 1 # self.backend.bk_constant(widx)
302
- self.ww_Real[lout] = wr
303
- self.ww_Imag[lout] = wi
304
- self.w_smooth[lout] = ws
305
- elif self.use_1D:
291
+
292
+ if self.use_1D:
306
293
  self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
307
294
  self.ww_RealT = {}
308
295
  self.ww_ImagT = {}
@@ -329,7 +316,7 @@ class FoCUS:
329
316
  self.backend.bk_constant(np.array(w).reshape(xx.shape[0]))
330
317
  )
331
318
 
332
- else:
319
+ if self.use_2D:
333
320
  self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
334
321
  self.ww_RealT = {}
335
322
  self.ww_ImagT = {}
@@ -373,6 +360,30 @@ class FoCUS:
373
360
 
374
361
  self.loss = {}
375
362
 
363
+ self.dtype_dcode_map = {
364
+ 0: np.int64,
365
+ 1: np.int32,
366
+ 2: np.float32,
367
+ 3: np.float64,
368
+ 4: np.complex64,
369
+ 5: np.complex128
370
+ }
371
+ self.dtype_code_map = {
372
+ np.int64: 0,
373
+ np.int32: 1,
374
+ np.float32: 2,
375
+ np.float64: 3,
376
+ np.complex64: 4,
377
+ np.complex128: 5
378
+ }
379
+
380
+ # this is for the storage only
381
+ def get_dtype_code(self, dtype):
382
+ for key, code in self.dtype_code_map.items():
383
+ if np.dtype(dtype) == np.dtype(key):
384
+ return code
385
+ raise ValueError(f"Unsupported data type: {dtype}")
386
+
376
387
  def get_type(self):
377
388
  return self.all_type
378
389
 
@@ -471,6 +482,114 @@ class FoCUS:
471
482
  )
472
483
  return indices, weights, xc, yc, zc
473
484
 
485
+ #======================================================================================
486
+ # The next two functions prepare the ability of FOSCAT to work with large indexed file
487
+ #======================================================================================
488
+
489
+ def save_index(self, filepath, data, offset=0, count=None):
490
+ """
491
+ Save an N-dimensional NumPy array with shape (N, ...) to binary file.
492
+ A 12x int64 header is written, describing dtype and shape beyond axis 0.
493
+
494
+ Header layout (12 x int64):
495
+ [0] = dtype code (0=int64, 1=int32, 2=float32, 3=float64, 4=complex64, 5=complex128)
496
+ [1] = number of extra dimensions (i.e., data.ndim - 1)
497
+ [2:12] = shape[1:] padded with zeros
498
+
499
+ Parameters:
500
+ - filepath: target binary file path
501
+ - data: NumPy array with shape (N, ...)
502
+ - offset: number of items to skip on axis 0
503
+ - count: number of items to write on axis 0 (default: rest of the array)
504
+ """
505
+ if filepath is None:
506
+ raise ValueError("No filepath specified for writing.")
507
+
508
+ data = np.asarray(data)
509
+ if data.ndim < 1:
510
+ raise ValueError("Data must have at least one dimension.")
511
+
512
+ extra_dims = data.shape[1:]
513
+ if len(extra_dims) > 10:
514
+ raise ValueError(f"Too many dimensions: {data.ndim}. Max supported is 11 (1 + 10 extra).")
515
+
516
+ dtype_code = self.get_dtype_code(data.dtype)
517
+ itemsize = data.dtype.itemsize
518
+ item_shape = data.shape[1:]
519
+ item_count = np.prod(item_shape, dtype=np.int64) if item_shape else 1
520
+
521
+ if count is None:
522
+ count = data.shape[0]
523
+
524
+ header = np.zeros(12, dtype=np.int64)
525
+ header[0] = dtype_code
526
+ header[1] = len(extra_dims)
527
+ header[2:2 + len(extra_dims)] = extra_dims
528
+
529
+ mode = 'r+b' if os.path.exists(filepath) else 'w+b'
530
+ with open(filepath, mode) as f:
531
+ if os.path.getsize(filepath) == 0:
532
+ f.write(header.tobytes())
533
+
534
+ byte_offset = 12 * 8 + offset * itemsize * item_count # header is 96 bytes
535
+ f.seek(byte_offset)
536
+ f.write(data[offset:offset + count].tobytes())
537
+
538
+ def read_index(self, filepath, offset=0, count=None):
539
+ """
540
+ Load a NumPy array from a binary file with a 12x int64 header.
541
+
542
+ Header layout:
543
+ [0] = dtype code
544
+ [1] = number of extra dimensions (D)
545
+ [2:2+D] = shape[1:] of each sample (shape after axis 0)
546
+
547
+ Parameters:
548
+ - filepath: path to the binary file
549
+ - offset: number of samples to skip on axis 0
550
+ - count: number of samples to read (default: all remaining)
551
+
552
+ Returns:
553
+ - data: NumPy array with shape (count, ...) and correct dtype
554
+ """
555
+ if not os.path.exists(filepath):
556
+ raise FileNotFoundError(f"File not found: {filepath}")
557
+
558
+ with open(filepath, 'rb') as f:
559
+ header_bytes = f.read(12 * 8)
560
+ if len(header_bytes) != 96:
561
+ raise ValueError("Invalid or missing header (expected 96 bytes).")
562
+
563
+ header = np.frombuffer(header_bytes, dtype=np.int64)
564
+ dtype_code = header[0]
565
+ ndim_extra = header[1]
566
+ if dtype_code not in self.dtype_dcode_map:
567
+ raise ValueError(f"Unknown dtype code in header: {dtype_code}")
568
+
569
+ dtype = self.dtype_dcode_map[dtype_code]
570
+ shape1 = tuple(header[2:2 + ndim_extra])
571
+ itemsize = np.dtype(dtype).itemsize
572
+ item_count = np.prod(shape1, dtype=np.int64) if shape1 else 1
573
+ bytes_per_sample = itemsize * item_count
574
+
575
+ # Seek to data block
576
+ f.seek(12 * 8 + offset * bytes_per_sample)
577
+
578
+ # Determine number of items
579
+ if count is None:
580
+ remaining_bytes = os.path.getsize(filepath) - (12 * 8 + offset * bytes_per_sample)
581
+ count = remaining_bytes // bytes_per_sample
582
+
583
+ raw = f.read(count * bytes_per_sample)
584
+ data = np.frombuffer(raw, dtype=dtype)
585
+
586
+ if shape1:
587
+ data = data.reshape((count,) + shape1)
588
+ else:
589
+ data = data.reshape((count,))
590
+
591
+ return data
592
+
474
593
  # ---------------------------------------------−---------
475
594
  # ---------------------------------------------−---------
476
595
  def healpix_layer(self, im, ww, indices=None, weights=None):
@@ -614,10 +733,10 @@ class FoCUS:
614
733
  ),None
615
734
 
616
735
  # --------------------------------------------------------
617
- def up_grade(self, im, nout, axis=0, nouty=None):
736
+ def up_grade(self, im, nout, axis=-1, nouty=None):
618
737
 
738
+ ishape = list(im.shape)
619
739
  if self.use_2D:
620
- ishape = list(im.shape)
621
740
  if len(ishape) < axis + 2:
622
741
  if not self.silent:
623
742
  print("Use of 2D scat with data that has less than 2D")
@@ -632,9 +751,6 @@ class FoCUS:
632
751
  npix = im.shape[axis]
633
752
  npiy = im.shape[axis + 1]
634
753
  odata = 1
635
- if len(ishape) > axis + 2:
636
- for k in range(axis + 2, len(ishape)):
637
- odata = odata * ishape[k]
638
754
 
639
755
  ndata = 1
640
756
  for k in range(axis):
@@ -658,13 +774,12 @@ class FoCUS:
658
774
  return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
659
775
  else:
660
776
  return self.backend.bk_reshape(
661
- res, ishape[0:axis] + [nout, nouty] + ishape[axis + 2 :]
777
+ res, ishape[0:axis] + [nout, nouty]
662
778
  )
663
779
 
664
780
  return self.backend.bk_reshape(res, [nout, nouty])
665
781
 
666
782
  elif self.use_1D:
667
- ishape = list(im.shape)
668
783
  if len(ishape) < axis + 1:
669
784
  if not self.silent:
670
785
  print("Use of 1D scat with data that has less than 1D")
@@ -757,8 +872,6 @@ class FoCUS:
757
872
  imout = im
758
873
  else:
759
874
  # work only on the last column
760
-
761
- ishape = list(im.shape)
762
875
 
763
876
  ndata = 1
764
877
  for k in range(len(ishape)-1):
@@ -781,12 +894,12 @@ class FoCUS:
781
894
  tim,
782
895
  self.weight_interp_val[(lout,nout)],
783
896
  )
784
-
897
+
785
898
  if len(ishape) == 1:
786
899
  return self.backend.bk_reshape(imout, [12 * nout**2])
787
900
  else:
788
901
  return self.backend.bk_reshape(
789
- imout, ishape[0:axis-1]+[12 * nout**2]
902
+ imout, ishape[0:axis]+[12 * nout**2]
790
903
  )
791
904
  return imout
792
905
 
@@ -1074,14 +1187,13 @@ class FoCUS:
1074
1187
 
1075
1188
  try:
1076
1189
  if self.use_2D:
1077
- tmp = np.load(
1078
- "%s/W%d_%s_%d_IDX.npy"
1079
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1190
+ tmp = self.read_index("%s/W%d_%s_%d_IDX.fst"
1191
+ % (self.TEMPLATE_PATH, l_kernel**2,TMPFILE_VERSION, nside)
1080
1192
  )
1081
1193
  else:
1082
- if cell_ids is not None:
1083
- tmp = np.load(
1084
- "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1194
+ if cell_ids is not None and nside>512:
1195
+ tmp = self.read_index(
1196
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1085
1197
  % (
1086
1198
  self.TEMPLATE_PATH,
1087
1199
  TMPFILE_VERSION,
@@ -1092,8 +1204,16 @@ class FoCUS:
1092
1204
  )
1093
1205
 
1094
1206
  else:
1095
- tmp = np.load(
1096
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1207
+ print('LOAD ',"%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1208
+ % (
1209
+ self.TEMPLATE_PATH,
1210
+ TMPFILE_VERSION,
1211
+ l_kernel**2,
1212
+ self.NORIENT,
1213
+ nside,spin # if cell_ids computes the index
1214
+ ))
1215
+ tmp = self.read_index(
1216
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1097
1217
  % (
1098
1218
  self.TEMPLATE_PATH,
1099
1219
  TMPFILE_VERSION,
@@ -1104,28 +1224,56 @@ class FoCUS:
1104
1224
  )
1105
1225
 
1106
1226
  except:
1227
+ if cell_ids is not None and nside<=512:
1228
+ self.init_index(nside, kernel=kernel, spin=spin)
1229
+
1107
1230
  if not self.use_2D:
1231
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1232
+ % (
1233
+ self.TEMPLATE_PATH,
1234
+ TMPFILE_VERSION,
1235
+ l_kernel**2,
1236
+ self.NORIENT,
1237
+ nside,spin # if cell_ids computes the index
1238
+ )
1239
+ )
1108
1240
  if spin!=0:
1109
1241
  try:
1110
- tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1111
- self.TEMPLATE_PATH,
1112
- self.TMPFILE_VERSION,
1113
- self.KERNELSZ**2,
1114
- self.NORIENT,
1115
- nside)
1116
- )
1242
+ tmp = self.read_index(
1243
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
1244
+ % (
1245
+ self.TEMPLATE_PATH,
1246
+ TMPFILE_VERSION,
1247
+ l_kernel**2,
1248
+ self.NORIENT,
1249
+ nside
1250
+ )
1251
+ )
1117
1252
  except:
1253
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst'
1254
+ % (
1255
+ self.TEMPLATE_PATH,
1256
+ TMPFILE_VERSION,
1257
+ l_kernel**2,
1258
+ self.NORIENT,
1259
+ nside
1260
+ )
1261
+ )
1262
+
1118
1263
  self.init_index(nside, kernel=kernel, spin=0)
1119
1264
 
1120
- tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1121
- self.TEMPLATE_PATH,
1122
- self.TMPFILE_VERSION,
1123
- self.KERNELSZ**2,
1124
- self.NORIENT,
1125
- nside)
1126
- )
1265
+ tmp = self.read_index(
1266
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
1267
+ % (
1268
+ self.TEMPLATE_PATH,
1269
+ TMPFILE_VERSION,
1270
+ l_kernel**2,
1271
+ self.NORIENT,
1272
+ nside
1273
+ )
1274
+ )
1127
1275
 
1128
- tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.npy"% (
1276
+ tmpw = self.read_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.fst"% (
1129
1277
  self.TEMPLATE_PATH,
1130
1278
  self.TMPFILE_VERSION,
1131
1279
  self.KERNELSZ**2,
@@ -1168,52 +1316,45 @@ class FoCUS:
1168
1316
  tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1169
1317
  tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1170
1318
 
1171
-
1172
- np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"% (self.TEMPLATE_PATH,
1319
+
1320
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"% (self.TEMPLATE_PATH,
1173
1321
  self.TMPFILE_VERSION,
1174
1322
  self.KERNELSZ**2,
1175
1323
  self.NORIENT,
1176
1324
  nside,
1177
1325
  spin
1178
1326
  ),
1179
- idxEB
1180
- )
1181
- np.save("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"% (self.TEMPLATE_PATH,
1327
+ idxEB
1328
+ )
1329
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"% (self.TEMPLATE_PATH,
1182
1330
  self.TMPFILE_VERSION,
1183
1331
  self.KERNELSZ**2,
1184
1332
  self.NORIENT,
1185
1333
  nside,
1186
1334
  spin,
1187
1335
  ),
1188
- tmpEB
1189
- )
1190
- tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.npy"%
1191
- (
1192
- self.TEMPLATE_PATH,
1193
- self.TMPFILE_VERSION,
1194
- self.KERNELSZ**2,
1195
- self.NORIENT,
1196
- nside,
1197
- )
1198
- )
1199
- tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN0.npy"%
1200
- (
1201
- self.TEMPLATE_PATH,
1202
- self.TMPFILE_VERSION,
1203
- self.KERNELSZ**2,
1204
- self.NORIENT,
1205
- nside,
1206
- )
1207
- )
1208
-
1209
- nn=12*nside**2
1210
- idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1211
- idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1212
- idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1213
- idxEB[2*tmp.shape[0]:,1]+=nn
1214
-
1215
- tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1336
+ tmpEB
1337
+ )
1216
1338
 
1339
+ tmp = self.read_index(
1340
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.fst"
1341
+ % (
1342
+ self.TEMPLATE_PATH,
1343
+ TMPFILE_VERSION,
1344
+ l_kernel**2,
1345
+ self.NORIENT,
1346
+ nside
1347
+ )
1348
+ )
1349
+
1350
+ tmpw = self.read_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN0.fst"% (
1351
+ self.TEMPLATE_PATH,
1352
+ self.TMPFILE_VERSION,
1353
+ self.KERNELSZ**2,
1354
+ self.NORIENT,
1355
+ nside,
1356
+ )
1357
+ )
1217
1358
  for k in range(12*nside**2):
1218
1359
  if k%(nside**2)==0:
1219
1360
  print('Init index 2/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
@@ -1221,48 +1362,40 @@ class FoCUS:
1221
1362
  nside,
1222
1363
  self.KERNELSZ))
1223
1364
  idx=np.where(tmp[:,1]==k)[0]
1224
-
1365
+
1225
1366
  im=np.zeros([12*nside**2])
1226
- im[tmp[idx,0]]=tmpw[idx].real
1367
+ im[tmp[idx,0]]=tmpw[idx]
1227
1368
  almR=hp.map2alm(hp.reorder(im,n2r=True))
1228
- im[tmp[idx,0]]=tmpw[idx].imag
1229
- almI=hp.map2alm(hp.reorder(im,n2r=True))
1230
-
1369
+
1231
1370
  i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1232
- i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1233
-
1234
- tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1235
- tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1236
-
1371
+
1372
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1373
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1374
+
1237
1375
  i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1238
- i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1239
-
1240
- tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1241
- tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1242
-
1243
-
1244
- np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"%
1245
- (
1246
- self.TEMPLATE_PATH,
1247
- self.TMPFILE_VERSION,
1248
- self.KERNELSZ**2,
1249
- self.NORIENT,
1250
- nside,
1251
- spin,
1252
- ),
1253
- idxEB
1254
- )
1255
- np.save("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"%
1256
- (
1257
- self.TEMPLATE_PATH,
1258
- self.TMPFILE_VERSION,
1259
- self.KERNELSZ**2,
1260
- self.NORIENT,
1261
- nside,
1262
- spin,
1263
- ),
1264
- tmpEB
1265
- )
1376
+
1377
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1378
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1379
+
1380
+
1381
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"% (self.TEMPLATE_PATH,
1382
+ self.TMPFILE_VERSION,
1383
+ self.KERNELSZ**2,
1384
+ self.NORIENT,
1385
+ nside,
1386
+ spin
1387
+ ),
1388
+ idxEB
1389
+ )
1390
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"% (self.TEMPLATE_PATH,
1391
+ self.TMPFILE_VERSION,
1392
+ self.KERNELSZ**2,
1393
+ self.NORIENT,
1394
+ nside,
1395
+ spin,
1396
+ ),
1397
+ tmpEB
1398
+ )
1266
1399
  else:
1267
1400
 
1268
1401
  if l_kernel == 5:
@@ -1280,7 +1413,7 @@ class FoCUS:
1280
1413
  pw2 = 0.25
1281
1414
  threshold = 4e-5
1282
1415
 
1283
- if cell_ids is not None:
1416
+ if cell_ids is not None and nside>512:
1284
1417
  if not isinstance(cell_ids, np.ndarray):
1285
1418
  cell_ids = self.backend.to_numpy(cell_ids)
1286
1419
  th, ph = hp.pix2ang(nside, cell_ids, nest=True)
@@ -1304,15 +1437,20 @@ class FoCUS:
1304
1437
  phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1305
1438
  thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1306
1439
 
1307
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1440
+ indice2 = np.zeros([12 * nside * nside * 64, 2],
1441
+ dtype="int")
1442
+
1308
1443
  indice = np.zeros(
1309
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1444
+ [12 * nside * nside * 64 * self.NORIENT, 2],
1445
+ dtype="int"
1310
1446
  )
1311
1447
  wav = np.zeros(
1312
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1448
+ [12 * nside * nside * 64 * self.NORIENT],
1449
+ dtype="complex"
1313
1450
  )
1314
1451
  wwav = np.zeros(
1315
- [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1452
+ [12 * nside * nside * 64 * self.NORIENT],
1453
+ dtype="float"
1316
1454
  )
1317
1455
  iv = 0
1318
1456
  iv2 = 0
@@ -1405,26 +1543,26 @@ class FoCUS:
1405
1543
  if cell_ids is None:
1406
1544
  if not self.silent:
1407
1545
  print(
1408
- "Write FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1409
- % (TMPFILE_VERSION, self.KERNELSZ**2,
1546
+ "Write %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1547
+ % ( self.TEMPLATE_PATH,
1548
+ TMPFILE_VERSION, self.KERNELSZ**2,
1410
1549
  self.NORIENT,
1411
1550
  nside,
1412
- spin,)
1551
+ spin)
1413
1552
  )
1414
- np.save(
1415
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1416
- % (
1417
- self.TEMPLATE_PATH,
1418
- TMPFILE_VERSION,
1419
- self.KERNELSZ**2,
1420
- self.NORIENT,
1421
- nside,
1422
- spin,
1423
- ),
1424
- indice,
1425
- )
1426
- np.save(
1427
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1553
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1554
+ % (
1555
+ self.TEMPLATE_PATH,
1556
+ TMPFILE_VERSION,
1557
+ self.KERNELSZ**2,
1558
+ self.NORIENT,
1559
+ nside,
1560
+ spin,
1561
+ ),
1562
+ indice
1563
+ )
1564
+ self.save_index(
1565
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1428
1566
  % (
1429
1567
  self.TEMPLATE_PATH,
1430
1568
  TMPFILE_VERSION,
@@ -1435,8 +1573,8 @@ class FoCUS:
1435
1573
  ),
1436
1574
  wav,
1437
1575
  )
1438
- np.save(
1439
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1576
+ self.save_index(
1577
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"
1440
1578
  % (
1441
1579
  self.TEMPLATE_PATH,
1442
1580
  TMPFILE_VERSION,
@@ -1447,8 +1585,8 @@ class FoCUS:
1447
1585
  ),
1448
1586
  indice2,
1449
1587
  )
1450
- np.save(
1451
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1588
+ self.save_index(
1589
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"
1452
1590
  % (
1453
1591
  self.TEMPLATE_PATH,
1454
1592
  TMPFILE_VERSION,
@@ -1475,11 +1613,11 @@ class FoCUS:
1475
1613
  )
1476
1614
  return None
1477
1615
 
1478
- if cell_ids is None:
1616
+ if cell_ids is None or nside<=512:
1479
1617
  self.barrier()
1480
1618
  if self.use_2D:
1481
- tmp = np.load(
1482
- "%s/W%d_%s_%d_IDX-SPIN%d.npy"
1619
+ tmp = self.read_index(
1620
+ "%s/W%d_%s_%d_IDX-SPIN%d.fst"
1483
1621
  % (
1484
1622
  self.TEMPLATE_PATH,
1485
1623
  l_kernel**2,
@@ -1488,8 +1626,8 @@ class FoCUS:
1488
1626
  spin)
1489
1627
  )
1490
1628
  else:
1491
- tmp = np.load(
1492
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1629
+ tmp = self.read_index(
1630
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1493
1631
  % (
1494
1632
  self.TEMPLATE_PATH,
1495
1633
  TMPFILE_VERSION,
@@ -1499,8 +1637,8 @@ class FoCUS:
1499
1637
  spin,
1500
1638
  )
1501
1639
  )
1502
- tmp2 = np.load(
1503
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1640
+ tmp2 = self.read_index(
1641
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"
1504
1642
  % (
1505
1643
  self.TEMPLATE_PATH,
1506
1644
  TMPFILE_VERSION,
@@ -1510,8 +1648,8 @@ class FoCUS:
1510
1648
  spin,
1511
1649
  )
1512
1650
  )
1513
- wr = np.load(
1514
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1651
+ wr = self.read_index(
1652
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1515
1653
  % (
1516
1654
  self.TEMPLATE_PATH,
1517
1655
  TMPFILE_VERSION,
@@ -1521,8 +1659,8 @@ class FoCUS:
1521
1659
  spin,
1522
1660
  )
1523
1661
  ).real
1524
- wi = np.load(
1525
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1662
+ wi = self.read_index(
1663
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1526
1664
  % (
1527
1665
  self.TEMPLATE_PATH,
1528
1666
  TMPFILE_VERSION,
@@ -1532,8 +1670,8 @@ class FoCUS:
1532
1670
  spin,
1533
1671
  )
1534
1672
  ).imag
1535
- ws = self.slope * np.load(
1536
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1673
+ ws = self.slope * self.read_index(
1674
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"
1537
1675
  % (
1538
1676
  self.TEMPLATE_PATH,
1539
1677
  TMPFILE_VERSION,
@@ -1543,6 +1681,39 @@ class FoCUS:
1543
1681
  spin,
1544
1682
  )
1545
1683
  )
1684
+
1685
+ if cell_ids is not None:
1686
+ idx_map=-np.ones([12*nside**2],dtype='int32')
1687
+ lcell_ids=cell_ids
1688
+
1689
+ try:
1690
+ idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32')
1691
+ except:
1692
+ lcell_ids=self.to_numpy(cell_ids)
1693
+ idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32')
1694
+
1695
+ lidx=np.where(idx_map[tmp[:,1]%(12*nside**2)]!=-1)[0]
1696
+ orientation=tmp[lidx,1]//(12*nside**2)
1697
+ tmp=tmp[lidx]
1698
+ wr=wr[lidx]
1699
+ wi=wi[lidx]
1700
+ tmp=idx_map[tmp%(12*nside**2)]
1701
+ lidx=np.where(tmp[:,0]==-1)[0]
1702
+ wr[lidx]=0.0
1703
+ wi[lidx]=0.0
1704
+ tmp[lidx,0]=0
1705
+ tmp[:,1]+=orientation*lcell_ids.shape[0]
1706
+
1707
+ idx_map=-np.ones([12*nside**2],dtype='int32')
1708
+ idx_map[lcell_ids]=np.arange(cell_ids.shape[0],dtype='int32')
1709
+ lidx=np.where(idx_map[tmp2[:,1]]!=-1)[0]
1710
+ tmp2=tmp2[lidx]
1711
+ ws=ws[lidx]
1712
+ tmp2=idx_map[tmp2]
1713
+ lidx=np.where(tmp2[:,0]==-1)[0]
1714
+ ws[lidx]=0.0
1715
+ tmp2[lidx,0]=0
1716
+
1546
1717
  else:
1547
1718
  tmp = indice
1548
1719
  tmp2 = indice2
@@ -1550,6 +1721,7 @@ class FoCUS:
1550
1721
  wi = wav.imag
1551
1722
  ws = self.slope * wwav
1552
1723
 
1724
+
1553
1725
  if spin==0:
1554
1726
  wr = self.backend.bk_SparseTensor(
1555
1727
  self.backend.bk_constant(tmp),
@@ -1590,7 +1762,7 @@ class FoCUS:
1590
1762
  if kernel != -1:
1591
1763
  return tmp
1592
1764
 
1593
- return wr, wi, ws, tmp
1765
+ return wr, wi, ws,tmp
1594
1766
 
1595
1767
 
1596
1768
  # ---------------------------------------------−---------
@@ -1609,8 +1781,8 @@ class FoCUS:
1609
1781
  try:
1610
1782
 
1611
1783
  if cell_ids is not None:
1612
- tmp = np.load(
1613
- "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1784
+ tmp = self.read_index(
1785
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1614
1786
  % (
1615
1787
  self.TEMPLATE_PATH,
1616
1788
  TMPFILE_VERSION,
@@ -1621,8 +1793,8 @@ class FoCUS:
1621
1793
  )
1622
1794
 
1623
1795
  else:
1624
- tmp = np.load(
1625
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1796
+ tmp = self.read_index(
1797
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1626
1798
  % (
1627
1799
  self.TEMPLATE_PATH,
1628
1800
  TMPFILE_VERSION,
@@ -1756,11 +1928,11 @@ class FoCUS:
1756
1928
  if cell_ids is None:
1757
1929
  if not self.silent:
1758
1930
  print(
1759
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1931
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1760
1932
  % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
1761
1933
  )
1762
- np.save(
1763
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1934
+ self.save_index(
1935
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1764
1936
  % (
1765
1937
  self.TEMPLATE_PATH,
1766
1938
  TMPFILE_VERSION,
@@ -1770,8 +1942,8 @@ class FoCUS:
1770
1942
  ),
1771
1943
  indice,
1772
1944
  )
1773
- np.save(
1774
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1945
+ self.save_index(
1946
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst"
1775
1947
  % (
1776
1948
  self.TEMPLATE_PATH,
1777
1949
  TMPFILE_VERSION,
@@ -1785,13 +1957,13 @@ class FoCUS:
1785
1957
  if cell_ids is None:
1786
1958
  self.barrier()
1787
1959
  if self.use_2D:
1788
- tmp = np.load(
1789
- "%s/W%d_%s_%d_IDX.npy"
1960
+ tmp = self.read_index(
1961
+ "%s/W%d_%s_%d_IDX.fst"
1790
1962
  % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1791
1963
  )
1792
1964
  else:
1793
- tmp = np.load(
1794
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1965
+ tmp = self.read_index(
1966
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1795
1967
  % (
1796
1968
  self.TEMPLATE_PATH,
1797
1969
  TMPFILE_VERSION,
@@ -1800,8 +1972,8 @@ class FoCUS:
1800
1972
  nside,
1801
1973
  )
1802
1974
  )
1803
- wav = np.load(
1804
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1975
+ wav = self.read_index(
1976
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst"
1805
1977
  % (
1806
1978
  self.TEMPLATE_PATH,
1807
1979
  TMPFILE_VERSION,
@@ -2299,34 +2471,19 @@ class FoCUS:
2299
2471
  if nside is None:
2300
2472
  nside = int(np.sqrt(image.shape[-1] // 12))
2301
2473
 
2302
- if spin==0:
2303
- if nside not in self.Idx_Neighbours:
2304
- if self.InitWave is None:
2305
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2306
- else:
2307
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2308
-
2309
- self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2310
- self.ww_Real[nside] = wr
2311
- self.ww_Imag[nside] = wi
2312
- self.w_smooth[nside] = ws
2313
-
2314
- l_ww_real = self.ww_Real[nside]
2315
- l_ww_imag = self.ww_Imag[nside]
2316
- else:
2317
- if (spin,nside) not in self.Idx_Neighbours:
2318
- if self.InitWave is None:
2319
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2320
- else:
2321
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2474
+ if (spin,nside) not in self.Idx_Neighbours:
2475
+ if self.InitWave is None:
2476
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2477
+ else:
2478
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2322
2479
 
2323
- self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2324
- self.ww_Real[(spin,nside)] = wr
2325
- self.ww_Imag[(spin,nside)] = wi
2326
- self.w_smooth[(spin,nside)] = ws
2480
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2481
+ self.ww_Real[(spin,nside)] = wr
2482
+ self.ww_Imag[(spin,nside)] = wi
2483
+ self.w_smooth[(spin,nside)] = ws
2327
2484
 
2328
- l_ww_real = self.ww_Real[(spin,nside)]
2329
- l_ww_imag = self.ww_Imag[(spin,nside)]
2485
+ l_ww_real = self.ww_Real[(spin,nside)]
2486
+ l_ww_imag = self.ww_Imag[(spin,nside)]
2330
2487
 
2331
2488
  # always convolve the last dimension
2332
2489
 
@@ -2342,7 +2499,6 @@ class FoCUS:
2342
2499
  tim = self.backend.bk_reshape(
2343
2500
  self.backend.bk_cast(image), [ndata, ishape[-1]]
2344
2501
  )
2345
-
2346
2502
  if tim.dtype == self.all_cbk_type:
2347
2503
  rr1 = self.backend.bk_reshape(
2348
2504
  self.backend.bk_sparse_dense_matmul(
@@ -2399,7 +2555,6 @@ class FoCUS:
2399
2555
  else:
2400
2556
  return self.backend.bk_reshape(res, [2,self.NORIENT, ishape[-1]])
2401
2557
 
2402
-
2403
2558
  return res
2404
2559
 
2405
2560
  # ---------------------------------------------−---------
@@ -2467,31 +2622,18 @@ class FoCUS:
2467
2622
  if nside is None:
2468
2623
  nside = int(np.sqrt(image.shape[-1] // 12))
2469
2624
 
2470
- if spin==0:
2471
- if nside not in self.Idx_Neighbours:
2472
- if self.InitWave is None:
2473
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2474
- else:
2475
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2476
-
2477
- self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2478
- self.ww_Real[nside] = wr
2479
- self.ww_Imag[nside] = wi
2480
- self.w_smooth[nside] = ws
2481
-
2482
- l_w_smooth = self.w_smooth[nside]
2483
- else:
2484
- if (spin,nside) not in self.Idx_Neighbours:
2485
- if self.InitWave is None:
2486
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2487
- else:
2488
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2625
+ if (spin,nside) not in self.Idx_Neighbours:
2626
+ if self.InitWave is None:
2627
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2628
+ else:
2629
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2489
2630
 
2490
- self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2491
- self.ww_Real[(spin,nside)] = wr
2492
- self.ww_Imag[(spin,nside)] = wi
2493
- self.w_smooth[(spin,nside)] = ws
2494
- l_w_smooth = self.w_smooth[(spin,nside)]
2631
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2632
+ self.ww_Real[(spin,nside)] = wr
2633
+ self.ww_Imag[(spin,nside)] = wi
2634
+ self.w_smooth[(spin,nside)] = ws
2635
+
2636
+ l_w_smooth = self.w_smooth[(spin,nside)]
2495
2637
 
2496
2638
  odata = 1
2497
2639
  for k in range(0, len(ishape) - 1):