chainlit 0.7.700__py3-none-any.whl → 1.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (38) hide show
  1. chainlit/__init__.py +32 -23
  2. chainlit/auth.py +9 -10
  3. chainlit/cli/__init__.py +1 -2
  4. chainlit/config.py +13 -12
  5. chainlit/context.py +7 -3
  6. chainlit/data/__init__.py +375 -9
  7. chainlit/data/acl.py +6 -5
  8. chainlit/element.py +86 -123
  9. chainlit/emitter.py +117 -50
  10. chainlit/frontend/dist/assets/{index-71698725.js → index-6aee009a.js} +118 -292
  11. chainlit/frontend/dist/assets/{react-plotly-2c0acdf0.js → react-plotly-2f07c02a.js} +1 -1
  12. chainlit/frontend/dist/index.html +1 -1
  13. chainlit/haystack/callbacks.py +45 -43
  14. chainlit/hello.py +1 -1
  15. chainlit/langchain/callbacks.py +132 -120
  16. chainlit/llama_index/callbacks.py +68 -48
  17. chainlit/message.py +179 -207
  18. chainlit/oauth_providers.py +39 -34
  19. chainlit/playground/provider.py +44 -30
  20. chainlit/playground/providers/anthropic.py +4 -4
  21. chainlit/playground/providers/huggingface.py +2 -2
  22. chainlit/playground/providers/langchain.py +8 -10
  23. chainlit/playground/providers/openai.py +19 -13
  24. chainlit/server.py +155 -99
  25. chainlit/session.py +109 -40
  26. chainlit/socket.py +47 -36
  27. chainlit/step.py +393 -0
  28. chainlit/types.py +78 -21
  29. chainlit/user.py +32 -0
  30. chainlit/user_session.py +1 -5
  31. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc1.dist-info}/METADATA +12 -31
  32. chainlit-1.0.0rc1.dist-info/RECORD +60 -0
  33. chainlit/client/base.py +0 -169
  34. chainlit/client/cloud.py +0 -502
  35. chainlit/prompt.py +0 -40
  36. chainlit-0.7.700.dist-info/RECORD +0 -61
  37. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc1.dist-info}/WHEEL +0 -0
  38. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc1.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import glob
2
2
  import json
3
3
  import mimetypes
4
+ import shutil
4
5
  import urllib.parse
5
6
  from typing import Optional, Union
6
7
 
@@ -17,30 +18,31 @@ from contextlib import asynccontextmanager
17
18
  from pathlib import Path
18
19
 
19
20
  from chainlit.auth import create_jwt, get_configuration, get_current_user
20
- from chainlit.client.cloud import AppUser, PersistedAppUser
21
21
  from chainlit.config import (
22
22
  APP_ROOT,
23
23
  BACKEND_ROOT,
24
24
  DEFAULT_HOST,
25
+ FILES_DIRECTORY,
25
26
  PACKAGE_ROOT,
26
27
  config,
27
28
  load_module,
28
29
  reload_config,
29
30
  )
30
- from chainlit.data import chainlit_client
31
- from chainlit.data.acl import is_conversation_author
31
+ from chainlit.data import get_data_layer
32
+ from chainlit.data.acl import is_thread_author
32
33
  from chainlit.logger import logger
33
34
  from chainlit.markdown import get_markdown_str
34
35
  from chainlit.playground.config import get_llm_providers
35
36
  from chainlit.telemetry import trace_event
36
37
  from chainlit.types import (
37
- CompletionRequest,
38
- DeleteConversationRequest,
39
- GetConversationsRequest,
38
+ DeleteThreadRequest,
39
+ GenerationRequest,
40
+ GetThreadsRequest,
40
41
  Theme,
41
42
  UpdateFeedbackRequest,
42
43
  )
43
- from fastapi import Depends, FastAPI, HTTPException, Query, Request, status
44
+ from chainlit.user import PersistedUser, User
45
+ from fastapi import Depends, FastAPI, HTTPException, Query, Request, UploadFile, status
44
46
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
45
47
  from fastapi.security import OAuth2PasswordRequestForm
