ACID-code 2.0.0a2__py3-none-any.whl → 2.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ACID_code/acid.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import traceback
2
3
  import warnings
3
4
  warnings.filterwarnings("ignore")
4
5
  import sys, emcee, os, time, inspect, inspect, contextlib
@@ -15,6 +16,10 @@ from .result import Result
15
16
  from .data import Data, Config, MaskingLines, LineList, DataList
16
17
  from .errors import ContinuumError
17
18
  from .utils import IntLike, Scalar, Array1D, Array2D
19
+ try:
20
+ import dynesty
21
+ except ImportError:
22
+ dynesty = None
18
23
 
19
24
  @beartype
20
25
  class Acid:
@@ -96,7 +101,7 @@ class Acid:
96
101
  By default 2 (medium).
97
102
  sampler_progress : :py:type:`bool`, optional
98
103
  A verbosity override for just the MCMC sampling progress.
99
- By default None which does not override, but if True/False, it will overwrite with that value.
104
+ By default None which does not override, but if True/False, it will overwrite with that value, and use/don't use a tqdm output for the sampler.
100
105
  masking_lines : :py:type:`dict` | :py:class:`MaskingLines`, optional
101
106
  Telluric lines (in angstroms) and widths in (km/s) to mask from the wavelength regions from. Unless you'd like to change the default masking
102
107
  lines, we recommend just using the defaults (leaving this as None), which are based on telluric lines and strong hydrogen/metal lines in the
@@ -231,6 +236,8 @@ class Acid:
231
236
  dev_perc : IntLike|None = None, # Config
232
237
  n_sig : IntLike|None = None, # Config
233
238
  skips : IntLike|None = None, # Config
239
+ od : bool|None = None, # Config
240
+ sampler_type : str|None = None, # Config
234
241
  parallel : bool|None = None, # Config
235
242
  cores : IntLike|None = None, # Config
236
243
  nwalkers : IntLike|None = None, # Config, then Data just before MCMC
@@ -308,6 +315,11 @@ class Acid:
308
315
  skips : :py:type:`IntLike`, optional
309
316
  An option to only run acid on one in every n pixels, where n is the integer argument. This is only useful for
310
317
  testing to get a quicker result especially for larger wavelength ranges or datasets, by default 1 (no skipping)
318
+ sampler_type : :py:type:`str`, optional
319
+ If you really try to wish to use the dynesty nested sampler, you can set this to "dynesty". It is almost entirely unsupported
320
+ by the rest of the code other than to just get a finished result object, and much slower. We highly recommend using None or "emcee" (default).
321
+ The only reason I added this was to get the Bayesian evidence for model comparison.
322
+ If "dynesty" is chosen, the dynesty package needs to be installed, and the nsteps parameter is treated as "nlive" to be passed to the NestedSampler.
311
323
  parallel : :py:type:`bool`, optional
312
324
  If True uses multiprocessing to calculate the profiles for each frame in parallel, see
313
325
  https://acid-code.readthedocs.io/en/stable/using_ACID.html#multiprocessing for more details. By default True
@@ -411,6 +423,8 @@ class Acid:
411
423
  "dev_perc" : dev_perc,
412
424
  "n_sig" : n_sig,
413
425
  "skips" : skips,
426
+ "od" : od,
427
+ "sampler_type" : sampler_type,
414
428
  "parallel" : parallel,
415
429
  "cores" : cores,
416
430
  "nwalkers" : nwalkers,
@@ -435,6 +449,14 @@ class Acid:
435
449
  print("Parallel MCMC on Windows is not currently supported. Running MCMC serially.")
436
450
  self.config.parallel = False
437
451
 
452
+ if self.config.sampler_type == "dynesty":
453
+ if dynesty is None:
454
+ raise ImportError("The 'dynesty' sampler requires the 'dynesty' package to be installed.\nPlease install it with 'pip install dynesty' or choose a different sampler type.")
455
+ if self.config.sampler_type == "dynesty" and not self.config.deterministic_profile:
456
+ raise ValueError("The 'dynesty' sampler can only be run with deterministic_profile=True (otherwise you'll be waiting hours for a single result)")
457
+ if self.config.sampler_type == "dynesty" and self.config.max_steps is not None:
458
+ raise ValueError("Cannot use max_steps as dynesty already natively supports this with live points, set nsteps=nlive. See the dynesty docs for more details.")
459
+
438
460
  # --- Start of the ACID method ---
439
461
 
440
462
  # Setup and data validation done in data class and applies skips
@@ -484,8 +506,9 @@ class Acid:
484
506
  # The code for telluric masking is contained without the MaskingLines class, which both telluric_lines
485
507
  # and hydrogen_lines are instances of.
486
508
  line_mask = self.config.masking_lines.get_masks(self.data.wavelengths["combined"])
487
- line_mask = np.all(line_mask, axis=0)
488
- self.data.errors["combined"][line_mask] = 1e12
509
+ if line_mask != []:
510
+ line_mask = np.all(line_mask, axis=0)
511
+ self.data.errors["combined"][line_mask] = 1e12
489
512
 
490
513
  # Get the initial polynomial coefficents
491
514
  if not hasattr(self.data.wavelengths, "combined_normalized"):
@@ -558,22 +581,25 @@ class Acid:
558
581
  self.data.nwalkers = self.data.ndim * 3 if self.config.nwalkers is None else self.config.nwalkers
559
582
  rng = np.random.default_rng(self.config.seed)
560
583
 
561
- # Starting values of walkers with independent variation
562
- sigma = 0.8 * 0.005
563
- initial_state = []
564
- for i in range(0, len(self.data.model_inputs)):
565
- if i < len(self.data.velocities):
566
- if not self.config.deterministic_profile:
567
- pos = rng.normal(self.data.model_inputs[i], sigma, (self.data.nwalkers, ))
584
+ # Starting values of walkers with independent variation#
585
+ if self.config.sampler_type == "emcee":
586
+ sigma = 0.8 * 0.005
587
+ initial_state = []
588
+ for i in range(0, len(self.data.model_inputs)):
589
+ if i < len(self.data.velocities):
590
+ if not self.config.deterministic_profile:
591
+ pos = rng.normal(self.data.model_inputs[i], sigma, (self.data.nwalkers, ))
592
+ else:
593
+ continue
568
594
  else:
569
- continue
570
- else:
571
- x1 = self.data.model_inputs[i]
572
- rounded_sigma = round(x1, 1-int(floor(log10(abs(x1))))-1)
573
- sigma = abs(rounded_sigma) / 10
574
- pos = rng.normal(self.data.model_inputs[i], sigma, (self.data.nwalkers, ))
575
- initial_state.append(pos)
576
- initial_state = np.array(initial_state).T
595
+ x1 = self.data.model_inputs[i]
596
+ rounded_sigma = round(x1, 1-int(floor(log10(abs(x1))))-1)
597
+ sigma = abs(rounded_sigma) / 10
598
+ pos = rng.normal(self.data.model_inputs[i], sigma, (self.data.nwalkers, ))
599
+ initial_state.append(pos)
600
+ initial_state = np.array(initial_state).T
601
+ else:
602
+ initial_state = None
577
603
 
578
604
  ### ACID initialialised ###
579
605
  self.data.setup_time += time.time() - init_t0
@@ -617,12 +643,12 @@ class Acid:
617
643
  """
618
644
  This method is no longer supported in ACID. Please use the ACID function with the appropriate inputs for HARPS spectra instead.
619
645
  Future versions of ACID will provide functions to load and configure data from a range of different standard instruments.
620
- If you still really wish to use ACID_HARPS, the last stable version of ACID with the method is 1.4.5. Try: pip install ACID_code==1.4.5
646
+ If you still really wish to use ACID_HARPS, the last stable version of ACID with the method is 1.4.5. Try: pip install ACID_code_v2==1.4.5
621
647
  """
622
648
  raise NotImplementedError(f"ACID_HARPS is no longer supported in ACID. \n"
623
649
  f"Please use the ACID function with the appropriate inputs for HARPS spectra instead. \n"
624
650
  f"Future versions of ACID will provide functions to load and configure data from a range of different standard instruments. \n"
625
- f"If you still really wish to use ACID_HARPS, the last stable version of ACID with the method is 1.4.5. Try: pip install ACID_code==1.4.5")
651
+ f"If you still really wish to use ACID_HARPS, the last stable version of ACID with the method is 1.4.5. Try: pip install ACID_code_v2==1.4.5")
626
652
 
