arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (34) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
  3. phoenix/config.py +42 -0
  4. phoenix/server/api/helpers/playground_clients.py +671 -0
  5. phoenix/server/api/helpers/playground_registry.py +70 -0
  6. phoenix/server/api/helpers/playground_spans.py +325 -0
  7. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  8. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  9. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  10. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  11. phoenix/server/api/mutations/__init__.py +4 -0
  12. phoenix/server/api/mutations/chat_mutations.py +374 -0
  13. phoenix/server/api/queries.py +41 -52
  14. phoenix/server/api/schema.py +42 -10
  15. phoenix/server/api/subscriptions.py +326 -595
  16. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  17. phoenix/server/api/types/GenerativeProvider.py +27 -3
  18. phoenix/server/api/types/Span.py +37 -0
  19. phoenix/server/api/types/TemplateLanguage.py +9 -0
  20. phoenix/server/app.py +61 -13
  21. phoenix/server/main.py +14 -1
  22. phoenix/server/static/.vite/manifest.json +9 -9
  23. phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
  24. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
  25. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
  26. phoenix/server/templates/index.html +1 -0
  27. phoenix/services.py +4 -0
  28. phoenix/session/session.py +15 -1
  29. phoenix/utilities/template_formatters.py +11 -1
  30. phoenix/version.py +1 -1
  31. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  32. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  33. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  34. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,44 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID
5
+
6
+ from .Experiment import Experiment
7
+ from .Span import Span
8
+
9
+
10
+ @strawberry.interface
11
+ class ChatCompletionSubscriptionPayload:
12
+ dataset_example_id: Optional[GlobalID] = None
13
+
14
+
15
+ @strawberry.type
16
+ class TextChunk(ChatCompletionSubscriptionPayload):
17
+ content: str
18
+
19
+
20
+ @strawberry.type
21
+ class FunctionCallChunk(ChatCompletionSubscriptionPayload):
22
+ name: str
23
+ arguments: str
24
+
25
+
26
+ @strawberry.type
27
+ class ToolCallChunk(ChatCompletionSubscriptionPayload):
28
+ id: str
29
+ function: FunctionCallChunk
30
+
31
+
32
+ @strawberry.type
33
+ class FinishedChatCompletion(ChatCompletionSubscriptionPayload):
34
+ span: Span
35
+
36
+
37
+ @strawberry.type
38
+ class ChatCompletionSubscriptionError(ChatCompletionSubscriptionPayload):
39
+ message: str
40
+
41
+
42
+ @strawberry.type
43
+ class ChatCompletionOverDatasetSubscriptionResult(ChatCompletionSubscriptionPayload):
44
+ experiment: Experiment
@@ -5,12 +5,36 @@ import strawberry
5
5
 
6
6
  @strawberry.enum
7
7
  class GenerativeProviderKey(Enum):
8
- OPENAI = "OPENAI"
9
- ANTHROPIC = "ANTHROPIC"
10
- AZURE_OPENAI = "AZURE_OPENAI"
8
+ OPENAI = "OpenAI"
9
+ ANTHROPIC = "Anthropic"
10
+ AZURE_OPENAI = "Azure OpenAI"
11
11
 
12
12
 
13
13
  @strawberry.type
14
14
  class GenerativeProvider:
15
15
  name: str
16
16
  key: GenerativeProviderKey
17
+
18
+ @strawberry.field
19
+ async def dependencies(self) -> list[str]:
20
+ from phoenix.server.api.helpers.playground_registry import (
21
+ PLAYGROUND_CLIENT_REGISTRY,
22
+ PROVIDER_DEFAULT,
23
+ )
24
+
25
+ default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
26
+ if default_client:
27
+ return default_client.dependencies()
28
+ return []
29
+
30
+ @strawberry.field
31
+ async def dependencies_installed(self) -> bool:
32
+ from phoenix.server.api.helpers.playground_registry import (
33
+ PLAYGROUND_CLIENT_REGISTRY,
34
+ PROVIDER_DEFAULT,
35
+ )
36
+
37
+ default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
38
+ if default_client:
39
+ return default_client.dependencies_are_installed()
40
+ return False
@@ -20,10 +20,12 @@ from phoenix.server.api.helpers.dataset_helpers import (
20
20
  get_dataset_example_input,
21
21
  get_dataset_example_output,
22
22
  )
