chainlit 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +30 -7
- chainlit/action.py +2 -4
- chainlit/cache.py +24 -1
- chainlit/cli/__init__.py +64 -21
- chainlit/client/base.py +152 -0
- chainlit/client/cloud.py +440 -0
- chainlit/client/local.py +257 -0
- chainlit/client/utils.py +23 -0
- chainlit/config.py +92 -29
- chainlit/context.py +29 -0
- chainlit/db/__init__.py +35 -0
- chainlit/db/prisma/schema.prisma +48 -0
- chainlit/element.py +54 -41
- chainlit/emitter.py +1 -30
- chainlit/frontend/dist/assets/index-995e21ad.js +11 -0
- chainlit/frontend/dist/assets/index-f93cc942.css +1 -0
- chainlit/frontend/dist/assets/index-fb1e167a.js +523 -0
- chainlit/frontend/dist/index.html +2 -2
- chainlit/lc/agent.py +1 -0
- chainlit/lc/callbacks.py +6 -21
- chainlit/logger.py +7 -2
- chainlit/message.py +22 -16
- chainlit/server.py +169 -59
- chainlit/session.py +1 -3
- chainlit/sync.py +16 -28
- chainlit/types.py +26 -1
- chainlit/user_session.py +1 -1
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/METADATA +8 -3
- chainlit-0.4.2.dist-info/RECORD +44 -0
- chainlit/client.py +0 -287
- chainlit/frontend/dist/assets/index-0cc9e355.css +0 -1
- chainlit/frontend/dist/assets/index-9e4bccd1.js +0 -717
- chainlit-0.4.0.dist-info/RECORD +0 -37
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/WHEEL +0 -0
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/entry_points.txt +0 -0
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
<script>
|
|
15
15
|
const global = globalThis;
|
|
16
16
|
</script>
|
|
17
|
-
<script type="module" crossorigin src="/assets/index-
|
|
18
|
-
<link rel="stylesheet" href="/assets/index-
|
|
17
|
+
<script type="module" crossorigin src="/assets/index-fb1e167a.js"></script>
|
|
18
|
+
<link rel="stylesheet" href="/assets/index-f93cc942.css">
|
|
19
19
|
</head>
|
|
20
20
|
<body>
|
|
21
21
|
<div id="root"></div>
|
chainlit/lc/agent.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from chainlit.lc.callbacks import ChainlitCallbackHandler, AsyncChainlitCallbackHandler
|
|
3
3
|
from chainlit.sync import make_async
|
|
4
|
+
from chainlit.context import emitter_var
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
|
chainlit/lc/callbacks.py
CHANGED
|
@@ -6,7 +6,8 @@ from langchain.schema import (
|
|
|
6
6
|
BaseMessage,
|
|
7
7
|
LLMResult,
|
|
8
8
|
)
|
|
9
|
-
from chainlit.emitter import
|
|
9
|
+
from chainlit.emitter import ChainlitEmitter
|
|
10
|
+
from chainlit.context import get_emitter
|
|
10
11
|
from chainlit.message import Message, ErrorMessage
|
|
11
12
|
from chainlit.config import config
|
|
12
13
|
from chainlit.types import LLMSettings
|
|
@@ -107,14 +108,10 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
|
|
|
107
108
|
return
|
|
108
109
|
|
|
109
110
|
if config.code.lc_rename:
|
|
110
|
-
author = run_sync(
|
|
111
|
-
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
|
|
112
|
-
)
|
|
111
|
+
author = run_sync(config.code.lc_rename(author))
|
|
113
112
|
|
|
114
113
|
self.pop_prompt()
|
|
115
114
|
|
|
116
|
-
__chainlit_emitter__ = self.emitter
|
|
117
|
-
|
|
118
115
|
streamed_message = Message(
|
|
119
116
|
author=author,
|
|
120
117
|
indent=indent,
|
|
@@ -135,11 +132,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
|
|
|
135
132
|
return
|
|
136
133
|
|
|
137
134
|
if config.code.lc_rename:
|
|
138
|
-
author = run_sync(
|
|
139
|
-
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
__chainlit_emitter__ = self.emitter
|
|
135
|
+
author = run_sync(config.code.lc_rename(author))
|
|
143
136
|
|
|
144
137
|
if error:
|
|
145
138
|
run_sync(ErrorMessage(author=author, content=message).send())
|
|
@@ -267,14 +260,10 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
|
|
|
267
260
|
return
|
|
268
261
|
|
|
269
262
|
if config.code.lc_rename:
|
|
270
|
-
author = await config.code.lc_rename(
|
|
271
|
-
author, __chainlit_emitter__=self.emitter
|
|
272
|
-
)
|
|
263
|
+
author = await config.code.lc_rename(author)
|
|
273
264
|
|
|
274
265
|
self.pop_prompt()
|
|
275
266
|
|
|
276
|
-
__chainlit_emitter__ = self.emitter
|
|
277
|
-
|
|
278
267
|
streamed_message = Message(
|
|
279
268
|
author=author,
|
|
280
269
|
indent=indent,
|
|
@@ -295,11 +284,7 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
|
|
|
295
284
|
return
|
|
296
285
|
|
|
297
286
|
if config.code.lc_rename:
|
|
298
|
-
author = await config.code.lc_rename(
|
|
299
|
-
author, __chainlit_emitter__=self.emitter
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
__chainlit_emitter__ = self.emitter
|
|
287
|
+
author = await config.code.lc_rename(author)
|
|
303
288
|
|
|
304
289
|
if error:
|
|
305
290
|
await ErrorMessage(author=author, content=message).send()
|
chainlit/logger.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
2
4
|
|
|
3
5
|
logging.basicConfig(
|
|
4
|
-
level=logging.INFO,
|
|
6
|
+
level=logging.INFO,
|
|
7
|
+
stream=sys.stdout,
|
|
8
|
+
format="%(asctime)s - %(message)s",
|
|
9
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
5
10
|
)
|
|
6
11
|
|
|
7
12
|
logging.getLogger("socketio").setLevel(logging.ERROR)
|
|
8
13
|
logging.getLogger("engineio").setLevel(logging.ERROR)
|
|
9
|
-
logging.getLogger("geventwebsocket.handler").setLevel(logging.ERROR)
|
|
10
14
|
logging.getLogger("numexpr").setLevel(logging.ERROR)
|
|
11
15
|
|
|
16
|
+
|
|
12
17
|
logger = logging.getLogger("chainlit")
|
chainlit/message.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from typing import List, Dict, Union
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
import uuid
|
|
4
|
-
import time
|
|
5
4
|
import asyncio
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
6
|
|
|
7
7
|
from chainlit.telemetry import trace_event
|
|
8
|
-
from chainlit.
|
|
8
|
+
from chainlit.context import get_emitter
|
|
9
9
|
from chainlit.config import config
|
|
10
10
|
from chainlit.types import (
|
|
11
11
|
LLMSettings,
|
|
@@ -16,11 +16,7 @@ from chainlit.types import (
|
|
|
16
16
|
)
|
|
17
17
|
from chainlit.element import Element
|
|
18
18
|
from chainlit.action import Action
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def current_milli_time():
|
|
22
|
-
"""Get the current time in milliseconds."""
|
|
23
|
-
return round(time.time() * 1000)
|
|
19
|
+
from chainlit.logger import logger
|
|
24
20
|
|
|
25
21
|
|
|
26
22
|
class MessageBase(ABC):
|
|
@@ -28,14 +24,13 @@ class MessageBase(ABC):
|
|
|
28
24
|
temp_id: str = None
|
|
29
25
|
streaming = False
|
|
30
26
|
created_at: int = None
|
|
27
|
+
fail_on_persist_error: bool = True
|
|
31
28
|
|
|
32
29
|
def __post_init__(self) -> None:
|
|
33
30
|
trace_event(f"init {self.__class__.__name__}")
|
|
34
31
|
self.temp_id = uuid.uuid4().hex
|
|
35
|
-
self.created_at =
|
|
32
|
+
self.created_at = datetime.now(timezone.utc).isoformat()
|
|
36
33
|
self.emitter = get_emitter()
|
|
37
|
-
if not self.emitter:
|
|
38
|
-
raise RuntimeError("Message should be instantiated in a Chainlit context")
|
|
39
34
|
|
|
40
35
|
@abstractmethod
|
|
41
36
|
def to_dict(self):
|
|
@@ -44,9 +39,14 @@ class MessageBase(ABC):
|
|
|
44
39
|
async def _create(self):
|
|
45
40
|
msg_dict = self.to_dict()
|
|
46
41
|
if self.emitter.client and not self.id:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
42
|
+
try:
|
|
43
|
+
self.id = await self.emitter.client.create_message(msg_dict)
|
|
44
|
+
if self.id:
|
|
45
|
+
msg_dict["id"] = self.id
|
|
46
|
+
except Exception as e:
|
|
47
|
+
if self.fail_on_persist_error:
|
|
48
|
+
raise e
|
|
49
|
+
logger.error(f"Failed to persist message: {str(e)}")
|
|
50
50
|
|
|
51
51
|
return msg_dict
|
|
52
52
|
|
|
@@ -77,8 +77,7 @@ class MessageBase(ABC):
|
|
|
77
77
|
msg_dict = self.to_dict()
|
|
78
78
|
|
|
79
79
|
if self.emitter.client and self.id:
|
|
80
|
-
self.emitter.client.update_message(self.id, msg_dict)
|
|
81
|
-
msg_dict["id"] = self.id
|
|
80
|
+
await self.emitter.client.update_message(self.id, msg_dict)
|
|
82
81
|
|
|
83
82
|
await self.emitter.update_message(msg_dict)
|
|
84
83
|
|
|
@@ -171,7 +170,7 @@ class Message(MessageBase):
|
|
|
171
170
|
super().__post_init__()
|
|
172
171
|
|
|
173
172
|
def to_dict(self):
|
|
174
|
-
|
|
173
|
+
_dict = {
|
|
175
174
|
"tempId": self.temp_id,
|
|
176
175
|
"createdAt": self.created_at,
|
|
177
176
|
"content": self.content,
|
|
@@ -182,6 +181,11 @@ class Message(MessageBase):
|
|
|
182
181
|
"indent": self.indent,
|
|
183
182
|
}
|
|
184
183
|
|
|
184
|
+
if self.id:
|
|
185
|
+
_dict["id"] = self.id
|
|
186
|
+
|
|
187
|
+
return _dict
|
|
188
|
+
|
|
185
189
|
async def send(self):
|
|
186
190
|
"""
|
|
187
191
|
Send the message to the UI and persist it in the cloud if a project ID is configured.
|
|
@@ -214,10 +218,12 @@ class ErrorMessage(MessageBase):
|
|
|
214
218
|
content: str,
|
|
215
219
|
author: str = config.ui.name,
|
|
216
220
|
indent: int = 0,
|
|
221
|
+
fail_on_persist_error: bool = False,
|
|
217
222
|
):
|
|
218
223
|
self.content = content
|
|
219
224
|
self.author = author
|
|
220
225
|
self.indent = indent
|
|
226
|
+
self.fail_on_persist_error = fail_on_persist_error
|
|
221
227
|
|
|
222
228
|
super().__post_init__()
|
|
223
229
|
|
chainlit/server.py
CHANGED
|
@@ -6,11 +6,13 @@ mimetypes.add_type("text/css", ".css")
|
|
|
6
6
|
import os
|
|
7
7
|
import json
|
|
8
8
|
import webbrowser
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
from contextlib import asynccontextmanager
|
|
11
13
|
from watchfiles import awatch
|
|
12
14
|
|
|
13
|
-
from fastapi import FastAPI
|
|
15
|
+
from fastapi import FastAPI, Request
|
|
14
16
|
from fastapi.responses import (
|
|
15
17
|
HTMLResponse,
|
|
16
18
|
JSONResponse,
|
|
@@ -21,17 +23,25 @@ from fastapi_socketio import SocketManager
|
|
|
21
23
|
from starlette.middleware.cors import CORSMiddleware
|
|
22
24
|
import asyncio
|
|
23
25
|
|
|
24
|
-
from chainlit.
|
|
26
|
+
from chainlit.context import emitter_var, loop_var
|
|
27
|
+
from chainlit.config import config, load_module, reload_config, DEFAULT_HOST
|
|
25
28
|
from chainlit.session import Session, sessions
|
|
26
29
|
from chainlit.user_session import user_sessions
|
|
27
|
-
from chainlit.client import CloudClient
|
|
30
|
+
from chainlit.client.cloud import CloudClient
|
|
31
|
+
from chainlit.client.local import LocalClient
|
|
32
|
+
from chainlit.client.utils import get_client
|
|
28
33
|
from chainlit.emitter import ChainlitEmitter
|
|
29
34
|
from chainlit.markdown import get_markdown_str
|
|
30
35
|
from chainlit.action import Action
|
|
31
36
|
from chainlit.message import Message, ErrorMessage
|
|
32
37
|
from chainlit.telemetry import trace_event
|
|
33
38
|
from chainlit.logger import logger
|
|
34
|
-
from chainlit.types import
|
|
39
|
+
from chainlit.types import (
|
|
40
|
+
CompletionRequest,
|
|
41
|
+
UpdateFeedbackRequest,
|
|
42
|
+
GetConversationsRequest,
|
|
43
|
+
DeleteConversationRequest,
|
|
44
|
+
)
|
|
35
45
|
|
|
36
46
|
|
|
37
47
|
@asynccontextmanager
|
|
@@ -39,32 +49,56 @@ async def lifespan(app: FastAPI):
|
|
|
39
49
|
host = config.run.host
|
|
40
50
|
port = config.run.port
|
|
41
51
|
|
|
42
|
-
if
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
52
|
+
if host == DEFAULT_HOST:
|
|
53
|
+
url = f"http://localhost:{port}"
|
|
54
|
+
else:
|
|
55
|
+
url = f"http://{host}:{port}"
|
|
56
|
+
|
|
57
|
+
logger.info(f"Your app is available at {url}")
|
|
47
58
|
|
|
48
|
-
|
|
59
|
+
if not config.run.headless:
|
|
60
|
+
# Add a delay before opening the browser
|
|
61
|
+
await asyncio.sleep(1)
|
|
49
62
|
webbrowser.open(url)
|
|
50
63
|
|
|
64
|
+
if config.project.database == "local":
|
|
65
|
+
from prisma import Client, register
|
|
66
|
+
|
|
67
|
+
client = Client()
|
|
68
|
+
register(client)
|
|
69
|
+
await client.connect()
|
|
70
|
+
|
|
51
71
|
watch_task = None
|
|
52
72
|
stop_event = asyncio.Event()
|
|
53
73
|
|
|
54
74
|
if config.run.watch:
|
|
55
75
|
|
|
56
76
|
async def watch_files_for_changes():
|
|
77
|
+
extensions = [".py"]
|
|
78
|
+
files = ["chainlit.md", "config.toml"]
|
|
57
79
|
async for changes in awatch(config.root, stop_event=stop_event):
|
|
58
80
|
for change_type, file_path in changes:
|
|
59
81
|
file_name = os.path.basename(file_path)
|
|
60
82
|
file_ext = os.path.splitext(file_name)[1]
|
|
61
83
|
|
|
62
|
-
if file_ext.lower()
|
|
63
|
-
logger.info(
|
|
84
|
+
if file_ext.lower() in extensions or file_name.lower() in files:
|
|
85
|
+
logger.info(
|
|
86
|
+
f"File {change_type.name}: {file_name}. Reloading app..."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
reload_config()
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"Error reloading config: {e}")
|
|
93
|
+
break
|
|
64
94
|
|
|
65
95
|
# Reload the module if the module name is specified in the config
|
|
66
96
|
if config.run.module_name:
|
|
67
|
-
|
|
97
|
+
try:
|
|
98
|
+
load_module(config.run.module_name)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error(f"Error reloading module: {e}")
|
|
101
|
+
break
|
|
68
102
|
|
|
69
103
|
await socket.emit("reload", {})
|
|
70
104
|
|
|
@@ -74,12 +108,16 @@ async def lifespan(app: FastAPI):
|
|
|
74
108
|
|
|
75
109
|
try:
|
|
76
110
|
yield
|
|
77
|
-
except KeyboardInterrupt:
|
|
78
|
-
logger.error("KeyboardInterrupt received, stopping the watch task...")
|
|
79
111
|
finally:
|
|
112
|
+
if config.project.database == "local":
|
|
113
|
+
await client.disconnect()
|
|
80
114
|
if watch_task:
|
|
81
|
-
|
|
82
|
-
|
|
115
|
+
try:
|
|
116
|
+
stop_event.set()
|
|
117
|
+
watch_task.cancel()
|
|
118
|
+
await watch_task
|
|
119
|
+
except asyncio.exceptions.CancelledError:
|
|
120
|
+
pass
|
|
83
121
|
|
|
84
122
|
|
|
85
123
|
root_dir = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -187,6 +225,80 @@ async def project_settings():
|
|
|
187
225
|
)
|
|
188
226
|
|
|
189
227
|
|
|
228
|
+
@app.put("/message/feedback")
|
|
229
|
+
async def update_feedback(request: Request, update: UpdateFeedbackRequest):
|
|
230
|
+
"""Update the human feedback for a particular message."""
|
|
231
|
+
|
|
232
|
+
client = await get_client(request)
|
|
233
|
+
await client.set_human_feedback(
|
|
234
|
+
message_id=update.messageId, feedback=update.feedback
|
|
235
|
+
)
|
|
236
|
+
return JSONResponse(content={"success": True})
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@app.get("/project/members")
|
|
240
|
+
async def get_project_members(request: Request):
|
|
241
|
+
"""Get all the members of a project."""
|
|
242
|
+
|
|
243
|
+
client = await get_client(request)
|
|
244
|
+
res = await client.get_project_members()
|
|
245
|
+
return JSONResponse(content=res)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@app.get("/project/role")
|
|
249
|
+
async def get_member_role(request: Request):
|
|
250
|
+
"""Get the role of a member."""
|
|
251
|
+
|
|
252
|
+
client = await get_client(request)
|
|
253
|
+
res = await client.get_member_role()
|
|
254
|
+
return PlainTextResponse(content=res)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@app.post("/project/conversations")
|
|
258
|
+
async def get_project_conversations(request: Request, payload: GetConversationsRequest):
|
|
259
|
+
"""Get the conversations page by page."""
|
|
260
|
+
|
|
261
|
+
client = await get_client(request)
|
|
262
|
+
res = await client.get_conversations(payload.pagination, payload.filter)
|
|
263
|
+
return JSONResponse(content=res.to_dict())
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@app.get("/project/conversation/{conversation_id}")
|
|
267
|
+
async def get_conversation(request: Request, conversation_id: str):
|
|
268
|
+
"""Get a specific conversation."""
|
|
269
|
+
|
|
270
|
+
client = await get_client(request)
|
|
271
|
+
res = await client.get_conversation(int(conversation_id))
|
|
272
|
+
return JSONResponse(content=res)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@app.get("/project/conversation/{conversation_id}/element/{element_id}")
|
|
276
|
+
async def get_conversation(request: Request, conversation_id: str, element_id: str):
|
|
277
|
+
"""Get a specific conversation."""
|
|
278
|
+
|
|
279
|
+
client = await get_client(request)
|
|
280
|
+
res = await client.get_element(int(conversation_id), int(element_id))
|
|
281
|
+
return JSONResponse(content=res)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@app.delete("/project/conversation")
|
|
285
|
+
async def delete_conversation(request: Request, payload: DeleteConversationRequest):
|
|
286
|
+
"""Delete a conversation."""
|
|
287
|
+
|
|
288
|
+
client = await get_client(request)
|
|
289
|
+
await client.delete_conversation(conversation_id=payload.conversationId)
|
|
290
|
+
return JSONResponse(content={"success": True})
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@app.get("/files/{filename:path}")
|
|
294
|
+
async def serve_file(filename: str):
|
|
295
|
+
file_path = Path(config.project.local_fs_path) / filename
|
|
296
|
+
if file_path.is_file():
|
|
297
|
+
return FileResponse(file_path)
|
|
298
|
+
else:
|
|
299
|
+
return {"error": "File not found"}
|
|
300
|
+
|
|
301
|
+
|
|
190
302
|
@app.get("/{path:path}")
|
|
191
303
|
async def serve(path: str):
|
|
192
304
|
"""Serve the UI."""
|
|
@@ -217,36 +329,30 @@ def need_session(id: str):
|
|
|
217
329
|
async def connect(sid, environ):
|
|
218
330
|
user_env = environ.get("HTTP_USER_ENV")
|
|
219
331
|
authorization = environ.get("HTTP_AUTHORIZATION")
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
# Check decorated functions
|
|
223
|
-
if (
|
|
224
|
-
not config.code.lc_factory
|
|
225
|
-
and not config.code.on_message
|
|
226
|
-
and not config.code.on_chat_start
|
|
227
|
-
):
|
|
228
|
-
logger.error(
|
|
229
|
-
"Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
|
|
230
|
-
)
|
|
231
|
-
return False
|
|
332
|
+
client = None
|
|
232
333
|
|
|
233
334
|
# Check authorization
|
|
234
335
|
if not config.project.public and not authorization:
|
|
235
336
|
# Refuse connection if the app is private and no access token is provided
|
|
236
337
|
trace_event("no_access_token")
|
|
237
|
-
logger.error("No access token provided")
|
|
338
|
+
logger.error("Connection refused: No access token provided")
|
|
238
339
|
return False
|
|
239
|
-
elif authorization and config.project.id:
|
|
340
|
+
elif authorization and config.project.id and config.project.database == "cloud":
|
|
240
341
|
# Create the cloud client
|
|
241
|
-
|
|
342
|
+
client = CloudClient(
|
|
242
343
|
project_id=config.project.id,
|
|
243
|
-
session_id=sid,
|
|
244
344
|
access_token=authorization,
|
|
245
345
|
)
|
|
246
|
-
is_project_member = await
|
|
346
|
+
is_project_member = await client.is_project_member()
|
|
247
347
|
if not is_project_member:
|
|
248
|
-
logger.error("You are not a member of this project")
|
|
348
|
+
logger.error("Connection refused: You are not a member of this project")
|
|
249
349
|
return False
|
|
350
|
+
elif config.project.database == "local":
|
|
351
|
+
client = LocalClient()
|
|
352
|
+
elif config.project.database == "custom":
|
|
353
|
+
if not config.code.client_factory:
|
|
354
|
+
raise ValueError("Client factory not provided")
|
|
355
|
+
client = await config.code.client_factory()
|
|
250
356
|
|
|
251
357
|
# Check user env
|
|
252
358
|
if config.project.user_env:
|
|
@@ -256,10 +362,12 @@ async def connect(sid, environ):
|
|
|
256
362
|
for key in config.project.user_env:
|
|
257
363
|
if key not in user_env:
|
|
258
364
|
trace_event("missing_user_env")
|
|
259
|
-
logger.error(
|
|
365
|
+
logger.error(
|
|
366
|
+
"Connection refused: Missing user environment variable: " + key
|
|
367
|
+
)
|
|
260
368
|
return False
|
|
261
369
|
else:
|
|
262
|
-
logger.error("Missing user environment variables")
|
|
370
|
+
logger.error("Connection refused: Missing user environment variables")
|
|
263
371
|
return False
|
|
264
372
|
|
|
265
373
|
# Create the session
|
|
@@ -283,9 +391,8 @@ async def connect(sid, environ):
|
|
|
283
391
|
"id": sid,
|
|
284
392
|
"emit": emit_fn,
|
|
285
393
|
"ask_user": ask_user_fn,
|
|
286
|
-
"client":
|
|
394
|
+
"client": client,
|
|
287
395
|
"user_env": user_env,
|
|
288
|
-
"running_sync": False,
|
|
289
396
|
"should_stop": False,
|
|
290
397
|
} # type: Session
|
|
291
398
|
|
|
@@ -298,15 +405,17 @@ async def connect(sid, environ):
|
|
|
298
405
|
@socket.on("connection_successful")
|
|
299
406
|
async def connection_successful(sid):
|
|
300
407
|
session = need_session(sid)
|
|
301
|
-
|
|
408
|
+
emitter_var.set(ChainlitEmitter(session))
|
|
409
|
+
loop_var.set(asyncio.get_event_loop())
|
|
410
|
+
|
|
302
411
|
if config.code.lc_factory:
|
|
303
412
|
"""Instantiate the langchain agent and store it in the session."""
|
|
304
|
-
agent = await config.code.lc_factory(
|
|
413
|
+
agent = await config.code.lc_factory()
|
|
305
414
|
session["agent"] = agent
|
|
306
415
|
|
|
307
416
|
if config.code.on_chat_start:
|
|
308
417
|
"""Call the on_chat_start function provided by the developer."""
|
|
309
|
-
await config.code.on_chat_start(
|
|
418
|
+
await config.code.on_chat_start()
|
|
310
419
|
|
|
311
420
|
|
|
312
421
|
@socket.on("disconnect")
|
|
@@ -326,7 +435,8 @@ async def stop(sid):
|
|
|
326
435
|
trace_event("stop_task")
|
|
327
436
|
session = sessions[sid]
|
|
328
437
|
|
|
329
|
-
|
|
438
|
+
emitter_var.set(ChainlitEmitter(session))
|
|
439
|
+
loop_var.set(asyncio.get_event_loop())
|
|
330
440
|
|
|
331
441
|
await Message(author="System", content="Task stopped by the user.").send()
|
|
332
442
|
|
|
@@ -340,8 +450,11 @@ async def process_message(session: Session, author: str, input_str: str):
|
|
|
340
450
|
"""Process a message from the user."""
|
|
341
451
|
|
|
342
452
|
try:
|
|
343
|
-
|
|
344
|
-
|
|
453
|
+
emitter = ChainlitEmitter(session)
|
|
454
|
+
emitter_var.set(emitter)
|
|
455
|
+
loop_var.set(asyncio.get_event_loop())
|
|
456
|
+
|
|
457
|
+
await emitter.task_start()
|
|
345
458
|
|
|
346
459
|
if session["client"]:
|
|
347
460
|
# If cloud is enabled, persist the message
|
|
@@ -363,7 +476,6 @@ async def process_message(session: Session, author: str, input_str: str):
|
|
|
363
476
|
await config.code.lc_run(
|
|
364
477
|
langchain_agent,
|
|
365
478
|
input_str,
|
|
366
|
-
__chainlit_emitter__=__chainlit_emitter__,
|
|
367
479
|
)
|
|
368
480
|
return
|
|
369
481
|
else:
|
|
@@ -374,9 +486,7 @@ async def process_message(session: Session, author: str, input_str: str):
|
|
|
374
486
|
|
|
375
487
|
if config.code.lc_postprocess:
|
|
376
488
|
# If the developer provided a custom postprocess function, use it
|
|
377
|
-
await config.code.lc_postprocess(
|
|
378
|
-
raw_res, __chainlit_emitter__=__chainlit_emitter__
|
|
379
|
-
)
|
|
489
|
+
await config.code.lc_postprocess(raw_res)
|
|
380
490
|
return
|
|
381
491
|
elif output_key is not None:
|
|
382
492
|
# Use the output key if provided
|
|
@@ -389,16 +499,16 @@ async def process_message(session: Session, author: str, input_str: str):
|
|
|
389
499
|
|
|
390
500
|
elif config.code.on_message:
|
|
391
501
|
# If no langchain agent is available, call the on_message function provided by the developer
|
|
392
|
-
await config.code.on_message(
|
|
393
|
-
input_str, __chainlit_emitter__=__chainlit_emitter__
|
|
394
|
-
)
|
|
502
|
+
await config.code.on_message(input_str)
|
|
395
503
|
except InterruptedError:
|
|
396
504
|
pass
|
|
397
505
|
except Exception as e:
|
|
398
506
|
logger.exception(e)
|
|
399
|
-
await ErrorMessage(
|
|
507
|
+
await ErrorMessage(
|
|
508
|
+
author="Error", content=str(e) or e.__class__.__name__
|
|
509
|
+
).send()
|
|
400
510
|
finally:
|
|
401
|
-
await
|
|
511
|
+
await emitter.task_end()
|
|
402
512
|
|
|
403
513
|
|
|
404
514
|
@socket.on("ui_message")
|
|
@@ -413,11 +523,10 @@ async def message(sid, data):
|
|
|
413
523
|
await process_message(session, author, input_str)
|
|
414
524
|
|
|
415
525
|
|
|
416
|
-
async def process_action(
|
|
417
|
-
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
526
|
+
async def process_action(action: Action):
|
|
418
527
|
callback = config.code.action_callbacks.get(action.name)
|
|
419
528
|
if callback:
|
|
420
|
-
await callback(action
|
|
529
|
+
await callback(action)
|
|
421
530
|
else:
|
|
422
531
|
logger.warning("No callback found for action %s", action.name)
|
|
423
532
|
|
|
@@ -426,8 +535,9 @@ async def process_action(session: Session, action: Action):
|
|
|
426
535
|
async def call_action(sid, action):
|
|
427
536
|
"""Handle an action call from the UI."""
|
|
428
537
|
session = need_session(sid)
|
|
538
|
+
emitter_var.set(ChainlitEmitter(session))
|
|
539
|
+
loop_var.set(asyncio.get_event_loop())
|
|
429
540
|
|
|
430
|
-
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
431
541
|
action = Action(**action)
|
|
432
542
|
|
|
433
|
-
await process_action(
|
|
543
|
+
await process_action(action)
|
chainlit/session.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from typing import Dict, TypedDict, Optional, Callable, Any, Union
|
|
2
|
-
from chainlit.client import BaseClient
|
|
2
|
+
from chainlit.client.base import BaseClient
|
|
3
3
|
from chainlit.types import AskResponse
|
|
4
4
|
|
|
5
5
|
|
|
@@ -16,8 +16,6 @@ class Session(TypedDict):
|
|
|
16
16
|
user_env: Dict[str, str]
|
|
17
17
|
# Optional langchain agent
|
|
18
18
|
agent: Any
|
|
19
|
-
# If the session is currently running a sync task
|
|
20
|
-
running_sync: bool
|
|
21
19
|
# Whether the current task should be stopped
|
|
22
20
|
should_stop: bool
|
|
23
21
|
# Optional client to persist messages and files
|
chainlit/sync.py
CHANGED
|
@@ -1,37 +1,25 @@
|
|
|
1
|
-
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any, TypeVar, Coroutine
|
|
3
|
+
|
|
4
|
+
if sys.version_info >= (3, 10):
|
|
5
|
+
from typing import ParamSpec
|
|
6
|
+
else:
|
|
7
|
+
from typing_extensions import ParamSpec
|
|
2
8
|
|
|
3
9
|
import asyncio
|
|
4
|
-
from syncer import sync
|
|
5
10
|
from asyncer import asyncify
|
|
6
11
|
|
|
7
|
-
from chainlit.
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def make_async(function: Callable):
|
|
11
|
-
emitter = get_emitter()
|
|
12
|
-
if not emitter:
|
|
13
|
-
raise RuntimeError(
|
|
14
|
-
"Emitter not found, please call make_async in a Chainlit context."
|
|
15
|
-
)
|
|
12
|
+
from chainlit.context import get_loop
|
|
16
13
|
|
|
17
|
-
def wrapper(*args, **kwargs):
|
|
18
|
-
emitter.session["running_sync"] = True
|
|
19
|
-
__chainlit_emitter__ = emitter
|
|
20
|
-
res = function(*args, **kwargs)
|
|
21
|
-
emitter.session["running_sync"] = False
|
|
22
|
-
return res
|
|
23
14
|
|
|
24
|
-
|
|
15
|
+
make_async = asyncify
|
|
25
16
|
|
|
17
|
+
T_Retval = TypeVar("T_Retval")
|
|
18
|
+
T_ParamSpec = ParamSpec("T_ParamSpec")
|
|
19
|
+
T = TypeVar("T")
|
|
26
20
|
|
|
27
|
-
def run_sync(co: Any):
|
|
28
|
-
try:
|
|
29
|
-
loop = asyncio.get_event_loop()
|
|
30
|
-
except RuntimeError as e:
|
|
31
|
-
if "There is no current event loop" in str(e):
|
|
32
|
-
loop = None
|
|
33
21
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
return
|
|
22
|
+
def run_sync(co: Coroutine[Any, Any, T_Retval]) -> T_Retval:
|
|
23
|
+
loop = get_loop()
|
|
24
|
+
result = asyncio.run_coroutine_threadsafe(co, loop=loop)
|
|
25
|
+
return result.result()
|