627
653
  def combine_spec(
628
654
  self,
@@ -834,9 +860,12 @@ class Acid:
834
860
  self.data.plot_continuum_fit(plot_type=plot_type)
835
861
 
836
862
  if np.any(flux_obs <= 0) or np.any(new_errors <= 0):
837
- raise ContinuumError("Continuum fit resulted in non-positive flux or errors, which is not physical.\n " \
863
+ error = ContinuumError("Continuum fit resulted in non-positive flux or errors, which is not physical.\n " \
838
864
  "Consider adjusting the polynomial order or continuum percentile. Use verbose=3 to see the plot of the continuum fit.\n " \
839
865
  "Note that this will only work for interactive terminals or displays which work with plt.show()")
866
+ self.data.exception = error
867
+ self.data.traceback = traceback.format_stack()
868
+ raise error
840
869
 
841
870
  return poly_coeffs, flux_obs, new_errors
842
871
 
@@ -856,7 +885,7 @@ class Acid:
856
885
  sn = self.data.sn["combined"]
857
886
 
858
887
  # Use the initial LSD run to get the forward model and scaled residuals
859
- forward, _profile = mcmc.MCMC(x, y, yerr, self.data.alpha).full_model(self.data.model_inputs)
888
+ forward, _profile = mcmc.MCMC(x, y, yerr, self.data.alpha, od=self.config.od).full_model(self.data.model_inputs)
860
889
  residuals = (y - forward) / forward
861
890
 
862
891
  # Chunk masking based on deviation from residuals
@@ -953,7 +982,8 @@ class Acid:
953
982
  """
954
983
 
955
984
  # Get default sampler kwargs from initial state
956
- sampler_kwargs, mcmc_kwargs = self._get_sampler_kwargs(nsteps, state)
985
+ if self.config.sampler_type == "emcee":
986
+ sampler_kwargs, mcmc_kwargs = self._get_sampler_kwargs(nsteps, state)
957
987
  pool_context = nullcontext(None)
958
988
 
959
989
  if self.config.parallel:
@@ -964,14 +994,25 @@ class Acid:
964
994
 
965
995
  ctx = mp.get_context("fork")
966
996
  pool_context = ctx.Pool(processes=self.config.cores, initializer=mcmc._mp_init_worker, initargs=(self.data,))
967
- log_prob_fn = mcmc._mp_log_probability
997
+ log_prob = mcmc._mp_log_probability if self.config.sampler_type == "emcee" else mcmc._mp_log_likelihood
998
+ ptform = mcmc._mp_ptform
999
+ queue_size = os.cpu_count()
968
1000
  else:
969
1001
  MCMC = mcmc.MCMC(self.data)
970
- log_prob_fn = MCMC
971
-
1002
+ log_prob = MCMC if self.config.sampler_type == "emcee" else MCMC.dynesty_logprob
1003
+ ptform = MCMC.ptform
1004
+ queue_size = None
1005
+
972
1006
  with pool_context as pool:
973
- self.sampler = EnsembleSampler(log_prob_fn=log_prob_fn, pool=pool, **sampler_kwargs)
974
- self.sampler.run_mcmc(**mcmc_kwargs)
1007
+ if self.config.sampler_type == "emcee":
1008
+ self.sampler = EnsembleSampler(log_prob_fn=log_prob, pool=pool, **sampler_kwargs)
1009
+ self.sampler.run_mcmc(**mcmc_kwargs)
1010
+ else:
1011
+ import dynesty
1012
+ if self.config.parallel:
1013
+ pool.size = self.config.cores
1014
+ self.sampler = dynesty.NestedSampler(log_prob, ptform, self.data.ndim, self.config.nsteps, pool=pool, queue_size=queue_size)
1015
+ self.sampler.run_nested(print_progress=self.config.verbose>1)
975
1016
 
976
1017
  def run_mcmc_until_converged(self, max_steps:IntLike, state=None) -> None:
977
1018
  """
ACID_code/data.py CHANGED
@@ -207,6 +207,8 @@ class Config:
207
207
  "dev_perc" : 25,
208
208
  "n_sig" : 3,
209
209
  "skips" : 1,
210
+ "od" : True,
211
+ "sampler_type" : "emcee",
210
212
  "parallel" : True,
211
213
  "cores" : None,
212
214
  "nwalkers" : None,
@@ -518,6 +520,12 @@ class Data:
518
520
  combined_profile : Optional[list] = None
519
521
  #: The final fitted continuum model and errors
520
522
  continuum_model : Optional[np.ndarray] = None
523
+ #: The forward model using the final profile, alpha matrix, and continuum model
524
+ forward_model : Optional[np.ndarray] = None
525
+ #: Errors on the above forward model, usually not needed
526
+ forward_errors : Optional[np.ndarray] = None
527
+ #: The x-axis for the above forward model, which is just the combined wavelength grid, and set in Result.process_results
528
+ forward_x : Optional[np.ndarray] = None
521
529
  #: The number of steps taken in the MCMC sampling, used for checking convergence and for resuming
522
530
  nsteps : Optional[int] = 0
523
531
  #: A flag for whether the profiles have been fully calculated to avoid recalculating
@@ -535,6 +543,10 @@ class Data:
535
543
  results_time : Optional[float] = 0
536
544
  #: total_time (float) - The total time for the full run
537
545
  total_time : Optional[float] = 0
546
+ #: The exception class if an error was raised during the run
547
+ exception : Optional[Exception] = None
548
+ #: The traceback string if an error was raised during the run
549
+ traceback : Optional[str] = None
538
550
 
539
551
  # Initialise the properties
540
552
  # -------------------------
@@ -573,14 +585,18 @@ class Data:
573
585
  if os.path.exists(sampler):
574
586
  self._sampler = utils.backend_to_sampler(HDFBackend(sampler), log_prob_fn)
575
587
  else:
576
- raise ValueError(f"The provided sampler path '{sampler}' does not exist.")
588
+ if self.config.verbose > 0:
589
+ print(f"Warning: The sampler was not found at the provided path '{sampler}', it may have been moved or deleted. \n"
590
+ f"The sampler will be set to None.", flush=True)
591
+ self._sampler = None
592
+ # TODO: Allow sampler to have completed results, but no sampler, and configured methods with _requiresampler property that need them
577
593
  elif sampler is None:
578
594
  if self.config.verbose > 0 and self._sampler is not None:
579
595
  print("Warning, you have discarded the sampler.")
580
596
  self._sampler = None
581
597
 
582
598
  if self._sampler is not None and isinstance(self._sampler.backend, HDFBackend):
583
- self.config.sampler_path = self._sampler.backend.filename
599
+ self.config.sampler_path = os.path.abspath(self._sampler.backend.filename)
584
600
 
585
601
  @property
586
602
  def velocities(self) -> Array1D|None:
@@ -1196,7 +1212,20 @@ class Data:
1196
1212
  """
1197
1213
  Converts the data object to a dictionary payload for saving. This is used internally in the save method,
1198
1214
  but can also be used for debugging or other purposes.
1215
+
1216
+ Parameters
1217
+ ----------
1218
+ store_sampler : bool, optional
1219
+ Whether to include the MCMC sampler in the dictionary payload, by default True.
1220
+ size_limit : Scalar | None, optional
1221
+ A hard size limit to the sampler in GB.
1222
+ If the sampler exceeds this size, it will not be stored regardless of the store_sampler flag.
1223
+ This is to avoid accidentally storing very large samplers. If None, no limit is set. Default is 1GB.
1224
+ A warning will be printed if this size_limit forces the store_sampler to be False if store_sampler was set to True.
1199
1225
  """
1226
+ if self.sampler is not None and self.config.sampler_type == "dynesty":
1227
+ raise ValueError("Storing the sampler is not currently supported for dynesty samplers.\n" \
1228
+ "If you really want to, separate the sampler with data.sampler.save('sampler') and add it back later.\n")
1200
1229
 
1201
1230
  payload: dict[str, Any] = {}
1202
1231
  for f in fields(self):
@@ -1239,14 +1268,26 @@ class Data:
1239
1268
 
1240
1269
  # Handle sampler separately
1241
1270
  self.sampler = payload.get("sampler", None) # property handles the loading of the sampler
1242
-
1271
+ if self.sampler is None and self.config.sampler_path is not None:
1272
+ try:
1273
+ self.sampler = self.config.sampler_path
1274
+ except:
1275
+ self.sampler = None
1276
+
1243
1277
  return self
1244
1278
 
1245
1279
  @property
1246
1280
  def result(self):
1281
+ if self.exception is not None:
1282
+ if self.config.verbose > 0:
1283
+ print(f"An exception was raised during the run, cannot return results object.\n"
1284
+ f"Returning None instead.")
1285
+ return None
1247
1286
  if not self.complete:
1248
- raise ValueError(f"Results have not yet been calculated, cannot return results object.\n"
1249
- f"Please run the MCMC sampling and process the results first.")
1287
+ if self.config.verbose > 0:
1288
+ print(f"Results for order {self.config.order} have not yet been calculated, cannot return results object.\n"
1289
+ f"Returning None instead.")
1290
+ return None
1250
1291
  from .result import Result
1251
1292
  return Result(self)
1252
1293
 
@@ -1403,8 +1444,9 @@ class DataList:
1403
1444
  order_range : Array1D|None = None,
1404
1445
  config : Config|list[Config]|None = None,
1405
1446
  save_dir : str|None = None,
1447
+ overwrite : bool = False,
1406
1448
  verbose : IntLike|bool|str|None = None,
1407
- load = None,
1449
+ _load = None,
1408
1450
  _data_list : list[Data]|None = None,
1409
1451
  **config_kwargs,
1410
1452
  ) -> None:
@@ -1453,11 +1495,17 @@ class DataList:
1453
1495
  By default the DataList will save data.pkl and sampler.h5 to the directory (named by the order number) to in this directory.
1454
1496
  If the Configs or kwargs passed contain their own save_path or sampler_path (see :py:class:`Acid`), those instead are used.
1455
1497
  If None, no saving will be done, this is however, not recommended. Default is None.
1498
+ overwrite : bool, optional
1499
+ Whether to overwrite existing with new Data instances when using run_ACID, or to load and use existing Data instance if they exist.
1500
+ If True, if a Data instance already exists for an order, it will be overwritten with the new Data instance generated from the ACID run for that order.
1501
+ Note, that the saving of this new Data instance only applies when run_ACID is run, otherwise it is just held in memory.
1502
+ If False, if a Data instance already exists for an order, it will be loaded and used instead of generating a new Data instance from the ACID run for that order.
1503
+ Default is False.
1456
1504
  verbose : int | bool | str | None, optional
1457
1505
  The verbosity level for printing information during the initialization.
1458
1506
  Follows the same format as the "verbose" input in the :py:class:`Config` class.
1459
1507
  Default is None.
1460
- load : Any, optional
1508
+ _load : Any, optional
1461
1509
  Not yet implemented, do not use. The idea is that you can input a Load object which has its own tools to pull s2d data from common instruments
1462
1510
  such as ESPRESSO, HARPS, etc. If you want to use this feature, please open an issue or contribute a pull request with the implementation.
1463
1511
  _data_list : list[:py:class:`Data`] | None, optional
@@ -1471,7 +1519,7 @@ class DataList:
1471
1519
  """
1472
1520
 
1473
1521
  # Raise if load was used
1474
- if load is not None:
1522
+ if _load is not None:
1475
1523
  raise NotImplementedError(f"The 'load' argument is not yet implemented. \n"
1476
1524
  f"The idea is that you can input a Load object which has its own tools to pull s2d data from common "\
1477
1525
  f"instruments such as ESPRESSO, HARPS, etc. \nIf you want to use this feature, please open an issue or "\
@@ -1485,19 +1533,19 @@ class DataList:
1485
1533
  # Configure verbosity
1486
1534
  self.verbose = Config(verbose=verbose).verbose
1487
1535
 
1488
- # Configure velocities
1536
+ # All orders should have the same velocity grid and line list
1489
1537
  self.velocities = velocities
1490
1538
 
1491
1539
  # Configure order_range, creates one if not input from the shape of wavelengths
1492
1540
  self.order_range = order_range # if None, will be set later, otherwise self.from_datalist handles the range from configs
1493
1541
 
1494
- # Configure save_dir, for saving intermediate results and figures per order
1542
+ # Set class attributes
1495
1543
  self._save_dir = None
1496
- self.save_dir = save_dir if save_dir is not None else None
1497
-
1498
- # Set empty class attributes
1544
+ self._data_list = None
1499
1545
  self._combined_profile = None
1546
+ self.overwrite = overwrite
1500
1547
  self.excluded_orders = []
1548
+ self.save_dir = save_dir
1501
1549
 
1502
1550
  if _data_list is not None:
1503
1551
  self.data_list = _data_list # datalist property handles the rest
@@ -1548,11 +1596,23 @@ class DataList:
1548
1596
  data.velocities = velocities
1549
1597
 
1550
1598
  if self.save_dir is not None:
1551
- save_path = os.path.join(self.save_dir, f"order_{order}", "data.pkl")
1552
- sampler_path = os.path.join(self.save_dir, f"order_{order}", "sampler.h5")
1553
- data.config.update_lowpri(save_path=save_path, # set default save path for this order which can be overwritten by user
1599
+ save_path = os.path.abspath(os.path.join(self.save_dir, f"order_{order}", "data.pkl"))
1600
+ sampler_path = os.path.abspath(os.path.join(self.save_dir, f"order_{order}", "sampler.h5"))
1601
+ data.config.update_hipri(save_path=save_path, # set default save path for this order which can be overwritten by user
1554
1602
  sampler_path=sampler_path) # set default sampler path for this order which can be overwritten by user
1555
1603
 
1604
+ # Check if file already exists
1605
+ if os.path.exists(data.config.save_path):
1606
+ if self.overwrite:
1607
+ if self.verbose > 1:
1608
+ print(f"File {data.config.save_path} already exists, but will be overwritten (when using run_ACID) due to setting.")
1609
+ else:
1610
+ if self.verbose > 0:
1611
+ print(f"File {data.config.save_path} already exists. The data for this order will be loaded from this file.")
1612
+ data = Data.load(data.config.save_path) # load the existing data from the file instead of using the newly initialized data
1613
+ else:
1614
+ data.save() # save the newly initialized, but mostly empty data instance to the file for future reference and use
1615
+
1556
1616
  datalist.append(data) # finally append to the datalist
1557
1617
 
1558
1618
  self.data_list = datalist # datalist property handles the rest
@@ -1717,7 +1777,7 @@ class DataList:
1717
1777
  use_index_mapping : bool = True,
1718
1778
  worker : IntLike|None = None,
1719
1779
  nworkers : IntLike|None = None,
1720
- allow_overwrite : bool = False,
1780
+ overwrite : bool|None = None,
1721
1781
  overwrite_kwargs : bool = False,
1722
1782
  **kwargs,
1723
1783
  ) -> None:
@@ -1748,8 +1808,9 @@ class DataList:
1748
1808
  If the sampler exceeds this size, it will not be stored regardless of the store_sampler flag.
1749
1809
  This is to avoid accidentally storing very large samplers. If None, no limit is set. Default is 1GB.
1750
1810
  A warning will be printed if this size_limit forces the store_sampler to be False if store_sampler was set to True.
1751
- allow_overwrite : bool, optional
1752
- If True, will allow overwriting existing result pickles in the save_dir. Default is False, which will skip running ACID on orders
1811
+ overwrite : bool, optional
1812
+ If True, will allow overwriting existing data and sampler pickles in the save_dir. Default is None, which will use the class
1813
+ default behaviour set in initialization (which is False). If False, this will skip running ACID on orders
1753
1814
  that already have result pickles in the save_dir.
1754
1815
  overwrite_kwargs : bool, optional
1755
1816
  If True, any keys in the kwargs that are also in the config for the Data instance will be overwritten by the kwargs values.
@@ -1761,6 +1822,10 @@ class DataList:
1761
1822
  """
1762
1823
  from .acid import Acid # local import to avoid circular imports, since Acid imports Data
1763
1824
 
1825
+ # Configure overwrite from class default if not input in the method call
1826
+ if overwrite is None:
1827
+ overwrite = self.overwrite
1828
+
1764
1829
  # Validate worker and nworkers inputs for splitting orders across workers, and set defaults if not provided for easier logic below.
1765
1830
  if worker is not None or nworkers is not None:
1766
1831
  if worker is None or nworkers is None:
@@ -1802,16 +1867,22 @@ class DataList:
1802
1867
  iterable = tqdm(orders, "Running ACID on orders", unit="order") if self.verbose > 1 else orders
1803
1868
  for order in iterable:
1804
1869
 
1870
+ data = self.data_list[self.o2i[order]]
1871
+
1805
1872
  # Check if ACID already ran for this order
1806
- if os.path.exists(self.data_list[self.o2i[order]].config.save_path) and not allow_overwrite:
1807
- if self.verbose > 1:
1808
- print(f"ACID result for order {order} already exists at {self.data_list[self.o2i[order]].config.save_path}. \n"
1809
- f"Skipping this order. To overwrite existing results, set allow_overwrite=True.")
1810
- # else the sampler and data instance is overwritten
1811
- continue
1873
+ if os.path.exists(data.config.save_path) and overwrite is False:
1874
+ if data.complete:
1875
+ if self.verbose > 1:
1876
+ print(f"An ACID completed result for order {order} already exists. \n"
1877
+ f"Skipping this order. To overwrite existing results, set overwrite=True.")
1878
+ continue
1879
+ elif data.exception is not None:
1880
+ if self.verbose > 1:
1881
+ print(f"An ACID run for order {order} previously encountered an exception. \n"
1882
+ f"Skipping this order. To retry and overwrite existing results, set overwrite=True.")
1883
+ continue
1812
1884
 
1813
- # Handling if any kwargs were input
1814
- data = self.data_list[self.o2i[order]]
1885
+ # Handling if any kwargs were input
1815
1886
  # Only overwrite if overwrite_kwargs is True, otherwise keep the existing linelist/velocities in the Data instance
1816
1887
  if "linelist" in kwargs:
1817
1888
  ll = kwargs.pop("linelist")
@@ -1866,20 +1937,23 @@ class DataList:
1866
1937
  save_dir : str | None, optional
1867
1938
  The directory to save the DataList pickle file. If None, self.save_dir is used. Default is None.
1868
1939
  """
1869
- d = {}
1870
1940
  if save_dir is not None:
1871
1941
  self.save_dir = save_dir
1872
1942
  if self.save_dir is None:
1873
- raise ValueError("No save path provided and save_dir was not set.")
1874
- d["dict_list"] = [data.to_dict() for data in self.data_list]
1943
+ raise ValueError("No save directory provided and save_dir was not set.")
1944
+ for data in self.data_list:
1945
+ # Ensures that the save paths for each data instance are correct and updated to match the current save_dir,
1946
+ # even if it was changed since initialization.
1947
+ self._set_paths_for_data(data, self.save_dir)
1875
1948
  save_loc = os.path.join(self.save_dir, "datalist.pkl")
1876
- d["save_dir"] = self.save_dir
1949
+ d = {}
1877
1950
  d["verbose"] = self.verbose
1951
+ # and maybe other class attributes later
1878
1952
  with open(save_loc, "wb") as f:
1879
1953
  pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
1880
1954
 
1881
1955
  @classmethod
1882
- def load(cls, path:str, print_progress:bool=True) -> DataList:
1956
+ def load(cls, path:str) -> DataList:
1883
1957
  """
1884
1958
  Loads a DataList from a pickle file. The pickle file should contain a dictionary with the list of Data objects (converted to dictionaries) and the save_dir.
1885
1959
  Will attempt to load from datalist.pkl in the provided path if it is a directory, otherwise will attempt to load from the provided path directly.
@@ -1888,55 +1962,55 @@ class DataList:
1888
1962
  Parameters
1889
1963
  ----------
1890
1964
  path : str
1891
- The path to the pickle file or directory to search for the pickle file.
1892
- print_progress : bool, optional
1893
- If True, will print progress messages during the loading process.
1894
- Similar to verbosity, but only for this function.
1895
- Default is True.
1965
+ The directory containing the datalist.pkl file, or the datalist.pkl itself. Note that the directories containing the results should also be in here.
1896
1966
 
1897
1967
  Returns
1898
1968
  -------
1899
1969
  DataList
1900
1970
  The loaded DataList object.
1901
1971
  """
1902
- if os.path.isdir(path):
1903
- path_check = os.path.join(path, "datalist.pkl")
1904
- if not os.path.exists(path_check):
1905
- # Final attempt to directly load the data pickles from order folders
1906
- all_files = os.listdir(path)
1907
- data_list = []
1908
- if print_progress:
1909
- all_files = tqdm(all_files, "Opening data pickles in order folders", unit="folder")
1910
- for folder in all_files:
1911
- folder_path = os.path.join(path, folder)
1912
- if os.path.isdir(folder_path) and folder.startswith("order_"):
1913
- pickle_path = os.path.join(folder_path, "data.pkl")
1914
- if os.path.exists(pickle_path):
1915
- with open(pickle_path, "rb") as f:
1916
- d = pickle.load(f)
1917
- data_list.append(Data().from_dict(d))
1918
- if len(data_list) > 0:
1919
- if print_progress:
1920
- print(f"Successfully loaded {len(data_list)} Data instances from order folders in {path}.")
1921
- return cls.from_datalist(data_list, save_dir=path)
1922
- else:
1923
- raise ValueError(f"No datalist.pkl found in {path}, and no data pickles found in order folders within that path.")
1972
+ abspath = os.path.abspath
1973
+ join = os.path.join
1974
+ isdir = os.path.isdir
1975
+ exists = os.path.exists
1976
+
1977
+ path = abspath(path)
1978
+ if path.endswith("datalist.pkl"):
1979
+ if not exists(path):
1980
+ raise ValueError(f"No pickle file found at {path} to load the DataList from.")
1924
1981
  else:
1925
- path = path_check
1982
+ path = os.path.dirname(path)
1983
+ elif not isdir(path):
1984
+ raise ValueError(f"The provided path {path} is not a directory, or a datalist pickle file.\n"
1985
+ f"You should provide a path to a directory containing the folders with the data pickles and sampler files.")
1986
+
1987
+ if exists(join(path, "datalist.pkl")):
1988
+ with open(join(path, "datalist.pkl"), "rb") as f:
1989
+ d = pickle.load(f)
1990
+ verbose = d["verbose"]
1926
1991
  else:
1927
- if not os.path.exists(path):
1928
- raise ValueError(f"No pickle file found at {path} to load the DataList from.")
1992
+ verbose = None
1993
+ verbose = Config(verbose=verbose).verbose
1929
1994
 
1930
- if print_progress:
1931
- print(f"Loading DataList from {path}...")
1932
- with open(path, "rb") as f:
1933
- d = pickle.load(f)
1934
- data_list = [Data().from_dict(d) for d in d["dict_list"]]
1935
- verbose = d["verbose"] if "verbose" in d else None
1936
- # We use a new save_dir depending on the path location in case the directory has changed since last saved
1937
- obj = cls.from_datalist(data_list, save_dir=os.path.dirname(path), verbose=verbose)
1938
- if print_progress:
1939
- print(f"Successfully loaded DataList from {path}.")
1995
+ dir_list = os.listdir(path)
1996
+ data_list = []
1997
+ folder_moved_flag = False
1998
+ dir_list = dir_list if verbose < 2 else tqdm(dir_list, "Loading Data instances from directory", unit="folder")
1999
+ for folder in dir_list:
2000
+ if isdir(join(path, folder)) and folder.startswith("order_"):
2001
+ save_path = join(path, folder, "data.pkl")
2002
+ sampler_path = join(path, folder, "sampler.h5")
2003
+ if exists(save_path):
2004
+ data = Data.load(save_path)
2005
+ if abspath(data.config.save_path) != save_path or abspath(data.config.sampler_path) != sampler_path:
2006
+ folder_moved_flag = True
2007
+ cls._set_paths_for_data(data, path)
2008
+ data_list.append(data)
2009
+
2010
+ if folder_moved_flag and verbose is not None and verbose > 0:
2011
+ print(f"Warning: At least one of the Data instances found in the directory does not match the current location, it has been updated.")
2012
+
2013
+ obj = cls.from_datalist(data_list, save_dir=path, verbose=verbose)
1940
2014
  return obj
1941
2015
 
1942
2016
  @property
@@ -1947,10 +2021,10 @@ class DataList:
1947
2021
  def save_dir(self, dir):
1948
2022
  if dir is not None:
1949
2023
  os.makedirs(dir, exist_ok=True)
1950
- self._save_dir = dir
1951
- if self._save_dir is None:
1952
- if self.verbose > 1:
2024
+ elif self._save_dir is None:
2025
+ if self.verbose > 0:
1953
2026
  print("Warning: save_dir is set to None. No results will be saved. This is not recommended.")
2027
+ self._save_dir = dir
1954
2028
  return
1955
2029
 
1956
2030
  @property
@@ -2075,3 +2149,14 @@ class DataList:
2075
2149
  from .profiles import Profiles
2076
2150
  profiles = Profiles(self.velocities, *self.combined_profile)
2077
2151
  return profiles.plot_fit(**kwargs)
2152
+
2153
+ @staticmethod
2154
+ def _set_paths_for_data(data: Data, save_dir: str) -> None:
2155
+ "Helper to set paths for a Data instance to a new one for a given order."
2156
+ order = data.config.order
2157
+ save_path = os.path.abspath(os.path.join(save_dir, f"order_{order}", "data.pkl"))
2158
+ sampler_path = os.path.abspath(os.path.join(save_dir, f"order_{order}", "sampler.h5"))
2159
+
2160
+ data.config.save_path = save_path
2161
+ data.config.sampler_path = sampler_path
2162
+ data.save()
ACID_code/lsd.py CHANGED
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations
2
2
  import numpy as np
3
3
  from astropy.io import fits
4
- import glob, psutil, os
4
+ import glob, psutil, os, traceback
5
5
  import matplotlib.pyplot as plt
6
6
  from scipy.signal import find_peaks
7
- from scipy.interpolate import LSQUnivariateSpline
8
7
  from tqdm import tqdm
9
8
  from scipy.linalg import cho_factor, cho_solve
10
9
  from beartype import beartype
@@ -25,7 +24,7 @@ class LSD:
25
24
  def __init__(
26
25
  self,
27
26
  data : object|None = None,
28
- OD : bool = True,
27
+ od : bool = None,
29
28
  verbose : IntLike|bool|str|None = None,
30
29
  ) -> None:
31
30
  """Initialises the LSD class, optionally with a Data instance to take parameters from.
@@ -34,8 +33,9 @@ class LSD:
34
33
  ----------
35
34
  data : object | None, optional
36
35
  A data instance to draw parameters and configs from, by default None
37
- OD : bool, optional
38
- Whether to perform LSD in optical depth space (True) or flux space (False), by default True.
36
+ od : bool, optional
37
+ Whether to perform LSD in optical depth space (True) or flux space (False), by default None.
38
+ If None, takes from Data instance if provided, else defaults to True.
39
39
  We generally recommend always using optical depth as ACID was always intended, but you can set
40
40
  this to False if you wish to do your own testing. See :ref:`LSD` in the documentation for more details.
41
41
  verbose : :py:type:`IntLike | bool | str | None`, optional
@@ -46,10 +46,10 @@ class LSD:
46
46
  # Set class variables, taking from input data if it exists, else setting to defaults
47
47
  self.slurm = "SLURM_JOB_ID" in os.environ
48
48
  self.data = data if data is not None else Data()
49
- self.linelist = data.linelist if data is not None else None
50
- self.OD = OD
49
+ self.linelist = self.data.linelist if self.data is not None else None
50
+ self.od = od if od is not None else self.data.config.od
51
51
  try:
52
- self.config = data.config
52
+ self.config = self.data.config
53
53
  except:
54
54
  self.config = Config() # uses defaults
55
55
  self.config.update_hipri(verbose=verbose) # Update config with new values, if not None
@@ -87,9 +87,9 @@ class LSD:
87
87
  decomposition and solving for the profile, by default None
88
88
  """
89
89
  # Ensure inputs are numpy arrays
90
- wavelengths = np.asarray(wavelengths)
91
- flux = np.asarray(flux)
92
- errors = np.asarray(errors)
90
+ wavelengths = np.array(wavelengths)
91
+ flux = np.array(flux)
92
+ errors = np.array(errors)
93
93
 
94
94
  # Ensure dimensions match
95
95
  if not wavelengths.shape == flux.shape == errors.shape:
@@ -107,16 +107,23 @@ class LSD:
107
107
  # Clip linelist to wavelength range of spectrum
108
108
  wavelengths_linelist, depths_linelist = utils.clip_wavelengths(wavelengths, wavelengths_linelist, depths_linelist)
109
109
  if len(wavelengths_linelist) == 0:
110
- raise LineListRangeError(f"No lines in linelist are within the wavelength range of the observed spectrum. \n"\
111
- f"You may have mismatched wavelengths units between linelist and spectrum or an empty linelist.\n"\
112
- f"Please check your linelist and input spectrum.")
110
+ error = LineListRangeError(
111
+ "No lines in linelist are within the wavelength range of the observed spectrum.\n"
112
+ "You may have mismatched wavelength units between linelist and spectrum or an empty linelist.\n"
113
+ "Please check your linelist and input spectrum."
114
+ )
115
+ self.data.exception = error
116
+ self.data.traceback = traceback.format_stack()
117
+ raise error
113
118
 
114
119
  # Apply S/N cut (of 1/(3*SN)) to linelist
115
120
  wavelengths_linelist, depths_linelist = self.sn_clip(wavelengths_linelist, depths_linelist, sn)
116
121
 
117
122
  # Convert to optical depth space for the linelist and the spectrum if needed, and convert errors accordingly
118
- if self.OD:
123
+ if self.od:
119
124
  flux, errors, depths_linelist = utils.flux_to_od(flux, errors, depths_linelist)
125
+ else:
126
+ flux -= 1
120
127
 
121
128
  # Calculates alpha in optical depth, selects lines greater than 1/(3*sn)
122
129
  if alpha is None:
@@ -130,11 +137,17 @@ class LSD:
130
137
  # Solve for profile and profile errors using Cholesky factors
131
138
  self.profile, self.profile_errors, self.cov_z = self.solve_z(self.alpha, flux, errors, self.c_factor, return_cov=True)
132
139
 
140
+ self.forward_model = self.alpha @ self.profile
141
+ self.forward_model_errors = np.sqrt(np.sum((self.alpha * self.profile_errors)**2, axis=1))
142
+
133
143
  # Convert profile back to flux if needed
134
- if self.OD:
144
+ if self.od:
135
145
  self.profile_F, self.profile_errors_F, self.cov_z_F = utils.od_to_flux(self.profile, self.profile_errors, cov_matrix=self.cov_z)
146
+ self.forward_model, self.forward_model_errors = utils.od_to_flux(self.forward_model, self.forward_model_errors)
136
147
  else:
148
+ self.profile += 1
137
149
  self.profile_F, self.profile_errors_F, self.cov_z_F = self.profile, self.profile_errors, self.cov_z
150
+ self.forward_model += 1
138
151
 
139
152
  return
140
153
 
@@ -171,7 +184,10 @@ class LSD:
171
184
  nrest = np.sum(idx)
172
185
  perc = 100 * nrest / (nrest + ncut)
173
186
  if nrest == 0:
174
- raise SNCutError(f"No lines remain in the linelist after S/N cut. Please check your linelist and S/N value.")
187
+ error = SNCutError(f"No lines remain in the linelist after S/N cut. Please check your linelist and S/N value.")
188
+ self.data.exception = error
189
+ self.data.traceback = traceback.format_stack()
190
+ raise error
175
191
  if self.config.verbose > 0:
176
192
  if perc < 5:
177
193
  print("Warning: Less than 5% of lines remain after S/N cut. Check your linelist and S/N value.")
ACID_code/mcmc.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import numpy as np
3
+ from numpy.linalg import norm
3
4
  from . import utils
4
5
  from .utils import Array1D, Array2D
5
6
  from beartype import beartype
@@ -20,6 +21,12 @@ def _mp_log_probability(theta):
20
21
  """Wrapper for log probability function for multiprocessing."""
21
22
  return _MCMC(theta)
22
23
 
24
+ def _mp_log_likelihood(theta):
25
+ return _MCMC.dynesty_logprob(theta)
26
+
27
+ def _mp_ptform(u):
28
+ return _MCMC.ptform(u)
29
+
23
30
  class MCMC:
24
31
 
25
32
  """
@@ -38,6 +45,9 @@ class MCMC:
38
45
  velocities : Array1D|None = None,
39
46
  c_factor = None,
40
47
  deterministic_profile : bool = False,
48
+ sampler_type : str = "emcee",
49
+ model_inputs : Array1D|None = None,
50
+ od : bool = True
41
51
  ) -> None:
42
52
  """
43
53
  Initialise MCMC functions with necessary data.
@@ -75,6 +85,9 @@ class MCMC:
75
85
  self.velocities = data.velocities
76
86
  self.c_factor = data.c_factor
77
87
  self.deterministic_profile = data.config.deterministic_profile
88
+ self.sampler_type = data.config.sampler_type
89
+ self.model_inputs = data.model_inputs
90
+ self.od = data.config.od
78
91
  else:
79
92
  self.x = x_or_data
80
93
  self.y = y
@@ -83,6 +96,10 @@ class MCMC:
83
96
  self.velocities = velocities
84
97
  self.c_factor = c_factor
85
98
  self.deterministic_profile = deterministic_profile
99
+ self.sampler_type = sampler_type
100
+ self.model_inputs = model_inputs
101
+ self.od = od
102
+ data = None
86
103
 
87
104
  self.k_max = self.alpha.shape[1] # the number of velocity points in the profile
88
105
 
@@ -91,10 +108,15 @@ class MCMC:
91
108
  a, b = utils.get_normalisation_coeffs(self.x)
92
109
  self.u = (a * self.x) + b # These are the normalized wavelengths used throughout the fitting process
93
110
 
94
- # For deterministic model, the below variables are used, and are precomputed for speed
95
- err_od = self.yerr / self.y # independent of continuum, since it's a ratio
96
- V = 1.0 / (err_od ** 2) # variance vector in log space, error already in log space
97
- self.AtV = self.alpha.T * V # precompute alpha matrix multiplication for _mcmc_solve_z input
111
+ if self.od:
112
+ # For deterministic model, the below variables are used, and are precomputed for speed
113
+ err_od = self.yerr / self.y # independent of continuum, since it's a ratio
114
+ V = 1.0 / (err_od ** 2) # variance vector in log space, error already in log space
115
+ else:
116
+ # For non-OD case, we need to precompute the variance vector in flux space for the likelihood calculation
117
+ V = 1.0 / (self.yerr ** 2) # variance vector in flux space
118
+
119
+ self.AtV = self.alpha.T * V # precompute alpha matrix multiplication for
98
120
 
99
121
  # Configure whether to use full or deterministic model
100
122
  if self.deterministic_profile is False:
@@ -125,10 +147,11 @@ class MCMC:
125
147
  """
