arize-phoenix 3.24.0__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +54 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +33 -199
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
phoenix/db/README.md ADDED
@@ -0,0 +1,25 @@
1
+ # Database
2
+
3
+ This module is responsible for the database connection and the migrations.
4
+
5
+ ## Migrations
6
+
7
+ All migrations are managed by Alembic. Migrations are applied to the database automatically when the application starts.
8
+
9
+ ### Applying migrations
10
+
11
+ To manually apply the migrations, run the following command:
12
+
13
+ ```bash
14
+ alembic upgrade head
15
+ ```
16
+
17
+ ### Creating a migration
18
+
19
+ All migrations are stored in the `migrations` folder. To create a new migration, run the following command:
20
+
21
+ ```bash
22
+ alembic revision -m "your_revision_name"
23
+ ```
24
+
25
+ Then fill the migration file with the necessary changes.
phoenix/db/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .engines import get_printable_db_url
2
+ from .migrate import migrate
3
+
4
+ __all__ = ["migrate", "get_printable_db_url"]
phoenix/db/alembic.ini ADDED
@@ -0,0 +1,119 @@
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+ # path to migration scripts
5
+ # Note this is overridden in .migrate during programatic migrations
6
+ script_location = migrations
7
+
8
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
9
+ # Uncomment the line below if you want the files to be prepended with date and time
10
+ # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
11
+ # for all available tokens
12
+ # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
13
+
14
+ # sys.path path, will be prepended to sys.path if present.
15
+ # defaults to the current working directory.
16
+ prepend_sys_path = .
17
+
18
+ # timezone to use when rendering the date within the migration file
19
+ # as well as the filename.
20
+ # If specified, requires the python>=3.9 or backports.zoneinfo library.
21
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
22
+ # string value is passed to ZoneInfo()
23
+ # leave blank for localtime
24
+ # timezone =
25
+
26
+ # max length of characters to apply to the
27
+ # "slug" field
28
+ # truncate_slug_length = 40
29
+
30
+ # set to 'true' to run the environment during
31
+ # the 'revision' command, regardless of autogenerate
32
+ # revision_environment = false
33
+
34
+ # set to 'true' to allow .pyc and .pyo files without
35
+ # a source .py file to be detected as revisions in the
36
+ # versions/ directory
37
+ # sourceless = false
38
+
39
+ # version location specification; This defaults
40
+ # to migrations/versions. When using multiple version
41
+ # directories, initial revisions must be specified with --version-path.
42
+ # The path separator used here should be the separator specified by "version_path_separator" below.
43
+ # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
44
+
45
+ # version path separator; As mentioned above, this is the character used to split
46
+ # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
47
+ # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
48
+ # Valid values for version_path_separator are:
49
+ #
50
+ # version_path_separator = :
51
+ # version_path_separator = ;
52
+ # version_path_separator = space
53
+ version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
54
+
55
+ # set to 'true' to search source files recursively
56
+ # in each "version_locations" directory
57
+ # new in Alembic version 1.10
58
+ # recursive_version_locations = false
59
+
60
+ # the output encoding used when revision files
61
+ # are written from script.py.mako
62
+ # output_encoding = utf-8
63
+
64
+ # NB: This is commented out intentionally as it is dynamic
65
+ # See migrations/env.py
66
+ # sqlalchemy.url = driver://user:pass@localhost/dbname
67
+
68
+
69
+ [post_write_hooks]
70
+ # post_write_hooks defines scripts or Python functions that are run
71
+ # on newly generated revision scripts. See the documentation for further
72
+ # detail and examples
73
+
74
+ # format using "black" - use the console_scripts runner, against the "black" entrypoint
75
+ # hooks = black
76
+ # black.type = console_scripts
77
+ # black.entrypoint = black
78
+ # black.options = -l 79 REVISION_SCRIPT_FILENAME
79
+
80
+ # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
81
+ # hooks = ruff
82
+ # ruff.type = exec
83
+ # ruff.executable = %(here)s/.venv/bin/ruff
84
+ # ruff.options = --fix REVISION_SCRIPT_FILENAME
85
+
86
+ # Logging configuration
87
+ [loggers]
88
+ keys = root,sqlalchemy,alembic
89
+
90
+ [handlers]
91
+ keys = console
92
+
93
+ [formatters]
94
+ keys = generic
95
+
96
+ [logger_root]
97
+ level = WARN
98
+ handlers = console
99
+ qualname =
100
+
101
+ [logger_sqlalchemy]
102
+ level = WARN
103
+ handlers =
104
+ qualname = sqlalchemy.engine
105
+
106
+ [logger_alembic]
107
+ level = WARN
108
+ handlers =
109
+ qualname = alembic
110
+
111
+ [handler_console]
112
+ class = StreamHandler
113
+ args = (sys.stderr,)
114
+ level = NOTSET
115
+ formatter = generic
116
+
117
+ [formatter_generic]
118
+ format = %(levelname)-5.5s [%(name)s] %(message)s
119
+ datefmt = %H:%M:%S
@@ -0,0 +1,206 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime, timezone
5
+ from itertools import islice
6
+ from time import perf_counter, time
7
+ from typing import (
8
+ Any,
9
+ AsyncContextManager,
10
+ Awaitable,
11
+ Callable,
12
+ Iterable,
13
+ List,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ )
18
+
19
+ from cachetools import LRUCache
20
+ from sqlalchemy.ext.asyncio import AsyncSession
21
+ from typing_extensions import TypeAlias
22
+
23
+ import phoenix.trace.v1 as pb
24
+ from phoenix.db.insertion.evaluation import (
25
+ EvaluationInsertionResult,
26
+ InsertEvaluationError,
27
+ insert_evaluation,
28
+ )
29
+ from phoenix.db.insertion.span import SpanInsertionEvent, insert_span
30
+ from phoenix.server.api.dataloaders import CacheForDataLoaders
31
+ from phoenix.trace.schemas import Span
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ ProjectRowId: TypeAlias = int
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class TransactionResult:
40
+ updated_project_rowids: Set[ProjectRowId] = field(default_factory=set)
41
+
42
+
43
+ class BulkInserter:
44
+ def __init__(
45
+ self,
46
+ db: Callable[[], AsyncContextManager[AsyncSession]],
47
+ *,
48
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
49
+ initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None,
50
+ initial_batch_of_evaluations: Optional[Iterable[pb.Evaluation]] = None,
51
+ run_interval_in_seconds: float = 2,
52
+ max_num_per_transaction: int = 1000,
53
+ enable_prometheus: bool = False,
54
+ ) -> None:
55
+ """
56
+ :param db: A function to initiate a new database session.
57
+ :param initial_batch_of_spans: Initial batch of spans to insert.
58
+ :param run_interval_in_seconds: The time interval between the starts of each
59
+ bulk insert. If there's nothing to insert, the inserter goes back to sleep.
60
+ :param max_num_per_transaction: The maximum number of items to insert in a single
61
+ transaction. Multiple transactions will be used if there are more items in the batch.
62
+ """
63
+ self._db = db
64
+ self._running = False
65
+ self._run_interval_seconds = run_interval_in_seconds
66
+ self._max_num_per_transaction = max_num_per_transaction
67
+ self._spans: List[Tuple[Span, str]] = (
68
+ [] if initial_batch_of_spans is None else list(initial_batch_of_spans)
69
+ )
70
+ self._evaluations: List[pb.Evaluation] = (
71
+ [] if initial_batch_of_evaluations is None else list(initial_batch_of_evaluations)
72
+ )
73
+ self._task: Optional[asyncio.Task[None]] = None
74
+ self._last_updated_at_by_project: LRUCache[ProjectRowId, datetime] = LRUCache(maxsize=100)
75
+ self._cache_for_dataloaders = cache_for_dataloaders
76
+ self._enable_prometheus = enable_prometheus
77
+
78
+ def last_updated_at(self, project_rowid: Optional[ProjectRowId] = None) -> Optional[datetime]:
79
+ if isinstance(project_rowid, ProjectRowId):
80
+ return self._last_updated_at_by_project.get(project_rowid)
81
+ return max(self._last_updated_at_by_project.values(), default=None)
82
+
83
+ async def __aenter__(
84
+ self,
85
+ ) -> Tuple[Callable[[Span, str], Awaitable[None]], Callable[[pb.Evaluation], Awaitable[None]]]:
86
+ self._running = True
87
+ self._task = asyncio.create_task(self._bulk_insert())
88
+ return self._queue_span, self._queue_evaluation
89
+
90
+ async def __aexit__(self, *args: Any) -> None:
91
+ self._running = False
92
+
93
+ async def _queue_span(self, span: Span, project_name: str) -> None:
94
+ self._spans.append((span, project_name))
95
+
96
+ async def _queue_evaluation(self, evaluation: pb.Evaluation) -> None:
97
+ self._evaluations.append(evaluation)
98
+
99
+ async def _bulk_insert(self) -> None:
100
+ spans_buffer, evaluations_buffer = None, None
101
+ next_run_at = time() + self._run_interval_seconds
102
+ while self._spans or self._evaluations or self._running:
103
+ await asyncio.sleep(next_run_at - time())
104
+ next_run_at = time() + self._run_interval_seconds
105
+ if not (self._spans or self._evaluations):
106
+ continue
107
+ # It's important to grab the buffers at the same time so there's
108
+ # no race condition, since an eval insertion will fail if the span
109
+ # it references doesn't exist. Grabbing the eval buffer later may
110
+ # include an eval whose span is in the queue but missed being
111
+ # included in the span buffer that was grabbed previously.
112
+ if self._spans:
113
+ spans_buffer = self._spans
114
+ self._spans = []
115
+ if self._evaluations:
116
+ evaluations_buffer = self._evaluations
117
+ self._evaluations = []
118
+ # Spans should be inserted before the evaluations, since an evaluation
119
+ # insertion will fail if the span it references doesn't exist.
120
+ transaction_result = TransactionResult()
121
+ if spans_buffer:
122
+ result = await self._insert_spans(spans_buffer)
123
+ transaction_result.updated_project_rowids.update(result.updated_project_rowids)
124
+ spans_buffer = None
125
+ if evaluations_buffer:
126
+ result = await self._insert_evaluations(evaluations_buffer)
127
+ transaction_result.updated_project_rowids.update(result.updated_project_rowids)
128
+ evaluations_buffer = None
129
+ for project_rowid in transaction_result.updated_project_rowids:
130
+ self._last_updated_at_by_project[project_rowid] = datetime.now(timezone.utc)
131
+
132
+ async def _insert_spans(self, spans: List[Tuple[Span, str]]) -> TransactionResult:
133
+ transaction_result = TransactionResult()
134
+ for i in range(0, len(spans), self._max_num_per_transaction):
135
+ try:
136
+ start = perf_counter()
137
+ async with self._db() as session:
138
+ for span, project_name in islice(spans, i, i + self._max_num_per_transaction):
139
+ if self._enable_prometheus:
140
+ from phoenix.server.prometheus import BULK_LOADER_SPAN_INSERTIONS
141
+
142
+ BULK_LOADER_SPAN_INSERTIONS.inc()
143
+ result: Optional[SpanInsertionEvent] = None
144
+ try:
145
+ async with session.begin_nested():
146
+ result = await insert_span(session, span, project_name)
147
+ except Exception:
148
+ if self._enable_prometheus:
149
+ from phoenix.server.prometheus import BULK_LOADER_EXCEPTIONS
150
+
151
+ BULK_LOADER_EXCEPTIONS.inc()
152
+ logger.exception(
153
+ f"Failed to insert span with span_id={span.context.span_id}"
154
+ )
155
+ if result is not None:
156
+ transaction_result.updated_project_rowids.add(result.project_rowid)
157
+ if (cache := self._cache_for_dataloaders) is not None:
158
+ cache.invalidate(result)
159
+ if self._enable_prometheus:
160
+ from phoenix.server.prometheus import BULK_LOADER_INSERTION_TIME
161
+
162
+ BULK_LOADER_INSERTION_TIME.observe(perf_counter() - start)
163
+ except Exception:
164
+ if self._enable_prometheus:
165
+ from phoenix.server.prometheus import BULK_LOADER_EXCEPTIONS
166
+
167
+ BULK_LOADER_EXCEPTIONS.inc()
168
+ logger.exception("Failed to insert spans")
169
+ return transaction_result
170
+
171
+ async def _insert_evaluations(self, evaluations: List[pb.Evaluation]) -> TransactionResult:
172
+ transaction_result = TransactionResult()
173
+ for i in range(0, len(evaluations), self._max_num_per_transaction):
174
+ try:
175
+ start = perf_counter()
176
+ async with self._db() as session:
177
+ for evaluation in islice(evaluations, i, i + self._max_num_per_transaction):
178
+ if self._enable_prometheus:
179
+ from phoenix.server.prometheus import BULK_LOADER_EVALUATION_INSERTIONS
180
+
181
+ BULK_LOADER_EVALUATION_INSERTIONS.inc()
182
+ result: Optional[EvaluationInsertionResult] = None
183
+ try:
184
+ async with session.begin_nested():
185
+ result = await insert_evaluation(session, evaluation)
186
+ except InsertEvaluationError as error:
187
+ if self._enable_prometheus:
188
+ from phoenix.server.prometheus import BULK_LOADER_EXCEPTIONS
189
+
190
+ BULK_LOADER_EXCEPTIONS.inc()
191
+ logger.exception(f"Failed to insert evaluation: {str(error)}")
192
+ if result is not None:
193
+ transaction_result.updated_project_rowids.add(result.project_rowid)
194
+ if (cache := self._cache_for_dataloaders) is not None:
195
+ cache.invalidate(result)
196
+ if self._enable_prometheus:
197
+ from phoenix.server.prometheus import BULK_LOADER_INSERTION_TIME
198
+
199
+ BULK_LOADER_INSERTION_TIME.observe(perf_counter() - start)
200
+ except Exception:
201
+ if self._enable_prometheus:
202
+ from phoenix.server.prometheus import BULK_LOADER_EXCEPTIONS
203
+
204
+ BULK_LOADER_EXCEPTIONS.inc()
205
+ logger.exception("Failed to insert evaluations")
206
+ return transaction_result
phoenix/db/engines.py ADDED
@@ -0,0 +1,152 @@
1
+ import asyncio
2
+ import json
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from sqlite3 import Connection
6
+ from typing import Any
7
+
8
+ import aiosqlite
9
+ import numpy as np
10
+ import sqlean
11
+ from sqlalchemy import URL, event, make_url
12
+ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
13
+ from typing_extensions import assert_never
14
+
15
+ from phoenix.db.helpers import SupportedSQLDialect
16
+ from phoenix.db.migrate import migrate_in_thread
17
+ from phoenix.db.models import init_models
18
+
19
+ sqlean.extensions.enable("text", "stats")
20
+
21
+
22
+ def set_sqlite_pragma(connection: Connection, _: Any) -> None:
23
+ cursor = connection.cursor()
24
+ cursor.execute("PRAGMA foreign_keys = ON;")
25
+ cursor.execute("PRAGMA journal_mode = WAL;")
26
+ cursor.execute("PRAGMA synchronous = OFF;")
27
+ cursor.execute("PRAGMA cache_size = -32000;")
28
+ cursor.execute("PRAGMA busy_timeout = 10000;")
29
+ cursor.close()
30
+
31
+
32
+ def get_printable_db_url(connection_str: str) -> str:
33
+ return make_url(connection_str).render_as_string(hide_password=True)
34
+
35
+
36
+ def get_async_db_url(connection_str: str) -> URL:
37
+ """
38
+ Parses the database URL string and returns a URL object that is async
39
+ """
40
+ url = make_url(connection_str)
41
+ if not url.database:
42
+ raise ValueError("Failed to parse database from connection string")
43
+ backend = SupportedSQLDialect(url.get_backend_name())
44
+ if backend is SupportedSQLDialect.SQLITE:
45
+ return url.set(drivername="sqlite+aiosqlite")
46
+ elif backend is SupportedSQLDialect.POSTGRESQL:
47
+ url = url.set(drivername="postgresql+asyncpg")
48
+ # For some reason username and password cannot be parsed from the typical slot
49
+ # So we need to parse them out manually
50
+ if url.username and url.password:
51
+ url = url.set(
52
+ query={"user": url.username, "password": url.password},
53
+ password=None,
54
+ username=None,
55
+ )
56
+ return url
57
+ else:
58
+ assert_never(backend)
59
+
60
+
61
+ def create_engine(
62
+ connection_str: str,
63
+ migrate: bool = True,
64
+ echo: bool = False,
65
+ ) -> AsyncEngine:
66
+ """
67
+ Factory to create a SQLAlchemy engine from a URL string.
68
+ """
69
+ url = make_url(connection_str)
70
+ if not url.database:
71
+ raise ValueError("Failed to parse database from connection string")
72
+ backend = SupportedSQLDialect(url.get_backend_name())
73
+ url = get_async_db_url(url.render_as_string(hide_password=False))
74
+ if backend is SupportedSQLDialect.SQLITE:
75
+ return aio_sqlite_engine(url=url, migrate=migrate, echo=echo)
76
+ elif backend is SupportedSQLDialect.POSTGRESQL:
77
+ return aio_postgresql_engine(url=url, migrate=migrate, echo=echo)
78
+ else:
79
+ assert_never(backend)
80
+
81
+
82
+ def aio_sqlite_engine(
83
+ url: URL,
84
+ migrate: bool = True,
85
+ echo: bool = False,
86
+ shared_cache: bool = True,
87
+ ) -> AsyncEngine:
88
+ database = url.database or ":memory:"
89
+ if database.startswith("file:"):
90
+ database = database[5:]
91
+ if database.startswith(":memory:") and shared_cache:
92
+ url = url.set(query={**url.query, "cache": "shared"}, database=":memory:")
93
+ database = url.render_as_string().partition("///")[-1]
94
+
95
+ def async_creator() -> aiosqlite.Connection:
96
+ conn = aiosqlite.Connection(
97
+ lambda: sqlean.connect(f"file:{database}", uri=True),
98
+ iter_chunk_size=64,
99
+ )
100
+ conn.daemon = True
101
+ return conn
102
+
103
+ engine = create_async_engine(
104
+ url=url,
105
+ echo=echo,
106
+ json_serializer=_dumps,
107
+ async_creator=async_creator,
108
+ )
109
+ event.listen(engine.sync_engine, "connect", set_sqlite_pragma)
110
+ if not migrate:
111
+ return engine
112
+ if database.startswith(":memory:"):
113
+ try:
114
+ asyncio.get_running_loop()
115
+ except RuntimeError:
116
+ asyncio.run(init_models(engine))
117
+ else:
118
+ asyncio.create_task(init_models(engine))
119
+ else:
120
+ migrate_in_thread(engine.url)
121
+ return engine
122
+
123
+
124
+ def aio_postgresql_engine(
125
+ url: URL,
126
+ migrate: bool = True,
127
+ echo: bool = False,
128
+ ) -> AsyncEngine:
129
+ engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps)
130
+ if not migrate:
131
+ return engine
132
+ migrate_in_thread(engine.url)
133
+ return engine
134
+
135
+
136
+ def _dumps(obj: Any) -> str:
137
+ return json.dumps(obj, cls=_Encoder)
138
+
139
+
140
+ class _Encoder(json.JSONEncoder):
141
+ def default(self, obj: Any) -> Any:
142
+ if isinstance(obj, datetime):
143
+ return obj.isoformat()
144
+ elif isinstance(obj, Enum):
145
+ return obj.value
146
+ elif isinstance(obj, np.ndarray):
147
+ return list(obj)
148
+ elif isinstance(obj, np.integer):
149
+ return int(obj)
150
+ elif isinstance(obj, np.floating):
151
+ return float(obj)
152
+ return super().default(obj)
phoenix/db/helpers.py ADDED
@@ -0,0 +1,47 @@
1
+ from enum import Enum
2
+ from typing import Any
3
+
4
+ from openinference.semconv.trace import (
5
+ OpenInferenceSpanKindValues,
6
+ RerankerAttributes,
7
+ SpanAttributes,
8
+ )
9
+ from sqlalchemy import Integer, SQLColumnExpression, case, func
10
+ from typing_extensions import assert_never
11
+
12
+ from phoenix.db import models
13
+
14
+
15
+ class SupportedSQLDialect(Enum):
16
+ SQLITE = "sqlite"
17
+ POSTGRESQL = "postgresql"
18
+
19
+ @classmethod
20
+ def _missing_(cls, v: Any) -> "SupportedSQLDialect":
21
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
22
+ return cls(v.lower())
23
+ raise ValueError(f"`{v}` is not a supported SQL backend/dialect.")
24
+
25
+
26
+ def num_docs_col(dialect: SupportedSQLDialect) -> SQLColumnExpression[Integer]:
27
+ if dialect is SupportedSQLDialect.POSTGRESQL:
28
+ array_length = func.jsonb_array_length
29
+ elif dialect is SupportedSQLDialect.SQLITE:
30
+ array_length = func.json_array_length
31
+ else:
32
+ assert_never(dialect)
33
+ retrieval_docs = models.Span.attributes[_RETRIEVAL_DOCUMENTS]
34
+ num_retrieval_docs = array_length(retrieval_docs)
35
+ reranker_docs = models.Span.attributes[_RERANKER_OUTPUT_DOCUMENTS]
36
+ num_reranker_docs = array_length(reranker_docs)
37
+ return case(
38
+ (
39
+ func.upper(models.Span.span_kind) == OpenInferenceSpanKindValues.RERANKER.value.upper(),
40
+ num_reranker_docs,
41
+ ),
42
+ else_=num_retrieval_docs,
43
+ ).label("num_docs")
44
+
45
+
46
+ _RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS.split(".")
47
+ _RERANKER_OUTPUT_DOCUMENTS = RerankerAttributes.RERANKER_OUTPUT_DOCUMENTS.split(".")