chainlit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -14,7 +14,7 @@
14
14
  <script>
15
15
  const global = globalThis;
16
16
  </script>
17
- <script type="module" crossorigin src="/assets/index-68c36c96.js"></script>
17
+ <script type="module" crossorigin src="/assets/index-51393291.js"></script>
18
18
  <link rel="stylesheet" href="/assets/index-f93cc942.css">
19
19
  </head>
20
20
  <body>
@@ -0,0 +1,75 @@
1
+ try:
2
+ import langflow
3
+
4
+ if langflow.__version__ < "0.1.4":
5
+ raise ValueError(
6
+ "LlamaIndex version is too old, expected >= 0.1.4. Run `pip install langflow --upgrade`"
7
+ )
8
+
9
+ LANGFLOW_INSTALLED = True
10
+ except ImportError:
11
+ LANGFLOW_INSTALLED = False
12
+
13
+ from typing import Callable, Union, Dict, Optional
14
+ import aiohttp
15
+
16
+ from chainlit.telemetry import trace
17
+ from chainlit.config import config
18
+ from chainlit.lc import langchain_factory
19
+
20
+
21
+ @trace
22
+ def langflow_factory(
23
+ use_async: bool, schema: Union[Dict, str], tweaks: Optional[Dict] = None
24
+ ) -> Callable:
25
+ """
26
+ Plug and play decorator for the Langflow library.
27
+ One instance per user session is created and cached.
28
+ The per user instance is called every time a new message is received.
29
+
30
+ Args:
31
+ use_async bool: Whether to call the the agent asynchronously or not.
32
+ schema (Union[Dict, str]): The langflow schema dict or url.
33
+ tweaks Optional[Dict]: Optional tweaks to be processed
34
+
35
+
36
+ Returns:
37
+ Callable[[], Any]: The decorated factory function.
38
+ """
39
+
40
+ # Check if the factory is called with the correct parameter
41
+ if type(schema) not in [dict, str]:
42
+ error_message = "langflow_factory schema parameter is required"
43
+ raise ValueError(error_message)
44
+
45
+ # Check if the factory is called with the correct parameter
46
+ if type(use_async) != bool:
47
+ error_message = "langflow_factory use_async parameter is required"
48
+ raise ValueError(error_message)
49
+
50
+ config.code.langflow_schema = schema
51
+
52
+ def decorator(func: Callable) -> Callable:
53
+ async def wrapper():
54
+ from langflow import load_flow_from_json
55
+
56
+ schema = config.code.langflow_schema
57
+
58
+ if type(schema) == str:
59
+ async with aiohttp.ClientSession() as session:
60
+ async with session.get(
61
+ schema,
62
+ ) as r:
63
+ if not r.ok:
64
+ reason = await r.text()
65
+ raise ValueError(f"Error: {reason}")
66
+ schema = await r.json()
67
+
68
+ flow = load_flow_from_json(input=schema, tweaks=tweaks)
69
+ return func(flow)
70
+
71
+ langchain_factory(use_async=use_async)(wrapper)
72
+
73
+ return func
74
+
75
+ return decorator
chainlit/lc/__init__.py CHANGED
@@ -9,3 +9,88 @@ try:
9
9
  LANGCHAIN_INSTALLED = True
10
10
  except ImportError:
11
11
  LANGCHAIN_INSTALLED = False
