foscat 2025.11.1__py3-none-any.whl → 2026.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/SphereUpGeo.py ADDED
@@ -0,0 +1,175 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from foscat.SphereDownGeo import SphereDownGeo
6
+
7
+
8
+ class SphereUpGeo(nn.Module):
9
+ """Geometric HEALPix upsampling operator using the transpose of SphereDownGeo.
10
+
11
+ `cell_ids_out` (coarse pixels at nside_out, NESTED) is mandatory.
12
+ Forward expects x of shape [B, C, K_out] aligned with that order.
13
+ Output is a full fine-grid map [B, C, N_in] at nside_in = 2*nside_out.
14
+
15
+ Normalization (diagonal corrections):
16
+ - up_norm='adjoint': x_up = M^T x
17
+ - up_norm='col_l1': x_up = (M^T x) / col_sum, col_sum[i] = sum_k M[k,i]
18
+ - up_norm='diag_l2': x_up = (M^T x) / col_l2, col_l2[i] = sum_k M[k,i]^2
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ nside_out: int,
24
+ cell_ids_out,
25
+ radius_deg: float | None = None,
26
+ sigma_deg: float | None = None,
27
+ weight_norm: str = "l1",
28
+ up_norm: str = "col_l1",
29
+ eps: float = 1e-12,
30
+ device=None,
31
+ dtype=torch.float32,
32
+ ):
33
+ super().__init__()
34
+
35
+ if cell_ids_out is None:
36
+ raise ValueError("cell_ids_out is mandatory (1D list/np/tensor of coarse HEALPix ids at nside_out).")
37
+
38
+ if device is None:
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ self.device = device
41
+ self.dtype = dtype
42
+
43
+ self.nside_out = int(nside_out)
44
+ assert (self.nside_out & (self.nside_out - 1)) == 0, "nside_out must be a power of 2."
45
+ self.nside_in = self.nside_out * 2
46
+
47
+ self.N_out = 12 * self.nside_out * self.nside_out
48
+ self.N_in = 12 * self.nside_in * self.nside_in
49
+
50
+ up_norm = str(up_norm).lower().strip()
51
+ if up_norm not in ("adjoint", "col_l1", "diag_l2"):
52
+ raise ValueError("up_norm must be 'adjoint', 'col_l1', or 'diag_l2'.")
53
+ self.up_norm = up_norm
54
+ self.eps = float(eps)
55
+
56
+ # Coarse ids in user-provided order (must be unique for alignment)
57
+ if isinstance(cell_ids_out, torch.Tensor):
58
+ cell_ids_out_np = cell_ids_out.detach().cpu().numpy().astype(np.int64)
59
+ else:
60
+ cell_ids_out_np = np.asarray(cell_ids_out, dtype=np.int64)
61
+
62
+ if cell_ids_out_np.ndim != 1:
63
+ raise ValueError("cell_ids_out must be 1D")
64
+ if cell_ids_out_np.size == 0:
65
+ raise ValueError("cell_ids_out must be non-empty")
66
+ if cell_ids_out_np.min() < 0 or cell_ids_out_np.max() >= self.N_out:
67
+ raise ValueError("cell_ids_out contains out-of-bounds ids for this nside_out")
68
+ if np.unique(cell_ids_out_np).size != cell_ids_out_np.size:
69
+ raise ValueError("cell_ids_out must not contain duplicates (order matters for alignment).")
70
+
71
+ self.cell_ids_out_np = cell_ids_out_np
72
+ self.K_out = int(cell_ids_out_np.size)
73
+ self.register_buffer("cell_ids_out_t", torch.as_tensor(cell_ids_out_np, dtype=torch.long, device=self.device))
74
+
75
+ # Build the FULL down operator at fine resolution (nside_in -> nside_out)
76
+ tmp_down = SphereDownGeo(
77
+ nside_in=self.nside_in,
78
+ mode="smooth",
79
+ radius_deg=radius_deg,
80
+ sigma_deg=sigma_deg,
81
+ weight_norm=weight_norm,
82
+ device=self.device,
83
+ dtype=self.dtype,
84
+ use_csr=False,
85
+ )
86
+
87
+ M_down_full = torch.sparse_coo_tensor(
88
+ tmp_down.M.indices(),
89
+ tmp_down.M.values(),
90
+ size=(tmp_down.N_out, tmp_down.N_in),
91
+ device=self.device,
92
+ dtype=self.dtype,
93
+ ).coalesce()
94
+
95
+ # Extract ONLY the requested coarse rows, in the provided order.
96
+ # We do this on CPU with numpy for simplicity and speed at init.
97
+ idx = M_down_full.indices().cpu().numpy()
98
+ vals = M_down_full.values().cpu().numpy()
99
+ rows = idx[0]
100
+ cols = idx[1]
101
+
102
+ # Map original row id -> new row position [0..K_out-1]
103
+ row_map = {int(r): i for i, r in enumerate(cell_ids_out_np.tolist())}
104
+ mask = np.fromiter((r in row_map for r in rows), dtype=bool, count=rows.size)
105
+
106
+ rows_sel = rows[mask]
107
+ cols_sel = cols[mask]
108
+ vals_sel = vals[mask]
109
+
110
+ new_rows = np.fromiter((row_map[int(r)] for r in rows_sel), dtype=np.int64, count=rows_sel.size)
111
+
112
+ M_down_sub = torch.sparse_coo_tensor(
113
+ torch.as_tensor(np.stack([new_rows, cols_sel], axis=0), dtype=torch.long),
114
+ torch.as_tensor(vals_sel, dtype=self.dtype),
115
+ size=(self.K_out, self.N_in),
116
+ device=self.device,
117
+ dtype=self.dtype,
118
+ ).coalesce()
119
+
120
+ # Store M^T (sparse) so forward is just sparse.mm
121
+ M_up = self._transpose_sparse(M_down_sub) # [N_in, K_out]
122
+ self.register_buffer("M_indices", M_up.indices())
123
+ self.register_buffer("M_values", M_up.values())
124
+ self.M_size = M_up.size()
125
+
126
+ # Diagonal normalizers (length N_in), based on the selected coarse rows only
127
+ idx_sub = M_down_sub.indices()
128
+ vals_sub = M_down_sub.values()
129
+ fine_cols = idx_sub[1]
130
+
131
+ col_sum = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
132
+ col_l2 = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
133
+ col_sum.scatter_add_(0, fine_cols, vals_sub)
134
+ col_l2.scatter_add_(0, fine_cols, vals_sub * vals_sub)
135
+
136
+ self.register_buffer("col_sum", col_sum)
137
+ self.register_buffer("col_l2", col_l2)
138
+
139
+ # Fine ids (full sphere)
140
+ self.register_buffer("cell_ids_in_t", torch.arange(self.N_in, dtype=torch.long, device=self.device))
141
+
142
+ self.M_T = torch.sparse_coo_tensor(
143
+ self.M_indices.to(device=self.device),
144
+ self.M_values.to(device=self.device, dtype=self.dtype),
145
+ size=self.M_size,
146
+ device=self.device,
147
+ dtype=self.dtype,
148
+ ).coalesce().to_sparse_csr().to(self.device)
149
+
150
+ @staticmethod
151
+ def _transpose_sparse(M: torch.Tensor) -> torch.Tensor:
152
+ M = M.coalesce()
153
+ idx = M.indices()
154
+ vals = M.values()
155
+ R, C = M.size()
156
+ idx_T = torch.stack([idx[1], idx[0]], dim=0)
157
+ return torch.sparse_coo_tensor(idx_T, vals, size=(C, R), device=M.device, dtype=M.dtype).coalesce()
158
+
159
+ def forward(self, x: torch.Tensor):
160
+ """x: [B, C, K_out] -> x_up: [B, C, N_in]."""
161
+ B, C, K_out = x.shape
162
+ assert K_out == self.K_out, f"Expected K_out={self.K_out}, got {K_out}"
163
+
164
+ x_bc = x.reshape(B * C, K_out)
165
+ x_up_bc_T = torch.sparse.mm(self.M_T, x_bc.T) # [N_in, B*C]
166
+ x_up = x_up_bc_T.T.reshape(B, C, self.N_in) # [B, C, N_in]
167
+
168
+ if self.up_norm == "col_l1":
169
+ denom = self.col_sum.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
170
+ x_up = x_up / denom.view(1, 1, -1)
171
+ elif self.up_norm == "diag_l2":
172
+ denom = self.col_l2.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
173
+ x_up = x_up / denom.view(1, 1, -1)
174
+
175
+ return x_up, self.cell_ids_in_t.to(device=x.device)
@@ -2,13 +2,8 @@
2
2
  # Author: J.-M. Delouis
3
3
  import numpy as np
4
4
  import healpy as hp
5
- import foscat.scat_cov as sc
6
5
  import torch
7
6
 
8
- import numpy as np
9
- import torch
10
- import healpy as hp
11
-
12
7
 
13
8
  class SphericalStencil:
14
9
  """
