arize-phoenix 0.0.32rc1__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (71) hide show
  1. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/METADATA +11 -5
  2. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/RECORD +69 -40
  3. phoenix/__init__.py +3 -1
  4. phoenix/config.py +23 -1
  5. phoenix/core/model_schema.py +14 -37
  6. phoenix/core/model_schema_adapter.py +0 -1
  7. phoenix/core/traces.py +285 -0
  8. phoenix/datasets/dataset.py +14 -21
  9. phoenix/datasets/errors.py +4 -1
  10. phoenix/datasets/schema.py +1 -1
  11. phoenix/datetime_utils.py +87 -0
  12. phoenix/experimental/callbacks/__init__.py +0 -0
  13. phoenix/experimental/callbacks/langchain_tracer.py +228 -0
  14. phoenix/experimental/callbacks/llama_index_trace_callback_handler.py +364 -0
  15. phoenix/experimental/evals/__init__.py +33 -0
  16. phoenix/experimental/evals/functions/__init__.py +4 -0
  17. phoenix/experimental/evals/functions/binary.py +156 -0
  18. phoenix/experimental/evals/functions/common.py +31 -0
  19. phoenix/experimental/evals/functions/generate.py +50 -0
  20. phoenix/experimental/evals/models/__init__.py +4 -0
  21. phoenix/experimental/evals/models/base.py +130 -0
  22. phoenix/experimental/evals/models/openai.py +128 -0
  23. phoenix/experimental/evals/retrievals.py +2 -2
  24. phoenix/experimental/evals/templates/__init__.py +24 -0
  25. phoenix/experimental/evals/templates/default_templates.py +126 -0
  26. phoenix/experimental/evals/templates/template.py +107 -0
  27. phoenix/experimental/evals/utils/__init__.py +0 -0
  28. phoenix/experimental/evals/utils/downloads.py +33 -0
  29. phoenix/experimental/evals/utils/threads.py +27 -0
  30. phoenix/experimental/evals/utils/types.py +9 -0
  31. phoenix/experimental/evals/utils.py +33 -0
  32. phoenix/metrics/binning.py +0 -1
  33. phoenix/metrics/timeseries.py +2 -3
  34. phoenix/server/api/context.py +2 -0
  35. phoenix/server/api/input_types/SpanSort.py +60 -0
  36. phoenix/server/api/schema.py +85 -4
  37. phoenix/server/api/types/DataQualityMetric.py +10 -1
  38. phoenix/server/api/types/Dataset.py +2 -4
  39. phoenix/server/api/types/DatasetInfo.py +10 -0
  40. phoenix/server/api/types/ExportEventsMutation.py +4 -1
  41. phoenix/server/api/types/Functionality.py +15 -0
  42. phoenix/server/api/types/MimeType.py +16 -0
  43. phoenix/server/api/types/Model.py +3 -5
  44. phoenix/server/api/types/SortDir.py +13 -0
  45. phoenix/server/api/types/Span.py +229 -0
  46. phoenix/server/api/types/TimeSeries.py +9 -2
  47. phoenix/server/api/types/pagination.py +2 -0
  48. phoenix/server/app.py +24 -4
  49. phoenix/server/main.py +60 -24
  50. phoenix/server/span_handler.py +39 -0
  51. phoenix/server/static/index.js +956 -479
  52. phoenix/server/thread_server.py +10 -2
  53. phoenix/services.py +39 -16
  54. phoenix/session/session.py +99 -27
  55. phoenix/trace/exporter.py +71 -0
  56. phoenix/trace/filter.py +181 -0
  57. phoenix/trace/fixtures.py +23 -8
  58. phoenix/trace/schemas.py +59 -6
  59. phoenix/trace/semantic_conventions.py +141 -1
  60. phoenix/trace/span_json_decoder.py +60 -6
  61. phoenix/trace/span_json_encoder.py +1 -9
  62. phoenix/trace/trace_dataset.py +100 -8
  63. phoenix/trace/tracer.py +26 -3
  64. phoenix/trace/v1/__init__.py +522 -0
  65. phoenix/trace/v1/trace_pb2.py +52 -0
  66. phoenix/trace/v1/trace_pb2.pyi +351 -0
  67. phoenix/core/dimension_data_type.py +0 -6
  68. phoenix/core/dimension_type.py +0 -9
  69. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/WHEEL +0 -0
  70. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/IP_NOTICE +0 -0
  71. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/LICENSE +0 -0