12
+
13
+ from chainlit.telemetry import trace
14
+ from typing import Callable, Any
15
+
16
+ from chainlit.config import config
17
+ from chainlit.utils import wrap_user_function
18
+
19
+
20
+ @trace
21
+ def langchain_factory(use_async: bool) -> Callable:
22
+ """
23
+ Plug and play decorator for the LangChain library.
24
+ The decorated function should instantiate a new LangChain instance (Chain, Agent...).
25
+ One instance per user session is created and cached.
26
+ The per user instance is called every time a new message is received.
27
+
28
+ Args:
29
+ use_async bool: Whether to call the the agent asynchronously or not.
30
+
31
+ Returns:
32
+ Callable[[], Any]: The decorated factory function.
33
+ """
34
+
35
+ # Check if the factory is called with the correct parameter
36
+ if type(use_async) != bool:
37
+ error_message = "langchain_factory use_async parameter is required"
38
+ raise ValueError(error_message)
39
+
40
+ def decorator(func: Callable) -> Callable:
41
+ config.code.lc_factory = wrap_user_function(func, with_task=True)
42
+ return func
43
+
44
+ config.code.lc_agent_is_async = use_async
45
+
46
+ return decorator
47
+
48
+
49
+ @trace
50
+ def langchain_postprocess(func: Callable[[Any], str]) -> Callable:
51
+ """
52
+ Useful to post process the response a LangChain object instantiated with @langchain_factory.
53
+ The decorated function takes the raw output of the LangChain object as input.
54
+ The response will NOT be automatically sent to the UI, you need to send a Message.
55
+
56
+ Args:
57
+ func (Callable[[Any], str]): The post-processing function to apply after generating a response. Takes the response as parameter.
58
+
59
+ Returns:
60
+ Callable[[Any], str]: The decorated post-processing function.
61
+ """
62
+
63
+ config.code.lc_postprocess = wrap_user_function(func)
64
+ return func
65
+
66
+
67
+ @trace
68
+ def langchain_run(func: Callable[[Any, str], str]) -> Callable:
69
+ """
70
+ Useful to override the default behavior of the LangChain object instantiated with @langchain_factory.
71
+ Use when your agent run method has custom parameters.
72
+ Takes the LangChain agent and the user input as parameters.
73
+ The response will NOT be automatically sent to the UI, you need to send a Message.
74
+ Args:
75
+ func (Callable[[Any, str], str]): The function to be called when a new message is received. Takes the agent and user input as parameters and returns the output string.
76
+
77
+ Returns:
78
+ Callable[[Any, str], Any]: The decorated function.
79
+ """
80
+ config.code.lc_run = wrap_user_function(func)
81
+ return func
82
+
83
+
84
+ @trace
85
+ def langchain_rename(func: Callable[[str], str]) -> Callable[[str], str]:
86
+ """
87
+ Useful to rename the LangChain tools/chains used in the agent and display more friendly author names in the UI.
88
+ Args:
89
+ func (Callable[[str], str]): The function to be called to rename a tool/chain. Takes the original tool/chain name as parameter.
90
+
91
+ Returns:
92
+ Callable[[Any, str], Any]: The decorated function.
93
+ """
94
+
95
+ config.code.lc_rename = wrap_user_function(func)
96
+ return func
chainlit/lc/agent.py CHANGED
@@ -1,6 +1,10 @@
1
1
  from typing import Any
2
- from chainlit.lc.callbacks import ChainlitCallbackHandler, AsyncChainlitCallbackHandler
2
+ from chainlit.lc.callbacks import (
3
+ LangchainCallbackHandler,
4
+ AsyncLangchainCallbackHandler,
5
+ )
3
6
  from chainlit.sync import make_async
7
+ from chainlit.context import emitter_var
4
8
 
5
9
 
6
10
  async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
@@ -8,20 +12,20 @@ async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
8
12
  input_key = agent.input_keys[0]
9
13
  if use_async:
10
14
  raw_res = await agent.acall(
11
- {input_key: input_str}, callbacks=[AsyncChainlitCallbackHandler()]
15
+ {input_key: input_str}, callbacks=[AsyncLangchainCallbackHandler()]
12
16
  )
13
17
  else:
14
18
  raw_res = await make_async(agent.__call__)(
15
- {input_key: input_str}, callbacks=[ChainlitCallbackHandler()]
19
+ {input_key: input_str}, callbacks=[LangchainCallbackHandler()]
16
20
  )
17
21
  else:
18
22
  if use_async:
19
23
  raw_res = await agent.acall(
20
- input_str, callbacks=[AsyncChainlitCallbackHandler()]
24
+ input_str, callbacks=[AsyncLangchainCallbackHandler()]
21
25
  )
22
26
  else:
23
27
  raw_res = await make_async(agent.__call__)(
24
- input_str, callbacks=[ChainlitCallbackHandler()]
28
+ input_str, callbacks=[LangchainCallbackHandler()]
25
29
  )
26
30
 
27
31
  if hasattr(agent, "output_keys"):
chainlit/lc/callbacks.py CHANGED
@@ -6,7 +6,8 @@ from langchain.schema import (
6
6
  BaseMessage,
7
7
  LLMResult,
8
8
  )
9
- from chainlit.emitter import get_emitter, ChainlitEmitter
9
+ from chainlit.emitter import ChainlitEmitter
10
+ from chainlit.context import get_emitter
10
11
  from chainlit.message import Message, ErrorMessage
11
12
  from chainlit.config import config
12
13
  from chainlit.types import LLMSettings
@@ -37,7 +38,7 @@ def get_llm_settings(invocation_params: Union[Dict, None]):
37
38
  return None
38
39
 
39
40
 
40
- class BaseChainlitCallbackHandler(BaseCallbackHandler):
41
+ class BaseLangchainCallbackHandler(BaseCallbackHandler):
41
42
  emitter: ChainlitEmitter
