chainlit 1.3.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +58 -56
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +24 -34
- chainlit/auth/cookie.py +123 -0
- chainlit/auth/jwt.py +37 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +65 -11
- chainlit/chat_context.py +2 -2
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +15 -2
- chainlit/config.py +46 -90
- chainlit/context.py +4 -3
- chainlit/copilot/dist/index.js +8608 -642
- chainlit/data/__init__.py +96 -8
- chainlit/data/acl.py +3 -2
- chainlit/data/base.py +1 -15
- chainlit/data/chainlit_data_layer.py +584 -0
- chainlit/data/dynamodb.py +7 -4
- chainlit/data/literalai.py +4 -6
- chainlit/data/sql_alchemy.py +9 -8
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/{storage_clients.py → storage_clients/azure.py} +2 -33
- chainlit/data/storage_clients/azure_blob.py +80 -0
- chainlit/data/storage_clients/base.py +22 -0
- chainlit/data/storage_clients/gcs.py +78 -0
- chainlit/data/storage_clients/s3.py +49 -0
- chainlit/discord/__init__.py +4 -4
- chainlit/discord/app.py +2 -1
- chainlit/element.py +41 -9
- chainlit/emitter.py +37 -16
- chainlit/frontend/dist/assets/{DailyMotion-Bq4wFES6.js → DailyMotion-DgRzV5GZ.js} +1 -1
- chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +22 -0
- chainlit/frontend/dist/assets/{Facebook-CHEgeJDe.js → Facebook-C0vx6HWv.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-BMFA6He5.js → FilePlayer-CdhzeHPP.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-BS4Q0SKd.js → Kaltura-5iVmeUct.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-tLlgZy_i.js → Mixcloud-C2zi77Ex.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-Bcz0qNhS.js → Mux-Vkebogdf.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-RsJjlwJx.js → Preview-DwY_sEIl.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-B9UgR7Bk.js → SoundCloud-CREBXAWo.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-BOgIqbui.js → Streamable-B5Lu25uy.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-CBX_d6nV.js → Twitch-y9iKCcM1.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-C5HPuozf.js → Vidyard-ClYvcuEu.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-CHBmywi9.js → Vimeo-D6HvM2jt.js} +1 -1
- chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +1 -0
- chainlit/frontend/dist/assets/{YouTube-CA7t0q0j.js → YouTube-D10tR6CJ.js} +1 -1
- chainlit/frontend/dist/assets/index-CI4qFOt5.js +8665 -0
- chainlit/frontend/dist/assets/index-CrrqM0nZ.css +1 -0
- chainlit/frontend/dist/assets/{react-plotly-Ba2Cl614.js → react-plotly-BpxUS-ab.js} +1 -1
- chainlit/frontend/dist/index.html +2 -2
- chainlit/haystack/callbacks.py +5 -4
- chainlit/input_widget.py +6 -4
- chainlit/langchain/callbacks.py +56 -47
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +7 -7
- chainlit/message.py +8 -10
- chainlit/mistralai/__init__.py +3 -2
- chainlit/oauth_providers.py +70 -3
- chainlit/openai/__init__.py +3 -2
- chainlit/secret.py +1 -1
- chainlit/server.py +481 -182
- chainlit/session.py +7 -5
- chainlit/slack/__init__.py +3 -3
- chainlit/slack/app.py +3 -2
- chainlit/socket.py +89 -112
- chainlit/step.py +12 -12
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +3 -3
- chainlit/teams/app.py +1 -0
- chainlit/translations/en-US.json +2 -1
- chainlit/translations/nl-NL.json +229 -0
- chainlit/types.py +24 -8
- chainlit/user.py +2 -1
- chainlit/utils.py +3 -2
- chainlit/version.py +3 -2
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/METADATA +15 -35
- chainlit-2.0.0.dist-info/RECORD +106 -0
- chainlit/frontend/dist/assets/Wistia-1Gb23ljh.js +0 -1
- chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
- chainlit/frontend/dist/assets/index-DnjoDoLU.js +0 -723
- chainlit-1.3.2.dist-info/RECORD +0 -96
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/WHEEL +0 -0
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -21,8 +21,8 @@
|
|
|
21
21
|
<script>
|
|
22
22
|
const global = globalThis;
|
|
23
23
|
</script>
|
|
24
|
-
<script type="module" crossorigin src="/assets/index-
|
|
25
|
-
<link rel="stylesheet" crossorigin href="/assets/index-
|
|
24
|
+
<script type="module" crossorigin src="/assets/index-CI4qFOt5.js"></script>
|
|
25
|
+
<link rel="stylesheet" crossorigin href="/assets/index-CrrqM0nZ.css">
|
|
26
26
|
</head>
|
|
27
27
|
<body>
|
|
28
28
|
<div id="root"></div>
|
chainlit/haystack/callbacks.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from typing import Any, Generic, List, Optional, TypeVar
|
|
3
3
|
|
|
4
|
-
from chainlit import Message
|
|
5
|
-
from chainlit.step import Step
|
|
6
|
-
from chainlit.sync import run_sync
|
|
7
4
|
from haystack.agents import Agent, Tool
|
|
8
5
|
from haystack.agents.agent_step import AgentStep
|
|
9
6
|
from literalai.helper import utc_now
|
|
10
7
|
|
|
8
|
+
from chainlit import Message
|
|
9
|
+
from chainlit.step import Step
|
|
10
|
+
from chainlit.sync import run_sync
|
|
11
|
+
|
|
11
12
|
T = TypeVar("T")
|
|
12
13
|
|
|
13
14
|
|
|
@@ -131,7 +132,7 @@ class HaystackAgentCallbackHandler:
|
|
|
131
132
|
tool_result: str,
|
|
132
133
|
tool_name: Optional[str] = None,
|
|
133
134
|
tool_input: Optional[str] = None,
|
|
134
|
-
**kwargs: Any
|
|
135
|
+
**kwargs: Any,
|
|
135
136
|
) -> None:
|
|
136
137
|
# Tool finished, send step with tool_result
|
|
137
138
|
tool_step = self.stack.pop()
|
chainlit/input_widget.py
CHANGED
|
@@ -2,8 +2,10 @@ from abc import abstractmethod
|
|
|
2
2
|
from collections import defaultdict
|
|
3
3
|
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
from pydantic.dataclasses import dataclass
|
|
7
|
+
|
|
5
8
|
from chainlit.types import InputWidgetType
|
|
6
|
-
from pydantic.dataclasses import Field, dataclass
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
@dataclass
|
|
@@ -75,7 +77,7 @@ class Select(InputWidget):
|
|
|
75
77
|
initial: Optional[str] = None
|
|
76
78
|
initial_index: Optional[int] = None
|
|
77
79
|
initial_value: Optional[str] = None
|
|
78
|
-
values: List[str] = Field(default_factory=
|
|
80
|
+
values: List[str] = Field(default_factory=list)
|
|
79
81
|
items: Dict[str, str] = Field(default_factory=lambda: defaultdict(dict))
|
|
80
82
|
|
|
81
83
|
def __post_init__(
|
|
@@ -167,8 +169,8 @@ class Tags(InputWidget):
|
|
|
167
169
|
"""Useful to create an input for an array of strings."""
|
|
168
170
|
|
|
169
171
|
type: InputWidgetType = "tags"
|
|
170
|
-
initial: List[str] = Field(default_factory=
|
|
171
|
-
values: List[str] = Field(default_factory=
|
|
172
|
+
initial: List[str] = Field(default_factory=list)
|
|
173
|
+
values: List[str] = Field(default_factory=list)
|
|
172
174
|
|
|
173
175
|
def to_dict(self) -> Dict[str, Any]:
|
|
174
176
|
return {
|
chainlit/langchain/callbacks.py
CHANGED
|
@@ -3,17 +3,19 @@ import time
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
from chainlit.message import Message
|
|
8
|
-
from chainlit.step import Step
|
|
9
|
-
from langchain.callbacks.tracers.base import BaseTracer
|
|
6
|
+
import pydantic
|
|
10
7
|
from langchain.callbacks.tracers.schemas import Run
|
|
11
8
|
from langchain.schema import BaseMessage
|
|
12
9
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
|
10
|
+
from langchain_core.tracers.base import AsyncBaseTracer
|
|
13
11
|
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
|
|
14
12
|
from literalai.helper import utc_now
|
|
15
13
|
from literalai.observability.step import TrueStepType
|
|
16
14
|
|
|
15
|
+
from chainlit.context import context_var
|
|
16
|
+
from chainlit.message import Message
|
|
17
|
+
from chainlit.step import Step
|
|
18
|
+
|
|
17
19
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
|
18
20
|
|
|
19
21
|
|
|
@@ -122,6 +124,14 @@ class GenerationHelper:
|
|
|
122
124
|
key: self.ensure_values_serializable(value)
|
|
123
125
|
for key, value in data.items()
|
|
124
126
|
}
|
|
127
|
+
elif isinstance(data, pydantic.BaseModel):
|
|
128
|
+
# Fallback to support pydantic v1
|
|
129
|
+
# https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
|
|
130
|
+
if pydantic.VERSION.startswith("1"):
|
|
131
|
+
return data.dict()
|
|
132
|
+
|
|
133
|
+
# pydantic v2
|
|
134
|
+
return data.model_dump() # pyright: ignore reportAttributeAccessIssue
|
|
125
135
|
elif isinstance(data, list):
|
|
126
136
|
return [self.ensure_values_serializable(item) for item in data]
|
|
127
137
|
elif isinstance(data, (str, int, float, bool, type(None))):
|
|
@@ -249,7 +259,7 @@ DEFAULT_TO_IGNORE = [
|
|
|
249
259
|
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
|
|
250
260
|
|
|
251
261
|
|
|
252
|
-
class LangchainTracer(
|
|
262
|
+
class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
|
|
253
263
|
steps: Dict[str, Step]
|
|
254
264
|
parent_id_map: Dict[str, str]
|
|
255
265
|
ignored_runs: set
|
|
@@ -268,7 +278,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
268
278
|
to_keep: Optional[List[str]] = None,
|
|
269
279
|
**kwargs: Any,
|
|
270
280
|
) -> None:
|
|
271
|
-
|
|
281
|
+
AsyncBaseTracer.__init__(self, **kwargs)
|
|
272
282
|
GenerationHelper.__init__(self)
|
|
273
283
|
FinalStreamHelper.__init__(
|
|
274
284
|
self,
|
|
@@ -296,7 +306,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
296
306
|
else:
|
|
297
307
|
self.to_keep = to_keep
|
|
298
308
|
|
|
299
|
-
def on_chat_model_start(
|
|
309
|
+
async def on_chat_model_start(
|
|
300
310
|
self,
|
|
301
311
|
serialized: Dict[str, Any],
|
|
302
312
|
messages: List[List[BaseMessage]],
|
|
@@ -305,8 +315,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
305
315
|
parent_run_id: Optional["UUID"] = None,
|
|
306
316
|
tags: Optional[List[str]] = None,
|
|
307
317
|
metadata: Optional[Dict[str, Any]] = None,
|
|
318
|
+
name: Optional[str] = None,
|
|
308
319
|
**kwargs: Any,
|
|
309
|
-
) ->
|
|
320
|
+
) -> Run:
|
|
310
321
|
lc_messages = messages[0]
|
|
311
322
|
self.chat_generations[str(run_id)] = {
|
|
312
323
|
"input_messages": lc_messages,
|
|
@@ -315,46 +326,48 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
315
326
|
"tt_first_token": None,
|
|
316
327
|
}
|
|
317
328
|
|
|
318
|
-
return super().on_chat_model_start(
|
|
329
|
+
return await super().on_chat_model_start(
|
|
319
330
|
serialized,
|
|
320
331
|
messages,
|
|
321
332
|
run_id=run_id,
|
|
322
333
|
parent_run_id=parent_run_id,
|
|
323
334
|
tags=tags,
|
|
324
335
|
metadata=metadata,
|
|
336
|
+
name=name,
|
|
325
337
|
**kwargs,
|
|
326
338
|
)
|
|
327
339
|
|
|
328
|
-
def on_llm_start(
|
|
340
|
+
async def on_llm_start(
|
|
329
341
|
self,
|
|
330
342
|
serialized: Dict[str, Any],
|
|
331
343
|
prompts: List[str],
|
|
332
344
|
*,
|
|
333
345
|
run_id: "UUID",
|
|
346
|
+
parent_run_id: Optional[UUID] = None,
|
|
334
347
|
tags: Optional[List[str]] = None,
|
|
335
|
-
parent_run_id: Optional["UUID"] = None,
|
|
336
348
|
metadata: Optional[Dict[str, Any]] = None,
|
|
337
|
-
name: Optional[str] = None,
|
|
338
349
|
**kwargs: Any,
|
|
339
|
-
) ->
|
|
340
|
-
|
|
341
|
-
"prompt": prompts[0],
|
|
342
|
-
"start": time.time(),
|
|
343
|
-
"token_count": 0,
|
|
344
|
-
"tt_first_token": None,
|
|
345
|
-
}
|
|
346
|
-
return super().on_llm_start(
|
|
350
|
+
) -> None:
|
|
351
|
+
await super().on_llm_start(
|
|
347
352
|
serialized,
|
|
348
353
|
prompts,
|
|
349
354
|
run_id=run_id,
|
|
350
355
|
parent_run_id=parent_run_id,
|
|
351
356
|
tags=tags,
|
|
352
357
|
metadata=metadata,
|
|
353
|
-
name=name,
|
|
354
358
|
**kwargs,
|
|
355
359
|
)
|
|
356
360
|
|
|
357
|
-
|
|
361
|
+
self.completion_generations[str(run_id)] = {
|
|
362
|
+
"prompt": prompts[0],
|
|
363
|
+
"start": time.time(),
|
|
364
|
+
"token_count": 0,
|
|
365
|
+
"tt_first_token": None,
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
async def on_llm_new_token(
|
|
358
371
|
self,
|
|
359
372
|
token: str,
|
|
360
373
|
*,
|
|
@@ -362,7 +375,14 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
362
375
|
run_id: "UUID",
|
|
363
376
|
parent_run_id: Optional["UUID"] = None,
|
|
364
377
|
**kwargs: Any,
|
|
365
|
-
) ->
|
|
378
|
+
) -> None:
|
|
379
|
+
await super().on_llm_new_token(
|
|
380
|
+
token=token,
|
|
381
|
+
chunk=chunk,
|
|
382
|
+
run_id=run_id,
|
|
383
|
+
parent_run_id=parent_run_id,
|
|
384
|
+
**kwargs,
|
|
385
|
+
)
|
|
366
386
|
if isinstance(chunk, ChatGenerationChunk):
|
|
367
387
|
start = self.chat_generations[str(run_id)]
|
|
368
388
|
else:
|
|
@@ -377,24 +397,13 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
377
397
|
if self.answer_reached:
|
|
378
398
|
if not self.final_stream:
|
|
379
399
|
self.final_stream = Message(content="")
|
|
380
|
-
self.
|
|
381
|
-
self.
|
|
400
|
+
await self.final_stream.send()
|
|
401
|
+
await self.final_stream.stream_token(token)
|
|
382
402
|
self.has_streamed_final_answer = True
|
|
383
403
|
else:
|
|
384
404
|
self.answer_reached = self._check_if_answer_reached()
|
|
385
405
|
|
|
386
|
-
|
|
387
|
-
token,
|
|
388
|
-
chunk=chunk,
|
|
389
|
-
run_id=run_id,
|
|
390
|
-
parent_run_id=parent_run_id,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS?
|
|
394
|
-
context_var.set(self.context)
|
|
395
|
-
self.context.loop.create_task(co)
|
|
396
|
-
|
|
397
|
-
def _persist_run(self, run: Run) -> None:
|
|
406
|
+
async def _persist_run(self, run: Run) -> None:
|
|
398
407
|
pass
|
|
399
408
|
|
|
400
409
|
def _get_run_parent_id(self, run: Run):
|
|
@@ -445,8 +454,8 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
445
454
|
self.ignored_runs.add(str(run.id))
|
|
446
455
|
return ignore, parent_id
|
|
447
456
|
|
|
448
|
-
def _start_trace(self, run: Run) -> None:
|
|
449
|
-
super()._start_trace(run)
|
|
457
|
+
async def _start_trace(self, run: Run) -> None:
|
|
458
|
+
await super()._start_trace(run)
|
|
450
459
|
context_var.set(self.context)
|
|
451
460
|
|
|
452
461
|
ignore, parent_id = self._should_ignore_run(run)
|
|
@@ -459,7 +468,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
459
468
|
if ignore:
|
|
460
469
|
return
|
|
461
470
|
|
|
462
|
-
step_type:
|
|
471
|
+
step_type: TrueStepType = "undefined"
|
|
463
472
|
if run.run_type == "agent":
|
|
464
473
|
step_type = "run"
|
|
465
474
|
elif run.run_type == "chain":
|
|
@@ -489,9 +498,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
489
498
|
|
|
490
499
|
self.steps[str(run.id)] = step
|
|
491
500
|
|
|
492
|
-
|
|
501
|
+
await step.send()
|
|
493
502
|
|
|
494
|
-
def _on_run_update(self, run: Run) -> None:
|
|
503
|
+
async def _on_run_update(self, run: Run) -> None:
|
|
495
504
|
"""Process a run upon update."""
|
|
496
505
|
context_var.set(self.context)
|
|
497
506
|
|
|
@@ -576,10 +585,10 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
576
585
|
|
|
577
586
|
if current_step:
|
|
578
587
|
current_step.end = utc_now()
|
|
579
|
-
|
|
588
|
+
await current_step.update()
|
|
580
589
|
|
|
581
590
|
if self.final_stream and self.has_streamed_final_answer:
|
|
582
|
-
self.
|
|
591
|
+
await self.final_stream.update()
|
|
583
592
|
|
|
584
593
|
return
|
|
585
594
|
|
|
@@ -599,16 +608,16 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
599
608
|
else output
|
|
600
609
|
)
|
|
601
610
|
current_step.end = utc_now()
|
|
602
|
-
|
|
611
|
+
await current_step.update()
|
|
603
612
|
|
|
604
|
-
def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
|
|
613
|
+
async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
|
|
605
614
|
context_var.set(self.context)
|
|
606
615
|
|
|
607
616
|
if current_step := self.steps.get(str(run_id), None):
|
|
608
617
|
current_step.is_error = True
|
|
609
618
|
current_step.output = str(error)
|
|
610
619
|
current_step.end = utc_now()
|
|
611
|
-
|
|
620
|
+
await current_step.update()
|
|
612
621
|
|
|
613
622
|
on_llm_error = _on_error
|
|
614
623
|
on_chain_error = _on_error
|
chainlit/langflow/__init__.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
|
2
2
|
|
|
3
|
-
from chainlit.context import context_var
|
|
4
|
-
from chainlit.element import Text
|
|
5
|
-
from chainlit.step import Step, StepType
|
|
6
3
|
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
|
|
7
4
|
from literalai.helper import utc_now
|
|
8
5
|
from llama_index.core.callbacks import TokenCountingHandler
|
|
@@ -10,6 +7,10 @@ from llama_index.core.callbacks.schema import CBEventType, EventPayload
|
|
|
10
7
|
from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
|
|
11
8
|
from llama_index.core.tools.types import ToolMetadata
|
|
12
9
|
|
|
10
|
+
from chainlit.context import context_var
|
|
11
|
+
from chainlit.element import Text
|
|
12
|
+
from chainlit.step import Step, StepType
|
|
13
|
+
|
|
13
14
|
DEFAULT_IGNORE = [
|
|
14
15
|
CBEventType.CHUNKING,
|
|
15
16
|
CBEventType.SYNTHESIZE,
|
|
@@ -143,16 +144,15 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
143
144
|
context_var.get().loop.create_task(step.update())
|
|
144
145
|
|
|
145
146
|
elif event_type == CBEventType.LLM:
|
|
146
|
-
formatted_messages = payload.get(
|
|
147
|
-
EventPayload.MESSAGES
|
|
148
|
-
) # type: Optional[List[ChatMessage]]
|
|
147
|
+
formatted_messages = payload.get(EventPayload.MESSAGES) # type: Optional[List[ChatMessage]]
|
|
149
148
|
formatted_prompt = payload.get(EventPayload.PROMPT)
|
|
150
149
|
response = payload.get(EventPayload.RESPONSE)
|
|
151
150
|
|
|
152
151
|
if formatted_messages:
|
|
153
152
|
messages = [
|
|
154
153
|
GenerationMessage(
|
|
155
|
-
role=m.role.value,
|
|
154
|
+
role=m.role.value, # type: ignore
|
|
155
|
+
content=m.content or "",
|
|
156
156
|
)
|
|
157
157
|
for m in formatted_messages
|
|
158
158
|
]
|
chainlit/message.py
CHANGED
|
@@ -5,6 +5,9 @@ import uuid
|
|
|
5
5
|
from abc import ABC
|
|
6
6
|
from typing import Dict, List, Optional, Union, cast
|
|
7
7
|
|
|
8
|
+
from literalai.helper import utc_now
|
|
9
|
+
from literalai.observability.step import MessageStepType
|
|
10
|
+
|
|
8
11
|
from chainlit.action import Action
|
|
9
12
|
from chainlit.chat_context import chat_context
|
|
10
13
|
from chainlit.config import config
|
|
@@ -22,8 +25,6 @@ from chainlit.types import (
|
|
|
22
25
|
AskSpec,
|
|
23
26
|
FileDict,
|
|
24
27
|
)
|
|
25
|
-
from literalai.helper import utc_now
|
|
26
|
-
from literalai.observability.step import MessageStepType
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class MessageBase(ABC):
|
|
@@ -42,7 +43,6 @@ class MessageBase(ABC):
|
|
|
42
43
|
metadata: Optional[Dict] = None
|
|
43
44
|
tags: Optional[List[str]] = None
|
|
44
45
|
wait_for_answer = False
|
|
45
|
-
indent: Optional[int] = None
|
|
46
46
|
|
|
47
47
|
def __post_init__(self) -> None:
|
|
48
48
|
trace_event(f"init {self.__class__.__name__}")
|
|
@@ -59,7 +59,7 @@ class MessageBase(ABC):
|
|
|
59
59
|
@classmethod
|
|
60
60
|
def from_dict(self, _dict: StepDict):
|
|
61
61
|
type = _dict.get("type", "assistant_message")
|
|
62
|
-
|
|
62
|
+
return Message(
|
|
63
63
|
id=_dict["id"],
|
|
64
64
|
parent_id=_dict.get("parentId"),
|
|
65
65
|
created_at=_dict["createdAt"],
|
|
@@ -67,10 +67,9 @@ class MessageBase(ABC):
|
|
|
67
67
|
author=_dict.get("name", config.ui.name),
|
|
68
68
|
type=type, # type: ignore
|
|
69
69
|
language=_dict.get("language"),
|
|
70
|
+
metadata=_dict.get("metadata", {}),
|
|
70
71
|
)
|
|
71
72
|
|
|
72
|
-
return message
|
|
73
|
-
|
|
74
73
|
def to_dict(self) -> StepDict:
|
|
75
74
|
_dict: StepDict = {
|
|
76
75
|
"id": self.id,
|
|
@@ -86,7 +85,6 @@ class MessageBase(ABC):
|
|
|
86
85
|
"streaming": self.streaming,
|
|
87
86
|
"isError": self.is_error,
|
|
88
87
|
"waitForAnswer": self.wait_for_answer,
|
|
89
|
-
"indent": self.indent,
|
|
90
88
|
"metadata": self.metadata or {},
|
|
91
89
|
"tags": self.tags,
|
|
92
90
|
}
|
|
@@ -114,7 +112,7 @@ class MessageBase(ABC):
|
|
|
114
112
|
except Exception as e:
|
|
115
113
|
if self.fail_on_persist_error:
|
|
116
114
|
raise e
|
|
117
|
-
logger.error(f"Failed to persist message update: {
|
|
115
|
+
logger.error(f"Failed to persist message update: {e!s}")
|
|
118
116
|
|
|
119
117
|
await context.emitter.update_step(step_dict)
|
|
120
118
|
|
|
@@ -134,7 +132,7 @@ class MessageBase(ABC):
|
|
|
134
132
|
except Exception as e:
|
|
135
133
|
if self.fail_on_persist_error:
|
|
136
134
|
raise e
|
|
137
|
-
logger.error(f"Failed to persist message deletion: {
|
|
135
|
+
logger.error(f"Failed to persist message deletion: {e!s}")
|
|
138
136
|
|
|
139
137
|
await context.emitter.delete_step(step_dict)
|
|
140
138
|
|
|
@@ -150,7 +148,7 @@ class MessageBase(ABC):
|
|
|
150
148
|
except Exception as e:
|
|
151
149
|
if self.fail_on_persist_error:
|
|
152
150
|
raise e
|
|
153
|
-
logger.error(f"Failed to persist message creation: {
|
|
151
|
+
logger.error(f"Failed to persist message creation: {e!s}")
|
|
154
152
|
|
|
155
153
|
return step_dict
|
|
156
154
|
|
chainlit/mistralai/__init__.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Union
|
|
3
3
|
|
|
4
|
-
from chainlit.context import get_context
|
|
5
|
-
from chainlit.step import Step
|
|
6
4
|
from literalai import ChatGeneration, CompletionGeneration
|
|
7
5
|
from literalai.helper import timestamp_utc
|
|
8
6
|
|
|
7
|
+
from chainlit.context import get_context
|
|
8
|
+
from chainlit.step import Step
|
|
9
|
+
|
|
9
10
|
|
|
10
11
|
def instrument_mistralai():
|
|
11
12
|
from literalai.instrumentation.mistralai import instrument_mistralai
|
chainlit/oauth_providers.py
CHANGED
|
@@ -4,9 +4,10 @@ import urllib.parse
|
|
|
4
4
|
from typing import Dict, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
|
+
from fastapi import HTTPException
|
|
8
|
+
|
|
7
9
|
from chainlit.secret import random_secret
|
|
8
10
|
from chainlit.user import User
|
|
9
|
-
from fastapi import HTTPException
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class OAuthProvider:
|
|
@@ -22,10 +23,10 @@ class OAuthProvider:
|
|
|
22
23
|
return all([os.environ.get(env) for env in self.env])
|
|
23
24
|
|
|
24
25
|
async def get_token(self, code: str, url: str) -> str:
|
|
25
|
-
raise NotImplementedError
|
|
26
|
+
raise NotImplementedError
|
|
26
27
|
|
|
27
28
|
async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
|
|
28
|
-
raise NotImplementedError
|
|
29
|
+
raise NotImplementedError
|
|
29
30
|
|
|
30
31
|
def get_env_prefix(self) -> str:
|
|
31
32
|
"""Return environment prefix, like AZURE_AD."""
|
|
@@ -664,6 +665,71 @@ class GitlabOAuthProvider(OAuthProvider):
|
|
|
664
665
|
return (gitlab_user, user)
|
|
665
666
|
|
|
666
667
|
|
|
668
|
+
class KeycloakOAuthProvider(OAuthProvider):
|
|
669
|
+
env = [
|
|
670
|
+
"OAUTH_KEYCLOAK_CLIENT_ID",
|
|
671
|
+
"OAUTH_KEYCLOAK_CLIENT_SECRET",
|
|
672
|
+
"OAUTH_KEYCLOAK_REALM",
|
|
673
|
+
"OAUTH_KEYCLOAK_BASE_URL",
|
|
674
|
+
]
|
|
675
|
+
id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
|
|
676
|
+
|
|
677
|
+
def __init__(self):
|
|
678
|
+
self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
|
|
679
|
+
self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
|
|
680
|
+
self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
|
|
681
|
+
self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
|
|
682
|
+
self.authorize_url = (
|
|
683
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
self.authorize_params = {
|
|
687
|
+
"scope": "profile email openid",
|
|
688
|
+
"response_type": "code",
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
if prompt := self.get_prompt():
|
|
692
|
+
self.authorize_params["prompt"] = prompt
|
|
693
|
+
|
|
694
|
+
async def get_token(self, code: str, url: str):
|
|
695
|
+
payload = {
|
|
696
|
+
"client_id": self.client_id,
|
|
697
|
+
"client_secret": self.client_secret,
|
|
698
|
+
"code": code,
|
|
699
|
+
"grant_type": "authorization_code",
|
|
700
|
+
"redirect_uri": url,
|
|
701
|
+
}
|
|
702
|
+
async with httpx.AsyncClient() as client:
|
|
703
|
+
response = await client.post(
|
|
704
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
|
|
705
|
+
data=payload,
|
|
706
|
+
)
|
|
707
|
+
response.raise_for_status()
|
|
708
|
+
json = response.json()
|
|
709
|
+
token = json.get("access_token")
|
|
710
|
+
if not token:
|
|
711
|
+
raise httpx.HTTPStatusError(
|
|
712
|
+
"Failed to get the access token",
|
|
713
|
+
request=response.request,
|
|
714
|
+
response=response,
|
|
715
|
+
)
|
|
716
|
+
return token
|
|
717
|
+
|
|
718
|
+
async def get_user_info(self, token: str):
|
|
719
|
+
async with httpx.AsyncClient() as client:
|
|
720
|
+
response = await client.get(
|
|
721
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
|
|
722
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
723
|
+
)
|
|
724
|
+
response.raise_for_status()
|
|
725
|
+
kc_user = response.json()
|
|
726
|
+
user = User(
|
|
727
|
+
identifier=kc_user["email"],
|
|
728
|
+
metadata={"provider": "keycloak"},
|
|
729
|
+
)
|
|
730
|
+
return (kc_user, user)
|
|
731
|
+
|
|
732
|
+
|
|
667
733
|
providers = [
|
|
668
734
|
GithubOAuthProvider(),
|
|
669
735
|
GoogleOAuthProvider(),
|
|
@@ -674,6 +740,7 @@ providers = [
|
|
|
674
740
|
DescopeOAuthProvider(),
|
|
675
741
|
AWSCognitoOAuthProvider(),
|
|
676
742
|
GitlabOAuthProvider(),
|
|
743
|
+
KeycloakOAuthProvider(),
|
|
677
744
|
]
|
|
678
745
|
|
|
679
746
|
|
chainlit/openai/__init__.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Union
|
|
3
3
|
|
|
4
|
+
from literalai import ChatGeneration, CompletionGeneration
|
|
5
|
+
from literalai.helper import timestamp_utc
|
|
6
|
+
|
|
4
7
|
from chainlit.context import local_steps
|
|
5
8
|
from chainlit.step import Step
|
|
6
9
|
from chainlit.utils import check_module_version
|
|
7
|
-
from literalai import ChatGeneration, CompletionGeneration
|
|
8
|
-
from literalai.helper import timestamp_utc
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def instrument_openai():
|
chainlit/secret.py
CHANGED