chainlit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +41 -130
- chainlit/action.py +2 -4
- chainlit/cli/__init__.py +64 -9
- chainlit/cli/mock.py +1571 -7
- chainlit/client/base.py +152 -0
- chainlit/client/cloud.py +440 -0
- chainlit/client/local.py +257 -0
- chainlit/client/utils.py +23 -0
- chainlit/config.py +31 -5
- chainlit/context.py +29 -0
- chainlit/db/__init__.py +35 -0
- chainlit/db/prisma/schema.prisma +48 -0
- chainlit/element.py +54 -41
- chainlit/emitter.py +1 -30
- chainlit/frontend/dist/assets/{index-51a1a88f.js → index-37b5009c.js} +1 -1
- chainlit/frontend/dist/assets/index-51393291.js +523 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/langflow/__init__.py +75 -0
- chainlit/lc/__init__.py +85 -0
- chainlit/lc/agent.py +9 -5
- chainlit/lc/callbacks.py +9 -24
- chainlit/llama_index/__init__.py +34 -0
- chainlit/llama_index/callbacks.py +99 -0
- chainlit/llama_index/run.py +34 -0
- chainlit/logger.py +7 -2
- chainlit/message.py +25 -19
- chainlit/server.py +149 -38
- chainlit/session.py +3 -3
- chainlit/sync.py +20 -27
- chainlit/types.py +26 -1
- chainlit/user_session.py +1 -1
- chainlit/utils.py +51 -0
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/METADATA +7 -3
- chainlit-0.4.3.dist-info/RECORD +49 -0
- chainlit/client.py +0 -287
- chainlit/frontend/dist/assets/index-68c36c96.js +0 -707
- chainlit-0.4.1.dist-info/RECORD +0 -38
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/WHEEL +0 -0
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/entry_points.txt +0 -0
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
<script>
|
|
15
15
|
const global = globalThis;
|
|
16
16
|
</script>
|
|
17
|
-
<script type="module" crossorigin src="/assets/index-
|
|
17
|
+
<script type="module" crossorigin src="/assets/index-51393291.js"></script>
|
|
18
18
|
<link rel="stylesheet" href="/assets/index-f93cc942.css">
|
|
19
19
|
</head>
|
|
20
20
|
<body>
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import langflow
|
|
3
|
+
|
|
4
|
+
if langflow.__version__ < "0.1.4":
|
|
5
|
+
raise ValueError(
|
|
6
|
+
"LlamaIndex version is too old, expected >= 0.1.4. Run `pip install langflow --upgrade`"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
LANGFLOW_INSTALLED = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
LANGFLOW_INSTALLED = False
|
|
12
|
+
|
|
13
|
+
from typing import Callable, Union, Dict, Optional
|
|
14
|
+
import aiohttp
|
|
15
|
+
|
|
16
|
+
from chainlit.telemetry import trace
|
|
17
|
+
from chainlit.config import config
|
|
18
|
+
from chainlit.lc import langchain_factory
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@trace
|
|
22
|
+
def langflow_factory(
|
|
23
|
+
use_async: bool, schema: Union[Dict, str], tweaks: Optional[Dict] = None
|
|
24
|
+
) -> Callable:
|
|
25
|
+
"""
|
|
26
|
+
Plug and play decorator for the Langflow library.
|
|
27
|
+
One instance per user session is created and cached.
|
|
28
|
+
The per user instance is called every time a new message is received.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
use_async bool: Whether to call the the agent asynchronously or not.
|
|
32
|
+
schema (Union[Dict, str]): The langflow schema dict or url.
|
|
33
|
+
tweaks Optional[Dict]: Optional tweaks to be processed
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Callable[[], Any]: The decorated factory function.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Check if the factory is called with the correct parameter
|
|
41
|
+
if type(schema) not in [dict, str]:
|
|
42
|
+
error_message = "langflow_factory schema parameter is required"
|
|
43
|
+
raise ValueError(error_message)
|
|
44
|
+
|
|
45
|
+
# Check if the factory is called with the correct parameter
|
|
46
|
+
if type(use_async) != bool:
|
|
47
|
+
error_message = "langflow_factory use_async parameter is required"
|
|
48
|
+
raise ValueError(error_message)
|
|
49
|
+
|
|
50
|
+
config.code.langflow_schema = schema
|
|
51
|
+
|
|
52
|
+
def decorator(func: Callable) -> Callable:
|
|
53
|
+
async def wrapper():
|
|
54
|
+
from langflow import load_flow_from_json
|
|
55
|
+
|
|
56
|
+
schema = config.code.langflow_schema
|
|
57
|
+
|
|
58
|
+
if type(schema) == str:
|
|
59
|
+
async with aiohttp.ClientSession() as session:
|
|
60
|
+
async with session.get(
|
|
61
|
+
schema,
|
|
62
|
+
) as r:
|
|
63
|
+
if not r.ok:
|
|
64
|
+
reason = await r.text()
|
|
65
|
+
raise ValueError(f"Error: {reason}")
|
|
66
|
+
schema = await r.json()
|
|
67
|
+
|
|
68
|
+
flow = load_flow_from_json(input=schema, tweaks=tweaks)
|
|
69
|
+
return func(flow)
|
|
70
|
+
|
|
71
|
+
langchain_factory(use_async=use_async)(wrapper)
|
|
72
|
+
|
|
73
|
+
return func
|
|
74
|
+
|
|
75
|
+
return decorator
|
chainlit/lc/__init__.py
CHANGED
|
@@ -9,3 +9,88 @@ try:
|
|
|
9
9
|
LANGCHAIN_INSTALLED = True
|
|
10
10
|
except ImportError:
|
|
11
11
|
LANGCHAIN_INSTALLED = False
|
|
12
|
+
|
|
13
|
+
from chainlit.telemetry import trace
|
|
14
|
+
from typing import Callable, Any
|
|
15
|
+
|
|
16
|
+
from chainlit.config import config
|
|
17
|
+
from chainlit.utils import wrap_user_function
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@trace
|
|
21
|
+
def langchain_factory(use_async: bool) -> Callable:
|
|
22
|
+
"""
|
|
23
|
+
Plug and play decorator for the LangChain library.
|
|
24
|
+
The decorated function should instantiate a new LangChain instance (Chain, Agent...).
|
|
25
|
+
One instance per user session is created and cached.
|
|
26
|
+
The per user instance is called every time a new message is received.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
use_async bool: Whether to call the the agent asynchronously or not.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Callable[[], Any]: The decorated factory function.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Check if the factory is called with the correct parameter
|
|
36
|
+
if type(use_async) != bool:
|
|
37
|
+
error_message = "langchain_factory use_async parameter is required"
|
|
38
|
+
raise ValueError(error_message)
|
|
39
|
+
|
|
40
|
+
def decorator(func: Callable) -> Callable:
|
|
41
|
+
config.code.lc_factory = wrap_user_function(func, with_task=True)
|
|
42
|
+
return func
|
|
43
|
+
|
|
44
|
+
config.code.lc_agent_is_async = use_async
|
|
45
|
+
|
|
46
|
+
return decorator
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@trace
|
|
50
|
+
def langchain_postprocess(func: Callable[[Any], str]) -> Callable:
|
|
51
|
+
"""
|
|
52
|
+
Useful to post process the response a LangChain object instantiated with @langchain_factory.
|
|
53
|
+
The decorated function takes the raw output of the LangChain object as input.
|
|
54
|
+
The response will NOT be automatically sent to the UI, you need to send a Message.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
func (Callable[[Any], str]): The post-processing function to apply after generating a response. Takes the response as parameter.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Callable[[Any], str]: The decorated post-processing function.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
config.code.lc_postprocess = wrap_user_function(func)
|
|
64
|
+
return func
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@trace
|
|
68
|
+
def langchain_run(func: Callable[[Any, str], str]) -> Callable:
|
|
69
|
+
"""
|
|
70
|
+
Useful to override the default behavior of the LangChain object instantiated with @langchain_factory.
|
|
71
|
+
Use when your agent run method has custom parameters.
|
|
72
|
+
Takes the LangChain agent and the user input as parameters.
|
|
73
|
+
The response will NOT be automatically sent to the UI, you need to send a Message.
|
|
74
|
+
Args:
|
|
75
|
+
func (Callable[[Any, str], str]): The function to be called when a new message is received. Takes the agent and user input as parameters and returns the output string.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Callable[[Any, str], Any]: The decorated function.
|
|
79
|
+
"""
|
|
80
|
+
config.code.lc_run = wrap_user_function(func)
|
|
81
|
+
return func
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@trace
|
|
85
|
+
def langchain_rename(func: Callable[[str], str]) -> Callable[[str], str]:
|
|
86
|
+
"""
|
|
87
|
+
Useful to rename the LangChain tools/chains used in the agent and display more friendly author names in the UI.
|
|
88
|
+
Args:
|
|
89
|
+
func (Callable[[str], str]): The function to be called to rename a tool/chain. Takes the original tool/chain name as parameter.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Callable[[Any, str], Any]: The decorated function.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
config.code.lc_rename = wrap_user_function(func)
|
|
96
|
+
return func
|
chainlit/lc/agent.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
from typing import Any
|
|
2
|
-
from chainlit.lc.callbacks import
|
|
2
|
+
from chainlit.lc.callbacks import (
|
|
3
|
+
LangchainCallbackHandler,
|
|
4
|
+
AsyncLangchainCallbackHandler,
|
|
5
|
+
)
|
|
3
6
|
from chainlit.sync import make_async
|
|
7
|
+
from chainlit.context import emitter_var
|
|
4
8
|
|
|
5
9
|
|
|
6
10
|
async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
|
|
@@ -8,20 +12,20 @@ async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
|
|
|
8
12
|
input_key = agent.input_keys[0]
|
|
9
13
|
if use_async:
|
|
10
14
|
raw_res = await agent.acall(
|
|
11
|
-
{input_key: input_str}, callbacks=[
|
|
15
|
+
{input_key: input_str}, callbacks=[AsyncLangchainCallbackHandler()]
|
|
12
16
|
)
|
|
13
17
|
else:
|
|
14
18
|
raw_res = await make_async(agent.__call__)(
|
|
15
|
-
{input_key: input_str}, callbacks=[
|
|
19
|
+
{input_key: input_str}, callbacks=[LangchainCallbackHandler()]
|
|
16
20
|
)
|
|
17
21
|
else:
|
|
18
22
|
if use_async:
|
|
19
23
|
raw_res = await agent.acall(
|
|
20
|
-
input_str, callbacks=[
|
|
24
|
+
input_str, callbacks=[AsyncLangchainCallbackHandler()]
|
|
21
25
|
)
|
|
22
26
|
else:
|
|
23
27
|
raw_res = await make_async(agent.__call__)(
|
|
24
|
-
input_str, callbacks=[
|
|
28
|
+
input_str, callbacks=[LangchainCallbackHandler()]
|
|
25
29
|
)
|
|
26
30
|
|
|
27
31
|
if hasattr(agent, "output_keys"):
|
chainlit/lc/callbacks.py
CHANGED
|
@@ -6,7 +6,8 @@ from langchain.schema import (
|
|
|
6
6
|
BaseMessage,
|
|
7
7
|
LLMResult,
|
|
8
8
|
)
|
|
9
|
-
from chainlit.emitter import
|
|
9
|
+
from chainlit.emitter import ChainlitEmitter
|
|
10
|
+
from chainlit.context import get_emitter
|
|
10
11
|
from chainlit.message import Message, ErrorMessage
|
|
11
12
|
from chainlit.config import config
|
|
12
13
|
from chainlit.types import LLMSettings
|
|
@@ -37,7 +38,7 @@ def get_llm_settings(invocation_params: Union[Dict, None]):
|
|
|
37
38
|
return None
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
class
|
|
41
|
+
class BaseLangchainCallbackHandler(BaseCallbackHandler):
|
|
41
42
|
emitter: ChainlitEmitter
|
|
42
43
|
# Keep track of the formatted prompts to display them in the prompt playground.
|
|
43
44
|
prompts: List[str]
|
|
@@ -99,7 +100,7 @@ class BaseChainlitCallbackHandler(BaseCallbackHandler):
|
|
|
99
100
|
return author, indent, llm_settings
|
|
100
101
|
|
|
101
102
|
|
|
102
|
-
class
|
|
103
|
+
class LangchainCallbackHandler(BaseLangchainCallbackHandler, BaseCallbackHandler):
|
|
103
104
|
def start_stream(self):
|
|
104
105
|
author, indent, llm_settings = self.get_message_params()
|
|
105
106
|
|
|
@@ -107,14 +108,10 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
|
|
|
107
108
|
return
|
|
108
109
|
|
|
109
110
|
if config.code.lc_rename:
|
|
110
|
-
author = run_sync(
|
|
111
|
-
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
|
|
112
|
-
)
|
|
111
|
+
author = run_sync(config.code.lc_rename(author))
|
|
113
112
|
|
|
114
113
|
self.pop_prompt()
|
|
115
114
|
|
|
116
|
-
__chainlit_emitter__ = self.emitter
|
|
117
|
-
|
|
118
115
|
streamed_message = Message(
|
|
119
116
|
author=author,
|
|
120
117
|
indent=indent,
|
|
@@ -135,11 +132,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
|
|
|
135
132
|
return
|
|
136
133
|
|
|
137
134
|
if config.code.lc_rename:
|
|
138
|
-
author = run_sync(
|
|
139
|
-
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
__chainlit_emitter__ = self.emitter
|
|
135
|
+
author = run_sync(config.code.lc_rename(author))
|
|
143
136
|
|
|
144
137
|
if error:
|
|
145
138
|
run_sync(ErrorMessage(author=author, content=message).send())
|
|
@@ -259,7 +252,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
|
|
|
259
252
|
pass
|
|
260
253
|
|
|
261
254
|
|
|
262
|
-
class
|
|
255
|
+
class AsyncLangchainCallbackHandler(BaseLangchainCallbackHandler, AsyncCallbackHandler):
|
|
263
256
|
async def start_stream(self):
|
|
264
257
|
author, indent, llm_settings = self.get_message_params()
|
|
265
258
|
|
|
@@ -267,14 +260,10 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
|
|
|
267
260
|
return
|
|
268
261
|
|
|
269
262
|
if config.code.lc_rename:
|
|
270
|
-
author = await config.code.lc_rename(
|
|
271
|
-
author, __chainlit_emitter__=self.emitter
|
|
272
|
-
)
|
|
263
|
+
author = await config.code.lc_rename(author)
|
|
273
264
|
|
|
274
265
|
self.pop_prompt()
|
|
275
266
|
|
|
276
|
-
__chainlit_emitter__ = self.emitter
|
|
277
|
-
|
|
278
267
|
streamed_message = Message(
|
|
279
268
|
author=author,
|
|
280
269
|
indent=indent,
|
|
@@ -295,11 +284,7 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
|
|
|
295
284
|
return
|
|
296
285
|
|
|
297
286
|
if config.code.lc_rename:
|
|
298
|
-
author = await config.code.lc_rename(
|
|
299
|
-
author, __chainlit_emitter__=self.emitter
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
__chainlit_emitter__ = self.emitter
|
|
287
|
+
author = await config.code.lc_rename(author)
|
|
303
288
|
|
|
304
289
|
if error:
|
|
305
290
|
await ErrorMessage(author=author, content=message).send()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import llama_index
|
|
3
|
+
|
|
4
|
+
if llama_index.__version__ < "0.6.27":
|
|
5
|
+
raise ValueError(
|
|
6
|
+
"LlamaIndex version is too old, expected >= 0.6.27. Run `pip install llama_index --upgrade`"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
LLAMA_INDEX_INSTALLED = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
LLAMA_INDEX_INSTALLED = False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from chainlit.telemetry import trace
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from chainlit.config import config
|
|
18
|
+
from chainlit.utils import wrap_user_function
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@trace
|
|
22
|
+
def llama_index_factory(func: Callable) -> Callable:
|
|
23
|
+
"""
|
|
24
|
+
Plug and play decorator for the Llama Index library.
|
|
25
|
+
The decorated function should instantiate a new Llama instance.
|
|
26
|
+
One instance per user session is created and cached.
|
|
27
|
+
The per user instance is called every time a new message is received.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Callable[[], Any]: The decorated factory function.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
config.code.llama_index_factory = wrap_user_function(func, with_task=True)
|
|
34
|
+
return func
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from llama_index.callbacks.base import BaseCallbackHandler
|
|
5
|
+
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from chainlit.message import Message
|
|
9
|
+
from chainlit.element import Text
|
|
10
|
+
from chainlit.sync import run_sync
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DEFAULT_IGNORE = [
|
|
14
|
+
CBEventType.CHUNKING,
|
|
15
|
+
CBEventType.SYNTHESIZE,
|
|
16
|
+
CBEventType.EMBEDDING,
|
|
17
|
+
CBEventType.NODE_PARSING,
|
|
18
|
+
CBEventType.QUERY,
|
|
19
|
+
CBEventType.TREE,
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LlamaIndexCallbackHandler(BaseCallbackHandler):
|
|
24
|
+
"""Base callback handler that can be used to track event starts and ends."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
|
|
29
|
+
event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initialize the base callback handler."""
|
|
32
|
+
self.event_starts_to_ignore = tuple(event_starts_to_ignore)
|
|
33
|
+
self.event_ends_to_ignore = tuple(event_ends_to_ignore)
|
|
34
|
+
|
|
35
|
+
def on_event_start(
|
|
36
|
+
self,
|
|
37
|
+
event_type: CBEventType,
|
|
38
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
39
|
+
event_id: str = "",
|
|
40
|
+
**kwargs: Any,
|
|
41
|
+
) -> str:
|
|
42
|
+
"""Run when an event starts and return id of event."""
|
|
43
|
+
run_sync(
|
|
44
|
+
Message(
|
|
45
|
+
author=event_type,
|
|
46
|
+
indent=1,
|
|
47
|
+
content="",
|
|
48
|
+
).send()
|
|
49
|
+
)
|
|
50
|
+
return ""
|
|
51
|
+
|
|
52
|
+
def on_event_end(
|
|
53
|
+
self,
|
|
54
|
+
event_type: CBEventType,
|
|
55
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
56
|
+
event_id: str = "",
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Run when an event ends."""
|
|
60
|
+
|
|
61
|
+
if event_type == CBEventType.RETRIEVE:
|
|
62
|
+
sources = payload.get(EventPayload.NODES)
|
|
63
|
+
if sources:
|
|
64
|
+
elements = [
|
|
65
|
+
Text(name=f"Source {idx}", content=source.node.get_text())
|
|
66
|
+
for idx, source in enumerate(sources)
|
|
67
|
+
]
|
|
68
|
+
source_refs = "\, ".join(
|
|
69
|
+
[f"Source {idx}" for idx, _ in enumerate(sources)]
|
|
70
|
+
)
|
|
71
|
+
content = f"Retrieved the following sources: {source_refs}"
|
|
72
|
+
|
|
73
|
+
run_sync(
|
|
74
|
+
Message(
|
|
75
|
+
content=content, author=event_type, elements=elements, indent=1
|
|
76
|
+
).send()
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if event_type == CBEventType.LLM:
|
|
80
|
+
run_sync(
|
|
81
|
+
Message(
|
|
82
|
+
content=payload.get(EventPayload.RESPONSE, ""),
|
|
83
|
+
author=event_type,
|
|
84
|
+
indent=1,
|
|
85
|
+
prompt=payload.get(EventPayload.PROMPT),
|
|
86
|
+
).send()
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
90
|
+
"""Run when an overall trace is launched."""
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
def end_trace(
|
|
94
|
+
self,
|
|
95
|
+
trace_id: Optional[str] = None,
|
|
96
|
+
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Run when an overall trace is exited."""
|
|
99
|
+
pass
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
from llama_index.response.schema import Response, StreamingResponse
|
|
3
|
+
from llama_index.chat_engine.types import BaseChatEngine
|
|
4
|
+
from llama_index.indices.query.base import BaseQueryEngine
|
|
5
|
+
|
|
6
|
+
from chainlit.message import Message
|
|
7
|
+
from chainlit.sync import make_async
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def run_llama(instance: Union[BaseChatEngine, BaseQueryEngine], input_str: str):
|
|
11
|
+
# Trick to display the loader in the UI until the first token is streamed
|
|
12
|
+
await Message(content="").send()
|
|
13
|
+
|
|
14
|
+
response_message = Message(content="")
|
|
15
|
+
|
|
16
|
+
if isinstance(instance, BaseQueryEngine):
|
|
17
|
+
response = await make_async(instance.query)(input_str)
|
|
18
|
+
elif isinstance(instance, BaseChatEngine):
|
|
19
|
+
response = await make_async(instance.chat)(input_str)
|
|
20
|
+
else:
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
if isinstance(response, Response):
|
|
24
|
+
response_message.content = str(response)
|
|
25
|
+
await response_message.send()
|
|
26
|
+
elif isinstance(response, StreamingResponse):
|
|
27
|
+
gen = response.response_gen
|
|
28
|
+
for token in gen:
|
|
29
|
+
await response_message.stream_token(token=token)
|
|
30
|
+
|
|
31
|
+
if response.response_txt:
|
|
32
|
+
response_message.content = response.response_txt
|
|
33
|
+
|
|
34
|
+
await response_message.send()
|
chainlit/logger.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
2
4
|
|
|
3
5
|
logging.basicConfig(
|
|
4
|
-
level=logging.INFO,
|
|
6
|
+
level=logging.INFO,
|
|
7
|
+
stream=sys.stdout,
|
|
8
|
+
format="%(asctime)s - %(message)s",
|
|
9
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
5
10
|
)
|
|
6
11
|
|
|
7
12
|
logging.getLogger("socketio").setLevel(logging.ERROR)
|
|
8
13
|
logging.getLogger("engineio").setLevel(logging.ERROR)
|
|
9
|
-
logging.getLogger("geventwebsocket.handler").setLevel(logging.ERROR)
|
|
10
14
|
logging.getLogger("numexpr").setLevel(logging.ERROR)
|
|
11
15
|
|
|
16
|
+
|
|
12
17
|
logger = logging.getLogger("chainlit")
|
chainlit/message.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from typing import List, Dict, Union
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
import uuid
|
|
4
|
-
import time
|
|
5
4
|
import asyncio
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
6
|
|
|
7
7
|
from chainlit.telemetry import trace_event
|
|
8
|
-
from chainlit.
|
|
8
|
+
from chainlit.context import get_emitter
|
|
9
9
|
from chainlit.config import config
|
|
10
10
|
from chainlit.types import (
|
|
11
11
|
LLMSettings,
|
|
@@ -16,11 +16,7 @@ from chainlit.types import (
|
|
|
16
16
|
)
|
|
17
17
|
from chainlit.element import Element
|
|
18
18
|
from chainlit.action import Action
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def current_milli_time():
|
|
22
|
-
"""Get the current time in milliseconds."""
|
|
23
|
-
return round(time.time() * 1000)
|
|
19
|
+
from chainlit.logger import logger
|
|
24
20
|
|
|
25
21
|
|
|
26
22
|
class MessageBase(ABC):
|
|
@@ -28,14 +24,13 @@ class MessageBase(ABC):
|
|
|
28
24
|
temp_id: str = None
|
|
29
25
|
streaming = False
|
|
30
26
|
created_at: int = None
|
|
27
|
+
fail_on_persist_error: bool = True
|
|
31
28
|
|
|
32
29
|
def __post_init__(self) -> None:
|
|
33
30
|
trace_event(f"init {self.__class__.__name__}")
|
|
34
31
|
self.temp_id = uuid.uuid4().hex
|
|
35
|
-
self.created_at =
|
|
32
|
+
self.created_at = datetime.now(timezone.utc).isoformat()
|
|
36
33
|
self.emitter = get_emitter()
|
|
37
|
-
if not self.emitter:
|
|
38
|
-
raise RuntimeError("Message should be instantiated in a Chainlit context")
|
|
39
34
|
|
|
40
35
|
@abstractmethod
|
|
41
36
|
def to_dict(self):
|
|
@@ -44,9 +39,14 @@ class MessageBase(ABC):
|
|
|
44
39
|
async def _create(self):
|
|
45
40
|
msg_dict = self.to_dict()
|
|
46
41
|
if self.emitter.client and not self.id:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
42
|
+
try:
|
|
43
|
+
self.id = await self.emitter.client.create_message(msg_dict)
|
|
44
|
+
if self.id:
|
|
45
|
+
msg_dict["id"] = self.id
|
|
46
|
+
except Exception as e:
|
|
47
|
+
if self.fail_on_persist_error:
|
|
48
|
+
raise e
|
|
49
|
+
logger.error(f"Failed to persist message: {str(e)}")
|
|
50
50
|
|
|
51
51
|
return msg_dict
|
|
52
52
|
|
|
@@ -77,8 +77,7 @@ class MessageBase(ABC):
|
|
|
77
77
|
msg_dict = self.to_dict()
|
|
78
78
|
|
|
79
79
|
if self.emitter.client and self.id:
|
|
80
|
-
self.emitter.client.update_message(self.id, msg_dict)
|
|
81
|
-
msg_dict["id"] = self.id
|
|
80
|
+
await self.emitter.client.update_message(self.id, msg_dict)
|
|
82
81
|
|
|
83
82
|
await self.emitter.update_message(msg_dict)
|
|
84
83
|
|
|
@@ -171,7 +170,7 @@ class Message(MessageBase):
|
|
|
171
170
|
super().__post_init__()
|
|
172
171
|
|
|
173
172
|
def to_dict(self):
|
|
174
|
-
|
|
173
|
+
_dict = {
|
|
175
174
|
"tempId": self.temp_id,
|
|
176
175
|
"createdAt": self.created_at,
|
|
177
176
|
"content": self.content,
|
|
@@ -182,6 +181,11 @@ class Message(MessageBase):
|
|
|
182
181
|
"indent": self.indent,
|
|
183
182
|
}
|
|
184
183
|
|
|
184
|
+
if self.id:
|
|
185
|
+
_dict["id"] = self.id
|
|
186
|
+
|
|
187
|
+
return _dict
|
|
188
|
+
|
|
185
189
|
async def send(self):
|
|
186
190
|
"""
|
|
187
191
|
Send the message to the UI and persist it in the cloud if a project ID is configured.
|
|
@@ -214,10 +218,12 @@ class ErrorMessage(MessageBase):
|
|
|
214
218
|
content: str,
|
|
215
219
|
author: str = config.ui.name,
|
|
216
220
|
indent: int = 0,
|
|
221
|
+
fail_on_persist_error: bool = False,
|
|
217
222
|
):
|
|
218
223
|
self.content = content
|
|
219
224
|
self.author = author
|
|
220
225
|
self.indent = indent
|
|
226
|
+
self.fail_on_persist_error = fail_on_persist_error
|
|
221
227
|
|
|
222
228
|
super().__post_init__()
|
|
223
229
|
|
|
@@ -241,10 +247,10 @@ class ErrorMessage(MessageBase):
|
|
|
241
247
|
|
|
242
248
|
|
|
243
249
|
class AskMessageBase(MessageBase):
|
|
244
|
-
def remove(self):
|
|
245
|
-
removed = super().remove()
|
|
250
|
+
async def remove(self):
|
|
251
|
+
removed = await super().remove()
|
|
246
252
|
if removed:
|
|
247
|
-
self.emitter.clear_ask()
|
|
253
|
+
await self.emitter.clear_ask()
|
|
248
254
|
|
|
249
255
|
|
|
250
256
|
class AskUserMessage(AskMessageBase):
|