@@ -61,8 +56,7 @@ class SphericalStencil:
61
56
  device=None,
62
57
  dtype=None,
63
58
  n_gauges=1,
64
- gauge_type='cosmo',
65
- scat_op=None,
59
+ gauge_type='phi',
66
60
  ):
67
61
  assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
68
62
  assert kernel_sz % 2 == 1, "kernel_sz must be odd"
@@ -75,10 +69,6 @@ class SphericalStencil:
75
69
  self.gauge_type=gauge_type
76
70
 
77
71
  self.nest = bool(nest)
78
- if scat_op is None:
79
- self.f=sc.funct(KERNELSZ=self.KERNELSZ)
80
- else:
81
- self.f=scat_op
82
72
 
83
73
  # Torch defaults
84
74
  if device is None:
@@ -354,10 +344,27 @@ class SphericalStencil:
354
344
  # --- build the local (P,3) stencil once on device
355
345
  P = self.P
356
346
  vec_np = np.zeros((P, 3), dtype=float)
357
- grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2) / self.nside
358
- vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
359
- vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
360
- vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
347
+ grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
348
+
349
+ # NEW: angular offsets
350
+ xx,yy=np.meshgrid(grid,grid)
351
+ s=1.0 # could be modified
352
+ alpha_pix = hp.nside2resol(self.nside, arcmin=False) # ~ taille angulaire typique
353
+ dtheta = (np.sqrt(xx**2+yy**2) * alpha_pix * s).ravel()
354
+ dphi = (np.arctan2(yy,xx)).ravel()
355
+ # local spherical displacement
356
+ # convert to unit vectors
357
+ x = np.sin(dtheta) * np.cos(dphi)
358
+ y = np.sin(dtheta) * np.sin(dphi)
359
+ z = np.cos(dtheta)
360
+ #print(self.nside*x.reshape(self.KERNELSZ,self.KERNELSZ))
361
+ #print(self.nside*y.reshape(self.KERNELSZ,self.KERNELSZ))
362
+ #print(self.nside*z.reshape(self.KERNELSZ,self.KERNELSZ))
363
+ vec_np = np.stack([x, y, z], axis=-1)
364
+
365
+ #vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
366
+ #vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
367
+ #vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
361
368
  vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
