chainlit 1.3.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (82) hide show
  1. chainlit/__init__.py +58 -56
  2. chainlit/action.py +12 -10
  3. chainlit/{auth.py → auth/__init__.py} +24 -34
  4. chainlit/auth/cookie.py +123 -0
  5. chainlit/auth/jwt.py +37 -0
  6. chainlit/cache.py +4 -6
  7. chainlit/callbacks.py +65 -11
  8. chainlit/chat_context.py +2 -2
  9. chainlit/chat_settings.py +3 -1
  10. chainlit/cli/__init__.py +15 -2
  11. chainlit/config.py +46 -90
  12. chainlit/context.py +4 -3
  13. chainlit/copilot/dist/index.js +8608 -642
  14. chainlit/data/__init__.py +96 -8
  15. chainlit/data/acl.py +3 -2
  16. chainlit/data/base.py +1 -15
  17. chainlit/data/chainlit_data_layer.py +584 -0
  18. chainlit/data/dynamodb.py +7 -4
  19. chainlit/data/literalai.py +4 -6
  20. chainlit/data/sql_alchemy.py +9 -8
  21. chainlit/data/storage_clients/__init__.py +0 -0
  22. chainlit/data/{storage_clients.py → storage_clients/azure.py} +2 -33
  23. chainlit/data/storage_clients/azure_blob.py +80 -0
  24. chainlit/data/storage_clients/base.py +22 -0
  25. chainlit/data/storage_clients/gcs.py +78 -0
  26. chainlit/data/storage_clients/s3.py +49 -0
  27. chainlit/discord/__init__.py +4 -4
  28. chainlit/discord/app.py +2 -1
  29. chainlit/element.py +41 -9
  30. chainlit/emitter.py +37 -16
  31. chainlit/frontend/dist/assets/{DailyMotion-CwoOhIL8.js → DailyMotion-DgRzV5GZ.js} +1 -1
  32. chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +22 -0
  33. chainlit/frontend/dist/assets/{Facebook-BhnGXlzq.js → Facebook-C0vx6HWv.js} +1 -1
  34. chainlit/frontend/dist/assets/{FilePlayer-CPSVT6fz.js → FilePlayer-CdhzeHPP.js} +1 -1
  35. chainlit/frontend/dist/assets/{Kaltura-COYaLzsL.js → Kaltura-5iVmeUct.js} +1 -1
  36. chainlit/frontend/dist/assets/{Mixcloud-JdadNiQ5.js → Mixcloud-C2zi77Ex.js} +1 -1
  37. chainlit/frontend/dist/assets/{Mux-CBN7RO2u.js → Mux-Vkebogdf.js} +1 -1
  38. chainlit/frontend/dist/assets/{Preview-CxAFvvjV.js → Preview-DwY_sEIl.js} +1 -1
  39. chainlit/frontend/dist/assets/{SoundCloud-JlgmASWm.js → SoundCloud-CREBXAWo.js} +1 -1
  40. chainlit/frontend/dist/assets/{Streamable-CUWgr6Zw.js → Streamable-B5Lu25uy.js} +1 -1
  41. chainlit/frontend/dist/assets/{Twitch-BiN1HEDM.js → Twitch-y9iKCcM1.js} +1 -1
  42. chainlit/frontend/dist/assets/{Vidyard-qhPmrhDm.js → Vidyard-ClYvcuEu.js} +1 -1
  43. chainlit/frontend/dist/assets/{Vimeo-CrZVSCaT.js → Vimeo-D6HvM2jt.js} +1 -1
  44. chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +1 -0
  45. chainlit/frontend/dist/assets/{YouTube-DKjw5Hbn.js → YouTube-D10tR6CJ.js} +1 -1
  46. chainlit/frontend/dist/assets/index-CI4qFOt5.js +8665 -0
  47. chainlit/frontend/dist/assets/index-CrrqM0nZ.css +1 -0
  48. chainlit/frontend/dist/assets/{react-plotly-Dpmqg5Sy.js → react-plotly-BpxUS-ab.js} +1 -1
  49. chainlit/frontend/dist/index.html +2 -2
  50. chainlit/haystack/callbacks.py +5 -4
  51. chainlit/input_widget.py +6 -4
  52. chainlit/langchain/callbacks.py +56 -47
  53. chainlit/langflow/__init__.py +1 -0
  54. chainlit/llama_index/callbacks.py +7 -7
  55. chainlit/message.py +8 -10
  56. chainlit/mistralai/__init__.py +3 -2
  57. chainlit/oauth_providers.py +70 -3
  58. chainlit/openai/__init__.py +3 -2
  59. chainlit/secret.py +1 -1
  60. chainlit/server.py +481 -182
  61. chainlit/session.py +7 -5
  62. chainlit/slack/__init__.py +3 -3
  63. chainlit/slack/app.py +3 -2
  64. chainlit/socket.py +89 -112
  65. chainlit/step.py +12 -12
  66. chainlit/sync.py +2 -1
  67. chainlit/teams/__init__.py +3 -3
  68. chainlit/teams/app.py +1 -0
  69. chainlit/translations/en-US.json +2 -1
  70. chainlit/translations/nl-NL.json +229 -0
  71. chainlit/types.py +24 -8
  72. chainlit/user.py +2 -1
  73. chainlit/utils.py +3 -2
  74. chainlit/version.py +3 -2
  75. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/METADATA +17 -37
  76. chainlit-2.0.0.dist-info/RECORD +106 -0
  77. chainlit/frontend/dist/assets/Wistia-C891KrBP.js +0 -1
  78. chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
  79. chainlit/frontend/dist/assets/index-DLRdQOIx.js +0 -723
  80. chainlit-1.3.1.dist-info/RECORD +0 -96
  81. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/WHEEL +0 -0
  82. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -21,8 +21,8 @@
