arize-phoenix 8.32.1__py3-none-any.whl → 9.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/METADATA +5 -5
  2. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/RECORD +76 -56
  3. phoenix/db/constants.py +1 -0
  4. phoenix/db/facilitator.py +55 -0
  5. phoenix/db/insertion/document_annotation.py +31 -13
  6. phoenix/db/insertion/evaluation.py +15 -3
  7. phoenix/db/insertion/helpers.py +2 -1
  8. phoenix/db/insertion/span_annotation.py +26 -9
  9. phoenix/db/insertion/trace_annotation.py +25 -9
  10. phoenix/db/insertion/types.py +7 -0
  11. phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
  12. phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
  13. phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
  14. phoenix/db/models.py +151 -10
  15. phoenix/db/types/annotation_configs.py +97 -0
  16. phoenix/db/types/db_models.py +41 -0
  17. phoenix/db/types/trace_retention.py +267 -0
  18. phoenix/experiments/functions.py +5 -1
  19. phoenix/server/api/auth.py +9 -0
  20. phoenix/server/api/context.py +5 -0
  21. phoenix/server/api/dataloaders/__init__.py +4 -0
  22. phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
  23. phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
  24. phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
  25. phoenix/server/api/helpers/annotations.py +9 -0
  26. phoenix/server/api/helpers/prompts/models.py +34 -67
  27. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
  28. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
  29. phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
  30. phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
  31. phoenix/server/api/mutations/__init__.py +6 -0
  32. phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
  33. phoenix/server/api/mutations/dataset_mutations.py +62 -39
  34. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
  35. phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
  36. phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
  37. phoenix/server/api/queries.py +86 -0
  38. phoenix/server/api/routers/v1/__init__.py +4 -0
  39. phoenix/server/api/routers/v1/annotation_configs.py +449 -0
  40. phoenix/server/api/routers/v1/annotations.py +161 -0
  41. phoenix/server/api/routers/v1/evaluations.py +6 -0
  42. phoenix/server/api/routers/v1/projects.py +1 -50
  43. phoenix/server/api/routers/v1/spans.py +35 -8
  44. phoenix/server/api/routers/v1/traces.py +22 -13
  45. phoenix/server/api/routers/v1/utils.py +60 -0
  46. phoenix/server/api/types/Annotation.py +7 -0
  47. phoenix/server/api/types/AnnotationConfig.py +124 -0
  48. phoenix/server/api/types/AnnotationSource.py +9 -0
  49. phoenix/server/api/types/AnnotationSummary.py +28 -14
  50. phoenix/server/api/types/AnnotatorKind.py +1 -0
  51. phoenix/server/api/types/CronExpression.py +15 -0
  52. phoenix/server/api/types/Evaluation.py +4 -30
  53. phoenix/server/api/types/Project.py +50 -2
  54. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
  55. phoenix/server/api/types/Span.py +78 -0
  56. phoenix/server/api/types/SpanAnnotation.py +24 -0
  57. phoenix/server/api/types/Trace.py +2 -2
  58. phoenix/server/api/types/TraceAnnotation.py +23 -0
  59. phoenix/server/app.py +20 -0
  60. phoenix/server/retention.py +76 -0
  61. phoenix/server/static/.vite/manifest.json +36 -36
  62. phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
  63. phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
  64. phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
  65. phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
  66. phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
  67. phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
  68. phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
  69. phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
  70. phoenix/trace/dsl/filter.py +25 -5
  71. phoenix/utilities/__init__.py +18 -0
  72. phoenix/version.py +1 -1
  73. phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
  74. phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
  75. phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
  76. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,77 @@
