jax-datetime 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """JAX compatible datetime types."""
15
+ # pylint: disable=g-multiple-import,g-importing-member,useless-import-alias
16
+
17
+ # Note: import <name> as <name> is required for names to be exported.
18
+ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570
19
+ from jax_datetime._src.core import (
20
+ Datetime as Datetime,
21
+ Timedelta as Timedelta,
22
+ to_datetime as to_datetime,
23
+ to_timedelta as to_timedelta,
24
+ )
25
+ from jax_datetime._src.numpy_funcs import (
26
+ interp as interp,
27
+ searchsorted as searchsorted,
28
+ )
29
+
30
+ __version__ = "0.1.0" # keep sync with pyproject.toml
File without changes
@@ -0,0 +1,666 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Core implementation of jax_datetime."""
15
+ from __future__ import annotations
16
+
17
+ import datetime
18
+ import functools
19
+ import math
20
+ import operator
21
+ from typing import Self, overload
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ Array = np.ndarray | jax.Array
29
+ Integer = int | np.integer | Array
30
+ Float = float | np.floating | Integer
31
+
32
+ DeviceLike = jax.Device | jax.sharding.Sharding | None
33
+
34
+
35
+ class PytreeArray:
36
+ """Base class for an array-like object implemented as a pytree.
37
+
38
+ The convention for these objects is that they are built out of a collection
39
+ of JAX arrays with the same shape.
40
+
41
+ In principle, they could be implemented as a custom JAX dtype, but such APIs
42
+ are not yet available.
43
+
44
+ There is a long list of methods we could potentially implement here. For now,
45
+ we only include the basics.
46
+ """
47
+
48
+ @property
49
+ def shape(self) -> tuple[int, ...]:
50
+ shapes = [jnp.shape(x) for x in jax.tree.leaves(self)]
51
+ if len(set(shapes)) != 1:
52
+ raise ValueError('all leaves must have the same shape')
53
+ return shapes[0]
54
+
55
+ @property
56
+ def size(self) -> int:
57
+ return math.prod(self.shape)
58
+
59
+ def __len__(self) -> int:
60
+ return self.shape[0]
61
+
62
+ @property
63
+ def ndim(self) -> int:
64
+ return len(self.shape)
65
+
66
+ def transpose(self, axes: tuple[int, ...]) -> PytreeArray:
67
+ return jax.tree.map(lambda x: x.transpose(axes), self)
68
+
69
+ def __getitem__(self, index) -> Array:
70
+ return jax.tree.map(lambda x: x[index], self)
71
+
72
+ # Take precedence over numpy arrays in binary arithmetic.
73
+ __array_priority__ = 100
74
+
75
+
76
+ def _as_integer_array(x: Integer, name: str) -> np.ndarray | jax.Array:
77
+ if not isinstance(x, jax.Array):
78
+ x = np.asarray(x)
79
+ if not np.issubdtype(x.dtype, np.integer):
80
+ raise ValueError(f'{name} must be an integer array, got {x.dtype}')
81
+ return x
82
+
83
+
84
+ def _zeros_like(other: np.ndarray | jax.Array):
85
+ zeros_like = jnp.zeros_like if isinstance(other, jax.Array) else np.zeros_like
86
+ return zeros_like(other)
87
+
88
+
89
+ _SECONDS_PER_DAY = 24 * 60 * 60
90
+
91
+
92
+ def _asarray(x: np.integer | Array) -> Array:
93
+ return np.asarray(x) if isinstance(x, np.integer) else x
94
+
95
+
96
+ def _normalize_days_seconds(days: Array, seconds: Array) -> tuple[Array, Array]:
97
+ assert np.issubdtype(days.dtype, np.integer)
98
+ assert np.issubdtype(seconds.dtype, np.integer)
99
+ days_delta, seconds = divmod(seconds, _SECONDS_PER_DAY)
100
+ days = _asarray(days + days_delta)
101
+ seconds = _asarray(seconds)
102
+ return (days, seconds)
103
+
104
+
105
+ def _to_int_seconds(delta: Timedelta) -> jnp.ndarray:
106
+ """Returns total seconds as an integer."""
107
+ # This works for timedeltas less than 2**31 seconds, which is ~68 years.
108
+ return delta.days * _SECONDS_PER_DAY + delta.seconds
109
+
110
+
111
+ _INT32_MAX = 2**31 - 1
112
+
113
+
114
+ @jax.jit
115
+ def _timedelta_floordiv(
116
+ numerator: Timedelta, divisor: Timedelta
117
+ ) -> jnp.ndarray:
118
+ """Implements Timedelta // Timedelta as accurately as possible."""
119
+ # We need a separate helper function because it must be jitted to avoid
120
+ # raising OverflowError if int32 calculations are invalid.
121
+ int_calc = _to_int_seconds(numerator) // _to_int_seconds(divisor)
122
+ float_calc = (numerator.total_seconds() // divisor.total_seconds()).astype(
123
+ int
124
+ )
125
+ int32_seconds_valid = (abs(numerator.total_seconds()) < _INT32_MAX) & (
126
+ abs(divisor.total_seconds()) < _INT32_MAX
127
+ )
128
+ return jnp.where(int32_seconds_valid, int_calc, float_calc)
129
+
130
+
131
+ @jax.tree_util.register_pytree_node_class
132
+ class Timedelta(PytreeArray):
133
+ """JAX compatible time duration, stored in days and seconds.
134
+
135
+ Like datetime.timedelta, the `Timedelta` constructor normalizes the seconds
136
+ field to fall in the range `[0, 24*60*60)`, with whole days moved into `days`.
137
+
138
+ Timedelta is implemented internally as a JAX pytree of integer arrays of
139
+ days and seconds. Using JAX's default int32 precision, Timedelta can exactly
140
+ represent durations over 5 million years.
141
+
142
+ You can either use the `Timedelta` constructor directly, or use `to_timedelta`
143
+ to convert from `datetime.timedelta`, `np.timedelta64` or integers with
144
+ units::
145
+
146
+ >>> import jax_datetime as jdt
147
+ >>> import datetime
148
+ >>> jdt.Timedelta(days=1)
149
+ jax_datetime.Timedelta(days=1, seconds=0)
150
+ >>> jdt.to_timedelta(datetime.timedelta(days=1))
151
+ jax_datetime.Timedelta(days=1, seconds=0)
152
+ >>> jdt.to_timedelta(1, 'D')
153
+ jax_datetime.Timedelta(days=1, seconds=0)
154
+
155
+ Attributes:
156
+ days: integer JAX array indicating the number of days in the duration.
157
+ seconds: integer JAX array indicating the number of seconds in the duration,
158
+ normalized to fall in the range `[0, 24*60*60)`.
159
+ """
160
+
161
+ # TODO(shoyer): can we rewrite this a custom JAX dtype, like jax.random.key?
162
+
163
+ def __init__(
164
+ self, days: Integer | None = None, seconds: Integer | None = None
165
+ ):
166
+ """Construct a Timedelta object.
167
+
168
+ Args:
169
+ days: optional number of days in the duration provided as an int or int
170
+ array, defaulting to zero.
171
+ seconds: optional number of seconds in the duration provided as an int or
172
+ int array, defaulting to zero. If both seconds and days are provided,
173
+ they must have the same shape.
174
+ """
175
+ if days is None and seconds is None:
176
+ days = np.asarray(0)
177
+ seconds = np.asarray(0)
178
+ elif days is None:
179
+ seconds = _as_integer_array(seconds, name='seconds')
180
+ days = _zeros_like(seconds)
181
+ elif seconds is None:
182
+ days = _as_integer_array(days, name='days')
183
+ seconds = _zeros_like(days)
184
+ else:
185
+ days = _as_integer_array(days, name='days')
186
+ seconds = _as_integer_array(seconds, name='seconds')
187
+ if days.shape != seconds.shape:
188
+ raise ValueError(
189
+ f'days and seconds must have the same shape, got {days.shape} and'
190
+ f' {seconds.shape}'
191
+ )
192
+
193
+ self._days, self._seconds = _normalize_days_seconds(days, seconds)
194
+
195
+ @property
196
+ def days(self) -> Array:
197
+ return self._days
198
+
199
+ @property
200
+ def seconds(self) -> Array:
201
+ return self._seconds
202
+
203
+ def __repr__(self) -> str:
204
+ return f'jax_datetime.Timedelta(days={self.days}, seconds={self.seconds})'
205
+
206
+ @classmethod
207
+ def from_normalized(cls, days: Integer, seconds: Integer) -> Self:
208
+ """Fast-path constructor from pre-normalized days and seconds."""
209
+ result = object.__new__(cls)
210
+ result._days = _as_integer_array(days, name='days')
211
+ result._seconds = _as_integer_array(seconds, name='seconds')
212
+ return result
213
+
214
+ @classmethod
215
+ def from_timedelta64(cls, values: np.timedelta64 | np.ndarray) -> Self:
216
+ """Construct a Timedelta from a NumPy timedelta64 scalar or array."""
217
+ seconds = values // np.timedelta64(1, 's') # round down
218
+ # normalize with numpy int64 arrays to avoid overflow in int32
219
+ days, seconds = divmod(seconds, _SECONDS_PER_DAY)
220
+ return cls.from_normalized(days, seconds)
221
+
222
+ @classmethod
223
+ def from_pytimedelta(cls, values: datetime.timedelta) -> Self:
224
+ """Construct a Timedelta from a datetime.timedelta object."""
225
+ return cls.from_normalized(values.days, values.seconds)
226
+
227
+ def to_timedelta64(self) -> np.timedelta64 | np.ndarray:
228
+ """Convert this value to a np.timedelta64 scalar or array."""
229
+ seconds = np.int64(self.days) * _SECONDS_PER_DAY + np.int64(self.seconds)
230
+ return seconds.astype(dtype='timedelta64[s]')
231
+
232
+ def to_pytimedelta(self) -> datetime.timedelta:
233
+ """Convert this value to a datetime.timedelta object."""
234
+ return datetime.timedelta(
235
+ days=operator.index(self.days), seconds=operator.index(self.seconds)
236
+ )
237
+
238
+ # The implementation of all methods should match datetime.timedelta, except
239
+ # extended to handle jax.Array objects:
240
+ # https://docs.python.org/3/library/datetime.html#timedelta-objects
241
+
242
+ def total_seconds(self) -> jnp.ndarray:
243
+ """Total number of seconds in the duration, as a JAX array of floats."""
244
+ return jnp.asarray(self.days, float) * _SECONDS_PER_DAY + self.seconds
245
+
246
+ @overload
247
+ def __add__(self, other: DatetimeLike) -> Datetime:
248
+ ...
249
+
250
+ @overload
251
+ def __add__(self, other: TimedeltaLike) -> Timedelta:
252
+ ...
253
+
254
+ def __add__(
255
+ self, other: TimedeltaLike | DatetimeLike
256
+ ) -> Timedelta | Datetime:
257
+ if isinstance(other, DatetimeLike):
258
+ other = to_datetime(other)
259
+ return other + self
260
+ elif isinstance(other, TimedeltaLike):
261
+ other = to_timedelta(other)
262
+ days = self.days + other.days
263
+ seconds = self.seconds + other.seconds
264
+ return Timedelta(days, seconds) # type: ignore
265
+ elif isinstance(other, np.ndarray):
266
+ # TODO(shoyer): consider handling np.ndarray objects. This is tricky to
267
+ # type check because the correct return type depends on the array dtype.
268
+ raise TypeError(
269
+ 'arithmetic between jax_datetime.Timedelta and np.ndarray objects is'
270
+ ' not yet supported. Use jdt.to_datetime() or jdt.to_timedelta() to'
271
+ ' explicitly cast the NumPy array to a Datetime or Timedelta.'
272
+ )
273
+ else:
274
+ return NotImplemented # type: ignore
275
+
276
+ __radd__ = __add__
277
+
278
+ def __pos__(self) -> Timedelta:
279
+ return self
280
+
281
+ def __neg__(self) -> Timedelta:
282
+ return Timedelta(-self.days, -self.seconds)
283
+
284
+ def __abs__(self) -> Timedelta:
285
+ return jax.tree.map(
286
+ functools.partial(jnp.where, self.days < 0), -self, self
287
+ )
288
+
289
+ def __sub__(self, other: TimedeltaLike) -> Timedelta:
290
+ # TODO(shoyer): consider handling timedelta64 np.ndarray objects
291
+ if not isinstance(other, TimedeltaLike):
292
+ return NotImplemented # type: ignore
293
+ other = to_timedelta(other)
294
+ return self + (-other) # type: ignore
295
+
296
+ def __mul__(self, other: Float | bool) -> Timedelta:
297
+ if not isinstance(other, Float | bool):
298
+ return NotImplemented
299
+ other = jnp.asarray(other)
300
+ if jnp.issubdtype(other.dtype, jnp.integer) or jnp.issubdtype(
301
+ other.dtype, jnp.bool
302
+ ):
303
+ return Timedelta(self.days * other, self.seconds * other)
304
+ elif jnp.issubdtype(other.dtype, jnp.floating):
305
+ float_days, day_fraction = jnp.divmod(self.days * other, 1)
306
+ float_seconds = day_fraction * _SECONDS_PER_DAY + self.seconds * other
307
+ days = jnp.around(float_days).astype(int)
308
+ seconds = jnp.around(float_seconds).astype(int)
309
+ return Timedelta(days, seconds)
310
+ else:
311
+ return NotImplemented # type: ignore
312
+
313
+ __rmul__ = __mul__
314
+
315
+ @overload
316
+ def __truediv__(self, other: TimedeltaLike) -> jnp.ndarray:
317
+ ...
318
+
319
+ @overload
320
+ def __truediv__(self, other: Float) -> Timedelta:
321
+ ...
322
+
323
+ def __truediv__(
324
+ self, other: TimedeltaLike | Float
325
+ ) -> jnp.ndarray | Timedelta:
326
+ if isinstance(other, TimedeltaLike):
327
+ other = to_timedelta(other)
328
+ return self.total_seconds() / other.total_seconds() # type: ignore
329
+ elif isinstance(other, Float):
330
+ other = jnp.asarray(other)
331
+ if jnp.issubdtype(other.dtype, jnp.integer):
332
+ days, remaining_days = jnp.divmod(self.days, other)
333
+ float_seconds = (
334
+ remaining_days * _SECONDS_PER_DAY + self.seconds
335
+ ) / other
336
+ seconds = jnp.around(float_seconds).astype(int)
337
+ return Timedelta(days, seconds) # type: ignore
338
+ elif jnp.issubdtype(other.dtype, jnp.floating):
339
+ float_days, remaining_days = jnp.divmod(self.days, other)
340
+ float_seconds = (
341
+ remaining_days * _SECONDS_PER_DAY + self.seconds
342
+ ) / other
343
+ days = jnp.around(float_days).astype(int)
344
+ seconds = jnp.around(float_seconds).astype(int)
345
+ return Timedelta(days, seconds) # type: ignore
346
+ else:
347
+ return NotImplemented # type: ignore
348
+ else:
349
+ return NotImplemented # type: ignore
350
+
351
+ @overload
352
+ def __floordiv__(self, other: TimedeltaLike) -> jnp.ndarray:
353
+ ...
354
+
355
+ @overload
356
+ def __floordiv__(self, other: Float) -> Timedelta:
357
+ ...
358
+
359
+ def __floordiv__(
360
+ self, other: TimedeltaLike | Float
361
+ ) -> jnp.ndarray | Timedelta:
362
+ if isinstance(other, TimedeltaLike):
363
+ other = to_timedelta(other)
364
+ return _timedelta_floordiv(self, other)
365
+ elif isinstance(other, Float):
366
+ other = jnp.asarray(other)
367
+ if not jnp.issubdtype(other.dtype, jnp.integer):
368
+ return NotImplemented # type: ignore
369
+ days, remaining_days = jnp.divmod(self.days, other)
370
+ seconds = (remaining_days * _SECONDS_PER_DAY + self.seconds) // other
371
+ return Timedelta(days, seconds) # type: ignore
372
+ else:
373
+ return NotImplemented # type: ignore
374
+
375
+ # TODO(shoyer): implement __divmod__ and __mod__
376
+
377
+ def _comparison_op(wrapped): # pylint: disable=no-self-argument
378
+ """Private decorator for implementing comparison ops."""
379
+
380
+ # Disable type errors for mismatched signatures with the base class
381
+ # comparison method (object), which always returns bool.
382
+ def wrapper(self, other: TimedeltaLike) -> jnp.ndarray: # type: ignore
383
+ if not isinstance(other, TimedeltaLike):
384
+ return NotImplemented # type: ignore
385
+ other = to_timedelta(other)
386
+ return wrapped(self, other)
387
+
388
+ return wrapper
389
+
390
+ @_comparison_op
391
+ def __eq__(self, other: Timedelta) -> jnp.ndarray:
392
+ return (self.days == other.days) & (self.seconds == other.seconds)
393
+
394
+ @_comparison_op
395
+ def __ne__(self, other: Timedelta) -> jnp.ndarray:
396
+ return (self.days != other.days) | (self.seconds != other.seconds)
397
+
398
+ @_comparison_op
399
+ def __lt__(self, other: Timedelta) -> jnp.ndarray:
400
+ return (self.days < other.days) | (
401
+ (self.days == other.days) & (self.seconds < other.seconds)
402
+ )
403
+
404
+ @_comparison_op
405
+ def __le__(self, other: Timedelta) -> jnp.ndarray:
406
+ return (self.days < other.days) | (
407
+ (self.days == other.days) & (self.seconds <= other.seconds)
408
+ )
409
+
410
+ @_comparison_op
411
+ def __gt__(self, other: Timedelta) -> jnp.ndarray:
412
+ return (self.days > other.days) | (
413
+ (self.days == other.days) & (self.seconds > other.seconds)
414
+ )
415
+
416
+ @_comparison_op
417
+ def __ge__(self, other: Timedelta) -> jnp.ndarray:
418
+ return (self.days > other.days) | (
419
+ (self.days == other.days) & (self.seconds >= other.seconds)
420
+ )
421
+
422
+ def tree_flatten(self):
423
+ """Custom flatten method for pytree serialization."""
424
+ leaves = (self.days, self.seconds)
425
+ aux_data = None
426
+ return leaves, aux_data
427
+
428
+ @classmethod
429
+ def tree_unflatten(cls, aux_data, leaves):
430
+ """Custom unflatten method for pytree serialization."""
431
+ assert aux_data is None
432
+ # JAX uses non-numeric values for pytree leaves inside transformations, so
433
+ # we skip __init__ by constructing the object directly:
434
+ # https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
435
+ result = object.__new__(cls)
436
+ result._days, result._seconds = leaves
437
+ return result
438
+
439
+
440
+ _NUMPY_UNIX_EPOCH = np.datetime64('1970-01-01T00:00:00', 's')
441
+ _PY_UNIX_EPOCH = datetime.datetime(1970, 1, 1)
442
+
443
+
444
+ @jax.tree_util.register_pytree_node_class
445
+ class Datetime(PytreeArray):
446
+ """JAX compatible datetime, stored as a delta from the Unix epoch.
447
+
448
+ The easiest way to create a Datetime is to use `to_datetime`, which
449
+ supports `datetime.datetime`, `np.datetime64` and strings in ISO 8601 format:
450
+
451
+ >>> import jax_datetime as jdt
452
+ >>> jdt.to_datetime('1970-01-02')
453
+ jax_datetime.Datetime(delta=jax_datetime.Timedelta(days=1, seconds=0))
454
+
455
+ Attributes:
456
+ delta: difference between this date and the Unix epoch (1970-01-01).
457
+ """
458
+
459
+ def __init__(self, delta: Timedelta):
460
+ self._delta = delta
461
+
462
+ @property
463
+ def delta(self) -> Timedelta:
464
+ return self._delta
465
+
466
+ def __repr__(self) -> str:
467
+ return f'jax_datetime.Datetime(delta={self.delta})'
468
+
469
+ @classmethod
470
+ def from_datetime64(cls, values: np.datetime64 | np.ndarray) -> Datetime:
471
+ """Construct a Datetime from a NumPy datetime64 scalar or array."""
472
+ return cls(Timedelta.from_timedelta64(values - _NUMPY_UNIX_EPOCH))
473
+
474
+ @classmethod
475
+ def from_pydatetime(cls, value: datetime.datetime) -> Datetime:
476
+ """Construct a Datetime from a Python datetime.datetime object."""
477
+ return cls(Timedelta.from_pytimedelta(value - _PY_UNIX_EPOCH))
478
+
479
+ @classmethod
480
+ def from_isoformat(cls, value: str) -> Datetime:
481
+ """Construct a Datetime from an ISO 8601 string, e.g., '2024-01-01T00'."""
482
+ return cls.from_pydatetime(datetime.datetime.fromisoformat(value))
483
+
484
+ def to_datetime64(self) -> np.datetime64 | np.ndarray:
485
+ """Convert this Datetime to a NumPy datetime64 scalar or array."""
486
+ return self.delta.to_timedelta64() + _NUMPY_UNIX_EPOCH
487
+
488
+ def to_pydatetime(self) -> datetime.datetime:
489
+ """Convert this Datetime to a Python datetime.datetime object."""
490
+ return self.delta.to_pytimedelta() + _PY_UNIX_EPOCH
491
+
492
+ def __add__(self, other: TimedeltaLike | np.ndarray) -> Datetime:
493
+ if not isinstance(other, TimedeltaLike | np.ndarray):
494
+ return NotImplemented # type: ignore
495
+ other = to_timedelta(other)
496
+ return Datetime(self.delta + other) # type: ignore
497
+
498
+ __radd__ = __add__
499
+
500
+ @overload
501
+ def __sub__(self, other: DatetimeLike) -> Timedelta:
502
+ ...
503
+
504
+ @overload
505
+ def __sub__(self, other: TimedeltaLike) -> Datetime:
506
+ ...
507
+
508
+ def __sub__(
509
+ self, other: TimedeltaLike | DatetimeLike
510
+ ) -> Timedelta | Datetime:
511
+ if isinstance(other, DatetimeLike):
512
+ other = to_datetime(other)
513
+ return self.delta - other.delta
514
+ elif isinstance(other, TimedeltaLike):
515
+ other = to_timedelta(other)
516
+ return Datetime(self.delta - other) # type: ignore
517
+ elif isinstance(other, np.ndarray):
518
+ # TODO(shoyer): consider handling np.ndarray objects. This is tricky to
519
+ # type check because the correct return type depends on the array dtype.
520
+ raise TypeError(
521
+ 'arithmetic between jax_datetime.Datetime and np.ndarray objects is'
522
+ ' not yet supported. Use jdt.to_datetime() or jdt.to_timedelta() to'
523
+ ' explicitly cast the NumPy array to a Datetime or Timedelta.'
524
+ )
525
+ else:
526
+ return NotImplemented # type: ignore
527
+
528
+ def __rsub__(self, other: DatetimeLike | np.ndarray) -> Timedelta:
529
+ # TODO(shoyer): consider handling datetime64 np.ndarray objects
530
+ if isinstance(other, DatetimeLike | np.ndarray):
531
+ other = to_datetime(other)
532
+ return other.delta - self.delta
533
+ else:
534
+ return NotImplemented # type: ignore
535
+
536
+ def _comparison_op(wrapped): # pylint: disable=no-self-argument
537
+ """Private decorator for implementing comparison ops."""
538
+
539
+ # Disable type errors for mismatched signatures with the base class
540
+ # comparison method (object), which always returns bool.
541
+ def wrapper(self, other: DatetimeLike) -> jnp.ndarray: # type: ignore
542
+ if not isinstance(other, DatetimeLike):
543
+ return NotImplemented # type: ignore
544
+ other = to_datetime(other)
545
+ return wrapped(self, other)
546
+
547
+ return wrapper
548
+
549
+ @_comparison_op
550
+ def __eq__(self, other: Datetime) -> jnp.ndarray:
551
+ return self.delta == other.delta
552
+
553
+ @_comparison_op
554
+ def __ne__(self, other: Datetime) -> jnp.ndarray:
555
+ return self.delta != other.delta
556
+
557
+ @_comparison_op
558
+ def __lt__(self, other: Datetime) -> jnp.ndarray:
559
+ return self.delta < other.delta
560
+
561
+ @_comparison_op
562
+ def __le__(self, other: Datetime) -> jnp.ndarray:
563
+ return self.delta <= other.delta
564
+
565
+ @_comparison_op
566
+ def __gt__(self, other: Datetime) -> jnp.ndarray:
567
+ return self.delta > other.delta
568
+
569
+ @_comparison_op
570
+ def __ge__(self, other: Datetime) -> jnp.ndarray:
571
+ return self.delta >= other.delta
572
+
573
+ def tree_flatten(self):
574
+ leaves = (self.delta,)
575
+ aux_data = None
576
+ return leaves, aux_data
577
+
578
+ @classmethod
579
+ def tree_unflatten(cls, aux_data, leaves):
580
+ assert aux_data is None
581
+ return cls(*leaves)
582
+
583
+
584
+ DatetimeLike = Datetime | datetime.datetime | np.datetime64
585
+ TimedeltaLike = Timedelta | datetime.timedelta | np.timedelta64
586
+
587
+
588
+ def to_datetime(value: DatetimeLike | np.ndarray | str) -> Datetime:
589
+ """Convert a value into a Datetime object.
590
+
591
+ Args:
592
+ value: a jax_datetime.Datetime, datetime.datetime, np.datetime64, np.ndarray
593
+ with a datetime64 dtype or string in ISO 8601 format.
594
+
595
+ Returns:
596
+ Value cast to a jax_datetime.Datetime object.
597
+ """
598
+ match value:
599
+ case Datetime():
600
+ return value
601
+ case datetime.datetime():
602
+ return Datetime.from_pydatetime(value)
603
+ case np.datetime64():
604
+ return Datetime.from_datetime64(value)
605
+ case np.ndarray():
606
+ return Datetime.from_datetime64(value)
607
+ case str():
608
+ return Datetime.from_isoformat(value)
609
+ case _:
610
+ raise TypeError(f'unsupported type for to_datetime: {type(value)}')
611
+
612
+
613
+ def _to_timedelta_from_units(value: Integer, unit: str) -> Timedelta:
614
+ """Create Timedelta from a numeric value and unit string."""
615
+ # valid units are a subset of those supported by pd.to_timedelta:
616
+ # https://pandas.pydata.org/docs/reference/api/pandas.to_timedelta.html
617
+ if not isinstance(value, Integer):
618
+ raise TypeError(
619
+ 'to_timedelta with units requires either a number or an array of'
620
+ f' numbers, got {type(value)}: {value!r}'
621
+ )
622
+ if unit in {'D', 'day', 'days'}:
623
+ return Timedelta(days=value)
624
+ elif unit in {'h', 'hr', 'hour', 'hours'}:
625
+ return Timedelta(seconds=value * 3600)
626
+ elif unit in {'m', 'min', 'minute', 'minutes'}:
627
+ return Timedelta(seconds=value * 60)
628
+ elif unit in {'s', 'sec', 'second', 'seconds'}:
629
+ return Timedelta(seconds=value)
630
+ else:
631
+ raise ValueError(f'unsupported unit for to_timedelta: {unit!r}')
632
+
633
+
634
+ def to_timedelta(
635
+ value: TimedeltaLike | Integer,
636
+ unit: str | None = None,
637
+ ) -> Timedelta:
638
+ """Convert a value into a Timedelta object.
639
+
640
+ Args:
641
+ value: a jax_datetime.Timedelta, datetime.timedelta, np.timedelta64,
642
+ np.ndarray with a timedelta64 dtype, array with a numeric dtype or number.
643
+ unit: optional units string. Required if `value` is given as a number.
644
+ Supported values are D/days/days, hours/hour/hr/h/H, m/minute/min/minutes,
645
+ and s/seconds/sec/second, i.e., NumPy's supported datetime units plus
646
+ standard abbreviations.
647
+
648
+ Returns:
649
+ Value cast to a jax_datetime.Timedelta object.
650
+ """
651
+ if unit is not None:
652
+ return _to_timedelta_from_units(value, unit)
653
+
654
+ match value:
655
+ case Timedelta():
656
+ return value
657
+ case datetime.timedelta():
658
+ return Timedelta.from_pytimedelta(value)
659
+ case np.timedelta64():
660
+ return Timedelta.from_timedelta64(value)
661
+ case np.ndarray():
662
+ return Timedelta.from_timedelta64(value)
663
+ case _:
664
+ raise TypeError(
665
+ f'unsupported type for to_timedelta without unit: {type(value)}'
666
+ )