46
48
  from fastapi.staticfiles import StaticFiles
@@ -117,6 +119,9 @@ async def lifespan(app: FastAPI):
117
119
  except asyncio.exceptions.CancelledError:
118
120
  pass
119
121
 
122
+ if FILES_DIRECTORY.is_dir():
123
+ shutil.rmtree(FILES_DIRECTORY)
124
+
120
125
  # Force exit the process to avoid potential AnyIO threads still running
121
126
  os._exit(0)
122
127
 
@@ -134,6 +139,7 @@ def get_build_dir():
134
139
 
135
140
  build_dir = get_build_dir()
136
141
 
142
+
137
143
  app = FastAPI(lifespan=lifespan)
138
144
 
139
145
  app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public")
@@ -156,14 +162,10 @@ app.add_middleware(
156
162
  )
157
163
 
158
164
 
159
- # Define max HTTP data size to 100 MB
160
- max_message_size = 100 * 1024 * 1024
161
-
162
165
  socket = SocketManager(
163
166
  app,
164
167
  cors_allowed_origins=[],
165
168
  async_mode="asgi",
166
- max_http_buffer_size=max_message_size,
167
169
  )
168
170
 
169
171
 
@@ -244,18 +246,18 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
244
246
  status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
245
247
  )
246
248
 
