jax-datetime 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_datetime/__init__.py +30 -0
- jax_datetime/_src/__init__.py +0 -0
- jax_datetime/_src/core.py +666 -0
- jax_datetime/_src/core_test.py +573 -0
- jax_datetime/_src/numpy_funcs.py +57 -0
- jax_datetime/_src/numpy_funcs_test.py +62 -0
- jax_datetime-0.1.0.dist-info/METADATA +337 -0
- jax_datetime-0.1.0.dist-info/RECORD +11 -0
- jax_datetime-0.1.0.dist-info/WHEEL +5 -0
- jax_datetime-0.1.0.dist-info/licenses/LICENSE +202 -0
- jax_datetime-0.1.0.dist-info/top_level.txt +1 -0
jax_datetime/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""JAX compatible datetime types."""
|
|
15
|
+
# pylint: disable=g-multiple-import,g-importing-member,useless-import-alias
|
|
16
|
+
|
|
17
|
+
# Note: import <name> as <name> is required for names to be exported.
|
|
18
|
+
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
|
19
|
+
from jax_datetime._src.core import (
|
|
20
|
+
Datetime as Datetime,
|
|
21
|
+
Timedelta as Timedelta,
|
|
22
|
+
to_datetime as to_datetime,
|
|
23
|
+
to_timedelta as to_timedelta,
|
|
24
|
+
)
|
|
25
|
+
from jax_datetime._src.numpy_funcs import (
|
|
26
|
+
interp as interp,
|
|
27
|
+
searchsorted as searchsorted,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
__version__ = "0.1.0" # keep sync with pyproject.toml
|
|
File without changes
|
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Core implementation of jax_datetime."""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import datetime
|
|
18
|
+
import functools
|
|
19
|
+
import math
|
|
20
|
+
import operator
|
|
21
|
+
from typing import Self, overload
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
Array = np.ndarray | jax.Array
|
|
29
|
+
Integer = int | np.integer | Array
|
|
30
|
+
Float = float | np.floating | Integer
|
|
31
|
+
|
|
32
|
+
DeviceLike = jax.Device | jax.sharding.Sharding | None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PytreeArray:
|
|
36
|
+
"""Base class for an array-like object implemented as a pytree.
|
|
37
|
+
|
|
38
|
+
The convention for these objects is that they are built out of a collection
|
|
39
|
+
of JAX arrays with the same shape.
|
|
40
|
+
|
|
41
|
+
In principle, they could be implemented as a custom JAX dtype, but such APIs
|
|
42
|
+
are not yet available.
|
|
43
|
+
|
|
44
|
+
There is a long list of methods we could potentially implement here. For now,
|
|
45
|
+
we only include the basics.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def shape(self) -> tuple[int, ...]:
|
|
50
|
+
shapes = [jnp.shape(x) for x in jax.tree.leaves(self)]
|
|
51
|
+
if len(set(shapes)) != 1:
|
|
52
|
+
raise ValueError('all leaves must have the same shape')
|
|
53
|
+
return shapes[0]
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def size(self) -> int:
|
|
57
|
+
return math.prod(self.shape)
|
|
58
|
+
|
|
59
|
+
def __len__(self) -> int:
|
|
60
|
+
return self.shape[0]
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def ndim(self) -> int:
|
|
64
|
+
return len(self.shape)
|
|
65
|
+
|
|
66
|
+
def transpose(self, axes: tuple[int, ...]) -> PytreeArray:
|
|
67
|
+
return jax.tree.map(lambda x: x.transpose(axes), self)
|
|
68
|
+
|
|
69
|
+
def __getitem__(self, index) -> Array:
|
|
70
|
+
return jax.tree.map(lambda x: x[index], self)
|
|
71
|
+
|
|
72
|
+
# Take precedence over numpy arrays in binary arithmetic.
|
|
73
|
+
__array_priority__ = 100
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _as_integer_array(x: Integer, name: str) -> np.ndarray | jax.Array:
|
|
77
|
+
if not isinstance(x, jax.Array):
|
|
78
|
+
x = np.asarray(x)
|
|
79
|
+
if not np.issubdtype(x.dtype, np.integer):
|
|
80
|
+
raise ValueError(f'{name} must be an integer array, got {x.dtype}')
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _zeros_like(other: np.ndarray | jax.Array):
|
|
85
|
+
zeros_like = jnp.zeros_like if isinstance(other, jax.Array) else np.zeros_like
|
|
86
|
+
return zeros_like(other)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
_SECONDS_PER_DAY = 24 * 60 * 60
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _asarray(x: np.integer | Array) -> Array:
|
|
93
|
+
return np.asarray(x) if isinstance(x, np.integer) else x
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _normalize_days_seconds(days: Array, seconds: Array) -> tuple[Array, Array]:
|
|
97
|
+
assert np.issubdtype(days.dtype, np.integer)
|
|
98
|
+
assert np.issubdtype(seconds.dtype, np.integer)
|
|
99
|
+
days_delta, seconds = divmod(seconds, _SECONDS_PER_DAY)
|
|
100
|
+
days = _asarray(days + days_delta)
|
|
101
|
+
seconds = _asarray(seconds)
|
|
102
|
+
return (days, seconds)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _to_int_seconds(delta: Timedelta) -> jnp.ndarray:
|
|
106
|
+
"""Returns total seconds as an integer."""
|
|
107
|
+
# This works for timedeltas less than 2**31 seconds, which is ~68 years.
|
|
108
|
+
return delta.days * _SECONDS_PER_DAY + delta.seconds
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
_INT32_MAX = 2**31 - 1
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@jax.jit
|
|
115
|
+
def _timedelta_floordiv(
|
|
116
|
+
numerator: Timedelta, divisor: Timedelta
|
|
117
|
+
) -> jnp.ndarray:
|
|
118
|
+
"""Implements Timedelta // Timedelta as accurately as possible."""
|
|
119
|
+
# We need a separate helper function because it must be jitted to avoid
|
|
120
|
+
# raising OverflowError if int32 calculations are invalid.
|
|
121
|
+
int_calc = _to_int_seconds(numerator) // _to_int_seconds(divisor)
|
|
122
|
+
float_calc = (numerator.total_seconds() // divisor.total_seconds()).astype(
|
|
123
|
+
int
|
|
124
|
+
)
|
|
125
|
+
int32_seconds_valid = (abs(numerator.total_seconds()) < _INT32_MAX) & (
|
|
126
|
+
abs(divisor.total_seconds()) < _INT32_MAX
|
|
127
|
+
)
|
|
128
|
+
return jnp.where(int32_seconds_valid, int_calc, float_calc)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@jax.tree_util.register_pytree_node_class
|
|
132
|
+
class Timedelta(PytreeArray):
|
|
133
|
+
"""JAX compatible time duration, stored in days and seconds.
|
|
134
|
+
|
|
135
|
+
Like datetime.timedelta, the `Timedelta` constructor normalizes the seconds
|
|
136
|
+
field to fall in the range `[0, 24*60*60)`, with whole days moved into `days`.
|
|
137
|
+
|
|
138
|
+
Timedelta is implemented internally as a JAX pytree of integer arrays of
|
|
139
|
+
days and seconds. Using JAX's default int32 precision, Timedelta can exactly
|
|
140
|
+
represent durations over 5 million years.
|
|
141
|
+
|
|
142
|
+
You can either use the `Timedelta` constructor directly, or use `to_timedelta`
|
|
143
|
+
to convert from `datetime.timedelta`, `np.timedelta64` or integers with
|
|
144
|
+
units::
|
|
145
|
+
|
|
146
|
+
>>> import jax_datetime as jdt
|
|
147
|
+
>>> import datetime
|
|
148
|
+
>>> jdt.Timedelta(days=1)
|
|
149
|
+
jax_datetime.Timedelta(days=1, seconds=0)
|
|
150
|
+
>>> jdt.to_timedelta(datetime.timedelta(days=1))
|
|
151
|
+
jax_datetime.Timedelta(days=1, seconds=0)
|
|
152
|
+
>>> jdt.to_timedelta(1, 'D')
|
|
153
|
+
jax_datetime.Timedelta(days=1, seconds=0)
|
|
154
|
+
|
|
155
|
+
Attributes:
|
|
156
|
+
days: integer JAX array indicating the number of days in the duration.
|
|
157
|
+
seconds: integer JAX array indicating the number of seconds in the duration,
|
|
158
|
+
normalized to fall in the range `[0, 24*60*60)`.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
# TODO(shoyer): can we rewrite this a custom JAX dtype, like jax.random.key?
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self, days: Integer | None = None, seconds: Integer | None = None
|
|
165
|
+
):
|
|
166
|
+
"""Construct a Timedelta object.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
days: optional number of days in the duration provided as an int or int
|
|
170
|
+
array, defaulting to zero.
|
|
171
|
+
seconds: optional number of seconds in the duration provided as an int or
|
|
172
|
+
int array, defaulting to zero. If both seconds and days are provided,
|
|
173
|
+
they must have the same shape.
|
|
174
|
+
"""
|
|
175
|
+
if days is None and seconds is None:
|
|
176
|
+
days = np.asarray(0)
|
|
177
|
+
seconds = np.asarray(0)
|
|
178
|
+
elif days is None:
|
|
179
|
+
seconds = _as_integer_array(seconds, name='seconds')
|
|
180
|
+
days = _zeros_like(seconds)
|
|
181
|
+
elif seconds is None:
|
|
182
|
+
days = _as_integer_array(days, name='days')
|
|
183
|
+
seconds = _zeros_like(days)
|
|
184
|
+
else:
|
|
185
|
+
days = _as_integer_array(days, name='days')
|
|
186
|
+
seconds = _as_integer_array(seconds, name='seconds')
|
|
187
|
+
if days.shape != seconds.shape:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f'days and seconds must have the same shape, got {days.shape} and'
|
|
190
|
+
f' {seconds.shape}'
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self._days, self._seconds = _normalize_days_seconds(days, seconds)
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def days(self) -> Array:
|
|
197
|
+
return self._days
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def seconds(self) -> Array:
|
|
201
|
+
return self._seconds
|
|
202
|
+
|
|
203
|
+
def __repr__(self) -> str:
|
|
204
|
+
return f'jax_datetime.Timedelta(days={self.days}, seconds={self.seconds})'
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_normalized(cls, days: Integer, seconds: Integer) -> Self:
|
|
208
|
+
"""Fast-path constructor from pre-normalized days and seconds."""
|
|
209
|
+
result = object.__new__(cls)
|
|
210
|
+
result._days = _as_integer_array(days, name='days')
|
|
211
|
+
result._seconds = _as_integer_array(seconds, name='seconds')
|
|
212
|
+
return result
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
def from_timedelta64(cls, values: np.timedelta64 | np.ndarray) -> Self:
|
|
216
|
+
"""Construct a Timedelta from a NumPy timedelta64 scalar or array."""
|
|
217
|
+
seconds = values // np.timedelta64(1, 's') # round down
|
|
218
|
+
# normalize with numpy int64 arrays to avoid overflow in int32
|
|
219
|
+
days, seconds = divmod(seconds, _SECONDS_PER_DAY)
|
|
220
|
+
return cls.from_normalized(days, seconds)
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def from_pytimedelta(cls, values: datetime.timedelta) -> Self:
|
|
224
|
+
"""Construct a Timedelta from a datetime.timedelta object."""
|
|
225
|
+
return cls.from_normalized(values.days, values.seconds)
|
|
226
|
+
|
|
227
|
+
def to_timedelta64(self) -> np.timedelta64 | np.ndarray:
|
|
228
|
+
"""Convert this value to a np.timedelta64 scalar or array."""
|
|
229
|
+
seconds = np.int64(self.days) * _SECONDS_PER_DAY + np.int64(self.seconds)
|
|
230
|
+
return seconds.astype(dtype='timedelta64[s]')
|
|
231
|
+
|
|
232
|
+
def to_pytimedelta(self) -> datetime.timedelta:
|
|
233
|
+
"""Convert this value to a datetime.timedelta object."""
|
|
234
|
+
return datetime.timedelta(
|
|
235
|
+
days=operator.index(self.days), seconds=operator.index(self.seconds)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# The implementation of all methods should match datetime.timedelta, except
|
|
239
|
+
# extended to handle jax.Array objects:
|
|
240
|
+
# https://docs.python.org/3/library/datetime.html#timedelta-objects
|
|
241
|
+
|
|
242
|
+
def total_seconds(self) -> jnp.ndarray:
|
|
243
|
+
"""Total number of seconds in the duration, as a JAX array of floats."""
|
|
244
|
+
return jnp.asarray(self.days, float) * _SECONDS_PER_DAY + self.seconds
|
|
245
|
+
|
|
246
|
+
@overload
|
|
247
|
+
def __add__(self, other: DatetimeLike) -> Datetime:
|
|
248
|
+
...
|
|
249
|
+
|
|
250
|
+
@overload
|
|
251
|
+
def __add__(self, other: TimedeltaLike) -> Timedelta:
|
|
252
|
+
...
|
|
253
|
+
|
|
254
|
+
def __add__(
|
|
255
|
+
self, other: TimedeltaLike | DatetimeLike
|
|
256
|
+
) -> Timedelta | Datetime:
|
|
257
|
+
if isinstance(other, DatetimeLike):
|
|
258
|
+
other = to_datetime(other)
|
|
259
|
+
return other + self
|
|
260
|
+
elif isinstance(other, TimedeltaLike):
|
|
261
|
+
other = to_timedelta(other)
|
|
262
|
+
days = self.days + other.days
|
|
263
|
+
seconds = self.seconds + other.seconds
|
|
264
|
+
return Timedelta(days, seconds) # type: ignore
|
|
265
|
+
elif isinstance(other, np.ndarray):
|
|
266
|
+
# TODO(shoyer): consider handling np.ndarray objects. This is tricky to
|
|
267
|
+
# type check because the correct return type depends on the array dtype.
|
|
268
|
+
raise TypeError(
|
|
269
|
+
'arithmetic between jax_datetime.Timedelta and np.ndarray objects is'
|
|
270
|
+
' not yet supported. Use jdt.to_datetime() or jdt.to_timedelta() to'
|
|
271
|
+
' explicitly cast the NumPy array to a Datetime or Timedelta.'
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
return NotImplemented # type: ignore
|
|
275
|
+
|
|
276
|
+
__radd__ = __add__
|
|
277
|
+
|
|
278
|
+
def __pos__(self) -> Timedelta:
|
|
279
|
+
return self
|
|
280
|
+
|
|
281
|
+
def __neg__(self) -> Timedelta:
|
|
282
|
+
return Timedelta(-self.days, -self.seconds)
|
|
283
|
+
|
|
284
|
+
def __abs__(self) -> Timedelta:
|
|
285
|
+
return jax.tree.map(
|
|
286
|
+
functools.partial(jnp.where, self.days < 0), -self, self
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def __sub__(self, other: TimedeltaLike) -> Timedelta:
|
|
290
|
+
# TODO(shoyer): consider handling timedelta64 np.ndarray objects
|
|
291
|
+
if not isinstance(other, TimedeltaLike):
|
|
292
|
+
return NotImplemented # type: ignore
|
|
293
|
+
other = to_timedelta(other)
|
|
294
|
+
return self + (-other) # type: ignore
|
|
295
|
+
|
|
296
|
+
def __mul__(self, other: Float | bool) -> Timedelta:
|
|
297
|
+
if not isinstance(other, Float | bool):
|
|
298
|
+
return NotImplemented
|
|
299
|
+
other = jnp.asarray(other)
|
|
300
|
+
if jnp.issubdtype(other.dtype, jnp.integer) or jnp.issubdtype(
|
|
301
|
+
other.dtype, jnp.bool
|
|
302
|
+
):
|
|
303
|
+
return Timedelta(self.days * other, self.seconds * other)
|
|
304
|
+
elif jnp.issubdtype(other.dtype, jnp.floating):
|
|
305
|
+
float_days, day_fraction = jnp.divmod(self.days * other, 1)
|
|
306
|
+
float_seconds = day_fraction * _SECONDS_PER_DAY + self.seconds * other
|
|
307
|
+
days = jnp.around(float_days).astype(int)
|
|
308
|
+
seconds = jnp.around(float_seconds).astype(int)
|
|
309
|
+
return Timedelta(days, seconds)
|
|
310
|
+
else:
|
|
311
|
+
return NotImplemented # type: ignore
|
|
312
|
+
|
|
313
|
+
__rmul__ = __mul__
|
|
314
|
+
|
|
315
|
+
@overload
|
|
316
|
+
def __truediv__(self, other: TimedeltaLike) -> jnp.ndarray:
|
|
317
|
+
...
|
|
318
|
+
|
|
319
|
+
@overload
|
|
320
|
+
def __truediv__(self, other: Float) -> Timedelta:
|
|
321
|
+
...
|
|
322
|
+
|
|
323
|
+
def __truediv__(
|
|
324
|
+
self, other: TimedeltaLike | Float
|
|
325
|
+
) -> jnp.ndarray | Timedelta:
|
|
326
|
+
if isinstance(other, TimedeltaLike):
|
|
327
|
+
other = to_timedelta(other)
|
|
328
|
+
return self.total_seconds() / other.total_seconds() # type: ignore
|
|
329
|
+
elif isinstance(other, Float):
|
|
330
|
+
other = jnp.asarray(other)
|
|
331
|
+
if jnp.issubdtype(other.dtype, jnp.integer):
|
|
332
|
+
days, remaining_days = jnp.divmod(self.days, other)
|
|
333
|
+
float_seconds = (
|
|
334
|
+
remaining_days * _SECONDS_PER_DAY + self.seconds
|
|
335
|
+
) / other
|
|
336
|
+
seconds = jnp.around(float_seconds).astype(int)
|
|
337
|
+
return Timedelta(days, seconds) # type: ignore
|
|
338
|
+
elif jnp.issubdtype(other.dtype, jnp.floating):
|
|
339
|
+
float_days, remaining_days = jnp.divmod(self.days, other)
|
|
340
|
+
float_seconds = (
|
|
341
|
+
remaining_days * _SECONDS_PER_DAY + self.seconds
|
|
342
|
+
) / other
|
|
343
|
+
days = jnp.around(float_days).astype(int)
|
|
344
|
+
seconds = jnp.around(float_seconds).astype(int)
|
|
345
|
+
return Timedelta(days, seconds) # type: ignore
|
|
346
|
+
else:
|
|
347
|
+
return NotImplemented # type: ignore
|
|
348
|
+
else:
|
|
349
|
+
return NotImplemented # type: ignore
|
|
350
|
+
|
|
351
|
+
@overload
|
|
352
|
+
def __floordiv__(self, other: TimedeltaLike) -> jnp.ndarray:
|
|
353
|
+
...
|
|
354
|
+
|
|
355
|
+
@overload
|
|
356
|
+
def __floordiv__(self, other: Float) -> Timedelta:
|
|
357
|
+
...
|
|
358
|
+
|
|
359
|
+
def __floordiv__(
|
|
360
|
+
self, other: TimedeltaLike | Float
|
|
361
|
+
) -> jnp.ndarray | Timedelta:
|
|
362
|
+
if isinstance(other, TimedeltaLike):
|
|
363
|
+
other = to_timedelta(other)
|
|
364
|
+
return _timedelta_floordiv(self, other)
|
|
365
|
+
elif isinstance(other, Float):
|
|
366
|
+
other = jnp.asarray(other)
|
|
367
|
+
if not jnp.issubdtype(other.dtype, jnp.integer):
|
|
368
|
+
return NotImplemented # type: ignore
|
|
369
|
+
days, remaining_days = jnp.divmod(self.days, other)
|
|
370
|
+
seconds = (remaining_days * _SECONDS_PER_DAY + self.seconds) // other
|
|
371
|
+
return Timedelta(days, seconds) # type: ignore
|
|
372
|
+
else:
|
|
373
|
+
return NotImplemented # type: ignore
|
|
374
|
+
|
|
375
|
+
# TODO(shoyer): implement __divmod__ and __mod__
|
|
376
|
+
|
|
377
|
+
def _comparison_op(wrapped): # pylint: disable=no-self-argument
|
|
378
|
+
"""Private decorator for implementing comparison ops."""
|
|
379
|
+
|
|
380
|
+
# Disable type errors for mismatched signatures with the base class
|
|
381
|
+
# comparison method (object), which always returns bool.
|
|
382
|
+
def wrapper(self, other: TimedeltaLike) -> jnp.ndarray: # type: ignore
|
|
383
|
+
if not isinstance(other, TimedeltaLike):
|
|
384
|
+
return NotImplemented # type: ignore
|
|
385
|
+
other = to_timedelta(other)
|
|
386
|
+
return wrapped(self, other)
|
|
387
|
+
|
|
388
|
+
return wrapper
|
|
389
|
+
|
|
390
|
+
@_comparison_op
|
|
391
|
+
def __eq__(self, other: Timedelta) -> jnp.ndarray:
|
|
392
|
+
return (self.days == other.days) & (self.seconds == other.seconds)
|
|
393
|
+
|
|
394
|
+
@_comparison_op
|
|
395
|
+
def __ne__(self, other: Timedelta) -> jnp.ndarray:
|
|
396
|
+
return (self.days != other.days) | (self.seconds != other.seconds)
|
|
397
|
+
|
|
398
|
+
@_comparison_op
|
|
399
|
+
def __lt__(self, other: Timedelta) -> jnp.ndarray:
|
|
400
|
+
return (self.days < other.days) | (
|
|
401
|
+
(self.days == other.days) & (self.seconds < other.seconds)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
@_comparison_op
|
|
405
|
+
def __le__(self, other: Timedelta) -> jnp.ndarray:
|
|
406
|
+
return (self.days < other.days) | (
|
|
407
|
+
(self.days == other.days) & (self.seconds <= other.seconds)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
@_comparison_op
|
|
411
|
+
def __gt__(self, other: Timedelta) -> jnp.ndarray:
|
|
412
|
+
return (self.days > other.days) | (
|
|
413
|
+
(self.days == other.days) & (self.seconds > other.seconds)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
@_comparison_op
|
|
417
|
+
def __ge__(self, other: Timedelta) -> jnp.ndarray:
|
|
418
|
+
return (self.days > other.days) | (
|
|
419
|
+
(self.days == other.days) & (self.seconds >= other.seconds)
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
def tree_flatten(self):
|
|
423
|
+
"""Custom flatten method for pytree serialization."""
|
|
424
|
+
leaves = (self.days, self.seconds)
|
|
425
|
+
aux_data = None
|
|
426
|
+
return leaves, aux_data
|
|
427
|
+
|
|
428
|
+
@classmethod
|
|
429
|
+
def tree_unflatten(cls, aux_data, leaves):
|
|
430
|
+
"""Custom unflatten method for pytree serialization."""
|
|
431
|
+
assert aux_data is None
|
|
432
|
+
# JAX uses non-numeric values for pytree leaves inside transformations, so
|
|
433
|
+
# we skip __init__ by constructing the object directly:
|
|
434
|
+
# https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
|
|
435
|
+
result = object.__new__(cls)
|
|
436
|
+
result._days, result._seconds = leaves
|
|
437
|
+
return result
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
_NUMPY_UNIX_EPOCH = np.datetime64('1970-01-01T00:00:00', 's')
|
|
441
|
+
_PY_UNIX_EPOCH = datetime.datetime(1970, 1, 1)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
@jax.tree_util.register_pytree_node_class
|
|
445
|
+
class Datetime(PytreeArray):
|
|
446
|
+
"""JAX compatible datetime, stored as a delta from the Unix epoch.
|
|
447
|
+
|
|
448
|
+
The easiest way to create a Datetime is to use `to_datetime`, which
|
|
449
|
+
supports `datetime.datetime`, `np.datetime64` and strings in ISO 8601 format:
|
|
450
|
+
|
|
451
|
+
>>> import jax_datetime as jdt
|
|
452
|
+
>>> jdt.to_datetime('1970-01-02')
|
|
453
|
+
jax_datetime.Datetime(delta=jax_datetime.Timedelta(days=1, seconds=0))
|
|
454
|
+
|
|
455
|
+
Attributes:
|
|
456
|
+
delta: difference between this date and the Unix epoch (1970-01-01).
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
def __init__(self, delta: Timedelta):
|
|
460
|
+
self._delta = delta
|
|
461
|
+
|
|
462
|
+
@property
|
|
463
|
+
def delta(self) -> Timedelta:
|
|
464
|
+
return self._delta
|
|
465
|
+
|
|
466
|
+
def __repr__(self) -> str:
|
|
467
|
+
return f'jax_datetime.Datetime(delta={self.delta})'
|
|
468
|
+
|
|
469
|
+
@classmethod
|
|
470
|
+
def from_datetime64(cls, values: np.datetime64 | np.ndarray) -> Datetime:
|
|
471
|
+
"""Construct a Datetime from a NumPy datetime64 scalar or array."""
|
|
472
|
+
return cls(Timedelta.from_timedelta64(values - _NUMPY_UNIX_EPOCH))
|
|
473
|
+
|
|
474
|
+
@classmethod
|
|
475
|
+
def from_pydatetime(cls, value: datetime.datetime) -> Datetime:
|
|
476
|
+
"""Construct a Datetime from a Python datetime.datetime object."""
|
|
477
|
+
return cls(Timedelta.from_pytimedelta(value - _PY_UNIX_EPOCH))
|
|
478
|
+
|
|
479
|
+
@classmethod
|
|
480
|
+
def from_isoformat(cls, value: str) -> Datetime:
|
|
481
|
+
"""Construct a Datetime from an ISO 8601 string, e.g., '2024-01-01T00'."""
|
|
482
|
+
return cls.from_pydatetime(datetime.datetime.fromisoformat(value))
|
|
483
|
+
|
|
484
|
+
def to_datetime64(self) -> np.datetime64 | np.ndarray:
|
|
485
|
+
"""Convert this Datetime to a NumPy datetime64 scalar or array."""
|
|
486
|
+
return self.delta.to_timedelta64() + _NUMPY_UNIX_EPOCH
|
|
487
|
+
|
|
488
|
+
def to_pydatetime(self) -> datetime.datetime:
|
|
489
|
+
"""Convert this Datetime to a Python datetime.datetime object."""
|
|
490
|
+
return self.delta.to_pytimedelta() + _PY_UNIX_EPOCH
|
|
491
|
+
|
|
492
|
+
def __add__(self, other: TimedeltaLike | np.ndarray) -> Datetime:
|
|
493
|
+
if not isinstance(other, TimedeltaLike | np.ndarray):
|
|
494
|
+
return NotImplemented # type: ignore
|
|
495
|
+
other = to_timedelta(other)
|
|
496
|
+
return Datetime(self.delta + other) # type: ignore
|
|
497
|
+
|
|
498
|
+
__radd__ = __add__
|
|
499
|
+
|
|
500
|
+
@overload
|
|
501
|
+
def __sub__(self, other: DatetimeLike) -> Timedelta:
|
|
502
|
+
...
|
|
503
|
+
|
|
504
|
+
@overload
|
|
505
|
+
def __sub__(self, other: TimedeltaLike) -> Datetime:
|
|
506
|
+
...
|
|
507
|
+
|
|
508
|
+
def __sub__(
|
|
509
|
+
self, other: TimedeltaLike | DatetimeLike
|
|
510
|
+
) -> Timedelta | Datetime:
|
|
511
|
+
if isinstance(other, DatetimeLike):
|
|
512
|
+
other = to_datetime(other)
|
|
513
|
+
return self.delta - other.delta
|
|
514
|
+
elif isinstance(other, TimedeltaLike):
|
|
515
|
+
other = to_timedelta(other)
|
|
516
|
+
return Datetime(self.delta - other) # type: ignore
|
|
517
|
+
elif isinstance(other, np.ndarray):
|
|
518
|
+
# TODO(shoyer): consider handling np.ndarray objects. This is tricky to
|
|
519
|
+
# type check because the correct return type depends on the array dtype.
|
|
520
|
+
raise TypeError(
|
|
521
|
+
'arithmetic between jax_datetime.Datetime and np.ndarray objects is'
|
|
522
|
+
' not yet supported. Use jdt.to_datetime() or jdt.to_timedelta() to'
|
|
523
|
+
' explicitly cast the NumPy array to a Datetime or Timedelta.'
|
|
524
|
+
)
|
|
525
|
+
else:
|
|
526
|
+
return NotImplemented # type: ignore
|
|
527
|
+
|
|
528
|
+
def __rsub__(self, other: DatetimeLike | np.ndarray) -> Timedelta:
|
|
529
|
+
# TODO(shoyer): consider handling datetime64 np.ndarray objects
|
|
530
|
+
if isinstance(other, DatetimeLike | np.ndarray):
|
|
531
|
+
other = to_datetime(other)
|
|
532
|
+
return other.delta - self.delta
|
|
533
|
+
else:
|
|
534
|
+
return NotImplemented # type: ignore
|
|
535
|
+
|
|
536
|
+
def _comparison_op(wrapped): # pylint: disable=no-self-argument
|
|
537
|
+
"""Private decorator for implementing comparison ops."""
|
|
538
|
+
|
|
539
|
+
# Disable type errors for mismatched signatures with the base class
|
|
540
|
+
# comparison method (object), which always returns bool.
|
|
541
|
+
def wrapper(self, other: DatetimeLike) -> jnp.ndarray: # type: ignore
|
|
542
|
+
if not isinstance(other, DatetimeLike):
|
|
543
|
+
return NotImplemented # type: ignore
|
|
544
|
+
other = to_datetime(other)
|
|
545
|
+
return wrapped(self, other)
|
|
546
|
+
|
|
547
|
+
return wrapper
|
|
548
|
+
|
|
549
|
+
@_comparison_op
|
|
550
|
+
def __eq__(self, other: Datetime) -> jnp.ndarray:
|
|
551
|
+
return self.delta == other.delta
|
|
552
|
+
|
|
553
|
+
@_comparison_op
|
|
554
|
+
def __ne__(self, other: Datetime) -> jnp.ndarray:
|
|
555
|
+
return self.delta != other.delta
|
|
556
|
+
|
|
557
|
+
@_comparison_op
|
|
558
|
+
def __lt__(self, other: Datetime) -> jnp.ndarray:
|
|
559
|
+
return self.delta < other.delta
|
|
560
|
+
|
|
561
|
+
@_comparison_op
|
|
562
|
+
def __le__(self, other: Datetime) -> jnp.ndarray:
|
|
563
|
+
return self.delta <= other.delta
|
|
564
|
+
|
|
565
|
+
@_comparison_op
|
|
566
|
+
def __gt__(self, other: Datetime) -> jnp.ndarray:
|
|
567
|
+
return self.delta > other.delta
|
|
568
|
+
|
|
569
|
+
@_comparison_op
|
|
570
|
+
def __ge__(self, other: Datetime) -> jnp.ndarray:
|
|
571
|
+
return self.delta >= other.delta
|
|
572
|
+
|
|
573
|
+
def tree_flatten(self):
|
|
574
|
+
leaves = (self.delta,)
|
|
575
|
+
aux_data = None
|
|
576
|
+
return leaves, aux_data
|
|
577
|
+
|
|
578
|
+
@classmethod
|
|
579
|
+
def tree_unflatten(cls, aux_data, leaves):
|
|
580
|
+
assert aux_data is None
|
|
581
|
+
return cls(*leaves)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
DatetimeLike = Datetime | datetime.datetime | np.datetime64
|
|
585
|
+
TimedeltaLike = Timedelta | datetime.timedelta | np.timedelta64
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def to_datetime(value: DatetimeLike | np.ndarray | str) -> Datetime:
|
|
589
|
+
"""Convert a value into a Datetime object.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
value: a jax_datetime.Datetime, datetime.datetime, np.datetime64, np.ndarray
|
|
593
|
+
with a datetime64 dtype or string in ISO 8601 format.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Value cast to a jax_datetime.Datetime object.
|
|
597
|
+
"""
|
|
598
|
+
match value:
|
|
599
|
+
case Datetime():
|
|
600
|
+
return value
|
|
601
|
+
case datetime.datetime():
|
|
602
|
+
return Datetime.from_pydatetime(value)
|
|
603
|
+
case np.datetime64():
|
|
604
|
+
return Datetime.from_datetime64(value)
|
|
605
|
+
case np.ndarray():
|
|
606
|
+
return Datetime.from_datetime64(value)
|
|
607
|
+
case str():
|
|
608
|
+
return Datetime.from_isoformat(value)
|
|
609
|
+
case _:
|
|
610
|
+
raise TypeError(f'unsupported type for to_datetime: {type(value)}')
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _to_timedelta_from_units(value: Integer, unit: str) -> Timedelta:
|
|
614
|
+
"""Create Timedelta from a numeric value and unit string."""
|
|
615
|
+
# valid units are a subset of those supported by pd.to_timedelta:
|
|
616
|
+
# https://pandas.pydata.org/docs/reference/api/pandas.to_timedelta.html
|
|
617
|
+
if not isinstance(value, Integer):
|
|
618
|
+
raise TypeError(
|
|
619
|
+
'to_timedelta with units requires either a number or an array of'
|
|
620
|
+
f' numbers, got {type(value)}: {value!r}'
|
|
621
|
+
)
|
|
622
|
+
if unit in {'D', 'day', 'days'}:
|
|
623
|
+
return Timedelta(days=value)
|
|
624
|
+
elif unit in {'h', 'hr', 'hour', 'hours'}:
|
|
625
|
+
return Timedelta(seconds=value * 3600)
|
|
626
|
+
elif unit in {'m', 'min', 'minute', 'minutes'}:
|
|
627
|
+
return Timedelta(seconds=value * 60)
|
|
628
|
+
elif unit in {'s', 'sec', 'second', 'seconds'}:
|
|
629
|
+
return Timedelta(seconds=value)
|
|
630
|
+
else:
|
|
631
|
+
raise ValueError(f'unsupported unit for to_timedelta: {unit!r}')
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def to_timedelta(
|
|
635
|
+
value: TimedeltaLike | Integer,
|
|
636
|
+
unit: str | None = None,
|
|
637
|
+
) -> Timedelta:
|
|
638
|
+
"""Convert a value into a Timedelta object.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
value: a jax_datetime.Timedelta, datetime.timedelta, np.timedelta64,
|
|
642
|
+
np.ndarray with a timedelta64 dtype, array with a numeric dtype or number.
|
|
643
|
+
unit: optional units string. Required if `value` is given as a number.
|
|
644
|
+
Supported values are D/days/days, hours/hour/hr/h/H, m/minute/min/minutes,
|
|
645
|
+
and s/seconds/sec/second, i.e., NumPy's supported datetime units plus
|
|
646
|
+
standard abbreviations.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
Value cast to a jax_datetime.Timedelta object.
|
|
650
|
+
"""
|
|
651
|
+
if unit is not None:
|
|
652
|
+
return _to_timedelta_from_units(value, unit)
|
|
653
|
+
|
|
654
|
+
match value:
|
|
655
|
+
case Timedelta():
|
|
656
|
+
return value
|
|
657
|
+
case datetime.timedelta():
|
|
658
|
+
return Timedelta.from_pytimedelta(value)
|
|
659
|
+
case np.timedelta64():
|
|
660
|
+
return Timedelta.from_timedelta64(value)
|
|
661
|
+
case np.ndarray():
|
|
662
|
+
return Timedelta.from_timedelta64(value)
|
|
663
|
+
case _:
|
|
664
|
+
raise TypeError(
|
|
665
|
+
f'unsupported type for to_timedelta without unit: {type(value)}'
|
|
666
|
+
)
|