arize-phoenix 3.25.0__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +54 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +33 -199
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -0,0 +1,209 @@
1
+ from typing import NamedTuple, Optional
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from typing_extensions import assert_never
6
+
7
+ from phoenix.db import models
8
+ from phoenix.db.helpers import SupportedSQLDialect, num_docs_col
9
+ from phoenix.db.insertion.helpers import OnConflict, insert_stmt
10
+ from phoenix.exceptions import PhoenixException
11
+ from phoenix.trace import v1 as pb
12
+
13
+
14
+ class InsertEvaluationError(PhoenixException):
15
+ pass
16
+
17
+
18
+ class EvaluationInsertionResult(NamedTuple):
19
+ project_rowid: int
20
+ evaluation_name: str
21
+
22
+
23
+ class SpanEvaluationInsertionEvent(EvaluationInsertionResult): ...
24
+
25
+
26
+ class TraceEvaluationInsertionEvent(EvaluationInsertionResult): ...
27
+
28
+
29
+ class DocumentEvaluationInsertionEvent(EvaluationInsertionResult): ...
30
+
31
+
32
+ async def insert_evaluation(
33
+ session: AsyncSession,
34
+ evaluation: pb.Evaluation,
35
+ ) -> Optional[EvaluationInsertionResult]:
36
+ evaluation_name = evaluation.name
37
+ result = evaluation.result
38
+ label = result.label.value if result.HasField("label") else None
39
+ score = result.score.value if result.HasField("score") else None
40
+ explanation = result.explanation.value if result.HasField("explanation") else None
41
+ if (evaluation_kind := evaluation.subject_id.WhichOneof("kind")) is None:
42
+ raise InsertEvaluationError("Cannot insert an evaluation that has no evaluation kind")
43
+ elif evaluation_kind == "trace_id":
44
+ trace_id = evaluation.subject_id.trace_id
45
+ return await _insert_trace_evaluation(
46
+ session, trace_id, evaluation_name, label, score, explanation
47
+ )
48
+ elif evaluation_kind == "span_id":
49
+ span_id = evaluation.subject_id.span_id
50
+ return await _insert_span_evaluation(
51
+ session, span_id, evaluation_name, label, score, explanation
52
+ )
53
+ elif evaluation_kind == "document_retrieval_id":
54
+ span_id = evaluation.subject_id.document_retrieval_id.span_id
55
+ document_position = evaluation.subject_id.document_retrieval_id.document_position
56
+ return await _insert_document_evaluation(
57
+ session, span_id, document_position, evaluation_name, label, score, explanation
58
+ )
59
+ else:
60
+ assert_never(evaluation_kind)
61
+
62
+
63
+ async def _insert_trace_evaluation(
64
+ session: AsyncSession,
65
+ trace_id: str,
66
+ evaluation_name: str,
67
+ label: Optional[str],
68
+ score: Optional[float],
69
+ explanation: Optional[str],
70
+ ) -> TraceEvaluationInsertionEvent:
71
+ stmt = select(
72
+ models.Trace.project_rowid,
73
+ models.Trace.id,
74
+ ).where(models.Trace.trace_id == trace_id)
75
+ if not (row := (await session.execute(stmt)).first()):
76
+ raise InsertEvaluationError(
77
+ f"Cannot insert a trace evaluation for a missing trace: {evaluation_name=}, {trace_id=}"
78
+ )
79
+ project_rowid, trace_rowid = row
80
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
81
+ values = dict(
82
+ trace_rowid=trace_rowid,
83
+ name=evaluation_name,
84
+ label=label,
85
+ score=score,
86
+ explanation=explanation,
87
+ metadata_={}, # `metadata_` must match ORM
88
+ annotator_kind="LLM",
89
+ )
90
+ set_ = dict(values)
91
+ set_.pop("metadata_")
92
+ set_["metadata"] = values["metadata_"] # `metadata` must match database
93
+ await session.execute(
94
+ insert_stmt(
95
+ dialect=dialect,
96
+ table=models.TraceAnnotation,
97
+ values=values,
98
+ constraint="uq_trace_annotations_name_trace_rowid",
99
+ column_names=("name", "trace_rowid"),
100
+ on_conflict=OnConflict.DO_UPDATE,
101
+ set_=set_,
102
+ )
103
+ )
104
+ return TraceEvaluationInsertionEvent(project_rowid, evaluation_name)
105
+
106
+
107
+ async def _insert_span_evaluation(
108
+ session: AsyncSession,
109
+ span_id: str,
110
+ evaluation_name: str,
111
+ label: Optional[str],
112
+ score: Optional[float],
113
+ explanation: Optional[str],
114
+ ) -> SpanEvaluationInsertionEvent:
115
+ stmt = (
116
+ select(
117
+ models.Trace.project_rowid,
118
+ models.Span.id,
119
+ )
120
+ .join_from(models.Span, models.Trace)
121
+ .where(models.Span.span_id == span_id)
122
+ )
123
+ if not (row := (await session.execute(stmt)).first()):
124
+ raise InsertEvaluationError(
125
+ f"Cannot insert a span evaluation for a missing span: {evaluation_name=}, {span_id=}"
126
+ )
127
+ project_rowid, span_rowid = row
128
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
129
+ values = dict(
130
+ span_rowid=span_rowid,
131
+ name=evaluation_name,
132
+ label=label,
133
+ score=score,
134
+ explanation=explanation,
135
+ metadata_={}, # `metadata_` must match ORM
136
+ annotator_kind="LLM",
137
+ )
138
+ set_ = dict(values)
139
+ set_.pop("metadata_")
140
+ set_["metadata"] = values["metadata_"] # `metadata` must match database
141
+ await session.execute(
142
+ insert_stmt(
143
+ dialect=dialect,
144
+ table=models.SpanAnnotation,
145
+ values=values,
146
+ constraint="uq_span_annotations_name_span_rowid",
147
+ column_names=("name", "span_rowid"),
148
+ on_conflict=OnConflict.DO_UPDATE,
149
+ set_=set_,
150
+ )
151
+ )
152
+ return SpanEvaluationInsertionEvent(project_rowid, evaluation_name)
153
+
154
+
155
+ async def _insert_document_evaluation(
156
+ session: AsyncSession,
157
+ span_id: str,
158
+ document_position: int,
159
+ evaluation_name: str,
160
+ label: Optional[str],
161
+ score: Optional[float],
162
+ explanation: Optional[str],
163
+ ) -> EvaluationInsertionResult:
164
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
165
+ stmt = (
166
+ select(
167
+ models.Trace.project_rowid,
168
+ models.Span.id,
169
+ num_docs_col(dialect),
170
+ )
171
+ .join_from(models.Span, models.Trace)
172
+ .where(models.Span.span_id == span_id)
173
+ )
174
+ if not (row := (await session.execute(stmt)).first()):
175
+ raise InsertEvaluationError(
176
+ f"Cannot insert a document evaluation for a missing span: {span_id=}"
177
+ )
178
+ project_rowid, span_rowid, num_docs = row
179
+ if num_docs is None or num_docs <= document_position:
180
+ raise InsertEvaluationError(
181
+ f"Cannot insert a document evaluation for a non-existent "
182
+ f"document position: {evaluation_name=}, {span_id=}, {document_position=}"
183
+ )
184
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
185
+ values = dict(
186
+ span_rowid=span_rowid,
187
+ document_position=document_position,
188
+ name=evaluation_name,
189
+ label=label,
190
+ score=score,
191
+ explanation=explanation,
192
+ metadata_={}, # `metadata_` must match ORM
193
+ annotator_kind="LLM",
194
+ )
195
+ set_ = dict(values)
196
+ set_.pop("metadata_")
197
+ set_["metadata"] = values["metadata_"] # `metadata` must match database
198
+ await session.execute(
199
+ insert_stmt(
200
+ dialect=dialect,
201
+ table=models.DocumentAnnotation,
202
+ values=values,
203
+ constraint="uq_document_annotations_name_span_rowid_document_position",
204
+ column_names=("name", "span_rowid", "document_position"),
205
+ on_conflict=OnConflict.DO_UPDATE,
206
+ set_=set_,
207
+ )
208
+ )
209
+ return DocumentEvaluationInsertionEvent(project_rowid, evaluation_name)
@@ -0,0 +1,54 @@
1
+ from enum import Enum, auto
2
+ from typing import Any, Mapping, Optional, Sequence
3
+
4
+ from sqlalchemy import Insert, insert
5
+ from sqlalchemy.dialects.postgresql import insert as insert_postgresql
6
+ from sqlalchemy.dialects.sqlite import insert as insert_sqlite
7
+ from typing_extensions import assert_never
8
+
9
+ from phoenix.db.helpers import SupportedSQLDialect
10
+
11
+
12
+ class OnConflict(Enum):
13
+ DO_NOTHING = auto()
14
+ DO_UPDATE = auto()
15
+
16
+
17
+ def insert_stmt(
18
+ dialect: SupportedSQLDialect,
19
+ table: Any,
20
+ values: Mapping[str, Any],
21
+ constraint: Optional[str] = None,
22
+ column_names: Sequence[str] = (),
23
+ on_conflict: OnConflict = OnConflict.DO_NOTHING,
24
+ set_: Optional[Mapping[str, Any]] = None,
25
+ ) -> Insert:
26
+ """
27
+ Dialect specific insertion statement using ON CONFLICT DO syntax.
28
+ """
29
+ if bool(constraint) != bool(column_names):
30
+ raise ValueError(
31
+ "Both `constraint` and `column_names` must be provided or omitted at the same time."
32
+ )
33
+ if (
34
+ dialect is SupportedSQLDialect.POSTGRESQL
35
+ and constraint is None
36
+ or dialect is SupportedSQLDialect.SQLITE
37
+ and not column_names
38
+ ):
39
+ return insert(table).values(values)
40
+ if dialect is SupportedSQLDialect.POSTGRESQL:
41
+ stmt_postgresql = insert_postgresql(table).values(values)
42
+ if on_conflict is OnConflict.DO_NOTHING or not set_:
43
+ return stmt_postgresql.on_conflict_do_nothing(constraint=constraint)
44
+ if on_conflict is OnConflict.DO_UPDATE:
45
+ return stmt_postgresql.on_conflict_do_update(constraint=constraint, set_=set_)
46
+ assert_never(on_conflict)
47
+ if dialect is SupportedSQLDialect.SQLITE:
48
+ stmt_sqlite = insert_sqlite(table).values(values)
49
+ if on_conflict is OnConflict.DO_NOTHING or not set_:
50
+ return stmt_sqlite.on_conflict_do_nothing(column_names)
51
+ if on_conflict is OnConflict.DO_UPDATE:
52
+ return stmt_sqlite.on_conflict_do_update(column_names, set_=set_)
53
+ assert_never(on_conflict)
54
+ assert_never(dialect)
@@ -0,0 +1,142 @@
1
+ from dataclasses import asdict
2
+ from typing import NamedTuple, Optional, cast
3
+
4
+ from openinference.semconv.trace import SpanAttributes
5
+ from sqlalchemy import func, insert, select, update
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from phoenix.db import models
9
+ from phoenix.db.helpers import SupportedSQLDialect
10
+ from phoenix.db.insertion.helpers import OnConflict, insert_stmt
11
+ from phoenix.trace.attributes import get_attribute_value
12
+ from phoenix.trace.schemas import Span, SpanStatusCode
13
+
14
+
15
+ class SpanInsertionEvent(NamedTuple):
16
+ project_rowid: int
17
+
18
+
19
+ class ClearProjectSpansEvent(NamedTuple):
20
+ project_rowid: int
21
+
22
+
23
+ async def insert_span(
24
+ session: AsyncSession,
25
+ span: Span,
26
+ project_name: str,
27
+ ) -> Optional[SpanInsertionEvent]:
28
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
29
+ project_rowid = await session.scalar(
30
+ insert_stmt(
31
+ dialect=dialect,
32
+ table=models.Project,
33
+ constraint="uq_projects_name",
34
+ column_names=("name",),
35
+ values=dict(name=project_name),
36
+ on_conflict=OnConflict.DO_UPDATE,
37
+ set_=dict(name=project_name),
38
+ ).returning(models.Project.id)
39
+ )
40
+ assert project_rowid is not None
41
+ if trace := await session.scalar(
42
+ select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
43
+ ):
44
+ trace_rowid = trace.id
45
+ if span.start_time < trace.start_time or trace.end_time < span.end_time:
46
+ trace_start_time = min(trace.start_time, span.start_time)
47
+ trace_end_time = max(trace.end_time, span.end_time)
48
+ await session.execute(
49
+ update(models.Trace)
50
+ .where(models.Trace.id == trace_rowid)
51
+ .values(
52
+ start_time=trace_start_time,
53
+ end_time=trace_end_time,
54
+ )
55
+ )
56
+ else:
57
+ trace_rowid = cast(
58
+ int,
59
+ await session.scalar(
60
+ insert(models.Trace)
61
+ .values(
62
+ project_rowid=project_rowid,
63
+ trace_id=span.context.trace_id,
64
+ start_time=span.start_time,
65
+ end_time=span.end_time,
66
+ )
67
+ .returning(models.Trace.id)
68
+ ),
69
+ )
70
+ cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
71
+ cumulative_llm_token_count_prompt = cast(
72
+ int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
73
+ )
74
+ cumulative_llm_token_count_completion = cast(
75
+ int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0
76
+ )
77
+ if accumulation := (
78
+ await session.execute(
79
+ select(
80
+ func.sum(models.Span.cumulative_error_count),
81
+ func.sum(models.Span.cumulative_llm_token_count_prompt),
82
+ func.sum(models.Span.cumulative_llm_token_count_completion),
83
+ ).where(models.Span.parent_id == span.context.span_id)
84
+ )
85
+ ).first():
86
+ cumulative_error_count += cast(int, accumulation[0] or 0)
87
+ cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0)
88
+ cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0)
89
+ span_rowid = await session.scalar(
90
+ insert_stmt(
91
+ dialect=dialect,
92
+ table=models.Span,
93
+ constraint="uq_spans_span_id",
94
+ column_names=("span_id",),
95
+ values=dict(
96
+ span_id=span.context.span_id,
97
+ trace_rowid=trace_rowid,
98
+ parent_id=span.parent_id,
99
+ span_kind=span.span_kind.value,
100
+ name=span.name,
101
+ start_time=span.start_time,
102
+ end_time=span.end_time,
103
+ attributes=span.attributes,
104
+ events=[asdict(event) for event in span.events],
105
+ status_code=span.status_code.value,
106
+ status_message=span.status_message,
107
+ cumulative_error_count=cumulative_error_count,
108
+ cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt,
109
+ cumulative_llm_token_count_completion=cumulative_llm_token_count_completion,
110
+ ),
111
+ on_conflict=OnConflict.DO_NOTHING,
112
+ ).returning(models.Span.id)
113
+ )
114
+ if span_rowid is None:
115
+ return None
116
+ # Propagate cumulative values to ancestors. This is usually a no-op, since
117
+ # the parent usually arrives after the child. But in the event that a
118
+ # child arrives after its parent, we need to make sure that all the
119
+ # ancestors' cumulative values are updated.
120
+ ancestors = (
121
+ select(models.Span.id, models.Span.parent_id)
122
+ .where(models.Span.span_id == span.parent_id)
123
+ .cte(recursive=True)
124
+ )
125
+ child = ancestors.alias()
126
+ ancestors = ancestors.union_all(
127
+ select(models.Span.id, models.Span.parent_id).join(
128
+ child, models.Span.span_id == child.c.parent_id
129
+ )
130
+ )
131
+ await session.execute(
132
+ update(models.Span)
133
+ .where(models.Span.id.in_(select(ancestors.c.id)))
134
+ .values(
135
+ cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count,
136
+ cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt
137
+ + cumulative_llm_token_count_prompt,
138
+ cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion
139
+ + cumulative_llm_token_count_completion,
140
+ )
141
+ )
142
+ return SpanInsertionEvent(project_rowid)
phoenix/db/migrate.py ADDED
@@ -0,0 +1,71 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from queue import Empty, Queue
4
+ from threading import Thread
5
+ from typing import Optional
6
+
7
+ from alembic import command
8
+ from alembic.config import Config
9
+ from sqlalchemy import URL
10
+
11
+ from phoenix.exceptions import PhoenixMigrationError
12
+ from phoenix.settings import Settings
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def printif(condition: bool, text: str) -> None:
18
+ if condition:
19
+ print(text)
20
+
21
+
22
+ def migrate(url: URL, error_queue: Optional["Queue[Exception]"] = None) -> None:
23
+ """
24
+ Runs migrations on the database.
25
+ NB: Migrate only works on non-memory databases.
26
+
27
+ Args:
28
+ url: The database URL.
29
+ """
30
+ try:
31
+ log_migrations = Settings.log_migrations
32
+ printif(log_migrations, "🏃‍♀️‍➡️ Running migrations on the database.")
33
+ printif(log_migrations, "---------------------------")
34
+ config_path = str(Path(__file__).parent.resolve() / "alembic.ini")
35
+ alembic_cfg = Config(config_path)
36
+
37
+ # Explicitly set the migration directory
38
+ scripts_location = str(Path(__file__).parent.resolve() / "migrations")
39
+ alembic_cfg.set_main_option("script_location", scripts_location)
40
+ alembic_cfg.set_main_option("sqlalchemy.url", str(url))
41
+ command.upgrade(alembic_cfg, "head")
42
+ printif(log_migrations, "---------------------------")
43
+ printif(log_migrations, "✅ Migrations complete.")
44
+ except Exception as e:
45
+ if error_queue:
46
+ error_queue.put(e)
47
+ raise e
48
+
49
+
50
+ def migrate_in_thread(url: URL) -> None:
51
+ """
52
+ Runs migrations on the database in a separate thread.
53
+ This is needed because depending on the context (notebook)
54
+ the migration process can fail to execute in the main thread.
55
+ """
56
+ error_queue: Queue[Exception] = Queue()
57
+ t = Thread(target=migrate, args=(url, error_queue))
58
+ t.start()
59
+ t.join()
60
+
61
+ try:
62
+ result = error_queue.get_nowait()
63
+ except Empty:
64
+ return
65
+
66
+ if result is not None:
67
+ error_message = (
68
+ "\n\nUnable to migrate configured Phoenix DB. Original error:\n"
69
+ f"{type(result).__name__}: {str(result)}"
70
+ )
71
+ raise PhoenixMigrationError(error_message) from result
@@ -0,0 +1,121 @@
1
+ import asyncio
2
+ from logging.config import fileConfig
3
+
4
+ from alembic import context
5
+ from sqlalchemy import Connection, engine_from_config, pool
6
+ from sqlalchemy.ext.asyncio import AsyncEngine
7
+
8
+ from phoenix.config import get_env_database_connection_str
9
+ from phoenix.db.engines import get_async_db_url
10
+ from phoenix.db.models import Base
11
+ from phoenix.settings import Settings
12
+
13
+ # this is the Alembic Config object, which provides
14
+ # access to the values within the .ini file in use.
15
+ config = context.config
16
+
17
+ # Interpret the config file for Python logging.
18
+ # This line sets up loggers basically.
19
+ if config.config_file_name is not None:
20
+ fileConfig(config.config_file_name, disable_existing_loggers=False)
21
+
22
+ # add your model's MetaData object here
23
+ # for 'autogenerate' support
24
+
25
+ target_metadata = Base.metadata
26
+
27
+ # other values from the config, defined by the needs of env.py,
28
+ # can be acquired:
29
+ # my_important_option = config.get_main_option("my_important_option")
30
+ # ... etc.
31
+
32
+
33
+ def run_migrations_offline() -> None:
34
+ """Run migrations in 'offline' mode.
35
+
36
+ This configures the context with just a URL
37
+ and not an Engine, though an Engine is acceptable
38
+ here as well. By skipping the Engine creation
39
+ we don't even need a DBAPI to be available.
40
+
41
+ Calls to context.execute() here emit the given string to the
42
+ script output.
43
+
44
+ """
45
+ url = config.get_main_option("sqlalchemy.url")
46
+ context.configure(
47
+ url=url,
48
+ target_metadata=target_metadata,
49
+ literal_binds=True,
50
+ dialect_opts={"paramstyle": "named"},
51
+ transaction_per_migration=True,
52
+ )
53
+
54
+ with context.begin_transaction():
55
+ context.run_migrations()
56
+
57
+
58
+ def run_migrations_online() -> None:
59
+ """Run migrations in 'online' mode.
60
+
61
+ In this scenario we need to create an Engine
62
+ and associate a connection with the context.
63
+
64
+ """
65
+ connectable = context.config.attributes.get("connection", None)
66
+ if connectable is None:
67
+ config = context.config.get_section(context.config.config_ini_section) or {}
68
+ if "sqlalchemy.url" not in config:
69
+ connection_str = get_env_database_connection_str()
70
+ config["sqlalchemy.url"] = get_async_db_url(connection_str).render_as_string(
71
+ hide_password=False
72
+ )
73
+ connectable = AsyncEngine(
74
+ engine_from_config(
75
+ config,
76
+ prefix="sqlalchemy.",
77
+ poolclass=pool.NullPool,
78
+ future=True,
79
+ echo=Settings.log_migrations,
80
+ )
81
+ )
82
+
83
+ if isinstance(connectable, AsyncEngine):
84
+ try:
85
+ asyncio.get_running_loop()
86
+ except RuntimeError:
87
+ asyncio.run(run_async_migrations(connectable))
88
+ else:
89
+ asyncio.create_task(run_async_migrations(connectable))
90
+ else:
91
+ run_migrations(connectable)
92
+
93
+
94
+ async def run_async_migrations(connectable: AsyncEngine) -> None:
95
+ async with connectable.connect() as connection:
96
+ await connection.run_sync(run_migrations)
97
+
98
+
99
+ def run_migrations(connection: Connection) -> None:
100
+ transaction = connection.begin()
101
+ try:
102
+ context.configure(
103
+ connection=connection,
104
+ target_metadata=target_metadata,
105
+ compare_type=True,
106
+ transactional_ddl=True,
107
+ transaction_per_migration=True,
108
+ )
109
+ context.run_migrations()
110
+ transaction.commit()
111
+ except Exception:
112
+ transaction.rollback()
113
+ raise
114
+ finally:
115
+ connection.close()
116
+
117
+
118
+ if context.is_offline_mode():
119
+ run_migrations_offline()
120
+ else:
121
+ run_migrations_online()
@@ -0,0 +1,26 @@
1
+ """${message}
2
+
3
+ Revision ID: ${up_revision}
4
+ Revises: ${down_revision | comma,n}
5
+ Create Date: ${create_date}
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ ${imports if imports else ""}
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = ${repr(up_revision)}
16
+ down_revision: Union[str, None] = ${repr(down_revision)}
17
+ branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
18
+ depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
19
+
20
+
21
+ def upgrade() -> None:
22
+ ${upgrades if upgrades else "pass"}
23
+
24
+
25
+ def downgrade() -> None:
26
+ ${downgrades if downgrades else "pass"}