imap-processing 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of imap-processing might be problematic. Click here for more details.

Files changed (43) hide show
  1. imap_processing/_version.py +2 -2
  2. imap_processing/cdf/config/imap_codice_global_cdf_attrs.yaml +13 -1
  3. imap_processing/cdf/config/imap_codice_l2-hi-omni_variable_attrs.yaml +635 -0
  4. imap_processing/cdf/config/imap_codice_l2-hi-sectored_variable_attrs.yaml +422 -0
  5. imap_processing/cdf/config/imap_enamaps_l2-common_variable_attrs.yaml +28 -21
  6. imap_processing/cdf/config/imap_enamaps_l2-healpix_variable_attrs.yaml +2 -0
  7. imap_processing/cdf/config/imap_enamaps_l2-rectangular_variable_attrs.yaml +12 -2
  8. imap_processing/cli.py +6 -11
  9. imap_processing/codice/codice_l2.py +640 -127
  10. imap_processing/codice/constants.py +61 -0
  11. imap_processing/ena_maps/ena_maps.py +111 -60
  12. imap_processing/ena_maps/utils/coordinates.py +5 -0
  13. imap_processing/ena_maps/utils/corrections.py +268 -0
  14. imap_processing/ena_maps/utils/map_utils.py +143 -42
  15. imap_processing/hi/hi_l2.py +3 -8
  16. imap_processing/ialirt/constants.py +7 -1
  17. imap_processing/ialirt/generate_coverage.py +1 -1
  18. imap_processing/ialirt/l0/process_codice.py +66 -0
  19. imap_processing/ialirt/utils/create_xarray.py +1 -0
  20. imap_processing/idex/idex_l2a.py +2 -2
  21. imap_processing/idex/idex_l2b.py +1 -1
  22. imap_processing/lo/l1c/lo_l1c.py +61 -3
  23. imap_processing/lo/l2/lo_l2.py +79 -11
  24. imap_processing/mag/l1a/mag_l1a.py +2 -2
  25. imap_processing/mag/l1a/mag_l1a_data.py +71 -13
  26. imap_processing/mag/l1c/interpolation_methods.py +34 -13
  27. imap_processing/mag/l1c/mag_l1c.py +117 -67
  28. imap_processing/mag/l1d/mag_l1d_data.py +3 -1
  29. imap_processing/spice/geometry.py +11 -9
  30. imap_processing/spice/pointing_frame.py +77 -50
  31. imap_processing/swapi/l1/swapi_l1.py +12 -4
  32. imap_processing/swe/utils/swe_constants.py +7 -7
  33. imap_processing/ultra/l1b/extendedspin.py +1 -1
  34. imap_processing/ultra/l1b/ultra_l1b_culling.py +2 -2
  35. imap_processing/ultra/l1b/ultra_l1b_extended.py +1 -1
  36. imap_processing/ultra/l1c/helio_pset.py +1 -1
  37. imap_processing/ultra/l1c/spacecraft_pset.py +2 -2
  38. imap_processing-1.0.1.dist-info/METADATA +121 -0
  39. {imap_processing-1.0.0.dist-info → imap_processing-1.0.1.dist-info}/RECORD +42 -40
  40. imap_processing-1.0.0.dist-info/METADATA +0 -120
  41. {imap_processing-1.0.0.dist-info → imap_processing-1.0.1.dist-info}/LICENSE +0 -0
  42. {imap_processing-1.0.0.dist-info → imap_processing-1.0.1.dist-info}/WHEEL +0 -0
  43. {imap_processing-1.0.0.dist-info → imap_processing-1.0.1.dist-info}/entry_points.txt +0 -0
@@ -11,9 +11,14 @@ from scipy.stats import binned_statistic_dd
11
11
  from imap_processing.cdf.imap_cdf_manager import ImapCdfAttributes
12
12
  from imap_processing.lo import lo_ancillary