23
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
23
24
  from phoenix.server.api.input_types.SpanAnnotationSort import (
24
25
  SpanAnnotationColumn,
25
26
  SpanAnnotationSort,
26
27
  )
28
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
27
29
  from phoenix.server.api.types.SortDir import SortDir
28
30
  from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
29
31
  from phoenix.trace.attributes import get_attribute_value
@@ -291,6 +293,41 @@ class Span(Node):
291
293
  examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
292
294
  return bool(examples)
293
295
 
296
+ @strawberry.field(description="Invocation parameters for the span") # type: ignore
297
+ async def invocation_parameters(self, info: Info[Context, None]) -> list[InvocationParameter]:
298
+ from phoenix.server.api.helpers.playground_clients import OpenAIStreamingClient
299
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
300
+
301
+ db_span = self.db_span
302
+ attributes = db_span.attributes
303
+ llm_provider: GenerativeProviderKey = (
304
+ get_attribute_value(attributes, SpanAttributes.LLM_PROVIDER)
305
+ or GenerativeProviderKey.OPENAI
306
+ )
307
+ llm_model = get_attribute_value(attributes, SpanAttributes.LLM_MODEL_NAME)
308
+ invocation_parameters = get_attribute_value(
309
+ attributes, SpanAttributes.LLM_INVOCATION_PARAMETERS
310
+ )
311
+ if invocation_parameters is None:
312
+ return []
313
+ invocation_parameters = json.loads(invocation_parameters)
314
+ # find the client class for the provider, if there is no client class or provider,
315
+ # return openai as default
316
+ client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(llm_provider, llm_model)
317
+ if not client_class:
318
+ client_class = OpenAIStreamingClient
319
+ supported_invocation_parameters = client_class.supported_invocation_parameters()
320
+ # filter supported invocation parameters down to those whose canonical_name is in the
321
+ # invocation_parameters keys
322
+ return [
323
+ ip
324
+ for ip in supported_invocation_parameters
325
+ if (
326
+ ip.canonical_name in invocation_parameters
327
+ or ip.invocation_name in invocation_parameters
328
+ )
329
+ ]
330
+
294
331
 
295
332
  def to_gql_span(span: models.Span) -> Span:
296
333
  events: list[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class TemplateLanguage(Enum):
8
+ MUSTACHE = "MUSTACHE"
9
+ F_STRING = "F_STRING"
phoenix/server/app.py CHANGED
@@ -1,7 +1,9 @@
1
1
  import asyncio
2
2
  import contextlib
3
+ import importlib
3
4
  import json
4
5
  import logging
6
+ import os
5
7
  from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
6
8
  from contextlib import AbstractAsyncContextManager, AsyncExitStack
7
9
  from dataclasses import dataclass, field
@@ -40,7 +42,6 @@ from starlette.types import Scope, StatefulLifespan
40
42
  from starlette.websockets import WebSocket
41
43
  from strawberry.extensions import SchemaExtension
42
44
  from strawberry.fastapi import GraphQLRouter
43
- from strawberry.schema import BaseSchema
44
45
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
45
46
  from typing_extensions import TypeAlias
46
47
 
@@ -51,6 +52,8 @@ from phoenix.config import (
51
52
  SERVER_DIR,
52
53
  OAuth2ClientConfig,
53
54
  get_env_csrf_trusted_origins,
55
+ get_env_fastapi_middleware_paths,
56
+ get_env_gql_extension_paths,
54
57
  get_env_host,
55
58
  get_env_port,
56
59
  server_instrumentation_is_enabled,
@@ -98,7 +101,7 @@ from phoenix.server.api.routers import (
98
101
  oauth2_router,
99
102
  )
100
103
  from phoenix.server.api.routers.v1 import REST_API_VERSION
101
- from phoenix.server.api.schema import schema
104
+ from phoenix.server.api.schema import build_graphql_schema
102
105
  from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
103
106
  from phoenix.server.dml_event import DmlEvent
104
107
  from phoenix.server.dml_event_handler import DmlEventHandler
@@ -150,6 +153,28 @@ ProjectName: TypeAlias = str
150
153
  _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
151
154
 
152
155
 
156
+ def import_object_from_file(file_path: str, object_name: str) -> Any:
157
+ """Import an object (class or function) from a Python file."""
158
+ try:
159
+ if not os.path.isfile(file_path):
160
+ raise FileNotFoundError(f"File '{file_path}' does not exist.")
161
+ module_name = f"custom_module_{hash(file_path)}"
162
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
163
+ if spec is None:
164
+ raise ImportError(f"Could not load spec for '{file_path}'")
165
+ module = importlib.util.module_from_spec(spec)
166
+ loader = spec.loader
167
+ if loader is None:
168
+ raise ImportError(f"No loader found for '{file_path}'")
169
+ loader.exec_module(module)
170
+ try:
171
+ return getattr(module, object_name)
172
+ except AttributeError:
173
+ raise ImportError(f"Module '{file_path}' does not have an object '{object_name}'.")
174
+ except Exception as e:
175
+ raise ImportError(f"Could not import '{object_name}' from '{file_path}': {e}")
176
+
177
+
153
178
  class OAuth2Idp(TypedDict):
154
179
  name: str
155
180
  displayName: str
@@ -166,6 +191,7 @@ class AppConfig(NamedTuple):
166
191
  web_manifest_path: Path
167
192
  authentication_enabled: bool
168
193
  """ Whether authentication is enabled """
194
+ websockets_enabled: bool
169
195
  oauth2_idps: Sequence[OAuth2Idp]
170
196
 
171
197
 
@@ -216,6 +242,7 @@ class Static(StaticFiles):
216
242
  "manifest": self._web_manifest,
217
243
  "authentication_enabled": self._app_config.authentication_enabled,
218
244
  "oauth2_idps": self._app_config.oauth2_idps,
245
+ "websockets_enabled": self._app_config.websockets_enabled,
219
246
  },
220
247
  )
221
248
  except Exception as e:
@@ -256,6 +283,28 @@ class HeadersMiddleware(BaseHTTPMiddleware):
256
283
  return response
257
284
 
258
285
 
286
+ def user_fastapi_middlewares() -> list[Middleware]:
287
+ paths = get_env_fastapi_middleware_paths()
288
+ middlewares = []
289
+ for file_path, object_name in paths:
290
+ middleware_class = import_object_from_file(file_path, object_name)
291
+ if not issubclass(middleware_class, BaseHTTPMiddleware):
292
+ raise TypeError(f"{middleware_class} is not a subclass of BaseHTTPMiddleware")
293
+ middlewares.append(Middleware(middleware_class))
294
+ return middlewares
295
+
296
+
297
+ def user_gql_extensions() -> list[Union[type[SchemaExtension], SchemaExtension]]:
298
+ paths = get_env_gql_extension_paths()
299
+ extensions = []
300
+ for file_path, object_name in paths:
301
+ extension_class = import_object_from_file(file_path, object_name)
302
+ if not issubclass(extension_class, SchemaExtension):
303
+ raise TypeError(f"{extension_class} is not a subclass of SchemaExtension")
304
+ extensions.append(extension_class)
305
+ return extensions
306
+
307
+
259
308
  ProjectRowId: TypeAlias = int
260
309
 
261
310
 
@@ -463,7 +512,7 @@ async def check_healthz(_: Request) -> PlainTextResponse:
463
512
 
