chainlit 0.7.604rc2__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +32 -23
- chainlit/auth.py +9 -10
- chainlit/cache.py +3 -3
- chainlit/cli/__init__.py +12 -2
- chainlit/config.py +22 -13
- chainlit/context.py +7 -3
- chainlit/data/__init__.py +375 -9
- chainlit/data/acl.py +6 -5
- chainlit/element.py +86 -123
- chainlit/emitter.py +117 -50
- chainlit/frontend/dist/assets/index-6aee009a.js +697 -0
- chainlit/frontend/dist/assets/{react-plotly-16f7de12.js → react-plotly-2f07c02a.js} +1 -1
- chainlit/frontend/dist/index.html +1 -1
- chainlit/haystack/callbacks.py +45 -43
- chainlit/hello.py +1 -1
- chainlit/langchain/callbacks.py +135 -120
- chainlit/llama_index/callbacks.py +68 -48
- chainlit/message.py +179 -207
- chainlit/oauth_providers.py +39 -34
- chainlit/playground/provider.py +44 -30
- chainlit/playground/providers/anthropic.py +4 -4
- chainlit/playground/providers/huggingface.py +2 -2
- chainlit/playground/providers/langchain.py +8 -10
- chainlit/playground/providers/openai.py +19 -13
- chainlit/server.py +155 -99
- chainlit/session.py +109 -40
- chainlit/socket.py +54 -38
- chainlit/step.py +393 -0
- chainlit/types.py +78 -21
- chainlit/user.py +32 -0
- chainlit/user_session.py +1 -5
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/METADATA +12 -31
- chainlit-1.0.0rc0.dist-info/RECORD +60 -0
- chainlit/client/base.py +0 -169
- chainlit/client/cloud.py +0 -500
- chainlit/frontend/dist/assets/index-c58dbd4b.js +0 -871
- chainlit/prompt.py +0 -40
- chainlit-0.7.604rc2.dist-info/RECORD +0 -61
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/WHEEL +0 -0
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import glob
|
|
2
2
|
import json
|
|
3
3
|
import mimetypes
|
|
4
|
+
import shutil
|
|
4
5
|
import urllib.parse
|
|
5
6
|
from typing import Optional, Union
|
|
6
7
|
|
|
@@ -17,30 +18,31 @@ from contextlib import asynccontextmanager
|
|
|
17
18
|
from pathlib import Path
|
|
18
19
|
|
|
19
20
|
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
20
|
-
from chainlit.client.cloud import AppUser, PersistedAppUser
|
|
21
21
|
from chainlit.config import (
|
|
22
22
|
APP_ROOT,
|
|
23
23
|
BACKEND_ROOT,
|
|
24
24
|
DEFAULT_HOST,
|
|
25
|
+
FILES_DIRECTORY,
|
|
25
26
|
PACKAGE_ROOT,
|
|
26
27
|
config,
|
|
27
28
|
load_module,
|
|
28
29
|
reload_config,
|
|
29
30
|
)
|
|
30
|
-
from chainlit.data import
|
|
31
|
-
from chainlit.data.acl import
|
|
31
|
+
from chainlit.data import get_data_layer
|
|
32
|
+
from chainlit.data.acl import is_thread_author
|
|
32
33
|
from chainlit.logger import logger
|
|
33
34
|
from chainlit.markdown import get_markdown_str
|
|
34
35
|
from chainlit.playground.config import get_llm_providers
|
|
35
36
|
from chainlit.telemetry import trace_event
|
|
36
37
|
from chainlit.types import (
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
DeleteThreadRequest,
|
|
39
|
+
GenerationRequest,
|
|
40
|
+
GetThreadsRequest,
|
|
40
41
|
Theme,
|
|
41
42
|
UpdateFeedbackRequest,
|
|
42
43
|
)
|
|
43
|
-
from
|
|
44
|
+
from chainlit.user import PersistedUser, User
|
|
45
|
+
from fastapi import Depends, FastAPI, HTTPException, Query, Request, UploadFile, status
|
|
44
46
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
45
47
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
46
48
|
from fastapi.staticfiles import StaticFiles
|
|
@@ -117,6 +119,9 @@ async def lifespan(app: FastAPI):
|
|
|
117
119
|
except asyncio.exceptions.CancelledError:
|
|
118
120
|
pass
|
|
119
121
|
|
|
122
|
+
if FILES_DIRECTORY.is_dir():
|
|
123
|
+
shutil.rmtree(FILES_DIRECTORY)
|
|
124
|
+
|
|
120
125
|
# Force exit the process to avoid potential AnyIO threads still running
|
|
121
126
|
os._exit(0)
|
|
122
127
|
|
|
@@ -134,6 +139,7 @@ def get_build_dir():
|
|
|
134
139
|
|
|
135
140
|
build_dir = get_build_dir()
|
|
136
141
|
|
|
142
|
+
|
|
137
143
|
app = FastAPI(lifespan=lifespan)
|
|
138
144
|
|
|
139
145
|
app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public")
|
|
@@ -156,14 +162,10 @@ app.add_middleware(
|
|
|
156
162
|
)
|
|
157
163
|
|
|
158
164
|
|
|
159
|
-
# Define max HTTP data size to 100 MB
|
|
160
|
-
max_message_size = 100 * 1024 * 1024
|
|
161
|
-
|
|
162
165
|
socket = SocketManager(
|
|
163
166
|
app,
|
|
164
167
|
cors_allowed_origins=[],
|
|
165
168
|
async_mode="asgi",
|
|
166
|
-
max_http_buffer_size=max_message_size,
|
|
167
169
|
)
|
|
168
170
|
|
|
169
171
|
|
|
@@ -244,18 +246,18 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
|
244
246
|
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
245
247
|
)
|
|
246
248
|
|
|
247
|
-
|
|
249
|
+
user = await config.code.password_auth_callback(
|
|
248
250
|
form_data.username, form_data.password
|
|
249
251
|
)
|
|
250
252
|
|
|
251
|
-
if not
|
|
253
|
+
if not user:
|
|
252
254
|
raise HTTPException(
|
|
253
255
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
254
256
|
detail="credentialssignin",
|
|
255
257
|
)
|
|
256
|
-
access_token = create_jwt(
|
|
257
|
-
if
|
|
258
|
-
await
|
|
258
|
+
access_token = create_jwt(user)
|
|
259
|
+
if data_layer := get_data_layer():
|
|
260
|
+
await data_layer.create_user(user)
|
|
259
261
|
return {
|
|
260
262
|
"access_token": access_token,
|
|
261
263
|
"token_type": "bearer",
|
|
@@ -270,17 +272,17 @@ async def header_auth(request: Request):
|
|
|
270
272
|
detail="No header_auth_callback defined",
|
|
271
273
|
)
|
|
272
274
|
|
|
273
|
-
|
|
275
|
+
user = await config.code.header_auth_callback(request.headers)
|
|
274
276
|
|
|
275
|
-
if not
|
|
277
|
+
if not user:
|
|
276
278
|
raise HTTPException(
|
|
277
279
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
278
280
|
detail="Unauthorized",
|
|
279
281
|
)
|
|
280
282
|
|
|
281
|
-
access_token = create_jwt(
|
|
282
|
-
if
|
|
283
|
-
await
|
|
283
|
+
access_token = create_jwt(user)
|
|
284
|
+
if data_layer := get_data_layer():
|
|
285
|
+
await data_layer.create_user(user)
|
|
284
286
|
return {
|
|
285
287
|
"access_token": access_token,
|
|
286
288
|
"token_type": "bearer",
|
|
@@ -369,21 +371,22 @@ async def oauth_callback(
|
|
|
369
371
|
url = get_user_facing_url(request.url)
|
|
370
372
|
token = await provider.get_token(code, url)
|
|
371
373
|
|
|
372
|
-
(raw_user_data,
|
|
374
|
+
(raw_user_data, default_user) = await provider.get_user_info(token)
|
|
373
375
|
|
|
374
|
-
|
|
375
|
-
provider_id, token, raw_user_data,
|
|
376
|
+
user = await config.code.oauth_callback(
|
|
377
|
+
provider_id, token, raw_user_data, default_user
|
|
376
378
|
)
|
|
377
379
|
|
|
378
|
-
if not
|
|
380
|
+
if not user:
|
|
379
381
|
raise HTTPException(
|
|
380
382
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
381
383
|
detail="Unauthorized",
|
|
382
384
|
)
|
|
383
385
|
|
|
384
|
-
access_token = create_jwt(
|
|
385
|
-
|
|
386
|
-
|
|
386
|
+
access_token = create_jwt(user)
|
|
387
|
+
|
|
388
|
+
if data_layer := get_data_layer():
|
|
389
|
+
await data_layer.create_user(user)
|
|
387
390
|
|
|
388
391
|
params = urllib.parse.urlencode(
|
|
389
392
|
{
|
|
@@ -399,23 +402,21 @@ async def oauth_callback(
|
|
|
399
402
|
return response
|
|
400
403
|
|
|
401
404
|
|
|
402
|
-
@app.post("/
|
|
403
|
-
async def
|
|
404
|
-
request:
|
|
405
|
-
current_user: Annotated[
|
|
406
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
407
|
-
],
|
|
405
|
+
@app.post("/generation")
|
|
406
|
+
async def generation(
|
|
407
|
+
request: GenerationRequest,
|
|
408
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
408
409
|
):
|
|
409
410
|
"""Handle a completion request from the prompt playground."""
|
|
410
411
|
|
|
411
412
|
providers = get_llm_providers()
|
|
412
413
|
|
|
413
414
|
try:
|
|
414
|
-
provider = [p for p in providers if p.id == request.
|
|
415
|
+
provider = [p for p in providers if p.id == request.generation.provider][0]
|
|
415
416
|
except IndexError:
|
|
416
417
|
raise HTTPException(
|
|
417
418
|
status_code=404,
|
|
418
|
-
detail=f"LLM provider '{request.
|
|
419
|
+
detail=f"LLM provider '{request.generation.provider}' not found",
|
|
419
420
|
)
|
|
420
421
|
|
|
421
422
|
trace_event("pp_create_completion")
|
|
@@ -426,7 +427,7 @@ async def completion(
|
|
|
426
427
|
|
|
427
428
|
@app.get("/project/llm-providers")
|
|
428
429
|
async def get_providers(
|
|
429
|
-
current_user: Annotated[Union[
|
|
430
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)]
|
|
430
431
|
):
|
|
431
432
|
"""List the providers."""
|
|
432
433
|
trace_event("pp_get_llm_providers")
|
|
@@ -437,7 +438,7 @@ async def get_providers(
|
|
|
437
438
|
|
|
438
439
|
@app.get("/project/settings")
|
|
439
440
|
async def project_settings(
|
|
440
|
-
current_user: Annotated[Union[
|
|
441
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)]
|
|
441
442
|
):
|
|
442
443
|
"""Return project settings. This is called by the UI before the establishing the websocket connection."""
|
|
443
444
|
profiles = []
|
|
@@ -450,126 +451,181 @@ async def project_settings(
|
|
|
450
451
|
"ui": config.ui.to_dict(),
|
|
451
452
|
"features": config.features.to_dict(),
|
|
452
453
|
"userEnv": config.project.user_env,
|
|
453
|
-
"dataPersistence":
|
|
454
|
-
"
|
|
454
|
+
"dataPersistence": get_data_layer() is not None,
|
|
455
|
+
"threadResumable": bool(config.code.on_chat_resume),
|
|
455
456
|
"markdown": get_markdown_str(config.root),
|
|
456
457
|
"chatProfiles": profiles,
|
|
457
458
|
}
|
|
458
459
|
)
|
|
459
460
|
|
|
460
461
|
|
|
461
|
-
@app.put("/
|
|
462
|
+
@app.put("/feedback")
|
|
462
463
|
async def update_feedback(
|
|
463
464
|
request: Request,
|
|
464
465
|
update: UpdateFeedbackRequest,
|
|
465
|
-
current_user: Annotated[
|
|
466
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
467
|
-
],
|
|
466
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
468
467
|
):
|
|
469
468
|
"""Update the human feedback for a particular message."""
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
if not chainlit_client:
|
|
469
|
+
data_layer = get_data_layer()
|
|
470
|
+
if not data_layer:
|
|
474
471
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
475
472
|
|
|
476
473
|
try:
|
|
477
|
-
await
|
|
478
|
-
message_id=update.messageId,
|
|
479
|
-
feedback=update.feedback,
|
|
480
|
-
feedbackComment=update.feedbackComment,
|
|
481
|
-
)
|
|
474
|
+
feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
|
|
482
475
|
except Exception as e:
|
|
483
476
|
raise HTTPException(detail=str(e), status_code=401)
|
|
484
477
|
|
|
485
|
-
return JSONResponse(content={"success": True})
|
|
478
|
+
return JSONResponse(content={"success": True, "feedbackId": feedback_id})
|
|
486
479
|
|
|
487
480
|
|
|
488
|
-
@app.post("/project/
|
|
489
|
-
async def
|
|
481
|
+
@app.post("/project/threads")
|
|
482
|
+
async def get_user_threads(
|
|
490
483
|
request: Request,
|
|
491
|
-
payload:
|
|
492
|
-
current_user: Annotated[
|
|
493
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
494
|
-
],
|
|
484
|
+
payload: GetThreadsRequest,
|
|
485
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
495
486
|
):
|
|
496
|
-
"""Get the
|
|
497
|
-
# Only show the current user
|
|
487
|
+
"""Get the threads page by page."""
|
|
488
|
+
# Only show the current user threads
|
|
498
489
|
|
|
499
|
-
|
|
490
|
+
data_layer = get_data_layer()
|
|
491
|
+
|
|
492
|
+
if not data_layer:
|
|
500
493
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
501
494
|
|
|
502
|
-
payload.filter.
|
|
503
|
-
|
|
495
|
+
payload.filter.userIdentifier = current_user.identifier
|
|
496
|
+
|
|
497
|
+
res = await data_layer.list_threads(payload.pagination, payload.filter)
|
|
504
498
|
return JSONResponse(content=res.to_dict())
|
|
505
499
|
|
|
506
500
|
|
|
507
|
-
@app.get("/project/
|
|
508
|
-
async def
|
|
501
|
+
@app.get("/project/thread/{thread_id}")
|
|
502
|
+
async def get_thread(
|
|
509
503
|
request: Request,
|
|
510
|
-
|
|
511
|
-
current_user: Annotated[
|
|
512
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
513
|
-
],
|
|
504
|
+
thread_id: str,
|
|
505
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
514
506
|
):
|
|
515
|
-
"""Get a specific
|
|
507
|
+
"""Get a specific thread."""
|
|
508
|
+
data_layer = get_data_layer()
|
|
516
509
|
|
|
517
|
-
if not
|
|
510
|
+
if not data_layer:
|
|
518
511
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
519
512
|
|
|
520
|
-
await
|
|
513
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
521
514
|
|
|
522
|
-
res = await
|
|
515
|
+
res = await data_layer.get_thread(thread_id)
|
|
523
516
|
return JSONResponse(content=res)
|
|
524
517
|
|
|
525
518
|
|
|
526
|
-
@app.get("/project/
|
|
527
|
-
async def
|
|
519
|
+
@app.get("/project/thread/{thread_id}/element/{element_id}")
|
|
520
|
+
async def get_thread_element(
|
|
528
521
|
request: Request,
|
|
529
|
-
|
|
522
|
+
thread_id: str,
|
|
530
523
|
element_id: str,
|
|
531
|
-
current_user: Annotated[
|
|
532
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
533
|
-
],
|
|
524
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
534
525
|
):
|
|
535
|
-
"""Get a specific
|
|
526
|
+
"""Get a specific thread element."""
|
|
527
|
+
data_layer = get_data_layer()
|
|
536
528
|
|
|
537
|
-
if not
|
|
529
|
+
if not data_layer:
|
|
538
530
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
539
531
|
|
|
540
|
-
await
|
|
532
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
541
533
|
|
|
542
|
-
res = await
|
|
534
|
+
res = await data_layer.get_element(thread_id, element_id)
|
|
543
535
|
return JSONResponse(content=res)
|
|
544
536
|
|
|
545
537
|
|
|
546
|
-
@app.delete("/project/
|
|
547
|
-
async def
|
|
538
|
+
@app.delete("/project/thread")
|
|
539
|
+
async def delete_thread(
|
|
548
540
|
request: Request,
|
|
549
|
-
payload:
|
|
550
|
-
current_user: Annotated[
|
|
551
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
552
|
-
],
|
|
541
|
+
payload: DeleteThreadRequest,
|
|
542
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
553
543
|
):
|
|
554
|
-
"""Delete a
|
|
544
|
+
"""Delete a thread."""
|
|
545
|
+
|
|
546
|
+
data_layer = get_data_layer()
|
|
555
547
|
|
|
556
|
-
if not
|
|
548
|
+
if not data_layer:
|
|
557
549
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
558
550
|
|
|
559
|
-
|
|
551
|
+
thread_id = payload.threadId
|
|
560
552
|
|
|
561
|
-
await
|
|
553
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
562
554
|
|
|
563
|
-
await
|
|
555
|
+
await data_layer.delete_thread(thread_id)
|
|
564
556
|
return JSONResponse(content={"success": True})
|
|
565
557
|
|
|
566
558
|
|
|
559
|
+
@app.post("/project/file")
|
|
560
|
+
async def upload_file(
|
|
561
|
+
session_id: str,
|
|
562
|
+
file: UploadFile,
|
|
563
|
+
current_user: Annotated[
|
|
564
|
+
Union[None, User, PersistedUser], Depends(get_current_user)
|
|
565
|
+
],
|
|
566
|
+
):
|
|
567
|
+
from chainlit.session import WebsocketSession
|
|
568
|
+
|
|
569
|
+
session = WebsocketSession.get_by_id(session_id)
|
|
570
|
+
|
|
571
|
+
if not session:
|
|
572
|
+
raise HTTPException(
|
|
573
|
+
status_code=404,
|
|
574
|
+
detail="Session not found",
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
if current_user:
|
|
578
|
+
if not session.user or session.user.identifier != current_user.identifier:
|
|
579
|
+
raise HTTPException(
|
|
580
|
+
status_code=401,
|
|
581
|
+
detail="You are not authorized to upload files for this session",
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
session.files_dir.mkdir(exist_ok=True)
|
|
585
|
+
|
|
586
|
+
content = await file.read()
|
|
587
|
+
|
|
588
|
+
file_response = await session.persist_file(
|
|
589
|
+
name=file.filename, content=content, mime=file.content_type
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
return JSONResponse(file_response)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
@app.get("/project/file/{file_id}")
|
|
596
|
+
async def get_file(
|
|
597
|
+
file_id: str,
|
|
598
|
+
session_id: Optional[str] = None,
|
|
599
|
+
token: Optional[str] = None,
|
|
600
|
+
):
|
|
601
|
+
from chainlit.session import WebsocketSession
|
|
602
|
+
|
|
603
|
+
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
604
|
+
|
|
605
|
+
if not session:
|
|
606
|
+
raise HTTPException(
|
|
607
|
+
status_code=404,
|
|
608
|
+
detail="Session not found",
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
if current_user := await get_current_user(token or ""):
|
|
612
|
+
if not session.user or session.user.identifier != current_user.identifier:
|
|
613
|
+
raise HTTPException(
|
|
614
|
+
status_code=401,
|
|
615
|
+
detail="You are not authorized to upload files for this session",
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
if file_id in session.files:
|
|
619
|
+
file = session.files[file_id]
|
|
620
|
+
return FileResponse(file["path"], media_type=file["type"])
|
|
621
|
+
else:
|
|
622
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
623
|
+
|
|
624
|
+
|
|
567
625
|
@app.get("/files/{filename:path}")
|
|
568
626
|
async def serve_file(
|
|
569
627
|
filename: str,
|
|
570
|
-
current_user: Annotated[
|
|
571
|
-
Union[AppUser, PersistedAppUser], Depends(get_current_user)
|
|
572
|
-
],
|
|
628
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
573
629
|
):
|
|
574
630
|
base_path = Path(config.project.local_fs_path).resolve()
|
|
575
631
|
file_path = (base_path / filename).resolve()
|
chainlit/session.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import json
|
|
3
|
-
|
|
2
|
+
import mimetypes
|
|
3
|
+
import shutil
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import aiofiles
|
|
4
8
|
|
|
5
9
|
if TYPE_CHECKING:
|
|
6
10
|
from chainlit.message import Message
|
|
7
|
-
from chainlit.
|
|
8
|
-
|
|
9
|
-
from chainlit.
|
|
10
|
-
from chainlit.data import chainlit_client
|
|
11
|
+
from chainlit.step import Step
|
|
12
|
+
from chainlit.types import FileDict, FileReference
|
|
13
|
+
from chainlit.user import PersistedUser, User
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
class JSONEncoderIgnoreNonSerializable(json.JSONEncoder):
|
|
@@ -27,12 +30,17 @@ def clean_metadata(metadata: Dict):
|
|
|
27
30
|
class BaseSession:
|
|
28
31
|
"""Base object."""
|
|
29
32
|
|
|
33
|
+
active_steps: List["Step"]
|
|
34
|
+
thread_id_to_resume: Optional[str] = None
|
|
35
|
+
|
|
30
36
|
def __init__(
|
|
31
37
|
self,
|
|
32
38
|
# Id of the session
|
|
33
39
|
id: str,
|
|
40
|
+
# Thread id
|
|
41
|
+
thread_id: Optional[str],
|
|
34
42
|
# Logged-in user informations
|
|
35
|
-
user: Optional[Union["
|
|
43
|
+
user: Optional[Union["User", "PersistedUser"]],
|
|
36
44
|
# Logged-in user token
|
|
37
45
|
token: Optional[str],
|
|
38
46
|
# User specific environment variables. Empty if no user environment variables are required.
|
|
@@ -41,41 +49,30 @@ class BaseSession:
|
|
|
41
49
|
root_message: Optional["Message"] = None,
|
|
42
50
|
# Chat profile selected before the session was created
|
|
43
51
|
chat_profile: Optional[str] = None,
|
|
44
|
-
# Conversation id to resume
|
|
45
|
-
conversation_id: Optional[str] = None,
|
|
46
52
|
):
|
|
53
|
+
if thread_id:
|
|
54
|
+
self.thread_id_to_resume = thread_id
|
|
55
|
+
self.thread_id = thread_id or str(uuid.uuid4())
|
|
47
56
|
self.user = user
|
|
48
57
|
self.token = token
|
|
49
58
|
self.root_message = root_message
|
|
50
59
|
self.has_user_message = False
|
|
51
60
|
self.user_env = user_env or {}
|
|
52
61
|
self.chat_profile = chat_profile
|
|
62
|
+
self.active_steps = []
|
|
53
63
|
|
|
54
64
|
self.id = id
|
|
55
|
-
self.conversation_id = conversation_id
|
|
56
65
|
|
|
57
66
|
self.chat_settings: Dict[str, Any] = {}
|
|
58
67
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
else:
|
|
68
|
-
tags = ["chat"]
|
|
69
|
-
|
|
70
|
-
async with self.lock:
|
|
71
|
-
if not self.conversation_id:
|
|
72
|
-
app_user_id = (
|
|
73
|
-
self.user.id if isinstance(self.user, PersistedAppUser) else None
|
|
74
|
-
)
|
|
75
|
-
self.conversation_id = await chainlit_client.create_conversation(
|
|
76
|
-
app_user_id=app_user_id, tags=tags
|
|
77
|
-
)
|
|
78
|
-
return self.conversation_id
|
|
68
|
+
async def persist_file(
|
|
69
|
+
self,
|
|
70
|
+
name: str,
|
|
71
|
+
mime: str,
|
|
72
|
+
path: Optional[str] = None,
|
|
73
|
+
content: Optional[Union[bytes, str]] = None,
|
|
74
|
+
):
|
|
75
|
+
return None
|
|
79
76
|
|
|
80
77
|
def to_persistable(self) -> Dict:
|
|
81
78
|
from chainlit.user_session import user_sessions
|
|
@@ -94,17 +91,24 @@ class HTTPSession(BaseSession):
|
|
|
94
91
|
self,
|
|
95
92
|
# Id of the session
|
|
96
93
|
id: str,
|
|
94
|
+
# Thread id
|
|
95
|
+
thread_id: Optional[str] = None,
|
|
97
96
|
# Logged-in user informations
|
|
98
|
-
user: Optional[Union["
|
|
97
|
+
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
99
98
|
# Logged-in user token
|
|
100
|
-
token: Optional[str],
|
|
101
|
-
user_env: Optional[Dict[str, str]],
|
|
99
|
+
token: Optional[str] = None,
|
|
100
|
+
user_env: Optional[Dict[str, str]] = None,
|
|
102
101
|
# Last message at the root of the chat
|
|
103
102
|
root_message: Optional["Message"] = None,
|
|
104
103
|
# User specific environment variables. Empty if no user environment variables are required.
|
|
105
104
|
):
|
|
106
105
|
super().__init__(
|
|
107
|
-
id=id,
|
|
106
|
+
id=id,
|
|
107
|
+
thread_id=thread_id,
|
|
108
|
+
user=user,
|
|
109
|
+
token=token,
|
|
110
|
+
user_env=user_env,
|
|
111
|
+
root_message=root_message,
|
|
108
112
|
)
|
|
109
113
|
|
|
110
114
|
|
|
@@ -129,28 +133,28 @@ class WebsocketSession(BaseSession):
|
|
|
129
133
|
# Function to emit a message to the user
|
|
130
134
|
emit: Callable[[str, Any], None],
|
|
131
135
|
# Function to ask the user a question
|
|
132
|
-
ask_user: Callable[[Any, Optional[int]],
|
|
136
|
+
ask_user: Callable[[Any, Optional[int]], Any],
|
|
133
137
|
# User specific environment variables. Empty if no user environment variables are required.
|
|
134
138
|
user_env: Dict[str, str],
|
|
139
|
+
# Thread id
|
|
140
|
+
thread_id: Optional[str] = None,
|
|
135
141
|
# Logged-in user informations
|
|
136
|
-
user: Optional[Union["
|
|
142
|
+
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
137
143
|
# Logged-in user token
|
|
138
|
-
token: Optional[str],
|
|
144
|
+
token: Optional[str] = None,
|
|
139
145
|
# Last message at the root of the chat
|
|
140
146
|
root_message: Optional["Message"] = None,
|
|
141
147
|
# Chat profile selected before the session was created
|
|
142
148
|
chat_profile: Optional[str] = None,
|
|
143
|
-
# Conversation id to resume
|
|
144
|
-
conversation_id: Optional[str] = None,
|
|
145
149
|
):
|
|
146
150
|
super().__init__(
|
|
147
151
|
id=id,
|
|
152
|
+
thread_id=thread_id,
|
|
148
153
|
user=user,
|
|
149
154
|
token=token,
|
|
150
155
|
user_env=user_env,
|
|
151
156
|
root_message=root_message,
|
|
152
157
|
chat_profile=chat_profile,
|
|
153
|
-
conversation_id=conversation_id,
|
|
154
158
|
)
|
|
155
159
|
|
|
156
160
|
self.socket_id = socket_id
|
|
@@ -160,9 +164,66 @@ class WebsocketSession(BaseSession):
|
|
|
160
164
|
self.should_stop = False
|
|
161
165
|
self.restored = False
|
|
162
166
|
|
|
167
|
+
self.thread_queues = {} # type: Dict[str, Deque[Callable]]
|
|
168
|
+
self.files = {} # type: Dict[str, "FileDict"]
|
|
169
|
+
|
|
163
170
|
ws_sessions_id[self.id] = self
|
|
164
171
|
ws_sessions_sid[socket_id] = self
|
|
165
172
|
|
|
173
|
+
@property
|
|
174
|
+
def files_dir(self):
|
|
175
|
+
from chainlit.config import FILES_DIRECTORY
|
|
176
|
+
|
|
177
|
+
return FILES_DIRECTORY / self.id
|
|
178
|
+
|
|
179
|
+
async def persist_file(
|
|
180
|
+
self,
|
|
181
|
+
name: str,
|
|
182
|
+
mime: str,
|
|
183
|
+
path: Optional[str] = None,
|
|
184
|
+
content: Optional[Union[bytes, str]] = None,
|
|
185
|
+
) -> "FileReference":
|
|
186
|
+
if not path and not content:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"Either path or content must be provided to persist a file"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.files_dir.mkdir(exist_ok=True)
|
|
192
|
+
|
|
193
|
+
file_id = str(uuid.uuid4())
|
|
194
|
+
|
|
195
|
+
file_path = self.files_dir / file_id
|
|
196
|
+
|
|
197
|
+
file_extension = mimetypes.guess_extension(mime)
|
|
198
|
+
if file_extension:
|
|
199
|
+
file_path = file_path.with_suffix(file_extension)
|
|
200
|
+
|
|
201
|
+
if path:
|
|
202
|
+
# Copy the file from the given path
|
|
203
|
+
async with aiofiles.open(path, "rb") as src, aiofiles.open(
|
|
204
|
+
file_path, "wb"
|
|
205
|
+
) as dst:
|
|
206
|
+
await dst.write(await src.read())
|
|
207
|
+
elif content:
|
|
208
|
+
# Write the provided content to the file
|
|
209
|
+
async with aiofiles.open(file_path, "wb") as buffer:
|
|
210
|
+
if isinstance(content, str):
|
|
211
|
+
content = content.encode("utf-8")
|
|
212
|
+
await buffer.write(content)
|
|
213
|
+
|
|
214
|
+
# Get the file size
|
|
215
|
+
file_size = file_path.stat().st_size
|
|
216
|
+
# Store the file content in memory
|
|
217
|
+
self.files[file_id] = {
|
|
218
|
+
"id": file_id,
|
|
219
|
+
"path": file_path,
|
|
220
|
+
"name": name,
|
|
221
|
+
"type": mime,
|
|
222
|
+
"size": file_size,
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
return {"id": file_id}
|
|
226
|
+
|
|
166
227
|
def restore(self, new_socket_id: str):
|
|
167
228
|
"""Associate a new socket id to the session."""
|
|
168
229
|
ws_sessions_sid.pop(self.socket_id, None)
|
|
@@ -172,9 +233,17 @@ class WebsocketSession(BaseSession):
|
|
|
172
233
|
|
|
173
234
|
def delete(self):
|
|
174
235
|
"""Delete the session."""
|
|
236
|
+
if self.files_dir.is_dir():
|
|
237
|
+
shutil.rmtree(self.files_dir)
|
|
175
238
|
ws_sessions_sid.pop(self.socket_id, None)
|
|
176
239
|
ws_sessions_id.pop(self.id, None)
|
|
177
240
|
|
|
241
|
+
async def flush_method_queue(self):
|
|
242
|
+
for method_name, queue in self.thread_queues.items():
|
|
243
|
+
while queue:
|
|
244
|
+
method, self, args, kwargs = queue.popleft()
|
|
245
|
+
await method(self, *args, **kwargs)
|
|
246
|
+
|
|
178
247
|
@classmethod
|
|
179
248
|
def get(cls, socket_id: str):
|
|
180
249
|
"""Get session by socket id."""
|