arize-phoenix 8.23.0__py3-none-any.whl → 8.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,7 +1,9 @@
1
1
  from typing import Optional
2
2
 
3
3
  from fastapi import APIRouter, HTTPException, Path, Query
4
+ from pydantic import Field
4
5
  from sqlalchemy import select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
5
7
  from starlette.requests import Request
6
8
  from starlette.status import (
7
9
  HTTP_204_NO_CONTENT,
@@ -14,6 +16,7 @@ from strawberry.relay import GlobalID
14
16
  from phoenix.config import DEFAULT_PROJECT_NAME
15
17
  from phoenix.db import models
16
18
  from phoenix.db.enums import UserRole
19
+ from phoenix.db.helpers import exclude_experiment_projects
17
20
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
18
21
  from phoenix.server.api.routers.v1.utils import (
19
22
  PaginatedResponseBody,
@@ -27,7 +30,7 @@ router = APIRouter(tags=["projects"])
27
30
 
28
31
 
29
32
  class ProjectData(V1RoutesBaseModel):
30
- name: str
33
+ name: str = Field(..., min_length=1)
31
34
  description: Optional[str] = None
32
35
 
33
36
 
@@ -43,8 +46,8 @@ class GetProjectResponseBody(ResponseBody[Project]):
43
46
  pass
44
47
 
45
48
 
46
- class CreateProjectRequestBody(V1RoutesBaseModel):
47
- project: ProjectData
49
+ class CreateProjectRequestBody(ProjectData):
50
+ pass
48
51
 
49
52
 
50
53
  class CreateProjectResponseBody(ResponseBody[Project]):
@@ -52,7 +55,7 @@ class CreateProjectResponseBody(ResponseBody[Project]):
52
55
 
53
56
 
54
57
  class UpdateProjectRequestBody(V1RoutesBaseModel):
55
- project: ProjectData
58
+ description: Optional[str] = None
56
59
 
57
60
 
58
61
  class UpdateProjectResponseBody(ResponseBody[Project]):
@@ -75,19 +78,25 @@ async def get_projects(
75
78
  request: Request,
76
79
  cursor: Optional[str] = Query(
77
80
  default=None,
78
- description="Cursor for pagination (base64-encoded project ID)",
81
+ description="Cursor for pagination (project ID)",
79
82
  ),
80
83
  limit: int = Query(
81
84
  default=100, description="The max number of projects to return at a time.", gt=0
82
85
  ),
86
+ include_experiment_projects: bool = Query(
87
+ default=False,
88
+ description="Include experiment projects in the response. Experiment projects are created from running experiments.", # noqa: E501
89
+ ),
83
90
  ) -> GetProjectsResponseBody:
84
91
  """
85
92
  Retrieve a paginated list of all projects in the system.
86
93
 
87
94
  Args:
88
95
  request (Request): The FastAPI request object.
89
- cursor (Optional[str]): Pagination cursor (base64-encoded project ID).
96
+ cursor (Optional[str]): Pagination cursor (project ID).
90
97
  limit (int): Maximum number of projects to return per request.
98
+ include_experiment_projects (bool): Flag to include experiment projects in the response.
99
+ Experiment projects are created from running experiments.
91
100
 
92
101
  Returns:
93
102
  GetProjectsResponseBody: Response containing a list of projects and pagination information.
@@ -95,9 +104,10 @@ async def get_projects(
95
104
  Raises:
96
105
  HTTPException: If the cursor format is invalid.
97
106
  """ # noqa: E501
107
+ stmt = select(models.Project).order_by(models.Project.id.desc())
108
+ if not include_experiment_projects:
109
+ stmt = exclude_experiment_projects(stmt)
98
110
  async with request.app.state.db() as session:
99
- stmt = select(models.Project).order_by(models.Project.id.desc())
100
-
101
111
  if cursor:
102
112
  try:
103
113
  cursor_id = GlobalID.from_id(cursor).node_id
@@ -109,27 +119,26 @@ async def get_projects(
109
119
  )
110
120
 
111
121
  stmt = stmt.limit(limit + 1)
112
- result = await session.execute(stmt)
113
- orm_projects = result.scalars().all()
122
+ projects = (await session.scalars(stmt)).all()
114
123
 
115
- if not orm_projects:
124
+ if not projects:
116
125
  return GetProjectsResponseBody(next_cursor=None, data=[])
117
126
 
118
127
  next_cursor = None
119
- if len(orm_projects) == limit + 1:
120
- last_project = orm_projects[-1]
128
+ if len(projects) == limit + 1:
129
+ last_project = projects[-1]
121
130
  next_cursor = str(GlobalID(ProjectNodeType.__name__, str(last_project.id)))
122
- orm_projects = orm_projects[:-1]
131
+ projects = projects[:-1]
123
132
 
124
- projects = [_project_from_orm_project(orm_project) for orm_project in orm_projects]
125
- return GetProjectsResponseBody(next_cursor=next_cursor, data=projects)
133
+ project_responses = [_to_project_response(project) for project in projects]
134
+ return GetProjectsResponseBody(next_cursor=next_cursor, data=project_responses)
126
135
 
127
136
 
128
137
  @router.get(
129
- "/projects/{project_id}",
138
+ "/projects/{project_identifier}",
130
139
  operation_id="getProject",
131
- summary="Get project by ID", # noqa: E501
132
- description="Retrieve a specific project using its unique identifier.", # noqa: E501
140
+ summary="Get project by ID or name", # noqa: E501
141
+ description="Retrieve a specific project using its unique identifier: either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
133
142
  response_description="The requested project", # noqa: E501
134
143
  responses=add_errors_to_responses(
135
144
  [
@@ -140,41 +149,27 @@ async def get_projects(
140
149
  )
141
150
  async def get_project(
142
151
  request: Request,
143
- project_id: str = Path(description="The ID of the project."),
152
+ project_identifier: str = Path(
153
+ description="The project identifier: either project ID or project name. If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
154
+ ),
144
155
  ) -> GetProjectResponseBody:
145
156
  """
146
- Retrieve a specific project by its ID.
157
+ Retrieve a specific project by its ID or name.
147
158
 
148
159
  Args:
149
160
  request (Request): The FastAPI request object.
150
- project_id (str): The ID of the project to retrieve.
161
+ project_identifier (str): The project identifier: either project ID or project name.
162
+ If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.
151
163
 
152
164
  Returns:
153
165
  GetProjectResponseBody: Response containing the requested project.
154
166
 
155
167
  Raises:
156
- HTTPException: If the project ID is invalid or the project is not found.
168
+ HTTPException: If the project identifier format is invalid or the project is not found.
157
169
  """ # noqa: E501
158
170
  async with request.app.state.db() as session:
159
- try:
160
- id_ = from_global_id_with_expected_type(
161
- GlobalID.from_id(project_id),
162
- ProjectNodeType.__name__,
163
- )
164
- project = await session.get(models.Project, id_)
165
- except ValueError:
166
- raise HTTPException(
167
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
168
- detail=f"Invalid project ID format: {project_id}",
169
- )
170
-
171
- if project is None:
172
- raise HTTPException(
173
- status_code=HTTP_404_NOT_FOUND,
174
- detail=f"Project with ID {project_id} not found",
175
- )
176
-
177
- data = _project_from_orm_project(project)
171
+ project = await _get_project_by_identifier(session, project_identifier)
172
+ data = _to_project_response(project)
178
173
  return GetProjectResponseBody(data=data)
179
174
 
180
175
 
@@ -207,23 +202,22 @@ async def create_project(
207
202
  Raises:
208
203
  HTTPException: If any validation error occurs.
209
204
  """
210
- project = request_body.project
211
205
  async with request.app.state.db() as session:
212
- project_orm = models.Project(
213
- name=project.name,
214
- description=project.description,
206
+ project = models.Project(
207
+ name=request_body.name,
208
+ description=request_body.description,
215
209
  )
216
- session.add(project_orm)
210
+ session.add(project)
217
211
  await session.flush()
218
- data = _project_from_orm_project(project_orm)
212
+ data = _to_project_response(project)
219
213
  return CreateProjectResponseBody(data=data)
220
214
 
221
215
 
222
216
  @router.put(
223
- "/projects/{project_id}",
217
+ "/projects/{project_identifier}",
224
218
  operation_id="updateProject",
225
- summary="Update a project", # noqa: E501
226
- description="Update an existing project with new configuration. Project names cannot be changed.", # noqa: E501
219
+ summary="Update a project by ID or name", # noqa: E501
220
+ description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
227
221
  response_description="The updated project", # noqa: E501
228
222
  responses=add_errors_to_responses(
229
223
  [
@@ -236,21 +230,24 @@ async def create_project(
236
230
  async def update_project(
237
231
  request: Request,
238
232
  request_body: UpdateProjectRequestBody,
239
- project_id: str = Path(description="The ID of the project to update."),
233
+ project_identifier: str = Path(
234
+ description="The project identifier: either project ID or project name. If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
235
+ ),
240
236
  ) -> UpdateProjectResponseBody:
241
237
  """
242
238
  Update an existing project.
243
239
 
244
240
  Args:
245
241
  request (Request): The FastAPI request object.
246
- request_body (UpdateProjectRequestBody): The request body containing updated project data.
247
- project_id (str): The ID of the project to update.
242
+ request_body (UpdateProjectRequestBody): The request body containing the new description.
243
+ project_identifier (str): The project identifier: either project ID or project name.
244
+ If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.
248
245
 
249
246
  Returns:
250
247
  UpdateProjectResponseBody: Response containing the updated project.
251
248
 
252
249
  Raises:
253
- HTTPException: If the project ID is invalid, the project is not found, or the name is changed.
250
+ HTTPException: If the project identifier format is invalid or the project is not found.
254
251
  """ # noqa: E501
255
252
  if request.app.state.authentication_enabled:
256
253
  async with request.app.state.db() as session:
@@ -267,43 +264,21 @@ async def update_project(
267
264
  detail="Only admins can update projects",
268
265
  )
269
266
  async with request.app.state.db() as session:
270
- try:
271
- id_ = from_global_id_with_expected_type(
272
- GlobalID.from_id(project_id),
273
- ProjectNodeType.__name__,
274
- )
275
- project_orm = await session.get(models.Project, id_)
276
- except ValueError:
277
- raise HTTPException(
278
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
279
- detail=f"Invalid project ID format: {project_id}",
280
- )
281
-
282
- if project_orm is None:
283
- raise HTTPException(
284
- status_code=HTTP_404_NOT_FOUND,
285
- detail=f"Project with ID {project_id} not found",
286
- )
287
-
288
- # Prevent changing the project name
289
- if project_orm.name != request_body.project.name:
290
- raise HTTPException(
291
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
292
- detail="Project names cannot be changed",
293
- )
267
+ project = await _get_project_by_identifier(session, project_identifier)
294
268
 
295
- # Only update the description
296
- project_orm.description = request_body.project.description
269
+ # Update the description if provided
270
+ if request_body.description is not None:
271
+ project.description = request_body.description
297
272
 
298
- data = _project_from_orm_project(project_orm)
273
+ data = _to_project_response(project)
299
274
  return UpdateProjectResponseBody(data=data)
300
275
 
301
276
 
302
277
  @router.delete(
303
- "/projects/{project_id}",
278
+ "/projects/{project_identifier}",
304
279
  operation_id="deleteProject",
305
- summary="Delete a project", # noqa: E501
306
- description="Delete an existing project and all its associated data.", # noqa: E501
280
+ summary="Delete a project by ID or name", # noqa: E501
281
+ description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
307
282
  response_description="No content returned on successful deletion", # noqa: E501
308
283
  status_code=HTTP_204_NO_CONTENT,
309
284
  responses=add_errors_to_responses(
@@ -316,20 +291,23 @@ async def update_project(
316
291
  )
317
292
  async def delete_project(
318
293
  request: Request,
319
- project_id: str = Path(description="The ID of the project to delete."),
294
+ project_identifier: str = Path(
295
+ description="The project identifier: either project ID or project name. If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
296
+ ),
320
297
  ) -> None:
321
298
  """
322
299
  Delete an existing project.
323
300
 
324
301
  Args:
325
302
  request (Request): The FastAPI request object.
326
- project_id (str): The ID of the project to delete.
303
+ project_identifier (str): The project identifier: either project ID or project name.
304
+ If using a project name, it cannot contain slash (/), question mark (?), or pound sign (#) characters.
327
305
 
328
306
  Returns:
329
307
  None: Returns a 204 No Content response on success.
330
308
 
331
309
  Raises:
332
- HTTPException: If the project ID is invalid, the project is not found, or it's the default project.
310
+ HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
333
311
  """ # noqa: E501
334
312
  if request.app.state.authentication_enabled:
335
313
  async with request.app.state.db() as session:
@@ -346,23 +324,7 @@ async def delete_project(
346
324
  detail="Only admins can delete projects",
347
325
  )
348
326
  async with request.app.state.db() as session:
349
- try:
350
- id_ = from_global_id_with_expected_type(
351
- GlobalID.from_id(project_id),
352
- ProjectNodeType.__name__,
353
- )
354
- project = await session.get(models.Project, id_)
355
- except ValueError:
356
- raise HTTPException(
357
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
358
- detail=f"Invalid project ID format: {project_id}",
359
- )
360
-
361
- if project is None:
362
- raise HTTPException(
363
- status_code=HTTP_404_NOT_FOUND,
364
- detail=f"Project with ID {project_id} not found",
365
- )
327
+ project = await _get_project_by_identifier(session, project_identifier)
366
328
 
367
329
  # The default project must not be deleted - it's forbidden
368
330
  if project.name == DEFAULT_PROJECT_NAME:
@@ -375,9 +337,57 @@ async def delete_project(
375
337
  return None
376
338
 
377
339
 
378
- def _project_from_orm_project(orm_project: models.Project) -> Project:
340
+ def _to_project_response(project: models.Project) -> Project:
379
341
  return Project(
380
- id=str(GlobalID(ProjectNodeType.__name__, str(orm_project.id))),
381
- name=orm_project.name,
382
- description=orm_project.description,
342
+ id=str(GlobalID(ProjectNodeType.__name__, str(project.id))),
343
+ name=project.name,
344
+ description=project.description,
383
345
  )
346
+
347
+
348
+ async def _get_project_by_identifier(
349
+ session: AsyncSession,
350
+ project_identifier: str,
351
+ ) -> models.Project:
352
+ """
353
+ Get a project by its ID or name.
354
+
355
+ Args:
356
+ session: The database session.
357
+ project_identifier: The project ID or name.
358
+
359
+ Returns:
360
+ The project object.
361
+
362
+ Raises:
363
+ HTTPException: If the identifier format is invalid or the project is not found.
364
+ """
365
+ # Try to parse as a GlobalID first
366
+ try:
367
+ id_ = from_global_id_with_expected_type(
368
+ GlobalID.from_id(project_identifier),
369
+ ProjectNodeType.__name__,
370
+ )
371
+ except Exception:
372
+ try:
373
+ name = project_identifier
374
+ except HTTPException:
375
+ raise HTTPException(
376
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
377
+ detail=f"Invalid project identifier format: {project_identifier}",
378
+ )
379
+ stmt = select(models.Project).filter_by(name=name)
380
+ project = await session.scalar(stmt)
381
+ if project is None:
382
+ raise HTTPException(
383
+ status_code=HTTP_404_NOT_FOUND,
384
+ detail=f"Project with name {name} not found",
385
+ )
386
+ else:
387
+ project = await session.get(models.Project, id_)
388
+ if project is None:
389
+ raise HTTPException(
390
+ status_code=HTTP_404_NOT_FOUND,
391
+ detail=f"Project with ID {project_identifier} not found",
392
+ )
393
+ return project
@@ -23,6 +23,7 @@ from strawberry.relay.types import GlobalID
23
23
  from strawberry.types import Info
24
24
  from typing_extensions import TypeAlias, assert_never
25
25
 
26
+ from phoenix.config import PLAYGROUND_PROJECT_NAME
26
27
  from phoenix.datetime_utils import local_now, normalize_datetime
27
28
  from phoenix.db import models
28
29
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
@@ -84,7 +85,6 @@ ChatCompletionResult: TypeAlias = tuple[
84
85
  DatasetExampleID, Optional[models.Span], models.ExperimentRun
85
86
  ]
86
87
  ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
87
- PLAYGROUND_PROJECT_NAME = "playground"
88
88
 
89
89
 
90
90
  @strawberry.type
phoenix/server/app.py CHANGED
@@ -60,6 +60,7 @@ from phoenix.config import (
60
60
  get_env_host,
61
61
  get_env_port,
62
62
  server_instrumentation_is_enabled,
63
+ verify_server_environment_variables,
63
64
  )
64
65
  from phoenix.core.model_schema import Model
65
66
  from phoenix.db import models
@@ -551,6 +552,7 @@ def create_graphql_router(
551
552
  read_only: bool = False,
552
553
  secret: Optional[str] = None,
553
554
  token_store: Optional[TokenStore] = None,
555
+ email_sender: Optional[EmailSender] = None,
554
556
  ) -> GraphQLRouter[Context, None]:
555
557
  """Creates the GraphQL router.
556
558
 
@@ -566,6 +568,8 @@ def create_graphql_router(
566
568
  cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
567
569
  read_only (bool, optional): Marks the app as read-only. Defaults to False.
568
570
  secret (Optional[str], optional): The application secret for auth. Defaults to None.
571
+ token_store (Optional[TokenStore], optional): The token store for auth. Defaults to None.
572
+ email_sender (Optional[EmailSender], optional): The email sender. Defaults to None.
569
573
 
570
574
  Returns:
571
575
  GraphQLRouter: The router mounted at /graphql
@@ -654,6 +658,7 @@ def create_graphql_router(
654
658
  auth_enabled=authentication_enabled,
655
659
  secret=secret,
656
660
  token_store=token_store,
661
+ email_sender=email_sender,
657
662
  )
658
663
 
659
664
  return GraphQLRouter(
@@ -768,6 +773,7 @@ def create_app(
768
773
  bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
769
774
  allowed_origins: Optional[list[str]] = None,
770
775
  ) -> FastAPI:
776
+ verify_server_environment_variables()
771
777
  if model.embedding_dimensions:
772
778
  try:
773
779
  import fast_hdbscan # noqa: F401
@@ -870,6 +876,7 @@ def create_app(
870
876
  read_only=read_only,
871
877
  secret=secret,
872
878
  token_store=token_store,
879
+ email_sender=email_sender,
873
880
  )
874
881
  if enable_prometheus:
875
882
  from phoenix.server.prometheus import PrometheusMiddleware
@@ -1,14 +1,19 @@
1
- import asyncio
2
1
  import smtplib
3
2
  import ssl
4
3
  from email.message import EmailMessage
5
4
  from pathlib import Path
6
5
  from typing import Literal
7
6
 
7
+ from anyio import to_thread
8
8
  from jinja2 import Environment, FileSystemLoader, select_autoescape
9
+ from typing_extensions import TypeAlias
10
+
11
+ from phoenix.config import get_env_root_url
9
12
 
10
13
  EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
11
14
 
15
+ ConnectionMethod: TypeAlias = Literal["STARTTLS", "SSL", "PLAIN"]
16
+
12
17
 
13
18
  class SimpleEmailSender:
14
19
  def __init__(
@@ -18,7 +23,7 @@ class SimpleEmailSender:
18
23
  username: str,
19
24
  password: str,
20
25
  sender_email: str,
21
- connection_method: Literal["STARTTLS", "SSL", "PLAIN"] = "STARTTLS",
26
+ connection_method: ConnectionMethod = "STARTTLS",
22
27
  validate_certs: bool = True,
23
28
  ) -> None:
24
29
  self.smtp_server = smtp_server
@@ -26,7 +31,7 @@ class SimpleEmailSender:
26
31
  self.username = username
27
32
  self.password = password
28
33
  self.sender_email = sender_email
29
- self.connection_method = connection_method.upper()
34
+ self.connection_method: ConnectionMethod = connection_method
30
35
  self.validate_certs = validate_certs
31
36
 
32
37
  self.env = Environment(
@@ -34,6 +39,28 @@ class SimpleEmailSender:
34
39
  autoescape=select_autoescape(["html", "xml"]),
35
40
  )
36
41
 
42
+ async def send_welcome_email(
43
+ self,
44
+ email: str,
45
+ name: str,
46
+ ) -> None:
47
+ subject = "[Phoenix] Welcome to Arize Phoenix"
48
+ template_name = "welcome.html"
49
+
50
+ template = self.env.get_template(template_name)
51
+ html_content = template.render(
52
+ name=name,
53
+ welcome_url=str(get_env_root_url()),
54
+ )
55
+
56
+ msg = EmailMessage()
57
+ msg["Subject"] = subject
58
+ msg["From"] = self.sender_email
59
+ msg["To"] = email
60
+ msg.set_content(html_content, subtype="html")
61
+
62
+ await to_thread.run_sync(self._send_email, msg)
63
+
37
64
  async def send_password_reset_email(
38
65
  self,
39
66
  email: str,
@@ -51,47 +78,47 @@ class SimpleEmailSender:
51
78
  msg["To"] = email
52
79
  msg.set_content(html_content, subtype="html")
53
80
 
54
- def send_email() -> None:
55
- context: ssl.SSLContext
56
- if self.validate_certs:
57
- context = ssl.create_default_context()
58
- else:
59
- context = ssl._create_unverified_context()
60
-
61
- methods_to_try = [self.connection_method]
62
- # add secure method fallbacks
63
- if self.connection_method != "PLAIN":
64
- if self.connection_method != "STARTTLS":
65
- methods_to_try.append("STARTTLS")
66
- elif self.connection_method != "SSL":
67
- methods_to_try.append("SSL")
68
-
69
- for method in methods_to_try:
70
- try:
71
- if method == "STARTTLS":
72
- server = smtplib.SMTP(self.smtp_server, self.smtp_port)
73
- server.ehlo()
74
- server.starttls(context=context)
75
- server.ehlo()
76
- elif method == "SSL":
77
- server = smtplib.SMTP_SSL(self.smtp_server, self.smtp_port, context=context)
78
- server.ehlo()
79
- elif method == "PLAIN":
80
- server = smtplib.SMTP(self.smtp_server, self.smtp_port)
81
- server.ehlo()
82
- else:
83
- continue # Unsupported method
84
-
85
- if self.username and self.password:
86
- server.login(self.username, self.password)
87
-
88
- server.send_message(msg)
89
- server.quit()
90
- break # Success
91
- except Exception as e:
92
- print(f"Failed to send email using {method}: {e}")
93
- continue
94
- else:
95
- raise Exception("All connection methods failed")
96
-
97
- await asyncio.to_thread(send_email)
81
+ await to_thread.run_sync(self._send_email, msg)
82
+
83
+ def _send_email(self, msg: EmailMessage) -> None:
84
+ context: ssl.SSLContext
85
+ if self.validate_certs:
86
+ context = ssl.create_default_context()
87
+ else:
88
+ context = ssl._create_unverified_context()
89
+
90
+ methods_to_try: list[ConnectionMethod] = [self.connection_method]
91
+ # add secure method fallbacks
92
+ if self.connection_method != "PLAIN":
93
+ if self.connection_method != "STARTTLS":
94
+ methods_to_try.append("STARTTLS")
95
+ if self.connection_method != "SSL":
96
+ methods_to_try.append("SSL")
97
+
98
+ for method in methods_to_try:
99
+ try:
100
+ if method == "STARTTLS":
101
+ server = smtplib.SMTP(self.smtp_server, self.smtp_port)
102
+ server.ehlo()
103
+ server.starttls(context=context)
104
+ server.ehlo()
105
+ elif method == "SSL":
106
+ server = smtplib.SMTP_SSL(self.smtp_server, self.smtp_port, context=context)
107
+ server.ehlo()
108
+ elif method == "PLAIN":
109
+ server = smtplib.SMTP(self.smtp_server, self.smtp_port)
110
+ server.ehlo()
111
+ else:
112
+ continue # Unsupported method
113
+
114
+ if self.username and self.password:
115
+ server.login(self.username, self.password)
116
+
117
+ server.send_message(msg)
118
+ server.quit()
119
+ break # Success
120
+ except Exception as e:
121
+ print(f"Failed to send email using {method}: {e}")
122
+ continue
123
+ else:
124
+ raise Exception("All connection methods failed")
@@ -0,0 +1,12 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>Welcome to Arize Phoenix</title>
6
+ </head>
7
+ <body>
8
+ <h1>Welcome to Arize Phoenix!</h1>
9
+ <p>Hi {{ name }}, please click the link below to get started:</p>
10
+ <a href="{{ welcome_url }}">Get Started</a>
11
+ </body>
12
+ </html>
@@ -3,9 +3,24 @@ from __future__ import annotations
3
3
  from typing import Protocol
4
4
 
5
5
 
6
- class EmailSender(Protocol):
6
+ class WelcomeEmailSender(Protocol):
7
+ async def send_welcome_email(
8
+ self,
9
+ email: str,
10
+ name: str,
11
+ ) -> None: ...
12
+
13
+
14
+ class PasswordResetEmailSender(Protocol):
7
15
  async def send_password_reset_email(
8
16
  self,
9
17
  email: str,
10
18
  reset_url: str,
11
19
  ) -> None: ...
20
+
21
+
22
+ class EmailSender(
23
+ WelcomeEmailSender,
24
+ PasswordResetEmailSender,
25
+ Protocol,
26
+ ): ...