1
+ """create project trace retention policies table
2
+
3
+ Revision ID: bb8139330879
4
+ Revises: 2f9d1a65945f
5
+ Create Date: 2025-02-27 15:57:18.752472
6
+
7
+ """
8
+
9
+ from typing import Any, Sequence, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+ from sqlalchemy import JSON
14
+ from sqlalchemy.dialects import postgresql
15
+ from sqlalchemy.ext.compiler import compiles
16
+
17
+
18
+ class JSONB(JSON):
19
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
20
+ __visit_name__ = "JSONB"
21
+
22
+
23
+ @compiles(JSONB, "sqlite")
24
+ def _(*args: Any, **kwargs: Any) -> str:
25
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
26
+ return "JSONB"
27
+
28
+
29
+ JSON_ = (
30
+ JSON()
31
+ .with_variant(
32
+ postgresql.JSONB(),
33
+ "postgresql",
34
+ )
35
+ .with_variant(
36
+ JSONB(),
37
+ "sqlite",
38
+ )
39
+ )
40
+
41
+
42
+ # revision identifiers, used by Alembic.
43
+ revision: str = "bb8139330879"
44
+ down_revision: Union[str, None] = "2f9d1a65945f"
45
+ branch_labels: Union[str, Sequence[str], None] = None
46
+ depends_on: Union[str, Sequence[str], None] = None
47
+
48
+
49
+ def upgrade() -> None:
50
+ op.create_table(
51
+ "project_trace_retention_policies",
52
+ sa.Column("id", sa.Integer, primary_key=True),
53
+ sa.Column("name", sa.String, nullable=False),
54
+ sa.Column("cron_expression", sa.String, nullable=False),
55
+ sa.Column("rule", JSON_, nullable=False),
56
+ )
57
+ with op.batch_alter_table("projects") as batch_op:
58
+ batch_op.add_column(
59
+ sa.Column(
60
+ "trace_retention_policy_id",
61
+ sa.Integer,
62
+ sa.ForeignKey("project_trace_retention_policies.id", ondelete="SET NULL"),
63
+ nullable=True,
64
+ ),
65
+ )
66
+ op.create_index(
67
+ "ix_projects_trace_retention_policy_id",
68
+ "projects",
69
+ ["trace_retention_policy_id"],
70
+ )
71
+
72
+
73
+ def downgrade() -> None:
74
+ op.drop_index("ix_projects_trace_retention_policy_id")
75
+ with op.batch_alter_table("projects") as batch_op:
76
+ batch_op.drop_column("trace_retention_policy_id")
77
+ op.drop_table("project_trace_retention_policies")
phoenix/db/models.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from datetime import datetime, timezone
2
2
  from enum import Enum
3
- from typing import Any, Iterable, Optional, Sequence, TypedDict, cast
3
+ from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict, cast
4
4
 
5
5
  import sqlalchemy.sql as sql
6
6
  from openinference.semconv.trace import RerankerAttributes, SpanAttributes
@@ -45,8 +45,15 @@ from sqlalchemy.sql.functions import coalesce
45
45
 
46
46
  from phoenix.config import get_env_database_schema
47
47
  from phoenix.datetime_utils import normalize_datetime
48
+ from phoenix.db.types.annotation_configs import (
49
+ AnnotationConfig as AnnotationConfigModel,
50
+ )
51
+ from phoenix.db.types.annotation_configs import (
52
+ AnnotationConfigType,
53
+ )
48
54
  from phoenix.db.types.identifier import Identifier
49
55
  from phoenix.db.types.model_provider import ModelProvider
