asteroid_spinprops 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssolib/dataprep.py ADDED
@@ -0,0 +1,554 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import sys
5
+ import os
6
+ from ssolib.modelfit import (
7
+ get_fit_params,
8
+ get_residuals,
9
+ )
10
+
11
+ import ssolib.utils as utils
12
+
13
+
14
+ def errorbar_filtering(data, mlimit):
15
+ """
16
+ Filter out data points with large photometric uncertainties.
17
+
18
+ Parameters
19
+ -----------
20
+ data : pd.DataFrame
21
+ A single-row DataFrame where each column contains an array of values
22
+ for a solar system object.
23
+ mlimit : float
24
+ Threshold value to filter out points with uncertainties greater than mlimit / 2.
25
+
26
+ Returns
27
+ -------
28
+ data : pd.DataFrame
29
+ Filtered DataFrame
30
+ rejects : pd.DataFrame
31
+ DataFrame containing the rejected measurements
32
+ """
33
+ errorbar_condition = data["csigmapsf"].values[0] <= mlimit / 2
34
+ rejects = data.copy()
35
+
36
+ for c in data.columns:
37
+ if c not in ["index", "kast", "name"]:
38
+ rejects.at[0, c] = data[c].values[0][~errorbar_condition]
39
+ data.at[0, c] = data[c].values[0][errorbar_condition]
40
+
41
+ return data, rejects
42
+
43
+
44
+ def projection_filtering(data):
45
+ """
46
+ Filters out photometric outliers in reduced magnitude space per filter using a 3 sigma criterion.
47
+
48
+ Parameters
49
+ -----------
50
+ data : pd.DataFrame
51
+ A single-row DataFrame where each column contains an array of values.
52
+ Returns
53
+ --------
54
+ data : pd.DataFrame
55
+ Filtered DataFrame
56
+ rejects : pd.DataFrame
57
+ DataFrame containing the rejected measurements
58
+ """
59
+ rejects = data.copy()
60
+ valid_indices = []
61
+
62
+ for f in np.unique(data["cfid"].values[0]):
63
+ filter_mask = np.array(data["cfid"].values[0]) == f
64
+
65
+ mean_val = np.mean(data["cmred"].values[0][filter_mask])
66
+ std_val = np.std(data["cmred"].values[0][filter_mask])
67
+
68
+ project_condition = (
69
+ filter_mask
70
+ & (data["cmred"].values[0] > mean_val - 3 * std_val)
71
+ & (data["cmred"].values[0] < mean_val + 3 * std_val)
72
+ )
73
+
74
+ valid_indices.append(np.where(project_condition)[0])
75
+
76
+ valid_indices = np.sort(
77
+ np.concatenate([valid_indices[n] for n in range(len(valid_indices))])
78
+ )
79
+
80
+ dummy = np.ones(data["cfid"].values[0].shape, dtype=bool)
81
+ dummy[valid_indices] = False
82
+
83
+ for c in data.columns:
84
+ if c not in ["index", "kast", "name"]:
85
+ rejects.at[0, c] = data[c].values[0][dummy]
86
+ data.at[0, c] = data[c].values[0][valid_indices]
87
+
88
+ return data, rejects
89
+
90
+
91
+ def iterative_filtering(data, max_iter=10):
92
+ """
93
+ Iteratively removes outliers based on residuals from fitting the SHG1G2 mdoel until convergence.
94
+
95
+ Parameters
96
+ -----------
97
+ data : pd.DataFrame
98
+ A single-row DataFrame where each column contains an array of values.
99
+
100
+ max_iter : int
101
+ Maximum number of filtering iterations (default is 10).
102
+
103
+ Returns
104
+ --------
105
+ data : pd.DataFrame
106
+ Filtered DataFrame
107
+
108
+ rejects : pd.DataFrame
109
+ DataFrame containing the rejected measurements
110
+ """
111
+ rejects = data.copy()
112
+
113
+ mask = np.ones_like(data["cfid"].values[0], dtype=bool)
114
+ inloop_quants = {}
115
+ reject_quants = {}
116
+
117
+ for c in data.columns:
118
+ if c not in ["index", "kast", "name"]:
119
+ inloop_quants[c] = data[c].values[0]
120
+ reject_quants[c] = np.array([])
121
+
122
+ for niter in range(max_iter):
123
+ prev_len = len(inloop_quants["cfid"])
124
+
125
+ for k in inloop_quants.keys():
126
+ reject_quants[k] = np.append(reject_quants[k], inloop_quants[k][~mask])
127
+ inloop_quants[k] = inloop_quants[k][mask]
128
+
129
+ mparams = get_fit_params(pd.DataFrame([inloop_quants]), "SHG1G2")
130
+ try:
131
+ residuals = get_residuals(pd.DataFrame([inloop_quants]), mparams)
132
+ except KeyError:
133
+ break
134
+ mask = np.abs(residuals) < 3 * np.std(residuals)
135
+
136
+ if prev_len == len(inloop_quants["Phase"][mask]):
137
+ break
138
+
139
+ for c in data.columns:
140
+ if c not in ["index", "kast", "name"]:
141
+ data.at[0, c] = inloop_quants[c]
142
+ rejects.at[0, c] = reject_quants[c]
143
+ return data, rejects
144
+
145
+
146
+ def lightcurve_filtering(data, window=10, maglim=0.4):
147
+ """
148
+ Filters out lightcurve points that deviate from the median by more than given mag limitation within time bins.
149
+
150
+ Parameters
151
+ ----------
152
+ data : pd.DataFrame
153
+ Single-row DataFrame
154
+ window : float
155
+ Time bin size (default is 10 days).
156
+ maglim : float
157
+ Magnitude deviation threshold from the median (default is 0.4 mag).
158
+
159
+ Returns
160
+ -------
161
+ data : pd.DataFrame
162
+ Filtered data
163
+ rejects : pd.DataFrame
164
+ DataFrame containing the rejected measurements
165
+ """
166
+ dummym, dummyt, dummyf, dummyi = [], [], [], []
167
+
168
+ dates = data["cjd"].values[0]
169
+ magnitudes = data["cmred"].values[0]
170
+ filters = data["cfid"].values[0]
171
+ indices = np.array([ind for ind in range(len(data["cfid"].values[0]))])
172
+
173
+ ufilters = np.unique(filters)
174
+
175
+ mag_pfilt = {}
176
+
177
+ date0 = dates.min()
178
+ date0_plus_step = date0 + window
179
+ # TODO: Use np.digitize instead of this
180
+ while date0 < dates.max():
181
+ prev_ind = np.where(dates == utils.find_nearest(dates, date0))[0][0]
182
+ plus_ten_index = np.where(dates == utils.find_nearest(dates, date0_plus_step))[
183
+ 0
184
+ ][0]
185
+
186
+ dummym.append(magnitudes[prev_ind:plus_ten_index])
187
+ dummyt.append(dates[prev_ind:plus_ten_index])
188
+ dummyf.append(filters[prev_ind:plus_ten_index])
189
+ dummyi.append(indices[prev_ind:plus_ten_index])
190
+
191
+ date0 = dates[plus_ten_index]
192
+ date0_plus_step = date0_plus_step + window
193
+
194
+ dummym.append(magnitudes[plus_ten_index:])
195
+ dummyt.append(dates[plus_ten_index:])
196
+ dummyf.append(filters[plus_ten_index:])
197
+ dummyi.append(indices[plus_ten_index:])
198
+
199
+ mag_binned, _, filt_binned, ind_binned = (
200
+ np.asarray(dummym, dtype=object),
201
+ np.asarray(dummyt, dtype=object),
202
+ np.asarray(dummyf, dtype=object),
203
+ np.asarray(dummyi, dtype=object),
204
+ )
205
+
206
+ for f in ufilters:
207
+ dummymain, dummym, dummyt, dummydiff, dummyi = [], [], [], [], []
208
+ for n in range(len(mag_binned)):
209
+ fcond = filt_binned[n] == f
210
+ dummymain.append(mag_binned[n][fcond])
211
+ dummym.append(np.median(mag_binned[n][fcond]))
212
+ dummydiff.append(
213
+ np.max(mag_binned[n][fcond], initial=0)
214
+ - np.min(mag_binned[n][fcond], initial=1e3)
215
+ )
216
+ dummyi.append(ind_binned[n][fcond])
217
+
218
+ dummydiff = np.array(dummydiff)
219
+ dummydiff[dummydiff == np.float64(-1000.0)] = 0
220
+
221
+ mag_pfilt["medimag_{}".format(f)] = dummym
222
+ mag_pfilt["mxmnmag_{}".format(f)] = dummydiff
223
+ mag_pfilt["mag_{}".format(f)] = dummymain
224
+ mag_pfilt["ind_{}".format(f)] = dummyi
225
+
226
+ valid_indices = []
227
+ reject_indices = []
228
+
229
+ rejects = data.copy()
230
+
231
+ for f in ufilters:
232
+ for n in range(len(mag_binned)):
233
+ bin_cond = (
234
+ mag_pfilt["mag_{}".format(f)][n]
235
+ > mag_pfilt["medimag_{}".format(f)][n] + maglim
236
+ ) | (
237
+ mag_pfilt["mag_{}".format(f)][n]
238
+ < mag_pfilt["medimag_{}".format(f)][n] - maglim
239
+ )
240
+ valid_indices.append(mag_pfilt["ind_{}".format(f)][n][~bin_cond])
241
+ reject_indices.append(mag_pfilt["ind_{}".format(f)][n][bin_cond])
242
+
243
+ valid_indices = np.array(utils.flatten_list(valid_indices), dtype=int)
244
+ reject_indices = np.array(utils.flatten_list(reject_indices), dtype=int)
245
+
246
+ for c in data.columns:
247
+ if c not in ["index", "kast", "name"]:
248
+ rejects.at[0, c] = data[c].values[0][reject_indices]
249
+ data.at[0, c] = data[c].values[0][valid_indices]
250
+
251
+ data = utils.sort_by_cjd(data)
252
+
253
+ return data, rejects
254
+
255
+
256
+ def plot_filtering(
257
+ clean_data, rejects, lc_filtering=True, iter_filtering=True, xaxis="Phase"
258
+ ):
259
+ if xaxis == "Date":
260
+ coll = "cjd"
261
+ if xaxis == "Phase":
262
+ coll = "Phase"
263
+ errorbar_rejects, projection_rejects, iterative_rejects, lightcurve_rejects = (
264
+ rejects[0],
265
+ rejects[1],
266
+ rejects[2],
267
+ rejects[3],
268
+ )
269
+
270
+ fig, ax = plt.subplots(2, 2, figsize=(12, 6))
271
+
272
+ filter_names = ["ZTF g", "ZTF r", "ATLAS orange", "ATLAS cyan"]
273
+
274
+ for i, f in enumerate(np.unique(clean_data["cfid"].values[0])):
275
+ if f in [1, 2]:
276
+ row = 0
277
+ if f in [3, 4]:
278
+ row = 1
279
+
280
+ if i % 2 != 0:
281
+ col = 1
282
+ else:
283
+ col = 0
284
+
285
+ filter_mask = np.array(clean_data["cfid"].values[0]) == f
286
+ filter_mask_r1 = np.array(errorbar_rejects["cfid"].values[0]) == f
287
+ filter_mask_r2 = np.array(projection_rejects["cfid"].values[0]) == f
288
+
289
+ if iter_filtering is True:
290
+ filter_mask_r3 = np.array(iterative_rejects["cfid"].values[0]) == f
291
+ else:
292
+ filter_mask_r3 = None
293
+
294
+ if lc_filtering is True:
295
+ filter_mask_r4 = np.array(lightcurve_rejects["cfid"].values[0]) == f
296
+ else:
297
+ filter_mask_r4 = None
298
+
299
+ ax[row, col].errorbar(
300
+ x=clean_data[coll].values[0][filter_mask],
301
+ y=clean_data["cmred"].values[0][filter_mask],
302
+ yerr=clean_data["csigmapsf"].values[0][filter_mask],
303
+ fmt=".",
304
+ capsize=2,
305
+ ms=5,
306
+ elinewidth=1,
307
+ label="Valid points",
308
+ )
309
+
310
+ ax[row, col].errorbar(
311
+ x=errorbar_rejects[coll].values[0][filter_mask_r1],
312
+ y=errorbar_rejects["cmred"].values[0][filter_mask_r1],
313
+ yerr=errorbar_rejects["csigmapsf"].values[0][filter_mask_r1],
314
+ fmt="x",
315
+ capsize=2,
316
+ ms=15,
317
+ elinewidth=1,
318
+ c="tab:red",
319
+ label=r"$\delta m > 3\sigma_{LCDB}$",
320
+ )
321
+
322
+ ax[row, col].errorbar(
323
+ x=projection_rejects[coll].values[0][filter_mask_r2],
324
+ y=projection_rejects["cmred"].values[0][filter_mask_r2],
325
+ yerr=projection_rejects["csigmapsf"].values[0][filter_mask_r2],
326
+ fmt="+",
327
+ capsize=2,
328
+ ms=15,
329
+ elinewidth=1,
330
+ c="tab:green",
331
+ label=r"$\substack{m > \bar{m} + 3\sigma_m \\ m < \bar{m} - 3\sigma_m}$",
332
+ )
333
+ if iter_filtering is True:
334
+ ax[row, col].errorbar(
335
+ x=iterative_rejects[coll].values[0][filter_mask_r3],
336
+ y=iterative_rejects["cmred"].values[0][filter_mask_r3],
337
+ yerr=iterative_rejects["csigmapsf"].values[0][filter_mask_r3],
338
+ fmt=">",
339
+ capsize=2,
340
+ ms=7,
341
+ elinewidth=1,
342
+ label="Iterative",
343
+ )
344
+ if lc_filtering is True:
345
+ ax[row, col].errorbar(
346
+ x=lightcurve_rejects[coll].values[0][filter_mask_r4],
347
+ y=lightcurve_rejects["cmred"].values[0][filter_mask_r4],
348
+ yerr=lightcurve_rejects["csigmapsf"].values[0][filter_mask_r4],
349
+ fmt="P",
350
+ capsize=2,
351
+ ms=7,
352
+ elinewidth=1,
353
+ c="black",
354
+ label="Lightcurve",
355
+ )
356
+
357
+ ax[row, col].invert_yaxis()
358
+ ax[row, col].text(
359
+ 0.05,
360
+ 0.05,
361
+ filter_names[i],
362
+ transform=ax[row, col].transAxes,
363
+ va="bottom",
364
+ ha="left",
365
+ bbox=dict(
366
+ facecolor="white",
367
+ edgecolor="black",
368
+ boxstyle="round,pad=0.3",
369
+ alpha=0.8,
370
+ ),
371
+ )
372
+ if xaxis == "Phase":
373
+ ax[1, 0].set_xlabel("Phase / deg")
374
+ ax[1, 1].set_xlabel("Phase / deg")
375
+ if xaxis == "Date":
376
+ ax[1, 0].set_xlabel("JD")
377
+ ax[1, 1].set_xlabel("JD")
378
+ ax[0, 0].set_ylabel("Reduced magnitude")
379
+ ax[1, 0].set_ylabel("Reduced magnitude")
380
+
381
+ ax[0, 1].legend(loc="upper right")
382
+
383
+
384
+ def filter_sso_data(
385
+ sso_name,
386
+ path_to_data,
387
+ pqdict,
388
+ ephem_path,
389
+ mlimit=0.7928,
390
+ lc_filtering=True,
391
+ iter_filtering=True,
392
+ ):
393
+ """
394
+ Filters data for a given SSO.
395
+ Applies errorbar, projections, iterative sigma-clipping and lightcurve filtering.
396
+
397
+ Parameters
398
+ ----------
399
+ sso_name : str
400
+ The name of the solar system object to filter.
401
+ path_to_data : str
402
+ Path to the data files.
403
+ pqdict : dict
404
+ Dictionary linking parquet filename to SSO.
405
+ ephem_path : str | None
406
+ Path to the ephemeris data.
407
+ mlimit : float, optional
408
+ Magnitude limit for errorbar filtering (default is 0.7928).
409
+ lc_filtering : bool, optional
410
+ Whether to apply lightcurve filtering (default is True).
411
+ iter_filtering : bool, optional
412
+ Whether to apply iterative filtering (default is True).
413
+
414
+ Returns
415
+ -------
416
+ clean_data : pd.DataFrame
417
+ Cleaned data
418
+ rejects : list
419
+ List of rejected data points from each filtering step.
420
+
421
+ Examples
422
+ ---------
423
+
424
+ >>> from ssolib.dataprep import prepare_sso_data, filter_sso_data
425
+ >>> from ssolib.dataprep import __file__
426
+ >>> import os
427
+ >>> import pickle
428
+
429
+ >>> wpath = os.path.dirname(__file__)
430
+
431
+ >>> ephem_path = os.path.join(wpath, "testing/ephemeris_testing")
432
+ >>> data_path = os.path.join(wpath, "testing/atlas_x_ztf_testing")
433
+ >>> pq_keys = os.path.join(wpath, "testing/testing_ssoname_keys.pkl")
434
+ >>> available_ssos = os.listdir(ephem_path)
435
+
436
+ >>> with open(pq_keys, "rb") as f:
437
+ ... pqload = pickle.load(f)
438
+
439
+ >>> path_args = [data_path, pqload, ephem_path]
440
+
441
+ >>> for name in available_ssos:
442
+ ... origin_data = prepare_sso_data(name, *path_args)
443
+ ... cdata, rejects = filter_sso_data(name, *path_args)
444
+ ... clean_p_rejects = cdata["cmred"].values[0].size
445
+ ... for n in range(4):
446
+ ... clean_p_rejects += rejects[n]["cmred"].values[0].size
447
+
448
+ >>> assert origin_data["cmred"].values[0].size == clean_p_rejects
449
+ """
450
+ clean_data = prepare_sso_data(
451
+ sso_name=sso_name,
452
+ path_to_data=path_to_data,
453
+ pqdict=pqdict,
454
+ ephem_path=ephem_path,
455
+ )
456
+
457
+ clean_data, errorbar_rejects = errorbar_filtering(data=clean_data, mlimit=mlimit)
458
+ clean_data, projection_rejects = projection_filtering(data=clean_data)
459
+ if iter_filtering is True:
460
+ clean_data, iterative_rejects = iterative_filtering(clean_data)
461
+ else:
462
+ iterative_rejects = None
463
+
464
+ if lc_filtering is True:
465
+ clean_data, lightcurve_rejects = lightcurve_filtering(clean_data)
466
+ else:
467
+ lightcurve_rejects = None
468
+
469
+ return clean_data, [
470
+ errorbar_rejects,
471
+ projection_rejects,
472
+ iterative_rejects,
473
+ lightcurve_rejects,
474
+ ]
475
+
476
+
477
+ def prepare_sso_data(
478
+ sso_name,
479
+ path_to_data,
480
+ pqdict,
481
+ ephem_path,
482
+ ):
483
+ """
484
+ Load and prepare observational and ephemeris data for a given solar system object (SSO).
485
+
486
+ - Locates the appropriate pqfile containing the SSO & loads its photometric data
487
+ - Retrieves or generates the corresponding ephemeris and appends SSO-Sun & SSO-Obs distances, phase angle, RA/Dec, elongation, reduced magnitude
488
+ - The resulting DataFrame is sorted by Julian date
489
+
490
+ Parameters
491
+ -----------------------------
492
+ sso_name : str
493
+ The name of the solar system object to retrieve data for.
494
+ path_to_data : str
495
+ Path to the directory containing pqfiles in Parquet format.
496
+ pqdict : dict
497
+ Dictionary mapping pqfile names to lists of SSOs they contain.
498
+ ephem_path : str | None
499
+ Path to the directory containing cached ephemeris files.
500
+
501
+ Returns
502
+ -----------------------------
503
+ data_extra : pd.DataFrame
504
+ A DataFrame containing observational data and appended ephemeris-related quantities:
505
+ - 'Dobs': observer-centric distance [au]
506
+ - 'Dhelio': heliocentric distance [au]
507
+ - 'Phase': phase angle [deg]
508
+ - 'ra', 'dec': right ascension and declination [deg]
509
+ - 'Elongation': solar elongation angle [deg]
510
+ - 'cmred': reduced magnitude
511
+ The data is sorted chronologically by Julian date.
512
+ """
513
+
514
+ file_name = utils.find_sso_in_pqdict(sso_name=sso_name, pqdict=pqdict)
515
+ file_path = os.path.join(path_to_data, file_name)
516
+
517
+ pdf = pd.read_parquet(file_path)
518
+
519
+ cond_name = pdf["name"] == sso_name
520
+ data_extra = pdf[cond_name].copy().reset_index(drop=True)
521
+
522
+ ephemeris = utils.get_atlas_ephem(
523
+ pdf=pdf, name=sso_name, path_to_cached_ephems=ephem_path
524
+ )
525
+
526
+ Dobs = ephemeris["Dobs"].values
527
+
528
+ Dhelio = ephemeris["Dhelio"].values
529
+ Phase = ephemeris["Phase"].values
530
+ Elongation = ephemeris["Elong."].values
531
+
532
+ px, py, pz = ephemeris["px"], ephemeris["py"], ephemeris["pz"]
533
+ ra, dec = utils.c2rd(px, py, pz)
534
+
535
+ data_extra[["Dobs", "Dhelio", "Phase", "ra", "dec", "Elongation"]] = pd.DataFrame(
536
+ [[Dobs, Dhelio, Phase, ra.values, dec.values, Elongation]],
537
+ index=data_extra.index,
538
+ )
539
+
540
+ data_extra["cmred"] = [
541
+ utils.calculate_reduced_magnitude(
542
+ magnitude=data_extra["cmagpsf"].values[0],
543
+ D_observer=data_extra["Dobs"].values[0],
544
+ D_sun=data_extra["Dhelio"].values[0],
545
+ )
546
+ ]
547
+ data_extra = utils.sort_by_cjd(data_extra)
548
+ return data_extra
549
+
550
+
551
+ if __name__ == "__main__":
552
+ import doctest
553
+
554
+ sys.exit(doctest.testmod()[0])