foscat 2025.9.3__py3-none-any.whl → 2025.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/Plot.py CHANGED
@@ -1,106 +1,114 @@
1
1
  import numpy as np
2
2
  import matplotlib.pyplot as plt
3
3
  import healpy as hp
4
+ import math
5
+ import io
6
+ import requests
7
+ from PIL import Image
4
8
 
5
9
  def lgnomproject(
6
- cell_ids, # array-like (N,), HEALPix pixel indices of your samples
7
- data, # array-like (N,), values per cell id
8
- nside: int,
9
- rot=None, # (lon0_deg, lat0_deg, psi_deg). If None: auto-center from cell_ids (pix centers)
10
- xsize: int = 400,
11
- ysize: int = 400,
12
- reso: float = None, # deg/pixel on tangent plane; if None, use fov_deg
13
- fov_deg=None, # full FoV deg (scalar or (fx,fy))
14
- nest: bool = True, # True if your cell_ids are NESTED (and ang2pix to be done in NEST)
15
- reduce: str = "mean", # 'mean'|'median'|'sum'|'first' when duplicates in cell_ids
16
- mask_outside: bool = True,
17
- unseen_value=None, # default to hp.UNSEEN
18
- return_image_only: bool = False,
19
- title: str = None, cmap: str = "viridis", vmin=None, vmax=None,
20
- notext: bool = False, # True to avoid tick marks
21
- hold: bool = True, # create a new figure if True otherwise use sub to split panel
22
- sub=(1,1,1), # declare sub plot
23
- cbar: bool = False, # plot colorbar if True
24
-
10
+ cell_ids, # (N,) HEALPix pixel indices
11
+ data, # (N,) scalar values OR (N,3) RGB values in [0,1]
12
+ nside: int,
13
+ rot=None,
14
+ xsize: int = 400,
15
+ ysize: int = 400,
16
+ reso: float = None,
17
+ fov_deg=None,
18
+ nest: bool = True,
19
+ reduce: str = "mean", # 'mean'|'median'|'sum'|'first'
20
+ mask_outside: bool = True,
21
+ unseen_value=None, # defaults to hp.UNSEEN (scalar) or np.nan (RGB)
22
+ return_image_only: bool = False,
23
+ title: str = None, cmap: str = "viridis", vmin=None, vmax=None,
24
+ notext: bool = False,
25
+ hold: bool = True,
26
+ interp: bool = False,
27
+ sub=(1,1,1),
28
+ cbar: bool = False,
29
+ unit: str = "Value",
30
+ rgb_clip=(0.0, 1.0), # clip range for RGB
25
31
  ):
26
32
  """
27
- Gnomonic projection from *sparse* HEALPix samples (cell_ids, data) to an image (ysize, xsize).
28
-
29
- For each output image pixel (i,j):
30
- plane (x,y) --inverse gnomonic--> (theta, phi) --HEALPix--> ipix
31
- if ipix in `cell_ids`: assign aggregated value, else UNSEEN.
32
-
33
- Parameters
34
- ----------
35
- cell_ids : (N,) int array
36
- HEALPix pixel indices of your samples. Must correspond to `nside` and `nest`.
37
- data : (N,) float array
38
- Sample values for each cell id.
39
- nside : int
40
- HEALPix NSIDE used for both `cell_ids` and the image reprojection.
41
- rot : (lon0_deg, lat0_deg, psi_deg) or None
42
- Gnomonic center (lon, lat) and in-plane rotation psi (deg).
43
- If None, we auto-center from the *centers of the provided pixels* (via hp.pix2ang).
44
- xsize, ysize : int
45
- Output image size (pixels).
46
- reso : float or None
47
- Pixel size (deg/pixel) on the tangent plane. If None, derived from `fov_deg`.
48
- fov_deg : float or (float,float)
49
- Full field of view in degrees.
50
- nest : bool
51
- Use True if your `cell_ids` correspond to NESTED indexing.
52
- reduce : str
53
- How to combine duplicate cell ids: 'mean'|'median'|'sum'|'first'.
54
- mask_outside : bool
55
- Mask pixels outside the valid gnomonic hemisphere (cosc <= 0).
56
- unseen_value : float or None
57
- Value for invalid pixels (default hp.UNSEEN).
58
-
59
- return_image_only : bool
60
- If True, return the 2D array only (no plotting).
61
-
62
- Returns
63
- -------
64
- (fig, ax, img) or img
65
- If return_image_only=True, returns img (ysize, xsize).
33
+ Gnomonic projection from *sparse* HEALPix samples (cell_ids, data) to an image.
34
+ Supports scalar data (N,) and RGB data (N,3). For RGB, colorbar/cmap/vmin/vmax are ignored.
66
35
  """
36
+
37
+ # -------- 0) Input normalization --------
67
38
  if unseen_value is None:
68
- unseen_value = hp.UNSEEN
39
+ unseen_value = hp.UNSEEN if (np.ndim(data) == 1 or (np.ndim(data)==2 and data.shape[1] != 3)) else np.nan
69
40
 
70
41
  cell_ids = np.asarray(cell_ids, dtype=np.int64)
71
- vals_in = np.asarray(data, dtype=float)
72
- if cell_ids.shape != vals_in.shape:
73
- raise ValueError("cell_ids and data must have the same shape (N,)")
74
-
75
- # -------- 1) Aggregate duplicates in cell_ids (if any) --------
76
- uniq, inv = np.unique(cell_ids, return_inverse=True) # uniq is sorted
77
- if reduce == "first":
78
- first_idx = np.full(uniq.size, -1, dtype=np.int64)
79
- for i, g in enumerate(inv):
80
- if first_idx[g] < 0:
81
- first_idx[g] = i
82
- agg_vals = vals_in[first_idx]
83
- elif reduce == "sum":
84
- agg_vals = np.zeros(uniq.size, dtype=float)
85
- np.add.at(agg_vals, inv, vals_in)
86
- elif reduce == "median":
87
- agg_vals = np.empty(uniq.size, dtype=float)
88
- for k, pix in enumerate(uniq):
89
- agg_vals[k] = np.median(vals_in[cell_ids == pix])
90
- elif reduce == "mean":
91
- sums = np.zeros(uniq.size, dtype=float)
92
- cnts = np.zeros(uniq.size, dtype=float)
93
- np.add.at(sums, inv, vals_in)
94
- np.add.at(cnts, inv, 1.0)
95
- agg_vals = sums / np.maximum(cnts, 1.0)
42
+ vals_in = np.asarray(data)
43
+
44
+ if vals_in.ndim == 1:
45
+ is_rgb = False
46
+ if cell_ids.shape != vals_in.shape:
47
+ raise ValueError("For scalar mode, `data` must have shape (N,).")
48
+ vals_in = vals_in.astype(float)
49
+ elif vals_in.ndim == 2 and vals_in.shape[1] == 3:
50
+ is_rgb = True
51
+ if cell_ids.shape[0] != vals_in.shape[0]:
52
+ raise ValueError("For RGB mode, `data` must have shape (N,3) matching `cell_ids` length.")
53
+ vals_in = vals_in.astype(float)
54
+ else:
55
+ raise ValueError("`data` must be (N,) for scalar or (N,3) for RGB.")
56
+
57
+ # -------- 1) Aggregate duplicates in cell_ids --------
58
+ uniq, inv = np.unique(cell_ids, return_inverse=True) # uniq sorted
59
+ if is_rgb:
60
+ if reduce == "first":
61
+ first_idx = np.full(uniq.size, -1, dtype=np.int64)
62
+ for i, g in enumerate(inv):
63
+ if first_idx[g] < 0:
64
+ first_idx[g] = i
65
+ agg_vals = vals_in[first_idx, :] # (U,3)
66
+ elif reduce == "sum":
67
+ agg_vals = np.zeros((uniq.size, 3), dtype=float)
68
+ # sum per channel
69
+ for c in range(3):
70
+ np.add.at(agg_vals[:, c], inv, vals_in[:, c])
71
+ elif reduce == "median":
72
+ agg_vals = np.empty((uniq.size, 3), dtype=float)
73
+ for k, pix in enumerate(uniq):
74
+ sel = (cell_ids == pix)
75
+ agg_vals[k, :] = np.median(vals_in[sel, :], axis=0)
76
+ elif reduce == "mean":
77
+ sums = np.zeros((uniq.size, 3), dtype=float)
78
+ cnts = np.zeros(uniq.size, dtype=float)
79
+ for c in range(3):
80
+ np.add.at(sums[:, c], inv, vals_in[:, c])
81
+ np.add.at(cnts, inv, 1.0)
82
+ agg_vals = sums / np.maximum(cnts[:, None], 1.0)
83
+ else:
84
+ raise ValueError("reduce must be one of {'mean','median','sum','first'}")
96
85
  else:
97
- raise ValueError("reduce must be one of {'mean','median','sum','first'}")
86
+ # scalar path (comme ta version)
87
+ if reduce == "first":
88
+ first_idx = np.full(uniq.size, -1, dtype=np.int64)
89
+ for i, g in enumerate(inv):
90
+ if first_idx[g] < 0:
91
+ first_idx[g] = i
92
+ agg_vals = vals_in[first_idx]
93
+ elif reduce == "sum":
94
+ agg_vals = np.zeros(uniq.size, dtype=float)
95
+ np.add.at(agg_vals, inv, vals_in)
96
+ elif reduce == "median":
97
+ agg_vals = np.empty(uniq.size, dtype=float)
98
+ for k, pix in enumerate(uniq):
99
+ agg_vals[k] = np.median(vals_in[cell_ids == pix])
100
+ elif reduce == "mean":
101
+ sums = np.zeros(uniq.size, dtype=float)
102
+ cnts = np.zeros(uniq.size, dtype=float)
103
+ np.add.at(sums, inv, vals_in)
104
+ np.add.at(cnts, inv, 1.0)
105
+ agg_vals = sums / np.maximum(cnts, 1.0)
106
+ else:
107
+ raise ValueError("reduce must be one of {'mean','median','sum','first'}")
98
108
 
99
109
  # -------- 2) Choose gnomonic center (rot) --------
100
110
  if rot is None:
101
- # Center from pixel centers of provided ids
102
111
  theta_c, phi_c = hp.pix2ang(nside, uniq, nest=nest) # colat, lon (rad)
103
- # circular mean for lon, median for colat
104
112
  lon0_deg = np.degrees(np.angle(np.mean(np.exp(1j * phi_c))))
105
113
  lat0_deg = 90.0 - np.degrees(np.median(theta_c))
106
114
  psi_deg = 0.0