21
21
  <script>
22
22
  const global = globalThis;
23
23
  </script>
24
- <script type="module" crossorigin src="/assets/index-DLRdQOIx.js"></script>
25
- <link rel="stylesheet" crossorigin href="/assets/index-CwmincdQ.css">
24
+ <script type="module" crossorigin src="/assets/index-CI4qFOt5.js"></script>
25
+ <link rel="stylesheet" crossorigin href="/assets/index-CrrqM0nZ.css">
26
26
  </head>
27
27
  <body>
28
28
  <div id="root"></div>
@@ -1,13 +1,14 @@
1
1
  import re
2
2
  from typing import Any, Generic, List, Optional, TypeVar
3
3
 
4
- from chainlit import Message
5
- from chainlit.step import Step
6
- from chainlit.sync import run_sync
7
4
  from haystack.agents import Agent, Tool
8
5
  from haystack.agents.agent_step import AgentStep
9
6
  from literalai.helper import utc_now
10
7
 
8
+ from chainlit import Message
9
+ from chainlit.step import Step
10
+ from chainlit.sync import run_sync
11
+
11
12
  T = TypeVar("T")
12
13
 
13
14
 
@@ -131,7 +132,7 @@ class HaystackAgentCallbackHandler:
131
132
  tool_result: str,
132
133
  tool_name: Optional[str] = None,
133
134
  tool_input: Optional[str] = None,
134
- **kwargs: Any
135
+ **kwargs: Any,
135
136
  ) -> None:
136
137
  # Tool finished, send step with tool_result
137
138
  tool_step = self.stack.pop()
chainlit/input_widget.py CHANGED
@@ -2,8 +2,10 @@ from abc import abstractmethod
2
2
  from collections import defaultdict
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
+ from pydantic import Field
6
+ from pydantic.dataclasses import dataclass
7
+
5
8
  from chainlit.types import InputWidgetType
6
- from pydantic.dataclasses import Field, dataclass
7
9
 
8
10
 
9
11
  @dataclass
@@ -75,7 +77,7 @@ class Select(InputWidget):
75
77
  initial: Optional[str] = None
76
78
  initial_index: Optional[int] = None
77
79
  initial_value: Optional[str] = None
