arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (186) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
  2. arize_phoenix-5.7.0.dist-info/RECORD +330 -0
  3. phoenix/config.py +50 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/helpers/playground_clients.py +671 -0
  72. phoenix/server/api/helpers/playground_registry.py +70 -0
  73. phoenix/server/api/helpers/playground_spans.py +325 -0
  74. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  75. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  76. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  77. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  78. phoenix/server/api/input_types/ClusterInput.py +2 -2
  79. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  80. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  81. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  82. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  83. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  84. phoenix/server/api/input_types/Granularity.py +1 -1
  85. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  86. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  87. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  88. phoenix/server/api/mutations/__init__.py +4 -0
  89. phoenix/server/api/mutations/chat_mutations.py +374 -0
  90. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  91. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  92. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  93. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  94. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  95. phoenix/server/api/mutations/user_mutations.py +4 -4
  96. phoenix/server/api/openapi/schema.py +2 -2
  97. phoenix/server/api/queries.py +61 -72
  98. phoenix/server/api/routers/oauth2.py +4 -4
  99. phoenix/server/api/routers/v1/datasets.py +22 -36
  100. phoenix/server/api/routers/v1/evaluations.py +6 -5
  101. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  102. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  103. phoenix/server/api/routers/v1/experiments.py +4 -4
  104. phoenix/server/api/routers/v1/spans.py +13 -12
  105. phoenix/server/api/routers/v1/traces.py +5 -5
  106. phoenix/server/api/routers/v1/utils.py +5 -5
  107. phoenix/server/api/schema.py +42 -10
  108. phoenix/server/api/subscriptions.py +347 -494
  109. phoenix/server/api/types/AnnotationSummary.py +3 -3
  110. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  111. phoenix/server/api/types/Cluster.py +8 -7
  112. phoenix/server/api/types/Dataset.py +5 -4
  113. phoenix/server/api/types/Dimension.py +3 -3
  114. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  115. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  116. phoenix/server/api/types/EvaluationSummary.py +3 -3
  117. phoenix/server/api/types/Event.py +7 -7
  118. phoenix/server/api/types/Experiment.py +3 -3
  119. phoenix/server/api/types/ExperimentComparison.py +2 -4
  120. phoenix/server/api/types/GenerativeProvider.py +27 -3
  121. phoenix/server/api/types/Inferences.py +9 -8
  122. phoenix/server/api/types/InferencesRole.py +2 -2
  123. phoenix/server/api/types/Model.py +2 -2
  124. phoenix/server/api/types/Project.py +11 -18
  125. phoenix/server/api/types/Segments.py +3 -3
  126. phoenix/server/api/types/Span.py +45 -7
  127. phoenix/server/api/types/TemplateLanguage.py +9 -0
  128. phoenix/server/api/types/TimeSeries.py +8 -7
  129. phoenix/server/api/types/Trace.py +2 -2
  130. phoenix/server/api/types/UMAPPoints.py +6 -6
  131. phoenix/server/api/types/User.py +3 -3
  132. phoenix/server/api/types/node.py +1 -3
  133. phoenix/server/api/types/pagination.py +4 -4
  134. phoenix/server/api/utils.py +2 -4
  135. phoenix/server/app.py +76 -37
  136. phoenix/server/bearer_auth.py +4 -10
  137. phoenix/server/dml_event.py +3 -3
  138. phoenix/server/dml_event_handler.py +10 -24
  139. phoenix/server/grpc_server.py +3 -2
  140. phoenix/server/jwt_store.py +22 -21
  141. phoenix/server/main.py +17 -4
  142. phoenix/server/oauth2.py +3 -2
  143. phoenix/server/rate_limiters.py +5 -8
  144. phoenix/server/static/.vite/manifest.json +31 -31
  145. phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
  146. phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
  147. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
  148. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  149. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  150. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  151. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  152. phoenix/server/templates/index.html +1 -0
  153. phoenix/server/thread_server.py +1 -1
  154. phoenix/server/types.py +17 -29
  155. phoenix/services.py +8 -3
  156. phoenix/session/client.py +12 -24
  157. phoenix/session/data_extractor.py +3 -3
  158. phoenix/session/evaluation.py +1 -2
  159. phoenix/session/session.py +26 -21
  160. phoenix/trace/attributes.py +16 -28
  161. phoenix/trace/dsl/filter.py +17 -21
  162. phoenix/trace/dsl/helpers.py +3 -3
  163. phoenix/trace/dsl/query.py +13 -22
  164. phoenix/trace/fixtures.py +11 -17
  165. phoenix/trace/otel.py +5 -15
  166. phoenix/trace/projects.py +3 -2
  167. phoenix/trace/schemas.py +2 -2
  168. phoenix/trace/span_evaluations.py +9 -8
  169. phoenix/trace/span_json_decoder.py +3 -3
  170. phoenix/trace/span_json_encoder.py +2 -2
  171. phoenix/trace/trace_dataset.py +6 -5
  172. phoenix/trace/utils.py +6 -6
  173. phoenix/utilities/deprecation.py +3 -2
  174. phoenix/utilities/error_handling.py +3 -2
  175. phoenix/utilities/json.py +2 -1
  176. phoenix/utilities/logging.py +2 -2
  177. phoenix/utilities/project.py +1 -1
  178. phoenix/utilities/re.py +3 -4
  179. phoenix/utilities/template_formatters.py +16 -5
  180. phoenix/version.py +1 -1
  181. arize_phoenix-5.5.2.dist-info/RECORD +0 -321
  182. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  183. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  184. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  185. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  186. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,7 @@ from copy import deepcopy
