chainlit 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -0,0 +1,23 @@
1
+ from fastapi import HTTPException, Request
2
+
3
+ from chainlit.config import config
4
+ from chainlit.client.base import BaseClient
5
+ from chainlit.client.local import LocalClient
6
+ from chainlit.client.cloud import CloudClient
7
+
8
+
9
+ async def get_client(request: Request) -> BaseClient:
10
+ auth_header = request.headers.get("Authorization")
11
+
12
+ db = config.project.database
13
+
14
+ if db == "local":
15
+ client = LocalClient()
16
+ elif db == "cloud":
17
+ client = CloudClient(config.project.id, auth_header)
18
+ elif db == "custom":
19
+ client = await config.code.client_factory()
20
+ else:
21
+ raise HTTPException(status_code=500, detail="Invalid database type")
22
+
23
+ return client
chainlit/config.py CHANGED
@@ -1,15 +1,16 @@
1
- from typing import Optional, Any, Callable, List, Dict, TYPE_CHECKING
1
+ from typing import Optional, Any, Callable, Literal, List, Dict, TYPE_CHECKING
2
2
  import os
3
3
  import sys
4
4
  import tomli
5
5
  from pydantic.dataclasses import dataclass
6
6
  from dataclasses_json import dataclass_json
7
- from importlib import machinery
7
+ from importlib import util
8
8
  from chainlit.logger import logger
9
9
  from chainlit.version import __version__
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from chainlit.action import Action
13
+ from chainlit.client.base import BaseClient
13
14
 
14
15
  PACKAGE_ROOT = os.path.dirname(__file__)
15
16
 