78
- values: List[str] = Field(default_factory=lambda: [])
80
+ values: List[str] = Field(default_factory=list)
79
81
  items: Dict[str, str] = Field(default_factory=lambda: defaultdict(dict))
80
82
 
81
83
  def __post_init__(
@@ -167,8 +169,8 @@ class Tags(InputWidget):
167
169
  """Useful to create an input for an array of strings."""
168
170
 
169
171
  type: InputWidgetType = "tags"
170
- initial: List[str] = Field(default_factory=lambda: [])
171
- values: List[str] = Field(default_factory=lambda: [])
172
+ initial: List[str] = Field(default_factory=list)
173
+ values: List[str] = Field(default_factory=list)
172
174
 
173
175
  def to_dict(self) -> Dict[str, Any]:
174
176
  return {
@@ -3,17 +3,19 @@ import time
3
3
  from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
4
4
  from uuid import UUID
5
5
 
6
- from chainlit.context import context_var
7
- from chainlit.message import Message
8
- from chainlit.step import Step
9
- from langchain.callbacks.tracers.base import BaseTracer
6
+ import pydantic
10
7
  from langchain.callbacks.tracers.schemas import Run
11
8
  from langchain.schema import BaseMessage
12
9
  from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
10
+ from langchain_core.tracers.base import AsyncBaseTracer
13
11
  from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
14
12
  from literalai.helper import utc_now
15
13
  from literalai.observability.step import TrueStepType
16
14
 
15
+ from chainlit.context import context_var
16
+ from chainlit.message import Message
17
+ from chainlit.step import Step
18
+
17
19
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
18
20
 
19
21
 
@@ -122,6 +124,14 @@ class GenerationHelper:
122
124
  key: self.ensure_values_serializable(value)
123
125
  for key, value in data.items()
124
126
  }
127
+ elif isinstance(data, pydantic.BaseModel):
128
+ # Fallback to support pydantic v1
129
+ # https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
130
+ if pydantic.VERSION.startswith("1"):
131
+ return data.dict()
132
+
133
+ # pydantic v2
134
+ return data.model_dump() # pyright: ignore reportAttributeAccessIssue
125
135
  elif isinstance(data, list):
126
136
  return [self.ensure_values_serializable(item) for item in data]
127
137
  elif isinstance(data, (str, int, float, bool, type(None))):
@@ -249,7 +259,7 @@ DEFAULT_TO_IGNORE = [
249
259
  DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
250
260
 
251
261
 
252
- class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
262
+ class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
253
263
  steps: Dict[str, Step]
254
264
  parent_id_map: Dict[str, str]
255
265
  ignored_runs: set
@@ -268,7 +278,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
268
278
  to_keep: Optional[List[str]] = None,
269
279
  **kwargs: Any,
270
280
  ) -> None:
271
- BaseTracer.__init__(self, **kwargs)
281
+ AsyncBaseTracer.__init__(self, **kwargs)
272
282
  GenerationHelper.__init__(self)
273
283
  FinalStreamHelper.__init__(
274
284
  self,
@@ -296,7 +306,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
296
306
  else:
297
307
  self.to_keep = to_keep
298
308
 
299
- def on_chat_model_start(
309
+ async def on_chat_model_start(
300
310
  self,
301
311
  serialized: Dict[str, Any],
302
312
  messages: List[List[BaseMessage]],
@@ -305,8 +315,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
305
315
  parent_run_id: Optional["UUID"] = None,
306
316
  tags: Optional[List[str]] = None,
307
317
  metadata: Optional[Dict[str, Any]] = None,
318
+ name: Optional[str] = None,
308
319
  **kwargs: Any,
309
- ) -> Any:
320
+ ) -> Run:
310
321
  lc_messages = messages[0]
311
322
  self.chat_generations[str(run_id)] = {
312
323
  "input_messages": lc_messages,
@@ -315,46 +326,48 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
315
326
  "tt_first_token": None,
316
327
  }
317
328
 
318
- return super().on_chat_model_start(
329
+ return await super().on_chat_model_start(
319
330
  serialized,
320
331
  messages,
321
332
  run_id=run_id,
322
333
  parent_run_id=parent_run_id,
323
334
  tags=tags,
324
335
  metadata=metadata,
336
+ name=name,
325
337
  **kwargs,
326
338
  )
327
339
 
328
- def on_llm_start(
340
+ async def on_llm_start(
329
341
  self,
330
342
  serialized: Dict[str, Any],
331
343
  prompts: List[str],
332
344
  *,
333
345
  run_id: "UUID",
346
+ parent_run_id: Optional[UUID] = None,
334
347
  tags: Optional[List[str]] = None,
335
- parent_run_id: Optional["UUID"] = None,
336
348
  metadata: Optional[Dict[str, Any]] = None,
337
- name: Optional[str] = None,
338
349
  **kwargs: Any,
339
- ) -> Run:
340
- self.completion_generations[str(run_id)] = {
341
- "prompt": prompts[0],
342
- "start": time.time(),
343
- "token_count": 0,
344
- "tt_first_token": None,
345
- }
346
- return super().on_llm_start(
350
+ ) -> None:
351
+ await super().on_llm_start(
347
352
  serialized,
348
353
  prompts,
349
354
  run_id=run_id,
350
355
  parent_run_id=parent_run_id,
351
356
  tags=tags,
352
357
  metadata=metadata,
353
- name=name,
354
358
  **kwargs,
355
359
  )
356
360
 
357
- def on_llm_new_token(
361
+ self.completion_generations[str(run_id)] = {
362
+ "prompt": prompts[0],
363
+ "start": time.time(),
364
+ "token_count": 0,
365
+ "tt_first_token": None,
366
+ }
367
+
368
+ return None
369
+
370
+ async def on_llm_new_token(
358
371
  self,
359
372
  token: str,
360
373
  *,
@@ -362,7 +375,14 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
362
375
  run_id: "UUID",
363
376
  parent_run_id: Optional["UUID"] = None,
364
377
  **kwargs: Any,
365
- ) -> Run:
378
+ ) -> None:
379
+ await super().on_llm_new_token(
380
+ token=token,
381
+ chunk=chunk,
382
+ run_id=run_id,
383
+ parent_run_id=parent_run_id,
384
+ **kwargs,
385
+ )
366
386
  if isinstance(chunk, ChatGenerationChunk):
367
387
  start = self.chat_generations[str(run_id)]
368
388
  else:
@@ -377,24 +397,13 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
377
397
  if self.answer_reached:
378
398
  if not self.final_stream:
379
399
  self.final_stream = Message(content="")
380
- self._run_sync(self.final_stream.send())
381
- self._run_sync(self.final_stream.stream_token(token))
400
+ await self.final_stream.send()
401
+ await self.final_stream.stream_token(token)
382
402
  self.has_streamed_final_answer = True
383
403
  else:
384
404
  self.answer_reached = self._check_if_answer_reached()
385
405
 
386
- return super().on_llm_new_token(
387
- token,
388
- chunk=chunk,
389
- run_id=run_id,
390
- parent_run_id=parent_run_id,
391
- )
392
-
393
- def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS?
394
- context_var.set(self.context)
395
- self.context.loop.create_task(co)
396
-
397
- def _persist_run(self, run: Run) -> None:
406
+ async def _persist_run(self, run: Run) -> None:
398
407
  pass
399
408
 
400
409
  def _get_run_parent_id(self, run: Run):
@@ -445,8 +454,8 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
445
454
  self.ignored_runs.add(str(run.id))
446
455
  return ignore, parent_id
447
456
 
448
- def _start_trace(self, run: Run) -> None:
449
- super()._start_trace(run)
457
+ async def _start_trace(self, run: Run) -> None:
458
+ await super()._start_trace(run)
450
459
  context_var.set(self.context)
451
460
 
452
461
  ignore, parent_id = self._should_ignore_run(run)
@@ -459,7 +468,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
459
468
  if ignore:
460
469
  return
461
470
 
462
- step_type: "TrueStepType" = "undefined"
471
+ step_type: TrueStepType = "undefined"
463
472
  if run.run_type == "agent":
464
473
  step_type = "run"
465
474
  elif run.run_type == "chain":
@@ -489,9 +498,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
489
498
 
490
499
  self.steps[str(run.id)] = step
491
500
 
492
- self._run_sync(step.send())
501
+ await step.send()
493
502
 
494
- def _on_run_update(self, run: Run) -> None:
503
+ async def _on_run_update(self, run: Run) -> None:
495
504
  """Process a run upon update."""
496
505
  context_var.set(self.context)
497
506
 
@@ -576,10 +585,10 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
576
585
 
577
586
  if current_step:
578
587
  current_step.end = utc_now()
579
- self._run_sync(current_step.update())
588
+ await current_step.update()
580
589
 
581
590
  if self.final_stream and self.has_streamed_final_answer:
582
- self._run_sync(self.final_stream.update())
591
+ await self.final_stream.update()
583
592
 
584
593
  return
585
594
 
@@ -599,16 +608,16 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
599
608
  else output
600
609
  )
601
610
  current_step.end = utc_now()
602
- self._run_sync(current_step.update())
611
+ await current_step.update()
603
612
 
604
- def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
613
+ async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
605
614
  context_var.set(self.context)
606
615
 
607
616
  if current_step := self.steps.get(str(run_id), None):
608
617
  current_step.is_error = True
609
618
  current_step.output = str(error)
610
619
  current_step.end = utc_now()
611
- self._run_sync(current_step.update())
620
+ await current_step.update()
612
621
 
613
622
  on_llm_error = _on_error
614
623
  on_chain_error = _on_error
@@ -8,6 +8,7 @@ if not check_module_version("langflow", "0.1.4"):
8
8
  from typing import Dict, Optional, Union
9
9
 
10
10
  import httpx
11
+
11
12
  from chainlit.telemetry import trace_event
12
13
 
13
14
 
@@ -1,8 +1,5 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from chainlit.context import context_var
4
- from chainlit.element import Text
5
- from chainlit.step import Step, StepType
6
3
  from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
7
4
  from literalai.helper import utc_now
8
5
  from llama_index.core.callbacks import TokenCountingHandler
@@ -10,6 +7,10 @@ from llama_index.core.callbacks.schema import CBEventType, EventPayload
10
7
  from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
11
8
  from llama_index.core.tools.types import ToolMetadata
12
9
 
10
+ from chainlit.context import context_var
11
+ from chainlit.element import Text
12
+ from chainlit.step import Step, StepType
13
+
13
14
  DEFAULT_IGNORE = [
14
15
  CBEventType.CHUNKING,
15
16
  CBEventType.SYNTHESIZE,
@@ -143,16 +144,15 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
143
144
  context_var.get().loop.create_task(step.update())
144
145
 
145
146
  elif event_type == CBEventType.LLM:
146
- formatted_messages = payload.get(
147
- EventPayload.MESSAGES
148
- ) # type: Optional[List[ChatMessage]]
147
+ formatted_messages = payload.get(EventPayload.MESSAGES) # type: Optional[List[ChatMessage]]
149
148
  formatted_prompt = payload.get(EventPayload.PROMPT)
150
149
  response = payload.get(EventPayload.RESPONSE)
151
150
 
152
151
  if formatted_messages:
153
152
  messages = [
154
153
  GenerationMessage(
155
- role=m.role.value, content=m.content or "" # type: ignore
154
+ role=m.role.value, # type: ignore
155
+ content=m.content or "",
156
156
  )
157
157
  for m in formatted_messages
158
158
  ]
chainlit/message.py CHANGED
@@ -5,6 +5,9 @@ import uuid
5
5
  from abc import ABC
6
6
  from typing import Dict, List, Optional, Union, cast
7
7
 
8
+ from literalai.helper import utc_now
9
+ from literalai.observability.step import MessageStepType
10
+
8
11
  from chainlit.action import Action
9
12
  from chainlit.chat_context import chat_context
10
13
  from chainlit.config import config
@@ -22,8 +25,6 @@ from chainlit.types import (
22
25
  AskSpec,
23
26
  FileDict,
24
27
  )
25
- from literalai.helper import utc_now
26
- from literalai.observability.step import MessageStepType
27
28
 
28
29
 
29
30
  class MessageBase(ABC):
@@ -42,7 +43,6 @@ class MessageBase(ABC):
42
43
  metadata: Optional[Dict] = None
43
44
  tags: Optional[List[str]] = None
44
45
  wait_for_answer = False
45
- indent: Optional[int] = None
46
46
 
47
47
  def __post_init__(self) -> None:
48
48
  trace_event(f"init {self.__class__.__name__}")
@@ -59,7 +59,7 @@ class MessageBase(ABC):
59
59
  @classmethod
60
60
  def from_dict(self, _dict: StepDict):
61
61
  type = _dict.get("type", "assistant_message")
62
- message = Message(
62
+ return Message(
63
63
  id=_dict["id"],
64
64
  parent_id=_dict.get("parentId"),
65
65
  created_at=_dict["createdAt"],
@@ -67,10 +67,9 @@ class MessageBase(ABC):
67
67
  author=_dict.get("name", config.ui.name),
68
68
  type=type, # type: ignore
69
69
  language=_dict.get("language"),
70
+ metadata=_dict.get("metadata", {}),
70
71
  )
71
72
 
72
- return message
73
-
74
73
  def to_dict(self) -> StepDict:
75
74
  _dict: StepDict = {
76
75
  "id": self.id,
@@ -86,7 +85,6 @@ class MessageBase(ABC):
86
85
  "streaming": self.streaming,
87
86
  "isError": self.is_error,
88
87
  "waitForAnswer": self.wait_for_answer,
89
- "indent": self.indent,
90
88
  "metadata": self.metadata or {},
91
89
  "tags": self.tags,
92
90
  }
@@ -114,7 +112,7 @@ class MessageBase(ABC):
114
112
  except Exception as e:
115
113
  if self.fail_on_persist_error:
116
114
  raise e
117
- logger.error(f"Failed to persist message update: {str(e)}")
115
+ logger.error(f"Failed to persist message update: {e!s}")
118
116
 
119
117
  await context.emitter.update_step(step_dict)
120
118
 
@@ -134,7 +132,7 @@ class MessageBase(ABC):
134
132
  except Exception as e:
135
133
  if self.fail_on_persist_error:
136
134
  raise e
137
- logger.error(f"Failed to persist message deletion: {str(e)}")
135
+ logger.error(f"Failed to persist message deletion: {e!s}")
138
136
 
139
137
  await context.emitter.delete_step(step_dict)
140
138
 
@@ -150,7 +148,7 @@ class MessageBase(ABC):
150
148
  except Exception as e:
151
149
  if self.fail_on_persist_error:
152
150
  raise e
153
- logger.error(f"Failed to persist message creation: {str(e)}")
151
+ logger.error(f"Failed to persist message creation: {e!s}")
154
152
 
155
153
  return step_dict
156
154
 
@@ -1,11 +1,12 @@
1
1
  import asyncio
2
2
  from typing import Union
3
3
 
4
- from chainlit.context import get_context
5
- from chainlit.step import Step
6
4
  from literalai import ChatGeneration, CompletionGeneration
7
5
  from literalai.helper import timestamp_utc
8
6
 
7
+ from chainlit.context import get_context
8
+ from chainlit.step import Step
9
+
9
10
 
10
11
  def instrument_mistralai():
11
12
  from literalai.instrumentation.mistralai import instrument_mistralai
@@ -4,9 +4,10 @@ import urllib.parse
4
4
  from typing import Dict, List, Optional, Tuple
5
5
 
6
6
  import httpx
7
+ from fastapi import HTTPException
8
+
7
9
  from chainlit.secret import random_secret
8
10
  from chainlit.user import User
9
- from fastapi import HTTPException
10
11
 
11
12
 
12
13
  class OAuthProvider:
@@ -22,10 +23,10 @@ class OAuthProvider:
22
23
  return all([os.environ.get(env) for env in self.env])
23
24
 
24
25
  async def get_token(self, code: str, url: str) -> str:
25
- raise NotImplementedError()
26
+ raise NotImplementedError
26
27
 
27
28
  async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
28
- raise NotImplementedError()
29
+ raise NotImplementedError
29
30
 
30
31
  def get_env_prefix(self) -> str:
31
32
  """Return environment prefix, like AZURE_AD."""
@@ -664,6 +665,71 @@ class GitlabOAuthProvider(OAuthProvider):
664
665
  return (gitlab_user, user)
665
666
 
666
667
 
668
+ class KeycloakOAuthProvider(OAuthProvider):
669
+ env = [
670
+ "OAUTH_KEYCLOAK_CLIENT_ID",
671
+ "OAUTH_KEYCLOAK_CLIENT_SECRET",
672
+ "OAUTH_KEYCLOAK_REALM",
673
+ "OAUTH_KEYCLOAK_BASE_URL",
674
+ ]
675
+ id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
676
+
677
+ def __init__(self):
678
+ self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
679
+ self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
680
+ self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
681
+ self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
682
+ self.authorize_url = (
683
+ f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
684
+ )
685
+
686
+ self.authorize_params = {
687
+ "scope": "profile email openid",
688
+ "response_type": "code",
689
+ }
690
+
691
+ if prompt := self.get_prompt():
692
+ self.authorize_params["prompt"] = prompt
693
+
694
+ async def get_token(self, code: str, url: str):
695
+ payload = {
696
+ "client_id": self.client_id,
697
+ "client_secret": self.client_secret,
698
+ "code": code,
699
+ "grant_type": "authorization_code",
700
+ "redirect_uri": url,
701
+ }
702
+ async with httpx.AsyncClient() as client:
703
+ response = await client.post(
704
+ f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
705
+ data=payload,
706
+ )
707
+ response.raise_for_status()
708
+ json = response.json()
709
+ token = json.get("access_token")
710
+ if not token:
711
+ raise httpx.HTTPStatusError(
712
+ "Failed to get the access token",
713
+ request=response.request,
714
+ response=response,
715
+ )
716
+ return token
717
+
718
+ async def get_user_info(self, token: str):
719
+ async with httpx.AsyncClient() as client:
720
+ response = await client.get(
721
+ f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
722
+ headers={"Authorization": f"Bearer {token}"},
723
+ )
724
+ response.raise_for_status()
725
+ kc_user = response.json()
726
+ user = User(
727
+ identifier=kc_user["email"],
728
+ metadata={"provider": "keycloak"},
729
+ )
730
+ return (kc_user, user)
731
+
732
+
667
733
  providers = [
668
734
  GithubOAuthProvider(),
669
735
  GoogleOAuthProvider(),
@@ -674,6 +740,7 @@ providers = [
674
740
  DescopeOAuthProvider(),
675
741
  AWSCognitoOAuthProvider(),
676
742
  GitlabOAuthProvider(),
743
+ KeycloakOAuthProvider(),
677
744
  ]
678
745
 
679
746
 
@@ -1,11 +1,12 @@
1
1
  import asyncio
2
2
  from typing import Union
3
3
 
4
+ from literalai import ChatGeneration, CompletionGeneration
5
+ from literalai.helper import timestamp_utc
6
+
4
7
  from chainlit.context import local_steps
5
8
  from chainlit.step import Step
6
9
  from chainlit.utils import check_module_version
7
- from literalai import ChatGeneration, CompletionGeneration
8
- from literalai.helper import timestamp_utc
9
10
 
10
11
 
11
12
  def instrument_openai():
chainlit/secret.py CHANGED
@@ -6,4 +6,4 @@ chars = string.ascii_letters + string.digits + "$%*,-./:=>?@^_~"
6
6
 
7
7
 
8
8
  def random_secret(length: int = 64):
9
- return "".join((secrets.choice(chars) for i in range(length)))
9
+ return "".join(secrets.choice(chars) for i in range(length))