phoenix/core/traces.py CHANGED
@@ -0,0 +1,285 @@
1
+ import weakref
2
+ from collections import defaultdict
3
+ from datetime import datetime, timezone
4
+ from queue import SimpleQueue
5
+ from threading import RLock, Thread
6
+ from types import MethodType
7
+ from typing import (
8
+ Any,
9
+ DefaultDict,
10
+ Dict,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Optional,
15
+ SupportsFloat,
16
+ Tuple,
17
+ Union,
18
+ cast,
19
+ )
20
+ from uuid import UUID
21
+
22
+ from sortedcontainers import SortedKeyList
23
+ from typing_extensions import TypeAlias
24
+ from wrapt import ObjectProxy
25
+
26
+ import phoenix.trace.v1.trace_pb2 as pb
27
+ from phoenix.datetime_utils import right_open_time_range
28
+ from phoenix.trace import semantic_conventions
29
+ from phoenix.trace.schemas import (
30
+ ATTRIBUTE_PREFIX,
31
+ COMPUTED_PREFIX,
32
+ CONTEXT_PREFIX,
33
+ Span,
34
+ SpanAttributes,
35
+ SpanID,
36
+ TraceID,
37
+ )
38
+ from phoenix.trace.v1 import decode, encode
39
+
40
+ NAME = "name"
41
+ STATUS_CODE = "status_code"
42
+ SPAN_KIND = "span_kind"
43
+ TRACE_ID = CONTEXT_PREFIX + "trace_id"
44
+ SPAN_ID = CONTEXT_PREFIX + "span_id"
45
+ PARENT_ID = "parent_id"
46
+ START_TIME = "start_time"
47
+ END_TIME = "end_time"
48
+ LLM_TOKEN_COUNT_TOTAL = ATTRIBUTE_PREFIX + semantic_conventions.LLM_TOKEN_COUNT_TOTAL
49
+ LLM_TOKEN_COUNT_PROMPT = ATTRIBUTE_PREFIX + semantic_conventions.LLM_TOKEN_COUNT_PROMPT
50
+ LLM_TOKEN_COUNT_COMPLETION = ATTRIBUTE_PREFIX + semantic_conventions.LLM_TOKEN_COUNT_COMPLETION
51
+ LATENCY_MS = COMPUTED_PREFIX + "latency_ms" # The latency (or duration) of the span in milliseconds
52
+ CUMULATIVE_LLM_TOKEN_COUNT_TOTAL = COMPUTED_PREFIX + "cumulative_token_count_total"
53
+ CUMULATIVE_LLM_TOKEN_COUNT_PROMPT = COMPUTED_PREFIX + "cumulative_token_count_prompt"
54
+ CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION = COMPUTED_PREFIX + "cumulative_token_count_completion"
55
+
56
+
57
+ class ReadableSpan(ObjectProxy): # type: ignore
58
+ """
59
+ A wrapped a protobuf Span, with access methods and ability to decode to
60
+ a python span. It's meant to be interface layer separating use from
61
+ implementation. It can also provide computed values that are not intrinsic
62
+ to the span, e.g. the latency rank percent which can change as more spans
63
+ are ingested, and would need to be re-computed on the fly.
64
+ """
65
+
66
+ __wrapped__: pb.Span
67
+
68
+ def __init__(self, span: pb.Span) -> None:
69
+ super().__init__(span)
70
+ self._self_computed_values: Dict[str, SupportsFloat] = {}
71
+
72
+ @property
73
+ def span(self) -> Span:
74
+ span = decode(self.__wrapped__)
75
+ span.attributes.update(cast(SpanAttributes, self._self_computed_values))
76
+ # TODO: compute latency rank percent (which can change depending on how
77
+ # many spans already ingested).
78
+ return span
79
+
80
+ def __getitem__(self, key: str) -> Any:
81
+ if key.startswith(COMPUTED_PREFIX):
82
+ return self._self_computed_values.get(key)
83
+ if key.startswith(CONTEXT_PREFIX):
84
+ suffix_key = key[len(CONTEXT_PREFIX) :]
85
+ return getattr(self.__wrapped__.context, suffix_key, None)
86
+ if key.startswith(ATTRIBUTE_PREFIX):
87
+ suffix_key = key[len(ATTRIBUTE_PREFIX) :]
88
+ if suffix_key not in self.__wrapped__.attributes:
89
+ return None
90
+ return self.__wrapped__.attributes[suffix_key]
91
+ return getattr(self.__wrapped__, key, None)
92
+
93
+ def __setitem__(self, key: str, value: Any) -> None:
94
+ if not key.startswith(COMPUTED_PREFIX):
95
+ raise KeyError(f"{key} is not a computed value")
96
+ self._self_computed_values[key] = value
97
+
98
+
99
+ ParentSpanID: TypeAlias = SpanID
100
+ ChildSpanID: TypeAlias = SpanID
101
+
102
+
103
+ class Traces:
104
+ def __init__(self, spans: Optional[Iterable[Span]] = None) -> None:
105
+ self._queue: "SimpleQueue[Optional[pb.Span]]" = SimpleQueue()
106
+ # Putting `None` as the sentinel value for queue termination.
107
+ weakref.finalize(self, self._queue.put, None)
108
+ for span in spans or ():
109
+ self.put(span)
110
+ self._lock = RLock()
111
+ self._spans: Dict[SpanID, ReadableSpan] = {}
112
+ self._parent_span_ids: Dict[SpanID, ParentSpanID] = {}
113
+ self._traces: Dict[TraceID, List[SpanID]] = defaultdict(list)
114
+ self._child_span_ids: DefaultDict[SpanID, List[ChildSpanID]] = defaultdict(list)
115
+ self._orphan_spans: DefaultDict[ParentSpanID, List[pb.Span]] = defaultdict(list)
116
+ self._start_time_sorted_span_ids: SortedKeyList[SpanID] = SortedKeyList(
117
+ key=lambda span_id: self._spans[span_id].start_time.ToDatetime(timezone.utc),
118
+ )
119
+ self._start_time_sorted_root_span_ids: SortedKeyList[SpanID] = SortedKeyList(
120
+ key=lambda span_id: self._spans[span_id].start_time.ToDatetime(timezone.utc),
121
+ )
122
+ self._latency_sorted_root_span_ids: SortedKeyList[SpanID] = SortedKeyList(
123
+ key=lambda span_id: self._spans[span_id][LATENCY_MS],
124
+ )
125
+ self._min_start_time: Optional[datetime] = None
126
+ self._max_start_time: Optional[datetime] = None
127
+ self._start_consumer()
128
+
129
+ def put(self, span: Optional[Union[Span, pb.Span]] = None) -> None:
130
+ self._queue.put(encode(span) if isinstance(span, Span) else span)
131
+
132
+ def get_trace(self, trace_id: TraceID) -> Iterator[Span]:
133
+ for span_id in self._traces[trace_id]:
134
+ if span := self[span_id]:
135
+ yield span
136
+
137
+ def get_spans(
138
+ self,
139
+ start_time: Optional[datetime] = None,
140
+ stop_time: Optional[datetime] = None,
141
+ root_spans_only: Optional[bool] = False,
142
+ ) -> Iterator[Span]:
143
+ if not self._spans:
144
+ return
145
+ min_start_time, max_stop_time = cast(
146
+ Tuple[datetime, datetime],
147
+ self.right_open_time_range,
148
+ )
149
+ start_time = start_time or min_start_time
150
+ stop_time = stop_time or max_stop_time
151
+ sorted_span_ids = (
152
+ self._start_time_sorted_root_span_ids
153
+ if root_spans_only
154
+ else self._start_time_sorted_span_ids
155
+ )
156
+ for span_id in sorted_span_ids.irange_key(
157
+ start_time.astimezone(timezone.utc),
158
+ stop_time.astimezone(timezone.utc),
159
+ inclusive=(True, False),
160
+ reverse=True, # most recent spans first
161
+ ):
162
+ if span := self[span_id]:
163
+ yield span
164
+
165
+ def latency_rank_percent(self, latency_ms: float) -> Optional[float]:
166
+ """
167
+ Returns a value between 0 and 100 approximating the rank of the
168
+ latency value as percent of the total count of root spans. E.g., for
169
+ a latency value at the 75th percentile, the result is roughly 75.
170
+ """
171
+ root_span_ids = self._latency_sorted_root_span_ids
172
+ if not (n := len(root_span_ids)):
173
+ return None
174
+ rank = cast(int, root_span_ids.bisect_key_left(latency_ms))
175
+ return rank / n * 100
176
+
177
+ def get_descendant_span_ids(self, span_id: SpanID) -> Iterator[SpanID]:
178
+ for child_span_id in self._child_span_ids.get(span_id) or ():
179
+ yield child_span_id
180
+ yield from self.get_descendant_span_ids(child_span_id)
181
+
182
+ @property
183
+ def span_count(self) -> int:
184
+ """Total number of spans (excluding orphan spans if any)"""
185
+ return len(self._spans)
186
+
187
+ @property
188
+ def right_open_time_range(self) -> Tuple[Optional[datetime], Optional[datetime]]:
189
+ return right_open_time_range(self._min_start_time, self._max_start_time)
190
+
191
+ def __getitem__(self, span_id: SpanID) -> Optional[Span]:
192
+ with self._lock:
193
+ if span := self._spans.get(span_id):
194
+ return span.span
195
+ return None
196
+
197
+ def _start_consumer(self) -> None:
198
+ Thread(
199
+ target=MethodType(
200
+ self.__class__._consume_spans,
201
+ weakref.proxy(self),
202
+ ),
203
+ daemon=True,
204
+ ).start()
205
+
206
+ def _consume_spans(self) -> None:
207
+ while True:
208
+ if not (span := self._queue.get()):
209
+ return
210
+ with self._lock:
211
+ self._process_span(span)
212
+
213
+ def _process_span(self, span: pb.Span) -> None:
214
+ span_id = UUID(bytes=span.context.span_id)
215
+ existing_span = self._spans.get(span_id)
216
+ if existing_span and existing_span.end_time:
217
+ # Reject updates if span has ended.
218
+ return
219
+ is_root_span = not span.HasField("parent_span_id")
220
+ if not is_root_span:
221
+ parent_span_id = UUID(bytes=span.parent_span_id.value)
222
+ if parent_span_id not in self._spans:
223
+ # Span can't be processed before its parent.
224
+ self._orphan_spans[parent_span_id].append(span)
225
+ return
226
+ self._child_span_ids[parent_span_id].append(span_id)
227
+ self._parent_span_ids[span_id] = parent_span_id
228
+ new_span = ReadableSpan(span)
229
+ start_time = span.start_time.ToDatetime(timezone.utc)
230
+ end_time = span.end_time.ToDatetime(timezone.utc) if span.HasField("end_time") else None
231
+ if end_time:
232
+ new_span[LATENCY_MS] = (end_time - start_time).total_seconds() * 1000
233
+ self._spans[span_id] = new_span
234
+ if is_root_span and end_time:
235
+ self._latency_sorted_root_span_ids.add(span_id)
236
+ if not existing_span:
237
+ trace_id = UUID(bytes=span.context.trace_id)
238
+ self._traces[trace_id].append(span_id)
239
+ if is_root_span:
240
+ self._start_time_sorted_root_span_ids.add(span_id)
241
+ self._start_time_sorted_span_ids.add(span_id)
242
+ self._min_start_time = (
243
+ start_time
244
+ if self._min_start_time is None
245
+ else min(self._min_start_time, start_time)
246
+ )
247
+ self._max_start_time = (
248
+ start_time
249
+ if self._max_start_time is None
250
+ else max(self._max_start_time, start_time)
251
+ )
252
+ # Update cumulative values for span's ancestors.
253
+ for attribute_name, cumulative_attribute_name in (
254
+ (LLM_TOKEN_COUNT_TOTAL, CUMULATIVE_LLM_TOKEN_COUNT_TOTAL),
255
+ (LLM_TOKEN_COUNT_PROMPT, CUMULATIVE_LLM_TOKEN_COUNT_PROMPT),
256
+ (LLM_TOKEN_COUNT_COMPLETION, CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION),
257
+ ):
258
+ existing_value = (existing_span[attribute_name] or 0) if existing_span else 0
259
+ new_value = new_span[attribute_name] or 0
260
+ if not (difference := new_value - existing_value):
261
+ continue
262
+ existing_cumulative_value = (
263
+ (existing_span[cumulative_attribute_name] or 0) if existing_span else 0
264
+ )
265
+ new_span[cumulative_attribute_name] = difference + existing_cumulative_value
266
+ self._add_value_to_span_ancestors(
267
+ span_id,
268
+ cumulative_attribute_name,
269
+ difference,
270
+ )
271
+ # Process previously orphaned spans, if any.
272
+ for orphan_span in self._orphan_spans[span_id]:
273
+ self._process_span(orphan_span)
274
+
275
+ def _add_value_to_span_ancestors(
276
+ self,
277
+ span_id: SpanID,
278
+ attribute_name: str,
279
+ value: float,
280
+ ) -> None:
281
+ while parent_span_id := self._parent_span_ids.get(span_id):
282
+ parent_span = self._spans[parent_span_id]
283
+ cumulative_value = parent_span[attribute_name] or 0
284
+ parent_span[attribute_name] = cumulative_value + value
285
+ span_id = parent_span_id
@@ -9,17 +9,14 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
12
- import pytz
13
- from pandas import DataFrame, Series, Timestamp, read_parquet, to_datetime
12
+ from pandas import DataFrame, Series, Timestamp, read_parquet
14
13
  from pandas.api.types import (
15
- is_datetime64_any_dtype,
16
- is_datetime64tz_dtype,
17
14
  is_numeric_dtype,
18
- is_object_dtype,
19
15
  )
