chainlit 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -14,8 +14,8 @@
14
14
  <script>
15
15
  const global = globalThis;
16
16
  </script>
17
- <script type="module" crossorigin src="/assets/index-9e4bccd1.js"></script>
18
- <link rel="stylesheet" href="/assets/index-0cc9e355.css">
17
+ <script type="module" crossorigin src="/assets/index-fb1e167a.js"></script>
18
+ <link rel="stylesheet" href="/assets/index-f93cc942.css">
19
19
  </head>
20
20
  <body>
21
21
  <div id="root"></div>
chainlit/lc/agent.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from typing import Any
2
2
  from chainlit.lc.callbacks import ChainlitCallbackHandler, AsyncChainlitCallbackHandler
3
3
  from chainlit.sync import make_async
4
+ from chainlit.context import emitter_var
4
5
 
5
6
 
6
7
  async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
chainlit/lc/callbacks.py CHANGED
@@ -6,7 +6,8 @@ from langchain.schema import (
6
6
  BaseMessage,
7
7
  LLMResult,
8
8
  )
9
- from chainlit.emitter import get_emitter, ChainlitEmitter
9
+ from chainlit.emitter import ChainlitEmitter
10
+ from chainlit.context import get_emitter
10
11
  from chainlit.message import Message, ErrorMessage
11
12
  from chainlit.config import config
12
13
  from chainlit.types import LLMSettings
@@ -107,14 +108,10 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
107
108
  return
108
109
 
109
110
  if config.code.lc_rename:
110
- author = run_sync(
111
- config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
112
- )
111
+ author = run_sync(config.code.lc_rename(author))
113
112
 
114
113
  self.pop_prompt()
115
114
 
