fastlisaresponse 1.0.9__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fastlisaresponse might be problematic. Click here for more details.

@@ -0,0 +1,784 @@
1
+ from multiprocessing.sharedctypes import Value
2
+ import numpy as np
3
+ from typing import Optional, List
4
+ import warnings
5
+ from typing import Tuple
6
+ from copy import deepcopy
7
+
8
+
9
+ try:
10
+ import cupy as cp
11
+ from .cutils.pyresponse_gpu import get_response_wrap as get_response_wrap_gpu
12
+ from .cutils.pyresponse_gpu import get_tdi_delays_wrap as get_tdi_delays_wrap_gpu
13
+
14
+ gpu = True
15
+
16
+ except (ImportError, ModuleNotFoundError) as e:
17
+ pass
18
+
19
+ gpu = False
20
+
21
+ from .cutils.pyresponse_cpu import get_response_wrap as get_response_wrap_cpu
22
+ from .cutils.pyresponse_cpu import get_tdi_delays_wrap as get_tdi_delays_wrap_cpu
23
+ import time
24
+ import h5py
25
+
26
+ from scipy.interpolate import CubicSpline
27
+
28
+ from lisatools.detector import EqualArmlengthOrbits, Orbits
29
+ from lisatools.utils.utility import AET
30
+ from lisatools.utils.pointeradjust import pointer_adjust
31
+
32
+ YRSID_SI = 31558149.763545603
33
+
34
+
35
+ def get_factorial(n):
36
+ fact = 1
37
+
38
+ for i in range(1, n + 1):
39
+ fact = fact * i
40
+
41
+ return fact
42
+
43
+
44
+ from math import factorial
45
+
46
+ factorials = np.array([factorial(i) for i in range(30)])
47
+
48
+ C_inv = 3.3356409519815204e-09
49
+
50
+
51
+ class pyResponseTDI(object):
52
+ """Class container for fast LISA response function generation.
53
+
54
+ The class computes the generic time-domain response function for LISA.
55
+ It takes LISA constellation orbital information as input and properly determines
56
+ the response for these orbits numerically. This includes both the projection
57
+ of the gravitational waves onto the LISA constellation arms and combinations \
58
+ of projections into TDI observables. The methods and maths used can be found
59
+ [here](https://arxiv.org/abs/2204.06633).
60
+
61
+ This class is also GPU-accelerated, which is very helpful for Bayesian inference
62
+ methods.
63
+
64
+ Args:
65
+ sampling_frequency (double): The sampling rate in Hz.
66
+ num_pts (int): Number of points to produce for the final output template.
67
+ order (int, optional): Order of Lagrangian interpolation technique. Lower orders
68
+ will be faster. The user must make sure the order is sufficient for the
69
+ waveform being used. (default: 25)
70
+ tdi (str or list, optional): TDI setup. Currently, the stock options are
71
+ :code:`'1st generation'` and :code:`'2nd generation'`. Or the user can provide
72
+ a list of tdi_combinations of the form
73
+ :code:`{"link": 12, "links_for_delay": [21, 13, 31], "sign": 1, "type": "delay"}`.
74
+ :code:`'link'` (`int`) the link index (12, 21, 13, 31, 23, 32) for the projection (:math:`y_{ij}`).
75
+ :code:`'links_for_delay'` (`list`) are the link indexes as a list used for delays
76
+ applied to the link projections.
77
+ ``'sign'`` is the sign in front of the contribution to the TDI observable. It takes the value of `+1` or `-1`.
78
+ ``type`` is either ``"delay"`` or ``"advance"``. It is optional and defaults to ``"delay"``.
79
+ (default: ``"1st generation"``)
80
+ orbits (:class:`Orbits`, optional): Orbits class from LISA Analysis Tools. Works with LISA Orbits
81
+ outputs: `lisa-simulation.pages.in2p3.fr/orbits/ <https://lisa-simulation.pages.in2p3.fr/orbits/latest/>`_.
82
+ (default: :class:`EqualArmlengthOrbits`)
83
+ tdi_chan (str, optional): Which TDI channel combination to return. Choices are :code:`'XYZ'`,
84
+ :code:`AET`, or :code:`AE`. (default: :code:`'XYZ'`)
85
+ tdi_orbits (:class:`Orbits`, optional): Set if different orbits from projection.
86
+ Orbits class from LISA Analysis Tools. Works with LISA Orbits
87
+ outputs: `lisa-simulation.pages.in2p3.fr/orbits/ <https://lisa-simulation.pages.in2p3.fr/orbits/latest/>`_.
88
+ (default: :class:`EqualArmlengthOrbits`)
89
+ use_gpu (bool, optional): If True, run code on the GPU. (default: :code:`False`)
90
+
91
+ Attributes:
92
+ A_in (xp.ndarray): Array containing y values for linear spline of A
93
+ during Lagrangian interpolation.
94
+ buffer_integer (int): Self-determined buffer necesary for the given
95
+ value for :code:`order`.
96
+ channels_no_delays (2D np.ndarray): Carrier of link index and sign information
97
+ for arms that do not get delayed during TDI computation.
98
+ deps (double): The spacing between Epsilon values in the interpolant
99
+ for the A quantity in Lagrangian interpolation. Hard coded to
100
+ 1/(:code:`num_A` - 1).
101
+ dt (double): Inverse of the sampling_frequency.
102
+ E_in (xp.ndarray): Array containing y values for linear spline of E
103
+ during Lagrangian interpolation.
104
+ half_order (int): Half of :code:`order` adjusted to be :code:`int`.
105
+ link_inds (xp.ndarray): Link indexes for delays in TDI.
106
+ link_space_craft_0_in (xp.ndarray): Link indexes for receiver on each
107
+ arm of the LISA constellation.
108
+ link_space_craft_1_in (xp.ndarray): Link indexes for emitter on each
109
+ arm of the LISA constellation.
110
+ nlinks (int): The number of links in the constellation. Typically 6.
111
+ num_A (int): Number of points to use for A spline values used in the Lagrangian
112
+ interpolation. This is hard coded to 1001.
113
+ num_channels (int): 3.
114
+ num_pts (int): Number of points to produce for the final output template.
115
+ order (int): Order of Lagrangian interpolation technique.
116
+ sampling_frequency (double): The sampling rate in Hz.
117
+ tdi (str or list): TDI setup.
118
+ tdi_buffer (int): The buffer necessary for all information needed at early times
119
+ for the TDI computation. This is set to 200.
120
+ use_gpu (bool): If True, run on GPU.
121
+ xp (obj): Either Numpy or Cupy.
122
+
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ sampling_frequency,
128
+ num_pts,
129
+ order=25,
130
+ tdi="1st generation",
131
+ orbits: Optional[Orbits] = EqualArmlengthOrbits,
132
+ tdi_orbits: Optional[Orbits] = None,
133
+ tdi_chan="XYZ",
134
+ use_gpu=False,
135
+ ):
136
+
137
+ # setup all quantities
138
+ self.sampling_frequency = sampling_frequency
139
+ self.dt = 1 / sampling_frequency
140
+ self.tdi_buffer = 200
141
+
142
+ self.num_pts = num_pts
143
+
144
+ # Lagrangian interpolation setup
145
+ self.order = order
146
+ self.buffer_integer = self.order * 2 + 1
147
+ self.half_order = int((order + 1) / 2)
148
+
149
+ # setup TDI information
150
+ self.tdi = tdi
151
+ self.tdi_chan = tdi_chan
152
+
153
+ # setup functions for GPU or CPU
154
+ self.use_gpu = use_gpu
155
+
156
+ # prepare the interpolation of A and E in the Lagrangian interpolation
157
+ self._fill_A_E()
158
+
159
+ # setup orbits
160
+ self.response_orbits = orbits
161
+
162
+ if tdi_orbits is None:
163
+ tdi_orbits = self.response_orbits
164
+
165
+ self.tdi_orbits = tdi_orbits
166
+
167
+ if self.num_pts * self.dt > self.response_orbits.t_base.max():
168
+ warnings.warn(
169
+ "Input number of points is longer in time than available orbital information. Trimming to fit orbital information."
170
+ )
171
+ self.num_pts = int(self.response_orbits.t_base.max() / self.dt)
172
+
173
+ # setup spacecraft links indexes
174
+
175
+ # setup TDI info
176
+ self._init_TDI_delays()
177
+
178
+ @property
179
+ def response_gen(self) -> callable:
180
+ """CPU/GPU function for generating the projections."""
181
+ return get_response_wrap_cpu if not self.use_gpu else get_response_wrap_gpu
182
+
183
+ @property
184
+ def tdi_gen(self) -> callable:
185
+ """CPU/GPU function for generating tdi."""
186
+ return get_tdi_delays_wrap_cpu if not self.use_gpu else get_tdi_delays_wrap_gpu
187
+
188
+ @property
189
+ def xp(self) -> object:
190
+ return np if not self.use_gpu else cp
191
+
192
+ @property
193
+ def response_orbits(self) -> Orbits:
194
+ """Response function orbits."""
195
+ return self._response_orbits
196
+
197
+ @response_orbits.setter
198
+ def response_orbits(self, orbits: Orbits) -> None:
199
+ """Set response orbits."""
200
+
201
+ if orbits is None:
202
+ orbits = EqualArmlengthOrbits()
203
+
204
+ assert isinstance(orbits, Orbits)
205
+
206
+ self._response_orbits = deepcopy(orbits)
207
+
208
+ if not self._response_orbits.configured:
209
+ self._response_orbits.configure(linear_interp_setup=True)
210
+
211
+ @property
212
+ def tdi_orbits(self) -> Orbits:
213
+ """TDI function orbits."""
214
+ return self._tdi_orbits
215
+
216
+ @tdi_orbits.setter
217
+ def tdi_orbits(self, orbits: Orbits) -> None:
218
+ """Set TDI orbits."""
219
+
220
+ if orbits is None:
221
+ orbits = EqualArmlengthOrbits()
222
+
223
+ assert isinstance(orbits, Orbits)
224
+ assert orbits.use_gpu == self.use_gpu
225
+
226
+ self._tdi_orbits = deepcopy(orbits)
227
+
228
+ if not self._tdi_orbits.configured:
229
+ self._tdi_orbits.configure(linear_interp_setup=True)
230
+
231
+ @property
232
+ def citation(self):
233
+ """Get citations for use of this code"""
234
+
235
+ return """
236
+ # TODO add
237
+ """
238
+
239
+ def _fill_A_E(self):
240
+ """Set up A and E terms inside the Lagrangian interpolant"""
241
+
242
+ factorials = np.asarray([float(get_factorial(n)) for n in range(40)])
243
+
244
+ # base quantities for linear interpolant over A
245
+ self.num_A = 1001
246
+ self.deps = 1.0 / (self.num_A - 1)
247
+
248
+ eps = np.arange(self.num_A) * self.deps
249
+
250
+ h = self.half_order
251
+
252
+ denominator = factorials[h - 1] * factorials[h]
253
+
254
+ # prepare A
255
+ A_in = np.zeros_like(eps)
256
+ for j, eps_i in enumerate(eps):
257
+ A = 1.0
258
+ for i in range(1, h):
259
+ A *= (i + eps_i) * (i + 1 - eps_i)
260
+
261
+ A /= denominator
262
+ A_in[j] = A
263
+
264
+ self.A_in = self.xp.asarray(A_in)
265
+
266
+ # prepare E
267
+ E_in = self.xp.zeros((self.half_order,))
268
+
269
+ for j in range(1, self.half_order):
270
+ first_term = factorials[h - 1] / factorials[h - 1 - j]
271
+ second_term = factorials[h] / factorials[h + j]
272
+ value = first_term * second_term
273
+ value = value * (-1.0) ** j
274
+ E_in[j - 1] = value
275
+
276
+ self.E_in = self.xp.asarray(E_in)
277
+
278
+ def _init_TDI_delays(self):
279
+ """Initialize TDI specific information"""
280
+
281
+ # setup the actual TDI combination
282
+ if self.tdi in ["1st generation", "2nd generation"]:
283
+ # tdi 1.0
284
+ tdi_combinations = [
285
+ {"link": 13, "links_for_delay": [], "sign": +1},
286
+ {"link": 31, "links_for_delay": [13], "sign": +1},
287
+ {"link": 12, "links_for_delay": [13, 31], "sign": +1},
288
+ {"link": 21, "links_for_delay": [13, 31, 12], "sign": +1},
289
+ {"link": 12, "links_for_delay": [], "sign": -1},
290
+ {"link": 21, "links_for_delay": [12], "sign": -1},
291
+ {"link": 13, "links_for_delay": [12, 21], "sign": -1},
292
+ {"link": 31, "links_for_delay": [12, 21, 13], "sign": -1},
293
+ ]
294
+
295
+ if self.tdi == "2nd generation":
296
+ # tdi 2.0 is tdi 1.0 + additional terms
297
+ tdi_combinations += [
298
+ {"link": 12, "links_for_delay": [13, 31, 12, 21], "sign": +1},
299
+ {"link": 21, "links_for_delay": [13, 31, 12, 21, 12], "sign": +1},
300
+ {
301
+ "link": 13,
302
+ "links_for_delay": [13, 31, 12, 21, 12, 21],
303
+ "sign": +1,
304
+ },
305
+ {
306
+ "link": 31,
307
+ "links_for_delay": [13, 31, 12, 21, 12, 21, 13],
308
+ "sign": +1,
309
+ },
310
+ {"link": 13, "links_for_delay": [12, 21, 13, 31], "sign": -1},
311
+ {"link": 31, "links_for_delay": [12, 21, 13, 31, 13], "sign": -1},
312
+ {
313
+ "link": 12,
314
+ "links_for_delay": [12, 21, 13, 31, 13, 31],
315
+ "sign": -1,
316
+ },
317
+ {
318
+ "link": 21,
319
+ "links_for_delay": [12, 21, 13, 31, 13, 31, 12],
320
+ "sign": -1,
321
+ },
322
+ ]
323
+
324
+ elif isinstance(self.tdi, list):
325
+ tdi_combinations = self.tdi
326
+
327
+ else:
328
+ raise ValueError(
329
+ "tdi kwarg should be '1st generation', '2nd generation', or a list with a specific tdi combination."
330
+ )
331
+ self.tdi_combinations = tdi_combinations
332
+
333
+ @property
334
+ def tdi_combinations(self) -> List:
335
+ """TDI Combination setup"""
336
+ return self._tdi_combinations
337
+
338
+ @tdi_combinations.setter
339
+ def tdi_combinations(self, tdi_combinations: List) -> None:
340
+ """Set TDI combinations and fill out setup."""
341
+ tdi_base_links = []
342
+ tdi_link_combinations = []
343
+ tdi_signs = []
344
+ channels = []
345
+
346
+ for permutation_number in range(3):
347
+ for tmp in tdi_combinations:
348
+ if len(tmp["links_for_delay"]) == 0:
349
+ tdi_base_links.append(
350
+ self._cyclic_permutation(tmp["link"], permutation_number)
351
+ )
352
+ tdi_link_combinations.append(-11)
353
+ tdi_signs.append(tmp["sign"])
354
+ channels.append(permutation_number)
355
+ continue
356
+
357
+ for link_delay in tmp["links_for_delay"]:
358
+ tdi_base_links.append(
359
+ self._cyclic_permutation(tmp["link"], permutation_number)
360
+ )
361
+ tdi_link_combinations.append(
362
+ self._cyclic_permutation(link_delay, permutation_number)
363
+ )
364
+ tdi_signs.append(tmp["sign"])
365
+ channels.append(permutation_number)
366
+
367
+ self.tdi_base_links = self.xp.asarray(tdi_base_links).astype(self.xp.int32)
368
+ self.tdi_link_combinations = self.xp.asarray(tdi_link_combinations).astype(
369
+ self.xp.int32
370
+ )
371
+ self.tdi_signs = self.xp.asarray(tdi_signs).astype(self.xp.int32)
372
+ self.channels = self.xp.asarray(channels).astype(self.xp.int32)
373
+ assert (
374
+ len(self.tdi_base_links)
375
+ == len(self.tdi_link_combinations)
376
+ == len(self.tdi_signs)
377
+ == len(self.channels)
378
+ )
379
+
380
+ def _cyclic_permutation(self, link, permutation):
381
+ """permute indexes by cyclic permutation"""
382
+ link_str = str(link)
383
+
384
+ out = ""
385
+ for i in range(2):
386
+ sc = int(link_str[i])
387
+ temp = sc + permutation
388
+ if temp > 3:
389
+ temp = temp % 3
390
+ out += str(temp)
391
+
392
+ return int(out)
393
+
394
+ @property
395
+ def y_gw(self):
396
+ """Projections along the arms"""
397
+ return self.y_gw_flat.reshape(self.nlinks, -1)
398
+
399
+ def _data_time_check(
400
+ self, t_data: np.ndarray, input_in: np.ndarray
401
+ ) -> Tuple[np.ndarray, np.ndarray]:
402
+ # remove input data that goes beyond orbital information
403
+ if t_data.max() > self.response_orbits.t.max():
404
+ warnings.warn(
405
+ "Input waveform is longer than available orbital information. Trimming to fit orbital information."
406
+ )
407
+
408
+ max_ind = np.where(t_data <= self.response_orbits.t.max())[0][-1]
409
+
410
+ t_data = t_data[:max_ind]
411
+ input_in = input_in[:max_ind]
412
+ return (t_data, input_in)
413
+
414
+ def get_projections(self, input_in, lam, beta, t0=10000.0):
415
+ """Compute projections of GW signal on to LISA constellation
416
+
417
+ Args:
418
+ input_in (xp.ndarray): Input complex time-domain signal. It should be of the form:
419
+ :math:`h_+ + ih_x`. If using the GPU for the response, this should be a CuPy array.
420
+ lam (double): Ecliptic Longitude in radians.
421
+ beta (double): Ecliptic Latitude in radians.
422
+ t0 (double, optional): Time at which to the waveform. Because of the delays
423
+ and interpolation towards earlier times, the beginning of the waveform
424
+ is garbage. ``t0`` tells the waveform generator where to start the waveform
425
+ compraed to ``t=0``.
426
+
427
+ Raises:
428
+ ValueError: If ``t0`` is not large enough.
429
+
430
+
431
+ """
432
+ self.tdi_start_ind = int(t0 / self.dt)
433
+ # get necessary buffer for TDI
434
+ self.check_tdi_buffer = int(100.0 * self.sampling_frequency) + 4 * self.order
435
+
436
+ from copy import deepcopy
437
+
438
+ tmp_orbits = deepcopy(self.response_orbits.x_base)
439
+ self.projection_buffer = (
440
+ int(
441
+ (
442
+ np.sum(
443
+ tmp_orbits.copy() * tmp_orbits.copy(),
444
+ axis=-1,
445
+ )
446
+ ** (1 / 2)
447
+ ).max()
448
+ * C_inv
449
+ )
450
+ + 4 * self.order
451
+ )
452
+ self.projections_start_ind = self.tdi_start_ind - 2 * self.check_tdi_buffer
453
+
454
+ if self.projections_start_ind < self.projection_buffer:
455
+ raise ValueError(
456
+ "Need to increase t0. The initial buffer is not large enough."
457
+ )
458
+
459
+ # determine sky vectors
460
+ k = np.zeros(3, dtype=np.float64)
461
+ u = np.zeros(3, dtype=np.float64)
462
+ v = np.zeros(3, dtype=np.float64)
463
+
464
+ self.num_total_points = len(input_in)
465
+
466
+ cosbeta = np.cos(beta)
467
+ sinbeta = np.sin(beta)
468
+
469
+ coslam = np.cos(lam)
470
+ sinlam = np.sin(lam)
471
+
472
+ v[0] = -sinbeta * coslam
473
+ v[1] = -sinbeta * sinlam
474
+ v[2] = cosbeta
475
+ u[0] = sinlam
476
+ u[1] = -coslam
477
+ u[2] = 0.0
478
+ k[0] = -cosbeta * coslam
479
+ k[1] = -cosbeta * sinlam
480
+ k[2] = -sinbeta
481
+
482
+ self.nlinks = 6
483
+ k_in = self.xp.asarray(k)
484
+ u_in = self.xp.asarray(u)
485
+ v_in = self.xp.asarray(v)
486
+
487
+ input_in = self.xp.asarray(input_in)
488
+
489
+ t_data = self.xp.arange(len(input_in)) * self.dt
490
+
491
+ t_data, input_in = self._data_time_check(t_data, input_in)
492
+
493
+ assert len(input_in) >= self.num_pts
494
+ y_gw = self.xp.zeros((self.nlinks * self.num_pts,), dtype=self.xp.float64)
495
+
496
+ self.response_gen(
497
+ y_gw,
498
+ t_data,
499
+ k_in,
500
+ u_in,
501
+ v_in,
502
+ self.dt,
503
+ len(input_in),
504
+ input_in,
505
+ len(input_in),
506
+ self.order,
507
+ self.sampling_frequency,
508
+ self.buffer_integer,
509
+ self.A_in,
510
+ self.deps,
511
+ len(self.A_in),
512
+ self.E_in,
513
+ self.projections_start_ind,
514
+ self.response_orbits,
515
+ )
516
+
517
+ self.y_gw_flat = y_gw
518
+ self.y_gw_length = self.num_pts
519
+
520
+ @property
521
+ def XYZ(self):
522
+ """Return links as an array"""
523
+ return self.delayed_links_flat.reshape(3, -1)
524
+
525
+ def get_tdi_delays(self, y_gw=None):
526
+ """Get TDI combinations from projections.
527
+
528
+ This functions generates the TDI combinations from the projections
529
+ computed with ``get_projections``. It can return XYZ, AET, or AE depending
530
+ on what was input for ``tdi_chan`` into ``__init__``.
531
+
532
+ Args:
533
+ y_gw (xp.ndarray, optional): Projections along each link. Must be
534
+ a 2D ``numpy`` or ``cupy`` array with shape: ``(nlinks, num_pts)``.
535
+ The links must be entered in the proper order in the code:
536
+ 21, 12, 31, 13, 32, 23. (Default: None)
537
+
538
+ Returns:
539
+ tuple: (X,Y,Z) or (A,E,T) or (A,E)
540
+
541
+ Raises:
542
+ ValueError: If ``tdi_chan`` is not one of the options.
543
+
544
+
545
+ """
546
+ self.delayed_links_flat = self.xp.zeros(
547
+ (3, self.num_pts), dtype=self.xp.float64
548
+ )
549
+
550
+ # y_gw entered directly
551
+ if y_gw is not None:
552
+ assert y_gw.shape == (len(self.link_space_craft_0_in), self.num_pts)
553
+ self.y_gw_flat = y_gw.flatten().copy()
554
+ self.y_gw_length = self.num_pts
555
+
556
+ elif self.y_gw_flat is None:
557
+ raise ValueError(
558
+ "Need to either enter projection array or have this code determine projections."
559
+ )
560
+
561
+ self.delayed_links_flat = self.delayed_links_flat.flatten()
562
+
563
+ t_data = self.xp.arange(self.y_gw_length) * self.dt
564
+
565
+ self.tdi_gen(
566
+ self.delayed_links_flat,
567
+ self.y_gw_flat,
568
+ self.y_gw_length,
569
+ self.num_pts,
570
+ t_data,
571
+ self.tdi_base_links,
572
+ self.tdi_link_combinations,
573
+ self.tdi_signs,
574
+ self.channels,
575
+ len(self.tdi_base_links), # num_units
576
+ 3, # num channels
577
+ self.order,
578
+ self.sampling_frequency,
579
+ self.buffer_integer,
580
+ self.A_in,
581
+ self.deps,
582
+ len(self.A_in),
583
+ self.E_in,
584
+ self.tdi_start_ind,
585
+ self.tdi_orbits,
586
+ )
587
+
588
+ if self.tdi_chan == "XYZ":
589
+ X, Y, Z = self.XYZ
590
+ return X, Y, Z
591
+
592
+ elif self.tdi_chan == "AET" or self.tdi_chan == "AE":
593
+ X, Y, Z = self.XYZ
594
+ A, E, T = AET(X, Y, Z)
595
+ if self.tdi_chan == "AET":
596
+ return A, E, T
597
+
598
+ else:
599
+ return A, E
600
+
601
+ else:
602
+ raise ValueError("tdi_chan must be 'XYZ', 'AET' or 'AE'.")
603
+
604
+
605
+ class ResponseWrapper:
606
+ """Wrapper to produce LISA TDI from TD waveforms
607
+
608
+ This class takes a waveform generator that produces :math:`h_+ \pm ih_x`.
609
+ (:code:`flip_hx` is used if the waveform produces :math:`h_+ - ih_x`).
610
+ It takes the complex waveform in the SSB frame and produces the TDI channels
611
+ according to settings chosen for :class:`pyResponseTDI`.
612
+
613
+ The waveform generator must have :code:`kwargs` with :code:`T` for the observation
614
+ time in years and :code:`dt` for the time step in seconds.
615
+
616
+ Args:
617
+ waveform_gen (obj): Function or class (with a :code:`__call__` function) that takes parameters and produces
618
+ :math:`h_+ \pm h_x`.
619
+ Tobs (double): Observation time in years.
620
+ dt (double): Time between time samples in seconds. The inverse of the sampling frequency.
621
+ index_lambda (int): The user will input parameters. The code will read these in
622
+ with the :code:`*args` formalism producing a list. :code:`index_lambda`
623
+ tells the class the index of the ecliptic longitude within this list of
624
+ parameters.
625
+ index_beta (int): The user will input parameters. The code will read these in
626
+ with the :code:`*args` formalism producing a list. :code:`index_beta`
627
+ tells the class the index of the ecliptic latitude (or ecliptic polar angle)
628
+ within this list of parameters.
629
+ t0 (double, optional): Start of returned waveform in seconds leaving ample time for garbage at
630
+ the beginning of the waveform. It also removed the same amount from the end. (Default: 10000.0)
631
+ flip_hx (bool, optional): If True, :code:`waveform_gen` produces :math:`h_+ - ih_x`.
632
+ :class:`pyResponseTDI` takes :math:`h_+ + ih_x`, so this setting will
633
+ multiply the cross polarization term out of the waveform generator by -1.
634
+ (Default: :code:`False`)
635
+ remove_sky_coords (bool, optional): If True, remove the sky coordinates from
636
+ the :code:`*args` list. This should be set to True if the waveform
637
+ generator does not take in the sky information. (Default: :code:`False`)
638
+ is_ecliptic_latitude (bool, optional): If True, the latitudinal sky
639
+ coordinate is the ecliptic latitude. If False, thes latitudinal sky
640
+ coordinate is the polar angle. In this case, the code will
641
+ convert it with :math:`\beta=\pi / 2 - \Theta`. (Default: :code:`True`)
642
+ use_gpu (bool, optional): If True, use GPU. (Default: :code:`False`)
643
+ remove_garbage (bool or str, optional): If True, it removes everything before ``t0``
644
+ and after the end time - ``t0``. If ``str``, it must be ``"zero"``. If ``"zero"``,
645
+ it will not remove the points, but set them to zero. This is ideal for PE. (Default: ``True``)
646
+ n_overide (int, optional): If not ``None``, this will override the determination of
647
+ the number of points, ``n``, from ``int(T/dt)`` to the ``n_overide``. This is used
648
+ if there is an issue matching points between the waveform generator and the response
649
+ model.
650
+ orbits (:class:`Orbits`, optional): Orbits class from LISA Analysis Tools. Works with LISA Orbits
651
+ outputs: `lisa-simulation.pages.in2p3.fr/orbits/ <https://lisa-simulation.pages.in2p3.fr/orbits/latest/>`_.
652
+ (default: :class:`EqualArmlengthOrbits`)
653
+ **kwargs (dict, optional): Keyword arguments passed to :class:`pyResponseTDI`.
654
+
655
+ """
656
+
657
+ def __init__(
658
+ self,
659
+ waveform_gen,
660
+ Tobs,
661
+ dt,
662
+ index_lambda,
663
+ index_beta,
664
+ t0=10000.0,
665
+ flip_hx=False,
666
+ remove_sky_coords=False,
667
+ is_ecliptic_latitude=True,
668
+ use_gpu=False,
669
+ remove_garbage=True,
670
+ n_overide=None,
671
+ orbits: Optional[Orbits] = EqualArmlengthOrbits,
672
+ **kwargs,
673
+ ):
674
+
675
+ # store all necessary information
676
+ self.waveform_gen = waveform_gen
677
+ self.index_lambda = index_lambda
678
+ self.index_beta = index_beta
679
+ self.dt = dt
680
+ self.t0 = t0
681
+ self.sampling_frequency = 1.0 / dt
682
+
683
+ if orbits is None:
684
+ orbits = EqualArmlengthOrbits()
685
+
686
+ assert isinstance(orbits, Orbits)
687
+
688
+ if Tobs * YRSID_SI > orbits.t_base.max():
689
+ warnings.warn(
690
+ f"Tobs is larger than available orbital information time array. Reducing Tobs to {orbits.t_base.max()}"
691
+ )
692
+ Tobs = orbits.t_base.max() / YRSID_SI
693
+
694
+ if n_overide is not None:
695
+ if not isinstance(n_overide, int):
696
+ raise ValueError("n_overide must be an integer if not None.")
697
+ self.n = n_overide
698
+
699
+ else:
700
+ self.n = int(Tobs * YRSID_SI / dt)
701
+
702
+ self.Tobs = self.n * dt
703
+ self.is_ecliptic_latitude = is_ecliptic_latitude
704
+ self.remove_sky_coords = remove_sky_coords
705
+ self.flip_hx = flip_hx
706
+ self.remove_garbage = remove_garbage
707
+
708
+ # initialize response function class
709
+ self.response_model = pyResponseTDI(
710
+ self.sampling_frequency, self.n, orbits=orbits, use_gpu=use_gpu, **kwargs
711
+ )
712
+
713
+ self.use_gpu = use_gpu
714
+
715
+ self.Tobs = (self.n * self.response_model.dt) / YRSID_SI
716
+
717
+ @property
718
+ def xp(self) -> object:
719
+ return np if not self.use_gpu else cp
720
+
721
+ @property
722
+ def citation(self):
723
+ """Get citations for use of this code"""
724
+
725
+ return """
726
+ # TODO add
727
+ """
728
+
729
+ def __call__(self, *args, **kwargs):
730
+ """Run the waveform and response generation
731
+
732
+ Args:
733
+ *args (list): Arguments to the waveform generator. This must include
734
+ the sky coordinates.
735
+ **kwargs (dict): kwargs necessary for the waveform generator.
736
+
737
+ Return:
738
+ list: TDI Channels.
739
+
740
+ """
741
+
742
+ args = list(args)
743
+
744
+ # get sky coords
745
+ beta = args[self.index_beta]
746
+ lam = args[self.index_lambda]
747
+
748
+ # remove them from the list if waveform generator does not take them
749
+ if self.remove_sky_coords:
750
+ args.pop(self.index_beta)
751
+ args.pop(self.index_lambda)
752
+
753
+ # transform polar angle
754
+ if not self.is_ecliptic_latitude:
755
+ beta = np.pi / 2.0 - beta
756
+
757
+ # add the new Tobs and dt info to the waveform generator kwargs
758
+ kwargs["T"] = self.Tobs
759
+ kwargs["dt"] = self.dt
760
+
761
+ # get the waveform
762
+ h = self.waveform_gen(*args, **kwargs)
763
+
764
+ if self.flip_hx:
765
+ h = h.real - 1j * h.imag
766
+
767
+ self.response_model.get_projections(h, lam, beta, t0=self.t0)
768
+ tdi_out = self.response_model.get_tdi_delays()
769
+
770
+ out = list(tdi_out)
771
+ if self.remove_garbage is True: # bool
772
+ for i in range(len(out)):
773
+ out[i] = out[i][
774
+ self.response_model.tdi_start_ind : -self.response_model.tdi_start_ind
775
+ ]
776
+
777
+ elif isinstance(self.remove_garbage, str): # bool
778
+ if self.remove_garbage != "zero":
779
+ raise ValueError("remove_garbage must be True, False, or 'zero'.")
780
+ for i in range(len(out)):
781
+ out[i][: self.response_model.tdi_start_ind] = 0.0
782
+ out[i][-self.response_model.tdi_start_ind :] = 0.0
783
+
784
+ return out