42
43
  # Keep track of the formatted prompts to display them in the prompt playground.
43
44
  prompts: List[str]
@@ -99,7 +100,7 @@ class BaseChainlitCallbackHandler(BaseCallbackHandler):
99
100
  return author, indent, llm_settings
100
101
 
101
102
 
102
- class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
103
+ class LangchainCallbackHandler(BaseLangchainCallbackHandler, BaseCallbackHandler):
103
104
  def start_stream(self):
104
105
  author, indent, llm_settings = self.get_message_params()
105
106
 
@@ -107,14 +108,10 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
107
108
  return
108
109
 
109
110
  if config.code.lc_rename:
110
- author = run_sync(
111
- config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
112
- )
111
+ author = run_sync(config.code.lc_rename(author))
113
112
 
114
113
  self.pop_prompt()
115
114
 
116
- __chainlit_emitter__ = self.emitter
117
-
118
115
  streamed_message = Message(
119
116
  author=author,
120
117
  indent=indent,
@@ -135,11 +132,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
135
132
  return
136
133
 
137
134
  if config.code.lc_rename:
138
- author = run_sync(
139
- config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
140
- )
141
-
142
- __chainlit_emitter__ = self.emitter
135
+ author = run_sync(config.code.lc_rename(author))
143
136
 
144
137
  if error:
145
138
  run_sync(ErrorMessage(author=author, content=message).send())
@@ -259,7 +252,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
259
252
  pass
260
253
 
261
254
 
262
- class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHandler):
255
+ class AsyncLangchainCallbackHandler(BaseLangchainCallbackHandler, AsyncCallbackHandler):
263
256
  async def start_stream(self):
264
257
  author, indent, llm_settings = self.get_message_params()
265
258
 
@@ -267,14 +260,10 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
267
260
  return
268
261
 
269
262
  if config.code.lc_rename:
270
- author = await config.code.lc_rename(
271
- author, __chainlit_emitter__=self.emitter
272
- )
263
+ author = await config.code.lc_rename(author)
273
264
 
274
265
  self.pop_prompt()
275
266
 