@@ -26,10 +27,15 @@ DEFAULT_CONFIG_STR = f"""[project]
26
27
  public = true
27
28
 
28
29
  # The project ID (found on https://cloud.chainlit.io).
29
- # If provided, all the message data will be stored in the cloud.
30
- # The project ID is required when public is set to false.
30
+ # The project ID is required when public is set to false or when using the cloud database.
31
31
  #id = ""
32
32
 
33
+ # Uncomment if you want to persist the chats.
34
+ # local will create a database in your .chainlit directory (requires node.js installed).
35
+ # cloud will use the Chainlit cloud database.
36
+ # custom will load use your custom client.
37
+ # database = "local"
38
+
33
39
  # Whether to enable telemetry (default: true). No personal data is collected.
34
40
  enable_telemetry = true
35
41
 
@@ -102,6 +108,32 @@ class CodeSettings:
102
108
  lc_postprocess: Optional[Callable[[Any], str]] = None
103
109
  lc_factory: Optional[Callable[[], Any]] = None
104
110
  lc_rename: Optional[Callable[[str], str]] = None
111
+ client_factory: Optional[Callable[[str], "BaseClient"]] = None
112
+
113
+ def validate(self):
114
+ requires_one_of = [
115
+ "lc_factory",
116
+ "on_message",
117
+ "on_chat_start",
118
+ ]
119
+
120
+ mutually_exclusive = ["lc_factory"]
121
+
122
+ # Check if at least one of the required attributes is set
123
+ if not any(getattr(self, attr) for attr in requires_one_of):
124
+ raise ValueError(
125
+ f"Module should at least expose one of {', '.join(requires_one_of)} function"
126
+ )
127
+
128
+ # Check if any mutually exclusive attributes are set together
129
+ for i, attr1 in enumerate(mutually_exclusive):
130
+ for attr2 in mutually_exclusive[i + 1 :]:
131
+ if getattr(self, attr1) and getattr(self, attr2):
132
+ raise ValueError(
133
+ f"Module should not expose both {attr1} and {attr2} functions"
134
+ )
135
+
136
+ return True
105
137
 
106
138
 
107
139
  @dataclass_json
@@ -111,12 +143,18 @@ class ProjectSettings:
111
143
  id: Optional[str] = None
112
144
  # Whether the app is available to anonymous users or only to team members.
113
145
  public: bool = True
146
+ # Storage type
147
+ database: Optional[Literal["local", "cloud", "custom"]] = None
114
148
  # Whether to enable telemetry. No personal data is collected.
115
149
  enable_telemetry: bool = True
116
150
  # List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
117
151
  user_env: List[str] = None
118
152
  # Path to the local langchain cache database
119
153
  lc_cache_path: str = None
154
+ # Path to the local chat db
155
+ local_db_path: str = None
156
+ # Path to the local file system
157
+ local_fs_path: str = None
120
158
 
121
159
 
122
160
  @dataclass()
@@ -145,35 +183,26 @@ def init_config(log=False):
145
183
  logger.info(f"Config file already exists at {config_file}")
146
184
 
147
185
 
148
- def reset_module_config():
149
- if not config:
150
- return
151
-
152
- config.code = CodeSettings(action_callbacks={})
153
-
154
-
155
186
  def load_module(target: str):
156
187
  """Load the specified module."""
157
188
 
158
- # Reset the config fields that belonged to the previous module
159
- reset_module_config()
160
-
161
189
  # Get the target's directory
162
190
  target_dir = os.path.dirname(os.path.abspath(target))
163
191
 
164
192
  # Add the target's directory to the Python path
165
193
  sys.path.insert(0, target_dir)
166
194
 
167
- loader = machinery.SourceFileLoader(target, target)
168
- config.code.module = loader.load_module()
195
+ spec = util.spec_from_file_location(target, target)
196
+ module = util.module_from_spec(spec)
197
+ spec.loader.exec_module(module)
169
198
 
170
199
  # Remove the target's directory from the Python path
171
200
  sys.path.pop(0)
172
201
 
202
+ config.code.validate()
173
203
 
174
- def load_config():
175
- """Load the configuration from the config file."""
176
- init_config()
204
+
205
+ def load_settings():
177
206
  with open(config_file, "rb") as f:
178
207
  toml_dict = tomli.load(f)
179
208
  # Load project settings
@@ -187,23 +216,57 @@ def load_config():
187
216
  )
188
217
 
189
218
  lc_cache_path = os.path.join(config_dir, ".langchain.db")
219
+ local_db_path = os.path.join(config_dir, "chat.db")
220
+ local_fs_path = os.path.join(config_dir, "chat_files")
221
+
222
+ os.environ[
223
+ "LOCAL_DB_PATH"
224
+ ] = f"file:{local_db_path}?socket_timeout=10&connection_limit=1"
190
225
 
191
226
  project_settings = ProjectSettings(
192
- lc_cache_path=lc_cache_path, **project_config
227
+ lc_cache_path=lc_cache_path,
228
+ local_db_path=local_db_path,
229
+ local_fs_path=local_fs_path,
230
+ **project_config,
193
231
  )
232
+
194
233
  ui_settings = UISettings(**ui_settings)
195
234
 
196
- if not project_settings.public and not project_settings.project_id:
235
+ if not project_settings.public and not project_settings.id:
197
236
  raise ValueError("Project ID is required when public is set to false.")
198
237
 
199
- config = ChainlitConfig(
200
- chainlit_server=chainlit_server,
201
- chainlit_prod_url=chainlit_prod_url,
202
- ui=ui_settings,
203
- run=RunSettings(),
204
- project=project_settings,
205
- code=CodeSettings(action_callbacks={}),
206
- )
238
+ return {
239
+ "ui": ui_settings,
240
+ "project": project_settings,
241
+ "code": CodeSettings(action_callbacks={}),
242
+ }
243
+
244
+
245
+ def reload_config():
246
+ """Reload the configuration from the config file."""
247
+ global config
248
+ if config is None:
249
+ return
250
+
251
+ settings = load_settings()
252
+
253
+ config.code = settings["code"]
254
+ config.ui = settings["ui"]
255
+ config.project = settings["project"]
256
+
257
+
258
+ def load_config():
259
+ """Load the configuration from the config file."""
260
+ init_config()
261
+
262
+ settings = load_settings()
263
+
264
+ config = ChainlitConfig(
265
+ chainlit_server=chainlit_server,
266
+ chainlit_prod_url=chainlit_prod_url,
267
+ run=RunSettings(),
268
+ **settings,
269
+ )
207
270
 
208
271
  return config
209
272
 
chainlit/context.py ADDED
@@ -0,0 +1,29 @@
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+ from asyncio import AbstractEventLoop
4
+
5
+ if TYPE_CHECKING:
6
+ from chainlit.emitter import ChainlitEmitter
7
+
8
+
9
+ class ChainlitContextException(Exception):
10
+ def __init__(self, msg="Chainlit context not found", *args, **kwargs):
11
+ super().__init__(msg, *args, **kwargs)
12
+
13
+
14
+ emitter_var = contextvars.ContextVar("emitter")
15
+ loop_var = contextvars.ContextVar("loop")
16
+
17
+
18
+ def get_emitter() -> "ChainlitEmitter":
19
+ try:
20
+ return emitter_var.get()
21
+ except LookupError:
22
+ raise ChainlitContextException()
23
+
24
+
25
+ def get_loop() -> AbstractEventLoop:
26
+ try:
27
+ return loop_var.get()
28
+ except LookupError:
29
+ raise ChainlitContextException()
@@ -0,0 +1,35 @@
1
+ import os
2
+ from chainlit.logger import logger
3
+ from chainlit.config import config, PACKAGE_ROOT
4
+
5
+ SCHEMA_PATH = os.path.join(PACKAGE_ROOT, "db/prisma/schema.prisma")
6
+
7
+
8
+ def db_push():
9
+ from prisma.cli.prisma import run
10
+ import prisma
11
+ from importlib import reload
12
+
13
+ args = ["db", "push", f"--schema={SCHEMA_PATH}"]
14
+ env = {"LOCAL_DB_PATH": os.environ.get("LOCAL_DB_PATH")}
15
+ run(args, env=env)
16
+
17
+ # Without this the client will fail to initialize the first time.
18
+ reload(prisma)
19
+
20
+
21
+ def init_local_db():
22
+ use_local_db = config.project.database == "local"
23
+ if use_local_db:
24
+ if not os.path.exists(config.project.local_db_path):
25
+ db_push()
26
+
27
+
28
+ def migrate_local_db():
29
+ use_local_db = config.project.database == "local"
30
+ if use_local_db:
31
+ if os.path.exists(config.project.local_db_path):
32
+ db_push()
33
+ logger.info(f"Local db migrated")
34
+ else:
35
+ logger.info(f"Local db does not exist, skipping migration")
@@ -0,0 +1,48 @@
1
+ datasource db {
2
+ provider = "sqlite"
3
+ url = env("LOCAL_DB_PATH")
4
+ }
5
+
6
+ generator client {
7
+ provider = "prisma-client-py"
8
+ recursive_type_depth = 5
9
+ }
10
+
11
+ model Conversation {
12
+ id Int @id @default(autoincrement())
13
+ createdAt DateTime @default(now())
14
+ messages Message[]
15
+ elements Element[]
16
+ }
17
+
18
+ model Element {
19
+ id Int @id @default(autoincrement())
20
+ createdAt DateTime @default(now())
21
+ conversationId Int
22
+ conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
23
+ type String
24
+ url String
25
+ name String
26
+ display String
27
+ size String?
28
+ language String?
29
+ forIds String?
30
+ }
31
+
32
+ model Message {
33
+ id Int @id @default(autoincrement())
34
+ createdAt DateTime @default(now())
35
+ conversationId Int
36
+ conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
37
+ authorIsUser Boolean @default(false)
38
+ isError Boolean @default(false)
39
+ waitForAnswer Boolean @default(false)
40
+ indent Int @default(0)
41
+ author String
42
+ content String
43
+ humanFeedback Int @default(0)
44
+ language String?
45
+ prompt String?
46
+ // Sqlite does not support JSON
47
+ llmSettings String?
48
+ }
chainlit/element.py CHANGED
@@ -1,22 +1,27 @@
1
1
  from pydantic.dataclasses import dataclass
2
- from dataclasses_json import dataclass_json
3
- from typing import Dict, Union, Any
2
+ from typing import Dict, List, Union, Any
4
3
  import uuid
5
4
  import aiofiles
6
5
  from io import BytesIO
7
6
 
8
- from chainlit.emitter import get_emitter, BaseClient
7
+ from chainlit.context import get_emitter
8
+ from chainlit.client.base import BaseClient
9
9
  from chainlit.telemetry import trace_event
10
10
  from chainlit.types import ElementType, ElementDisplay, ElementSize
11
11
 
12
12
  type_to_mime = {
13
- "image": "binary/octet-stream",
13
+ "image": "image/png",
14
14
  "text": "text/plain",
15
15
  "pdf": "application/pdf",
16
16
  }
17
17
 
18
+ mime_to_ext = {
19
+ "image/png": "png",
20
+ "text/plain": "txt",
21
+ "application/pdf": "pdf",
22
+ }
23
+
18
24
 
19
- @dataclass_json
20
25
  @dataclass
21
26
  class Element:
22
27
  # Name of the element, this will be used to reference the element in the UI.
@@ -34,20 +39,35 @@ class Element:
34
39
  # The ID of the element. This is set automatically when the element is sent to the UI if cloud is enabled.
35
40
  id: int = None
36
41
  # The ID of the element if cloud is disabled.
37
- tempId: str = None
42
+ temp_id: str = None
38
43
  # The ID of the message this element is associated with.
39
- forId: str = None
44
+ for_ids: List[str] = None
40
45
 
41
46
  def __post_init__(self) -> None:
42
47
  trace_event(f"init {self.__class__.__name__}")
43
48
  self.emitter = get_emitter()
44
- if not self.emitter:
45
- raise RuntimeError("Element should be instantiated in a Chainlit context")
49
+ self.for_ids = []
50
+ self.temp_id = str(uuid.uuid4())
46
51
 
47
52
  if not self.url and not self.path and not self.content:
48
53
  raise ValueError("Must provide url, path or content to instantiate element")
49
54
 
50
- self.tempId = uuid.uuid4().hex
55
+ def to_dict(self) -> Dict:
56
+ _dict = {
57
+ "tempId": self.temp_id,
58
+ "type": self.type,
59
+ "url": self.url,
60
+ "name": self.name,
61
+ "display": self.display,
62
+ "size": getattr(self, "size", None),
63
+ "language": getattr(self, "language", None),
64
+ "forIds": getattr(self, "for_ids", None),
65
+ }
66
+
67
+ if self.id:
68
+ _dict["id"] = self.id
69
+
70
+ return _dict
51
71
 
52
72
  async def preprocess_content(self):
53
73
  pass
@@ -56,29 +76,15 @@ class Element:
56
76
  if self.path:
57
77
  async with aiofiles.open(self.path, "rb") as f:
58
78
  self.content = await f.read()
59
- await self.preprocess_content()
60
- elif self.content:
61
- await self.preprocess_content()
62
79
  else:
63
80
  raise ValueError("Must provide path or content to load element")
64
81
 
65
- async def persist(self, client: BaseClient, for_id: str = None):
66
- if not self.url and self.content:
82
+ async def persist(self, client: BaseClient):
83
+ if not self.url and self.content and not self.id:
67
84
  self.url = await client.upload_element(
68
85
  content=self.content, mime=type_to_mime[self.type]
69
86
  )
70
-
71
- size = getattr(self, "size", None)
72
- language = getattr(self, "language", None)
73
- element = await client.create_element(
74
- name=self.name,
75
- url=self.url,
76
- type=self.type,
77
- display=self.display,
78
- size=size,
79
- language=language,
80
- for_id=for_id,
81
- )
87
+ element = await client.upsert_element(self.to_dict())
82
88
  return element
83
89
 
84
90
  async def before_emit(self, element: Dict) -> Dict:
@@ -90,22 +96,34 @@ class Element:
90
96
  if not self.content and not self.url and self.path:
91
97
  await self.load()
92
98
 
93
- # Cloud is enabled, upload the element to S3
94
- if self.emitter.client and not self.id:
95
- element = await self.persist(self.emitter.client, for_id)
96
- self.id = element["id"]
99
+ await self.preprocess_content()
100
+
101
+ if for_id:
102
+ self.for_ids.append(for_id)
103
+
104
+ # We have a client, persist the element
105
+ if self.emitter.client:
106
+ element = await self.persist(self.emitter.client)
107
+ self.id = element and element.get("id")
97
108
 
98
109
  elif not self.url and not self.content:
99
110
  raise ValueError("Must provide url or content to send element")
100
111
 
101
112
  element = self.to_dict()
102
- if for_id:
103
- element["forId"] = for_id
113
+
114
+ element["content"] = self.content
104
115
 
105
116
  if self.emitter.emit and element:
106
- trace_event(f"send {self.__class__.__name__}")
107
- element = await self.before_emit(element)
108
- await self.emitter.emit("element", element)
117
+ if len(self.for_ids) > 1:
118
+ trace_event(f"update {self.__class__.__name__}")
119
+ await self.emitter.emit(
120
+ "update_element",
121
+ {"id": self.id or self.temp_id, "forIds": self.for_ids},
122
+ )
123
+ else:
124
+ trace_event(f"send {self.__class__.__name__}")
125
+ element = await self.before_emit(element)
126
+ await self.emitter.emit("element", element)
109
127
 
110
128
 
111
129
  @dataclass
@@ -187,8 +205,3 @@ class Pyplot(Element):
187
205
  self.content = image.getvalue()
188
206
 
189
207
  super().__post_init__()
190
-
191
- async def before_emit(self, element: Dict) -> Dict:
192
- # Prevent the figure from being serialized
193
- del element["figure"]
194
- return element
chainlit/emitter.py CHANGED
@@ -1,9 +1,8 @@
1
1
  from typing import Union, Dict
2
2
  from chainlit.session import Session
3
3
  from chainlit.types import AskSpec
4
- from chainlit.client import BaseClient
4
+ from chainlit.client.base import BaseClient
5
5
  from socketio.exceptions import TimeoutError
6
- import inspect
7
6
 
8
7
 
9
8
  class ChainlitEmitter:
@@ -129,31 +128,3 @@ class ChainlitEmitter:
129
128
  def send_token(self, id: Union[str, int], token: str):
130
129
  """Send a message token to the UI."""
131
130
  return self.emit("stream_token", {"id": id, "token": token})
132
-
133
-
134
- def get_emitter() -> Union[ChainlitEmitter, None]:
135
- """
136
- Get the Chainlit Emitter instance from the current call stack.
137
- This unusual approach is necessary because:
138
- - we need to get the right Emitter instance with the right websocket connection
139
- - to preserve a lean developer experience, we do not pass the Emitter instance to every function call
140
-
141
- What happens is that we set __chainlit_emitter__ in the local variables when we receive a websocket message.
142
- Then we can retrieve it from the call stack when we need it, even if the developer's code has no idea about it.
143
- """
144
- attr = "__chainlit_emitter__"
145
- candidates = [i[0].f_locals.get(attr) for i in inspect.stack()]
146
- emitter = None
147
- for candidate in candidates:
148
- if candidate:
149
- emitter = candidate
150
- break
151
-
152
- return emitter
153
-
154
-
155
- def get_emit_fn():
156
- emitter = get_emitter()
157
- if emitter:
158
- return emitter.emit
159
- return None