126
148
  # Extract profile points and continuum coefficients from theta
127
149
  z = theta[:self.k_max]
150
+ z -= 1 if not self.od else 0 # if not using OD, profile points are in flux space and need to be shifted by 1
128
151
  mdl = self.alpha @ z
129
152
 
130
153
  # Converting model from optical depth to flux
131
- mdl = np.exp(-mdl)
154
+ mdl = np.exp(-mdl) if self.od else mdl+1 # if not using OD, just use flux directly
132
155
 
133
156
  # Calculate continuum polynomial
134
157
  coefs = np.asarray(theta[self.k_max:], dtype=float)
@@ -163,13 +186,22 @@ class MCMC:
163
186
 
164
187
  # Calculate fitted flux and convert to OD
165
188
  fitted_flux = self.y/mdl
166
- flux_od = - np.log(fitted_flux)
189
+
190
+ # Do OD/non-OD conversions
191
+ if self.od:
192
+ flux_od = (-np.log(fitted_flux))
193
+ AtV = self.AtV
194
+ else:
195
+ AtV = self.AtV * (mdl * mdl)
196
+ flux_od = fitted_flux - 1
167
197
 
168
198
  # Solve for the profile points
169
- z = cho_solve(self.c_factor, self.AtV @ flux_od)
199
+ z = cho_solve(self.c_factor, AtV @ flux_od, check_finite=False)
170
200
 