276
- __chainlit_emitter__ = self.emitter
277
-
278
267
  streamed_message = Message(
279
268
  author=author,
280
269
  indent=indent,
@@ -295,11 +284,7 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
295
284
  return
296
285
 
297
286
  if config.code.lc_rename:
298
- author = await config.code.lc_rename(
299
- author, __chainlit_emitter__=self.emitter
300
- )
301
-
302
- __chainlit_emitter__ = self.emitter
287
+ author = await config.code.lc_rename(author)
303
288
 
304
289
  if error:
305
290
  await ErrorMessage(author=author, content=message).send()
@@ -0,0 +1,34 @@
1
+ try:
2
+ import llama_index
3
+
4
+ if llama_index.__version__ < "0.6.27":
5
+ raise ValueError(
6
+ "LlamaIndex version is too old, expected >= 0.6.27. Run `pip install llama_index --upgrade`"
7
+ )
8
+
9
+ LLAMA_INDEX_INSTALLED = True
10
+ except ImportError:
11
+ LLAMA_INDEX_INSTALLED = False
12
+
13
+
14
+ from chainlit.telemetry import trace
15
+ from typing import Callable
16
+
17
+ from chainlit.config import config
18
+ from chainlit.utils import wrap_user_function
19
+
20
+
21
+ @trace
22
+ def llama_index_factory(func: Callable) -> Callable:
23
+ """
24
+ Plug and play decorator for the Llama Index library.
25
+ The decorated function should instantiate a new Llama instance.
26
+ One instance per user session is created and cached.
27
+ The per user instance is called every time a new message is received.
28
+
29
+ Returns:
30
+ Callable[[], Any]: The decorated factory function.
31
+ """
32
+
33
+ config.code.llama_index_factory = wrap_user_function(func, with_task=True)
34
+ return func
@@ -0,0 +1,99 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+
4
+ from llama_index.callbacks.base import BaseCallbackHandler
5
+ from llama_index.callbacks.schema import CBEventType, EventPayload
6
+
7
+
8
+ from chainlit.message import Message
9
+ from chainlit.element import Text
10
+ from chainlit.sync import run_sync
11
+
12
+
13
+ DEFAULT_IGNORE = [
14
+ CBEventType.CHUNKING,
15
+ CBEventType.SYNTHESIZE,
16
+ CBEventType.EMBEDDING,
17
+ CBEventType.NODE_PARSING,
18
+ CBEventType.QUERY,
19
+ CBEventType.TREE,
20
+ ]
21
+
22
+
23
+ class LlamaIndexCallbackHandler(BaseCallbackHandler):
24
+ """Base callback handler that can be used to track event starts and ends."""
25
+
26
+ def __init__(
27
+ self,
28
+ event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
29
+ event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
30
+ ) -> None:
31
+ """Initialize the base callback handler."""
32
+ self.event_starts_to_ignore = tuple(event_starts_to_ignore)
33
+ self.event_ends_to_ignore = tuple(event_ends_to_ignore)
34
+
35
+ def on_event_start(
36
+ self,
37
+ event_type: CBEventType,
38
+ payload: Optional[Dict[str, Any]] = None,
39
+ event_id: str = "",
40
+ **kwargs: Any,
41
+ ) -> str:
42
+ """Run when an event starts and return id of event."""
43
+ run_sync(
44
+ Message(
45
+ author=event_type,
46
+ indent=1,
47
+ content="",
48
+ ).send()
49
+ )
50
+ return ""
51
+
52
+ def on_event_end(
53
+ self,
54
+ event_type: CBEventType,
55
+ payload: Optional[Dict[str, Any]] = None,
56
+ event_id: str = "",
57
+ **kwargs: Any,
58
+ ) -> None:
59
+ """Run when an event ends."""
60
+
61
+ if event_type == CBEventType.RETRIEVE:
62
+ sources = payload.get(EventPayload.NODES)
63
+ if sources:
64
+ elements = [
65
+ Text(name=f"Source {idx}", content=source.node.get_text())
66
+ for idx, source in enumerate(sources)
67
+ ]
68
+ source_refs = "\, ".join(
69
+ [f"Source {idx}" for idx, _ in enumerate(sources)]
70
+ )
71
+ content = f"Retrieved the following sources: {source_refs}"
72
+
73
+ run_sync(
74
+ Message(
75
+ content=content, author=event_type, elements=elements, indent=1
76
+ ).send()
77
+ )
78
+
79
+ if event_type == CBEventType.LLM:
80
+ run_sync(
81
+ Message(
82
+ content=payload.get(EventPayload.RESPONSE, ""),
83
+ author=event_type,
84
+ indent=1,
85
+ prompt=payload.get(EventPayload.PROMPT),
86
+ ).send()
87
+ )
88
+
89
+ def start_trace(self, trace_id: Optional[str] = None) -> None:
90
+ """Run when an overall trace is launched."""
91
+ pass
92
+
93
+ def end_trace(
94
+ self,
95
+ trace_id: Optional[str] = None,
96
+ trace_map: Optional[Dict[str, List[str]]] = None,
97
+ ) -> None:
98
+ """Run when an overall trace is exited."""
99
+ pass
@@ -0,0 +1,34 @@
1
+ from typing import Union
2
+ from llama_index.response.schema import Response, StreamingResponse
3
+ from llama_index.chat_engine.types import BaseChatEngine
4
+ from llama_index.indices.query.base import BaseQueryEngine
5
+
6
+ from chainlit.message import Message
7
+ from chainlit.sync import make_async
8
+
9
+
10
+ async def run_llama(instance: Union[BaseChatEngine, BaseQueryEngine], input_str: str):
11
+ # Trick to display the loader in the UI until the first token is streamed
12
+ await Message(content="").send()
13
+
14
+ response_message = Message(content="")
15
+
16
+ if isinstance(instance, BaseQueryEngine):
17
+ response = await make_async(instance.query)(input_str)
18
+ elif isinstance(instance, BaseChatEngine):
19
+ response = await make_async(instance.chat)(input_str)
20
+ else:
21
+ raise NotImplementedError
22
+
23
+ if isinstance(response, Response):
24
+ response_message.content = str(response)
25
+ await response_message.send()
26
+ elif isinstance(response, StreamingResponse):
27
+ gen = response.response_gen
28
+ for token in gen:
29
+ await response_message.stream_token(token=token)
30
+
31
+ if response.response_txt:
32
+ response_message.content = response.response_txt
33
+
34
+ await response_message.send()
chainlit/logger.py CHANGED
@@ -1,12 +1,17 @@
1
1
  import logging
2
+ import sys
3
+
2
4
 
3
5
  logging.basicConfig(
4
- level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
6
+ level=logging.INFO,
7
+ stream=sys.stdout,
8
+ format="%(asctime)s - %(message)s",
9
+ datefmt="%Y-%m-%d %H:%M:%S",
5
10
  )
6
11
 
7
12
  logging.getLogger("socketio").setLevel(logging.ERROR)
8
13
  logging.getLogger("engineio").setLevel(logging.ERROR)
9
- logging.getLogger("geventwebsocket.handler").setLevel(logging.ERROR)
10
14
  logging.getLogger("numexpr").setLevel(logging.ERROR)
11
15
 
16
+
12
17
  logger = logging.getLogger("chainlit")
chainlit/message.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from typing import List, Dict, Union
2
2
  from abc import ABC, abstractmethod
3
3
  import uuid
4
- import time
5
4
  import asyncio
5
+ from datetime import datetime, timezone
6
6
 
7
7
  from chainlit.telemetry import trace_event
8
- from chainlit.emitter import get_emitter
8
+ from chainlit.context import get_emitter
9
9
  from chainlit.config import config
10
10
  from chainlit.types import (
11
11
  LLMSettings,
@@ -16,11 +16,7 @@ from chainlit.types import (
16
16
  )
17
17
  from chainlit.element import Element
18
18
  from chainlit.action import Action
19
-
20
-
21
- def current_milli_time():
22
- """Get the current time in milliseconds."""
23
- return round(time.time() * 1000)
19
+ from chainlit.logger import logger
24
20
 
25
21
 
26
22
  class MessageBase(ABC):
@@ -28,14 +24,13 @@ class MessageBase(ABC):
28
24
  temp_id: str = None
29
25
  streaming = False
30
26
  created_at: int = None
27
+ fail_on_persist_error: bool = True
31
28
 
32
29
  def __post_init__(self) -> None:
33
30
  trace_event(f"init {self.__class__.__name__}")
34
31
  self.temp_id = uuid.uuid4().hex
35
- self.created_at = current_milli_time()
32
+ self.created_at = datetime.now(timezone.utc).isoformat()
36
33
  self.emitter = get_emitter()
37
- if not self.emitter:
38
- raise RuntimeError("Message should be instantiated in a Chainlit context")
39
34
 
40
35
  @abstractmethod
41
36
  def to_dict(self):
@@ -44,9 +39,14 @@ class MessageBase(ABC):
44
39
  async def _create(self):
45
40
  msg_dict = self.to_dict()
46
41
  if self.emitter.client and not self.id:
47
- self.id = await self.emitter.client.create_message(msg_dict)
48
- if self.id:
49
- msg_dict["id"] = self.id
42
+ try:
43
+ self.id = await self.emitter.client.create_message(msg_dict)
44
+ if self.id:
45
+ msg_dict["id"] = self.id
46
+ except Exception as e:
47
+ if self.fail_on_persist_error:
48
+ raise e
49
+ logger.error(f"Failed to persist message: {str(e)}")
50
50
 
51
51
  return msg_dict
52
52
 
@@ -77,8 +77,7 @@ class MessageBase(ABC):
77
77
  msg_dict = self.to_dict()
78
78
 
79
79
  if self.emitter.client and self.id:
80
- self.emitter.client.update_message(self.id, msg_dict)
81
- msg_dict["id"] = self.id
80
+ await self.emitter.client.update_message(self.id, msg_dict)
82
81
 
83
82
  await self.emitter.update_message(msg_dict)
84
83
 
@@ -171,7 +170,7 @@ class Message(MessageBase):
171
170
  super().__post_init__()
172
171
 
173
172
  def to_dict(self):
174
- return {
173
+ _dict = {
175
174
  "tempId": self.temp_id,
176
175
  "createdAt": self.created_at,
177
176
  "content": self.content,
@@ -182,6 +181,11 @@ class Message(MessageBase):
182
181
  "indent": self.indent,
183
182
  }
184
183
 
184
+ if self.id:
185
+ _dict["id"] = self.id
186
+
187
+ return _dict
188
+
185
189
  async def send(self):
186
190
  """
187
191
  Send the message to the UI and persist it in the cloud if a project ID is configured.
@@ -214,10 +218,12 @@ class ErrorMessage(MessageBase):
214
218
  content: str,
215
219
  author: str = config.ui.name,
216
220
  indent: int = 0,
221
+ fail_on_persist_error: bool = False,
217
222
  ):
218
223
  self.content = content
219
224
  self.author = author
220
225
  self.indent = indent
226
+ self.fail_on_persist_error = fail_on_persist_error
221
227
 
222
228
  super().__post_init__()
223
229
 
@@ -241,10 +247,10 @@ class ErrorMessage(MessageBase):
241
247
 
242
248
 
243
249
  class AskMessageBase(MessageBase):
244
- def remove(self):
245
- removed = super().remove()
250
+ async def remove(self):
251
+ removed = await super().remove()
246
252
  if removed:
247
- self.emitter.clear_ask()
253
+ await self.emitter.clear_ask()
248
254
 
249
255
 
250
256
  class AskUserMessage(AskMessageBase):