arize-phoenix 8.32.1__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/METADATA +2 -2
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/RECORD +76 -56
- phoenix/db/constants.py +1 -0
- phoenix/db/facilitator.py +55 -0
- phoenix/db/insertion/document_annotation.py +31 -13
- phoenix/db/insertion/evaluation.py +15 -3
- phoenix/db/insertion/helpers.py +2 -1
- phoenix/db/insertion/span_annotation.py +26 -9
- phoenix/db/insertion/trace_annotation.py +25 -9
- phoenix/db/insertion/types.py +7 -0
- phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
- phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
- phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
- phoenix/db/models.py +151 -10
- phoenix/db/types/annotation_configs.py +97 -0
- phoenix/db/types/db_models.py +41 -0
- phoenix/db/types/trace_retention.py +267 -0
- phoenix/experiments/functions.py +5 -1
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/context.py +5 -0
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
- phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
- phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
- phoenix/server/api/helpers/annotations.py +9 -0
- phoenix/server/api/helpers/prompts/models.py +34 -67
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
- phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
- phoenix/server/api/mutations/dataset_mutations.py +62 -39
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
- phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
- phoenix/server/api/queries.py +86 -0
- phoenix/server/api/routers/v1/__init__.py +4 -0
- phoenix/server/api/routers/v1/annotation_configs.py +449 -0
- phoenix/server/api/routers/v1/annotations.py +161 -0
- phoenix/server/api/routers/v1/evaluations.py +6 -0
- phoenix/server/api/routers/v1/projects.py +1 -50
- phoenix/server/api/routers/v1/spans.py +35 -8
- phoenix/server/api/routers/v1/traces.py +22 -13
- phoenix/server/api/routers/v1/utils.py +60 -0
- phoenix/server/api/types/Annotation.py +7 -0
- phoenix/server/api/types/AnnotationConfig.py +124 -0
- phoenix/server/api/types/AnnotationSource.py +9 -0
- phoenix/server/api/types/AnnotationSummary.py +28 -14
- phoenix/server/api/types/AnnotatorKind.py +1 -0
- phoenix/server/api/types/CronExpression.py +15 -0
- phoenix/server/api/types/Evaluation.py +4 -30
- phoenix/server/api/types/Project.py +50 -2
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
- phoenix/server/api/types/Span.py +78 -0
- phoenix/server/api/types/SpanAnnotation.py +24 -0
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/TraceAnnotation.py +23 -0
- phoenix/server/app.py +20 -0
- phoenix/server/retention.py +76 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
- phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
- phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
- phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
- phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
- phoenix/trace/dsl/filter.py +25 -5
- phoenix/utilities/__init__.py +18 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
- phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
- phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/db/models.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from datetime import datetime, timezone
|
|
2
2
|
from enum import Enum
|
|
3
|
-
from typing import Any, Iterable, Optional, Sequence, TypedDict, cast
|
|
3
|
+
from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict, cast
|
|
4
4
|
|
|
5
5
|
import sqlalchemy.sql as sql
|
|
6
6
|
from openinference.semconv.trace import RerankerAttributes, SpanAttributes
|
|
@@ -45,8 +45,15 @@ from sqlalchemy.sql.functions import coalesce
|
|
|
45
45
|
|
|
46
46
|
from phoenix.config import get_env_database_schema
|
|
47
47
|
from phoenix.datetime_utils import normalize_datetime
|
|
48
|
+
from phoenix.db.types.annotation_configs import (
|
|
49
|
+
AnnotationConfig as AnnotationConfigModel,
|
|
50
|
+
)
|
|
51
|
+
from phoenix.db.types.annotation_configs import (
|
|
52
|
+
AnnotationConfigType,
|
|
53
|
+
)
|
|
48
54
|
from phoenix.db.types.identifier import Identifier
|
|
49
55
|
from phoenix.db.types.model_provider import ModelProvider
|
|
56
|
+
from phoenix.db.types.trace_retention import TraceRetentionCronExpression, TraceRetentionRule
|
|
50
57
|
from phoenix.server.api.helpers.prompts.models import (
|
|
51
58
|
PromptInvocationParameters,
|
|
52
59
|
PromptInvocationParametersRootModel,
|
|
@@ -267,7 +274,7 @@ class _PromptTemplate(TypeDecorator[PromptTemplate]):
|
|
|
267
274
|
class _Tools(TypeDecorator[PromptTools]):
|
|
268
275
|
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
|
|
269
276
|
cache_ok = True
|
|
270
|
-
impl =
|
|
277
|
+
impl = JSON
|
|
271
278
|
|
|
272
279
|
def process_bind_param(
|
|
273
280
|
self, value: Optional[PromptTools], _: Dialect
|
|
@@ -283,7 +290,7 @@ class _Tools(TypeDecorator[PromptTools]):
|
|
|
283
290
|
class _PromptResponseFormat(TypeDecorator[PromptResponseFormat]):
|
|
284
291
|
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
|
|
285
292
|
cache_ok = True
|
|
286
|
-
impl =
|
|
293
|
+
impl = JSON
|
|
287
294
|
|
|
288
295
|
def process_bind_param(
|
|
289
296
|
self, value: Optional[PromptResponseFormat], _: Dialect
|
|
@@ -332,6 +339,60 @@ class _TemplateFormat(TypeDecorator[PromptTemplateFormat]):
|
|
|
332
339
|
return None if value is None else PromptTemplateFormat(value)
|
|
333
340
|
|
|
334
341
|
|
|
342
|
+
class _TraceRetentionCronExpression(TypeDecorator[TraceRetentionCronExpression]):
|
|
343
|
+
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
|
|
344
|
+
cache_ok = True
|
|
345
|
+
impl = String
|
|
346
|
+
|
|
347
|
+
def process_bind_param(
|
|
348
|
+
self, value: Optional[TraceRetentionCronExpression], _: Dialect
|
|
349
|
+
) -> Optional[str]:
|
|
350
|
+
assert isinstance(value, TraceRetentionCronExpression)
|
|
351
|
+
assert isinstance(ans := value.model_dump(), str)
|
|
352
|
+
return ans
|
|
353
|
+
|
|
354
|
+
def process_result_value(
|
|
355
|
+
self, value: Optional[str], _: Dialect
|
|
356
|
+
) -> Optional[TraceRetentionCronExpression]:
|
|
357
|
+
assert value and isinstance(value, str)
|
|
358
|
+
return TraceRetentionCronExpression.model_validate(value)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class _TraceRetentionRule(TypeDecorator[TraceRetentionRule]):
|
|
362
|
+
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
|
|
363
|
+
cache_ok = True
|
|
364
|
+
impl = JSON_
|
|
365
|
+
|
|
366
|
+
def process_bind_param(
|
|
367
|
+
self, value: Optional[TraceRetentionRule], _: Dialect
|
|
368
|
+
) -> Optional[dict[str, Any]]:
|
|
369
|
+
assert isinstance(value, TraceRetentionRule)
|
|
370
|
+
assert isinstance(ans := value.model_dump(), dict)
|
|
371
|
+
return ans
|
|
372
|
+
|
|
373
|
+
def process_result_value(
|
|
374
|
+
self, value: Optional[dict[str, Any]], _: Dialect
|
|
375
|
+
) -> Optional[TraceRetentionRule]:
|
|
376
|
+
assert value and isinstance(value, dict)
|
|
377
|
+
return TraceRetentionRule.model_validate(value)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class _AnnotationConfig(TypeDecorator[AnnotationConfigType]):
|
|
381
|
+
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
|
|
382
|
+
cache_ok = True
|
|
383
|
+
impl = JSON_
|
|
384
|
+
|
|
385
|
+
def process_bind_param(
|
|
386
|
+
self, value: Optional[AnnotationConfigType], _: Dialect
|
|
387
|
+
) -> Optional[dict[str, Any]]:
|
|
388
|
+
return AnnotationConfigModel(root=value).model_dump() if value is not None else None
|
|
389
|
+
|
|
390
|
+
def process_result_value(
|
|
391
|
+
self, value: Optional[str], _: Dialect
|
|
392
|
+
) -> Optional[AnnotationConfigType]:
|
|
393
|
+
return AnnotationConfigModel.model_validate(value).root if value is not None else None
|
|
394
|
+
|
|
395
|
+
|
|
335
396
|
class ExperimentRunOutput(TypedDict, total=False):
|
|
336
397
|
task_output: Any
|
|
337
398
|
|
|
@@ -357,6 +418,19 @@ class Base(DeclarativeBase):
|
|
|
357
418
|
}
|
|
358
419
|
|
|
359
420
|
|
|
421
|
+
class ProjectTraceRetentionPolicy(Base):
|
|
422
|
+
__tablename__ = "project_trace_retention_policies"
|
|
423
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
424
|
+
name: Mapped[str] = mapped_column(String, nullable=False)
|
|
425
|
+
cron_expression: Mapped[TraceRetentionCronExpression] = mapped_column(
|
|
426
|
+
_TraceRetentionCronExpression, nullable=False
|
|
427
|
+
)
|
|
428
|
+
rule: Mapped[TraceRetentionRule] = mapped_column(_TraceRetentionRule, nullable=False)
|
|
429
|
+
projects: Mapped[list["Project"]] = relationship(
|
|
430
|
+
"Project", back_populates="trace_retention_policy", uselist=True
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
360
434
|
class Project(Base):
|
|
361
435
|
__tablename__ = "projects"
|
|
362
436
|
name: Mapped[str]
|
|
@@ -374,7 +448,15 @@ class Project(Base):
|
|
|
374
448
|
updated_at: Mapped[datetime] = mapped_column(
|
|
375
449
|
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
|
|
376
450
|
)
|
|
377
|
-
|
|
451
|
+
trace_retention_policy_id: Mapped[Optional[int]] = mapped_column(
|
|
452
|
+
ForeignKey("project_trace_retention_policies.id", ondelete="SET NULL"),
|
|
453
|
+
nullable=True,
|
|
454
|
+
index=True,
|
|
455
|
+
)
|
|
456
|
+
trace_retention_policy: Mapped[Optional[ProjectTraceRetentionPolicy]] = relationship(
|
|
457
|
+
"ProjectTraceRetentionPolicy",
|
|
458
|
+
back_populates="projects",
|
|
459
|
+
)
|
|
378
460
|
traces: WriteOnlyMapped[list["Trace"]] = relationship(
|
|
379
461
|
"Trace",
|
|
380
462
|
back_populates="project",
|
|
@@ -602,6 +684,7 @@ class Span(Base):
|
|
|
602
684
|
)
|
|
603
685
|
|
|
604
686
|
trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
|
|
687
|
+
span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span")
|
|
605
688
|
document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
|
|
606
689
|
dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
|
|
607
690
|
|
|
@@ -732,17 +815,31 @@ class SpanAnnotation(Base):
|
|
|
732
815
|
score: Mapped[Optional[float]] = mapped_column(Float, index=True)
|
|
733
816
|
explanation: Mapped[Optional[str]]
|
|
734
817
|
metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
|
|
735
|
-
annotator_kind: Mapped[
|
|
736
|
-
CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
|
|
818
|
+
annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
|
|
819
|
+
CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
|
|
737
820
|
)
|
|
738
821
|
created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
|
|
739
822
|
updated_at: Mapped[datetime] = mapped_column(
|
|
740
823
|
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
|
|
741
824
|
)
|
|
825
|
+
identifier: Mapped[str] = mapped_column(
|
|
826
|
+
String,
|
|
827
|
+
server_default="",
|
|
828
|
+
nullable=False,
|
|
829
|
+
)
|
|
830
|
+
source: Mapped[Literal["API", "APP"]] = mapped_column(
|
|
831
|
+
CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
|
|
832
|
+
)
|
|
833
|
+
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
|
|
834
|
+
|
|
835
|
+
span: Mapped["Span"] = relationship(back_populates="span_annotations")
|
|
836
|
+
user: Mapped[Optional["User"]] = relationship("User")
|
|
837
|
+
|
|
742
838
|
__table_args__ = (
|
|
743
839
|
UniqueConstraint(
|
|
744
840
|
"name",
|
|
745
841
|
"span_rowid",
|
|
842
|
+
"identifier",
|
|
746
843
|
),
|
|
747
844
|
)
|
|
748
845
|
|
|
@@ -758,17 +855,28 @@ class TraceAnnotation(Base):
|
|
|
758
855
|
score: Mapped[Optional[float]] = mapped_column(Float, index=True)
|
|
759
856
|
explanation: Mapped[Optional[str]]
|
|
760
857
|
metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
|
|
761
|
-
annotator_kind: Mapped[
|
|
762
|
-
CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
|
|
858
|
+
annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
|
|
859
|
+
CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
|
|
763
860
|
)
|
|
764
861
|
created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
|
|
765
862
|
updated_at: Mapped[datetime] = mapped_column(
|
|
766
863
|
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
|
|
767
864
|
)
|
|
865
|
+
identifier: Mapped[str] = mapped_column(
|
|
866
|
+
String,
|
|
867
|
+
server_default="",
|
|
868
|
+
nullable=False,
|
|
869
|
+
)
|
|
870
|
+
source: Mapped[Literal["API", "APP"]] = mapped_column(
|
|
871
|
+
CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
|
|
872
|
+
)
|
|
873
|
+
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
|
|
874
|
+
|
|
768
875
|
__table_args__ = (
|
|
769
876
|
UniqueConstraint(
|
|
770
877
|
"name",
|
|
771
878
|
"trace_rowid",
|
|
879
|
+
"identifier",
|
|
772
880
|
),
|
|
773
881
|
)
|
|
774
882
|
|
|
@@ -785,13 +893,23 @@ class DocumentAnnotation(Base):
|
|
|
785
893
|
score: Mapped[Optional[float]] = mapped_column(Float, index=True)
|
|
786
894
|
explanation: Mapped[Optional[str]]
|
|
787
895
|
metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
|
|
788
|
-
annotator_kind: Mapped[
|
|
789
|
-
CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
|
|
896
|
+
annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
|
|
897
|
+
CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
|
|
790
898
|
)
|
|
791
899
|
created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
|
|
792
900
|
updated_at: Mapped[datetime] = mapped_column(
|
|
793
901
|
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
|
|
794
902
|
)
|
|
903
|
+
identifier: Mapped[str] = mapped_column(
|
|
904
|
+
String,
|
|
905
|
+
server_default="",
|
|
906
|
+
nullable=False,
|
|
907
|
+
)
|
|
908
|
+
source: Mapped[Literal["API", "APP"]] = mapped_column(
|
|
909
|
+
CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
|
|
910
|
+
)
|
|
911
|
+
user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
|
|
912
|
+
|
|
795
913
|
span: Mapped["Span"] = relationship(back_populates="document_annotations")
|
|
796
914
|
|
|
797
915
|
__table_args__ = (
|
|
@@ -799,6 +917,7 @@ class DocumentAnnotation(Base):
|
|
|
799
917
|
"name",
|
|
800
918
|
"span_rowid",
|
|
801
919
|
"document_position",
|
|
920
|
+
"identifier",
|
|
802
921
|
),
|
|
803
922
|
)
|
|
804
923
|
|
|
@@ -1301,3 +1420,25 @@ class PromptVersionTag(Base):
|
|
|
1301
1420
|
)
|
|
1302
1421
|
|
|
1303
1422
|
__table_args__ = (UniqueConstraint("name", "prompt_id"),)
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
class AnnotationConfig(Base):
|
|
1426
|
+
__tablename__ = "annotation_configs"
|
|
1427
|
+
|
|
1428
|
+
id: Mapped[int] = mapped_column(primary_key=True)
|
|
1429
|
+
name: Mapped[str] = mapped_column(String, nullable=False, unique=True)
|
|
1430
|
+
config: Mapped[AnnotationConfigType] = mapped_column(_AnnotationConfig, nullable=False)
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
class ProjectAnnotationConfig(Base):
|
|
1434
|
+
__tablename__ = "project_annotation_configs"
|
|
1435
|
+
|
|
1436
|
+
id: Mapped[int] = mapped_column(primary_key=True)
|
|
1437
|
+
project_id: Mapped[int] = mapped_column(
|
|
1438
|
+
ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True
|
|
1439
|
+
)
|
|
1440
|
+
annotation_config_id: Mapped[int] = mapped_column(
|
|
1441
|
+
ForeignKey("annotation_configs.id", ondelete="CASCADE"), nullable=False, index=True
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
__table_args__ = (UniqueConstraint("project_id", "annotation_config_id"),)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Annotated, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import AfterValidator, Field, RootModel, model_validator
|
|
5
|
+
from typing_extensions import Self, TypeAlias
|
|
6
|
+
|
|
7
|
+
from .db_models import DBBaseModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnnotationType(Enum):
|
|
11
|
+
CATEGORICAL = "CATEGORICAL"
|
|
12
|
+
CONTINUOUS = "CONTINUOUS"
|
|
13
|
+
FREEFORM = "FREEFORM"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OptimizationDirection(Enum):
|
|
17
|
+
MINIMIZE = "MINIMIZE"
|
|
18
|
+
MAXIMIZE = "MAXIMIZE"
|
|
19
|
+
NONE = "NONE"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _BaseAnnotationConfig(DBBaseModel):
|
|
23
|
+
description: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _categorical_value_label_is_non_empty_string(label: str) -> str:
|
|
27
|
+
if not label:
|
|
28
|
+
raise ValueError("Label must be non-empty")
|
|
29
|
+
return label
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CategoricalAnnotationValue(DBBaseModel):
|
|
33
|
+
label: Annotated[str, AfterValidator(_categorical_value_label_is_non_empty_string)]
|
|
34
|
+
score: Optional[float] = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _categorical_values_are_non_empty_list(
|
|
38
|
+
values: list[CategoricalAnnotationValue],
|
|
39
|
+
) -> list[CategoricalAnnotationValue]:
|
|
40
|
+
if not values:
|
|
41
|
+
raise ValueError("Values must be non-empty")
|
|
42
|
+
return values
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _categorical_values_have_unique_labels(
|
|
46
|
+
values: list[CategoricalAnnotationValue],
|
|
47
|
+
) -> list[CategoricalAnnotationValue]:
|
|
48
|
+
labels = set()
|
|
49
|
+
for value in values:
|
|
50
|
+
label = value.label
|
|
51
|
+
if label in labels:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f'Values for categorical annotation config has duplicate label: "{label}"'
|
|
54
|
+
)
|
|
55
|
+
labels.add(label)
|
|
56
|
+
return values
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CategoricalAnnotationConfig(_BaseAnnotationConfig):
|
|
60
|
+
type: Literal[AnnotationType.CATEGORICAL.value] # type: ignore[name-defined]
|
|
61
|
+
optimization_direction: OptimizationDirection
|
|
62
|
+
values: Annotated[
|
|
63
|
+
list[CategoricalAnnotationValue],
|
|
64
|
+
AfterValidator(_categorical_values_are_non_empty_list),
|
|
65
|
+
AfterValidator(_categorical_values_have_unique_labels),
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ContinuousAnnotationConfig(_BaseAnnotationConfig):
|
|
70
|
+
type: Literal[AnnotationType.CONTINUOUS.value] # type: ignore[name-defined]
|
|
71
|
+
optimization_direction: OptimizationDirection
|
|
72
|
+
lower_bound: Optional[float] = None
|
|
73
|
+
upper_bound: Optional[float] = None
|
|
74
|
+
|
|
75
|
+
@model_validator(mode="after")
|
|
76
|
+
def check_bounds(self) -> Self:
|
|
77
|
+
if (
|
|
78
|
+
self.lower_bound is not None
|
|
79
|
+
and self.upper_bound is not None
|
|
80
|
+
and self.lower_bound >= self.upper_bound
|
|
81
|
+
):
|
|
82
|
+
raise ValueError("Lower bound must be strictly less than upper bound")
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class FreeformAnnotationConfig(_BaseAnnotationConfig):
|
|
87
|
+
type: Literal[AnnotationType.FREEFORM.value] # type: ignore[name-defined]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
AnnotationConfigType: TypeAlias = Annotated[
|
|
91
|
+
Union[CategoricalAnnotationConfig, ContinuousAnnotationConfig, FreeformAnnotationConfig],
|
|
92
|
+
Field(..., discriminator="type"),
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class AnnotationConfig(RootModel[AnnotationConfigType]):
|
|
97
|
+
root: AnnotationConfigType
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DBBaseModel(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
A base Pydantic model suitable for use with JSON columns in the database.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
model_config = ConfigDict(
|
|
12
|
+
extra="forbid", # disallow extra attributes
|
|
13
|
+
use_enum_values=True,
|
|
14
|
+
validate_assignment=True,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
18
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
|
|
21
|
+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
|
22
|
+
return super().model_dump(*args, exclude_unset=True, by_alias=True, **kwargs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Undefined:
|
|
26
|
+
"""
|
|
27
|
+
A singleton class that represents an unset or undefined value. Needed since Pydantic
|
|
28
|
+
can't natively distinguish between an undefined value and a value that is set to
|
|
29
|
+
None.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __new__(cls) -> Any:
|
|
33
|
+
if not hasattr(cls, "_instance"):
|
|
34
|
+
cls._instance = super().__new__(cls)
|
|
35
|
+
return cls._instance
|
|
36
|
+
|
|
37
|
+
def __bool__(self) -> bool:
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
UNDEFINED: Any = Undefined()
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
from typing import Annotated, Iterable, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import sqlalchemy as sa
|
|
7
|
+
from pydantic import AfterValidator, BaseModel, Field, RootModel
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from phoenix.utilities import hour_of_week
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _MaxDays(BaseModel):
|
|
14
|
+
max_days: Annotated[float, Field(ge=0)]
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def max_days_filter(self) -> sa.ColumnElement[bool]:
|
|
18
|
+
if self.max_days <= 0:
|
|
19
|
+
return sa.literal(False)
|
|
20
|
+
from phoenix.db.models import Trace
|
|
21
|
+
|
|
22
|
+
return Trace.start_time < datetime.now(timezone.utc) - timedelta(days=self.max_days)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class _MaxCount(BaseModel):
|
|
26
|
+
max_count: Annotated[int, Field(ge=0)]
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def max_count_filter(self) -> sa.ColumnElement[bool]:
|
|
30
|
+
if self.max_count <= 0:
|
|
31
|
+
return sa.literal(False)
|
|
32
|
+
from phoenix.db.models import Trace
|
|
33
|
+
|
|
34
|
+
return Trace.start_time < (
|
|
35
|
+
sa.select(Trace.start_time)
|
|
36
|
+
.order_by(Trace.start_time.desc())
|
|
37
|
+
.offset(self.max_count - 1)
|
|
38
|
+
.limit(1)
|
|
39
|
+
.scalar_subquery()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MaxDaysRule(_MaxDays, BaseModel):
|
|
44
|
+
type: Literal["max_days"] = "max_days"
|
|
45
|
+
|
|
46
|
+
def __bool__(self) -> bool:
|
|
47
|
+
return self.max_days > 0
|
|
48
|
+
|
|
49
|
+
async def delete_traces(
|
|
50
|
+
self,
|
|
51
|
+
session: AsyncSession,
|
|
52
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
53
|
+
) -> set[int]:
|
|
54
|
+
if self.max_days <= 0:
|
|
55
|
+
return set()
|
|
56
|
+
from phoenix.db.models import Trace
|
|
57
|
+
|
|
58
|
+
stmt = (
|
|
59
|
+
sa.delete(Trace)
|
|
60
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
61
|
+
.where(self.max_days_filter)
|
|
62
|
+
.returning(Trace.project_rowid)
|
|
63
|
+
)
|
|
64
|
+
return set(await session.scalars(stmt))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class MaxCountRule(_MaxCount, BaseModel):
|
|
68
|
+
type: Literal["max_count"] = "max_count"
|
|
69
|
+
|
|
70
|
+
def __bool__(self) -> bool:
|
|
71
|
+
return self.max_count > 0
|
|
72
|
+
|
|
73
|
+
async def delete_traces(
|
|
74
|
+
self,
|
|
75
|
+
session: AsyncSession,
|
|
76
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
77
|
+
) -> set[int]:
|
|
78
|
+
if self.max_count <= 0:
|
|
79
|
+
return set()
|
|
80
|
+
from phoenix.db.models import Trace
|
|
81
|
+
|
|
82
|
+
stmt = (
|
|
83
|
+
sa.delete(Trace)
|
|
84
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
85
|
+
.where(self.max_count_filter)
|
|
86
|
+
.returning(Trace.project_rowid)
|
|
87
|
+
)
|
|
88
|
+
return set(await session.scalars(stmt))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class MaxDaysOrCountRule(_MaxDays, _MaxCount, BaseModel):
|
|
92
|
+
type: Literal["max_days_or_count"] = "max_days_or_count"
|
|
93
|
+
|
|
94
|
+
def __bool__(self) -> bool:
|
|
95
|
+
return self.max_days > 0 or self.max_count > 0
|
|
96
|
+
|
|
97
|
+
async def delete_traces(
|
|
98
|
+
self,
|
|
99
|
+
session: AsyncSession,
|
|
100
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
101
|
+
) -> set[int]:
|
|
102
|
+
if self.max_days <= 0 and self.max_count <= 0:
|
|
103
|
+
return set()
|
|
104
|
+
from phoenix.db.models import Trace
|
|
105
|
+
|
|
106
|
+
stmt = (
|
|
107
|
+
sa.delete(Trace)
|
|
108
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
109
|
+
.where(sa.or_(self.max_days_filter, self.max_count_filter))
|
|
110
|
+
.returning(Trace.project_rowid)
|
|
111
|
+
)
|
|
112
|
+
return set(await session.scalars(stmt))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class TraceRetentionRule(RootModel[Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule]]):
|
|
116
|
+
root: Annotated[
|
|
117
|
+
Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule], Field(discriminator="type")
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
def __bool__(self) -> bool:
|
|
121
|
+
return bool(self.root)
|
|
122
|
+
|
|
123
|
+
async def delete_traces(
|
|
124
|
+
self,
|
|
125
|
+
session: AsyncSession,
|
|
126
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
127
|
+
) -> set[int]:
|
|
128
|
+
return await self.root.delete_traces(session, project_rowids)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _time_of_next_run(
|
|
132
|
+
cron_expression: str,
|
|
133
|
+
after: Optional[datetime] = None,
|
|
134
|
+
) -> datetime:
|
|
135
|
+
"""
|
|
136
|
+
Parse a cron expression and calculate the UTC datetime of the next run.
|
|
137
|
+
Only processes hour, and day of week fields; day-of-month and
|
|
138
|
+
month fields must be '*'; minute field must be 0.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
cron_expression (str): Standard cron expression with 5 fields:
|
|
142
|
+
minute hour day-of-month month day-of-week
|
|
143
|
+
(minute must be '0'; day-of-month and month must be '*')
|
|
144
|
+
after: Optional[datetime]: The datetime to start searching from. If None,
|
|
145
|
+
the current time is used. Must be timezone-aware.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
datetime: The datetime of the next run. Timezone is UTC.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If the expression has non-wildcard values for day-of-month or month, if the
|
|
152
|
+
minute field is not '0', or if no match is found within the next 7 days (168 hours).
|
|
153
|
+
"""
|
|
154
|
+
fields: list[str] = cron_expression.strip().split()
|
|
155
|
+
if len(fields) != 5:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Invalid cron expression. Expected 5 fields "
|
|
158
|
+
"(minute hour day-of-month month day-of-week)."
|
|
159
|
+
)
|
|
160
|
+
if fields[0] != "0":
|
|
161
|
+
raise ValueError("Invalid cron expression. Minute field must be '0'.")
|
|
162
|
+
if fields[2] != "*" or fields[3] != "*":
|
|
163
|
+
raise ValueError("Invalid cron expression. Day-of-month and month fields must be '*'.")
|
|
164
|
+
hours: set[int] = _parse_field(fields[1], 0, 23)
|
|
165
|
+
# Parse days of week (0-6, where 0 is Sunday)
|
|
166
|
+
days_of_week: set[int] = _parse_field(fields[4], 0, 6)
|
|
167
|
+
# Convert to Python's weekday format (0-6, where 0 is Monday)
|
|
168
|
+
# Sunday (0 in cron) becomes 6 in Python's weekday()
|
|
169
|
+
python_days_of_week = {(day_of_week + 6) % 7 for day_of_week in days_of_week}
|
|
170
|
+
t = after.replace(tzinfo=timezone.utc) if after else datetime.now(timezone.utc)
|
|
171
|
+
t = t.replace(minute=0, second=0, microsecond=0)
|
|
172
|
+
for _ in range(168): # Check up to 7 days (168 hours)
|
|
173
|
+
t += timedelta(hours=1)
|
|
174
|
+
if t.hour in hours and t.weekday() in python_days_of_week:
|
|
175
|
+
return t
|
|
176
|
+
raise ValueError("No matching execution time found within the next 7 days.")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TraceRetentionCronExpression(RootModel[str]):
|
|
180
|
+
root: Annotated[str, AfterValidator(lambda x: (_time_of_next_run(x), x)[1])]
|
|
181
|
+
|
|
182
|
+
def get_hour_of_prev_run(self) -> int:
|
|
183
|
+
"""
|
|
184
|
+
Calculate the hour of the previous run before now.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
int: The hour of the previous run (0-167), where 0 is midnight Sunday UTC.
|
|
188
|
+
"""
|
|
189
|
+
after = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
190
|
+
return hour_of_week(_time_of_next_run(self.root, after))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _parse_field(field: str, min_val: int, max_val: int) -> set[int]:
|
|
194
|
+
"""
|
|
195
|
+
Parse a cron field and return the set of matching values.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
field (str): The cron field to parse
|
|
199
|
+
min_val (int): Minimum allowed value for this field
|
|
200
|
+
max_val (int): Maximum allowed value for this field
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
set[int]: Set of all valid values represented by the field expression
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: If the field contains invalid values or formats
|
|
207
|
+
"""
|
|
208
|
+
if field == "*":
|
|
209
|
+
return set(range(min_val, max_val + 1))
|
|
210
|
+
values: set[int] = set()
|
|
211
|
+
for part in field.split(","):
|
|
212
|
+
if "/" in part:
|
|
213
|
+
# Handle steps
|
|
214
|
+
range_part, step_str = part.split("/")
|
|
215
|
+
try:
|
|
216
|
+
step = int(step_str)
|
|
217
|
+
except ValueError:
|
|
218
|
+
raise ValueError(f"Invalid step value: {step_str}")
|
|
219
|
+
if step <= 0:
|
|
220
|
+
raise ValueError(f"Step value must be positive: {step}")
|
|
221
|
+
if range_part == "*":
|
|
222
|
+
start, end = min_val, max_val
|
|
223
|
+
elif "-" in range_part:
|
|
224
|
+
try:
|
|
225
|
+
start_str, end_str = range_part.split("-")
|
|
226
|
+
start, end = int(start_str), int(end_str)
|
|
227
|
+
except ValueError:
|
|
228
|
+
raise ValueError(f"Invalid range format: {range_part}")
|
|
229
|
+
if start < min_val or end > max_val:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
|
|
232
|
+
)
|
|
233
|
+
if start > end:
|
|
234
|
+
raise ValueError(f"Invalid range: {start}-{end} (start > end)")
|
|
235
|
+
else:
|
|
236
|
+
try:
|
|
237
|
+
start = int(range_part)
|
|
238
|
+
except ValueError:
|
|
239
|
+
raise ValueError(f"Invalid value: {range_part}")
|
|
240
|
+
if start < min_val or start > max_val:
|
|
241
|
+
raise ValueError(f"Value {start} out of range ({min_val}-{max_val})")
|
|
242
|
+
end = max_val
|
|
243
|
+
values.update(range(start, end + 1, step))
|
|
244
|
+
elif "-" in part:
|
|
245
|
+
# Handle ranges
|
|
246
|
+
try:
|
|
247
|
+
start_str, end_str = part.split("-")
|
|
248
|
+
start, end = int(start_str), int(end_str)
|
|
249
|
+
except ValueError:
|
|
250
|
+
raise ValueError(f"Invalid range format: {part}")
|
|
251
|
+
if start < min_val or end > max_val:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
|
|
254
|
+
)
|
|
255
|
+
if start > end:
|
|
256
|
+
raise ValueError(f"Invalid range: {start}-{end} (start > end)")
|
|
257
|
+
values.update(range(start, end + 1))
|
|
258
|
+
else:
|
|
259
|
+
# Handle single values
|
|
260
|
+
try:
|
|
261
|
+
value = int(part)
|
|
262
|
+
except ValueError:
|
|
263
|
+
raise ValueError(f"Invalid value: {part}")
|
|
264
|
+
if value < min_val or value > max_val:
|
|
265
|
+
raise ValueError(f"Value {value} out of range ({min_val}-{max_val})")
|
|
266
|
+
values.add(value)
|
|
267
|
+
return values
|