171
201
  # Convert back from optical depth to flux
172
- forward = np.exp(- (self.alpha @ z)) * mdl
202
+ dot_prod = self.alpha @ z
203
+ dot_prod = np.exp(-dot_prod) if self.od else dot_prod + 1
204
+ forward = dot_prod * mdl
173
205
 
174
206
  return forward, z
175
207
 
@@ -190,8 +222,12 @@ class MCMC:
190
222
  """
191
223
 
192
224
  # Hard box prior on each z[i]
193
- if np.any((z < -0.4) | (z > 1.6)):
194
- return -np.inf
225
+ if self.od:
226
+ if np.any((z < -0.4) | (z > 1.6)):
227
+ return -np.inf
228
+ else:
229
+ if np.any((z > 0.5) | (z <= -1)):
230
+ return -np.inf
195
231
 
196
232
  # # excluding the continuum points in the profile (in flux)
197
233
  # z_cont = []
@@ -289,4 +325,53 @@ class MCMC:
289
325
  tol_str = f"{last_tolerance:.4f}{tol_str}{config.tau_tol}"
290
326
  neff_str = ">" if last_neff > config.min_tau_factor else "<"
291
327
  neff_str = f"{last_neff:.2f}{neff_str}{config.min_tau_factor}"
292
- return tol_str, neff_str
328
+ return tol_str, neff_str
329
+
330
+ def dynesty_logprob(self, theta):
331
+ """Log likelihood function for dynesty nested sampling."""
332
+
333
+ forward, z = self.model_function(theta)
334
+
335
+ if not np.all(np.isfinite(forward)):
336
+ return -np.inf
337
+
338
+ lp = self.log_prior(z)
339
+ if not np.isfinite(lp):
340
+ return -np.inf
341
+
342
+ diff = self.y - forward
343
+ var = self.yerr * self.yerr
344
+
345
+ return -0.5 * np.sum(diff * diff / var + np.log(2 * np.pi * var))
346
+
347
+ def ptform(self, u):
348
+ """
349
+ Prior transform for dynesty.
350
+
351
+ Maps unit-cube samples u in [0, 1] to continuum polynomial
352
+ coefficients using uniform priors centred on self.model_inputs.
353
+ """
354
+
355
+ u = np.asarray(u, dtype=float)
356
+ theta0 = np.asarray(self.model_inputs, dtype=float)
357
+ theta0 = theta0[self.k_max:] # only continuum coefficents, not profile points
358
+
359
+ # Width of uniform prior around curve_fit solution.
360
+ # The floor matters because higher-order polynomial coefficients
361
+ # may be close to zero.
362
+ frac_width = 5
363
+ abs_floor = 0.05
364
+
365
+ width = np.maximum(frac_width * np.abs(theta0), abs_floor)
366
+
367
+ # Usually the zeroth-order continuum coefficient is close to 1,
368
+ # so give it a slightly wider absolute floor.
369
+ if len(width) > 0:
370
+ width[0] = max(width[0], 0.25)
371
+
372
+ lower = theta0 - width
373
+ upper = theta0 + width
374
+
375
+ return lower + u * (upper - lower)
376
+
377
+
ACID_code/result.py CHANGED
@@ -2,17 +2,21 @@ from __future__ import annotations
2
2
  from time import time
3
3
  import numpy as np
4
4
  import matplotlib.pyplot as plt
5
- import corner, sys, os, pickle, warnings, contextlib, functools, inspect, psutil
5
+ import corner, sys, os, warnings, contextlib, functools, inspect
6
6
  from emcee import EnsembleSampler
7
- import emcee.backends.backend as emceebackend
8
7
  from beartype import beartype
9
8
  from scipy.interpolate import interp1d
10
9
  from numpy.polynomial import polynomial as P
11
10
  from .lsd import LSD
12
- from . import mcmc
13
11
  from . import utils
14
12
  from .data import Data
15
13
  from .utils import IntLike, Scalar
14
+ try:
15
+ from dynesty.sampler import Sampler
16
+ from dynesty import plotting as dyplot
17
+ except ImportError:
18
+ Sampler = None
19
+ dyplot = None
16
20
  #TODO: utils.set_dict_defaults for plots
17
21
 
18
22
  warnings.filterwarnings("ignore")
@@ -59,7 +63,7 @@ class Result:
59
63
  def __init__(
60
64
  self,
61
65
  data : Data|object,
62
- sampler : EnsembleSampler|None = None,
66
+ sampler : EnsembleSampler|Sampler|None = None, # type:ignore
63
67
  process_results : bool = True,
64
68
  verbose : IntLike|bool|str|None = None,
65
69
  ) -> None:
@@ -74,7 +78,7 @@ class Result:
74
78
  provided, a sampler can be provided in the second argument. If a sampler object
75
79
  is provided, it will be used as the sampler, but all other attributes will need
76
80
  to be set manually for the Result object to be fully functional.
77
- sampler : :py:class:`emcee.EnsembleSampler`, optional
81
+ sampler : :py:class:`emcee.EnsembleSampler` | :py:class:`dynesty.Sampler`, optional
78
82
  Sets and overwrites the sampler in the Data object with this if provided, by default None.
79
83
  process_results : bool, optional
80
84
  Whether to process the results from the Acid object upon initialisation, by default True.
@@ -108,7 +112,8 @@ class Result:
108
112
  # Handle the sampler if input, initiate if one exists
109
113
  self.sampler = sampler if sampler is not None else self.sampler # update sampler if provided, otherwise keep the same
110
114
  if self.sampler is not None:
111
- self.initiate_sampler(self.sampler) # set internal variables based on sampler, sets sampler_initialized to True
115
+ self.dynesty = isinstance(self.sampler, Sampler)
116
+ self.initiate_sampler(self.sampler) # set internal variables based on sampler, sets sampler_initialiated to True
112
117
 
113
118
  if not self.data.complete:
114
119
  if process_results:
@@ -135,7 +140,10 @@ class Result:
135
140
  t0 = time()
136
141
 
137
142
  # Obtain flattened samples
138
- flat_samples = self.sampler.get_chain(discard=self.burnin, thin=self.thin, flat=True)
143
+ if self.dynesty:
144
+ flat_samples = self.sampler.results.samples_equal()
145
+ else:
146
+ flat_samples = self.sampler.get_chain(discard=self.burnin, thin=self.thin, flat=True)
139
147
 
140
148
  # Getting the final profile and continuum values
141
149
  nvel = len(self.data.velocities) if self.config.deterministic_profile is False else 0
@@ -213,15 +221,21 @@ class Result:
213
221
  alpha = self.data.alpha if condition else None
214
222
 
215
223
  LSD_profiles = LSD(self.data)
216
- LSD_profiles.run_LSD(wavelengths, flux, error, sn=sn, alpha=alpha)
224
+ LSD_profiles.run_LSD(wavelengths, flux, error, sn, alpha=alpha)
217
225
 
218
226
  profile_f = LSD_profiles.profile_F
219
227
  profile_errors_f = LSD_profiles.profile_errors_F
220
228
  cov_z_f = LSD_profiles.cov_z_F
221
229
 
222
230
  if counter == 0:
231
+ # Set combined profile params
223
232
  self.data.combined_profile = [profile_f, profile_errors_f, cov_z_f]
224
233
  self.data.continuum_model = mdl
234
+
235
+ # Set the forward model params, multiplied by mdl as LSD is run on normalized flux
236
+ self.data.forward_model = LSD_profiles.forward_model * mdl
237
+ self.data.forward_errors = LSD_profiles.forward_model_errors * mdl
238
+ self.data.forward_x = wavelengths
225
239
  else:
226
240
  profiles.append([profile_f, profile_errors_f, cov_z_f])
227
241
 
@@ -379,13 +393,23 @@ class Result:
379
393
  return fig, ax
380
394
  plt.show()
381
395
 
396
+ @_require_sampler
397
+ def plot_traceplot(self, return_fig:bool=False) -> None | tuple:
398
+ if not self.dynesty:
399
+ raise ValueError("Traceplot is only available for dynesty samplers, as emcee traceplots are already plotted in plot_walkers.")
400
+ fig, ax = dyplot.traceplot(self.sampler.results, labels=self.default_param_labels)
401
+ plt.suptitle('Dynesty Traceplot')
402
+ if return_fig:
403
+ return fig, ax
404
+ plt.show()
405
+
382
406
  @_require_sampler
383
407
  def plot_corner(
384
408
  self,
385
409
  sampler :EnsembleSampler|None = None,
386
410
  return_fig :bool = False,
387
411
  **kwargs,
388
- ) -> None | tuple:
412
+ ) -> None | plt.Figure:
389
413
  """Creates a corner plot for at maximum the last 8 LSD profile and continuum polynomial coefficients.