116
- __chainlit_emitter__ = self.emitter
117
-
118
115
  streamed_message = Message(
119
116
  author=author,
120
117
  indent=indent,
@@ -135,11 +132,7 @@ class ChainlitCallbackHandler(BaseChainlitCallbackHandler, BaseCallbackHandler):
135
132
  return
136
133
 
137
134
  if config.code.lc_rename:
138
- author = run_sync(
139
- config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
140
- )
141
-
142
- __chainlit_emitter__ = self.emitter
135
+ author = run_sync(config.code.lc_rename(author))
143
136
 
144
137
  if error:
145
138
  run_sync(ErrorMessage(author=author, content=message).send())
@@ -267,14 +260,10 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
267
260
  return
268
261
 
269
262
  if config.code.lc_rename:
270
- author = await config.code.lc_rename(
271
- author, __chainlit_emitter__=self.emitter
272
- )
263
+ author = await config.code.lc_rename(author)
273
264
 
274
265
  self.pop_prompt()
275
266
 
276
- __chainlit_emitter__ = self.emitter
277
-
278
267
  streamed_message = Message(
279
268
  author=author,
280
269
  indent=indent,
@@ -295,11 +284,7 @@ class AsyncChainlitCallbackHandler(BaseChainlitCallbackHandler, AsyncCallbackHan
295
284
  return
296
285
 
297
286
  if config.code.lc_rename:
298
- author = await config.code.lc_rename(
299
- author, __chainlit_emitter__=self.emitter
300
- )
301
-
302
- __chainlit_emitter__ = self.emitter
287
+ author = await config.code.lc_rename(author)
303
288
 
304
289
  if error:
305
290
  await ErrorMessage(author=author, content=message).send()
chainlit/logger.py CHANGED
@@ -1,12 +1,17 @@
1
1
  import logging
2
+ import sys
3
+
2
4
 
3
5
  logging.basicConfig(
4
- level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
6
+ level=logging.INFO,
7
+ stream=sys.stdout,
8
+ format="%(asctime)s - %(message)s",
9
+ datefmt="%Y-%m-%d %H:%M:%S",
5
10
  )
6
11
 
7
12
  logging.getLogger("socketio").setLevel(logging.ERROR)
8
13
  logging.getLogger("engineio").setLevel(logging.ERROR)
9
- logging.getLogger("geventwebsocket.handler").setLevel(logging.ERROR)
10
14
  logging.getLogger("numexpr").setLevel(logging.ERROR)
11
15
 
16
+
12
17
  logger = logging.getLogger("chainlit")
chainlit/message.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from typing import List, Dict, Union
2
2
  from abc import ABC, abstractmethod
3
3
  import uuid
4
- import time
5
4
  import asyncio
5
+ from datetime import datetime, timezone
6
6
 
7
7
  from chainlit.telemetry import trace_event
8
- from chainlit.emitter import get_emitter
8
+ from chainlit.context import get_emitter
9
9
  from chainlit.config import config
10
10
  from chainlit.types import (
11
11
  LLMSettings,
@@ -16,11 +16,7 @@ from chainlit.types import (
16
16
  )
17
17
  from chainlit.element import Element
18
18
  from chainlit.action import Action
19
-
20
-
21
- def current_milli_time():
22
- """Get the current time in milliseconds."""
23
- return round(time.time() * 1000)
19
+ from chainlit.logger import logger
24
20
 
25
21
 
26
22
  class MessageBase(ABC):
@@ -28,14 +24,13 @@ class MessageBase(ABC):
28
24
  temp_id: str = None
29
25
  streaming = False
30
26
  created_at: int = None
27
+ fail_on_persist_error: bool = True
31
28
 
32
29
  def __post_init__(self) -> None:
33
30
  trace_event(f"init {self.__class__.__name__}")
34
31
  self.temp_id = uuid.uuid4().hex
35
- self.created_at = current_milli_time()
32
+ self.created_at = datetime.now(timezone.utc).isoformat()
36
33
  self.emitter = get_emitter()
37
- if not self.emitter:
38
- raise RuntimeError("Message should be instantiated in a Chainlit context")
39
34
 
40
35
  @abstractmethod
41
36
  def to_dict(self):
@@ -44,9 +39,14 @@ class MessageBase(ABC):
44
39
  async def _create(self):
45
40
  msg_dict = self.to_dict()
46
41
  if self.emitter.client and not self.id:
47
- self.id = await self.emitter.client.create_message(msg_dict)
48
- if self.id:
49
- msg_dict["id"] = self.id
42
+ try:
43
+ self.id = await self.emitter.client.create_message(msg_dict)
44
+ if self.id:
45
+ msg_dict["id"] = self.id
46
+ except Exception as e:
47
+ if self.fail_on_persist_error:
48
+ raise e
49
+ logger.error(f"Failed to persist message: {str(e)}")
50
50
 
51
51
  return msg_dict
52
52
 
@@ -77,8 +77,7 @@ class MessageBase(ABC):
77
77
  msg_dict = self.to_dict()
78
78
 
79
79
  if self.emitter.client and self.id:
80
- self.emitter.client.update_message(self.id, msg_dict)
81
- msg_dict["id"] = self.id
80
+ await self.emitter.client.update_message(self.id, msg_dict)
82
81
 
83
82
  await self.emitter.update_message(msg_dict)
84
83
 
@@ -171,7 +170,7 @@ class Message(MessageBase):
171
170
  super().__post_init__()
172
171
 
173
172
  def to_dict(self):
174
- return {
173
+ _dict = {
175
174
  "tempId": self.temp_id,
176
175
  "createdAt": self.created_at,
177
176
  "content": self.content,
@@ -182,6 +181,11 @@ class Message(MessageBase):
182
181
  "indent": self.indent,
183
182
  }
184
183
 
184
+ if self.id:
185
+ _dict["id"] = self.id
186
+
187
+ return _dict
188
+
185
189
  async def send(self):
186
190
  """
187
191
  Send the message to the UI and persist it in the cloud if a project ID is configured.
@@ -214,10 +218,12 @@ class ErrorMessage(MessageBase):
214
218
  content: str,
215
219
  author: str = config.ui.name,
216
220
  indent: int = 0,
221
+ fail_on_persist_error: bool = False,
217
222
  ):
218
223
  self.content = content
219
224
  self.author = author
220
225
  self.indent = indent
226
+ self.fail_on_persist_error = fail_on_persist_error
221
227
 
222
228
  super().__post_init__()
223
229
 
chainlit/server.py CHANGED
@@ -6,11 +6,13 @@ mimetypes.add_type("text/css", ".css")
6
6
  import os
7
7
  import json
8
8
  import webbrowser
9
+ from pathlib import Path
10
+
9
11
 
10
12
  from contextlib import asynccontextmanager
11
13
  from watchfiles import awatch
12
14
 
13
- from fastapi import FastAPI
15
+ from fastapi import FastAPI, Request
14
16
  from fastapi.responses import (
15
17
  HTMLResponse,
16
18
  JSONResponse,
@@ -21,17 +23,25 @@ from fastapi_socketio import SocketManager
21
23
  from starlette.middleware.cors import CORSMiddleware
22
24
  import asyncio
23
25
 
24
- from chainlit.config import config, load_module, DEFAULT_HOST
26
+ from chainlit.context import emitter_var, loop_var
27
+ from chainlit.config import config, load_module, reload_config, DEFAULT_HOST
25
28
  from chainlit.session import Session, sessions
26
29
  from chainlit.user_session import user_sessions
27
- from chainlit.client import CloudClient
30
+ from chainlit.client.cloud import CloudClient
31
+ from chainlit.client.local import LocalClient
32
+ from chainlit.client.utils import get_client
28
33
  from chainlit.emitter import ChainlitEmitter
29
34
  from chainlit.markdown import get_markdown_str
30
35
  from chainlit.action import Action
31
36
  from chainlit.message import Message, ErrorMessage
32
37
  from chainlit.telemetry import trace_event
33
38
  from chainlit.logger import logger
34
- from chainlit.types import CompletionRequest
39
+ from chainlit.types import (
40
+ CompletionRequest,
41
+ UpdateFeedbackRequest,
42
+ GetConversationsRequest,
43
+ DeleteConversationRequest,
44
+ )
35
45
 
36
46
 
37
47
  @asynccontextmanager
@@ -39,32 +49,56 @@ async def lifespan(app: FastAPI):
39
49
  host = config.run.host
40
50
  port = config.run.port
41
51
 
42
- if not config.run.headless:
43
- if host == DEFAULT_HOST:
44
- url = f"http://localhost:{port}"
45
- else:
46
- url = f"http://{host}:{port}"
52
+ if host == DEFAULT_HOST:
53
+ url = f"http://localhost:{port}"
54
+ else:
55
+ url = f"http://{host}:{port}"
56
+
57
+ logger.info(f"Your app is available at {url}")
47
58
 
48
- logger.info(f"Your app is available at {url}")
59
+ if not config.run.headless:
60
+ # Add a delay before opening the browser
61
+ await asyncio.sleep(1)
49
62
  webbrowser.open(url)
50
63
 
64
+ if config.project.database == "local":
65
+ from prisma import Client, register
66
+
67
+ client = Client()
68
+ register(client)
69
+ await client.connect()
70
+
51
71
  watch_task = None
52
72
  stop_event = asyncio.Event()
53
73
 
54
74
  if config.run.watch:
55
75
 
56
76
  async def watch_files_for_changes():
77
+ extensions = [".py"]
78
+ files = ["chainlit.md", "config.toml"]
57
79
  async for changes in awatch(config.root, stop_event=stop_event):
58
80
  for change_type, file_path in changes:
59
81
  file_name = os.path.basename(file_path)
60
82
  file_ext = os.path.splitext(file_name)[1]
61
83
 
62
- if file_ext.lower() == ".py" or file_name.lower() == "chainlit.md":
63
- logger.info(f"File {change_type.name}: {file_name}")
84
+ if file_ext.lower() in extensions or file_name.lower() in files:
85
+ logger.info(
86
+ f"File {change_type.name}: {file_name}. Reloading app..."
87
+ )
88
+
89
+ try:
90
+ reload_config()
91
+ except Exception as e:
92
+ logger.error(f"Error reloading config: {e}")
93
+ break
64
94
 
65
95
  # Reload the module if the module name is specified in the config
66
96
  if config.run.module_name:
67
- load_module(config.run.module_name)
97
+ try:
98
+ load_module(config.run.module_name)
99
+ except Exception as e:
100
+ logger.error(f"Error reloading module: {e}")
101
+ break
68
102
 
69
103
  await socket.emit("reload", {})
70
104
 
@@ -74,12 +108,16 @@ async def lifespan(app: FastAPI):
74
108
 
75
109
  try:
76
110
  yield
77
- except KeyboardInterrupt:
78
- logger.error("KeyboardInterrupt received, stopping the watch task...")
79
111
  finally:
112
+ if config.project.database == "local":
113
+ await client.disconnect()
80
114
  if watch_task:
81
- stop_event.set()
82
- await watch_task
115
+ try:
116
+ stop_event.set()
117
+ watch_task.cancel()
118
+ await watch_task
119
+ except asyncio.exceptions.CancelledError:
120
+ pass
83
121
 
84
122
 
85
123
  root_dir = os.path.dirname(os.path.abspath(__file__))
@@ -187,6 +225,80 @@ async def project_settings():
187
225
  )
188
226
 
189
227
 
228
+ @app.put("/message/feedback")
229
+ async def update_feedback(request: Request, update: UpdateFeedbackRequest):
230
+ """Update the human feedback for a particular message."""
231
+
232
+ client = await get_client(request)
233
+ await client.set_human_feedback(
234
+ message_id=update.messageId, feedback=update.feedback
235
+ )
236
+ return JSONResponse(content={"success": True})
237
+
238
+
239
+ @app.get("/project/members")
240
+ async def get_project_members(request: Request):
241
+ """Get all the members of a project."""
242
+
243
+ client = await get_client(request)
244
+ res = await client.get_project_members()
245
+ return JSONResponse(content=res)
246
+
247
+
248
+ @app.get("/project/role")
249
+ async def get_member_role(request: Request):
250
+ """Get the role of a member."""
251
+
252
+ client = await get_client(request)
253
+ res = await client.get_member_role()
254
+ return PlainTextResponse(content=res)
255
+
256
+
257
+ @app.post("/project/conversations")
258
+ async def get_project_conversations(request: Request, payload: GetConversationsRequest):
259
+ """Get the conversations page by page."""
260
+
261
+ client = await get_client(request)
262
+ res = await client.get_conversations(payload.pagination, payload.filter)
263
+ return JSONResponse(content=res.to_dict())
264
+
265
+
266
+ @app.get("/project/conversation/{conversation_id}")
267
+ async def get_conversation(request: Request, conversation_id: str):
268
+ """Get a specific conversation."""
269
+
270
+ client = await get_client(request)
271
+ res = await client.get_conversation(int(conversation_id))
272
+ return JSONResponse(content=res)
273
+
274
+
275
+ @app.get("/project/conversation/{conversation_id}/element/{element_id}")
276
+ async def get_conversation(request: Request, conversation_id: str, element_id: str):
277
+ """Get a specific conversation."""
278
+
279
+ client = await get_client(request)
280
+ res = await client.get_element(int(conversation_id), int(element_id))
281
+ return JSONResponse(content=res)
282
+
283
+
284
+ @app.delete("/project/conversation")
285
+ async def delete_conversation(request: Request, payload: DeleteConversationRequest):
286
+ """Delete a conversation."""
287
+
288
+ client = await get_client(request)
289
+ await client.delete_conversation(conversation_id=payload.conversationId)
290
+ return JSONResponse(content={"success": True})
291
+
292
+
293
+ @app.get("/files/{filename:path}")
294
+ async def serve_file(filename: str):
295
+ file_path = Path(config.project.local_fs_path) / filename
296
+ if file_path.is_file():
297
+ return FileResponse(file_path)
298
+ else:
299
+ return {"error": "File not found"}
300
+
301
+
190
302
  @app.get("/{path:path}")
191
303
  async def serve(path: str):
192
304
  """Serve the UI."""
@@ -217,36 +329,30 @@ def need_session(id: str):
217
329
  async def connect(sid, environ):
218
330
  user_env = environ.get("HTTP_USER_ENV")
219
331
  authorization = environ.get("HTTP_AUTHORIZATION")
220
- cloud_client = None
221
-
222
- # Check decorated functions
223
- if (
224
- not config.code.lc_factory
225
- and not config.code.on_message
226
- and not config.code.on_chat_start
227
- ):
228
- logger.error(
229
- "Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
230
- )
231
- return False
332
+ client = None
232
333
 
233
334
  # Check authorization
234
335
  if not config.project.public and not authorization:
235
336
  # Refuse connection if the app is private and no access token is provided
236
337
  trace_event("no_access_token")
237
- logger.error("No access token provided")
338
+ logger.error("Connection refused: No access token provided")
238
339
  return False
239
- elif authorization and config.project.id:
340
+ elif authorization and config.project.id and config.project.database == "cloud":
240
341
  # Create the cloud client
241
- cloud_client = CloudClient(
342
+ client = CloudClient(
242
343
  project_id=config.project.id,
243
- session_id=sid,
244
344
  access_token=authorization,
245
345
  )
246
- is_project_member = await cloud_client.is_project_member()
346
+ is_project_member = await client.is_project_member()
247
347
  if not is_project_member:
248
- logger.error("You are not a member of this project")
348
+ logger.error("Connection refused: You are not a member of this project")
249
349
  return False
350
+ elif config.project.database == "local":
351
+ client = LocalClient()
352
+ elif config.project.database == "custom":
353
+ if not config.code.client_factory:
354
+ raise ValueError("Client factory not provided")
355
+ client = await config.code.client_factory()
250
356
 
251
357
  # Check user env
252
358
  if config.project.user_env:
@@ -256,10 +362,12 @@ async def connect(sid, environ):
256
362
  for key in config.project.user_env:
257
363
  if key not in user_env:
258
364
  trace_event("missing_user_env")
259
- logger.error("Missing user environment variable: " + key)
365
+ logger.error(
366
+ "Connection refused: Missing user environment variable: " + key
367
+ )
260
368
  return False
261
369
  else:
262
- logger.error("Missing user environment variables")
370
+ logger.error("Connection refused: Missing user environment variables")
263
371
  return False
264
372
 
265
373
  # Create the session
@@ -283,9 +391,8 @@ async def connect(sid, environ):
283
391
  "id": sid,
284
392
  "emit": emit_fn,
285
393
  "ask_user": ask_user_fn,
286
- "client": cloud_client,
394
+ "client": client,
287
395
  "user_env": user_env,
288
- "running_sync": False,
289
396
  "should_stop": False,
290
397
  } # type: Session
291
398
 
@@ -298,15 +405,17 @@ async def connect(sid, environ):
298
405
  @socket.on("connection_successful")
299
406
  async def connection_successful(sid):
300
407
  session = need_session(sid)
301
- __chainlit_emitter__ = ChainlitEmitter(session)
408
+ emitter_var.set(ChainlitEmitter(session))
409
+ loop_var.set(asyncio.get_event_loop())
410
+
302
411
  if config.code.lc_factory:
303
412
  """Instantiate the langchain agent and store it in the session."""
304
- agent = await config.code.lc_factory(__chainlit_emitter__=__chainlit_emitter__)
413
+ agent = await config.code.lc_factory()
305
414
  session["agent"] = agent
306
415
 
307
416
  if config.code.on_chat_start:
308
417
  """Call the on_chat_start function provided by the developer."""
309
- await config.code.on_chat_start(__chainlit_emitter__=__chainlit_emitter__)
418
+ await config.code.on_chat_start()
310
419
 
311
420
 
312
421
  @socket.on("disconnect")
@@ -326,7 +435,8 @@ async def stop(sid):
326
435
  trace_event("stop_task")
327
436
  session = sessions[sid]
328
437
 
329
- __chainlit_emitter__ = ChainlitEmitter(session)
438
+ emitter_var.set(ChainlitEmitter(session))
439
+ loop_var.set(asyncio.get_event_loop())
330
440
 
331
441
  await Message(author="System", content="Task stopped by the user.").send()
332
442
 
@@ -340,8 +450,11 @@ async def process_message(session: Session, author: str, input_str: str):
340
450
  """Process a message from the user."""
341
451
 
342
452
  try:
343
- __chainlit_emitter__ = ChainlitEmitter(session)
344
- await __chainlit_emitter__.task_start()
453
+ emitter = ChainlitEmitter(session)
454
+ emitter_var.set(emitter)
455
+ loop_var.set(asyncio.get_event_loop())
456
+
457
+ await emitter.task_start()
345
458
 
346
459
  if session["client"]:
347
460
  # If cloud is enabled, persist the message
@@ -363,7 +476,6 @@ async def process_message(session: Session, author: str, input_str: str):
363
476
  await config.code.lc_run(
364
477
  langchain_agent,
365
478
  input_str,
366
- __chainlit_emitter__=__chainlit_emitter__,
367
479
  )
368
480
  return
369
481
  else:
@@ -374,9 +486,7 @@ async def process_message(session: Session, author: str, input_str: str):
374
486
 
375
487
  if config.code.lc_postprocess:
376
488
  # If the developer provided a custom postprocess function, use it
377
- await config.code.lc_postprocess(
378
- raw_res, __chainlit_emitter__=__chainlit_emitter__
379
- )
489
+ await config.code.lc_postprocess(raw_res)
380
490
  return
381
491
  elif output_key is not None:
382
492
  # Use the output key if provided
@@ -389,16 +499,16 @@ async def process_message(session: Session, author: str, input_str: str):
389
499
 
390
500
  elif config.code.on_message:
391
501
  # If no langchain agent is available, call the on_message function provided by the developer
392
- await config.code.on_message(
393
- input_str, __chainlit_emitter__=__chainlit_emitter__
394
- )
502
+ await config.code.on_message(input_str)
395
503
  except InterruptedError:
396
504
  pass
397
505
  except Exception as e:
398
506
  logger.exception(e)
399
- await ErrorMessage(author="Error", content=str(e)).send()
507
+ await ErrorMessage(
508
+ author="Error", content=str(e) or e.__class__.__name__
509
+ ).send()
400
510
  finally:
401
- await __chainlit_emitter__.task_end()
511
+ await emitter.task_end()
402
512
 
403
513
 
404
514
  @socket.on("ui_message")
@@ -413,11 +523,10 @@ async def message(sid, data):
413
523
  await process_message(session, author, input_str)
414
524
 
415
525
 
416
- async def process_action(session: Session, action: Action):
417
- __chainlit_emitter__ = ChainlitEmitter(session)
526
+ async def process_action(action: Action):
418
527
  callback = config.code.action_callbacks.get(action.name)
419
528
  if callback:
420
- await callback(action, __chainlit_emitter__=__chainlit_emitter__)
529
+ await callback(action)
421
530
  else:
422
531
  logger.warning("No callback found for action %s", action.name)
423
532
 
@@ -426,8 +535,9 @@ async def process_action(session: Session, action: Action):
426
535
  async def call_action(sid, action):
427
536
  """Handle an action call from the UI."""
428
537
  session = need_session(sid)
538
+ emitter_var.set(ChainlitEmitter(session))
539
+ loop_var.set(asyncio.get_event_loop())
429
540
 
430
- __chainlit_emitter__ = ChainlitEmitter(session)
431
541
  action = Action(**action)
432
542
 
433
- await process_action(session, action)
543
+ await process_action(action)
chainlit/session.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from typing import Dict, TypedDict, Optional, Callable, Any, Union
2
- from chainlit.client import BaseClient
2
+ from chainlit.client.base import BaseClient
3
3
  from chainlit.types import AskResponse
4
4
 
5
5
 
@@ -16,8 +16,6 @@ class Session(TypedDict):
16
16
  user_env: Dict[str, str]
17
17
  # Optional langchain agent
18
18
  agent: Any
19
- # If the session is currently running a sync task
20
- running_sync: bool
21
19
  # Whether the current task should be stopped
22
20
  should_stop: bool
23
21
  # Optional client to persist messages and files
chainlit/sync.py CHANGED
@@ -1,37 +1,25 @@
1
- from typing import Any, Callable
1
+ import sys
2
+ from typing import Any, TypeVar, Coroutine
3
+
4
+ if sys.version_info >= (3, 10):
5
+ from typing import ParamSpec
6
+ else:
7
+ from typing_extensions import ParamSpec
2
8
 
3
9
  import asyncio
4
- from syncer import sync
5
10
  from asyncer import asyncify
6
11
 
7
- from chainlit.emitter import get_emitter
8
-
9
-
10
- def make_async(function: Callable):
11
- emitter = get_emitter()
12
- if not emitter:
13
- raise RuntimeError(
14
- "Emitter not found, please call make_async in a Chainlit context."
15
- )
12
+ from chainlit.context import get_loop
16
13
 
17
- def wrapper(*args, **kwargs):
18
- emitter.session["running_sync"] = True
19
- __chainlit_emitter__ = emitter
20
- res = function(*args, **kwargs)
21
- emitter.session["running_sync"] = False
22
- return res
23
14
 
24
- return asyncify(wrapper, cancellable=True)
15
+ make_async = asyncify
25
16
 
17
+ T_Retval = TypeVar("T_Retval")
18
+ T_ParamSpec = ParamSpec("T_ParamSpec")
19
+ T = TypeVar("T")
26
20
 
27
- def run_sync(co: Any):
28
- try:
29
- loop = asyncio.get_event_loop()
30
- except RuntimeError as e:
31
- if "There is no current event loop" in str(e):
32
- loop = None
33
21
 
34
- if loop is None or not loop.is_running():
35
- loop = asyncio.new_event_loop()
36
- asyncio.set_event_loop(loop)
37
- return sync(co)
22
+ def run_sync(co: Coroutine[Any, Any, T_Retval]) -> T_Retval:
23
+ loop = get_loop()
24
+ result = asyncio.run_coroutine_threadsafe(co, loop=loop)
25
+ return result.result()