chainlit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -0,0 +1,257 @@
1
+ from typing import Optional, Dict
2
+ import uuid
3
+ import json
4
+ import os
5
+
6
+ import asyncio
7
+ import aiofiles
8
+
9
+ from chainlit.client.base import PaginatedResponse, PageInfo
10
+
11
+ from .base import BaseClient
12
+
13
+ from chainlit.logger import logger
14
+ from chainlit.config import config
15
+ from chainlit.element import mime_to_ext
16
+
17
+
18
+ class LocalClient(BaseClient):
19
+ conversation_id: Optional[str] = None
20
+ lock: asyncio.Lock
21
+
22
+ def __init__(self):
23
+ self.lock = asyncio.Lock()
24
+
25
+ def before_write(self, variables: Dict):
26
+ if "llmSettings" in variables:
27
+ # Sqlite doesn't support json fields, so we need to serialize it.
28
+ variables["llmSettings"] = json.dumps(variables["llmSettings"])
29
+
30
+ if "forIds" in variables:
31
+ # Sqlite doesn't support list of primitives, so we need to serialize it.
32
+ variables["forIds"] = json.dumps(variables["forIds"])
33
+
34
+ if "tempId" in variables:
35
+ del variables["tempId"]
36
+
37
+ def after_read(self, variables: Dict):
38
+ if "llmSettings" in variables:
39
+ # Sqlite doesn't support json fields, so we need to parse it.
40
+ variables["llmSettings"] = json.loads(variables["llmSettings"])
41
+
42
+ async def is_project_member(self):
43
+ return True
44
+
45
+ async def get_member_role(self):
46
+ return "OWNER"
47
+
48
+ async def get_project_members(self):
49
+ return []
50
+
51
+ async def get_conversation_id(self):
52
+ self.conversation_id = await self.create_conversation()
53
+
54
+ return self.conversation_id
55
+
56
+ async def create_conversation(self):
57
+ from prisma.models import Conversation
58
+
59
+ # If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
60
+ async with self.lock:
61
+ if self.conversation_id:
62
+ return self.conversation_id
63
+
64
+ res = await Conversation.prisma().create(data={})
65
+
66
+ return res.id
67
+
68
+ async def delete_conversation(self, conversation_id):
69
+ from prisma.models import Conversation
70
+
71
+ await Conversation.prisma().delete(where={"id": conversation_id})
72
+
73
+ return True
74
+
75
+ async def get_conversation(self, conversation_id: int):
76
+ from prisma.models import Conversation
77
+
78
+ c = await Conversation.prisma().find_unique_or_raise(
79
+ where={"id": conversation_id}, include={"messages": True, "elements": True}
80
+ )
81
+
82
+ for m in c.messages:
83
+ if m.llmSettings:
84
+ m.llmSettings = json.loads(m.llmSettings)
85
+
86
+ for e in c.elements:
87
+ if e.forIds:
88
+ e.forIds = json.loads(e.forIds)
89
+
90
+ return json.loads(c.json())
91
+
92
+ async def get_conversations(self, pagination, filter):
93
+ from prisma.models import Conversation
94
+
95
+ some_messages = {}
96
+
97
+ if filter.feedback is not None:
98
+ some_messages["humanFeedback"] = filter.feedback
99
+
100
+ if filter.search is not None:
101
+ some_messages["content"] = {"contains": filter.search or None}
102
+
103
+ if pagination.cursor:
104
+ cursor = {"id": pagination.cursor}
105
+ else:
106
+ cursor = None
107
+
108
+ conversations = await Conversation.prisma().find_many(
109
+ take=pagination.first,
110
+ skip=1 if pagination.cursor else None,
111
+ cursor=cursor,
112
+ include={
113
+ "messages": {
114
+ "take": 1,
115
+ "where": {
116
+ "authorIsUser": True,
117
+ },
118
+ "orderBy": [
119
+ {
120
+ "createdAt": "asc",
121
+ }
122
+ ],
123
+ }
124
+ },
125
+ where={"messages": {"some": some_messages}},
126
+ order={
127
+ "createdAt": "desc",
128
+ },
129
+ )
130
+
131
+ has_more = len(conversations) == pagination.first
132
+
133
+ if has_more:
134
+ end_cursor = conversations[-1].id
135
+ else:
136
+ end_cursor = None
137
+
138
+ conversations = [json.loads(c.json()) for c in conversations]
139
+
140
+ return PaginatedResponse(
141
+ pageInfo=PageInfo(hasNextPage=has_more, endCursor=end_cursor),
142
+ data=conversations,
143
+ )
144
+
145
+ async def create_message(self, variables):
146
+ from prisma.models import Message
147
+
148
+ c_id = await self.get_conversation_id()
149
+
150
+ if not c_id:
151
+ logger.warning("Missing conversation ID, could not persist the message.")
152
+ return None
153
+
154
+ variables = variables.copy()
155
+
156
+ variables["conversationId"] = c_id
157
+
158
+ self.before_write(variables)
159
+
160
+ res = await Message.prisma().create(data=variables)
161
+ return res.id
162
+
163
+ async def get_message(self, message_id):
164
+ from prisma.models import Message
165
+
166
+ res = await Message.prisma().find_first(where={"id": message_id})
167
+ res = res.dict()
168
+ self.after_read(res)
169
+ return res
170
+
171
+ async def update_message(self, message_id, variables):
172
+ from prisma.models import Message
173
+
174
+ variables = variables.copy()
175
+
176
+ self.before_write(variables)
177
+
178
+ await Message.prisma().update(data=variables, where={"id": message_id})
179
+
180
+ return True
181
+
182
+ async def delete_message(self, message_id):
183
+ from prisma.models import Message
184
+
185
+ await Message.prisma().delete(where={"id": message_id})
186
+
187
+ return True
188
+
189
+ async def upsert_element(
190
+ self,
191
+ variables,
192
+ ):
193
+ from prisma.models import Element
194
+
195
+ c_id = await self.get_conversation_id()
196
+
197
+ if not c_id:
198
+ logger.warning("Missing conversation ID, could not persist the element.")
199
+ return None
200
+
201
+ variables["conversationId"] = c_id
202
+
203
+ self.before_write(variables)
204
+
205
+ if "id" in variables:
206
+ res = await Element.prisma().update(
207
+ data=variables, where={"id": variables.get("id")}
208
+ )
209
+ else:
210
+ res = await Element.prisma().create(data=variables)
211
+
212
+ return res.dict()
213
+
214
+ async def get_element(
215
+ self,
216
+ conversation_id,
217
+ element_id,
218
+ ):
219
+ from prisma.models import Element
220
+
221
+ res = await Element.prisma().find_unique_or_raise(where={"id": element_id})
222
+ return json.loads(res.json())
223
+
224
+ async def upload_element(self, content: bytes, mime: str):
225
+ c_id = await self.get_conversation_id()
226
+
227
+ if not c_id:
228
+ logger.warning("Missing conversation ID, could not persist the message.")
229
+ return None
230
+
231
+ file_ext = mime_to_ext.get(mime, "bin")
232
+ file_name = f"{uuid.uuid4()}.{file_ext}"
233
+
234
+ sub_path = os.path.join(str(c_id), file_name)
235
+ full_path = os.path.join(config.project.local_fs_path, sub_path)
236
+
237
+ if not os.path.exists(os.path.dirname(full_path)):
238
+ os.makedirs(os.path.dirname(full_path))
239
+
240
+ async with aiofiles.open(full_path, "wb") as out:
241
+ await out.write(content)
242
+ await out.flush()
243
+
244
+ url = f"/files/{sub_path}"
245
+ return url
246
+
247
+ async def set_human_feedback(self, message_id, feedback):
248
+ from prisma.models import Message
249
+
250
+ await Message.prisma().update(
251
+ where={"id": message_id},
252
+ data={
253
+ "humanFeedback": feedback,
254
+ },
255
+ )
256
+
257
+ return True
@@ -0,0 +1,23 @@
1
+ from fastapi import HTTPException, Request
2
+
3
+ from chainlit.config import config
4
+ from chainlit.client.base import BaseClient
5
+ from chainlit.client.local import LocalClient
6
+ from chainlit.client.cloud import CloudClient
7
+
8
+
9
+ async def get_client(request: Request) -> BaseClient:
10
+ auth_header = request.headers.get("Authorization")
11
+
12
+ db = config.project.database
13
+
14
+ if db == "local":
15
+ client = LocalClient()
16
+ elif db == "cloud":
17
+ client = CloudClient(config.project.id, auth_header)
18
+ elif db == "custom":
19
+ client = await config.code.client_factory()
20
+ else:
21
+ raise HTTPException(status_code=500, detail="Invalid database type")
22
+
23
+ return client
chainlit/config.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, Any, Callable, List, Dict, TYPE_CHECKING
1
+ from typing import Optional, Any, Callable, Union, Literal, List, Dict, TYPE_CHECKING
2
2
  import os