390
414
 
391
415
  Parameters
@@ -402,6 +426,13 @@ class Result:
402
426
  ----------
403
427
  If return_fig is True, returns the figure object containing the corner plot, else None
404
428
  """
429
+ if self.dynesty:
430
+ fig, axes = dyplot.cornerplot(self.sampler.results, labels=self.default_param_labels, show_titles=True, title_fmt=".3f", title_kwargs={"fontsize": 16}, **kwargs)
431
+ plt.suptitle('Dynesty Corner Plot')
432
+ if return_fig:
433
+ return fig, axes
434
+ plt.show()
435
+ return
405
436
 
406
437
  # Get samples and thin and burnin from the class variables
407
438
  samples = self.sampler.get_chain()
@@ -496,6 +527,7 @@ class Result:
496
527
  @_require_profiles
497
528
  def plot_forward_model(
498
529
  self,
530
+ fig_ax :tuple|None = None,
499
531
  grid :bool = True,
500
532
  labels :dict|None = None,
501
533
  return_fig :bool = False,
@@ -506,6 +538,12 @@ class Result:
506
538
 
507
539
  Parameters
508
540
  ----------
541
+ fig_ax: tuple | None
542
+ Optionally provide an existing fig/axis tuple to plot on, by default None and
543
+ creates a new figure and axis. The axis must be a 2 element array of axes,
544
+ where the first axis is for the spectrum and forward model,
545
+ and the second axis is for the residuals.
546
+ If provided, the grid, labels, and titles should be set by you.
509
547
  grid : bool, optional
510
548
  Show or hide grid, by default True
511
549
  labels : dict | None, optional
@@ -544,8 +582,7 @@ class Result:
544
582
 
545
583
  # Get flat_samples which are the same samples used to calculate the final profile, alpha is OD,
546
584
  # so convert profile back to OD and reconvert to flux for forward model
547
- profile = utils.flux_to_od(self[0])
548
- model_flux = utils.od_to_flux(self.data.alpha @ profile) * self.data.continuum_model
585
+ model_flux = self.data.forward_model
549
586
 
550
587
  # Due to distortion at the edges of the profile, we drop the last 2 pixels
551
588
  wavelengths = utils.drop_edges(wavelengths)
@@ -554,22 +591,26 @@ class Result:
554
591
  continuum_model = utils.drop_edges(self.data.continuum_model)
555
592
 
556
593
  # Plotting
557
- fig, ax = plt.subplots(2, 1, **subplot_kwargs)
594
+ if fig_ax is not None:
595
+ fig, ax = fig_ax
596
+ else:
597
+ fig, ax = plt.subplots(2, 1, **subplot_kwargs)
598
+ ax[0].set_title(labels["title"])
599
+ ax[1].set_xlabel(labels["xlabel"])
600
+ ax[0].set_ylabel(labels["ylabel"])
601
+ ax[1].set_ylabel(labels["residuals_ylabel"])
602
+ ax[0].grid(grid)
603
+ ax[1].grid(grid)
604
+ plt.subplots_adjust(hspace=0.05)
605
+
606
+ ax[1].axhline(0, color='black', linestyle='--', linewidth=1)
558
607
  ax[0].plot(wavelengths, flux, color='black', linewidth=1, label='Observed Spectrum')
559
608
  ax[0].plot(wavelengths, model_flux, color='C0', linewidth=1, label='Forward Model Fit')
560
609
  ax[0].plot(wavelengths, continuum_model, color='C1', linewidth=1, label='Fitted Continuum', linestyle='--')
561
610
  ax[1].plot(wavelengths, model_flux-flux, color='C0', linewidth=1, label='Residuals')
562
611
  ax[1].axhline(0, color='black', linestyle='--', linewidth=1)
563
- ax[0].set_title(labels["title"])
564
- ax[1].set_xlabel(labels["xlabel"])
565
- ax[0].set_ylabel(labels["ylabel"])
566
- ax[1].set_ylabel(labels["residuals_ylabel"])
567
- ax[1].axhline(0, color='black', linestyle='--', linewidth=1)
568
612
  ax[0].legend()
569
613
  ax[1].legend()
570
- ax[0].grid(grid)
571
- ax[1].grid(grid)
572
- plt.subplots_adjust(hspace=0.05)
573
614
 
574
615
  if return_fig:
575
616
  return fig, ax
@@ -720,14 +761,14 @@ class Result:
720
761
  return fig, ax
721
762
  plt.show()
722
763
 
723
- def initiate_sampler(self, sampler:EnsembleSampler|None, _method_name=None) -> None:
764
+ def initiate_sampler(self, sampler:EnsembleSampler|Sampler|None, _method_name=None) -> None: # type:ignore
724
765
  """