56
+ from phoenix.db.types.trace_retention import TraceRetentionCronExpression, TraceRetentionRule
50
57
  from phoenix.server.api.helpers.prompts.models import (
51
58
  PromptInvocationParameters,
52
59
  PromptInvocationParametersRootModel,
@@ -267,7 +274,7 @@ class _PromptTemplate(TypeDecorator[PromptTemplate]):
267
274
  class _Tools(TypeDecorator[PromptTools]):
268
275
  # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
269
276
  cache_ok = True
270
- impl = JSON_
277
+ impl = JSON
271
278
 
272
279
  def process_bind_param(
273
280
  self, value: Optional[PromptTools], _: Dialect
@@ -283,7 +290,7 @@ class _Tools(TypeDecorator[PromptTools]):
283
290
  class _PromptResponseFormat(TypeDecorator[PromptResponseFormat]):
284
291
  # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
285
292
  cache_ok = True
286
- impl = JSON_
293
+ impl = JSON
287
294
 
288
295
  def process_bind_param(
289
296
  self, value: Optional[PromptResponseFormat], _: Dialect
@@ -332,6 +339,60 @@ class _TemplateFormat(TypeDecorator[PromptTemplateFormat]):
332
339
  return None if value is None else PromptTemplateFormat(value)
333
340
 
334
341
 
342
+ class _TraceRetentionCronExpression(TypeDecorator[TraceRetentionCronExpression]):
343
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
344
+ cache_ok = True
345
+ impl = String
346
+
347
+ def process_bind_param(
348
+ self, value: Optional[TraceRetentionCronExpression], _: Dialect
349
+ ) -> Optional[str]:
350
+ assert isinstance(value, TraceRetentionCronExpression)
351
+ assert isinstance(ans := value.model_dump(), str)
352
+ return ans
353
+
354
+ def process_result_value(
355
+ self, value: Optional[str], _: Dialect
356
+ ) -> Optional[TraceRetentionCronExpression]:
357
+ assert value and isinstance(value, str)
358
+ return TraceRetentionCronExpression.model_validate(value)
359
+
360
+
361
+ class _TraceRetentionRule(TypeDecorator[TraceRetentionRule]):
362
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
363
+ cache_ok = True
364
+ impl = JSON_
365
+
366
+ def process_bind_param(
367
+ self, value: Optional[TraceRetentionRule], _: Dialect
368
+ ) -> Optional[dict[str, Any]]:
369
+ assert isinstance(value, TraceRetentionRule)
370
+ assert isinstance(ans := value.model_dump(), dict)
371
+ return ans
372
+
373
+ def process_result_value(
374
+ self, value: Optional[dict[str, Any]], _: Dialect
375
+ ) -> Optional[TraceRetentionRule]:
376
+ assert value and isinstance(value, dict)
377
+ return TraceRetentionRule.model_validate(value)
378
+
379
+
380
+ class _AnnotationConfig(TypeDecorator[AnnotationConfigType]):
381
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
382
+ cache_ok = True
383
+ impl = JSON_
384
+
385
+ def process_bind_param(
386
+ self, value: Optional[AnnotationConfigType], _: Dialect
387
+ ) -> Optional[dict[str, Any]]:
388
+ return AnnotationConfigModel(root=value).model_dump() if value is not None else None
389
+
390
+ def process_result_value(
391
+ self, value: Optional[str], _: Dialect
392
+ ) -> Optional[AnnotationConfigType]:
393
+ return AnnotationConfigModel.model_validate(value).root if value is not None else None
394
+
395
+
335
396
  class ExperimentRunOutput(TypedDict, total=False):
336
397
  task_output: Any
337
398
 
@@ -357,6 +418,19 @@ class Base(DeclarativeBase):
357
418
  }
358
419
 
359
420
 
421
+ class ProjectTraceRetentionPolicy(Base):
422
+ __tablename__ = "project_trace_retention_policies"
423
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
424
+ name: Mapped[str] = mapped_column(String, nullable=False)
425
+ cron_expression: Mapped[TraceRetentionCronExpression] = mapped_column(
426
+ _TraceRetentionCronExpression, nullable=False
427
+ )
428
+ rule: Mapped[TraceRetentionRule] = mapped_column(_TraceRetentionRule, nullable=False)
429
+ projects: Mapped[list["Project"]] = relationship(
430
+ "Project", back_populates="trace_retention_policy", uselist=True
431
+ )
432
+
433
+
360
434
  class Project(Base):
361
435
  __tablename__ = "projects"
362
436
  name: Mapped[str]
@@ -374,7 +448,15 @@ class Project(Base):
374
448
  updated_at: Mapped[datetime] = mapped_column(
375
449
  UtcTimeStamp, server_default=func.now(), onupdate=func.now()
376
450
  )
377
-
451
+ trace_retention_policy_id: Mapped[Optional[int]] = mapped_column(
452
+ ForeignKey("project_trace_retention_policies.id", ondelete="SET NULL"),
453
+ nullable=True,
454
+ index=True,
455
+ )
456
+ trace_retention_policy: Mapped[Optional[ProjectTraceRetentionPolicy]] = relationship(
457
+ "ProjectTraceRetentionPolicy",
458
+ back_populates="projects",
459
+ )
378
460
  traces: WriteOnlyMapped[list["Trace"]] = relationship(
379
461
  "Trace",
380
462
  back_populates="project",
@@ -602,6 +684,7 @@ class Span(Base):
602
684
  )
603
685
 
604
686
  trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
687
+ span_annotations: Mapped[list["SpanAnnotation"]] = relationship(back_populates="span")
605
688
  document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
606
689
  dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
607
690
 
@@ -732,17 +815,31 @@ class SpanAnnotation(Base):
732
815
  score: Mapped[Optional[float]] = mapped_column(Float, index=True)
733
816
  explanation: Mapped[Optional[str]]
734
817
  metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
735
- annotator_kind: Mapped[str] = mapped_column(
736
- CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
818
+ annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
819
+ CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
737
820
  )
738
821
  created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
739
822
  updated_at: Mapped[datetime] = mapped_column(
740
823
  UtcTimeStamp, server_default=func.now(), onupdate=func.now()
741
824
  )
825
+ identifier: Mapped[str] = mapped_column(
826
+ String,
827
+ server_default="",
828
+ nullable=False,
829
+ )
830
+ source: Mapped[Literal["API", "APP"]] = mapped_column(
831
+ CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
832
+ )
833
+ user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
834
+
835
+ span: Mapped["Span"] = relationship(back_populates="span_annotations")
836
+ user: Mapped[Optional["User"]] = relationship("User")
837
+
742
838
  __table_args__ = (
743
839
  UniqueConstraint(
744
840
  "name",
745
841
  "span_rowid",
842
+ "identifier",
746
843
  ),
747
844
  )
748
845
 
@@ -758,17 +855,28 @@ class TraceAnnotation(Base):
758
855
  score: Mapped[Optional[float]] = mapped_column(Float, index=True)
759
856
  explanation: Mapped[Optional[str]]
760
857
  metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
761
- annotator_kind: Mapped[str] = mapped_column(
762
- CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
858
+ annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
859
+ CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
763
860
  )
764
861
  created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
765
862
  updated_at: Mapped[datetime] = mapped_column(
766
863
  UtcTimeStamp, server_default=func.now(), onupdate=func.now()
767
864
  )
865
+ identifier: Mapped[str] = mapped_column(
866
+ String,
867
+ server_default="",
868
+ nullable=False,
869
+ )
870
+ source: Mapped[Literal["API", "APP"]] = mapped_column(
871
+ CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
872
+ )
873
+ user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
874
+
768
875
  __table_args__ = (
769
876
  UniqueConstraint(
770
877
  "name",
771
878
  "trace_rowid",
879
+ "identifier",
772
880
  ),
773
881
  )
774
882
 
@@ -785,13 +893,23 @@ class DocumentAnnotation(Base):
785
893
  score: Mapped[Optional[float]] = mapped_column(Float, index=True)
786
894
  explanation: Mapped[Optional[str]]
787
895
  metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
788
- annotator_kind: Mapped[str] = mapped_column(
789
- CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
896
+ annotator_kind: Mapped[Literal["LLM", "CODE", "HUMAN"]] = mapped_column(
897
+ CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
790
898
  )
791
899
  created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
792
900
  updated_at: Mapped[datetime] = mapped_column(
793
901
  UtcTimeStamp, server_default=func.now(), onupdate=func.now()
794
902
  )
903
+ identifier: Mapped[str] = mapped_column(
904
+ String,
905
+ server_default="",
906
+ nullable=False,
907
+ )
908
+ source: Mapped[Literal["API", "APP"]] = mapped_column(
909
+ CheckConstraint("source IN ('API', 'APP')", name="valid_source"),
910
+ )
911
+ user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id", ondelete="SET NULL"))
912
+
795
913
  span: Mapped["Span"] = relationship(back_populates="document_annotations")
796
914
 
797
915
  __table_args__ = (
@@ -799,6 +917,7 @@ class DocumentAnnotation(Base):
799
917
  "name",
800
918
  "span_rowid",
801
919
  "document_position",
920
+ "identifier",
802
921
  ),
803
922
  )
804
923
 
@@ -1301,3 +1420,25 @@ class PromptVersionTag(Base):
1301
1420
  )
1302
1421
 
1303
1422
  __table_args__ = (UniqueConstraint("name", "prompt_id"),)
1423
+
1424
+
1425
+ class AnnotationConfig(Base):
1426
+ __tablename__ = "annotation_configs"
1427
+
1428
+ id: Mapped[int] = mapped_column(primary_key=True)
1429
+ name: Mapped[str] = mapped_column(String, nullable=False, unique=True)
1430
+ config: Mapped[AnnotationConfigType] = mapped_column(_AnnotationConfig, nullable=False)
1431
+
1432
+
1433
+ class ProjectAnnotationConfig(Base):
1434
+ __tablename__ = "project_annotation_configs"
1435
+
1436
+ id: Mapped[int] = mapped_column(primary_key=True)
1437
+ project_id: Mapped[int] = mapped_column(
1438
+ ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True
1439
+ )
1440
+ annotation_config_id: Mapped[int] = mapped_column(
1441
+ ForeignKey("annotation_configs.id", ondelete="CASCADE"), nullable=False, index=True
1442
+ )
1443
+
1444
+ __table_args__ = (UniqueConstraint("project_id", "annotation_config_id"),)
@@ -0,0 +1,97 @@
1
+ from enum import Enum
2
+ from typing import Annotated, Literal, Optional, Union
3
+
4
+ from pydantic import AfterValidator, Field, RootModel, model_validator
5
+ from typing_extensions import Self, TypeAlias
6
+
7
+ from .db_models import DBBaseModel
8
+
9
+
10
+ class AnnotationType(Enum):
11
+ CATEGORICAL = "CATEGORICAL"
12
+ CONTINUOUS = "CONTINUOUS"
13
+ FREEFORM = "FREEFORM"
14
+
15
+
16
+ class OptimizationDirection(Enum):
17
+ MINIMIZE = "MINIMIZE"
18
+ MAXIMIZE = "MAXIMIZE"
19
+ NONE = "NONE"
20
+
21
+
22
+ class _BaseAnnotationConfig(DBBaseModel):
23
+ description: Optional[str] = None
24
+
25
+
26
+ def _categorical_value_label_is_non_empty_string(label: str) -> str:
27
+ if not label:
28
+ raise ValueError("Label must be non-empty")
29
+ return label
30
+
31
+
32
+ class CategoricalAnnotationValue(DBBaseModel):
33
+ label: Annotated[str, AfterValidator(_categorical_value_label_is_non_empty_string)]
34
+ score: Optional[float] = None
35
+
36
+
37
+ def _categorical_values_are_non_empty_list(
38
+ values: list[CategoricalAnnotationValue],
39
+ ) -> list[CategoricalAnnotationValue]:
40
+ if not values:
41
+ raise ValueError("Values must be non-empty")
42
+ return values
43
+
44
+
45
+ def _categorical_values_have_unique_labels(
46
+ values: list[CategoricalAnnotationValue],
47
+ ) -> list[CategoricalAnnotationValue]:
48
+ labels = set()
49
+ for value in values:
50
+ label = value.label
51
+ if label in labels:
52
+ raise ValueError(
53
+ f'Values for categorical annotation config has duplicate label: "{label}"'
54
+ )
55
+ labels.add(label)
56
+ return values
57
+
58
+
59
+ class CategoricalAnnotationConfig(_BaseAnnotationConfig):
60
+ type: Literal[AnnotationType.CATEGORICAL.value] # type: ignore[name-defined]
61
+ optimization_direction: OptimizationDirection
62
+ values: Annotated[
63
+ list[CategoricalAnnotationValue],
64
+ AfterValidator(_categorical_values_are_non_empty_list),
65
+ AfterValidator(_categorical_values_have_unique_labels),
66
+ ]
67
+
68
+
69
+ class ContinuousAnnotationConfig(_BaseAnnotationConfig):
70
+ type: Literal[AnnotationType.CONTINUOUS.value] # type: ignore[name-defined]
71
+ optimization_direction: OptimizationDirection
72
+ lower_bound: Optional[float] = None
73
+ upper_bound: Optional[float] = None
74
+
75
+ @model_validator(mode="after")
76
+ def check_bounds(self) -> Self:
77
+ if (
78
+ self.lower_bound is not None
79
+ and self.upper_bound is not None
80
+ and self.lower_bound >= self.upper_bound
81
+ ):
82
+ raise ValueError("Lower bound must be strictly less than upper bound")
83
+ return self
84
+
85
+
86
+ class FreeformAnnotationConfig(_BaseAnnotationConfig):
87
+ type: Literal[AnnotationType.FREEFORM.value] # type: ignore[name-defined]
88
+
89
+
90
+ AnnotationConfigType: TypeAlias = Annotated[
91
+ Union[CategoricalAnnotationConfig, ContinuousAnnotationConfig, FreeformAnnotationConfig],
92
+ Field(..., discriminator="type"),
93
+ ]
94
+
95
+
96
+ class AnnotationConfig(RootModel[AnnotationConfigType]):
97
+ root: AnnotationConfigType
@@ -0,0 +1,41 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class DBBaseModel(BaseModel):
7
+ """
8
+ A base Pydantic model suitable for use with JSON columns in the database.
9
+ """
10
+
11
+ model_config = ConfigDict(
12
+ extra="forbid", # disallow extra attributes
13
+ use_enum_values=True,
14
+ validate_assignment=True,
15
+ )
16
+
17
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
18
+ kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
19
+ super().__init__(*args, **kwargs)
20
+
21
+ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
22
+ return super().model_dump(*args, exclude_unset=True, by_alias=True, **kwargs)
23
+
24
+
25
+ class Undefined:
26
+ """
27
+ A singleton class that represents an unset or undefined value. Needed since Pydantic
28
+ can't natively distinguish between an undefined value and a value that is set to
29
+ None.
30
+ """
31
+
32
+ def __new__(cls) -> Any:
33
+ if not hasattr(cls, "_instance"):
34
+ cls._instance = super().__new__(cls)
35
+ return cls._instance
36
+
37
+ def __bool__(self) -> bool:
38
+ return False
39
+
40
+
41
+ UNDEFINED: Any = Undefined()