13
13
  from imap_processing.lo.l1b.lo_l1b import set_bad_or_goodtimes
14
+ from imap_processing.spice.geometry import SpiceFrame, frame_transform_az_el
14
15
  from imap_processing.spice.repoint import get_pointing_times
15
16
  from imap_processing.spice.spin import get_spin_number
16
- from imap_processing.spice.time import met_to_ttj2000ns, ttj2000ns_to_met
17
+ from imap_processing.spice.time import (
18
+ met_to_ttj2000ns,
19
+ ttj2000ns_to_et,
20
+ ttj2000ns_to_met,
21
+ )
17
22
 
18
23
  N_ESA_ENERGY_STEPS = 7
19
24
  N_SPIN_ANGLE_BINS = 3600
@@ -164,6 +169,10 @@ def lo_l1c(sci_dependencies: dict, anc_dependencies: list) -> list[xr.Dataset]:
164
169
  attr_mgr,
165
170
  )
166
171
 
172
+ pset["hae_longitude"], pset["hae_latitude"] = set_pointing_directions(
173
+ pset["epoch"].item()
174
+ )
175
+
167
176
  pset.attrs = attr_mgr.get_global_attributes(logical_source)
168
177
 
169
178
  pset = pset.assign_coords(
@@ -295,7 +304,7 @@ def create_pset_counts(
295
304
  lat_edges = np.arange(41)
296
305
  energy_edges = np.arange(8)
297
306
 
298
- hist, edges = np.histogramdd(
307
+ hist, _edges = np.histogramdd(
299
308
  data,
300
309
  bins=[energy_edges, lon_edges, lat_edges],
301
310
  )
@@ -572,7 +581,7 @@ def set_background_rates(
572
581
  if row["type"] == "rate":
573
582
  bg_rates[esa_step, bin_start:bin_end, :] = value
574
583
  elif row["type"] == "sigma":
575
- bg_stat_uncert[esa_step, bin_start:bin_end, :] = value
584
+ bg_sys_err[esa_step, bin_start:bin_end, :] = value
576
585
  else:
577
586
  raise ValueError("Unknown background type in ancillary file.")
578
587
  # set the background rates, uncertainties, and systematic errors
@@ -597,3 +606,52 @@ def set_background_rates(
597
606
  )
598
607
 
599
608
  return bg_rates_data, bg_stat_uncert_data, bg_sys_err_data
609
+
610
+
611
+ def set_pointing_directions(epoch: float) -> tuple[xr.DataArray, xr.DataArray]:
612
+ """
613
+ Set the pointing directions for the given epoch.
614
+
615
+ The pointing directions are calculated by transforming Spin and off angles
616
+ to HAE longitude and latitude using SPICE. This returns the HAE longitude and
617
+ latitude as (3600, 40) arrays for each the latitude and longitude.
618
+
619
+ Parameters
620
+ ----------
621
+ epoch : float
622
+ The epoch time in TTJ2000ns.
623
+
624
+ Returns
625
+ -------
626
+ hae_longitude : xr.DataArray
627
+ The HAE longitude for each spin and off angle bin.
628
+ hae_latitude : xr.DataArray
629
+ The HAE latitude for each spin and off angle bin.
630
+ """
631
+ et = ttj2000ns_to_et(epoch)
632
+ # create a meshgrid of spin and off angles using the bin centers
633
+ spin, off = np.meshgrid(
634
+ SPIN_ANGLE_BIN_CENTERS, OFF_ANGLE_BIN_CENTERS, indexing="ij"
635
+ )
636
+ dps_az_el = np.stack([spin, off], axis=-1)
637
+
638
+ # Transform from DPS Az/El to HAE lon/lat
639
+ hae_az_el = frame_transform_az_el(
640
+ et, dps_az_el, SpiceFrame.IMAP_DPS, SpiceFrame.IMAP_HAE, degrees=True
641
+ )
642
+
643
+ return xr.DataArray(
644
+ data=hae_az_el[:, :, 0].astype(np.float64),
645
+ dims=["spin_angle", "off_angle"],
646
+ # TODO: Add hae_longitude to yaml
647
+ # attrs=attr_mgr.get_variable_attributes(
648
+ # "hae_longitude"
649
+ # )
650
+ ), xr.DataArray(
651
+ data=hae_az_el[:, :, 1].astype(np.float64),
652
+ dims=["spin_angle", "off_angle"],
653
+ # TODO: Add hae_longitude to yaml
654
+ # attrs=attr_mgr.get_variable_attributes(
655
+ # "hae_latitude"
656
+ # )
657
+ )
@@ -75,7 +75,18 @@ def lo_l2(
75
75
  dataset = add_geometric_factors(dataset, map_descriptor.species)
76
76
 
77
77
  logger.info("Step 4: Calculating rates and intensities")
78
- dataset = calculate_all_rates_and_intensities(dataset)
78
+
79
+ # Determine if corrections are needed and prepare oxygen data if required
80
+ sputtering_correction, bootstrap_correction, o_map_dataset = _prepare_corrections(
81
+ map_descriptor, descriptor, sci_dependencies, anc_dependencies
82
+ )
83
+
84
+ dataset = calculate_all_rates_and_intensities(
85
+ dataset,
86
+ sputtering_correction=sputtering_correction,
87
+ bootstrap_correction=bootstrap_correction,
88
+ o_map_dataset=o_map_dataset,
89
+ )
79
90
 
80
91
  logger.info("Step 5: Finalizing dataset with attributes")
81
92
  dataset = finalize_dataset(dataset, descriptor)
@@ -84,6 +95,59 @@ def lo_l2(
84
95
  return [dataset]
85
96
 
86
97
 
98
+ def _prepare_corrections(
99
+ map_descriptor: MapDescriptor,
100
+ descriptor: str,
101
+ sci_dependencies: dict,
102
+ anc_dependencies: list,
103
+ ) -> tuple[bool, bool, xr.Dataset | None]:
104
+ """
105
+ Determine what corrections are needed and prepare oxygen dataset if required.
106
+
107
+ This helper function encapsulates the logic for determining when sputtering
108
+ and bootstrap corrections should be applied, and handles the creation of
109
+ the oxygen dataset needed for sputtering corrections.
110
+
111
+ Parameters
112
+ ----------
113
+ map_descriptor : MapDescriptor
114
+ The parsed map descriptor containing species and data type information.
115
+ descriptor : str
116
+ The original descriptor string for creating the oxygen variant.
117
+ sci_dependencies : dict
118
+ Dictionary of datasets needed for L2 data product creation.
119
+ anc_dependencies : list
120
+ List of ancillary file paths.
121
+
122
+ Returns
123
+ -------
124
+ tuple[bool, bool, xr.Dataset | None]
125
+ A tuple containing:
126
+ - sputtering_correction: Whether to apply sputtering corrections
127
+ - bootstrap_correction: Whether to apply bootstrap corrections
128
+ - o_map_dataset: Oxygen dataset if needed, None otherwise
129
+ """
130
+ # Default values - no corrections needed
131
+ sputtering_correction = False
132
+ bootstrap_correction = False
133
+ o_map_dataset = None
134
+
135
+ # Sputtering and bootstrap corrections are only applied to hydrogen ENA data
136
+ # Guard against recursion: don't process oxygen for oxygen maps
137
+ if (
138
+ map_descriptor.species == "h"
139
+ and map_descriptor.principal_data == "ena"
140
+ and "-o-" not in descriptor
141
+ ): # Safety check to prevent infinite recursion
142
+ logger.info("Creating map for oxygen for sputtering corrections")
143
+ o_descriptor = descriptor.replace("-h-", "-o-")
144
+ o_map_dataset = lo_l2(sci_dependencies, anc_dependencies, o_descriptor)[0]
145
+ sputtering_correction = True
146
+ bootstrap_correction = True
147
+
148
+ return sputtering_correction, bootstrap_correction, o_map_dataset
149
+
150
+
87
151
  # =============================================================================
88
152
  # SETUP AND INITIALIZATION HELPERS
89
153
  # =============================================================================
@@ -600,6 +664,7 @@ def calculate_all_rates_and_intensities(
600
664
  dataset: xr.Dataset,
601
665
  sputtering_correction: bool = False,
602
666
  bootstrap_correction: bool = False,
667
+ o_map_dataset: xr.Dataset | None = None,
603
668
  ) -> xr.Dataset:
604
669
  """
605
670
  Calculate rates and intensities with proper error propagation.
@@ -614,6 +679,8 @@ def calculate_all_rates_and_intensities(
614
679
  bootstrap_correction : bool, optional
615
680
  Whether to apply bootstrap corrections to intensities.
616
681
  Default is False.
682
+ o_map_dataset : xr.Dataset, optional
683
+ Dataset specifically for oxygen, needed for sputtering corrections.
617
684
 
618
685
  Returns
619
686
  -------
@@ -632,12 +699,9 @@ def calculate_all_rates_and_intensities(
632
699
 
633
700
  # Optional Step 4: Calculate sputtering corrections
634
701
  if sputtering_correction:
635
- # TODO: The second dataset is for Oxygen specifically,
636
- # if we get an H dataset in, we may need to calculate
637
- # the O dataset separately before calling here.
638
- dataset = calculate_sputtering_corrections(dataset, dataset)
702
+ dataset = calculate_sputtering_corrections(dataset, o_map_dataset)
639
703
 
640
- # Optional Step 5: Clean up intermediate variables
704
+ # Optional Step 5: Calculate bootstrap corrections
641
705
  if bootstrap_correction:
642
706
  dataset = calculate_bootstrap_corrections(dataset)
643
707
 
@@ -764,7 +828,7 @@ def calculate_sputtering_corrections(
764
828
  ----------
765
829
  dataset : xr.Dataset
766
830
  Dataset with count rates, geometric factors, and center energies.
767
- This could be either an H or O dataset.
831
+ This is an H dataset that we are applying the corrections to.
768
832
  o_dataset : xr.Dataset
769
833
  Dataset specifically for oxygen, needed to access oxygen intensities
770
834
  and uncertainties.
@@ -773,9 +837,9 @@ def calculate_sputtering_corrections(
773
837
  -------
774
838
  xr.Dataset
775
839
  Dataset with calculated sputtering-corrected intensities and their
776
- uncertainties for hydrogen and oxygen.
840
+ uncertainties.
777
841
  """
778
- logger.info("Applying sputtering corrections to oxygen intensities")
842
+ logger.info("Applying sputtering corrections to hydrogen intensities")
779
843
  # Only apply sputtering correction to esa levels 5 and 6 (indices 4 and 5)
780
844
  energy_indices = [4, 5]
781
845
  small_dataset = dataset.isel(epoch=0, energy=energy_indices)
@@ -789,6 +853,10 @@ def calculate_sputtering_corrections(
789
853
  "bg_rates_stat_uncert"
790
854
  ] / (o_small_dataset["geometric_factor"] * o_small_dataset["energy"])
791
855
 
856
+ # We need to align the energy dimensions from the oxygen dataset to the
857
+ # Hydrogen dataset so the calculations below get aligned by xarray correctly.
858
+ o_small_dataset["energy"] = small_dataset["energy"]
859
+
792
860
  # Equation 9
793
861
  j_o_prime = o_small_dataset["ena_intensity"] - o_small_dataset["bg_intensity"]
794
862
  j_o_prime.values[j_o_prime.values < 0] = 0 # No negative intensities
@@ -801,10 +869,10 @@ def calculate_sputtering_corrections(
801
869
 
802
870
  # NOTE: From table 2 of the mapping document, for energy level 5 and 6
803
871
  sputter_correction_factor = xr.DataArray(
804
- [0.15, 0.01], dims=["energy"], coords={"energy": energy_indices}
872
+ [0.15, 0.01], dims=["energy"], coords={"energy": small_dataset["energy"]}
805
873
  )
806
874
  # Equation 11
807
- # Remove the sputtered oxygen intensity to correct the original O intensity
875
+ # Remove the sputtered oxygen intensity to correct the original H intensity
808
876
  sputter_corrected_intensity = (
809
877
  small_dataset["ena_intensity"] - sputter_correction_factor * j_o_prime
810
878
  )
@@ -43,7 +43,7 @@ def mag_l1a(packet_filepath: Path) -> list[xr.Dataset]:
43
43
  A list of generated filenames.
44
44
  """
45
45
  packets = decom_mag.decom_packets(packet_filepath)
46
-
46
+ logger.info("Packet decoding complete, beginning L1A processing.")
47
47
  norm_data = packets["norm"]
48
48
  burst_data = packets["burst"]
49
49
 
@@ -188,7 +188,7 @@ def process_packets(
188
188
  secondary_packet_data.start_time,
189
189
  )
190
190
 
191
- # Sort primary and secondary into MAGo and MAGi by 24 hour chunks
191
+ # Sort primary and secondary into MAGo and MAGi
192
192
 
193
193
  if mago is None:
194
194
  mago = MagL1a(
@@ -205,13 +205,25 @@ class MagL1a:
205
205
  1 if the sensor is active, 0 if not
206
206
  shcoarse : int
207
207
  Mission elapsed time for the first packet, the start time for the whole day
208
- vectors : numpy.ndarray
209
- List of magnetic vector samples, starting at start_time. [x, y, z, range, time],
210
- where time is numpy.datetime64[ns]
208
+ starting_vectors : InitVar[numpy.ndarray]
209
+ Initvar to create the first entry in the vector list. This is to preserve the
210
+ external API of creating the object with the first set of vectors.
211
+ This cannot be accessed from an instance of the class. Instead, vectors
212
+ should be used.
211
213
  starting_packet : InitVar[MagL1aPacketProperties]
212
214
  The packet properties for the first packet in the day. As an InitVar, this
213
215
  cannot be accessed from an instance of the class. Instead, packet_definitions
214
216
  should be used.
217
+ vectors : numpy.ndarray
218
+ List of magnetic vector samples, starting at start_time. [x, y, z, range, time],
219
+ where time is numpy.datetime64[ns]. This is a property that concatenates the
220
+ internal vector list on demand.
221
+ compression_flags : numpy.ndarray
222
+ Array of flags to indicate compression and width for all timestamps in the
223
+ L1A file. Shaped like (n, 2) where n is the number of vectors. First value
224
+ is a boolean for compressed/uncompressed, second vector is a number between 0-20
225
+ if the data is compressed, which is the width in bits of the compressed data.
226
+ This is a property that concatenates the internal compression flags list.
215
227
  packet_definitions : dict[numpy.datetime64, MagL1aPacketProperties]
216
228
  Dictionary of packet properties for each packet in the day. The key is the start
217
229
  time of the packet, and the value is a dataclass of packet properties.
@@ -221,11 +233,20 @@ class MagL1a:
221
233
  List of missing sequence numbers in the day
222
234
  start_time : numpy.int64
223
235
  Start time of the day, in ns since J2000 epoch
224
- compression_flags : np.ndarray
236
+ _compression_flags_list : np.ndarray
225
237
  Array of flags to indication compression and width for all timestamps in the
226
238
  L1A file. Shaped like (n, 2) where n is the number of vectors. First value
227
239
  is a boolean for compressed/uncompressed, second vector is a number between 0-20
228
240
  if the data is compressed, which is the width in bits of the compressed data.
241
+ Transformed into a numpy array upon retrieval.
242
+ _vector_list : list
243
+ Internal list of vectors, used to build the final vectors attribute.
244
+ This is a list of numpy arrays, each with shape (n, 5) where n is the
245
+ number of vectors in that packet, and each vector is (x, y, z, range, time).
246
+ _vector_cache : numpy.ndarray | None
247
+ A cache of the concatenated vector list. This is None until the vectors
248
+ property is accessed, at which point it is created and stored here for future
249
+ access.
229
250
 
230
251
  Methods
231
252
  -------
@@ -248,23 +269,30 @@ class MagL1a:
248
269
  is_mago: bool
249
270
  is_active: int
250
271
  shcoarse: int
251
- vectors: np.ndarray
272
+ starting_vectors: InitVar[np.ndarray]
252
273
  starting_packet: InitVar[MagL1aPacketProperties]
253
274
  packet_definitions: dict[np.int64, MagL1aPacketProperties] = field(init=False)
254
275
  most_recent_sequence: int = field(init=False)
255
276
  missing_sequences: list[int] = field(default_factory=list)
256
277
  start_time: np.int64 = field(init=False)
257
- compression_flags: np.ndarray | None = field(init=False, default=None)
278
+ _compression_flags_list: list = field(default_factory=list)
279
+ _vector_list: list = field(init=False)
280
+ _vector_cache: np.ndarray | None = field(init=False, default=None)
258
281
 
259
- def __post_init__(self, starting_packet: MagL1aPacketProperties) -> None:
282
+ def __post_init__(
283
+ self, starting_vectors: np.ndarray, starting_packet: MagL1aPacketProperties
284
+ ) -> None:
260
285
  """
261
- Initialize the packet_definition dictionary and most_recent_sequence.
286
+ Initialize the vector list, packet_definition dictionary & most_recent_sequence.
262
287
 
263
288
  Parameters
264
289
  ----------
290
+ starting_vectors : numpy.ndarray
291
+ The vectors for the first packet in the day.
265
292
  starting_packet : MagL1aPacketProperties
266
293
  The packet properties for the first packet in the day, including start time.
267
294
  """
295
+ self._vector_list = [starting_vectors]
268
296
  self.start_time = np.int64(met_to_ttj2000ns(starting_packet.shcoarse))
269
297
  self.packet_definitions = {self.start_time: starting_packet}
270
298
  # most_recent_sequence is the sequence number of the packet used to initialize
@@ -272,6 +300,36 @@ class MagL1a:
272
300
  self.most_recent_sequence = starting_packet.src_seq_ctr
273
301
  self.update_compression_array(starting_packet, self.vectors.shape[0])
274
302
 
303
+ @property
304
+ def vectors(self) -> np.ndarray:
305
+ """
306
+ Concatenate the internal vector list into a numpy array.
307
+
308
+ If the array has already been created, return the cached version.
309
+
310
+ Returns
311
+ -------
312
+ np.ndarray
313
+ Array of vectors with shape (n, 5) where n is the number of vectors,
314
+ and each vector is (x, y, z, range, time).
315
+ """
316
+ if self._vector_cache is None:
317
+ self._vector_cache = np.concatenate(self._vector_list, axis=0)
318
+ return self._vector_cache
319
+
320
+ @property
321
+ def compression_flags(self) -> np.ndarray:
322
+ """
323
+ Return the compression flags array.
324
+
325
+ Returns
326
+ -------
327
+ np.ndarray
328
+ Array of compression flags with shape (n, 2) where n is the number of
329
+ vectors, and each entry is (is_compressed, compression_width).
330
+ """
331
+ return np.concatenate(self._compression_flags_list, axis=0)
332
+
275
333
  def append_vectors(
276
334
  self, additional_vectors: np.ndarray, packet_properties: MagL1aPacketProperties
277
335
  ) -> None:
@@ -285,9 +343,12 @@ class MagL1a:
285
343
  packet_properties : MagL1aPacketProperties
286
344
  Additional vector definition to add to the l0_packets dictionary.
287
345
  """
346
+ self._vector_list.append(additional_vectors)
347
+ # Invalidate the cache
348
+ self._vector_cache = None
349
+
288
350
  vector_sequence = packet_properties.src_seq_ctr
289
351
 
290
- self.vectors = np.concatenate([self.vectors, additional_vectors])
291
352
  start_time = np.int64(met_to_ttj2000ns(packet_properties.shcoarse))
292
353
  self.packet_definitions[start_time] = packet_properties
293
354
 
@@ -322,10 +383,7 @@ class MagL1a:
322
383
  [packet_properties.compression, packet_properties.compression_width],
323
384
  dtype=np.int8,
324
385
  )
325
- if self.compression_flags is None:
326
- self.compression_flags = new_flags
327
- else:
328
- self.compression_flags = np.concatenate([self.compression_flags, new_flags])
386
+ self._compression_flags_list.append(new_flags)
329
387
 
330
388
  @staticmethod
331
389
  def calculate_vector_time(
@@ -33,9 +33,11 @@ def remove_invalid_output_timestamps(
33
33
  numpy.ndarray
34
34
  All valid output timestamps where there exists input data.
35
35
  """
36
- if input_timestamps[0] > output_timestamps[0]:
37
- # Chop data where we don't have input timestamps to interpolate
38
- output_timestamps = output_timestamps[output_timestamps >= input_timestamps[0]]
36
+ # Chop data where we don't have input timestamps to interpolate
37
+ output_timestamps = output_timestamps[
38
+ (output_timestamps >= input_timestamps[0])
39
+ & (output_timestamps <= input_timestamps[-1])
40
+ ]
39
41
  return output_timestamps
40
42
 
41
43
 
@@ -45,7 +47,8 @@ def linear(
45
47
  output_timestamps: np.ndarray,
46
48
  input_rate: VecSec | None = None,
47
49
  output_rate: VecSec | None = None,
48
- ) -> np.ndarray:
50
+ extrapolate: bool = False,
51
+ ) -> tuple[np.ndarray, np.ndarray]:
49
52
  """
50
53
  Linear interpolation of input vectors to output timestamps.
51
54
 
@@ -63,6 +66,9 @@ def linear(
63
66
  Not required for this interpolation method.
64
67
  output_rate : VecSec, optional
65
68
  Not required for this interpolation method.
69
+ extrapolate : bool, optional
70
+ Whether to allow extrapolation of output timestamps outside the range of input
71
+ timestamps. Default is False.
66
72
 
67
73
  Returns
68
74
  -------
@@ -70,9 +76,12 @@ def linear(
70
76
  Interpolated vectors of shape (m, 3) where m is equal to the number of output
71
77
  timestamps. Contains x, y, z components of the vector.
72
78
  """
73
- # TODO: Remove invalid timestamps using remove_invalid_output_timestamps
79
+ if not extrapolate:
80
+ output_timestamps = remove_invalid_output_timestamps(
81
+ input_timestamps, output_timestamps
82
+ )
74
83
  spline = make_interp_spline(input_timestamps, input_vectors, k=1)
75
- return spline(output_timestamps)
84
+ return output_timestamps, spline(output_timestamps)
76
85
 
77
86
 
78
87
  def quadratic(
@@ -81,7 +90,7 @@ def quadratic(
81
90
  output_timestamps: np.ndarray,
82
91
  input_rate: VecSec | None = None,
83
92
  output_rate: VecSec | None = None,
84
- ) -> np.ndarray:
93
+ ) -> tuple[np.ndarray, np.ndarray]:
85
94
  """
86
95
  Quadratic interpolation of input vectors to output timestamps.
87
96
 
@@ -106,8 +115,11 @@ def quadratic(
106
115
  Interpolated vectors of shape (m, 3) where m is equal to the number of output
107
116
  timestamps. Contains x, y, z components of the vector.
108
117
  """
118
+ output_timestamps = remove_invalid_output_timestamps(
119
+ input_timestamps, output_timestamps
120
+ )
109
121
  spline = make_interp_spline(input_timestamps, input_vectors, k=2)
110
- return spline(output_timestamps)
122
+ return output_timestamps, spline(output_timestamps)
111
123
 
112
124
 
113
125
  def cubic(
@@ -116,7 +128,7 @@ def cubic(
116
128
  output_timestamps: np.ndarray,
117
129
  input_rate: VecSec | None = None,
118
130
  output_rate: VecSec | None = None,
119
- ) -> np.ndarray:
131
+ ) -> tuple[np.ndarray, np.ndarray]:
120
132
  """
121
133
  Cubic interpolation of input vectors to output timestamps.
122
134
 
@@ -141,8 +153,11 @@ def cubic(
141
153
  Interpolated vectors of shape (m, 3) where m is equal to the number of output
142
154
  timestamps. Contains x, y, z components of the vector.
143
155
  """
156
+ output_timestamps = remove_invalid_output_timestamps(
157
+ input_timestamps, output_timestamps
158
+ )
144
159
  spline = make_interp_spline(input_timestamps, input_vectors, k=3)
145
- return spline(output_timestamps)
160
+ return output_timestamps, spline(output_timestamps)
146
161
 
147
162
 
148
163
  def estimate_rate(timestamps: np.ndarray) -> VecSec:
@@ -245,7 +260,7 @@ def linear_filtered(
245
260
  output_timestamps: np.ndarray,
246
261
  input_rate: VecSec | None = None,
247
262
  output_rate: VecSec | None = None,
248
- ) -> np.ndarray:
263
+ ) -> tuple[np.ndarray, np.ndarray]:
249
264
  """
250
265
  Linear filtered interpolation of input vectors to output timestamps.
251
266
 
@@ -290,7 +305,7 @@ def quadratic_filtered(
290
305
  output_timestamps: np.ndarray,
291
306
  input_rate: VecSec | None = None,
292
307
  output_rate: VecSec | None = None,
293
- ) -> np.ndarray:
308
+ ) -> tuple[np.ndarray, np.ndarray]:
294
309
  """
295
310
  Quadratic filtered interpolation of input vectors to output timestamps.
296
311
 
@@ -317,6 +332,9 @@ def quadratic_filtered(
317
332
  Interpolated vectors of shape (m, 3) where m is equal to the number of output
318
333
  timestamps. Contains x, y, z components of the vector.
319
334
  """
335
+ output_timestamps = remove_invalid_output_timestamps(
336
+ input_timestamps, output_timestamps
337
+ )
320
338
  input_filtered, vectors_filtered = cic_filter(
321
339
  input_vectors, input_timestamps, output_timestamps, input_rate, output_rate
322
340
  )
@@ -329,7 +347,7 @@ def cubic_filtered(
329
347
  output_timestamps: np.ndarray,
330
348
  input_rate: VecSec | None = None,
331
349
  output_rate: VecSec | None = None,
332
- ) -> np.ndarray:
350
+ ) -> tuple[np.ndarray, np.ndarray]:
333
351
  """
334
352
  Cubic filtered interpolation of input vectors to output timestamps.
335
353
 
@@ -356,6 +374,9 @@ def cubic_filtered(
356
374
  Interpolated vectors of shape (m, 3) where m is equal to the number of output
357
375
  timestamps. Contains x, y, z components of the vector.
358
376
  """
377
+ output_timestamps = remove_invalid_output_timestamps(
378
+ input_timestamps, output_timestamps
379
+ )
359
380
  input_filtered, vectors_filtered = cic_filter(
360
381
  input_vectors, input_timestamps, output_timestamps, input_rate, output_rate
361
382
  )