chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +98 -279
- chainlit/_utils.py +8 -0
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +28 -36
- chainlit/auth/cookie.py +123 -0
- chainlit/auth/jwt.py +39 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +362 -0
- chainlit/chat_context.py +64 -0
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +77 -8
- chainlit/config.py +191 -102
- chainlit/context.py +42 -13
- chainlit/copilot/dist/index.js +8750 -903
- chainlit/data/__init__.py +101 -416
- chainlit/data/acl.py +6 -2
- chainlit/data/base.py +107 -0
- chainlit/data/chainlit_data_layer.py +614 -0
- chainlit/data/dynamodb.py +590 -0
- chainlit/data/literalai.py +500 -0
- chainlit/data/sql_alchemy.py +721 -0
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +81 -0
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +26 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +75 -0
- chainlit/data/utils.py +29 -0
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +354 -0
- chainlit/element.py +91 -33
- chainlit/emitter.py +81 -29
- chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
- chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
- chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
- chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
- chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
- chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
- chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
- chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
- chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
- chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
- chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
- chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
- chainlit/frontend/dist/index.html +2 -4
- chainlit/haystack/callbacks.py +4 -7
- chainlit/input_widget.py +8 -4
- chainlit/langchain/callbacks.py +103 -68
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +65 -40
- chainlit/markdown.py +22 -6
- chainlit/message.py +54 -56
- chainlit/mistralai/__init__.py +50 -0
- chainlit/oauth_providers.py +266 -8
- chainlit/openai/__init__.py +10 -18
- chainlit/secret.py +1 -1
- chainlit/server.py +789 -228
- chainlit/session.py +108 -90
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +397 -0
- chainlit/socket.py +199 -116
- chainlit/step.py +141 -89
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +338 -0
- chainlit/translations/bn.json +244 -0
- chainlit/translations/en-US.json +122 -8
- chainlit/translations/gu.json +244 -0
- chainlit/translations/he-IL.json +244 -0
- chainlit/translations/hi.json +244 -0
- chainlit/translations/ja.json +242 -0
- chainlit/translations/kn.json +244 -0
- chainlit/translations/ml.json +244 -0
- chainlit/translations/mr.json +244 -0
- chainlit/translations/nl-NL.json +242 -0
- chainlit/translations/ta.json +244 -0
- chainlit/translations/te.json +244 -0
- chainlit/translations/zh-CN.json +243 -0
- chainlit/translations.py +60 -0
- chainlit/types.py +133 -28
- chainlit/user.py +14 -3
- chainlit/user_session.py +6 -3
- chainlit/utils.py +52 -5
- chainlit/version.py +3 -2
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
- chainlit-2.0.4.dist-info/RECORD +107 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-9711593e.js +0 -723
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -40
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -13
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -408
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit/translations/pt-BR.json +0 -155
- chainlit-1.0.401.dist-info/RECORD +0 -66
- /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/playground/__init__.py
DELETED
chainlit/playground/config.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
from typing import Dict
|
|
2
|
-
|
|
3
|
-
from chainlit.playground.provider import BaseProvider
|
|
4
|
-
from chainlit.playground.providers import (
|
|
5
|
-
Anthropic,
|
|
6
|
-
AzureChatOpenAI,
|
|
7
|
-
AzureOpenAI,
|
|
8
|
-
ChatOpenAI,
|
|
9
|
-
OpenAI,
|
|
10
|
-
ChatVertexAI,
|
|
11
|
-
GenerationVertexAI,
|
|
12
|
-
Gemini,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
providers = {
|
|
16
|
-
AzureChatOpenAI.id: AzureChatOpenAI,
|
|
17
|
-
AzureOpenAI.id: AzureOpenAI,
|
|
18
|
-
ChatOpenAI.id: ChatOpenAI,
|
|
19
|
-
OpenAI.id: OpenAI,
|
|
20
|
-
Anthropic.id: Anthropic,
|
|
21
|
-
ChatVertexAI.id: ChatVertexAI,
|
|
22
|
-
GenerationVertexAI.id: GenerationVertexAI,
|
|
23
|
-
Gemini.id: Gemini,
|
|
24
|
-
} # type: Dict[str, BaseProvider]
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def has_llm_provider(id: str):
|
|
28
|
-
return id in providers
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def add_llm_provider(provider: BaseProvider):
|
|
32
|
-
if not provider.is_configured():
|
|
33
|
-
raise ValueError(
|
|
34
|
-
f"{provider.name} LLM provider requires the following environment variables: {', '.join(provider.env_vars.values())}"
|
|
35
|
-
)
|
|
36
|
-
providers[provider.id] = provider
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def get_llm_providers():
|
|
40
|
-
return [provider for provider in providers.values() if provider.is_configured()]
|
chainlit/playground/provider.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Any, Dict, List, Optional, Union
|
|
3
|
-
|
|
4
|
-
from chainlit.config import config
|
|
5
|
-
from chainlit.telemetry import trace_event
|
|
6
|
-
from chainlit.types import GenerationRequest
|
|
7
|
-
from fastapi import HTTPException
|
|
8
|
-
from literalai import BaseGeneration, ChatGeneration, GenerationMessage
|
|
9
|
-
from pydantic.dataclasses import dataclass
|
|
10
|
-
|
|
11
|
-
from chainlit import input_widget
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@dataclass
|
|
15
|
-
class BaseProvider:
|
|
16
|
-
id: str
|
|
17
|
-
name: str
|
|
18
|
-
env_vars: Dict[str, str]
|
|
19
|
-
inputs: List[input_widget.InputWidget]
|
|
20
|
-
is_chat: bool
|
|
21
|
-
|
|
22
|
-
# Convert the message to string format
|
|
23
|
-
def message_to_string(self, message: GenerationMessage):
|
|
24
|
-
return message["content"]
|
|
25
|
-
|
|
26
|
-
# Concatenate multiple messages with a joiner
|
|
27
|
-
def concatenate_messages(self, messages: List[GenerationMessage], joiner="\n\n"):
|
|
28
|
-
return joiner.join([self.message_to_string(m) for m in messages])
|
|
29
|
-
|
|
30
|
-
# Format the template based on the prompt inputs
|
|
31
|
-
def _format_template(self, template: str, inputs: Optional[Dict]):
|
|
32
|
-
return template.format(**(inputs or {}))
|
|
33
|
-
|
|
34
|
-
# Create a prompt based on the request
|
|
35
|
-
def create_generation(self, request: GenerationRequest):
|
|
36
|
-
if request.chatGeneration and request.chatGeneration.messages:
|
|
37
|
-
messages = request.chatGeneration.messages
|
|
38
|
-
else:
|
|
39
|
-
messages = None
|
|
40
|
-
|
|
41
|
-
if self.is_chat:
|
|
42
|
-
if messages:
|
|
43
|
-
return messages
|
|
44
|
-
elif request.completionGeneration and request.completionGeneration.prompt:
|
|
45
|
-
return [
|
|
46
|
-
GenerationMessage(
|
|
47
|
-
content=request.completionGeneration.prompt,
|
|
48
|
-
role="user",
|
|
49
|
-
),
|
|
50
|
-
]
|
|
51
|
-
else:
|
|
52
|
-
raise HTTPException(
|
|
53
|
-
status_code=422, detail="Could not create generation"
|
|
54
|
-
)
|
|
55
|
-
else:
|
|
56
|
-
if request.completionGeneration:
|
|
57
|
-
return request.completionGeneration.prompt
|
|
58
|
-
elif messages:
|
|
59
|
-
return self.concatenate_messages(messages)
|
|
60
|
-
else:
|
|
61
|
-
raise HTTPException(status_code=422, detail="Could not create prompt")
|
|
62
|
-
|
|
63
|
-
# Create a completion event
|
|
64
|
-
async def create_completion(self, request: GenerationRequest):
|
|
65
|
-
trace_event("completion")
|
|
66
|
-
|
|
67
|
-
# Get the environment variable based on the request
|
|
68
|
-
def get_var(self, request: GenerationRequest, var: str) -> Union[str, None]:
|
|
69
|
-
user_env = config.project.user_env or []
|
|
70
|
-
|
|
71
|
-
if var in user_env:
|
|
72
|
-
return request.userEnv.get(var)
|
|
73
|
-
else:
|
|
74
|
-
return os.environ.get(var)
|
|
75
|
-
|
|
76
|
-
# Check if the environment variable is available
|
|
77
|
-
def _is_env_var_available(self, var: str) -> bool:
|
|
78
|
-
user_env = config.project.user_env or []
|
|
79
|
-
return var in os.environ or var in user_env
|
|
80
|
-
|
|
81
|
-
# Check if the provider is configured
|
|
82
|
-
def is_configured(self):
|
|
83
|
-
for var in self.env_vars.values():
|
|
84
|
-
if not self._is_env_var_available(var):
|
|
85
|
-
return False
|
|
86
|
-
return True
|
|
87
|
-
|
|
88
|
-
# Validate the environment variables in the request
|
|
89
|
-
def validate_env(self, request: GenerationRequest):
|
|
90
|
-
return {k: self.get_var(request, v) for k, v in self.env_vars.items()}
|
|
91
|
-
|
|
92
|
-
# Check if the required settings are present
|
|
93
|
-
def require_settings(self, settings: Dict[str, Any]):
|
|
94
|
-
for _input in self.inputs:
|
|
95
|
-
if _input.id not in settings:
|
|
96
|
-
raise HTTPException(
|
|
97
|
-
status_code=422,
|
|
98
|
-
detail=f"Field {_input.id} is a required setting but is not found.",
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
# Convert the provider to dictionary format
|
|
102
|
-
def to_dict(self):
|
|
103
|
-
return {
|
|
104
|
-
"id": self.id,
|
|
105
|
-
"name": self.name,
|
|
106
|
-
"inputs": [input_widget.to_dict() for input_widget in self.inputs],
|
|
107
|
-
"is_chat": self.is_chat,
|
|
108
|
-
}
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
from chainlit.input_widget import Select, Slider, Tags
|
|
2
|
-
from chainlit.playground.provider import BaseProvider
|
|
3
|
-
from fastapi import HTTPException
|
|
4
|
-
from fastapi.responses import StreamingResponse
|
|
5
|
-
from literalai import GenerationMessage
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class AnthropicProvider(BaseProvider):
|
|
9
|
-
def message_to_string(self, message: GenerationMessage) -> str:
|
|
10
|
-
import anthropic
|
|
11
|
-
|
|
12
|
-
if message["role"] == "user":
|
|
13
|
-
message_text = f"{anthropic.HUMAN_PROMPT} {message['content']}"
|
|
14
|
-
elif message["role"] == "assistant":
|
|
15
|
-
message_text = f"{anthropic.AI_PROMPT} {message['content']}"
|
|
16
|
-
elif message["role"] == "function":
|
|
17
|
-
message_text = f"{anthropic.AI_PROMPT} {message['content']}"
|
|
18
|
-
elif message["role"] == "system":
|
|
19
|
-
message_text = (
|
|
20
|
-
f"{anthropic.HUMAN_PROMPT} <admin>{message['content']}</admin>"
|
|
21
|
-
)
|
|
22
|
-
else:
|
|
23
|
-
raise HTTPException(
|
|
24
|
-
status_code=400, detail=f"Got unknown type {message['role']}"
|
|
25
|
-
)
|
|
26
|
-
return message_text
|
|
27
|
-
|
|
28
|
-
async def create_completion(self, request):
|
|
29
|
-
await super().create_completion(request)
|
|
30
|
-
import anthropic
|
|
31
|
-
|
|
32
|
-
env_settings = self.validate_env(request=request)
|
|
33
|
-
|
|
34
|
-
llm_settings = request.generation.settings
|
|
35
|
-
self.require_settings(llm_settings)
|
|
36
|
-
|
|
37
|
-
prompt = self.concatenate_messages(self.create_generation(request), joiner="")
|
|
38
|
-
|
|
39
|
-
if not prompt.endswith(anthropic.AI_PROMPT):
|
|
40
|
-
prompt += anthropic.AI_PROMPT
|
|
41
|
-
|
|
42
|
-
client = anthropic.AsyncAnthropic(**env_settings)
|
|
43
|
-
|
|
44
|
-
llm_settings["stream"] = True
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
stream = await client.completions.create(prompt=prompt, **llm_settings)
|
|
48
|
-
except anthropic.APIConnectionError as e:
|
|
49
|
-
raise HTTPException(
|
|
50
|
-
status_code=503,
|
|
51
|
-
detail=e.__cause__,
|
|
52
|
-
)
|
|
53
|
-
except anthropic.RateLimitError as e:
|
|
54
|
-
raise HTTPException(
|
|
55
|
-
status_code=429,
|
|
56
|
-
)
|
|
57
|
-
except anthropic.APIStatusError as e:
|
|
58
|
-
raise HTTPException(status_code=e.status_code, detail=e.response)
|
|
59
|
-
|
|
60
|
-
async def create_event_stream():
|
|
61
|
-
async for data in stream:
|
|
62
|
-
token = data.completion
|
|
63
|
-
yield token
|
|
64
|
-
|
|
65
|
-
return StreamingResponse(create_event_stream())
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
Anthropic = AnthropicProvider(
|
|
69
|
-
id="anthropic-chat",
|
|
70
|
-
name="Anthropic",
|
|
71
|
-
env_vars={"api_key": "ANTHROPIC_API_KEY"},
|
|
72
|
-
inputs=[
|
|
73
|
-
Select(
|
|
74
|
-
id="model",
|
|
75
|
-
label="Model",
|
|
76
|
-
values=["claude-2", "claude-instant-1"],
|
|
77
|
-
initial_value="claude-2",
|
|
78
|
-
),
|
|
79
|
-
Slider(
|
|
80
|
-
id="max_tokens_to_sample",
|
|
81
|
-
label="Max Tokens To Sample",
|
|
82
|
-
min=1.0,
|
|
83
|
-
max=100000,
|
|
84
|
-
step=1.0,
|
|
85
|
-
initial=1000,
|
|
86
|
-
),
|
|
87
|
-
Tags(
|
|
88
|
-
id="stop_sequences",
|
|
89
|
-
label="Stop Sequences",
|
|
90
|
-
initial=[],
|
|
91
|
-
),
|
|
92
|
-
Slider(
|
|
93
|
-
id="temperature",
|
|
94
|
-
label="Temperature",
|
|
95
|
-
min=0.0,
|
|
96
|
-
max=1.0,
|
|
97
|
-
step=0.01,
|
|
98
|
-
initial=1,
|
|
99
|
-
),
|
|
100
|
-
Slider(
|
|
101
|
-
id="top_p",
|
|
102
|
-
label="Top P",
|
|
103
|
-
min=0.0,
|
|
104
|
-
max=1.0,
|
|
105
|
-
step=0.01,
|
|
106
|
-
initial=0.7,
|
|
107
|
-
),
|
|
108
|
-
Slider(
|
|
109
|
-
id="top_k",
|
|
110
|
-
label="Top K",
|
|
111
|
-
min=0.0,
|
|
112
|
-
max=2048.0,
|
|
113
|
-
step=1.0,
|
|
114
|
-
initial=0,
|
|
115
|
-
),
|
|
116
|
-
],
|
|
117
|
-
is_chat=True,
|
|
118
|
-
)
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
from chainlit.input_widget import Slider
|
|
4
|
-
from chainlit.playground.provider import BaseProvider
|
|
5
|
-
from chainlit.sync import make_async
|
|
6
|
-
from fastapi import HTTPException
|
|
7
|
-
from fastapi.responses import StreamingResponse
|
|
8
|
-
from pydantic.dataclasses import dataclass
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass
|
|
12
|
-
class BaseHuggingFaceProvider(BaseProvider):
|
|
13
|
-
repo_id: Optional[str] = None
|
|
14
|
-
task = "text2text-generation"
|
|
15
|
-
|
|
16
|
-
async def create_completion(self, request):
|
|
17
|
-
await super().create_completion(request)
|
|
18
|
-
from huggingface_hub.inference_api import InferenceApi
|
|
19
|
-
|
|
20
|
-
env_settings = self.validate_env(request=request)
|
|
21
|
-
llm_settings = request.generation.settings
|
|
22
|
-
self.require_settings(llm_settings)
|
|
23
|
-
|
|
24
|
-
client = InferenceApi(
|
|
25
|
-
repo_id=self.repo_id,
|
|
26
|
-
token=env_settings["api_token"],
|
|
27
|
-
task=self.task,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
prompt = self.create_generation(request)
|
|
31
|
-
|
|
32
|
-
response = await make_async(client)(inputs=prompt, params=llm_settings)
|
|
33
|
-
|
|
34
|
-
if "error" in response:
|
|
35
|
-
raise HTTPException(
|
|
36
|
-
status_code=500,
|
|
37
|
-
detail=f"Error raised by inference API: {response['error']}",
|
|
38
|
-
)
|
|
39
|
-
if client.task == "text2text-generation":
|
|
40
|
-
|
|
41
|
-
def create_event_stream():
|
|
42
|
-
yield response[0]["generated_text"]
|
|
43
|
-
|
|
44
|
-
return StreamingResponse(create_event_stream())
|
|
45
|
-
else:
|
|
46
|
-
raise HTTPException(status_code=400, detail="Unsupported task")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
flan_hf_env_vars = {"api_token": "HUGGINGFACE_API_TOKEN"}
|
|
50
|
-
|
|
51
|
-
HFFlanT5 = BaseHuggingFaceProvider(
|
|
52
|
-
id="huggingface_hub",
|
|
53
|
-
repo_id="declare-lab/flan-alpaca-large",
|
|
54
|
-
name="Flan Alpaca Large",
|
|
55
|
-
env_vars=flan_hf_env_vars,
|
|
56
|
-
inputs=[
|
|
57
|
-
Slider(
|
|
58
|
-
id="temperature",
|
|
59
|
-
label="Temperature",
|
|
60
|
-
min=0.0,
|
|
61
|
-
max=1.0,
|
|
62
|
-
step=0.01,
|
|
63
|
-
initial=0.9,
|
|
64
|
-
),
|
|
65
|
-
Slider(
|
|
66
|
-
id="max_length",
|
|
67
|
-
label="Completion max length",
|
|
68
|
-
min=1.0,
|
|
69
|
-
max=5000,
|
|
70
|
-
step=1.0,
|
|
71
|
-
initial=256,
|
|
72
|
-
),
|
|
73
|
-
],
|
|
74
|
-
is_chat=False,
|
|
75
|
-
)
|
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
from typing import List, Union
|
|
2
|
-
|
|
3
|
-
from chainlit.input_widget import InputWidget
|
|
4
|
-
from chainlit.playground.provider import BaseProvider
|
|
5
|
-
from chainlit.sync import make_async
|
|
6
|
-
from fastapi.responses import StreamingResponse
|
|
7
|
-
from literalai import GenerationMessage
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class LangchainGenericProvider(BaseProvider):
|
|
11
|
-
from langchain.chat_models.base import BaseChatModel
|
|
12
|
-
from langchain.llms.base import LLM
|
|
13
|
-
from langchain.schema import BaseMessage
|
|
14
|
-
|
|
15
|
-
llm: Union[LLM, BaseChatModel]
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
id: str,
|
|
20
|
-
name: str,
|
|
21
|
-
llm: Union[LLM, BaseChatModel],
|
|
22
|
-
inputs: List[InputWidget] = [],
|
|
23
|
-
is_chat: bool = False,
|
|
24
|
-
):
|
|
25
|
-
super().__init__(
|
|
26
|
-
id=id,
|
|
27
|
-
name=name,
|
|
28
|
-
env_vars={},
|
|
29
|
-
inputs=inputs,
|
|
30
|
-
is_chat=is_chat,
|
|
31
|
-
)
|
|
32
|
-
self.llm = llm
|
|
33
|
-
|
|
34
|
-
def prompt_message_to_langchain_message(self, message: GenerationMessage):
|
|
35
|
-
from langchain.schema.messages import (
|
|
36
|
-
AIMessage,
|
|
37
|
-
FunctionMessage,
|
|
38
|
-
HumanMessage,
|
|
39
|
-
SystemMessage,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
content = "" if message["content"] is None else message["content"]
|
|
43
|
-
if message["role"] == "user":
|
|
44
|
-
return HumanMessage(content=content) # type: ignore
|
|
45
|
-
elif message["role"] == "assistant":
|
|
46
|
-
return AIMessage(content=content) # type: ignore
|
|
47
|
-
elif message["role"] == "system":
|
|
48
|
-
return SystemMessage(content=content) # type: ignore
|
|
49
|
-
elif message["role"] == "tool":
|
|
50
|
-
return FunctionMessage(
|
|
51
|
-
content=content, # type: ignore
|
|
52
|
-
name=message["name"] if message["name"] else "function",
|
|
53
|
-
)
|
|
54
|
-
else:
|
|
55
|
-
raise ValueError(f"Got unknown type {message['role']}")
|
|
56
|
-
|
|
57
|
-
def format_message(self, message, prompt):
|
|
58
|
-
message = super().format_message(message, prompt)
|
|
59
|
-
return self.prompt_message_to_langchain_message(message)
|
|
60
|
-
|
|
61
|
-
def message_to_string(self, message: BaseMessage) -> str: # type: ignore[override]
|
|
62
|
-
return str(getattr(message, "content", ""))
|
|
63
|
-
|
|
64
|
-
async def create_completion(self, request):
|
|
65
|
-
from langchain.schema.messages import BaseMessageChunk
|
|
66
|
-
|
|
67
|
-
await super().create_completion(request)
|
|
68
|
-
|
|
69
|
-
messages = self.create_generation(request)
|
|
70
|
-
|
|
71
|
-
# https://github.com/langchain-ai/langchain/issues/14980
|
|
72
|
-
result = await make_async(self.llm.stream)(
|
|
73
|
-
input=messages, **request.generation.settings
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
def create_event_stream():
|
|
77
|
-
try:
|
|
78
|
-
for chunk in result:
|
|
79
|
-
if isinstance(chunk, BaseMessageChunk):
|
|
80
|
-
yield chunk.content
|
|
81
|
-
else:
|
|
82
|
-
yield chunk
|
|
83
|
-
except Exception as e:
|
|
84
|
-
# The better solution would be to return a 500 error, but
|
|
85
|
-
# langchain raises the error in the stream, and the http
|
|
86
|
-
# headers have already been sent.
|
|
87
|
-
yield f"Failed to create completion: {str(e)}"
|
|
88
|
-
|
|
89
|
-
return StreamingResponse(create_event_stream())
|