20
16
  from typing_extensions import TypeAlias
21
17
 
22
18
  from phoenix.config import DATASET_DIR, GENERATED_DATASET_NAME_PREFIX
19
+ from phoenix.datetime_utils import normalize_timestamps
23
20
 
24
21
  from . import errors as err
25
22
  from .schema import (
@@ -70,6 +67,7 @@ class Dataset:
70
67
  _data_file_name: str = "data.parquet"
71
68
  _schema_file_name: str = "schema.json"
72
69
  _is_persisted: bool = False
70
+ _is_empty: bool = False
73
71
 
74
72
  def __init__(
75
73
  self,
@@ -85,8 +83,6 @@ class Dataset:
85
83
  schema=schema,
86
84
  )
87
85
  if errors:
88
- for e in errors:
89
- logger.error(e)
90
86
  raise err.DatasetError(errors)
91
87
  dataframe, schema = _parse_dataframe_and_schema(dataframe, schema)
92
88
  dataframe, schema = _normalize_timestamps(
@@ -98,6 +94,7 @@ class Dataset:
98
94
  self.__name: str = (
99
95
  name if name is not None else f"{GENERATED_DATASET_NAME_PREFIX}{str(uuid.uuid4())}"
100
96
  )
97
+ self._is_empty = self.dataframe.empty
101
98
  logger.info(f"""Dataset: {self.__name} initialized""")
102
99
 
103
100
  def __repr__(self) -> str:
@@ -531,7 +528,7 @@ def _create_and_normalize_dataframe_and_schema(
531
528
  if column_name_to_include.get(str(column_name), False):
532
529
  included_column_names.append(str(column_name))
533
530
  parsed_dataframe = dataframe[included_column_names].copy()
534
- parsed_schema = replace(schema, excluded_column_names=None, **schema_patch)
531
+ parsed_schema = replace(schema, excluded_column_names=None, **schema_patch) # type: ignore
535
532
  pred_id_col_name = parsed_schema.prediction_id_column_name
536
533
  if pred_id_col_name is None:
537
534
  parsed_schema = replace(parsed_schema, prediction_id_column_name="prediction_id")
@@ -594,8 +591,9 @@ def _normalize_timestamps(
594
591
  """
595
592
  Ensures that the dataframe has a timestamp column and the schema has a timestamp field. If the
596
593
  input dataframe contains a Unix or datetime timestamp or ISO8601 timestamp strings column, it
597
- is converted to UTC timestamps. If the input dataframe and schema do not contain timestamps,
598
- the default timestamp is used.
594
+ is converted to UTC timezone-aware timestamp. If the input dataframe and schema do not contain
595
+ timestamps, the default timestamp is used. If a timestamp is timezone-naive, it is localized
596
+ as per local timezone and then converted to UTC
599
597
  """
600
598
  timestamp_column: Series[Timestamp]
601
599
  if (timestamp_column_name := schema.timestamp_column_name) is None:
@@ -606,18 +604,9 @@ def _normalize_timestamps(
606
604
  if len(dataframe)
607
605
  else Series([default_timestamp]).iloc[:0].set_axis(dataframe.index, axis=0)
608
606
  )
609
- elif is_numeric_dtype(timestamp_column_dtype := dataframe[timestamp_column_name].dtype):
610
- timestamp_column = to_datetime(dataframe[timestamp_column_name], unit="s", utc=True)
611
- elif is_datetime64tz_dtype(timestamp_column_dtype):
612
- timestamp_column = dataframe[timestamp_column_name].dt.tz_convert(pytz.utc)
613
- elif is_datetime64_any_dtype(timestamp_column_dtype):
614
- timestamp_column = dataframe[timestamp_column_name].dt.tz_localize(pytz.utc)
615
- elif is_object_dtype(timestamp_column_dtype):
616
- timestamp_column = to_datetime(dataframe[timestamp_column_name], utc=True)
617
607
  else:
618
- raise ValueError(
619
- "When provided, input timestamp column must have numeric or datetime dtype, "
620
- f"but found {timestamp_column_dtype} instead."
608
+ timestamp_column = normalize_timestamps(
609
+ dataframe[timestamp_column_name],
621
610
  )
622
611
  dataframe[timestamp_column_name] = timestamp_column
623
612
  return dataframe, schema
@@ -731,3 +720,7 @@ def _parse_open_inference_column_name(column_name: str) -> _OpenInferenceColumnN
731
720
  name=extract.get("name", ""),
732
721
  )
733
722
  raise ValueError(f"Invalid format for column name: {column_name}")
723
+
724
+
725
+ # A dataset with no data. Useful for stubs
726
+ EMPTY_DATASET = Dataset(pd.DataFrame(), schema=Schema())
@@ -58,7 +58,10 @@ class DatasetError(Exception):
58
58
  """An error raised when the dataset is invalid or incomplete"""
59
59
 
60
60
  def __init__(self, errors: Union[ValidationError, List[ValidationError]]):
61
- self.errors = errors
61
+ self.errors: List[ValidationError] = errors if isinstance(errors, list) else [errors]
62
+
63
+ def __str__(self) -> str:
64
+ return "\n".join(map(str, self.errors))
62
65
 
63
66
 
64
67
  class InvalidColumnType(ValidationError):
@@ -105,7 +105,7 @@ class Schema:
105
105
  object.__setattr__(self, "prediction_id_column_name", self.id_column_name)
106
106
  object.__setattr__(self, "id_column_name", None)
107
107
 
108
- def replace(self, **changes: str) -> "Schema":
108
+ def replace(self, **changes: Any) -> "Schema":
109
109
  return replace(self, **changes)
110
110
 
111
111
  def asdict(self) -> Dict[str, str]:
@@ -0,0 +1,87 @@
1
+ from datetime import datetime, timedelta, timezone
2
+ from typing import Any, Optional, Tuple, cast
3
+
4
+ import pandas as pd
5
+ import pytz
6
+ from pandas import Timestamp, to_datetime
7
+ from pandas.core.dtypes.common import (
8
+ is_datetime64_any_dtype,
9
+ is_datetime64tz_dtype,
10
+ is_numeric_dtype,
11
+ is_object_dtype,
12
+ )
13
+
14
+
15
+ def normalize_timestamps(
16
+ timestamps: "pd.Series[Any]",
17
+ ) -> "pd.Series[Timestamp]":
18
+ """
19
+ If the input timestamps contains a Unix or datetime timestamp or ISO8601
20
+ timestamp strings column, it is converted to UTC timezone-aware timestamp.
21
+ If a timestamp is timezone-naive, it is localized as per local timezone
22
+ and then converted to UTC.
23
+ """
24
+ if is_numeric_dtype(timestamps):
25
+ return to_datetime(timestamps, unit="s", utc=True)
26
+ if is_datetime64tz_dtype(timestamps):
27
+ return timestamps.dt.tz_convert(pytz.utc)
28
+ if is_datetime64_any_dtype(timestamps):
29
+ return timestamps.dt.tz_localize(
30
+ datetime.now().astimezone().tzinfo,
31
+ ).dt.tz_convert(
32
+ timezone.utc,
33
+ )
34
+ if is_object_dtype(timestamps):
35
+ timestamps = to_datetime(timestamps)
36
+ if timestamps.dt.tz is None:
37
+ timestamps = timestamps.dt.tz_localize(
38
+ datetime.now().astimezone().tzinfo,
39
+ )
40
+ return timestamps.dt.tz_convert(
41
+ timezone.utc,
42
+ )
43
+ raise ValueError(
44
+ "When provided, input timestamp column must have numeric or datetime dtype, "
45
+ f"but found {timestamps.dtype} instead."
46
+ )
47
+
48
+
49
+ def floor_to_minute(dt: datetime) -> datetime:
50
+ """Floor datetime to the minute by taking a round-trip through string
51
+ format because there isn't always an available function to strip the
52
+ nanoseconds if present."""
53
+ try:
54
+ dt_as_string = dt.astimezone(
55
+ timezone.utc,
56
+ ).strftime(
57
+ MINUTE_DATETIME_FORMAT,
58
+ )
59
+ except ValueError:
60
+ # NOTE: as of Python 3.8.16, pandas 1.5.3:
61
+ # >>> isinstance(pd.NaT, datetime.datetime)
62
+ # True
63
+ return cast(datetime, pd.NaT)
64
+ return datetime.strptime(
65
+ dt_as_string,
66
+ MINUTE_DATETIME_FORMAT,
67
+ ).astimezone(
68
+ timezone.utc,
69
+ )
70
+
71
+
72
+ MINUTE_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:00%z"
73
+
74
+
75
+ def right_open_time_range(
76
+ min_time: Optional[datetime],
77
+ max_time: Optional[datetime],
78
+ ) -> Tuple[Optional[datetime], Optional[datetime]]:
79
+ """
80
+ First adds one minute to `max_time`, because time intervals are right
81
+ open and one minute is the smallest interval allowed, then rounds down
82
+ the times to the nearest minute.
83
+ """
84
+ return (
85
+ floor_to_minute(min_time) if min_time else None,
86
+ floor_to_minute(max_time + timedelta(minutes=1)) if max_time else None,
87
+ )
File without changes