362
369
 
363
370
  # --- rotation matrices for all targets & gauges: (K,G,3,3)
@@ -371,7 +378,7 @@ class SphericalStencil:
371
378
  th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
372
379
  device=self.device, dtype=self.dtype
373
380
  ) # shape (K,G,3,3)
374
-
381
+
375
382
  # --- rotate stencil for each (target, gauge): (K,G,P,3)
376
383
  # einsum over local stencil (P,3) with rotation (K,G,3,3)
377
384
  rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
@@ -568,119 +575,6 @@ class SphericalStencil:
568
575
  self.dtype = dtype
569
576
 
570
577
 
571
- '''
572
- def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
573
- """
574
- Multi-gauge sparse binding (Step B).
575
- Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
576
- and builds, for each gauge g, (pos_safe, w_norm, present).
577
-
578
- Parameters
579
- ----------
580
- ids_sorted_np : np.ndarray (K,)
581
- Sorted pixel ids for available samples (matches the last axis of your data).
582
- device, dtype : torch device/dtype for the produced mapping tensors.
583
-
584
- Side effects
585
- ------------
586
- Sets:
587
- - self.ids_sorted_np : (K,)
588
- - self.pos_safe_t_multi : (G, 4, K*P) LongTensor
589
- - self.w_norm_t_multi : (G, 4, K*P) Tensor
590
- - self.present_t_multi : (G, 4, K*P) BoolTensor
591
- - (and mirrors device/dtype in self.device/self.dtype)
592
- """
593
- assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
594
- "Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
595
- assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
596
-
597
- if device is None: device = self.device
598
- if dtype is None: dtype = self.dtype
599
-
600
- self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
601
- ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
602
-
603
- G, _, M = self.idx_t_multi.shape
604
- K = self.Kb
605
- P = self.P
606
- assert M == K*P, "idx_t_multi second axis must have K*P columns"
607
-
608
- pos_list, present_list, wnorm_list = [], [], []
609
-
610
- for g in range(G):
611
- idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
612
- w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
613
-
614
- pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
615
- in_range = pos < ids_sorted.numel()
616
- cmp_vals = torch.full_like(idx, -1)
617
- cmp_vals[in_range] = ids_sorted[pos[in_range]]
618
- present = (cmp_vals == idx)
619
-
620
- # normalize weights per column after masking
621
- w = w * present
622
- colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
623
- w_norm = w / colsum
624
-
625
- pos_safe = torch.where(present, pos, torch.zeros_like(pos))
626
-
627
- pos_list.append(pos_safe)
628
- present_list.append(present)
629
- wnorm_list.append(w_norm)
630
-
631
- self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
632
- self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
633
- self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
634
-
635
- # mirror runtime placement
636
- self.device = device
637
- self.dtype = dtype
638
-
639
- # ------------------------------------------------------------------
640
- # Step B: bind support Torch
641
- # ------------------------------------------------------------------
642
- def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
643
- """
644
- Map HEALPix neighbor indices (from Step A) to actual data samples
645
- sorted by pixel id. Produces pos_safe and normalized weights.
646
-
647
- Parameters
648
- ----------
649
- ids_sorted_np : np.ndarray (K,)
650
- Sorted pixel ids for available data.
651
- device, dtype : Torch device/dtype for results.
652
- """
653
- if device is None:
654
- device = self.device
655
- if dtype is None:
656
- dtype = self.dtype
657
-
658
- self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
659
- ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
660
-
661
- idx = self.idx_t.to(device=device, dtype=torch.long)
662
- w = self.w_t.to(device=device, dtype=dtype)
663
-
664
- M = self.Kb * self.P
665
- idx = idx.view(4, M)
666
- w = w.view(4, M)
667
-
668
- pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
669
- in_range = pos < ids_sorted.shape[0]
670
- cmp_vals = torch.full_like(idx, -1)
671
- cmp_vals[in_range] = ids_sorted[pos[in_range]]
672
- present = (cmp_vals == idx)
673
-
674
- w = w * present
675
- colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
676
- w_norm = w / colsum
677
-
678
- self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
679
- self.w_norm_t = w_norm
680
- self.present_t = present
681
- self.device = device
682
- self.dtype = dtype
683
- '''
684
578
  # ------------------------------------------------------------------