464
513
  def create_graphql_router(
465
514
  *,
466
- schema: BaseSchema,
515
+ graphql_schema: strawberry.Schema,
467
516
  db: DbSessionFactory,
468
517
  model: Model,
469
518
  export_path: Path,
@@ -567,7 +616,7 @@ def create_graphql_router(
567
616
  )
568
617
 
569
618
  return GraphQLRouter(
570
- schema,
619
+ graphql_schema,
571
620
  graphql_ide="graphiql",
572
621
  context_getter=get_context,
573
622
  include_in_schema=False,
@@ -653,6 +702,7 @@ def create_app(
653
702
  model: Model,
654
703
  authentication_enabled: bool,
655
704
  umap_params: UMAPParameters,
705
+ enable_websockets: bool,
656
706
  corpus: Optional[Model] = None,
657
707
  debug: bool = False,
658
708
  dev: bool = False,
@@ -700,6 +750,7 @@ def create_app(
700
750
  )
701
751
  last_updated_at = LastUpdatedAt()
702
752
  middlewares: list[Middleware] = [Middleware(HeadersMiddleware)]
753
+ middlewares.extend(user_fastapi_middlewares())
703
754
  if origins := get_env_csrf_trusted_origins():
704
755
  trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
705
756
  middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
@@ -733,8 +784,9 @@ def create_app(
733
784
  initial_batch_of_evaluations=initial_batch_of_evaluations,
734
785
  )
735
786
  tracer_provider = None
736
- strawberry_extensions: list[Union[type[SchemaExtension], SchemaExtension]] = []
737
- strawberry_extensions.extend(schema.get_extensions())
787
+ graphql_schema_extensions: list[Union[type[SchemaExtension], SchemaExtension]] = []
788
+ graphql_schema_extensions.extend(user_gql_extensions())
789
+
738
790
  if server_instrumentation_is_enabled():
739
791
  tracer_provider = initialize_opentelemetry_tracer_provider()
740
792
  from opentelemetry.trace import TracerProvider
@@ -752,16 +804,11 @@ def create_app(
752
804
  # used by OpenInference.
753
805
  self._tracer = cast(TracerProvider, tracer_provider).get_tracer("strawberry")
754
806
 
755
- strawberry_extensions.append(_OpenTelemetryExtension)
807
+ graphql_schema_extensions.append(_OpenTelemetryExtension)
756
808
 
757
809
  graphql_router = create_graphql_router(
758
810
  db=db,
759
- schema=strawberry.Schema(
760
- query=schema.query,
761
- mutation=schema.mutation,
762
- subscription=schema.subscription,
763
- extensions=strawberry_extensions,
764
- ),
811
+ graphql_schema=build_graphql_schema(graphql_schema_extensions),
765
812
  model=model,
766
813
  corpus=corpus,
767
814
  authentication_enabled=authentication_enabled,
@@ -830,6 +877,7 @@ def create_app(
830
877
  authentication_enabled=authentication_enabled,
831
878
  web_manifest_path=web_manifest_path,
832
879
  oauth2_idps=oauth2_idps,
880
+ websockets_enabled=enable_websockets,
833
881
  ),
834
882
  ),
835
883
  name="static",
phoenix/server/main.py CHANGED
@@ -21,6 +21,7 @@ from phoenix.config import (
21
21
  get_env_database_schema,
22
22
  get_env_db_logging_level,
23
23
  get_env_enable_prometheus,
24
+ get_env_enable_websockets,
24
25
  get_env_grpc_port,
25
26
  get_env_host,
26
27
  get_env_host_root_path,
@@ -95,6 +96,7 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
95
96
  | 🚀 Phoenix Server 🚀
96
97
  | Phoenix UI: {{ ui_path }}
97
98
  | Authentication: {{ auth_enabled }}
99
+ | Websockets: {{ websockets_enabled }}
98
100
  | Log traces:
99
101
  | - gRPC: {{ grpc_path }}
100
102
  | - HTTP: {{ http_path }}
@@ -162,7 +164,7 @@ def main() -> None:
162
164
  parser.add_argument("--debug", action="store_true", help=SUPPRESS)
163
165
  parser.add_argument("--dev", action="store_true", help=SUPPRESS)
164
166
  parser.add_argument("--no-ui", action="store_true", help=SUPPRESS)
165
-
167
+ parser.add_argument("--enable-websockets", type=str, help=SUPPRESS)
166
168
  subparsers = parser.add_subparsers(dest="command", required=True, help=SUPPRESS)
167
169
 
168
170
  serve_parser = subparsers.add_parser("serve")
@@ -348,6 +350,14 @@ def main() -> None:
348
350
  corpus_model = (
349
351
  None if corpus_inferences is None else create_model_from_inferences(corpus_inferences)
350
352
  )
353
+
354
+ # Get enable_websockets from environment variable or command line argument
355
+ enable_websockets = get_env_enable_websockets()
356
+ if args.enable_websockets is not None:
357
+ enable_websockets = args.enable_websockets.lower() == "true"
358
+ if enable_websockets is None:
359
+ enable_websockets = True
360
+
351
361
  # Print information about the server
352
362
  root_path = urljoin(f"http://{host}:{port}", host_root_path)
353
363
  msg = _WELCOME_MESSAGE.render(
@@ -358,6 +368,7 @@ def main() -> None:
358
368
  storage=get_printable_db_url(db_connection_str),
359
369
  schema=get_env_database_schema(),
360
370
  auth_enabled=authentication_enabled,
371
+ websockets_enabled=enable_websockets,
361
372
  )
362
373
  if sys.platform.startswith("win"):
363
374
  msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
@@ -382,10 +393,12 @@ def main() -> None:
382
393
  connection_method="STARTTLS",
383
394
  validate_certs=get_env_smtp_validate_certs(),
384
395
  )
396
+
385
397
  app = create_app(
386
398
  db=factory,
387
399
  export_path=export_path,
388
400
  model=model,
401
+ enable_websockets=enable_websockets,
389
402
  authentication_enabled=authentication_enabled,
390
403
  umap_params=umap_params,
391
404
  corpus=corpus_model,
@@ -1,22 +1,22 @@
1
1
  {
2
- "_components-C70HJiXz.js": {
3
- "file": "assets/components-C70HJiXz.js",
2
+ "_components-Csu8UKOs.js": {
3
+ "file": "assets/components-Csu8UKOs.js",
4
4
  "name": "components",
5
5
  "imports": [
6
6
  "_vendor-CtqfhlbC.js",
7
7
  "_vendor-arizeai-C_3SBz56.js",
8
- "_pages-C8-Sl7JI.js",
8
+ "_pages-UeWaKXNs.js",
9
9
  "_vendor-three-DwGkEfCM.js",
10
10
  "_vendor-codemirror-wfdk9cjp.js"
11
11
  ]
12
12
  },
13
- "_pages-C8-Sl7JI.js": {
14
- "file": "assets/pages-C8-Sl7JI.js",
13
+ "_pages-UeWaKXNs.js": {
14
+ "file": "assets/pages-UeWaKXNs.js",
15
15
  "name": "pages",
16
16
  "imports": [
17
17
  "_vendor-CtqfhlbC.js",
18
18
  "_vendor-arizeai-C_3SBz56.js",
19
- "_components-C70HJiXz.js",
19
+ "_components-Csu8UKOs.js",
20
20
  "_vendor-recharts-BiVnSv90.js",
21
21
  "_vendor-codemirror-wfdk9cjp.js"
22
22
  ]
@@ -61,15 +61,15 @@
61
61
  "name": "vendor-three"
62
62
  },
63
63
  "index.tsx": {
64
- "file": "assets/index-DLe1Oo3l.js",
64
+ "file": "assets/index-Bk5C9EA7.js",
65
65
  "name": "index",
66
66
  "src": "index.tsx",
67
67
  "isEntry": true,
68
68
  "imports": [
69
69
  "_vendor-CtqfhlbC.js",
70
70
  "_vendor-arizeai-C_3SBz56.js",
71
- "_pages-C8-Sl7JI.js",
72
- "_components-C70HJiXz.js",
71
+ "_pages-UeWaKXNs.js",
72
+ "_components-Csu8UKOs.js",
73
73
  "_vendor-three-DwGkEfCM.js",
74
74
  "_vendor-recharts-BiVnSv90.js",
75
75
  "_vendor-codemirror-wfdk9cjp.js"