725
766
  Initiates the sampler attribute from an external sampler.
726
767
 
727
768
  Parameters
728
769
  ----------
729
- sampler : :py:class:`emcee.EnsembleSampler`
730
- An emcee EnsembleSampler object to set as the sampler attribute.
770
+ sampler : :py:class:`emcee.EnsembleSampler` or object, optional
771
+ An emcee EnsembleSampler object or a compatible sampler object to set as the sampler attribute.
731
772
  _method_name : str, optional
732
773
  Internal parameter used to track which method is calling initiate_sampler, for error messages.
733
774
  Not intended for user input, by default None.
@@ -744,6 +785,14 @@ class Result:
744
785
  error_msg = "Cannot initiate sampler without a sampler stored in the instance or passed as a parameter, please pass in a sampler "
745
786
  raise AttributeError(error_msg)
746
787
 
788
+ if self.dynesty:
789
+ a=ord('a')
790
+ alph=[chr(i) for i in range(a,a+26)]
791
+ poly_labels = [alph[i] for i in range(self.config.poly_ord + 1)]
792
+ self.default_param_labels = poly_labels
793
+ self.default_params = None
794
+ return
795
+
747
796
  # Calculate autocorr time, burnin, thin
748
797
  # Suppress output from get_autocorr_time call
749
798
  with open(os.devnull, "w") as devnull, \