685
579
  # Step C: apply convolution (already Torch in your code)
686
580
  # ------------------------------------------------------------------
@@ -1215,7 +1109,7 @@ class SphericalStencil:
1215
1109
  vals = torch.cat(vals_all, dim=0)
1216
1110
 
1217
1111
 
1218
- indices = torch.stack([cols, rows], dim=0) # (2, nnz) invert rows/cols for foscat needs
1112
+ indices = torch.stack([cols, rows], dim=0)
1219
1113
 
1220
1114
  if return_sparse_tensor:
1221
1115
  M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
@@ -1224,123 +1118,10 @@ class SphericalStencil:
1224
1118
  return vals, indices, shape
1225
1119
 
1226
1120
 
1227
- def _to_numpy_1d(self, ids):
1228
- """Return a 1D numpy array of int64 for a single set of cell ids."""
1229
- import numpy as np, torch
1230
- if isinstance(ids, np.ndarray):
1231
- return ids.reshape(-1).astype(np.int64, copy=False)
1232
- if torch.is_tensor(ids):
1233
- return ids.detach().cpu().to(torch.long).view(-1).numpy()
1234
- # python list/tuple of ints
1235
- return np.asarray(ids, dtype=np.int64).reshape(-1)
1236
-
1237
- def _is_varlength_batch(self, ids):
1238
- """
1239
- True if ids is a list/tuple of per-sample id arrays (var-length batch).
1240
- False if ids is a single array/tensor of ids (shared for whole batch).
1241
- """
1242
- import numpy as np, torch
1243
- if isinstance(ids, (list, tuple)):
1244
- return True
1245
- if isinstance(ids, np.ndarray) and ids.ndim == 2:
1246
- # This would be a dense (B, Npix) matrix -> NOT var-length list
1247
- return False
1248
- if torch.is_tensor(ids) and ids.dim() == 2:
1249
- return False
1250
- return False
1251
-
1252
- def Down(self, im, cell_ids=None, nside=None,max_poll=False):
1253
- """
1254
- If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
1255
- If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
1256
- """
1257
- if self.f is None:
1258
- if self.dtype==torch.float64:
1259
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1260
- else:
1261
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1262
-
1263
- if cell_ids is None:
1264
- dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
1265
- return dim,cdim
1266
-
1267
- if nside is None:
1268
- nside = self.nside
1269
-
1270
- # var-length mode: list/tuple of ids, one per sample
1271
- if self._is_varlength_batch(cell_ids):
1272
- outs, outs_ids = [], []
1273
- B = len(cell_ids)
1274
- for b in range(B):
1275
- cid_b = self._to_numpy_1d(cell_ids[b])
1276
- # extraire le bon échantillon d'`im`
1277
- if torch.is_tensor(im):
1278
- xb = im[b:b+1] # (1, C, N_b)
1279
- yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
1280
- outs.append(yb.squeeze(0)) # (C, N_b')
1281
- else:
1282
- # si im est déjà une liste de (C, N_b)
1283
- xb = im[b]
1284
- yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
1285
- outs.append(yb.squeeze(0))
1286
- outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
1287
- return outs, outs_ids
1288
-
1289
- # grille commune (un seul vecteur d'ids)
1290
- cid = self._to_numpy_1d(cell_ids)
1291
- return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
1292
-
1293
- def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
1294
- """
1295
- If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
1296
- If they are lists (var-length per sample) -> return list[Tensor].
1297
- """
1298
- if self.f is None:
1299
- if self.dtype==torch.float64:
1300
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1301
- else:
1302
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1303
-
1304
- if cell_ids is None:
1305
- dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
1306
- return dim
1307
-
1308
- if nside is None:
1309
- nside = self.nside
1310
-
1311
- # var-length: listes parallèles
1312
- if self._is_varlength_batch(cell_ids):
1313
- assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
1314
- "In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
1315
- outs = []
1316
- B = len(cell_ids)
1317
- for b in range(B):
1318
- cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
1319
- ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
1320
- if torch.is_tensor(im):
1321
- xb = im[b:b+1] # (1, C, N_b_coarse)
1322
- yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
1323
- o_cell_ids=ocid_b, force_init_index=True)
1324
- outs.append(yb.squeeze(0)) # (C, N_b_fine)
1325
- else:
1326
- xb = im[b] # (C, N_b_coarse)
1327
- yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
1328
- o_cell_ids=ocid_b, force_init_index=True)
1329
- outs.append(yb.squeeze(0))
1330
- return outs
1331
-
1332
- # grille commune
1333
- cid = self._to_numpy_1d(cell_ids)
1334
- ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
1335
- return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
1336
- o_cell_ids=ocid, force_init_index=True)
1337
-
1338
-
1339
1121
  def to_tensor(self,x):
1340
- return torch.tensor(x,device=self.device,dtype=self.dtype)
1341
-
1122
+ return torch.tensor(x,device='cuda')
1123
+
1342
1124
  def to_numpy(self,x):
1343
1125
  if isinstance(x,np.ndarray):
1344
1126
  return x
1345
- return x.cpu().numpy()
1346
-
1127
+ return x.cpu().numpy()