chainlit 1.0.0rc2__py3-none-any.whl → 1.0.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -20,7 +20,7 @@
20
20
  <script>
21
21
  const global = globalThis;
22
22
  </script>
23
- <script type="module" crossorigin src="/assets/index-aac0232b.js"></script>
23
+ <script type="module" crossorigin src="/assets/index-c4f40824.js"></script>
24
24
  <link rel="stylesheet" href="/assets/index-d088547c.css">
25
25
  </head>
26
26
  <body>
@@ -1,9 +1,11 @@
1
1
  from datetime import datetime
2
2
  from typing import Any, Generic, List, Optional, TypeVar
3
+ import re
3
4
 
4
5
  from chainlit.context import context
5
6
  from chainlit.step import Step
6
7
  from chainlit.sync import run_sync
8
+ from chainlit import Message
7
9
  from haystack.agents import Agent, Tool
8
10
  from haystack.agents.agent_step import AgentStep
9
11
 
@@ -34,7 +36,7 @@ class HaystackAgentCallbackHandler:
34
36
  stack: Stack[Step]
35
37
  last_step: Optional[Step]
36
38
 
37
- def __init__(self, agent: Agent):
39
+ def __init__(self, agent: Agent, stream_final_answer: bool = False, stream_final_answer_agent_name: str = 'Agent'):
38
40
  agent.callback_manager.on_agent_start += self.on_agent_start
39
41
  agent.callback_manager.on_agent_step += self.on_agent_step
40
42
  agent.callback_manager.on_agent_finish += self.on_agent_finish
@@ -44,10 +46,20 @@ class HaystackAgentCallbackHandler:
44
46
  agent.tm.callback_manager.on_tool_finish += self.on_tool_finish
45
47
  agent.tm.callback_manager.on_tool_error += self.on_tool_error
46
48
 
49
+ self.final_answer_pattern = agent.final_answer_pattern
50
+ self.stream_final_answer = stream_final_answer
51
+ self.stream_final_answer_agent_name = stream_final_answer_agent_name
52
+
47
53
  def on_agent_start(self, **kwargs: Any) -> None:
48
54
  # Prepare agent step message for streaming
49
55
  self.agent_name = kwargs.get("name", "Agent")
50
56
  self.stack = Stack[Step]()
57
+
58
+ if self.stream_final_answer:
59
+ self.final_stream = Message(author=self.stream_final_answer_agent_name, content="")
60
+ self.last_tokens: List[str] = []
61
+ self.answer_reached = False
62
+
51
63
  root_message = context.session.root_message
52
64
  parent_id = root_message.id if root_message else None
53
65
  run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
@@ -59,10 +71,11 @@ class HaystackAgentCallbackHandler:
59
71
  self.stack.push(run_step)
60
72
 
61
73
  def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
62
- run_step = self.stack.pop()
63
- run_step.end = datetime.utcnow().isoformat()
64
- run_step.output = agent_step.prompt_node_response
65
- run_sync(run_step.update())
74
+ if self.last_step:
75
+ run_step = self.last_step
76
+ run_step.end = datetime.utcnow().isoformat()
77
+ run_step.output = agent_step.prompt_node_response
78
+ run_sync(run_step.update())
66
79
 
67
80
  # This method is called when a step has finished
68
81
  def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None:
@@ -77,11 +90,24 @@ class HaystackAgentCallbackHandler:
77
90
 
78
91
  if not agent_step.is_last():
79
92
  # Prepare step for next agent step
80
- step = Step(name=self.agent_name, parent_id=self.stack.peek().id)
93
+ step = Step(name=self.agent_name, parent_id=self.last_step.id)
81
94
  self.stack.push(step)
82
95
 
83
96
  def on_new_token(self, token, **kwargs: Any) -> None:
84
97
  # Stream agent step tokens