@@ -799,12 +848,12 @@ class Result:
799
848
  self.sampler_initialized = True
800
849
 
801
850
  @property
802
- def sampler(self) -> EnsembleSampler|None:
851
+ def sampler(self) -> EnsembleSampler|Sampler|None: # type:ignore
803
852
  """Returns the sampler attribute, by default is None if not saved."""
804
853
  return self.data.sampler
805
854
 
806
855
  @sampler.setter
807
- def sampler(self, value: EnsembleSampler|None) -> None:
856
+ def sampler(self, value: EnsembleSampler|Sampler|None) -> None: # type:ignore
808
857
  """Sets the sampler in the data class."""
809
858
  self.data.sampler = value
810
859
 
ACID_code/utils.py CHANGED
@@ -108,9 +108,12 @@ def mask_invalid(wavelengths, flux, errors=None, return_mask=False, verbose=2):
108
108
 
109
109
  if verbose > 1:
110
110
  num_invalid = np.size(wavelengths) - np.count_nonzero(mask)
111
- if num_invalid > 0:
112
- print(f"Your spectrum includes {num_invalid} out of {np.size(wavelengths)} non-positive/non-finite/nan values, which will be dropped when necessary, \n"
113
- f"but it is still recommended to check your wavelength, spectrum and error arrays for bad pixels and make sure this is intentional.")
111
+ perc_invalid = num_invalid / np.size(wavelengths) * 100
112
+ if perc_invalid > 10:
113
+ print(f"Your spectrum includes {num_invalid} out of {np.size(wavelengths)} non-positive/non-finite/nan values ({perc_invalid:.2f}%), \n"
114
+ f"which will be dropped when necessary, but it is still recommended to check your wavelength, \n"
115
+ f"spectrum and error arrays for bad pixels and make sure this is intentional. \n"
116
+ f"This warning is only printed if more than 10% of pixels are invalid.")
114
117
 
115
118
  output = (w, f, e) if errors is not None else (w, f)
116
119
  output = output + (mask,) if return_mask else output
@@ -146,8 +149,9 @@ def drop_invalid(wavelengths, flux, errors=None, return_mask=False, verbose=2):
146
149
 
147
150
  if verbose > 1:
148
151
  num_invalid = np.size(wavelengths) - np.count_nonzero(mask)
149
- if num_invalid > 0:
150
- print(f"Dropped {num_invalid} invalid pixels out of {np.size(wavelengths)} (non-finite or <= 0).")
152
+ perc_invalid = num_invalid / np.size(wavelengths) * 100
153
+ if perc_invalid > 10:
154
+ print(f"Dropped {num_invalid} invalid pixels out of {np.size(wavelengths)} (non-finite or <= 0), which is {perc_invalid:.2f}% of the total.")
151
155
 
152
156
  output = (w, f, e) if errors is not None else (w, f)
153
157
  output = output + (mask,) if return_mask else output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ACID_code
3
- Version: 2.0.0a2
3
+ Version: 2.0.0a3
4
4
  Summary: Returns line profiles from input spectra by fitting the stellar continuum and performing LSD
5
5
  Author: Lucy Dolan
6
6
  Author-email: Benjamin Cadell <bcadell01@qub.ac.uk>
@@ -0,0 +1,15 @@
1
+ ACID_code/__init__.py,sha256=-cmUx-kDheDBhk8ZlleWVRrWZyXZRVe-dwciVV9ZxLg,655
2
+ ACID_code/acid.py,sha256=hibDm91O0_HLO3d50I2Q137rXi_xmXdcJSb9KcDNLCs,75532
3
+ ACID_code/data.py,sha256=qaLGYFDPyK7XlXg09b1AZO4xGMZCmrYd7t04SU3LknA,114575
4
+ ACID_code/errors.py,sha256=qqG44x_rVpi5njLIoRkwvw_6owy33wejvsFuAZ-XKKM,525
5
+ ACID_code/load.py,sha256=3gzIZpAv7flgX5ekWdRDI95hkPTxvtp2F1liygpivOQ,5453
6
+ ACID_code/lsd.py,sha256=VfdB-mN2uH20N6jT6lGKhEaepOwB2bJQepyZfIpqM9U,21521
7
+ ACID_code/mcmc.py,sha256=PjGuxUhq-oFRkKKQmQPdSp82NyueG9FPwr7kpLgLk0U,14314
8
+ ACID_code/profiles.py,sha256=KkVsJjCLXx_wjRBialfmnlGwjWsROy0GkggdPPYLelc,17103
9
+ ACID_code/result.py,sha256=itDNK0zxgtNH8-Igqip-zFi6StK5ZyuMS47UYVooSq8,42141
10
+ ACID_code/utils.py,sha256=Xt24WrhV6bKW-NiNl1FsNiJ2krntCJeCUL-hzIQAdLA,27637
11
+ acid_code-2.0.0a3.dist-info/licenses/LICENSE,sha256=L6dUgqjvHmRoobrBCPSHKC4UtRM5Ldp1DJBC4bnLk3w,1070
12
+ acid_code-2.0.0a3.dist-info/METADATA,sha256=XFBGmLrD4od1XTRY4SRGN3Cg-RAvM21VUEClyV6paIQ,2999
13
+ acid_code-2.0.0a3.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
14
+ acid_code-2.0.0a3.dist-info/top_level.txt,sha256=O4OaSabv1ebFYQmHgftr1PGAv6BvC2l81Y3HjgNehQI,10
15
+ acid_code-2.0.0a3.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- ACID_code/__init__.py,sha256=-cmUx-kDheDBhk8ZlleWVRrWZyXZRVe-dwciVV9ZxLg,655
2
- ACID_code/acid.py,sha256=Id3pDMmuQ5OEx3gLarS5p5-pnAiZ0b_tF7_DX3Ou8Wo,72704
3
- ACID_code/data.py,sha256=J8Nz7BAzltxclElWH1iI3rx0yUPDRSjcSCH3q1xjsRE,109492
4
- ACID_code/errors.py,sha256=qqG44x_rVpi5njLIoRkwvw_6owy33wejvsFuAZ-XKKM,525
5
- ACID_code/load.py,sha256=3gzIZpAv7flgX5ekWdRDI95hkPTxvtp2F1liygpivOQ,5453
6
- ACID_code/lsd.py,sha256=iSNjxga-ic1tDCEpEOr5ISSxOsg6GLL_BYqTJi58MTA,20814
7
- ACID_code/mcmc.py,sha256=kNW1Aj5jfWb5xa95yKJU4NBOp1hnGV3KzHV5DrCtKU8,11434
8
- ACID_code/profiles.py,sha256=KkVsJjCLXx_wjRBialfmnlGwjWsROy0GkggdPPYLelc,17103
9
- ACID_code/result.py,sha256=9VXQkovQflDSUIYjwHCAo_cFMs7981IOuKDwlFv8_9w,39819
10
- ACID_code/utils.py,sha256=SY7qlIM15h8yYz7G_-DqXpYhralFbE0ycpOPvnw6DZw,27323
11
- acid_code-2.0.0a2.dist-info/licenses/LICENSE,sha256=L6dUgqjvHmRoobrBCPSHKC4UtRM5Ldp1DJBC4bnLk3w,1070
12
- acid_code-2.0.0a2.dist-info/METADATA,sha256=qVMgFV_XyFpmUQgGZ6zG3gSZXZR4aT7Lbhfdp6mIabA,2999
13
- acid_code-2.0.0a2.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
14
- acid_code-2.0.0a2.dist-info/top_level.txt,sha256=O4OaSabv1ebFYQmHgftr1PGAv6BvC2l81Y3HjgNehQI,10
15
- acid_code-2.0.0a2.dist-info/RECORD,,