3
3
  import sys
4
4
  import tomli
@@ -10,6 +10,7 @@ from chainlit.version import __version__
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from chainlit.action import Action
13
+ from chainlit.client.base import BaseClient
13
14
 
14
15
  PACKAGE_ROOT = os.path.dirname(__file__)
15
16
 
@@ -26,10 +27,15 @@ DEFAULT_CONFIG_STR = f"""[project]
26
27
  public = true
27
28
 
28
29
  # The project ID (found on https://cloud.chainlit.io).
29
- # If provided, all the message data will be stored in the cloud.
30
- # The project ID is required when public is set to false.
30
+ # The project ID is required when public is set to false or when using the cloud database.
31
31
  #id = ""
32
32
 
33
+ # Uncomment if you want to persist the chats.
34
+ # local will create a database in your .chainlit directory (requires node.js installed).
35
+ # cloud will use the Chainlit cloud database.
36
+ # custom will load use your custom client.
37
+ # database = "local"
38
+
33
39
  # Whether to enable telemetry (default: true). No personal data is collected.
34
40
  enable_telemetry = true
35
41
 
@@ -102,15 +108,19 @@ class CodeSettings:
102
108
  lc_postprocess: Optional[Callable[[Any], str]] = None
103
109
  lc_factory: Optional[Callable[[], Any]] = None
104
110
  lc_rename: Optional[Callable[[str], str]] = None
111
+ llama_index_factory: Optional[Callable[[], Any]] = None
112
+ langflow_schema: Union[Dict, str] = None
113
+ client_factory: Optional[Callable[[str], "BaseClient"]] = None
105
114
 
106
115
  def validate(self):
107
116
  requires_one_of = [
108
117
  "lc_factory",
118
+ "llama_index_factory",
109
119
  "on_message",
110
120
  "on_chat_start",
111
121
  ]
112
122
 
113
- mutually_exclusive = ["lc_factory"]
123
+ mutually_exclusive = ["lc_factory", "llama_index_factory"]
114
124
 
115
125
  # Check if at least one of the required attributes is set
116
126
  if not any(getattr(self, attr) for attr in requires_one_of):
@@ -136,12 +146,18 @@ class ProjectSettings:
136
146
  id: Optional[str] = None
137
147
  # Whether the app is available to anonymous users or only to team members.
138
148
  public: bool = True
149
+ # Storage type
150
+ database: Optional[Literal["local", "cloud", "custom"]] = None
139
151
  # Whether to enable telemetry. No personal data is collected.
140
152
  enable_telemetry: bool = True
141
153
  # List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
142
154
  user_env: List[str] = None
143
155
  # Path to the local langchain cache database
144
156
  lc_cache_path: str = None
157
+ # Path to the local chat db
158
+ local_db_path: str = None
159
+ # Path to the local file system
160
+ local_fs_path: str = None
145
161
 
146
162
 
147
163
  @dataclass()
@@ -203,10 +219,20 @@ def load_settings():
203
219
  )
204
220
 
205
221
  lc_cache_path = os.path.join(config_dir, ".langchain.db")
222
+ local_db_path = os.path.join(config_dir, "chat.db")
223
+ local_fs_path = os.path.join(config_dir, "chat_files")
224
+
225
+ os.environ[
226
+ "LOCAL_DB_PATH"
227
+ ] = f"file:{local_db_path}?socket_timeout=10&connection_limit=1"
206
228
 
207
229
  project_settings = ProjectSettings(
208
- lc_cache_path=lc_cache_path, **project_config
230
+ lc_cache_path=lc_cache_path,
231
+ local_db_path=local_db_path,
232
+ local_fs_path=local_fs_path,
233
+ **project_config,
209
234
  )
235
+
210
236
  ui_settings = UISettings(**ui_settings)
211
237
 
212
238
  if not project_settings.public and not project_settings.id:
chainlit/context.py ADDED
@@ -0,0 +1,29 @@
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+ from asyncio import AbstractEventLoop
4
+
5
+ if TYPE_CHECKING:
6
+ from chainlit.emitter import ChainlitEmitter
7
+
8
+
9
+ class ChainlitContextException(Exception):
10
+ def __init__(self, msg="Chainlit context not found", *args, **kwargs):
11
+ super().__init__(msg, *args, **kwargs)
12
+
13
+
14
+ emitter_var = contextvars.ContextVar("emitter")
15
+ loop_var = contextvars.ContextVar("loop")
16
+
17
+
18
+ def get_emitter() -> "ChainlitEmitter":
19
+ try:
20
+ return emitter_var.get()
21
+ except LookupError:
22
+ raise ChainlitContextException()
23
+
24
+
25
+ def get_loop() -> AbstractEventLoop:
26
+ try:
27
+ return loop_var.get()
28
+ except LookupError:
29
+ raise ChainlitContextException()
@@ -0,0 +1,35 @@
1
+ import os
2
+ from chainlit.logger import logger
3
+ from chainlit.config import config, PACKAGE_ROOT
4
+
5
+ SCHEMA_PATH = os.path.join(PACKAGE_ROOT, "db/prisma/schema.prisma")
6
+
7
+
8
+ def db_push():
9
+ from prisma.cli.prisma import run
10
+ import prisma
11
+ from importlib import reload
12
+
13
+ args = ["db", "push", f"--schema={SCHEMA_PATH}"]
14
+ env = {"LOCAL_DB_PATH": os.environ.get("LOCAL_DB_PATH")}
15
+ run(args, env=env)
16
+
17
+ # Without this the client will fail to initialize the first time.
18
+ reload(prisma)
19
+
20
+
21
+ def init_local_db():
22
+ use_local_db = config.project.database == "local"
23
+ if use_local_db:
24
+ if not os.path.exists(config.project.local_db_path):
25
+ db_push()
26
+
27
+
28
+ def migrate_local_db():
29
+ use_local_db = config.project.database == "local"
30
+ if use_local_db:
31
+ if os.path.exists(config.project.local_db_path):
32
+ db_push()
33
+ logger.info(f"Local db migrated")
34
+ else:
35
+ logger.info(f"Local db does not exist, skipping migration")
@@ -0,0 +1,48 @@
1
+ datasource db {
2
+ provider = "sqlite"
3
+ url = env("LOCAL_DB_PATH")
4
+ }
5
+
6
+ generator client {
7
+ provider = "prisma-client-py"
8
+ recursive_type_depth = 5
9
+ }
10
+
11
+ model Conversation {
12
+ id Int @id @default(autoincrement())
13
+ createdAt DateTime @default(now())
14
+ messages Message[]
15
+ elements Element[]
16
+ }
17
+
18
+ model Element {
19
+ id Int @id @default(autoincrement())
20
+ createdAt DateTime @default(now())
21
+ conversationId Int
22
+ conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
23
+ type String
24
+ url String
25
+ name String
26
+ display String
27
+ size String?
28
+ language String?
29
+ forIds String?
30
+ }
31
+
32
+ model Message {
33
+ id Int @id @default(autoincrement())
34
+ createdAt DateTime @default(now())
35
+ conversationId Int
36
+ conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
37
+ authorIsUser Boolean @default(false)
38
+ isError Boolean @default(false)
39
+ waitForAnswer Boolean @default(false)
40
+ indent Int @default(0)
41
+ author String
42
+ content String
43
+ humanFeedback Int @default(0)
44
+ language String?
45
+ prompt String?
46
+ // Sqlite does not support JSON
47
+ llmSettings String?
48
+ }
chainlit/element.py CHANGED
@@ -1,22 +1,27 @@
1
1
  from pydantic.dataclasses import dataclass
2
- from dataclasses_json import dataclass_json
3
- from typing import Dict, Union, Any
2
+ from typing import Dict, List, Union, Any
4
3
  import uuid
5
4
  import aiofiles
6
5
  from io import BytesIO
7
6
 
8
- from chainlit.emitter import get_emitter, BaseClient
7
+ from chainlit.context import get_emitter
8
+ from chainlit.client.base import BaseClient
9
9
  from chainlit.telemetry import trace_event
10
10
  from chainlit.types import ElementType, ElementDisplay, ElementSize
11
11
 
12
12
  type_to_mime = {
13
- "image": "binary/octet-stream",
13
+ "image": "image/png",
14
14
  "text": "text/plain",
15
15
  "pdf": "application/pdf",
16
16
  }
17
17
 
18
+ mime_to_ext = {
19
+ "image/png": "png",
20
+ "text/plain": "txt",
21
+ "application/pdf": "pdf",
22
+ }
23
+
18
24
 
19
- @dataclass_json
20
25
  @dataclass
21
26
  class Element:
22
27
  # Name of the element, this will be used to reference the element in the UI.
@@ -34,20 +39,35 @@ class Element:
34
39
  # The ID of the element. This is set automatically when the element is sent to the UI if cloud is enabled.
35
40
  id: int = None
36
41
  # The ID of the element if cloud is disabled.
37
- tempId: str = None
42
+ temp_id: str = None
38
43
  # The ID of the message this element is associated with.
39
- forId: str = None
44
+ for_ids: List[str] = None
40
45
 
41
46
  def __post_init__(self) -> None:
42
47
  trace_event(f"init {self.__class__.__name__}")
43
48
  self.emitter = get_emitter()
44
- if not self.emitter:
45
- raise RuntimeError("Element should be instantiated in a Chainlit context")
49
+ self.for_ids = []
50
+ self.temp_id = str(uuid.uuid4())
46
51
 
47
52
  if not self.url and not self.path and not self.content:
48
53
  raise ValueError("Must provide url, path or content to instantiate element")
49
54
 
50
- self.tempId = uuid.uuid4().hex
55
+ def to_dict(self) -> Dict:
56
+ _dict = {
57
+ "tempId": self.temp_id,
58
+ "type": self.type,
59
+ "url": self.url,
60
+ "name": self.name,
61
+ "display": self.display,
62
+ "size": getattr(self, "size", None),
63
+ "language": getattr(self, "language", None),
64
+ "forIds": getattr(self, "for_ids", None),
65
+ }
66
+
67
+ if self.id:
68
+ _dict["id"] = self.id
69
+
70
+ return _dict
51
71
 
52
72
  async def preprocess_content(self):
53
73
  pass
@@ -56,29 +76,15 @@ class Element:
56
76
  if self.path:
57
77
  async with aiofiles.open(self.path, "rb") as f:
58
78
  self.content = await f.read()
59
- await self.preprocess_content()
60
- elif self.content:
61
- await self.preprocess_content()
62
79
  else:
63
80
  raise ValueError("Must provide path or content to load element")
64
81
 
65
- async def persist(self, client: BaseClient, for_id: str = None):
66
- if not self.url and self.content:
82
+ async def persist(self, client: BaseClient):
83
+ if not self.url and self.content and not self.id:
67
84
  self.url = await client.upload_element(
68
85
  content=self.content, mime=type_to_mime[self.type]
69
86
  )
70
-
71
- size = getattr(self, "size", None)
72
- language = getattr(self, "language", None)
73
- element = await client.create_element(
74
- name=self.name,
75
- url=self.url,
76
- type=self.type,
77
- display=self.display,
78
- size=size,
79
- language=language,
80
- for_id=for_id,
81
- )
87
+ element = await client.upsert_element(self.to_dict())
82
88
  return element
83
89
 
84
90
  async def before_emit(self, element: Dict) -> Dict:
@@ -90,22 +96,34 @@ class Element:
90
96
  if not self.content and not self.url and self.path:
91
97
  await self.load()
92
98
 
93
- # Cloud is enabled, upload the element to S3
94
- if self.emitter.client and not self.id:
95
- element = await self.persist(self.emitter.client, for_id)
96
- self.id = element["id"]
99
+ await self.preprocess_content()
100
+
101
+ if for_id:
102
+ self.for_ids.append(for_id)
103
+
104
+ # We have a client, persist the element
105
+ if self.emitter.client:
106
+ element = await self.persist(self.emitter.client)
107
+ self.id = element and element.get("id")
97
108
 
98
109
  elif not self.url and not self.content:
99
110
  raise ValueError("Must provide url or content to send element")
100
111
 
101
112
  element = self.to_dict()
102
- if for_id:
103
- element["forId"] = for_id
113
+
114
+ element["content"] = self.content
104
115
 
105
116
  if self.emitter.emit and element:
106
- trace_event(f"send {self.__class__.__name__}")
107
- element = await self.before_emit(element)
108
- await self.emitter.emit("element", element)
117
+ if len(self.for_ids) > 1:
118
+ trace_event(f"update {self.__class__.__name__}")
119
+ await self.emitter.emit(
120
+ "update_element",
121
+ {"id": self.id or self.temp_id, "forIds": self.for_ids},
122
+ )
123
+ else:
124
+ trace_event(f"send {self.__class__.__name__}")
125
+ element = await self.before_emit(element)
126
+ await self.emitter.emit("element", element)
109
127
 
110
128
 
111
129
  @dataclass
@@ -187,8 +205,3 @@ class Pyplot(Element):
187
205
  self.content = image.getvalue()
188
206
 
189
207
  super().__post_init__()
190
-
191
- async def before_emit(self, element: Dict) -> Dict:
192
- # Prevent the figure from being serialized
193
- del element["figure"]
194
- return element
chainlit/emitter.py CHANGED
@@ -1,9 +1,8 @@
1
1
  from typing import Union, Dict
2
2
  from chainlit.session import Session
3
3
  from chainlit.types import AskSpec
4
- from chainlit.client import BaseClient
4
+ from chainlit.client.base import BaseClient
5
5
  from socketio.exceptions import TimeoutError
6
- import inspect
7
6
 
8
7
 
9
8
  class ChainlitEmitter:
@@ -129,31 +128,3 @@ class ChainlitEmitter:
129
128
  def send_token(self, id: Union[str, int], token: str):
130
129
  """Send a message token to the UI."""
131
130
  return self.emit("stream_token", {"id": id, "token": token})
132
-
133
-
134
- def get_emitter() -> Union[ChainlitEmitter, None]:
135
- """
136
- Get the Chainlit Emitter instance from the current call stack.
137
- This unusual approach is necessary because:
138
- - we need to get the right Emitter instance with the right websocket connection
139
- - to preserve a lean developer experience, we do not pass the Emitter instance to every function call
140
-
141
- What happens is that we set __chainlit_emitter__ in the local variables when we receive a websocket message.
142
- Then we can retrieve it from the call stack when we need it, even if the developer's code has no idea about it.
143
- """
144
- attr = "__chainlit_emitter__"
145
- candidates = [i[0].f_locals.get(attr) for i in inspect.stack()]
146
- emitter = None
147
- for candidate in candidates:
148
- if candidate:
149
- emitter = candidate
150
- break
151
-
152
- return emitter
153
-
154
-
155
- def get_emit_fn():
156
- emitter = get_emitter()
157
- if emitter:
158
- return emitter.emit
159
- return None