@@ -111,7 +119,7 @@ def lgnomproject(
111
119
  lat0 = np.deg2rad(lat0_deg)
112
120
  psi = np.deg2rad(psi_deg)
113
121
 
114
- # -------- 3) Tangent-plane grid (gnomonic) --------
122
+ # -------- 3) Tangent-plane grid --------
115
123
  if reso is not None:
116
124
  dx = np.tan(np.deg2rad(reso))
117
125
  dy = dx
@@ -119,8 +127,7 @@ def lgnomproject(
119
127
  half_y = 0.5 * ysize * dy
120
128
  else:
121
129
  if fov_deg is None:
122
- fov_deg=np.rad2deg(np.sqrt(cell_ids.shape[0])/nside)*1.4
123
-
130
+ fov_deg = np.rad2deg(np.sqrt(cell_ids.shape[0]) / nside) * 1.4
124
131
  if np.isscalar(fov_deg):
125
132
  fx, fy = float(fov_deg), float(fov_deg)
126
133
  else:
@@ -129,13 +136,12 @@ def lgnomproject(
129
136
  ay = np.deg2rad(0.5 * fy)
130
137
  half_x = np.tan(ax)
131
138
  half_y = np.tan(ay)
132
-
139
+
133
140
  xs = np.linspace(-half_x, +half_x, xsize, endpoint=False) + (half_x / xsize)
134
141
  ys = np.linspace(-half_y, +half_y, ysize, endpoint=False) + (half_y / ysize)
135
-
136
- X, Y = np.meshgrid(xs, ys) # (ysize, xsize)
142
+ X, Y = np.meshgrid(xs, ys)
137
143
 
138
- # in-plane rotation psi
144
+ # rotate plane
139
145
  c, s = np.cos(psi), np.sin(psi)
140
146
  Xr = c * X + s * Y
141
147
  Yr = -s * X + c * Y
@@ -154,79 +160,165 @@ def lgnomproject(
154
160
  theta_img = (np.pi / 2.0) - lat
155
161
  outside = (cosc <= 0.0) if mask_outside else np.zeros_like(cosc, dtype=bool)
156
162
 
157
- # -------- 5) Map image pixels to HEALPix ids --------
158
- ip_img = hp.ang2pix(nside, theta_img.ravel(), lon.ravel(), nest=nest).astype(np.int64)
163
+ if interp:
164
+
165
+ # Inputs attendus (déjà présents dans ton code):
166
+ # - nside, nest
167
+ # - theta_img, lon (angles pour chaque pixel de l'image, shape (ysize, xsize))
168
+ # - uniq (np.ndarray trié d'indices HEALPix présents)
169
+ # - agg_vals : valeurs associées à uniq
170
+ # * si is_rgb: shape (uniq.size, 3)
171
+ # * sinon: shape (uniq.size,)
172
+ # - ysize, xsize
173
+ # - mask_outside (bool)
174
+ # - outside (masque bool à plat ou 2D selon ton code)
175
+ # - unseen_value (float), p.ex. np.nan ou autre
176
+
177
+ # -------- 5) (NOUVEAU) Interpolation bilinéaire via poids HEALPix --------
178
+ # Aplatis les angles de l'image
179
+ theta_flat = theta_img.ravel()
180
+ phi_flat = lon.ravel()
181
+
182
+ # Récupère pour chaque direction les indices de 4 voisins et leurs poids
183
+ # inds: shape (npix_img, 4) ; w: shape (npix_img, 4)
184
+ inds, w = hp.get_interp_weights(nside, theta_flat, phi_flat, nest=nest)
185
+
186
+ # On mappe les indices 'inds' (HEALPix) vers positions dans 'uniq' en O(log N)
187
+ pos = np.searchsorted(uniq, inds, side="left")
188
+ in_range = pos < uniq.size
189
+ match = np.zeros_like(in_range, dtype=bool)
190
+ match[in_range] = (uniq[pos[in_range]] == inds[in_range])
191
+
192
+ # Construit un masque 'valid' des voisins présents dans tes données
193
+ # valid shape (npix_img, 4)
194
+ valid = match
195
+
196
+ if is_rgb:
197
+ # Récupère les valeurs des 4 voisins (3 canaux)
198
+ # vals shape (npix_img, 4, 3) avec NaN pour voisins absents
199
+ vals = np.full((inds.shape[0], inds.shape[1], 3), np.nan, dtype=float)
200
+ # positions valides -> on insère les vraies valeurs RGB
201
+ vals[valid, :] = agg_vals[pos[valid], :]
202
+
203
+ # pondération : on annule le poids là où la valeur est absente
204
+ w_eff = w.copy()
205
+ w_eff[~valid] = 0.0
206
+ ws = np.sum(w_eff, axis=0, keepdims=False) # somme des poids utiles
207
+
208
+ # éviter division par 0 : pixels hors couverture -> unseen
209
+ nonzero = ws.squeeze() > 0
210
+ img_flat = np.full((inds.shape[1], 3), unseen_value, dtype=float)
211
+
212
+ # combinaison pondérée (en ignorant les NaN via w_eff)
213
+ num = np.nansum(vals * w_eff[..., None], axis=0) # (npix_img, 3)
214
+ img_flat[nonzero, :] = (num[nonzero, :] / ws[nonzero,None]).astype(float)
215
+
216
+ img = img_flat.reshape(ysize, xsize, 3)
217
+
218
+ if mask_outside:
219
+ mask = outside.reshape(ysize, xsize)
220
+ img[mask, :] = np.nan # ou unseen_value
221
+
222
+ else:
223
+ # Scalaire : vals shape (npix_img, 4)
224
+ vals = np.full(inds.shape, np.nan, dtype=float)
225
+ vals[valid] = agg_vals[pos[valid]]
159
226
 
160
- # -------- 6) Assign values by matching ip_img ∈ uniq (safe searchsorted) --------
161
- # uniq is sorted; build insertion pos then check matches only where pos < len(uniq)
162
- pos = np.searchsorted(uniq, ip_img, side="left")
163
- valid = pos < uniq.size
164
- match = np.zeros_like(valid, dtype=bool)
165
- match[valid] = (uniq[pos[valid]] == ip_img[valid])
227
+ w_eff = w.copy()
228
+ w_eff[~valid] = 0.0
229
+ ws = np.sum(w_eff, axis=0) # (,npix_img)
166
230
 
167
- img_flat = np.full(ip_img.shape, unseen_value, dtype=float)
168
- img_flat[match] = agg_vals[pos[match]]
169
- img = img_flat.reshape(ysize, xsize)
231
+ img_flat = np.full(inds.shape[1], unseen_value, dtype=float)
232
+ nonzero = ws > 0
170
233
 
171
- # Mask out-of-hemisphere gnomonic region
172
- if mask_outside:
173
- img[outside] = unseen_value
234
+ num = np.nansum(vals * w_eff, axis=0) # (npix_img,)
235
+ img_flat[nonzero] = (num[nonzero] / ws[nonzero]).astype(float)
236
+
237
+ img = img_flat.reshape(ysize, xsize)
238
+ if mask_outside:
239
+ img[outside] = unseen_value
240
+
241
+ else:
242
+ # -------- 5) Map image pixels to HEALPix ids --------
243
+ ip_img = hp.ang2pix(nside, theta_img.ravel(), lon.ravel(), nest=nest).astype(np.int64)
244
+
245
+ # -------- 6) Assign values by matching ip_img ∈ uniq --------
246
+ pos = np.searchsorted(uniq, ip_img, side="left")
247
+ valid = pos < uniq.size
248
+ match = np.zeros_like(valid, dtype=bool)
249
+ match[valid] = (uniq[pos[valid]] == ip_img[valid])
250
+
251
+ if is_rgb:
252
+ img_flat = np.full((ip_img.size, 3), np.nan, dtype=float)
253
+ img_flat[match, :] = agg_vals[pos[match], :]
254
+ img = img_flat.reshape(ysize, xsize, 3)
255
+ if mask_outside:
256
+ mask = outside.reshape(ysize, xsize)
257
+ img[mask, :] = np.nan
258
+ else:
259
+ img_flat = np.full(ip_img.shape, unseen_value, dtype=float)
260
+ img_flat[match] = agg_vals[pos[match]]
261
+ img = img_flat.reshape(ysize, xsize)
262
+ if mask_outside:
263
+ img[outside] = unseen_value
174
264
 
175
265
  # -------- 7) Return / plot --------
176
266
  if return_image_only:
177
267
  return img
178
268
 
179
- # Axes in approx. "gnomonic degrees" (atan of plane coords)
269
+ # axes extents (approx)
180
270
  x_deg = np.degrees(np.arctan(xs))
181
271
  y_deg = np.degrees(np.arctan(ys))
182
-
183
- longitude_min=x_deg[0]/np.cos(np.deg2rad(lat0_deg))+lon0_deg
184
- longitude_max=x_deg[-1]/np.cos(np.deg2rad(lat0_deg))+lon0_deg
185
-
186
- if longitude_min>180:
187
- longitude_min-=360
188
- longitude_max-=360
189
-
190
- extent = (longitude_min,longitude_max,
191
- y_deg[0]+lat0_deg, y_deg[-1]+lat0_deg)
272
+ longitude_min = x_deg[0]/np.cos(np.deg2rad(lat0_deg)) + lon0_deg
273
+ longitude_max = x_deg[-1]/np.cos(np.deg2rad(lat0_deg)) + lon0_deg
274
+ if longitude_min > 180:
275
+ longitude_min -= 360
276
+ longitude_max -= 360
277
+ extent = (longitude_min, longitude_max, y_deg[0]+lat0_deg, y_deg[-1]+lat0_deg)
192
278
 
193
279
  if hold:
194
280
  fig, ax = plt.subplots(figsize=(xsize/100, ysize/100), dpi=100)
195
281
  else:
196
- ax=plt.subplot(sub[0],sub[1],sub[2])
197
-
198
- im = ax.imshow(
199
- np.where(img == unseen_value, np.nan, img),
200
- origin="lower",
201
- extent=extent,
202
- cmap=cmap,
203
- vmin=vmin, vmax=vmax,
204
- interpolation="nearest",
205
- aspect="auto"
206
- )
282
+ ax = plt.subplot(sub[0], sub[1], sub[2])
283
+
284
+ if is_rgb:
285
+ shown = ax.imshow(
286
+ np.clip(img, rgb_clip[0], rgb_clip[1]),
287
+ origin="lower",
288
+ extent=extent,
289
+ interpolation="nearest",
290
+ aspect="auto"
291
+ )
292
+ # pas de cmap/cbar en RGB
293
+ else:
294
+ shown = ax.imshow(
295
+ np.where(img == unseen_value, np.nan, img),
296
+ origin="lower",
297
+ extent=extent,
298
+ cmap=cmap,
299
+ vmin=vmin, vmax=vmax,
300
+ interpolation="nearest",
301
+ aspect="auto"
302
+ )
303
+ if cbar:
304
+ if hold:
305
+ cb = fig.colorbar(shown, ax=ax)
306
+ cb.set_label("value")
307
+ else:
308
+ plt.colorbar(shown, ax=ax, orientation="horizontal", label=unit)
309
+
207
310
  if not notext:
208
311
  ax.set_xlabel("Longitude (deg)")
209
312
  ax.set_ylabel("Latitude (deg)")
210
313
  else:
211
314
  ax.set_xticks([])
212
315
  ax.set_yticks([])
213
-
316
+
214
317
  if title:
215
318
  ax.set_title(title)
216
-
217
- if cbar:
218
- if hold:
219
- cb = fig.colorbar(im, ax=ax)
220
- cb.set_label("value")
221
- else:
222
- plt.colorbar(im, ax=ax, orientation="horizontal", label="value")
223
-
224
- plt.tight_layout()
225
- if hold:
226
- return fig, ax #, img
227
- else:
228
- return ax
229
319
 
320
+ plt.tight_layout()
321
+ return (fig, ax) if hold else ax
230
322
 
231
323
  def plot_scat(s1,s2,s3,s4):
232
324
 
@@ -329,3 +421,878 @@ def plot_scat(s1,s2,s3,s4):
329
421
  plt.xticks(l_pos,l_name, fontsize=6, rotation=90)
330
422
  plt.xlabel(r"$j_{1},j_{2},j_{3}$", fontsize=9)
331
423
  plt.ylabel(r"$S_{4}$", fontsize=9)
424
+
425
+
426
+ import numpy as np
427
+
428
+ def power_spectrum_1d(data, dx=1.0):
429
+ """
430
+ Compute the isotropic 1D power spectrum of a 2D field.
431
+
432
+ Parameters
433
+ ----------
434
+ data : ndarray (ny, nx)
435
+ Input 2D field.
436
+ dx : float
437
+ Pixel size in the same spatial unit as desired frequency inverse.
438
+ If dx is in meters, returned frequencies are in m^-1 (cycles per meter).
439
+
440
+ Returns
441
+ -------
442
+ f_centers : ndarray
443
+ Radial spatial frequencies (cycles per unit length), e.g., m^-1 if dx is in meters.
444
+ Pk : ndarray
445
+ Azimuthally averaged power spectrum over radial frequency bins (arbitrary units unless you add a normalization).
446
+ """
447
+ # 2D FFT and power
448
+ F = np.fft.fftshift(np.fft.fft2(data))
449
+ P2D = np.abs(F) ** 2
450
+
451
+ # Spatial frequency grids (cycles per unit length; NOT radians)
452
+ ny, nx = data.shape
453
+ fx = np.fft.fftshift(np.fft.fftfreq(nx, d=dx)) # cycles per unit length (e.g., m^-1)
454
+ fy = np.fft.fftshift(np.fft.fftfreq(ny, d=dx)) # cycles per unit length (e.g., m^-1)
455
+ fx2d, fy2d = np.meshgrid(fx, fy, indexing="xy")
456
+ fr = np.sqrt(fx2d**2 + fy2d**2) # radial spatial frequency (cycles per unit length)
457
+
458
+ # Radial binning
459
+ nbins = min(nx, ny) // 2
460
+ f_bins = np.linspace(0.0, fr.max(), nbins + 1)
461
+
462
+ # Vectorized bin average of P2D over annuli
463
+ fr_flat = fr.ravel()
464
+ P_flat = P2D.ravel()
465
+ bin_idx = np.digitize(fr_flat, f_bins) - 1 # -> [0, nbins-1]
466
+ valid = (bin_idx >= 0) & (bin_idx < nbins)
467
+
468
+ # Sum and count per bin, then mean
469
+ sum_per_bin = np.bincount(bin_idx[valid], weights=P_flat[valid], minlength=nbins)
470
+ cnt_per_bin = np.bincount(bin_idx[valid], minlength=nbins)
471
+ with np.errstate(invalid="ignore", divide="ignore"):
472
+ Pk = sum_per_bin / cnt_per_bin
473
+ Pk[cnt_per_bin == 0] = np.nan # empty bins
474
+
475
+ # Bin centers
476
+ f_centers = 0.5 * (f_bins[1:] + f_bins[:-1])
477
+
478
+ return f_centers, Pk
479
+
480
+ import numpy as np
481
+
482
+ def _freq_grids(ny, nx, dx=1.0):
483
+ """Return 2D radial spatial frequency grid fr (cycles per unit), with fftshift."""
484
+ fx = np.fft.fftshift(np.fft.fftfreq(nx, d=dx))
485
+ fy = np.fft.fftshift(np.fft.fftfreq(ny, d=dx))
486
+ fx2d, fy2d = np.meshgrid(fx, fy, indexing="xy")
487
+ fr = np.sqrt(fx2d**2 + fy2d**2)
488
+ return fr
489
+
490
+ def _hann2d(ny, nx):
491
+ """2D separable Hann apodization."""
492
+ wy = np.hanning(ny)
493
+ wx = np.hanning(nx)
494
+ return np.outer(wy, wx)
495
+
496
+ def estimate_psd_slope(img, dx=1.0, fmin_frac=0.02, fmax_frac=0.4):
497
+ """
498
+ Estimate beta in P(f) ~ f^-beta from the isotropic 1D PSD (log-log linear fit).
499
+ Uses the provided band [fmin_frac, fmax_frac] * f_max to avoid DC/Nyquist artifacts.
500
+ """
501
+ # 2D periodogram
502
+ F = np.fft.fftshift(np.fft.fft2(img))
503
+ P2D = np.abs(F)**2
504
+ ny, nx = img.shape
505
+ fr = _freq_grids(ny, nx, dx=dx)
506
+ # radial bins
507
+ nbins = min(nx, ny)//2
508
+ f_bins = np.linspace(0.0, fr.max(), nbins+1)
509
+ fr_flat = fr.ravel(); P_flat = P2D.ravel()
510
+ bin_idx = np.digitize(fr_flat, f_bins) - 1
511
+ valid = (bin_idx >= 0) & (bin_idx < nbins)
512
+ sum_bin = np.bincount(bin_idx[valid], weights=P_flat[valid], minlength=nbins)
513
+ cnt_bin = np.bincount(bin_idx[valid], minlength=nbins)
514
+ with np.errstate(invalid="ignore", divide="ignore"):
515
+ Pk = sum_bin / cnt_bin
516
+ Pk[cnt_bin == 0] = np.nan
517
+ f_centers = 0.5*(f_bins[1:] + f_bins[:-1])
518
+
519
+ # fit on a safe band
520
+ fmax = np.nanmax(f_centers)
521
+ mask = (f_centers > fmin_frac*fmax) & (f_centers < fmax_frac*fmax) & np.isfinite(Pk) & (Pk > 0)
522
+ x = np.log10(f_centers[mask]); y = np.log10(Pk[mask])
523
+ if x.size < 5:
524
+ return np.nan
525
+ m, b = np.polyfit(x, y, 1) # log10 P = m log10 f + b
526
+ beta = -m # since P ~ f^m -> m = -beta
527
+ return beta
528
+
529
+ def adjust_psd_slope(img, dx=1.0, delta_beta=0.0,
530
+ f_ref=None, band=None,
531
+ apodize=True, preserve_mean=True, match_variance=True, eps=None):
532
+ """
533
+ Change the isotropic PSD slope by delta_beta (P -> P * f^{-delta_beta}).
534
+ - delta_beta > 0 : steeper spectrum (more large-scale, smoother image)
535
+ - delta_beta < 0 : flatter/whiter spectrum (more small-scale, rougher image)
536
+
537
+ Parameters
538
+ ----------
539
+ img : 2D array
540
+ Input image.
541
+ dx : float
542
+ Pixel size (e.g., meters). Frequencies are cycles per unit of dx.
543
+ delta_beta : float
544
+ Desired slope change: P' ~ P * f^{-delta_beta}.
545
+ f_ref : float or None
546
+ Reference frequency for normalization. If None, use median nonzero f.
547
+ band : tuple (f_lo, f_hi) or None
548
+ If set, apply the slope change only within [f_lo, f_hi] (cycles per unit); smooth edges.
549
+ apodize : bool
550
+ Apply 2D Hann window before FFT to reduce edge ringing.
551
+ preserve_mean : bool
552
+ Keep DC (mean) unchanged.
553
+ match_variance : bool
554
+ Rescale output to match input variance.
555
+ eps : float or None
556
+ Small positive to protect f=0. If None, set to 1/(max(n)*dx).
557
+
558
+ Returns
559
+ -------
560
+ out : 2D array (real)
561
+ Image with adjusted spectrum slope.
562
+ """
563
+ img = np.asarray(img, float)
564
+ ny, nx = img.shape
565
+
566
+ # Apodization
567
+ if apodize:
568
+ w2 = _hann2d(ny, nx)
569
+ imgw = img * w2
570
+ else:
571
+ imgw = img
572
+
573
+ # FFT
574
+ F = np.fft.fftshift(np.fft.fft2(imgw))
575
+
576
+ # Radial frequency grid
577
+ fr = _freq_grids(ny, nx, dx=dx)
578
+ if eps is None:
579
+ eps = 1.0 / (max(nx, ny) * dx)
580
+
581
+ # Reference frequency
582
+ if f_ref is None:
583
+ f_ref = np.median(fr[fr > 0])
584
+
585
+ # Base radial gain for amplitudes (half the PSD exponent)
586
+ H = ((fr + eps) / (f_ref + eps)) ** (-0.5 * delta_beta)
587
+
588
+ # Optional band-limiting with smooth cosine tapers
589
+ if band is not None:
590
+ f_lo, f_hi = band
591
+ if f_lo is None: f_lo = 0.0
592
+ if f_hi is None: f_hi = fr.max()
593
+ # smooth 0..1 mask between f_lo and f_hi (raised-cosine of width 10% band)
594
+ width = 0.1 * (f_hi - f_lo) if f_hi > f_lo else 0.0
595
+ def smooth_step(f, a, b):
596
+ # 0 below a, 1 above b, cosine ramp in between
597
+ if b <= a:
598
+ return (f >= b).astype(float)
599
+ t = np.clip((f - a) / (b - a), 0, 1)
600
+ return 0.5 - 0.5*np.cos(np.pi*t)
601
+ mask_lo = smooth_step(fr, f_lo - width, f_lo + width)
602
+ mask_hi = 1.0 - smooth_step(fr, f_hi - width, f_hi + width)
603
+ band_mask = mask_lo * mask_hi
604
+ H = 1.0 + band_mask * (H - 1.0)
605
+
606
+ # Preserve DC (mean) if requested
607
+ if preserve_mean:
608
+ H[fr == 0] = 1.0
609
+
610
+ # Apply filter on Fourier amplitudes
611
+ Ff = F * H
612
+
613
+ # Inverse FFT (undo shift)
614
+ out = np.fft.ifft2(np.fft.ifftshift(Ff)).real
615
+
616
+ # Undo apodization bias (optional): we keep variance matching which is simpler/robust
617
+ if match_variance:
618
+ s_in = np.std(img)
619
+ s_out = np.std(out)
620
+ if s_out > 0:
621
+ out = (out - out.mean()) * (s_in / s_out) + (img.mean() if preserve_mean else 0.0)
622
+
623
+ return out
624
+
625
+ # --- 1) Lat/Lon -> fractional XYZ tile coords at zoom z ---
626
+ def latlon_to_xyz_frac(lat, lon, z):
627
+ """Return fractional tile coordinates (xf, yf) at zoom z (Web Mercator)."""
628
+ n = 2 ** z
629
+ xf = (lon + 180.0) / 360.0 * n
630
+ lat_rad = np.radians(lat)
631
+ yf = (1.0 - np.log(np.tan(lat_rad) + 1/np.cos(lat_rad)) / math.pi) / 2.0 * n
632
+ return xf, yf
633
+
634
+ # --- 2) Fractional tile coords -> (xtile, ytile, px, py) inside 256×256 tile ---
635
+ def xyz_frac_to_tile_pixel(xf, yf, tile_size=256):
636
+ xtile = np.floor(xf).astype(int)
637
+ ytile = np.floor(yf).astype(int)
638
+ px = np.floor((xf - xtile) * tile_size).astype(int)
639
+ py = np.floor((yf - ytile) * tile_size).astype(int)
640
+ # clamp just in case of edge cases at tile borders
641
+ px = np.clip(px, 0, tile_size - 1)
642
+ py = np.clip(py, 0, tile_size - 1)
643
+ return xtile, ytile, px, py
644
+
645
+ # --- 3) Simple tile cache + fetcher for Esri World Imagery ---
646
+ ESRI_WORLD_IMAGERY = (
647
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
648
+ "World_Imagery/MapServer/tile/{z}/{y}/{x}"
649
+ )
650
+
651
+ class TileCache:
652
+ def __init__(self):
653
+ self.cache = {} # (z,x,y) -> PIL.Image
654
+ self.session = requests.Session()
655
+ self.session.headers.update({"User-Agent": "research-sampler/1.0"})
656
+ def get_tile(self, z, x, y, timeout=10):
657
+ key = (z, x, y)
658
+ if key in self.cache:
659
+ return self.cache[key]
660
+ url = ESRI_WORLD_IMAGERY.format(z=z, x=x, y=y)
661
+ r = self.session.get(url, timeout=timeout)
662
+ r.raise_for_status()
663
+ img = Image.open(io.BytesIO(r.content)).convert("RGB")
664
+ self.cache[key] = img
665
+ return img
666
+
667
+ # --- 4) Main sampler ---
668
+ def sample_esri_world_imagery(lat, lon, zoom=17, tile_size=256):
669
+ """
670
+ lat, lon: arrays of shape (N,)
671
+ zoom: Web Mercator zoom level
672
+ returns: RGB uint8 array of shape (N, 3)
673
+ """
674
+ lat = np.asarray(lat, dtype=float)
675
+ lon = np.asarray(lon, dtype=float)
676
+ assert lat.shape == lon.shape
677
+ N = lat.size
678
+
679
+ xf, yf = latlon_to_xyz_frac(lat, lon, zoom)
680
+ xt, yt, px, py = xyz_frac_to_tile_pixel(xf, yf, tile_size=tile_size)
681
+
682
+ # Group by tile to fetch each only once
683
+ tile_cache = TileCache()
684
+ rgb = np.zeros((N, 3), dtype=np.uint8)
685
+
686
+ # Make an index per unique tile
687
+ tiles, inv = np.unique(np.stack([xt, yt], axis=1), axis=0, return_inverse=True)
688
+ # For each unique tile, fetch and sample all points in that tile
689
+ for t_idx, (x_tile, y_tile) in enumerate(tiles):
690
+ # gather original indices belonging to this tile
691
+ sel = np.where(inv == t_idx)[0]
692
+ # fetch image
693
+ try:
694
+ img = tile_cache.get_tile(zoom, x_tile, y_tile)
695
+ except Exception as e:
696
+ # If a tile fails, leave zeros or handle as you wish
697
+ print(f"Warning: failed to fetch tile z={zoom} x={x_tile} y={y_tile}: {e}")
698
+ continue
699
+ pix = img.load() # pixel accessor is fast enough for sparse samples
700
+ for i in sel:
701
+ rgb[i] = pix[int(px[i]), int(py[i])]
702
+ return rgb
703
+
704
+ # ---------- EXAMPLE ----------
705
+ # latN, lonN are your arrays of length N
706
+ # latN = np.array([...], dtype=float)
707
+ # lonN = np.array([...], dtype=float)
708
+ # zoom = 17 # adjust to your scale needs
709
+ # vals_rgb = sample_esri_world_imagery(latN, lonN, zoom=zoom)
710
+ # vals_rgb.shape -> (N, 3)
711
+
712
+ import numpy as np
713
+ import healpy as hp
714
+
715
+ # --- NESTED helpers ---
716
+ def _compact_bits_u64(z):
717
+ z = z & np.uint64(0x5555555555555555)
718
+ z = (z | (z >> 1)) & np.uint64(0x3333333333333333)
719
+ z = (z | (z >> 2)) & np.uint64(0x0F0F0F0F0F0F0F0F)
720
+ z = (z | (z >> 4)) & np.uint64(0x00FF00FF00FF00FF)
721
+ z = (z | (z >> 8)) & np.uint64(0x0000FFFF0000FFFF)
722
+ z = (z | (z >> 16)) & np.uint64(0x00000000FFFFFFFF)
723
+ return z
724
+
725
+ def _spread_bits_u64(v):
726
+ v = v & np.uint64(0x00000000FFFFFFFF)
727
+ v = (v | (v << 16)) & np.uint64(0x0000FFFF0000FFFF)
728
+ v = (v | (v << 8)) & np.uint64(0x00FF00FF00FF00FF)
729
+ v = (v | (v << 4)) & np.uint64(0x0F0F0F0F0F0F0F0F)
730
+ v = (v | (v << 2)) & np.uint64(0x3333333333333333)
731
+ v = (v | (v << 1)) & np.uint64(0x5555555555555555)
732
+ return v
733
+
734
+ def _nest_to_fxy(ipix, nside):
735
+ ipix = ipix.astype(np.uint64)
736
+ pp = np.uint64(nside) * np.uint64(nside)
737
+ face = ipix // pp
738
+ inface = ipix % pp
739
+ x = _compact_bits_u64(inface)
740
+ y = _compact_bits_u64(inface >> np.uint64(1))
741
+ return face.astype(np.int64), x.astype(np.int64), y.astype(np.int64)
742
+
743
+ def _fxy_to_nest(face, x, y, nside):
744
+ inter = _spread_bits_u64(x.astype(np.uint64)) | (_spread_bits_u64(y.astype(np.uint64)) << np.uint64(1))
745
+ base = face.astype(np.uint64) * (np.uint64(nside) * np.uint64(nside))
746
+ return (base + inter).astype(np.int64)
747
+
748
+ # --- Main ---
749
+ def get_half_interp_weights_ang_general(nside_full, theta, phi, edge_mode="nearest"):
750
+ """
751
+ Bilinear weights from the 'half-level' lattice (EVEN pixels of full grid) to arbitrary directions.
752
+ Returns I (4 even NESTED ids) and W (4 weights summing to 1).
753
+ """
754
+ theta = np.asarray(theta, dtype=np.float64).ravel()
755
+ phi = np.asarray(phi, dtype=np.float64).ravel()
756
+ N = theta.size
757
+
758
+ # 1) Use the full-grid containing pixel to locate face/x/y neighborhood
759
+ ids_full = hp.ang2pix(nside_full, theta, phi, nest=True)
760
+ face, x, y = _nest_to_fxy(ids_full, nside_full)
761
+
762
+ # 2) Even anchors for 2×2 block on the even lattice
763
+ x0 = (x // 2) * 2
764
+ y0 = (y // 2) * 2
765
+ x1 = x0 + 2
766
+ y1 = y0 + 2
767
+
768
+ Xs = [x0, x1, x0, x1]
769
+ Ys = [y0, y0, y1, y1]
770
+
771
+ if edge_mode == "nearest":
772
+ Xs = [np.clip(X, 0, nside_full - 1) for X in Xs]
773
+ Ys = [np.clip(Y, 0, nside_full - 1) for Y in Ys]
774
+ drop_mask = [np.zeros(N, dtype=bool) for _ in range(4)]
775
+ elif edge_mode == "drop":
776
+ drop_mask = [
777
+ (Xs[0] < 0) | (Xs[0] >= nside_full) | (Ys[0] < 0) | (Ys[0] >= nside_full),
778
+ (Xs[1] < 0) | (Xs[1] >= nside_full) | (Ys[1] < 0) | (Ys[1] >= nside_full),
779
+ (Xs[2] < 0) | (Xs[2] >= nside_full) | (Ys[2] < 0) | (Ys[2] >= nside_full),
780
+ (Xs[3] < 0) | (Xs[3] >= nside_full) | (Ys[3] < 0) | (Ys[3] >= nside_full),
781
+ ]
782
+ Xs = [np.clip(X, 0, nside_full - 1) for X in Xs]
783
+ Ys = [np.clip(Y, 0, nside_full - 1) for Y in Ys]
784
+ else:
785
+ raise ValueError("edge_mode must be 'nearest' or 'drop'")
786
+
787
+ # 3) Map the four even corners to ids
788
+ I = np.empty((4, N), dtype=np.int64)
789
+ for k in range(4):
790
+ I[k] = _fxy_to_nest(face, Xs[k], Ys[k], nside_full)
791
+
792
+ # 4) Build 3D vectors (STACK tuples -> (N,3))
793
+ v_tgt = np.vstack(hp.ang2vec(theta, phi,nest=True)).T # (N,3)
794
+ v00 = np.vstack(hp.pix2vec(nside_full, I[0], nest=True)).T
795
+ v10 = np.vstack(hp.pix2vec(nside_full, I[1], nest=True)).T
796
+ v01 = np.vstack(hp.pix2vec(nside_full, I[2], nest=True)).T
797
+ v11 = np.vstack(hp.pix2vec(nside_full, I[3], nest=True)).T
798
+
799
+ # 5) Tangent-plane basis at average corner direction (robust)
800
+ v_c = v00 + v10 + v01 + v11
801
+ v_c /= np.linalg.norm(v_c, axis=1, keepdims=True) + 1e-15
802
+
803
+ tmp = v10 - v00
804
+ tmp -= (tmp * v_c).sum(1, keepdims=True) * v_c
805
+ # if degenerate, pick an arbitrary perpendicular
806
+ bad = (np.linalg.norm(tmp, axis=1) < 1e-12)
807
+ if np.any(bad):
808
+ ref = np.zeros_like(v_c)
809
+ ref[:, 0] = 1.0
810
+ # if nearly collinear, use y-axis
811
+ mask = (np.abs((ref * v_c).sum(1)) > 0.99)
812
+ ref[mask] = np.array([0.0, 1.0, 0.0])
813
+ tmp[bad] = ref[bad] - (ref[bad] * v_c[bad]).sum(1, keepdims=True) * v_c[bad]
814
+
815
+ e1 = tmp / (np.linalg.norm(tmp, axis=1, keepdims=True) + 1e-15)
816
+ e2 = np.cross(v_c, e1)
817
+
818
+ def proj(v): # (N,3) -> (N,2)
819
+ return np.stack([(v * e1).sum(1), (v * e2).sum(1)], axis=1)
820
+
821
+ p_tgt = proj(v_tgt)
822
+ p00 = proj(v00)
823
+ p10 = proj(v10)
824
+ p01 = proj(v01)
825
+ p11 = proj(v11)
826
+
827
+ # 6) Solve for (tx, ty) in bilinear map via GN refinement
828
+ a = p10 - p00
829
+ b = p01 - p00
830
+ c = p11 - p10 - p01 + p00
831
+ rhs = p_tgt - p00
832
+
833
+ AtA_00 = (a * a).sum(1)
834
+ AtA_11 = (b * b).sum(1)
835
+ AtA_01 = (a * b).sum(1)
836
+ det = AtA_00 * AtA_11 - AtA_01 * AtA_01
837
+ det[det == 0] = 1e-15
838
+ At_rhs0 = (a * rhs).sum(1)
839
+ At_rhs1 = (b * rhs).sum(1)
840
+ tx = ( AtA_11 * At_rhs0 - AtA_01 * At_rhs1) / det
841
+ ty = (-AtA_01 * At_rhs0 + AtA_00 * At_rhs1) / det
842
+
843
+ P_est = p00 + a * tx[:, None] + b * ty[:, None] + c * (tx * ty)[:, None]
844
+ res = rhs - (P_est - p00)
845
+ J0 = a + c * ty[:, None]
846
+ J1 = b + c * tx[:, None]
847
+ JTJ_00 = (J0 * J0).sum(1)
848
+ JTJ_11 = (J1 * J1).sum(1)
849
+ JTJ_01 = (J0 * J1).sum(1)
850
+ det2 = JTJ_00 * JTJ_11 - JTJ_01 * JTJ_01
851
+ det2[det2 == 0] = 1e-15
852
+ JTres0 = (J0 * res).sum(1)
853
+ JTres1 = (J1 * res).sum(1)
854
+ dtx = ( JTJ_11 * JTres0 - JTJ_01 * JTres1) / det2
855
+ dty = (-JTJ_01 * JTres0 + JTJ_00 * JTres1) / det2
856
+ tx += dtx
857
+ ty += dty
858
+
859
+ tx = np.clip(tx, 0.0, 1.0)
860
+ ty = np.clip(ty, 0.0, 1.0)
861
+
862
+ # 7) Bilinear weights
863
+ W = np.empty((4, N), dtype=np.float64)
864
+ W[0] = (1 - tx) * (1 - ty)
865
+ W[1] = tx * (1 - ty)
866
+ W[2] = (1 - tx) * ty
867
+ W[3] = tx * ty
868
+
869
+ if edge_mode == "drop":
870
+ for k in range(4):
871
+ W[k, drop_mask[k]] = 0.0
872
+ s = W.sum(axis=0)
873
+ ok = s > 0
874
+ W[:, ok] /= s[ok]
875
+
876
+ return I, W
877
+
878
+
879
+ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
880
+ LPT=None,
881
+ LP=None,
882
+ max_iter=100,
883
+ tol=1e-8,
884
+ verbose=True):
885
+ """
886
+ Solve the normal equation (Pᵗ P) x = Pᵗ y using the Conjugate Gradient method.
887
+
888
+ Parameters
889
+ ----------
890
+ data : array_like
891
+ Observed UTM data y ∈ ℝᵐ
892
+ x0 : array_like
893
+ Initial guess for solution x ∈ ℝⁿ (HEALPix domain)
894
+ www : interpolation weights
895
+ all_idx : interpolation indices
896
+ LPT : implementation of adjoint operator Pᵗ
897
+ LP : implementation of forward operator P
898
+ max_iter: maximum number of CG iterations
899
+ tol : stopping tolerance on residual norm
900
+ verbose : print convergence info every 50 iterations
901
+
902
+ Returns
903
+ -------
904
+ x : estimated HEALPix solution u ∈ ℝⁿ
905
+ """
906
+
907
+
908
+ def default_P(x, W, indices):
909
+ """
910
+ Forward operator: P(x) = projection of HEALPix map x onto the UTM grid.
911
+
912
+ Steps:
913
+ - Apply spherical convolution with kernel w(x,y).
914
+ - Interpolate from HEALPix cells to UTM pixels using weights W and indices.
915
+ """
916
+ return np.sum(x[indices] * W, 0)
917
+
918
+ def default_PT(y, W, indices, hit):
919
+ """
920
+ Adjoint operator: Pᵗ(y) = back-projection from UTM grid to HEALPix cells.
921
+
922
+ Steps:
923
+ - Distribute UTM values y back onto contributing HEALPix cells using W.
924
+ - Apply hit normalization (inverse of pixel coverage).
925
+ - Apply spherical convolution with kernel w(x,y).
926
+ """
927
+ value = np.bincount(indices.flatten(),
928
+ weights=(W * y[None,:]).flatten()) * hit
929
+ return value
930
+
931
+ if LPT is None:
932
+ LP=default_P
933
+ LPT=default_PT
934
+
935
+ x = x0.copy()
936
+
937
+ # Compute pixel coverage normalization (hit map)
938
+ hit = np.bincount(all_idx.flatten(), weights=www.flatten())
939
+ hit[hit > 0] = 1 / hit[hit > 0]
940
+
941
+ # Compute b = Pᵗ y
942
+ b = LPT(data, www, all_idx, hit)
943
+
944
+ # Compute initial residual r = b - A x, with A = Pᵗ P
945
+ Ax = LPT(LP(x, www, all_idx), www, all_idx, hit)
946
+ r = b - Ax
947
+
948
+ # Initialize direction
949
+ p = r.copy()
950
+ rs_old = np.dot(r, r)
951
+
952
+ for i in range(max_iter):
953
+ # Compute A p = Pᵗ P p
954
+ Ap = LPT(LP(p, www, all_idx), www, all_idx, hit)
955
+
956
+ alpha = rs_old / np.dot(p, Ap)
957
+ x += alpha * p
958
+ r -= alpha * Ap
959
+
960
+ rs_new = np.dot(r, r)
961
+
962
+ if verbose and i % 50 == 0:
963
+ print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
964
+
965
+ if np.sqrt(rs_new) < tol:
966
+ if verbose:
967
+ print(f"Converged. Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
968
+ break
969
+
970
+ p = r + (rs_new / rs_old) * p
971
+ rs_old = rs_new
972
+
973
+ return x
974
+
975
+
976
+ def spectrum_polar_to_cartesian(
977
+ w,
978
+ scales=None, # radial *values* (see scale_kind)
979
+ orientations=None, # angles in radians (uniform)
980
+ n_pixels=512,
981
+ r_max=None,
982
+ method="bilinear",
983
+ fill_value=0.0,
984
+ *,
985
+ scale_kind="frequency", # "frequency" or "size"
986
+ size_to_freq_factor=1.0, # if scale_kind="size": freq = size_to_freq_factor / size
987
+ ):
988
+ """
989
+ If scale_kind == "frequency":
990
+ `scales` are already radii in frequency units, strictly increasing (low->high freq).
991
+ If scale_kind == "size":
992
+ `scales` are spatial sizes (e.g., km, px), strictly increasing (small->large size),
993
+ and they are converted to frequency radii by: freq = size_to_freq_factor / size.
994
+ Choose size_to_freq_factor to get the units you want (e.g., 1.0 for cycles/size).
995
+ """
996
+ from math import pi
997
+ w = np.asarray(w)
998
+ if w.ndim != 2:
999
+ raise ValueError("w must be (Nscale, Norientation)")
1000
+ ns, no = w.shape
1001
+
1002
+ # ---- handle scales ----
1003
+ if scales is None:
1004
+ # default dyadic: sizes OR frequencies depending on scale_kind
1005
+ base = 2.0 ** np.arange(ns, dtype=float)
1006
+ if scale_kind == "frequency":
1007
+ scales = base # 1,2,4,... as frequencies
1008
+ elif scale_kind == "size":
1009
+ # sizes: 1,2,4,... -> convert to frequency
1010
+ scales = size_to_freq_factor / base
1011
+ else:
1012
+ raise ValueError("scale_kind must be 'frequency' or 'size'")
1013
+ else:
1014
+ scales = np.asarray(scales, dtype=float)
1015
+ if len(scales) != ns:
1016
+ raise ValueError("len(scales) must match Nscale")
1017
+ if scale_kind == "frequency":
1018
+ pass # already radii
1019
+ elif scale_kind == "size":
1020
+ # convert sizes -> frequency radii
1021
+ scales = size_to_freq_factor / scales
1022
+ else:
1023
+ raise ValueError("scale_kind must be 'frequency' or 'size'")
1024
+
1025
+ # After conversion, we need strictly increasing radii (low->high frequency)
1026
+ if not np.all(np.diff(scales) > 0):
1027
+ # If your provided sizes were increasing, 1/size is decreasing => reverse order
1028
+ # and reorder w accordingly along the radial axis.
1029
+ order = np.argsort(scales)
1030
+ scales = scales[order]
1031
+ w = w[order, :]
1032
+
1033
+ # ---- orientations (uniform over [0, 2π) ) ----
1034
+ if orientations is None:
1035
+ orientations = np.linspace(0.0, 2*np.pi, no, endpoint=False)
1036
+ else:
1037
+ orientations = np.asarray(orientations, dtype=float)
1038
+ if len(orientations) != no:
1039
+ raise ValueError("len(orientations) must match Norientation")
1040
+
1041
+ # ---- call the previous core (unchanged) ----
1042
+ return _spectrum_polar_to_cartesian_core(
1043
+ w, scales, orientations, n_pixels, r_max, method, fill_value
1044
+ )
1045
+
1046
+ def _spectrum_polar_to_cartesian_core(
1047
+ w, scales, orientations, n_pixels, r_max, method, fill_value
1048
+ ):
1049
+ """Core function from before (unchanged logic), expects increasing frequency radii."""
1050
+ ns, no = w.shape
1051
+ if r_max is None:
1052
+ r_max = float(np.max(scales))
1053
+
1054
+ kx = np.linspace(-r_max, r_max, n_pixels, dtype=float)
1055
+ ky = np.linspace(-r_max, r_max, n_pixels, dtype=float)
1056
+ KX, KY = np.meshgrid(kx, ky, indexing="xy")
1057
+ R = np.hypot(KX, KY)
1058
+ Theta = -np.mod(np.arctan2(KY, KX), 2.0*np.pi)
1059
+
1060
+ radial_index = np.interp(R, scales, np.arange(ns, dtype=float),
1061
+ left=np.nan, right=np.nan)
1062
+ dtheta = (2.0*np.pi) / no
1063
+ angular_index = Theta / dtheta
1064
+
1065
+ valid = np.isfinite(radial_index)
1066
+ try:
1067
+ from scipy.ndimage import map_coordinates
1068
+ order = 3 if method.lower() == "bicubic" else 1
1069
+ coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
1070
+ eps = 1e-6
1071
+ coords[0, :] = np.where(np.isfinite(coords[0, :]),
1072
+ np.clip(coords[0, :], 0.0+eps, (ns-1)-eps), 0.0)
1073
+ sampled = map_coordinates(
1074
+ w, coords, order=order, mode="wrap", cval=fill_value, prefilter=True
1075
+ ).reshape(n_pixels, n_pixels)
1076
+ img = np.where(valid, sampled, fill_value)
1077
+ except Exception:
1078
+ # bilinear fallback
1079
+ r_idx = np.floor(radial_index).astype(np.int64)
1080
+ t_idx = np.floor(angular_index).astype(np.int64)
1081
+ r_idx = np.clip(r_idx, 0, ns-2)
1082
+ t0 = np.mod(t_idx, no)
1083
+ t1 = np.mod(t_idx+1, no)
1084
+ tr = np.clip(radial_index - r_idx, 0.0, 1.0)
1085
+ ta = np.clip(angular_index - t_idx, 0.0, 1.0)
1086
+ f00 = w[r_idx, t0]
1087
+ f01 = w[r_idx, t1]
1088
+ f10 = w[r_idx+1, t0]
1089
+ f11 = w[r_idx+1, t1]
1090
+ g0 = (1.0 - ta) * f00 + ta * f01
1091
+ g1 = (1.0 - ta) * f10 + ta * f11
1092
+ img = (1.0 - tr) * g0 + tr * g1
1093
+ img = np.where(valid, img, fill_value)
1094
+
1095
+ return img, kx, ky
1096
+
1097
+ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
1098
+ img, kx, ky = spectrum_polar_to_cartesian(
1099
+ wave,
1100
+ scales=2**np.arange(wave.shape[0]), # tailles croissantes
1101
+ scale_kind="size", # conversion automatique vers fréquence
1102
+ size_to_freq_factor=50.0, # cycles / (unit of size) (Sentinel-2 10m résolution ~to 20m resoltuion for smaller scale; equiv. 50 cycles/km
1103
+ method="bicubic",
1104
+ n_pixels=512,
1105
+ )
1106
+ plt.imshow(
1107
+ img,
1108
+ extent=[kx[0], kx[-1], ky[0], ky[-1]],
1109
+ origin="lower",
1110
+ aspect="equal",
1111
+ cmap=cmap,
1112
+ )
1113
+ plt.colorbar(label=unit,shrink=0.5)
1114
+ plt.xlabel(r"$k_x$ [cycles / km]")
1115
+ plt.ylabel(r"$k_y$ [cycles / km]")
1116
+ plt.title(title)
1117
+
1118
+ def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
1119
+ """
1120
+ Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ shape : tuple(int, int)
1125
+ (H, W) of the image.
1126
+ ref_lon, ref_lat : float
1127
+ Reference coordinate in degrees. Interpreted according to `anchor`.
1128
+ dlon, dlat : float
1129
+ Pixel size in degrees along x (lon) and y (lat). Use positives.
1130
+ anchor : {"center","topleft","topright","bottomleft","bottomright"}
1131
+ Where (ref_lon, ref_lat) sits relative to the image.
1132
+
1133
+ Returns
1134
+ -------
1135
+ lon_edges, lat_edges : 2D arrays of shape (H+1, W+1)
1136
+ Corner coordinates suitable for `pcolormesh`.
1137
+ """
1138
+ H, W = shape
1139
+ dlon = float(dlon)
1140
+ dlat = float(dlat)
1141
+
1142
+ # center of the grid in lon/lat
1143
+ if anchor == "center":
1144
+ lon0 = ref_lon
1145
+ lat0 = ref_lat
1146
+ elif anchor == "topleft":
1147
+ lon0 = ref_lon + (W/2.0 - 0.5) * dlon
1148
+ lat0 = ref_lat - (H/2.0 - 0.5) * dlat
1149
+ elif anchor == "topright":
1150
+ lon0 = ref_lon - (W/2.0 - 0.5) * dlon
1151
+ lat0 = ref_lat - (H/2.0 - 0.5) * dlat
1152
+ elif anchor == "bottomleft":
1153
+ lon0 = ref_lon + (W/2.0 - 0.5) * dlon
1154
+ lat0 = ref_lat + (H/2.0 - 0.5) * dlat
1155
+ elif anchor == "bottomright":
1156
+ lon0 = ref_lon - (W/2.0 - 0.5) * dlon
1157
+ lat0 = ref_lat + (H/2.0 - 0.5) * dlat
1158
+ else:
1159
+ raise ValueError("anchor must be one of: center/topleft/topright/bottomleft/bottomright")
1160
+
1161
+ # 1D edges (corners) along lon/lat, centered on (lon0, lat0)
1162
+ lon_edges_1d = lon0 + (np.arange(W + 1) - W/2.0) * dlon
1163
+ lat_edges_1d = lat0 + (np.arange(H + 1) - H/2.0) * dlat
1164
+
1165
+ # 2D corner grids (H+1, W+1)
1166
+ lon_edges, lat_edges = np.meshgrid(lon_edges_1d, lat_edges_1d, indexing="xy")
1167
+ return lon_edges, lat_edges
1168
+
1169
+
1170
+ def plot_image_lonlat(img, lon_edges, lat_edges, cmap="viridis", vmin=None, vmax=None):
1171
+ """
1172
+ Plot a 2D image on a lon/lat grid using pcolormesh (no reprojection).
1173
+ """
1174
+ #fig, ax = plt.subplots(figsize=(7, 5))
1175
+ m = plt.pcolormesh(lon_edges, lat_edges, img, cmap=cmap, vmin=vmin, vmax=vmax, shading="flat")
1176
+ #plt.colorbar(m, ax=ax, label="Intensity")
1177
+ #ax.set_xlabel("Longitude (deg)")
1178
+ #ax.set_ylabel("Latitude (deg)")
1179
+ #ax.set_aspect("equal") # keeps degrees square; remove if you prefer auto
1180
+ # add a small margin
1181
+ #ax.set_xlim(lon_edges.min(), lon_edges.max())
1182
+ #ax.set_ylim(lat_edges.min(), lat_edges.max())
1183
+ return fig, ax
1184
+
1185
+ import matplotlib.tri as mtri
1186
+
1187
+ def _edges_from_centers_2d(C):
1188
+ """
1189
+ Compute (H+1, W+1) cell-corner coordinates from a 2D array of cell centers (H, W).
1190
+ This works for a *structured* grid indexed by (i, j) even if the physical spacing
1191
+ is non-uniform and warped.
1192
+
1193
+ Strategy (robust and common in geosciences):
1194
+ 1) Extrapolate one ghost cell around the array using first-order linear extrapolation.
1195
+ 2) Corners are the mean of the 2x2 block of surrounding centers in the padded array.
1196
+
1197
+ Parameters
1198
+ ----------
1199
+ C : (H, W) ndarray
1200
+ 2D field of *centers* (e.g., lat or lon at pixel centers).
1201
+
1202
+ Returns
1203
+ -------
1204
+ E : (H+1, W+1) ndarray
1205
+ 2D field of *edges* (corners), suitable for pcolormesh.
1206
+ """
1207
+ C = np.asarray(C)
1208
+ H, W = C.shape
1209
+
1210
+ # 1) Pad by one cell on all sides using linear extrapolation
1211
+ Cp = np.empty((H + 2, W + 2), dtype=C.dtype)
1212
+ Cp[1:-1, 1:-1] = C
1213
+
1214
+ # Edges (extrapolate outward along each axis)
1215
+ Cp[0, 1:-1] = 2*C[0, :] - C[1, :] # top row
1216
+ Cp[-1, 1:-1] = 2*C[-1, :] - C[-2, :] # bottom row
1217
+ Cp[1:-1, 0] = 2*C[:, 0] - C[:, 1] # left col
1218
+ Cp[1:-1, -1] = 2*C[:, -1] - C[:, -2] # right col
1219
+
1220
+ # Corners of the padded array: extrapolate diagonally
1221
+ Cp[0, 0] = 2*C[0, 0] - C[1, 1]
1222
+ Cp[0, -1] = 2*C[0, -1] - C[1, -2]
1223
+ Cp[-1, 0] = 2*C[-1, 0] - C[-2, 1]
1224
+ Cp[-1, -1] = 2*C[-1, -1] - C[-2, -2]
1225
+
1226
+ # 2) Average 2x2 blocks to get corners (H+1, W+1)
1227
+ E = 0.25 * (Cp[:-1, :-1] + Cp[1:, :-1] + Cp[:-1, 1:] + Cp[1:, 1:])
1228
+ return E
1229
+
1230
+
1231
+ def plot_image_latlon(fig,ax,img, lat, lon, mode="structured", cmap="viridis", vmin=None, vmax=None,
1232
+ shading="flat", aspect="equal"):
1233
+ """
1234
+ Plot an image given per-pixel lat/lon coordinates.
1235
+
1236
+ Parameters
1237
+ ----------
1238
+ img : (H, W) ndarray
1239
+ Image values per pixel.
1240
+ lat, lon : (H, W) ndarray
1241
+ Latitude and longitude at *pixel centers* (same shape as `img`).
1242
+ mode : {"structured", "scattered"}
1243
+ - "structured": (i, j) grid is regular (rectangular index space), possibly warped.
1244
+ We'll compute per-cell corners and use pcolormesh.
1245
+ - "scattered" : pixels are not on a regular (i, j) grid. We'll triangulate points and use tripcolor.
1246
+ cmap, vmin, vmax : matplotlib colormap settings.
1247
+ shading : {"flat","gouraud"} for pcolormesh/tripcolor. "flat" = one color per cell/triangle.
1248
+ aspect : matplotlib aspect for axes, e.g. "equal" or "auto".
1249
+
1250
+ Returns
1251
+ -------
1252
+ fig, ax : matplotlib Figure and Axes
1253
+ artist : QuadMesh (structured) or PolyCollection (scattered)
1254
+ """
1255
+ img = np.asarray(img)
1256
+ lat = np.asarray(lat)
1257
+ lon = np.asarray(lon)
1258
+
1259
+ if mode == "structured":
1260
+ if img.shape != lat.shape or img.shape != lon.shape:
1261
+ raise ValueError("For 'structured' mode, img, lat, lon must have the same (H, W) shape.")
1262
+
1263
+ # Compute *corner* grids (H+1, W+1) from center grids (H, W)
1264
+ lat_edges = _edges_from_centers_2d(lat)
1265
+ lon_edges = _edges_from_centers_2d(lon)
1266
+
1267
+ m = ax.pcolormesh(lon_edges, lat_edges, img, cmap=cmap, vmin=vmin, vmax=vmax, shading=shading)
1268
+ plt.colorbar(m, ax=ax, label="reflectance",shrink=0.5)
1269
+ ax.set_xlabel("Longitude (deg)")
1270
+ ax.set_ylabel("Latitude (deg)")
1271
+ ax.set_aspect(aspect)
1272
+ ax.set_xlim(np.nanmin(lon_edges), np.nanmax(lon_edges))
1273
+ ax.set_ylim(np.nanmin(lat_edges), np.nanmax(lat_edges))
1274
+ return fig, ax, m
1275
+
1276
+ elif mode == "scattered":
1277
+ # Flatten and remove NaNs before triangulation
1278
+ z = img.ravel()
1279
+ x = lon.ravel()
1280
+ y = lat.ravel()
1281
+ mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
1282
+ x, y, z = x[mask], y[mask], z[mask]
1283
+
1284
+ # Triangulate in lon/lat plane
1285
+ tri = mtri.Triangulation(x, y)
1286
+
1287
+ fig, ax = plt.subplots(figsize=(7, 5))
1288
+ m = ax.tripcolor(tri, z, cmap=cmap, vmin=vmin, vmax=vmax, shading=shading)
1289
+ plt.colorbar(m, ax=ax, label="reflectance",shrink=0.5)
1290
+ ax.set_xlabel("Longitude (deg)")
1291
+ ax.set_ylabel("Latitude (deg)")
1292
+ ax.set_aspect(aspect)
1293
+ ax.set_xlim(np.nanmin(x), np.nanmax(x))
1294
+ ax.set_ylim(np.nanmin(y), np.nanmax(y))
1295
+ return m
1296
+
1297
+ else:
1298
+ raise ValueError("mode must be 'structured' or 'scattered'")