98
+ if self.stream_final_answer:
99
+ if self.answer_reached:
100
+ run_sync(self.final_stream.stream_token(token))
101
+ else:
102
+ self.last_tokens.append(token)
103
+
104
+ last_tokens_concat = ''.join(self.last_tokens)
105
+ final_answer_match = re.search(self.final_answer_pattern, last_tokens_concat)
106
+
107
+ if final_answer_match:
108
+ self.answer_reached = True
109
+ run_sync(self.final_stream.stream_token(final_answer_match.group(1)))
110
+
85
111
  run_sync(self.stack.peek().stream_token(token))
86
112
 
87
113
  def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None:
@@ -6,11 +6,11 @@ from chainlit.context import context_var
6
6
  from chainlit.message import Message
7
7
  from chainlit.playground.providers.openai import stringify_function_call
8
8
  from chainlit.step import Step, TrueStepType
9
- from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
10
9
  from langchain.callbacks.tracers.base import BaseTracer
11
10
  from langchain.callbacks.tracers.schemas import Run
12
11
  from langchain.schema import BaseMessage
13
12
  from langchain.schema.output import ChatGenerationChunk, GenerationChunk
13
+ from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
14
14
 
15
15
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
16
16
 
@@ -388,11 +388,13 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
388
388
  self.steps = {}
389
389
  self.parent_id_map = {}
390
390
  self.ignored_runs = set()
391
- self.root_parent_id = (
392
- self.context.session.root_message.id
393
- if self.context.session.root_message
394
- else None
395
- )
391
+
392
+ if self.context.current_step:
393
+ self.root_parent_id = self.context.current_step.id
394
+ elif self.context.session.root_message:
395
+ self.root_parent_id = self.context.session.root_message.id
396
+ else:
397
+ self.root_parent_id = None
396
398
 
397
399
  if to_ignore is None:
398
400
  self.to_ignore = DEFAULT_TO_IGNORE
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
4
4
  from chainlit.context import context_var
5
5
  from chainlit.element import Text
6
6
  from chainlit.step import Step, StepType
7
- from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
7
+ from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
8
8
  from llama_index.callbacks import TokenCountingHandler
9
9
  from llama_index.callbacks.schema import CBEventType, EventPayload
10
10
  from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
@@ -36,14 +36,19 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
36
36
  )
37
37
  self.context = context_var.get()
38
38
 
39
+ if self.context.current_step:
40
+ self.root_parent_id = self.context.current_step.id
41
+ elif self.context.session.root_message:
42
+ self.root_parent_id = self.context.session.root_message.id
43
+ else:
44
+ self.root_parent_id = None
45
+
39
46
  self.steps = {}
40
47
 
41
48
  def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
42
49
  if event_parent_id and event_parent_id in self.steps:
43
50
  return event_parent_id
44
- if root_message := self.context.session.root_message:
45
- return root_message.id
46
- return None
51
+ return self.root_parent_id
47
52
 
48
53
  def _restore_context(self) -> None:
49
54
  """Restore Chainlit context in the current thread
@@ -113,7 +118,10 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
113
118
  [f"Source {idx}" for idx, _ in enumerate(sources)]
114
119
  )
115
120
  step.elements = [
116
- Text(name=f"Source {idx}", content=source.node.get_text())
121
+ Text(
122
+ name=f"Source {idx}",
123
+ content=source.node.get_text() or "Empty node",
124
+ )
117
125
  for idx, source in enumerate(sources)
118
126
  ]
119
127
  step.output = f"Retrieved the following sources: {source_refs}"
chainlit/message.py CHANGED
@@ -22,7 +22,7 @@ from chainlit.types import (
22
22
  AskSpec,
23
23
  FileDict,
24
24
  )
25
- from chainlit_client.step import MessageStepType
25
+ from literalai.step import MessageStepType
26
26
 
27
27
 
28
28
  class MessageBase(ABC):
@@ -409,6 +409,72 @@ class DescopeOAuthProvider(OAuthProvider):
409
409
  return (descope_user, user)
410
410
 
411
411
 
412
+ class AWSCognitoOAuthProvider(OAuthProvider):
413
+ id = "aws-cognito"
414
+ env = [
415
+ "OAUTH_COGNITO_CLIENT_ID",
416
+ "OAUTH_COGNITO_CLIENT_SECRET",
417
+ "OAUTH_COGNITO_DOMAIN",
418
+ ]
419
+ authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login"
420
+ token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token"
421
+
422
+ def __init__(self):
423
+ self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID")
424
+ self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET")
425
+ self.authorize_params = {
426
+ "response_type": "code",
427
+ "client_id": self.client_id,
428
+ "scope": "openid profile email",
429
+ }
430
+
431
+ async def get_token(self, code: str, url: str):
432
+ payload = {
433
+ "client_id": self.client_id,
434
+ "client_secret": self.client_secret,
435
+ "code": code,
436
+ "grant_type": "authorization_code",
437
+ "redirect_uri": url,
438
+ }
439
+ async with httpx.AsyncClient() as client:
440
+ response = await client.post(
441
+ self.token_url,
442
+ data=payload,
443
+ )
444
+ response.raise_for_status()
445
+ json = response.json()
446
+
447
+ token = json.get("access_token")
448
+ if not token:
449
+ raise HTTPException(
450
+ status_code=400, detail="Failed to get the access token"
451
+ )
452
+ return token
453
+
454
+ async def get_user_info(self, token: str):
455
+ user_info_url = (
456
+ f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo"
457
+ )
458
+ async with httpx.AsyncClient() as client:
459
+ response = await client.get(
460
+ user_info_url,
461
+ headers={"Authorization": f"Bearer {token}"},
462
+ )
463
+ response.raise_for_status()
464
+
465
+ cognito_user = response.json()
466
+
467
+ # Customize user metadata as needed
468
+ user = User(
469
+ identifier=cognito_user["email"],
470
+ metadata={
471
+ "image": cognito_user.get("picture", ""),
472
+ "provider": "aws-cognito",
473
+ },
474
+ )
475
+ return (cognito_user, user)
476
+
477
+
412
478
  providers = [
413
479
  GithubOAuthProvider(),
414
480
  GoogleOAuthProvider(),
@@ -416,6 +482,7 @@ providers = [
416
482
  OktaOAuthProvider(),
417
483
  Auth0OAuthProvider(),
418
484
  DescopeOAuthProvider(),
485
+ AWSCognitoOAuthProvider(),
419
486
  ]
420
487
 
421
488
 
@@ -9,6 +9,7 @@ from chainlit.playground.providers import (
9
9
  OpenAI,
10
10
  ChatVertexAI,
11
11
  GenerationVertexAI,
12
+ Gemini,
12
13
  )
13
14
 
14
15
  providers = {
@@ -19,6 +20,7 @@ providers = {
19
20
  Anthropic.id: Anthropic,
20
21
  ChatVertexAI.id: ChatVertexAI,
21
22
  GenerationVertexAI.id: GenerationVertexAI,
23
+ Gemini.id: Gemini,
22
24
  } # type: Dict[str, BaseProvider]
23
25
 
24
26
 
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
4
4
  from chainlit.config import config
5
5
  from chainlit.telemetry import trace_event
6
6
  from chainlit.types import GenerationRequest
7
- from chainlit_client import BaseGeneration, ChatGeneration, GenerationMessage
8
7
  from fastapi import HTTPException
8
+ from literalai import BaseGeneration, ChatGeneration, GenerationMessage
9
9
  from pydantic.dataclasses import dataclass
10
10
 
11
11
  from chainlit import input_widget
@@ -9,4 +9,5 @@ from .openai import (
9
9
  from .vertexai import (
10
10
  ChatVertexAI,
11
11
  GenerationVertexAI,
12
+ Gemini
12
13
  )
@@ -1,8 +1,8 @@
1
1
  from chainlit.input_widget import Select, Slider, Tags
2
2
  from chainlit.playground.provider import BaseProvider
3
- from chainlit_client import GenerationMessage
4
3
  from fastapi import HTTPException
5
4
  from fastapi.responses import StreamingResponse
5
+ from literalai import GenerationMessage
6
6
 
7
7
 
8
8
  class AnthropicProvider(BaseProvider):
@@ -1,9 +1,10 @@
1
- from typing import Union
1
+ from typing import List, Union
2
2
 
3
+ from chainlit.input_widget import InputWidget
3
4
  from chainlit.playground.provider import BaseProvider
4
5
  from chainlit.sync import make_async
5
- from chainlit_client import GenerationMessage
6
6
  from fastapi.responses import StreamingResponse
7
+ from literalai import GenerationMessage
7
8
 
8
9
 
9
10
  class LangchainGenericProvider(BaseProvider):
@@ -18,13 +19,14 @@ class LangchainGenericProvider(BaseProvider):
18
19
  id: str,
19
20
  name: str,
20
21
  llm: Union[LLM, BaseChatModel],
22
+ inputs: List[InputWidget] = [],
21
23
  is_chat: bool = False,
22
24
  ):
23
25
  super().__init__(
24
26
  id=id,
25
27
  name=name,
26
28
  env_vars={},
27
- inputs=[],
29
+ inputs=inputs,
28
30
  is_chat=is_chat,
29
31
  )
30
32
  self.llm = llm
@@ -65,10 +67,9 @@ class LangchainGenericProvider(BaseProvider):
65
67
 
66
68
  messages = self.create_generation(request)
67
69
 
68
- stream = make_async(self.llm.stream)
69
-
70
- result = await stream(
71
- input=messages,
70
+ # https://github.com/langchain-ai/langchain/issues/14980
71
+ result = await make_async(self.llm.stream)(
72
+ input=messages, **request.generation.settings
72
73
  )
73
74
 
74
75
  def create_event_stream():
@@ -237,7 +237,6 @@ class AzureOpenAIProvider(BaseProvider):
237
237
  api_version=env_settings["api_version"],
238
238
  azure_endpoint=env_settings["azure_endpoint"],
239
239
  azure_ad_token=self.get_var(request, "AZURE_AD_TOKEN"),
240
- azure_ad_token_provider=self.get_var(request, "AZURE_AD_TOKEN_PROVIDER"),
241
240
  azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"),
242
241
  )
243
242
  llm_settings = request.generation.settings
@@ -290,7 +289,6 @@ class AzureChatOpenAIProvider(BaseProvider):
290
289
  api_version=env_settings["api_version"],
291
290
  azure_endpoint=env_settings["azure_endpoint"],
292
291
  azure_ad_token=self.get_var(request, "AZURE_AD_TOKEN"),
293
- azure_ad_token_provider=self.get_var(request, "AZURE_AD_TOKEN_PROVIDER"),
294
292
  azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"),
295
293
  )
296
294
 
@@ -33,10 +33,10 @@ class ChatVertexAIProvider(BaseProvider):
33
33
 
34
34
  self.validate_env(request=request)
35
35
 
36
- llm_settings = request.prompt.settings
36
+ llm_settings = request.generation.settings
37
37
  self.require_settings(llm_settings)
38
38
 
39
- prompt = self.create_prompt(request)
39
+ messages = self.create_generation(request)
40
40
  model_name = llm_settings["model"]
41
41
  if model_name.startswith("chat-"):
42
42
  model = ChatModel.from_pretrained(model_name)
@@ -52,7 +52,7 @@ class ChatVertexAIProvider(BaseProvider):
52
52
 
53
53
  async def create_event_stream():
54
54
  for response in await cl.make_async(chat.send_message_streaming)(
55
- prompt[0].formatted, **llm_settings
55
+ messages, **llm_settings
56
56
  ):
57
57
  yield response.text
58
58
 
@@ -66,10 +66,10 @@ class GenerationVertexAIProvider(BaseProvider):
66
66
 
67
67
  self.validate_env(request=request)
68
68
 
69
- llm_settings = request.prompt.settings
69
+ llm_settings = request.generation.settings
70
70
  self.require_settings(llm_settings)
71
71
 
72
- prompt = self.create_prompt(request)
72
+ messages = self.create_generation(request)
73
73
  model_name = llm_settings["model"]
74
74
  if model_name.startswith("text-"):
75
75
  model = TextGenerationModel.from_pretrained(model_name)
@@ -84,13 +84,42 @@ class GenerationVertexAIProvider(BaseProvider):
84
84
 
85
85
  async def create_event_stream():
86
86
  for response in await cl.make_async(model.predict_streaming)(
87
- prompt, **llm_settings
87
+ messages, **llm_settings
88
88
  ):
89
89
  yield response.text
90
90
 
91
91
  return StreamingResponse(create_event_stream())
92
92
 
93
93
 
94
+ class GeminiProvider(BaseProvider):
95
+ async def create_completion(self, request):
96
+ await super().create_completion(request)
97
+ from vertexai.preview.generative_models import GenerativeModel
98
+ from google.cloud import aiplatform
99
+ import os
100
+
101
+ self.validate_env(request=request)
102
+
103
+ llm_settings = request.generation.settings
104
+ self.require_settings(llm_settings)
105
+
106
+ messages = self.create_generation(request)
107
+ aiplatform.init( # TODO: remove this when Gemini is released in all the regions
108
+ project=os.environ["GCP_PROJECT_ID"],
109
+ location='us-central1',
110
+ )
111
+ model = GenerativeModel(llm_settings["model"])
112
+ del llm_settings["model"]
113
+
114
+ async def create_event_stream():
115
+ for response in await cl.make_async(model.generate_content)(
116
+ messages, stream=True, generation_config=llm_settings
117
+ ):
118
+ yield response.candidates[0].content.parts[0].text
119
+
120
+ return StreamingResponse(create_event_stream())
121
+
122
+
94
123
  gcp_env_vars = {"google_application_credentials": "GOOGLE_APPLICATION_CREDENTIALS"}
95
124
 
96
125
  ChatVertexAI = ChatVertexAIProvider(
@@ -124,3 +153,19 @@ GenerationVertexAI = GenerationVertexAIProvider(
124
153
  ],
125
154
  is_chat=False,
126
155
  )
156
+
157
+ Gemini = GeminiProvider(
158
+ id="gemini",
159
+ env_vars=gcp_env_vars,
160
+ name="Gemini",
161
+ inputs=[
162
+ Select(
163
+ id="model",
164
+ label="Model",
165
+ values=["gemini-pro", "gemini-pro-vision"],
166
+ initial_value="gemini-pro",
167
+ ),
168
+ *vertexai_common_inputs,
169
+ ],
170
+ is_chat=False,
171
+ )
chainlit/server.py CHANGED
@@ -42,7 +42,16 @@ from chainlit.types import (
42
42
  UpdateFeedbackRequest,
43
43
  )
44
44
  from chainlit.user import PersistedUser, User
45
- from fastapi import Depends, FastAPI, HTTPException, Query, Request, UploadFile, status
45
+ from fastapi import (
46
+ Depends,
47
+ FastAPI,
48
+ HTTPException,
49
+ Query,
50
+ Request,
51
+ Response,
52
+ UploadFile,
53
+ status,
54
+ )
46
55
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
47
56
  from fastapi.security import OAuth2PasswordRequestForm
48
57
  from fastapi.staticfiles import StaticFiles
@@ -257,13 +266,24 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
257
266
  )
258
267
  access_token = create_jwt(user)
259
268
  if data_layer := get_data_layer():
260
- await data_layer.create_user(user)
269
+ try:
270
+ await data_layer.create_user(user)
271
+ except Exception as e:
272
+ logger.error(f"Error creating user: {e}")
273
+
261
274
  return {
262
275
  "access_token": access_token,
263
276
  "token_type": "bearer",
264
277
  }
265
278
 
266
279
 
280
+ @app.post("/logout")
281
+ async def logout(request: Request, response: Response):
282
+ if config.code.on_logout:
283
+ return await config.code.on_logout(request, response)
284
+ return {"success": True}
285
+
286
+
267
287
  @app.post("/auth/header")
268
288
  async def header_auth(request: Request):
269
289
  if not config.code.header_auth_callback:
@@ -282,7 +302,11 @@ async def header_auth(request: Request):
282
302
 
283
303
  access_token = create_jwt(user)
284
304
  if data_layer := get_data_layer():
285
- await data_layer.create_user(user)
305
+ try:
306
+ await data_layer.create_user(user)
307
+ except Exception as e:
308
+ logger.error(f"Error creating user: {e}")
309
+
286
310
  return {
287
311
  "access_token": access_token,
288
312
  "token_type": "bearer",
@@ -318,8 +342,9 @@ async def oauth_login(provider_id: str, request: Request):
318
342
  url=f"{provider.authorize_url}?{params}",
319
343
  )
320
344
  samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax") # type: Any
345
+ secure = samesite.lower() == 'none'
321
346
  response.set_cookie(
322
- "oauth_state", random, httponly=True, samesite=samesite, max_age=3 * 60
347
+ "oauth_state", random, httponly=True, samesite=samesite, secure=secure, max_age=3 * 60
323
348
  )
324
349
  return response
325
350
 
@@ -389,7 +414,10 @@ async def oauth_callback(
389
414
  access_token = create_jwt(user)
390
415
 
391
416
  if data_layer := get_data_layer():
392
- await data_layer.create_user(user)
417
+ try:
418
+ await data_layer.create_user(user)
419
+ except Exception as e:
420
+ logger.error(f"Error creating user: {e}")
393
421
 
394
422
  params = urllib.parse.urlencode(
395
423
  {
@@ -471,12 +499,12 @@ async def update_feedback(
471
499
  """Update the human feedback for a particular message."""
472
500
  data_layer = get_data_layer()
473
501
  if not data_layer:
474
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
502
+ raise HTTPException(status_code=500, detail="Data persistence is not enabled")
475
503
 
476
504
  try:
477
505
  feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
478
506
  except Exception as e:
479
- raise HTTPException(detail=str(e), status_code=401)
507
+ raise HTTPException(detail=str(e), status_code=500)
480
508
 
481
509
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
482
510
 
@@ -599,7 +627,6 @@ async def upload_file(
599
627
  async def get_file(
600
628
  file_id: str,
601
629
  session_id: Optional[str] = None,
602
- token: Optional[str] = None,
603
630
  ):
604
631
  from chainlit.session import WebsocketSession
605
632
 
@@ -611,13 +638,6 @@ async def get_file(
611
638
  detail="Session not found",
612
639
  )
613
640
 
614
- if current_user := await get_current_user(token or ""):
615
- if not session.user or session.user.identifier != current_user.identifier:
616
- raise HTTPException(
617
- status_code=401,
618
- detail="You are not authorized to upload files for this session",
619
- )
620
-
621
641
  if file_id in session.files:
622
642
  file = session.files[file_id]
623
643
  return FileResponse(file["path"], media_type=file["type"])
chainlit/session.py CHANGED
@@ -5,6 +5,7 @@ import uuid
5
5
  from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
6
6
 
7
7
  import aiofiles
8
+ from chainlit.logger import logger
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  from chainlit.message import Message
@@ -56,7 +57,7 @@ class BaseSession:
56
57
  self.user = user
57
58
  self.token = token
58
59
  self.root_message = root_message
59
- self.has_user_message = False
60
+ self.has_first_interaction = False
60
61
  self.user_env = user_env or {}
61
62
  self.chat_profile = chat_profile
62
63
  self.active_steps = []
@@ -242,7 +243,10 @@ class WebsocketSession(BaseSession):
242
243
  for method_name, queue in self.thread_queues.items():
243
244
  while queue:
244
245
  method, self, args, kwargs = queue.popleft()
245
- await method(self, *args, **kwargs)
246
+ try:
247
+ await method(self, *args, **kwargs)
248
+ except Exception as e:
249
+ logger.error(f"Error while flushing {method_name}: {e}")
246
250
 
247
251
  @classmethod
248
252
  def get(cls, socket_id: str):