247
- app_user = await config.code.password_auth_callback(
249
+ user = await config.code.password_auth_callback(
248
250
  form_data.username, form_data.password
249
251
  )
250
252
 
251
- if not app_user:
253
+ if not user:
252
254
  raise HTTPException(
253
255
  status_code=status.HTTP_401_UNAUTHORIZED,
254
256
  detail="credentialssignin",
255
257
  )
256
- access_token = create_jwt(app_user)
257
- if chainlit_client:
258
- await chainlit_client.create_app_user(app_user=app_user)
258
+ access_token = create_jwt(user)
259
+ if data_layer := get_data_layer():
260
+ await data_layer.create_user(user)
259
261
  return {
260
262
  "access_token": access_token,
261
263
  "token_type": "bearer",
@@ -270,17 +272,17 @@ async def header_auth(request: Request):
270
272
  detail="No header_auth_callback defined",
271
273
  )
272
274
 
273
- app_user = await config.code.header_auth_callback(request.headers)
275
+ user = await config.code.header_auth_callback(request.headers)
274
276
 
275
- if not app_user:
277
+ if not user:
276
278
  raise HTTPException(
277
279
  status_code=status.HTTP_401_UNAUTHORIZED,
278
280
  detail="Unauthorized",
279
281
  )
280
282
 
281
- access_token = create_jwt(app_user)
282
- if chainlit_client:
283
- await chainlit_client.create_app_user(app_user=app_user)
283
+ access_token = create_jwt(user)
284
+ if data_layer := get_data_layer():
285
+ await data_layer.create_user(user)
284
286
  return {
285
287
  "access_token": access_token,
286
288
  "token_type": "bearer",
@@ -369,21 +371,22 @@ async def oauth_callback(
369
371
  url = get_user_facing_url(request.url)
370
372
  token = await provider.get_token(code, url)
371
373
 
372
- (raw_user_data, default_app_user) = await provider.get_user_info(token)
374
+ (raw_user_data, default_user) = await provider.get_user_info(token)
373
375
 
374
- app_user = await config.code.oauth_callback(
375
- provider_id, token, raw_user_data, default_app_user
376
+ user = await config.code.oauth_callback(
377
+ provider_id, token, raw_user_data, default_user
376
378
  )
377
379
 
378
- if not app_user:
380
+ if not user:
379
381
  raise HTTPException(
380
382
  status_code=status.HTTP_401_UNAUTHORIZED,
381
383
  detail="Unauthorized",
382
384
  )
383
385
 
384
- access_token = create_jwt(app_user)
385
- if chainlit_client:
386
- await chainlit_client.create_app_user(app_user=app_user)
386
+ access_token = create_jwt(user)
387
+
388
+ if data_layer := get_data_layer():
389
+ await data_layer.create_user(user)
387
390
 
388
391
  params = urllib.parse.urlencode(
389
392
  {
@@ -399,23 +402,21 @@ async def oauth_callback(
399
402
  return response
400
403
 
401
404
 
402
- @app.post("/completion")
403
- async def completion(
404
- request: CompletionRequest,
405
- current_user: Annotated[
406
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
407
- ],
405
+ @app.post("/generation")
406
+ async def generation(
407
+ request: GenerationRequest,
408
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
408
409
  ):
409
410
  """Handle a completion request from the prompt playground."""
410
411
 
411
412
  providers = get_llm_providers()
412
413
 
413
414
  try:
414
- provider = [p for p in providers if p.id == request.prompt.provider][0]
415
+ provider = [p for p in providers if p.id == request.generation.provider][0]
415
416
  except IndexError:
416
417
  raise HTTPException(
417
418
  status_code=404,
418
- detail=f"LLM provider '{request.prompt.provider}' not found",
419
+ detail=f"LLM provider '{request.generation.provider}' not found",
419
420
  )
420
421
 
421
422
  trace_event("pp_create_completion")
@@ -426,7 +427,7 @@ async def completion(
426
427
 
427
428
  @app.get("/project/llm-providers")
428
429
  async def get_providers(
429
- current_user: Annotated[Union[AppUser, PersistedAppUser], Depends(get_current_user)]
430
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)]
430
431
  ):
431
432
  """List the providers."""
432
433
  trace_event("pp_get_llm_providers")
@@ -437,7 +438,7 @@ async def get_providers(
437
438
 
438
439
  @app.get("/project/settings")
439
440
  async def project_settings(
440
- current_user: Annotated[Union[AppUser, PersistedAppUser], Depends(get_current_user)]
441
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)]
441
442
  ):
442
443
  """Return project settings. This is called by the UI before the establishing the websocket connection."""
443
444
  profiles = []
@@ -450,126 +451,181 @@ async def project_settings(
450
451
  "ui": config.ui.to_dict(),
451
452
  "features": config.features.to_dict(),
452
453
  "userEnv": config.project.user_env,
453
- "dataPersistence": config.data_persistence,
454
- "conversationResumable": bool(config.code.on_chat_resume),
454
+ "dataPersistence": get_data_layer() is not None,
455
+ "threadResumable": bool(config.code.on_chat_resume),
455
456
  "markdown": get_markdown_str(config.root),
456
457
  "chatProfiles": profiles,
457
458
  }
458
459
  )
459
460
 
460
461
 
461
- @app.put("/message/feedback")
462
+ @app.put("/feedback")
462
463
  async def update_feedback(
463
464
  request: Request,
464
465
  update: UpdateFeedbackRequest,
465
- current_user: Annotated[
466
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
467
- ],
466
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
468
467
  ):
469
468
  """Update the human feedback for a particular message."""
470
-
471
- # TODO: check that message belong to a user's conversation
472
-
473
- if not chainlit_client:
469
+ data_layer = get_data_layer()
470
+ if not data_layer:
474
471
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
475
472
 
476
473
  try:
477
- await chainlit_client.set_human_feedback(
478
- message_id=update.messageId,
479
- feedback=update.feedback,
480
- feedbackComment=update.feedbackComment,
481
- )
474
+ feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
482
475
  except Exception as e:
483
476
  raise HTTPException(detail=str(e), status_code=401)
484
477
 
485
- return JSONResponse(content={"success": True})
478
+ return JSONResponse(content={"success": True, "feedbackId": feedback_id})
486
479
 
487
480
 
488
- @app.post("/project/conversations")
489
- async def get_user_conversations(
481
+ @app.post("/project/threads")
482
+ async def get_user_threads(
490
483
  request: Request,
491
- payload: GetConversationsRequest,
492
- current_user: Annotated[
493
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
494
- ],
484
+ payload: GetThreadsRequest,
485
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
495
486
  ):
496
- """Get the conversations page by page."""
497
- # Only show the current user conversations
487
+ """Get the threads page by page."""
488
+ # Only show the current user threads
498
489
 
499
- if not chainlit_client:
490
+ data_layer = get_data_layer()
491
+
492
+ if not data_layer:
500
493
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
501
494
 
502
- payload.filter.username = current_user.username
503
- res = await chainlit_client.get_conversations(payload.pagination, payload.filter)
495
+ payload.filter.userIdentifier = current_user.identifier
496
+
497
+ res = await data_layer.list_threads(payload.pagination, payload.filter)
504
498
  return JSONResponse(content=res.to_dict())
505
499
 
506
500
 
507
- @app.get("/project/conversation/{conversation_id}")
508
- async def get_conversation(
501
+ @app.get("/project/thread/{thread_id}")
502
+ async def get_thread(
509
503
  request: Request,
510
- conversation_id: str,
511
- current_user: Annotated[
512
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
513
- ],
504
+ thread_id: str,
505
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
514
506
  ):
515
- """Get a specific conversation."""
507
+ """Get a specific thread."""
508
+ data_layer = get_data_layer()
516
509
 
517
- if not chainlit_client:
510
+ if not data_layer:
518
511
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
519
512
 
520
- await is_conversation_author(current_user.username, conversation_id)
513
+ await is_thread_author(current_user.identifier, thread_id)
521
514
 
522
- res = await chainlit_client.get_conversation(conversation_id)
515
+ res = await data_layer.get_thread(thread_id)
523
516
  return JSONResponse(content=res)
524
517
 
525
518
 
526
- @app.get("/project/conversation/{conversation_id}/element/{element_id}")
527
- async def get_conversation_element(
519
+ @app.get("/project/thread/{thread_id}/element/{element_id}")
520
+ async def get_thread_element(
528
521
  request: Request,
529
- conversation_id: str,
522
+ thread_id: str,
530
523
  element_id: str,
531
- current_user: Annotated[
532
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
533
- ],
524
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
534
525
  ):
535
- """Get a specific conversation element."""
526
+ """Get a specific thread element."""
527
+ data_layer = get_data_layer()
536
528
 
537
- if not chainlit_client:
529
+ if not data_layer:
538
530
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
539
531
 
540
- await is_conversation_author(current_user.username, conversation_id)
532
+ await is_thread_author(current_user.identifier, thread_id)
541
533
 
542
- res = await chainlit_client.get_element(conversation_id, element_id)
534
+ res = await data_layer.get_element(thread_id, element_id)
543
535
  return JSONResponse(content=res)
544
536
 
545
537
 
546
- @app.delete("/project/conversation")
547
- async def delete_conversation(
538
+ @app.delete("/project/thread")
539
+ async def delete_thread(
548
540
  request: Request,
549
- payload: DeleteConversationRequest,
550
- current_user: Annotated[
551
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
552
- ],
541
+ payload: DeleteThreadRequest,
542
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
553
543
  ):
554
- """Delete a conversation."""
544
+ """Delete a thread."""
545
+
546
+ data_layer = get_data_layer()
555
547
 
556
- if not chainlit_client:
548
+ if not data_layer:
557
549
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
558
550
 
559
- conversation_id = payload.conversationId
551
+ thread_id = payload.threadId
560
552
 
561
- await is_conversation_author(current_user.username, conversation_id)
553
+ await is_thread_author(current_user.identifier, thread_id)
562
554
 
563
- await chainlit_client.delete_conversation(conversation_id)
555
+ await data_layer.delete_thread(thread_id)
564
556
  return JSONResponse(content={"success": True})
565
557
 
566
558
 
559
+ @app.post("/project/file")
560
+ async def upload_file(
561
+ session_id: str,
562
+ file: UploadFile,
563
+ current_user: Annotated[
564
+ Union[None, User, PersistedUser], Depends(get_current_user)
565
+ ],
566
+ ):
567
+ from chainlit.session import WebsocketSession
568
+
569
+ session = WebsocketSession.get_by_id(session_id)
570
+
571
+ if not session:
572
+ raise HTTPException(
573
+ status_code=404,
574
+ detail="Session not found",
575
+ )
576
+
577
+ if current_user:
578
+ if not session.user or session.user.identifier != current_user.identifier:
579
+ raise HTTPException(
580
+ status_code=401,
581
+ detail="You are not authorized to upload files for this session",
582
+ )
583
+
584
+ session.files_dir.mkdir(exist_ok=True)
585
+
586
+ content = await file.read()
587
+
588
+ file_response = await session.persist_file(
589
+ name=file.filename, content=content, mime=file.content_type
590
+ )
591
+
592
+ return JSONResponse(file_response)
593
+
594
+
595
+ @app.get("/project/file/{file_id}")
596
+ async def get_file(
597
+ file_id: str,
598
+ session_id: Optional[str] = None,
599
+ token: Optional[str] = None,
600
+ ):
601
+ from chainlit.session import WebsocketSession
602
+
603
+ session = WebsocketSession.get_by_id(session_id) if session_id else None
604
+
605
+ if not session:
606
+ raise HTTPException(
607
+ status_code=404,
608
+ detail="Session not found",
609
+ )
610
+
611
+ if current_user := await get_current_user(token or ""):
612
+ if not session.user or session.user.identifier != current_user.identifier:
613
+ raise HTTPException(
614
+ status_code=401,
615
+ detail="You are not authorized to upload files for this session",
616
+ )
617
+
618
+ if file_id in session.files:
619
+ file = session.files[file_id]
620
+ return FileResponse(file["path"], media_type=file["type"])
621
+ else:
622
+ raise HTTPException(status_code=404, detail="File not found")
623
+
624
+
567
625
  @app.get("/files/{filename:path}")
568
626
  async def serve_file(
569
627
  filename: str,
570
- current_user: Annotated[
571
- Union[AppUser, PersistedAppUser], Depends(get_current_user)
572
- ],
628
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
573
629
  ):
574
630
  base_path = Path(config.project.local_fs_path).resolve()
575
631
  file_path = (base_path / filename).resolve()
chainlit/session.py CHANGED
@@ -1,13 +1,16 @@
1
- import asyncio
2
1
  import json
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
2
+ import mimetypes
3
+ import shutil
4
+ import uuid
5
+ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
6
+
7
+ import aiofiles
4
8
 
5
9
  if TYPE_CHECKING:
6
10
  from chainlit.message import Message
7
- from chainlit.types import AskResponse
8
-
9
- from chainlit.client.cloud import AppUser, PersistedAppUser
10
- from chainlit.data import chainlit_client
11
+ from chainlit.step import Step
12
+ from chainlit.types import FileDict, FileReference
13
+ from chainlit.user import PersistedUser, User
11
14
 
12
15
 
13
16
  class JSONEncoderIgnoreNonSerializable(json.JSONEncoder):
@@ -27,12 +30,17 @@ def clean_metadata(metadata: Dict):
27
30
  class BaseSession:
28
31
  """Base object."""
29
32
 
33
+ active_steps: List["Step"]
34
+ thread_id_to_resume: Optional[str] = None
35
+
30
36
  def __init__(
31
37
  self,
32
38
  # Id of the session
33
39
  id: str,
40
+ # Thread id
41
+ thread_id: Optional[str],
34
42
  # Logged-in user informations
35
- user: Optional[Union["AppUser", "PersistedAppUser"]],
43
+ user: Optional[Union["User", "PersistedUser"]],
36
44
  # Logged-in user token
37
45
  token: Optional[str],
38
46
  # User specific environment variables. Empty if no user environment variables are required.
@@ -41,41 +49,30 @@ class BaseSession:
41
49
  root_message: Optional["Message"] = None,
42
50
  # Chat profile selected before the session was created
43
51
  chat_profile: Optional[str] = None,
44
- # Conversation id to resume
45
- conversation_id: Optional[str] = None,
46
52
  ):
53
+ if thread_id:
54
+ self.thread_id_to_resume = thread_id
55
+ self.thread_id = thread_id or str(uuid.uuid4())
47
56
  self.user = user
48
57
  self.token = token
49
58
  self.root_message = root_message
50
59
  self.has_user_message = False
51
60
  self.user_env = user_env or {}
52
61
  self.chat_profile = chat_profile
62
+ self.active_steps = []
53
63
 
54
64
  self.id = id
55
- self.conversation_id = conversation_id
56
65
 
57
66
  self.chat_settings: Dict[str, Any] = {}
58
67
 
59
- self.lock = asyncio.Lock()
60
-
61
- async def get_conversation_id(self) -> Optional[str]:
62
- if not chainlit_client:
63
- return None
64
-
65
- if isinstance(self, HTTPSession):
66
- tags = ["api"]
67
- else:
68
- tags = ["chat"]
69
-
70
- async with self.lock:
71
- if not self.conversation_id:
72
- app_user_id = (
73
- self.user.id if isinstance(self.user, PersistedAppUser) else None
74
- )
75
- self.conversation_id = await chainlit_client.create_conversation(
76
- app_user_id=app_user_id, tags=tags
77
- )
78
- return self.conversation_id
68
+ async def persist_file(
69
+ self,
70
+ name: str,
71
+ mime: str,
72
+ path: Optional[str] = None,
73
+ content: Optional[Union[bytes, str]] = None,
74
+ ):
75
+ return None
79
76
 
80
77
  def to_persistable(self) -> Dict:
81
78
  from chainlit.user_session import user_sessions
@@ -94,17 +91,24 @@ class HTTPSession(BaseSession):
94
91
  self,
95
92
  # Id of the session
96
93
  id: str,
94
+ # Thread id
95
+ thread_id: Optional[str] = None,
97
96
  # Logged-in user informations
98
- user: Optional[Union["AppUser", "PersistedAppUser"]],
97
+ user: Optional[Union["User", "PersistedUser"]] = None,
99
98
  # Logged-in user token
100
- token: Optional[str],
101
- user_env: Optional[Dict[str, str]],
99
+ token: Optional[str] = None,
100
+ user_env: Optional[Dict[str, str]] = None,
102
101
  # Last message at the root of the chat
103
102
  root_message: Optional["Message"] = None,
104
103
  # User specific environment variables. Empty if no user environment variables are required.
105
104
  ):
106
105
  super().__init__(
107
- id=id, user=user, token=token, user_env=user_env, root_message=root_message
106
+ id=id,
107
+ thread_id=thread_id,
108
+ user=user,
109
+ token=token,
110
+ user_env=user_env,
111
+ root_message=root_message,
108
112
  )
109
113
 
110
114
 
@@ -129,28 +133,28 @@ class WebsocketSession(BaseSession):
129
133
  # Function to emit a message to the user
130
134
  emit: Callable[[str, Any], None],
131
135
  # Function to ask the user a question
132
- ask_user: Callable[[Any, Optional[int]], Union["AskResponse", None]],
136
+ ask_user: Callable[[Any, Optional[int]], Any],
133
137
  # User specific environment variables. Empty if no user environment variables are required.
134
138
  user_env: Dict[str, str],
139
+ # Thread id
140
+ thread_id: Optional[str] = None,
135
141
  # Logged-in user informations
136
- user: Optional[Union["AppUser", "PersistedAppUser"]],
142
+ user: Optional[Union["User", "PersistedUser"]] = None,
137
143
  # Logged-in user token
138
- token: Optional[str],
144
+ token: Optional[str] = None,
139
145
  # Last message at the root of the chat
140
146
  root_message: Optional["Message"] = None,
141
147
  # Chat profile selected before the session was created
142
148
  chat_profile: Optional[str] = None,
143
- # Conversation id to resume
144
- conversation_id: Optional[str] = None,
145
149
  ):
146
150
  super().__init__(
147
151
  id=id,
152
+ thread_id=thread_id,
148
153
  user=user,
149
154
  token=token,
150
155
  user_env=user_env,
151
156
  root_message=root_message,
152
157
  chat_profile=chat_profile,
153
- conversation_id=conversation_id,
154
158
  )
155
159
 
156
160
  self.socket_id = socket_id
@@ -160,9 +164,66 @@ class WebsocketSession(BaseSession):
160
164
  self.should_stop = False
161
165
  self.restored = False
162
166
 
167
+ self.thread_queues = {} # type: Dict[str, Deque[Callable]]
168
+ self.files = {} # type: Dict[str, "FileDict"]
169
+
163
170
  ws_sessions_id[self.id] = self
164
171
  ws_sessions_sid[socket_id] = self
165
172
 
173
+ @property
174
+ def files_dir(self):
175
+ from chainlit.config import FILES_DIRECTORY
176
+
177
+ return FILES_DIRECTORY / self.id
178
+
179
+ async def persist_file(
180
+ self,
181
+ name: str,
182
+ mime: str,
183
+ path: Optional[str] = None,
184
+ content: Optional[Union[bytes, str]] = None,
185
+ ) -> "FileReference":
186
+ if not path and not content:
187
+ raise ValueError(
188
+ "Either path or content must be provided to persist a file"
189
+ )
190
+
191
+ self.files_dir.mkdir(exist_ok=True)
192
+
193
+ file_id = str(uuid.uuid4())
194
+
195
+ file_path = self.files_dir / file_id
196
+
197
+ file_extension = mimetypes.guess_extension(mime)
198
+ if file_extension:
199
+ file_path = file_path.with_suffix(file_extension)
200
+
201
+ if path:
202
+ # Copy the file from the given path
203
+ async with aiofiles.open(path, "rb") as src, aiofiles.open(
204
+ file_path, "wb"
205
+ ) as dst:
206
+ await dst.write(await src.read())
207
+ elif content:
208
+ # Write the provided content to the file
209
+ async with aiofiles.open(file_path, "wb") as buffer:
210
+ if isinstance(content, str):
211
+ content = content.encode("utf-8")
212
+ await buffer.write(content)
213
+
214
+ # Get the file size
215
+ file_size = file_path.stat().st_size
216
+ # Store the file content in memory
217
+ self.files[file_id] = {
218
+ "id": file_id,
219
+ "path": file_path,
220
+ "name": name,
221
+ "type": mime,
222
+ "size": file_size,
223
+ }
224
+
225
+ return {"id": file_id}
226
+
166
227
  def restore(self, new_socket_id: str):
167
228
  """Associate a new socket id to the session."""
168
229
  ws_sessions_sid.pop(self.socket_id, None)
@@ -172,9 +233,17 @@ class WebsocketSession(BaseSession):
172
233
 
173
234
  def delete(self):
174
235
  """Delete the session."""
236
+ if self.files_dir.is_dir():
237
+ shutil.rmtree(self.files_dir)
175
238
  ws_sessions_sid.pop(self.socket_id, None)
176
239
  ws_sessions_id.pop(self.id, None)
177
240
 
241
+ async def flush_method_queue(self):
242
+ for method_name, queue in self.thread_queues.items():
243
+ while queue:
244
+ method, self, args, kwargs = queue.popleft()
245
+ await method(self, *args, **kwargs)
246
+
178
247
  @classmethod
179
248
  def get(cls, socket_id: str):
180
249
  """Get session by socket id."""