chainlit 1.0.0rc3__py3-none-any.whl → 1.0.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +14 -1
- chainlit/auth.py +4 -3
- chainlit/config.py +2 -6
- chainlit/data/__init__.py +97 -24
- chainlit/emitter.py +9 -5
- chainlit/frontend/dist/assets/index-c4f40824.js +723 -0
- chainlit/frontend/dist/assets/{react-plotly-c9578a93.js → react-plotly-259d6961.js} +1 -1
- chainlit/frontend/dist/index.html +1 -1
- chainlit/haystack/callbacks.py +32 -6
- chainlit/langchain/callbacks.py +8 -6
- chainlit/llama_index/callbacks.py +13 -5
- chainlit/message.py +1 -1
- chainlit/oauth_providers.py +67 -0
- chainlit/playground/config.py +2 -0
- chainlit/playground/provider.py +1 -1
- chainlit/playground/providers/__init__.py +1 -0
- chainlit/playground/providers/anthropic.py +1 -1
- chainlit/playground/providers/langchain.py +8 -7
- chainlit/playground/providers/vertexai.py +51 -6
- chainlit/server.py +35 -15
- chainlit/session.py +5 -1
- chainlit/socket.py +56 -18
- chainlit/step.py +4 -4
- chainlit/telemetry.py +2 -6
- chainlit/types.py +1 -1
- {chainlit-1.0.0rc3.dist-info → chainlit-1.0.100.dist-info}/METADATA +3 -3
- {chainlit-1.0.0rc3.dist-info → chainlit-1.0.100.dist-info}/RECORD +29 -29
- chainlit/frontend/dist/assets/index-15bb372a.js +0 -697
- {chainlit-1.0.0rc3.dist-info → chainlit-1.0.100.dist-info}/WHEEL +0 -0
- {chainlit-1.0.0rc3.dist-info → chainlit-1.0.100.dist-info}/entry_points.txt +0 -0
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
<script>
|
|
21
21
|
const global = globalThis;
|
|
22
22
|
</script>
|
|
23
|
-
<script type="module" crossorigin src="/assets/index-
|
|
23
|
+
<script type="module" crossorigin src="/assets/index-c4f40824.js"></script>
|
|
24
24
|
<link rel="stylesheet" href="/assets/index-d088547c.css">
|
|
25
25
|
</head>
|
|
26
26
|
<body>
|
chainlit/haystack/callbacks.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from typing import Any, Generic, List, Optional, TypeVar
|
|
3
|
+
import re
|
|
3
4
|
|
|
4
5
|
from chainlit.context import context
|
|
5
6
|
from chainlit.step import Step
|
|
6
7
|
from chainlit.sync import run_sync
|
|
8
|
+
from chainlit import Message
|
|
7
9
|
from haystack.agents import Agent, Tool
|
|
8
10
|
from haystack.agents.agent_step import AgentStep
|
|
9
11
|
|
|
@@ -34,7 +36,7 @@ class HaystackAgentCallbackHandler:
|
|
|
34
36
|
stack: Stack[Step]
|
|
35
37
|
last_step: Optional[Step]
|
|
36
38
|
|
|
37
|
-
def __init__(self, agent: Agent):
|
|
39
|
+
def __init__(self, agent: Agent, stream_final_answer: bool = False, stream_final_answer_agent_name: str = 'Agent'):
|
|
38
40
|
agent.callback_manager.on_agent_start += self.on_agent_start
|
|
39
41
|
agent.callback_manager.on_agent_step += self.on_agent_step
|
|
40
42
|
agent.callback_manager.on_agent_finish += self.on_agent_finish
|
|
@@ -44,10 +46,20 @@ class HaystackAgentCallbackHandler:
|
|
|
44
46
|
agent.tm.callback_manager.on_tool_finish += self.on_tool_finish
|
|
45
47
|
agent.tm.callback_manager.on_tool_error += self.on_tool_error
|
|
46
48
|
|
|
49
|
+
self.final_answer_pattern = agent.final_answer_pattern
|
|
50
|
+
self.stream_final_answer = stream_final_answer
|
|
51
|
+
self.stream_final_answer_agent_name = stream_final_answer_agent_name
|
|
52
|
+
|
|
47
53
|
def on_agent_start(self, **kwargs: Any) -> None:
|
|
48
54
|
# Prepare agent step message for streaming
|
|
49
55
|
self.agent_name = kwargs.get("name", "Agent")
|
|
50
56
|
self.stack = Stack[Step]()
|
|
57
|
+
|
|
58
|
+
if self.stream_final_answer:
|
|
59
|
+
self.final_stream = Message(author=self.stream_final_answer_agent_name, content="")
|
|
60
|
+
self.last_tokens: List[str] = []
|
|
61
|
+
self.answer_reached = False
|
|
62
|
+
|
|
51
63
|
root_message = context.session.root_message
|
|
52
64
|
parent_id = root_message.id if root_message else None
|
|
53
65
|
run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
|
|
@@ -59,10 +71,11 @@ class HaystackAgentCallbackHandler:
|
|
|
59
71
|
self.stack.push(run_step)
|
|
60
72
|
|
|
61
73
|
def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
74
|
+
if self.last_step:
|
|
75
|
+
run_step = self.last_step
|
|
76
|
+
run_step.end = datetime.utcnow().isoformat()
|
|
77
|
+
run_step.output = agent_step.prompt_node_response
|
|
78
|
+
run_sync(run_step.update())
|
|
66
79
|
|
|
67
80
|
# This method is called when a step has finished
|
|
68
81
|
def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None:
|
|
@@ -77,11 +90,24 @@ class HaystackAgentCallbackHandler:
|
|
|
77
90
|
|
|
78
91
|
if not agent_step.is_last():
|
|
79
92
|
# Prepare step for next agent step
|
|
80
|
-
step = Step(name=self.agent_name, parent_id=self.
|
|
93
|
+
step = Step(name=self.agent_name, parent_id=self.last_step.id)
|
|
81
94
|
self.stack.push(step)
|
|
82
95
|
|
|
83
96
|
def on_new_token(self, token, **kwargs: Any) -> None:
|
|
84
97
|
# Stream agent step tokens
|
|
98
|
+
if self.stream_final_answer:
|
|
99
|
+
if self.answer_reached:
|
|
100
|
+
run_sync(self.final_stream.stream_token(token))
|
|
101
|
+
else:
|
|
102
|
+
self.last_tokens.append(token)
|
|
103
|
+
|
|
104
|
+
last_tokens_concat = ''.join(self.last_tokens)
|
|
105
|
+
final_answer_match = re.search(self.final_answer_pattern, last_tokens_concat)
|
|
106
|
+
|
|
107
|
+
if final_answer_match:
|
|
108
|
+
self.answer_reached = True
|
|
109
|
+
run_sync(self.final_stream.stream_token(final_answer_match.group(1)))
|
|
110
|
+
|
|
85
111
|
run_sync(self.stack.peek().stream_token(token))
|
|
86
112
|
|
|
87
113
|
def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None:
|
chainlit/langchain/callbacks.py
CHANGED
|
@@ -6,11 +6,11 @@ from chainlit.context import context_var
|
|
|
6
6
|
from chainlit.message import Message
|
|
7
7
|
from chainlit.playground.providers.openai import stringify_function_call
|
|
8
8
|
from chainlit.step import Step, TrueStepType
|
|
9
|
-
from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
|
|
10
9
|
from langchain.callbacks.tracers.base import BaseTracer
|
|
11
10
|
from langchain.callbacks.tracers.schemas import Run
|
|
12
11
|
from langchain.schema import BaseMessage
|
|
13
12
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
|
13
|
+
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
|
|
14
14
|
|
|
15
15
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
|
16
16
|
|
|
@@ -388,11 +388,13 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
388
388
|
self.steps = {}
|
|
389
389
|
self.parent_id_map = {}
|
|
390
390
|
self.ignored_runs = set()
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
391
|
+
|
|
392
|
+
if self.context.current_step:
|
|
393
|
+
self.root_parent_id = self.context.current_step.id
|
|
394
|
+
elif self.context.session.root_message:
|
|
395
|
+
self.root_parent_id = self.context.session.root_message.id
|
|
396
|
+
else:
|
|
397
|
+
self.root_parent_id = None
|
|
396
398
|
|
|
397
399
|
if to_ignore is None:
|
|
398
400
|
self.to_ignore = DEFAULT_TO_IGNORE
|
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
from chainlit.context import context_var
|
|
5
5
|
from chainlit.element import Text
|
|
6
6
|
from chainlit.step import Step, StepType
|
|
7
|
-
from
|
|
7
|
+
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
|
|
8
8
|
from llama_index.callbacks import TokenCountingHandler
|
|
9
9
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
10
10
|
from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
|
|
@@ -36,14 +36,19 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
36
36
|
)
|
|
37
37
|
self.context = context_var.get()
|
|
38
38
|
|
|
39
|
+
if self.context.current_step:
|
|
40
|
+
self.root_parent_id = self.context.current_step.id
|
|
41
|
+
elif self.context.session.root_message:
|
|
42
|
+
self.root_parent_id = self.context.session.root_message.id
|
|
43
|
+
else:
|
|
44
|
+
self.root_parent_id = None
|
|
45
|
+
|
|
39
46
|
self.steps = {}
|
|
40
47
|
|
|
41
48
|
def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
|
|
42
49
|
if event_parent_id and event_parent_id in self.steps:
|
|
43
50
|
return event_parent_id
|
|
44
|
-
|
|
45
|
-
return root_message.id
|
|
46
|
-
return None
|
|
51
|
+
return self.root_parent_id
|
|
47
52
|
|
|
48
53
|
def _restore_context(self) -> None:
|
|
49
54
|
"""Restore Chainlit context in the current thread
|
|
@@ -113,7 +118,10 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
113
118
|
[f"Source {idx}" for idx, _ in enumerate(sources)]
|
|
114
119
|
)
|
|
115
120
|
step.elements = [
|
|
116
|
-
Text(
|
|
121
|
+
Text(
|
|
122
|
+
name=f"Source {idx}",
|
|
123
|
+
content=source.node.get_text() or "Empty node",
|
|
124
|
+
)
|
|
117
125
|
for idx, source in enumerate(sources)
|
|
118
126
|
]
|
|
119
127
|
step.output = f"Retrieved the following sources: {source_refs}"
|
chainlit/message.py
CHANGED
chainlit/oauth_providers.py
CHANGED
|
@@ -409,6 +409,72 @@ class DescopeOAuthProvider(OAuthProvider):
|
|
|
409
409
|
return (descope_user, user)
|
|
410
410
|
|
|
411
411
|
|
|
412
|
+
class AWSCognitoOAuthProvider(OAuthProvider):
|
|
413
|
+
id = "aws-cognito"
|
|
414
|
+
env = [
|
|
415
|
+
"OAUTH_COGNITO_CLIENT_ID",
|
|
416
|
+
"OAUTH_COGNITO_CLIENT_SECRET",
|
|
417
|
+
"OAUTH_COGNITO_DOMAIN",
|
|
418
|
+
]
|
|
419
|
+
authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login"
|
|
420
|
+
token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token"
|
|
421
|
+
|
|
422
|
+
def __init__(self):
|
|
423
|
+
self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID")
|
|
424
|
+
self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET")
|
|
425
|
+
self.authorize_params = {
|
|
426
|
+
"response_type": "code",
|
|
427
|
+
"client_id": self.client_id,
|
|
428
|
+
"scope": "openid profile email",
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
async def get_token(self, code: str, url: str):
|
|
432
|
+
payload = {
|
|
433
|
+
"client_id": self.client_id,
|
|
434
|
+
"client_secret": self.client_secret,
|
|
435
|
+
"code": code,
|
|
436
|
+
"grant_type": "authorization_code",
|
|
437
|
+
"redirect_uri": url,
|
|
438
|
+
}
|
|
439
|
+
async with httpx.AsyncClient() as client:
|
|
440
|
+
response = await client.post(
|
|
441
|
+
self.token_url,
|
|
442
|
+
data=payload,
|
|
443
|
+
)
|
|
444
|
+
response.raise_for_status()
|
|
445
|
+
json = response.json()
|
|
446
|
+
|
|
447
|
+
token = json.get("access_token")
|
|
448
|
+
if not token:
|
|
449
|
+
raise HTTPException(
|
|
450
|
+
status_code=400, detail="Failed to get the access token"
|
|
451
|
+
)
|
|
452
|
+
return token
|
|
453
|
+
|
|
454
|
+
async def get_user_info(self, token: str):
|
|
455
|
+
user_info_url = (
|
|
456
|
+
f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo"
|
|
457
|
+
)
|
|
458
|
+
async with httpx.AsyncClient() as client:
|
|
459
|
+
response = await client.get(
|
|
460
|
+
user_info_url,
|
|
461
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
462
|
+
)
|
|
463
|
+
response.raise_for_status()
|
|
464
|
+
|
|
465
|
+
cognito_user = response.json()
|
|
466
|
+
|
|
467
|
+
# Customize user metadata as needed
|
|
468
|
+
user = User(
|
|
469
|
+
identifier=cognito_user["email"],
|
|
470
|
+
metadata={
|
|
471
|
+
"image": cognito_user.get("picture", ""),
|
|
472
|
+
"provider": "aws-cognito",
|
|
473
|
+
},
|
|
474
|
+
)
|
|
475
|
+
return (cognito_user, user)
|
|
476
|
+
|
|
477
|
+
|
|
412
478
|
providers = [
|
|
413
479
|
GithubOAuthProvider(),
|
|
414
480
|
GoogleOAuthProvider(),
|
|
@@ -416,6 +482,7 @@ providers = [
|
|
|
416
482
|
OktaOAuthProvider(),
|
|
417
483
|
Auth0OAuthProvider(),
|
|
418
484
|
DescopeOAuthProvider(),
|
|
485
|
+
AWSCognitoOAuthProvider(),
|
|
419
486
|
]
|
|
420
487
|
|
|
421
488
|
|
chainlit/playground/config.py
CHANGED
|
@@ -9,6 +9,7 @@ from chainlit.playground.providers import (
|
|
|
9
9
|
OpenAI,
|
|
10
10
|
ChatVertexAI,
|
|
11
11
|
GenerationVertexAI,
|
|
12
|
+
Gemini,
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
providers = {
|
|
@@ -19,6 +20,7 @@ providers = {
|
|
|
19
20
|
Anthropic.id: Anthropic,
|
|
20
21
|
ChatVertexAI.id: ChatVertexAI,
|
|
21
22
|
GenerationVertexAI.id: GenerationVertexAI,
|
|
23
|
+
Gemini.id: Gemini,
|
|
22
24
|
} # type: Dict[str, BaseProvider]
|
|
23
25
|
|
|
24
26
|
|
chainlit/playground/provider.py
CHANGED
|
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
4
4
|
from chainlit.config import config
|
|
5
5
|
from chainlit.telemetry import trace_event
|
|
6
6
|
from chainlit.types import GenerationRequest
|
|
7
|
-
from chainlit_client import BaseGeneration, ChatGeneration, GenerationMessage
|
|
8
7
|
from fastapi import HTTPException
|
|
8
|
+
from literalai import BaseGeneration, ChatGeneration, GenerationMessage
|
|
9
9
|
from pydantic.dataclasses import dataclass
|
|
10
10
|
|
|
11
11
|
from chainlit import input_widget
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from chainlit.input_widget import Select, Slider, Tags
|
|
2
2
|
from chainlit.playground.provider import BaseProvider
|
|
3
|
-
from chainlit_client import GenerationMessage
|
|
4
3
|
from fastapi import HTTPException
|
|
5
4
|
from fastapi.responses import StreamingResponse
|
|
5
|
+
from literalai import GenerationMessage
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class AnthropicProvider(BaseProvider):
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import List, Union
|
|
2
2
|
|
|
3
|
+
from chainlit.input_widget import InputWidget
|
|
3
4
|
from chainlit.playground.provider import BaseProvider
|
|
4
5
|
from chainlit.sync import make_async
|
|
5
|
-
from chainlit_client import GenerationMessage
|
|
6
6
|
from fastapi.responses import StreamingResponse
|
|
7
|
+
from literalai import GenerationMessage
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class LangchainGenericProvider(BaseProvider):
|
|
@@ -18,13 +19,14 @@ class LangchainGenericProvider(BaseProvider):
|
|
|
18
19
|
id: str,
|
|
19
20
|
name: str,
|
|
20
21
|
llm: Union[LLM, BaseChatModel],
|
|
22
|
+
inputs: List[InputWidget] = [],
|
|
21
23
|
is_chat: bool = False,
|
|
22
24
|
):
|
|
23
25
|
super().__init__(
|
|
24
26
|
id=id,
|
|
25
27
|
name=name,
|
|
26
28
|
env_vars={},
|
|
27
|
-
inputs=
|
|
29
|
+
inputs=inputs,
|
|
28
30
|
is_chat=is_chat,
|
|
29
31
|
)
|
|
30
32
|
self.llm = llm
|
|
@@ -65,10 +67,9 @@ class LangchainGenericProvider(BaseProvider):
|
|
|
65
67
|
|
|
66
68
|
messages = self.create_generation(request)
|
|
67
69
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
input=messages,
|
|
70
|
+
# https://github.com/langchain-ai/langchain/issues/14980
|
|
71
|
+
result = await make_async(self.llm.stream)(
|
|
72
|
+
input=messages, **request.generation.settings
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
def create_event_stream():
|
|
@@ -33,10 +33,10 @@ class ChatVertexAIProvider(BaseProvider):
|
|
|
33
33
|
|
|
34
34
|
self.validate_env(request=request)
|
|
35
35
|
|
|
36
|
-
llm_settings = request.
|
|
36
|
+
llm_settings = request.generation.settings
|
|
37
37
|
self.require_settings(llm_settings)
|
|
38
38
|
|
|
39
|
-
|
|
39
|
+
messages = self.create_generation(request)
|
|
40
40
|
model_name = llm_settings["model"]
|
|
41
41
|
if model_name.startswith("chat-"):
|
|
42
42
|
model = ChatModel.from_pretrained(model_name)
|
|
@@ -52,7 +52,7 @@ class ChatVertexAIProvider(BaseProvider):
|
|
|
52
52
|
|
|
53
53
|
async def create_event_stream():
|
|
54
54
|
for response in await cl.make_async(chat.send_message_streaming)(
|
|
55
|
-
|
|
55
|
+
messages, **llm_settings
|
|
56
56
|
):
|
|
57
57
|
yield response.text
|
|
58
58
|
|
|
@@ -66,10 +66,10 @@ class GenerationVertexAIProvider(BaseProvider):
|
|
|
66
66
|
|
|
67
67
|
self.validate_env(request=request)
|
|
68
68
|
|
|
69
|
-
llm_settings = request.
|
|
69
|
+
llm_settings = request.generation.settings
|
|
70
70
|
self.require_settings(llm_settings)
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
messages = self.create_generation(request)
|
|
73
73
|
model_name = llm_settings["model"]
|
|
74
74
|
if model_name.startswith("text-"):
|
|
75
75
|
model = TextGenerationModel.from_pretrained(model_name)
|
|
@@ -84,13 +84,42 @@ class GenerationVertexAIProvider(BaseProvider):
|
|
|
84
84
|
|
|
85
85
|
async def create_event_stream():
|
|
86
86
|
for response in await cl.make_async(model.predict_streaming)(
|
|
87
|
-
|
|
87
|
+
messages, **llm_settings
|
|
88
88
|
):
|
|
89
89
|
yield response.text
|
|
90
90
|
|
|
91
91
|
return StreamingResponse(create_event_stream())
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
class GeminiProvider(BaseProvider):
|
|
95
|
+
async def create_completion(self, request):
|
|
96
|
+
await super().create_completion(request)
|
|
97
|
+
from vertexai.preview.generative_models import GenerativeModel
|
|
98
|
+
from google.cloud import aiplatform
|
|
99
|
+
import os
|
|
100
|
+
|
|
101
|
+
self.validate_env(request=request)
|
|
102
|
+
|
|
103
|
+
llm_settings = request.generation.settings
|
|
104
|
+
self.require_settings(llm_settings)
|
|
105
|
+
|
|
106
|
+
messages = self.create_generation(request)
|
|
107
|
+
aiplatform.init( # TODO: remove this when Gemini is released in all the regions
|
|
108
|
+
project=os.environ["GCP_PROJECT_ID"],
|
|
109
|
+
location='us-central1',
|
|
110
|
+
)
|
|
111
|
+
model = GenerativeModel(llm_settings["model"])
|
|
112
|
+
del llm_settings["model"]
|
|
113
|
+
|
|
114
|
+
async def create_event_stream():
|
|
115
|
+
for response in await cl.make_async(model.generate_content)(
|
|
116
|
+
messages, stream=True, generation_config=llm_settings
|
|
117
|
+
):
|
|
118
|
+
yield response.candidates[0].content.parts[0].text
|
|
119
|
+
|
|
120
|
+
return StreamingResponse(create_event_stream())
|
|
121
|
+
|
|
122
|
+
|
|
94
123
|
gcp_env_vars = {"google_application_credentials": "GOOGLE_APPLICATION_CREDENTIALS"}
|
|
95
124
|
|
|
96
125
|
ChatVertexAI = ChatVertexAIProvider(
|
|
@@ -124,3 +153,19 @@ GenerationVertexAI = GenerationVertexAIProvider(
|
|
|
124
153
|
],
|
|
125
154
|
is_chat=False,
|
|
126
155
|
)
|
|
156
|
+
|
|
157
|
+
Gemini = GeminiProvider(
|
|
158
|
+
id="gemini",
|
|
159
|
+
env_vars=gcp_env_vars,
|
|
160
|
+
name="Gemini",
|
|
161
|
+
inputs=[
|
|
162
|
+
Select(
|
|
163
|
+
id="model",
|
|
164
|
+
label="Model",
|
|
165
|
+
values=["gemini-pro", "gemini-pro-vision"],
|
|
166
|
+
initial_value="gemini-pro",
|
|
167
|
+
),
|
|
168
|
+
*vertexai_common_inputs,
|
|
169
|
+
],
|
|
170
|
+
is_chat=False,
|
|
171
|
+
)
|
chainlit/server.py
CHANGED
|
@@ -42,7 +42,16 @@ from chainlit.types import (
|
|
|
42
42
|
UpdateFeedbackRequest,
|
|
43
43
|
)
|
|
44
44
|
from chainlit.user import PersistedUser, User
|
|
45
|
-
from fastapi import
|
|
45
|
+
from fastapi import (
|
|
46
|
+
Depends,
|
|
47
|
+
FastAPI,
|
|
48
|
+
HTTPException,
|
|
49
|
+
Query,
|
|
50
|
+
Request,
|
|
51
|
+
Response,
|
|
52
|
+
UploadFile,
|
|
53
|
+
status,
|
|
54
|
+
)
|
|
46
55
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
47
56
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
48
57
|
from fastapi.staticfiles import StaticFiles
|
|
@@ -257,13 +266,24 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
|
257
266
|
)
|
|
258
267
|
access_token = create_jwt(user)
|
|
259
268
|
if data_layer := get_data_layer():
|
|
260
|
-
|
|
269
|
+
try:
|
|
270
|
+
await data_layer.create_user(user)
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error(f"Error creating user: {e}")
|
|
273
|
+
|
|
261
274
|
return {
|
|
262
275
|
"access_token": access_token,
|
|
263
276
|
"token_type": "bearer",
|
|
264
277
|
}
|
|
265
278
|
|
|
266
279
|
|
|
280
|
+
@app.post("/logout")
|
|
281
|
+
async def logout(request: Request, response: Response):
|
|
282
|
+
if config.code.on_logout:
|
|
283
|
+
return await config.code.on_logout(request, response)
|
|
284
|
+
return {"success": True}
|
|
285
|
+
|
|
286
|
+
|
|
267
287
|
@app.post("/auth/header")
|
|
268
288
|
async def header_auth(request: Request):
|
|
269
289
|
if not config.code.header_auth_callback:
|
|
@@ -282,7 +302,11 @@ async def header_auth(request: Request):
|
|
|
282
302
|
|
|
283
303
|
access_token = create_jwt(user)
|
|
284
304
|
if data_layer := get_data_layer():
|
|
285
|
-
|
|
305
|
+
try:
|
|
306
|
+
await data_layer.create_user(user)
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.error(f"Error creating user: {e}")
|
|
309
|
+
|
|
286
310
|
return {
|
|
287
311
|
"access_token": access_token,
|
|
288
312
|
"token_type": "bearer",
|
|
@@ -318,8 +342,9 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
318
342
|
url=f"{provider.authorize_url}?{params}",
|
|
319
343
|
)
|
|
320
344
|
samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax") # type: Any
|
|
345
|
+
secure = samesite.lower() == 'none'
|
|
321
346
|
response.set_cookie(
|
|
322
|
-
"oauth_state", random, httponly=True, samesite=samesite, max_age=3 * 60
|
|
347
|
+
"oauth_state", random, httponly=True, samesite=samesite, secure=secure, max_age=3 * 60
|
|
323
348
|
)
|
|
324
349
|
return response
|
|
325
350
|
|
|
@@ -389,7 +414,10 @@ async def oauth_callback(
|
|
|
389
414
|
access_token = create_jwt(user)
|
|
390
415
|
|
|
391
416
|
if data_layer := get_data_layer():
|
|
392
|
-
|
|
417
|
+
try:
|
|
418
|
+
await data_layer.create_user(user)
|
|
419
|
+
except Exception as e:
|
|
420
|
+
logger.error(f"Error creating user: {e}")
|
|
393
421
|
|
|
394
422
|
params = urllib.parse.urlencode(
|
|
395
423
|
{
|
|
@@ -471,12 +499,12 @@ async def update_feedback(
|
|
|
471
499
|
"""Update the human feedback for a particular message."""
|
|
472
500
|
data_layer = get_data_layer()
|
|
473
501
|
if not data_layer:
|
|
474
|
-
raise HTTPException(status_code=
|
|
502
|
+
raise HTTPException(status_code=500, detail="Data persistence is not enabled")
|
|
475
503
|
|
|
476
504
|
try:
|
|
477
505
|
feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
|
|
478
506
|
except Exception as e:
|
|
479
|
-
raise HTTPException(detail=str(e), status_code=
|
|
507
|
+
raise HTTPException(detail=str(e), status_code=500)
|
|
480
508
|
|
|
481
509
|
return JSONResponse(content={"success": True, "feedbackId": feedback_id})
|
|
482
510
|
|
|
@@ -599,7 +627,6 @@ async def upload_file(
|
|
|
599
627
|
async def get_file(
|
|
600
628
|
file_id: str,
|
|
601
629
|
session_id: Optional[str] = None,
|
|
602
|
-
token: Optional[str] = None,
|
|
603
630
|
):
|
|
604
631
|
from chainlit.session import WebsocketSession
|
|
605
632
|
|
|
@@ -611,13 +638,6 @@ async def get_file(
|
|
|
611
638
|
detail="Session not found",
|
|
612
639
|
)
|
|
613
640
|
|
|
614
|
-
if current_user := await get_current_user(token or ""):
|
|
615
|
-
if not session.user or session.user.identifier != current_user.identifier:
|
|
616
|
-
raise HTTPException(
|
|
617
|
-
status_code=401,
|
|
618
|
-
detail="You are not authorized to upload files for this session",
|
|
619
|
-
)
|
|
620
|
-
|
|
621
641
|
if file_id in session.files:
|
|
622
642
|
file = session.files[file_id]
|
|
623
643
|
return FileResponse(file["path"], media_type=file["type"])
|
chainlit/session.py
CHANGED
|
@@ -5,6 +5,7 @@ import uuid
|
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
|
|
6
6
|
|
|
7
7
|
import aiofiles
|
|
8
|
+
from chainlit.logger import logger
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
10
11
|
from chainlit.message import Message
|
|
@@ -242,7 +243,10 @@ class WebsocketSession(BaseSession):
|
|
|
242
243
|
for method_name, queue in self.thread_queues.items():
|
|
243
244
|
while queue:
|
|
244
245
|
method, self, args, kwargs = queue.popleft()
|
|
245
|
-
|
|
246
|
+
try:
|
|
247
|
+
await method(self, *args, **kwargs)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
logger.error(f"Error while flushing {method_name}: {e}")
|
|
246
250
|
|
|
247
251
|
@classmethod
|
|
248
252
|
def get(cls, socket_id: str):
|