5
5
  from dataclasses import dataclass, fields, replace
6
6
  from enum import Enum
7
7
  from itertools import groupby
8
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
8
+ from typing import Any, Optional, Union
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
@@ -154,7 +154,7 @@ class Inferences:
154
154
  @deprecated("Inferences.from_open_inference is deprecated and will be removed.")
155
155
  def from_open_inference(cls, dataframe: DataFrame) -> "Inferences":
156
156
  schema = Schema()
157
- column_renaming: Dict[str, str] = {}
157
+ column_renaming: dict[str, str] = {}
158
158
  for group_name, group in groupby(
159
159
  sorted(
160
160
  map(_parse_open_inference_column_name, dataframe.columns),
@@ -351,7 +351,7 @@ def _parse_open_inference_column_name(column_name: str) -> _OpenInferenceColumnN
351
351
  raise ValueError(f"Invalid format for column name: {column_name}")
352
352
 
353
353
 
354
- def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[DataFrame, Schema]:
354
+ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> tuple[DataFrame, Schema]:
355
355
  """
356
356
  Parses a dataframe according to a schema, infers feature columns names when
357
357
  they are not explicitly provided, and removes excluded column names from
@@ -364,12 +364,12 @@ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[D
364
364
  names present in the dataframe but not included in any other schema fields.
365
365
  """
366
366
 
367
- unseen_excluded_column_names: Set[str] = (
367
+ unseen_excluded_column_names: set[str] = (
368
368
  set(schema.excluded_column_names) if schema.excluded_column_names is not None else set()
369
369
  )
370
- unseen_column_names: Set[str] = set(dataframe.columns.to_list())
371
- column_name_to_include: Dict[str, bool] = {}
372
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue] = {}
370
+ unseen_column_names: set[str] = set(dataframe.columns.to_list())
371
+ column_name_to_include: dict[str, bool] = {}
372
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue] = {}
373
373
 
374
374
  for schema_field_name in SINGLE_COLUMN_SCHEMA_FIELD_NAMES:
375
375
  _check_single_column_schema_field_for_excluded_columns(
@@ -434,10 +434,10 @@ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[D
434
434
  def _check_single_column_schema_field_for_excluded_columns(
435
435
  schema: Schema,
436
436
  schema_field_name: str,
437
- unseen_excluded_column_names: Set[str],
438
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
439
- column_name_to_include: Dict[str, bool],
440
- unseen_column_names: Set[str],
437
+ unseen_excluded_column_names: set[str],
438
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue],
439
+ column_name_to_include: dict[str, bool],
440
+ unseen_column_names: set[str],
441
441
  ) -> None:
442
442
  """
443
443
  Checks single-column schema fields for excluded column names.
@@ -455,18 +455,18 @@ def _check_single_column_schema_field_for_excluded_columns(
455
455
  def _check_multi_column_schema_field_for_excluded_columns(
456
456
  schema: Schema,
457
457
  schema_field_name: str,
458
- unseen_excluded_column_names: Set[str],
459
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
460
- column_name_to_include: Dict[str, bool],
461
- unseen_column_names: Set[str],
458
+ unseen_excluded_column_names: set[str],
459
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue],
460
+ column_name_to_include: dict[str, bool],
461
+ unseen_column_names: set[str],
462
462
  ) -> None:
463
463
  """
464
464
  Checks multi-column schema fields for excluded columns names.
465
465
  """
466
- column_names: Optional[List[str]] = getattr(schema, schema_field_name)
466
+ column_names: Optional[list[str]] = getattr(schema, schema_field_name)
467
467
  if column_names:
468
- included_column_names: List[str] = []
469
- excluded_column_names: List[str] = []
468
+ included_column_names: list[str] = []
469
+ excluded_column_names: list[str] = []
470
470
  for column_name in column_names:
471
471
  is_included_column = column_name not in unseen_excluded_column_names
472
472
  column_name_to_include[column_name] = is_included_column
@@ -482,10 +482,10 @@ def _check_multi_column_schema_field_for_excluded_columns(
482
482
 
483
483
  def _check_embedding_features_schema_field_for_excluded_columns(
484
484
  embedding_features: EmbeddingFeatures,
485
- unseen_excluded_column_names: Set[str],
486
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
487
- column_name_to_include: Dict[str, bool],
488
- unseen_column_names: Set[str],
485
+ unseen_excluded_column_names: set[str],
486
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue],
487
+ column_name_to_include: dict[str, bool],
488
+ unseen_column_names: set[str],
489
489
  ) -> None:
490
490
  """
491
491
  Check embedding features for excluded column names.
@@ -527,8 +527,8 @@ def _check_embedding_features_schema_field_for_excluded_columns(
527
527
 
528
528
  def _check_embedding_column_names_for_excluded_columns(
529
529
  embedding_column_name_mapping: EmbeddingColumnNames,
530
- column_name_to_include: Dict[str, bool],
531
- unseen_column_names: Set[str],
530
+ column_name_to_include: dict[str, bool],
531
+ unseen_column_names: set[str],
532
532
  ) -> None:
533
533
  """
534
534
  Check embedding column names for excluded column names.
@@ -542,10 +542,10 @@ def _check_embedding_column_names_for_excluded_columns(
542
542
 
543
543
  def _discover_feature_columns(
544
544
  dataframe: DataFrame,
545
- unseen_excluded_column_names: Set[str],
546
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
547
- column_name_to_include: Dict[str, bool],
548
- unseen_column_names: Set[str],
545
+ unseen_excluded_column_names: set[str],
546
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue],
547
+ column_name_to_include: dict[str, bool],
548
+ unseen_column_names: set[str],
549
549
  ) -> None:
550
550
  """
551
551
  Adds unseen and un-excluded columns as features, with the exception of "prediction_id"
@@ -559,10 +559,10 @@ def _discover_feature_columns(
559
559
  else:
560
560
  unseen_excluded_column_names.discard(column_name)
561
561
  logger.debug(f"excluded feature: {column_name}")
562
- original_column_positions: List[int] = dataframe.columns.get_indexer(
562
+ original_column_positions: list[int] = dataframe.columns.get_indexer(
563
563
  discovered_feature_column_names
564
564
  ) # type: ignore
565
- feature_column_name_to_position: Dict[str, int] = dict(
565
+ feature_column_name_to_position: dict[str, int] = dict(
566
566
  zip(discovered_feature_column_names, original_column_positions)
567
567
  )
568
568
  discovered_feature_column_names.sort(key=lambda col: feature_column_name_to_position[col])
@@ -575,16 +575,16 @@ def _discover_feature_columns(
575
575
  def _create_and_normalize_dataframe_and_schema(
576
576
  dataframe: DataFrame,
577
577
  schema: Schema,
578
- schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
579
- column_name_to_include: Dict[str, bool],
580
- ) -> Tuple[DataFrame, Schema]:
578
+ schema_patch: dict[SchemaFieldName, SchemaFieldValue],
579
+ column_name_to_include: dict[str, bool],
580
+ ) -> tuple[DataFrame, Schema]:
581
581
  """
582
582
  Creates new dataframe and schema objects to reflect excluded column names
583
583
  and discovered features. This also normalizes dataframe columns to ensure a
584
584
  standard set of columns (i.e. timestamp and prediction_id) and datatypes for
585
585
  those columns.
586
586
  """
587
- included_column_names: List[str] = []
587
+ included_column_names: list[str] = []
588
588
  for column_name in dataframe.columns:
589
589
  if column_name_to_include.get(str(column_name), False):
590
590
  included_column_names.append(str(column_name))
@@ -648,7 +648,7 @@ def _normalize_timestamps(
648
648
  dataframe: DataFrame,
649
649
  schema: Schema,
650
650
  default_timestamp: Timestamp,
651
- ) -> Tuple[DataFrame, Schema]:
651
+ ) -> tuple[DataFrame, Schema]:
652
652
  """
653
653
  Ensures that the dataframe has a timestamp column and the schema has a timestamp field. If the
654
654
  input dataframe contains a Unix or datetime timestamp or ISO8601 timestamp strings column, it
@@ -686,7 +686,7 @@ def _get_schema_from_unknown_schema_param(schemaLike: SchemaLike) -> Schema:
686
686
  if not isinstance(schemaLike, ArizeSchema):
687
687
  raise ValueError("Unknown schema passed to Dataset. Please pass a phoenix Schema")
688
688
 
689
- embedding_feature_column_names: Dict[str, EmbeddingColumnNames] = {}
689
+ embedding_feature_column_names: dict[str, EmbeddingColumnNames] = {}
690
690
  if schemaLike.embedding_feature_column_names is not None:
691
691
  for (
692
692
  embedding_name,
@@ -734,7 +734,7 @@ def _get_schema_from_unknown_schema_param(schemaLike: SchemaLike) -> Schema:
734
734
  )
735
735
 
736
736
 
737
- def _add_prediction_id(num_rows: int) -> List[str]:
737
+ def _add_prediction_id(num_rows: int) -> list[str]:
738
738
  return [str(uuid.uuid4()) for _ in range(num_rows)]
739
739
 
740
740
 
@@ -1,13 +1,14 @@
1
1
  import json
2
+ from collections.abc import Mapping
2
3
  from dataclasses import asdict, dataclass, replace
3
- from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
4
+ from typing import Any, Optional, Union
4
5
 
5
- EmbeddingFeatures = Dict[str, "EmbeddingColumnNames"]
6
+ EmbeddingFeatures = dict[str, "EmbeddingColumnNames"]
6
7
  SchemaFieldName = str
7
- SchemaFieldValue = Union[Optional[str], Optional[List[str]], Optional[EmbeddingFeatures]]
8
+ SchemaFieldValue = Union[Optional[str], Optional[list[str]], Optional[EmbeddingFeatures]]
8
9
 
9
- MULTI_COLUMN_SCHEMA_FIELD_NAMES: Tuple[str, ...] = ("feature_column_names", "tag_column_names")
10
- SINGLE_COLUMN_SCHEMA_FIELD_NAMES: Tuple[str, ...] = (
10
+ MULTI_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = ("feature_column_names", "tag_column_names")
11
+ SINGLE_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = (
11
12
  "prediction_id_column_name",
12
13
  "timestamp_column_name",
13
14
  "prediction_label_column_name",
@@ -19,7 +20,7 @@ LLM_SCHEMA_FIELD_NAMES = ["prompt_column_names", "response_column_names"]
19
20
 
20
21
 
21
22
  @dataclass(frozen=True)
22
- class EmbeddingColumnNames(Dict[str, Any]):
23
+ class EmbeddingColumnNames(dict[str, Any]):
23
24
  """
24
25
  A dataclass to hold the column names for the embedding features.
25
26
  An embedding feature is a feature that is represented by a vector.
@@ -80,8 +81,8 @@ class Schema:
80
81
  prediction_id_column_name: Optional[str] = None
81
82
  id_column_name: Optional[str] = None # Syntax sugar for prediction_id_column_name
82
83
  timestamp_column_name: Optional[str] = None
83
- feature_column_names: Optional[List[str]] = None
84
- tag_column_names: Optional[List[str]] = None
84
+ feature_column_names: Optional[list[str]] = None
85
+ tag_column_names: Optional[list[str]] = None
85
86
  prediction_label_column_name: Optional[str] = None
86
87
  prediction_score_column_name: Optional[str] = None
87
88
  actual_label_column_name: Optional[str] = None
@@ -91,7 +92,7 @@ class Schema:
91
92
  # document_column_names is used explicitly when the schema is used to capture a corpus
92
93
  document_column_names: Optional[EmbeddingColumnNames] = None
93
94
  embedding_feature_column_names: Optional[EmbeddingFeatures] = None
94
- excluded_column_names: Optional[List[str]] = None
95
+ excluded_column_names: Optional[list[str]] = None
95
96
 
96
97
  def __post_init__(self) -> None:
97
98
  # re-map document_column_names to be in the prompt_column_names position
@@ -107,7 +108,7 @@ class Schema:
107
108
  def replace(self, **changes: Any) -> "Schema":
108
109
  return replace(self, **changes)
109
110
 
110
- def asdict(self) -> Dict[str, str]:
111
+ def asdict(self) -> dict[str, str]:
111
112
  return asdict(self)
112
113
 
113
114
  def to_json(self) -> str:
@@ -1,5 +1,4 @@
1
1
  import math
2
- from typing import List
3
2
 
4
3
  import numpy as np
5
4
  from pandas import DataFrame, Series
@@ -11,8 +10,8 @@ from .schema import EmbeddingColumnNames, Schema
11
10
  RESERVED_EMBEDDING_NAMES = ("prompt", "response")
12
11
 
13
12
 
14
- def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
15
- errs: List[str] = []
13
+ def _check_valid_schema(schema: Schema) -> list[err.ValidationError]:
14
+ errs: list[str] = []
16
15
  if schema.excluded_column_names is None:
17
16
  return []
18
17
 
@@ -34,7 +33,7 @@ def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
34
33
  return []
35
34
 
36
35
 
37
- def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
36
+ def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
38
37
  errors = _check_missing_columns(dataframe, schema)
39
38
  if errors:
40
39
  return errors
@@ -53,12 +52,12 @@ def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> List[err
53
52
  return []
54
53
 
55
54
 
56
- def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
55
+ def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
57
56
  embedding_col_names = schema.embedding_feature_column_names
58
57
  if embedding_col_names is None:
59
58
  return []
60
59
 
61
- embedding_errors: List[err.ValidationError] = []
60
+ embedding_errors: list[err.ValidationError] = []
62
61
  for embedding_name, column_names in embedding_col_names.items():
63
62
  if embedding_name in RESERVED_EMBEDDING_NAMES:
64
63
  embedding_errors += _validate_reserved_embedding_name(embedding_name, schema)
@@ -71,8 +70,8 @@ def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> List[er
71
70
 
72
71
  def _check_valid_prompt_response_data(
73
72
  dataframe: DataFrame, schema: Schema
74
- ) -> List[err.ValidationError]:
75
- prompt_response_errors: List[err.ValidationError] = []
73
+ ) -> list[err.ValidationError]:
74
+ prompt_response_errors: list[err.ValidationError] = []
76
75
 
77
76
  prompt_response_column_names = {
78
77
  "prompt": schema.prompt_column_names,
@@ -89,7 +88,7 @@ def _check_valid_prompt_response_data(
89
88
 
90
89
  def _validate_reserved_embedding_name(
91
90
  embedding_name: str, schema: Schema
92
- ) -> List[err.ValidationError]:
91
+ ) -> list[err.ValidationError]:
93
92
  if embedding_name == "prompt" and schema.prompt_column_names is not None:
94
93
  return [err.InvalidEmbeddingReservedName(embedding_name, "schema.prompt_column_names")]
95
94
  elif embedding_name == "response" and schema.response_column_names is not None:
@@ -99,9 +98,9 @@ def _validate_reserved_embedding_name(
99
98
 
100
99
  def _validate_embedding_vector(
101
100
  dataframe: DataFrame, name: str, vector_column_name: str
102
- ) -> List[err.ValidationError]:
101
+ ) -> list[err.ValidationError]:
103
102
  vector_column = dataframe[vector_column_name]
104
- errors: List[err.ValidationError] = []
103
+ errors: list[err.ValidationError] = []
105
104
  vector_length = None
106
105
 
107
106
  for vector in vector_column:
@@ -156,8 +155,8 @@ def _validate_embedding_vector(
156
155
  return errors
157
156
 
158
157
 
159
- def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
160
- wrong_type_cols: List[str] = []
158
+ def _check_column_types(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
159
+ wrong_type_cols: list[str] = []
161
160
  if schema.prediction_id_column_name is not None:
162
161
  if not (
163
162
  is_numeric_dtype(dataframe.dtypes[schema.prediction_id_column_name])
@@ -172,7 +171,7 @@ def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.Valida
172
171
  return []
173
172
 
174
173
 
175
- def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
174
+ def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
176
175
  # converting to a set first makes the checks run a lot faster
177
176
  existing_columns = set(dataframe.columns)
178
177
  missing_columns = []
@@ -1,7 +1,7 @@
1
1
  import datetime as dt
2
2
  import json
3
3
  import logging
4
- from typing import Dict, Optional
4
+ from typing import Optional
5
5
 
6
6
  LOG_RECORD_BUILTIN_ATTRS = {
7
7
  "args",
@@ -34,7 +34,7 @@ class PhoenixJSONFormatter(logging.Formatter):
34
34
  def __init__(
35
35
  self,
36
36
  *,
37
- fmt_keys: Optional[Dict[str, str]] = None,
37
+ fmt_keys: Optional[dict[str, str]] = None,
38
38
  ):
39
39
  super().__init__()
40
40
  self.fmt_keys = fmt_keys if fmt_keys is not None else {}
@@ -43,7 +43,7 @@ class PhoenixJSONFormatter(logging.Formatter):
43
43
  message = self._prepare_log_dict(record)
44
44
  return json.dumps(message, default=str)
45
45
 
46
- def _prepare_log_dict(self, record: logging.LogRecord) -> Dict[str, str]:
46
+ def _prepare_log_dict(self, record: logging.LogRecord) -> dict[str, str]:
47
47
  always_fields = {
48
48
  "message": record.getMessage(),
49
49
  "timestamp": dt.datetime.fromtimestamp(record.created, tz=dt.timezone.utc).isoformat(),
@@ -1,8 +1,9 @@
1
1
  import logging
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
+ from collections.abc import Iterable, Mapping
4
5
  from dataclasses import dataclass
5
- from typing import Any, Iterable, List, Mapping, Optional, Union
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  import numpy as np
8
9
  import pandas as pd
@@ -36,13 +37,13 @@ class Metric(ABC):
36
37
  def calc(self, dataframe: pd.DataFrame) -> Any: ...
37
38
 
38
39
  @abstractmethod
39
- def operands(self) -> List[Column]: ...
40
+ def operands(self) -> list[Column]: ...
40
41
 
41
42
  def __call__(
42
43
  self,
43
44
  df: pd.DataFrame,
44
45
  /,
45
- subset_rows: Optional[Union[slice, List[int]]] = None,
46
+ subset_rows: Optional[Union[slice, list[int]]] = None,
46
47
  ) -> Any:
47
48
  """
48
49
  Computes the metric on a dataframe.
@@ -51,7 +52,7 @@ class Metric(ABC):
51
52
  ----------
52
53
  df: pandas DataFrame
53
54
  The dataframe input to the metric.
54
- subset_rows: Optional[Union[slice, List[int]]] = None
55
+ subset_rows: Optional[Union[slice, list[int]]] = None
55
56
  Optionally specifying a subset of rows for the computation.
56
57
  Can be a list or slice (e.g. `slice(100, 200)`) of integers.
57
58
  """
@@ -1,8 +1,9 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable, Sequence
3
4
  from dataclasses import dataclass
4
5
  from functools import partial
5
- from typing import Any, Iterable, Optional, Sequence, cast
6
+ from typing import Any, Optional, cast
6
7
 
7
8
  import numpy as np
8
9
  import pandas as pd
@@ -1,8 +1,9 @@
1
1
  import math
2
2
  import warnings
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass, field
4
5
  from functools import cached_property
5
- from typing import Callable, Union, cast
6
+ from typing import Union, cast
6
7
 
7
8
  import numpy as np
8
9
  import numpy.typing as npt
phoenix/metrics/mixins.py CHANGED
@@ -7,10 +7,11 @@ on cooperative multiple inheritance and method resolution order in Python.
7
7
  import collections
8
8
  import inspect
9
9
  from abc import ABC, abstractmethod
10
+ from collections.abc import Callable
10
11
  from dataclasses import dataclass, field, fields, replace
11
12
  from functools import cached_property
12
13
  from itertools import repeat
13
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Mapping, Optional
14
+ from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional
14
15
 
15
16
  import numpy as np
16
17
  import pandas as pd
@@ -42,7 +43,7 @@ class VectorOperator(ABC):
42
43
 
43
44
  @dataclass(frozen=True)
44
45
  class NullaryOperator(Metric, ABC):
45
- def operands(self) -> List[Column]:
46
+ def operands(self) -> list[Column]:
46
47
  return []
47
48
 
48
49
 
@@ -55,7 +56,7 @@ class UnaryOperator(Metric, ABC):
55
56
 
56
57
  operand: Column = Column()
57
58
 
58
- def operands(self) -> List[Column]:
59
+ def operands(self) -> list[Column]:
59
60
  return [self.operand]
60
61
 
61
62
 
@@ -98,10 +99,10 @@ class EvaluationMetricKeywordParameters(_BaseMapping):
98
99
  return sum(1 for _ in self)
99
100
 
100
101
  @property
101
- def columns(self) -> List[Column]:
102
+ def columns(self) -> list[Column]:
102
103
  return [v for v in self.values() if isinstance(v, Column)]
103
104
 
104
- def __call__(self, df: pd.DataFrame) -> Dict[str, Any]:
105
+ def __call__(self, df: pd.DataFrame) -> dict[str, Any]:
105
106
  return {k: v(df) if isinstance(v, Column) else v for k, v in self.items()}
106
107
 
107
108
 
@@ -142,7 +143,7 @@ class EvaluationMetric(Metric, ABC):
142
143
  ),
143
144
  )
144
145
 
145
- def operands(self) -> List[Column]:
146
+ def operands(self) -> list[Column]:
146
147
  return [self.actual, self.predicted] + self.parameters.columns
147
148
 
148
149
  def calc(self, df: pd.DataFrame) -> float:
@@ -1,5 +1,6 @@
1
+ from collections.abc import Iterable
1
2
  from dataclasses import dataclass, field
2
- from typing import Iterable, Optional, cast
3
+ from typing import Optional, cast
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
@@ -1,7 +1,8 @@
1
+ from collections.abc import Callable, Iterable, Iterator
1
2
  from datetime import datetime, timedelta, timezone
2
3
  from functools import partial
3
4
  from itertools import accumulate, repeat
4
- from typing import Callable, Iterable, Iterator, Tuple, cast
5
+ from typing import cast
5
6
 
6
7
  import pandas as pd
7
8
  from typing_extensions import TypeAlias
@@ -41,12 +42,12 @@ def row_interval_from_sorted_time_index(
41
42
  time_index: pd.DatetimeIndex,
42
43
  time_start: datetime,
43
44
  time_stop: datetime,
44
- ) -> Tuple[StartIndex, StopIndex]:
45
+ ) -> tuple[StartIndex, StopIndex]:
45
46
  """
46
47
  Returns end exclusive time slice from sorted index.
47
48
  """
48
49
  return cast(
49
- Tuple[StartIndex, StopIndex],
50
+ tuple[StartIndex, StopIndex],
50
51
  time_index.searchsorted((time_start, time_stop)),
51
52
  )
52
53
 
@@ -86,7 +87,7 @@ def _groupers(
86
87
  end_time: datetime,
87
88
  evaluation_window: timedelta,
88
89
  sampling_interval: timedelta,
89
- ) -> Iterator[Tuple[StartTime, EndTime, pd.Grouper]]:
90
+ ) -> Iterator[tuple[StartTime, EndTime, pd.Grouper]]:
90
91
  """
91
92
  Yields pandas.Groupers from time series parameters.
92
93
  """
@@ -18,7 +18,7 @@ from abc import ABC
18
18
  from enum import Enum
19
19
  from inspect import Signature
20
20
  from itertools import chain, islice
21
- from typing import Any, Dict, List, Tuple, cast
21
+ from typing import Any, cast
22
22
 
23
23
  import numpy as np
24
24
  import pandas as pd
@@ -157,7 +157,7 @@ def _coerce_dtype_if_necessary(
157
157
  def _eliminate_missing_values_from_all_series(
158
158
  *args: Any,
159
159
  **kwargs: Any,
160
- ) -> Tuple[List[Any], Dict[str, Any]]:
160
+ ) -> tuple[list[Any], dict[str, Any]]:
161
161
  positional_arguments = list(args)
162
162
  keyword_arguments = dict(kwargs)
163
163
  all_series = [
@@ -1,12 +1,11 @@
1
1
  from dataclasses import asdict, dataclass
2
- from typing import List, Set
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  RowIndex: TypeAlias = int
9
- RawCluster: TypeAlias = Set[RowIndex]
8
+ RawCluster: TypeAlias = set[RowIndex]
10
9
  Matrix: TypeAlias = npt.NDArray[np.float64]
11
10
 
12
11
 
@@ -16,11 +15,11 @@ class Hdbscan:
16
15
  min_samples: float = 1
17
16
  cluster_selection_epsilon: float = 0.0
18
17
 
19
- def find_clusters(self, mat: Matrix) -> List[RawCluster]:
18
+ def find_clusters(self, mat: Matrix) -> list[RawCluster]:
20
19
  from fast_hdbscan import HDBSCAN
21
20
 
22
21
  cluster_ids: npt.NDArray[np.int_] = HDBSCAN(**asdict(self)).fit_predict(mat)
23
- ans: List[RawCluster] = [set() for _ in range(np.max(cluster_ids) + 1)]
22
+ ans: list[RawCluster] = [set() for _ in range(np.max(cluster_ids) + 1)]
24
23
  for row_idx, cluster_id in enumerate(cluster_ids):
25
24
  if cluster_id > -1:
26
25
  ans[cluster_id].add(row_idx)
@@ -1,9 +1,9 @@
1
+ from collections.abc import Hashable, Mapping
1
2
  from dataclasses import dataclass
2
- from typing import Dict, List, Mapping, Protocol, Set, Tuple
3
+ from typing import Protocol, TypeVar
3
4
 
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
- from strawberry import ID
7
7
  from typing_extensions import TypeAlias
8
8
 
9
9
  from phoenix.pointcloud.clustering import RawCluster
@@ -12,13 +12,15 @@ Vector: TypeAlias = npt.NDArray[np.float64]
12
12
  Matrix: TypeAlias = npt.NDArray[np.float64]
13
13
  RowIndex: TypeAlias = int
14
14
 
15
+ _IdType = TypeVar("_IdType", bound=Hashable)
16
+
15
17
 
16
18
  class DimensionalityReducer(Protocol):
17
19
  def project(self, mat: Matrix, n_components: int) -> Matrix: ...
18
20
 
19
21
 
20
22
  class ClustersFinder(Protocol):
21
- def find_clusters(self, mat: Matrix) -> List[RawCluster]: ...
23
+ def find_clusters(self, mat: Matrix) -> list[RawCluster]: ...
22
24
 
23
25
 
24
26
  @dataclass(frozen=True)
@@ -28,9 +30,9 @@ class PointCloud:
28
30
 
29
31
  def generate(
30
32
  self,
31
- data: Mapping[ID, Vector],
33
+ data: Mapping[_IdType, Vector],
32
34
  n_components: int = 3,
33
- ) -> Tuple[Dict[ID, Vector], Dict[str, Set[ID]]]:
35
+ ) -> tuple[dict[_IdType, Vector], dict[str, set[_IdType]]]:
34
36
  """
35
37
  Given a set of vectors, projects them onto lower dimensions, and
36
38
  finds clusters among the projections.
@@ -1,5 +1,6 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
- from typing import Any, Mapping, Optional
3
+ from typing import Any, Optional
3
4
 
4
5
  DEFAULT_MIN_DIST = 0.0
5
6
  DEFAULT_N_NEIGHBORS = 30