unique_toolkit 0.5.55__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/validate_required_values.py +21 -0
- unique_toolkit/app/__init__.py +20 -0
- unique_toolkit/app/schemas.py +73 -7
- unique_toolkit/chat/__init__.py +5 -4
- unique_toolkit/chat/constants.py +3 -0
- unique_toolkit/chat/functions.py +661 -0
- unique_toolkit/chat/schemas.py +11 -11
- unique_toolkit/chat/service.py +273 -430
- unique_toolkit/content/__init__.py +1 -0
- unique_toolkit/content/constants.py +2 -0
- unique_toolkit/content/functions.py +475 -0
- unique_toolkit/content/service.py +163 -315
- unique_toolkit/content/utils.py +32 -0
- unique_toolkit/embedding/__init__.py +3 -0
- unique_toolkit/embedding/constants.py +2 -0
- unique_toolkit/embedding/functions.py +79 -0
- unique_toolkit/embedding/service.py +47 -34
- unique_toolkit/evaluators/__init__.py +1 -0
- unique_toolkit/evaluators/constants.py +1 -0
- unique_toolkit/evaluators/context_relevancy/constants.py +3 -3
- unique_toolkit/evaluators/context_relevancy/utils.py +5 -2
- unique_toolkit/evaluators/hallucination/utils.py +2 -1
- unique_toolkit/language_model/__init__.py +1 -0
- unique_toolkit/language_model/constants.py +4 -0
- unique_toolkit/language_model/functions.py +362 -0
- unique_toolkit/language_model/service.py +246 -293
- unique_toolkit/short_term_memory/__init__.py +5 -0
- unique_toolkit/short_term_memory/constants.py +1 -0
- unique_toolkit/short_term_memory/functions.py +175 -0
- unique_toolkit/short_term_memory/service.py +153 -27
- {unique_toolkit-0.5.55.dist-info → unique_toolkit-0.6.0.dist-info}/METADATA +33 -7
- unique_toolkit-0.6.0.dist-info/RECORD +64 -0
- unique_toolkit-0.5.55.dist-info/RECORD +0 -50
- {unique_toolkit-0.5.55.dist-info → unique_toolkit-0.6.0.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.5.55.dist-info → unique_toolkit-0.6.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,79 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import unique_sdk
|
4
|
+
|
5
|
+
from unique_toolkit.embedding.constants import DEFAULT_TIMEOUT, DOMAIN_NAME
|
6
|
+
from unique_toolkit.embedding.schemas import Embeddings
|
7
|
+
|
8
|
+
logger = logging.getLogger(f"toolkit.{DOMAIN_NAME}.{__name__}")
|
9
|
+
|
10
|
+
|
11
|
+
def embed_texts(
|
12
|
+
user_id: str,
|
13
|
+
company_id: str,
|
14
|
+
texts: list[str],
|
15
|
+
timeout: int = DEFAULT_TIMEOUT,
|
16
|
+
) -> Embeddings:
|
17
|
+
"""
|
18
|
+
Embed text.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
user_id (str): The user ID.
|
22
|
+
company_id (str): The company ID.
|
23
|
+
texts (list[str]): The texts to embed.
|
24
|
+
timeout (int): The timeout in milliseconds. Defaults to 600000.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
Embeddings: The Embedding object.
|
28
|
+
|
29
|
+
Raises:
|
30
|
+
Exception: If an error occurs.
|
31
|
+
"""
|
32
|
+
|
33
|
+
try:
|
34
|
+
data = {
|
35
|
+
"user_id": user_id,
|
36
|
+
"company_id": company_id,
|
37
|
+
"texts": texts,
|
38
|
+
"timeout": timeout,
|
39
|
+
}
|
40
|
+
response = unique_sdk.Embeddings.create(**data)
|
41
|
+
return Embeddings(**response)
|
42
|
+
except Exception as e:
|
43
|
+
logger.error(f"Error embedding texts: {e}")
|
44
|
+
raise e
|
45
|
+
|
46
|
+
|
47
|
+
async def embed_texts_async(
|
48
|
+
user_id: str,
|
49
|
+
company_id: str,
|
50
|
+
texts: list[str],
|
51
|
+
timeout: int = DEFAULT_TIMEOUT,
|
52
|
+
) -> Embeddings:
|
53
|
+
"""
|
54
|
+
Embed text asynchronously.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
user_id (str): The user ID.
|
58
|
+
company_id (str): The company ID.
|
59
|
+
texts (list[str]): The texts to embed.
|
60
|
+
timeout (int): The timeout in milliseconds. Defaults to 600000.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Embeddings: The Embedding object.
|
64
|
+
|
65
|
+
Raises:
|
66
|
+
Exception: If an error occurs.
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
data = {
|
70
|
+
"user_id": user_id,
|
71
|
+
"company_id": company_id,
|
72
|
+
"texts": texts,
|
73
|
+
"timeout": timeout,
|
74
|
+
}
|
75
|
+
response = await unique_sdk.Embeddings.create_async(**data)
|
76
|
+
return Embeddings(**response)
|
77
|
+
except Exception as e:
|
78
|
+
logger.error(f"Error embedding texts: {e}")
|
79
|
+
raise e
|
@@ -1,10 +1,10 @@
|
|
1
|
-
import
|
2
|
-
from typing import Optional
|
3
|
-
|
4
|
-
import unique_sdk
|
1
|
+
from typing_extensions import deprecated
|
5
2
|
|
6
3
|
from unique_toolkit._common._base_service import BaseService
|
7
|
-
from unique_toolkit.
|
4
|
+
from unique_toolkit._common.validate_required_values import validate_required_values
|
5
|
+
from unique_toolkit.app.schemas import BaseEvent, Event
|
6
|
+
from unique_toolkit.embedding.constants import DEFAULT_TIMEOUT
|
7
|
+
from unique_toolkit.embedding.functions import embed_texts, embed_texts_async
|
8
8
|
from unique_toolkit.embedding.schemas import Embeddings
|
9
9
|
|
10
10
|
|
@@ -13,14 +13,37 @@ class EmbeddingService(BaseService):
|
|
13
13
|
Provides methods to interact with the Embedding service.
|
14
14
|
|
15
15
|
Attributes:
|
16
|
-
|
17
|
-
|
16
|
+
company_id (str | None): The company ID.
|
17
|
+
user_id (str | None): The user ID.
|
18
18
|
"""
|
19
19
|
|
20
|
-
def __init__(
|
21
|
-
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
event: Event | BaseEvent | None = None,
|
23
|
+
company_id: str | None = None,
|
24
|
+
user_id: str | None = None,
|
25
|
+
):
|
26
|
+
self._event = event
|
27
|
+
if event:
|
28
|
+
self.company_id = event.company_id
|
29
|
+
self.user_id = event.user_id
|
30
|
+
else:
|
31
|
+
[company_id, user_id] = validate_required_values([company_id, user_id])
|
32
|
+
self.company_id = company_id
|
33
|
+
self.user_id = user_id
|
34
|
+
|
35
|
+
@property
|
36
|
+
@deprecated(
|
37
|
+
"The event property is deprecated and will be removed in a future version."
|
38
|
+
)
|
39
|
+
def event(self) -> Event | BaseEvent | None:
|
40
|
+
"""
|
41
|
+
Get the event object (deprecated).
|
22
42
|
|
23
|
-
|
43
|
+
Returns:
|
44
|
+
Event | BaseEvent | None: The event object.
|
45
|
+
"""
|
46
|
+
return self._event
|
24
47
|
|
25
48
|
def embed_texts(
|
26
49
|
self,
|
@@ -32,7 +55,7 @@ class EmbeddingService(BaseService):
|
|
32
55
|
|
33
56
|
Args:
|
34
57
|
text (str): The text to embed.
|
35
|
-
timeout (int): The timeout in milliseconds. Defaults to
|
58
|
+
timeout (int): The timeout in milliseconds. Defaults to 600000.
|
36
59
|
|
37
60
|
Returns:
|
38
61
|
Embeddings: The Embedding object.
|
@@ -40,13 +63,12 @@ class EmbeddingService(BaseService):
|
|
40
63
|
Raises:
|
41
64
|
Exception: If an error occurs.
|
42
65
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
raise e
|
66
|
+
return embed_texts(
|
67
|
+
user_id=self.user_id,
|
68
|
+
company_id=self.company_id,
|
69
|
+
texts=texts,
|
70
|
+
timeout=timeout,
|
71
|
+
)
|
50
72
|
|
51
73
|
async def embed_texts_async(
|
52
74
|
self,
|
@@ -58,7 +80,7 @@ class EmbeddingService(BaseService):
|
|
58
80
|
|
59
81
|
Args:
|
60
82
|
text (str): The text to embed.
|
61
|
-
timeout (int): The timeout in milliseconds. Defaults to
|
83
|
+
timeout (int): The timeout in milliseconds. Defaults to 600000.
|
62
84
|
|
63
85
|
Returns:
|
64
86
|
Embeddings: The Embedding object.
|
@@ -66,18 +88,9 @@ class EmbeddingService(BaseService):
|
|
66
88
|
Raises:
|
67
89
|
Exception: If an error occurs.
|
68
90
|
"""
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
raise e
|
76
|
-
|
77
|
-
def _get_request_obj(self, texts: list[str], timeout: int) -> dict:
|
78
|
-
return {
|
79
|
-
"user_id": self.event.user_id,
|
80
|
-
"company_id": self.event.company_id,
|
81
|
-
"texts": texts,
|
82
|
-
"timeout": timeout,
|
83
|
-
}
|
91
|
+
return await embed_texts_async(
|
92
|
+
user_id=self.user_id,
|
93
|
+
company_id=self.company_id,
|
94
|
+
texts=texts,
|
95
|
+
timeout=timeout,
|
96
|
+
)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .constants import DOMAIN_NAME as DOMAIN_NAME
|
@@ -0,0 +1 @@
|
|
1
|
+
DOMAIN_NAME = "evaluators"
|
@@ -4,6 +4,7 @@ from unique_toolkit.evaluators.context_relevancy.prompts import (
|
|
4
4
|
CONTEXT_RELEVANCY_METRIC_USER_MSG,
|
5
5
|
)
|
6
6
|
from unique_toolkit.evaluators.schemas import (
|
7
|
+
EvaluationMetricInputFieldName,
|
7
8
|
EvaluationMetricName,
|
8
9
|
)
|
9
10
|
from unique_toolkit.language_model.infos import LanguageModel
|
@@ -14,9 +15,8 @@ USER_MSG_KEY = "userPrompt"
|
|
14
15
|
|
15
16
|
# Required input fields for context relevancy evaluation
|
16
17
|
context_relevancy_required_input_fields = [
|
17
|
-
|
18
|
-
|
19
|
-
"context_texts",
|
18
|
+
EvaluationMetricInputFieldName.INPUT_TEXT,
|
19
|
+
EvaluationMetricInputFieldName.CONTEXT_TEXTS,
|
20
20
|
]
|
21
21
|
|
22
22
|
|
@@ -79,7 +79,9 @@ async def check_context_relevancy_async(
|
|
79
79
|
try:
|
80
80
|
msgs = _get_msgs(input, config)
|
81
81
|
result = await LanguageModelService.complete_async_util(
|
82
|
-
company_id=company_id,
|
82
|
+
company_id=company_id,
|
83
|
+
messages=msgs,
|
84
|
+
model_name=model_name,
|
83
85
|
)
|
84
86
|
result_content = result.choices[0].message.content
|
85
87
|
if not result_content:
|
@@ -89,7 +91,8 @@ async def check_context_relevancy_async(
|
|
89
91
|
user_message=error_message,
|
90
92
|
)
|
91
93
|
return parse_eval_metric_result(
|
92
|
-
result_content,
|
94
|
+
result_content, # type: ignore
|
95
|
+
EvaluationMetricName.CONTEXT_RELEVANCY,
|
93
96
|
)
|
94
97
|
except Exception as e:
|
95
98
|
error_message = "Error occurred during context relevancy metric analysis"
|
@@ -91,7 +91,8 @@ async def check_hallucination_async(
|
|
91
91
|
user_message=error_message,
|
92
92
|
)
|
93
93
|
return parse_eval_metric_result(
|
94
|
-
result_content,
|
94
|
+
result_content, # type: ignore
|
95
|
+
EvaluationMetricName.HALLUCINATION,
|
95
96
|
)
|
96
97
|
except Exception as e:
|
97
98
|
error_message = "Error occurred during hallucination metric analysis"
|
@@ -0,0 +1,362 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional, Type, cast
|
3
|
+
|
4
|
+
import unique_sdk
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from unique_toolkit.content.schemas import ContentChunk
|
8
|
+
from unique_toolkit.evaluators import DOMAIN_NAME
|
9
|
+
from unique_toolkit.language_model.constants import (
|
10
|
+
DEFAULT_COMPLETE_TEMPERATURE,
|
11
|
+
DEFAULT_COMPLETE_TIMEOUT,
|
12
|
+
)
|
13
|
+
from unique_toolkit.language_model.infos import LanguageModelName
|
14
|
+
from unique_toolkit.language_model.schemas import (
|
15
|
+
LanguageModelMessages,
|
16
|
+
LanguageModelResponse,
|
17
|
+
LanguageModelStreamResponse,
|
18
|
+
LanguageModelTool,
|
19
|
+
)
|
20
|
+
|
21
|
+
logger = logging.getLogger(f"toolkit.{DOMAIN_NAME}.{__name__}")
|
22
|
+
|
23
|
+
|
24
|
+
def complete(
|
25
|
+
company_id: str,
|
26
|
+
messages: LanguageModelMessages,
|
27
|
+
model_name: LanguageModelName | str,
|
28
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
29
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
30
|
+
tools: Optional[list[LanguageModelTool]] = None,
|
31
|
+
other_options: Optional[dict] = None,
|
32
|
+
structured_output_model: Optional[Type[BaseModel]] = None,
|
33
|
+
structured_output_enforce_schema: bool = False,
|
34
|
+
) -> LanguageModelResponse:
|
35
|
+
"""
|
36
|
+
Calls the completion endpoint synchronously without streaming the response.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
company_id (str): The company ID associated with the request.
|
40
|
+
messages (LanguageModelMessages): The messages to complete.
|
41
|
+
model_name (LanguageModelName | str): The model name to use for the completion.
|
42
|
+
temperature (float): The temperature setting for the completion. Defaults to 0.
|
43
|
+
timeout (int): The timeout value in milliseconds. Defaults to 240_000.
|
44
|
+
tools (Optional[list[LanguageModelTool]]): Optional list of tools to include.
|
45
|
+
other_options (Optional[dict]): Additional options to use. Defaults to None.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
LanguageModelResponse: The response object containing the completed result.
|
49
|
+
"""
|
50
|
+
options, model, messages_dict, _ = _prepare_completion_params_util(
|
51
|
+
messages=messages,
|
52
|
+
model_name=model_name,
|
53
|
+
temperature=temperature,
|
54
|
+
tools=tools,
|
55
|
+
other_options=other_options,
|
56
|
+
structured_output_model=structured_output_model,
|
57
|
+
structured_output_enforce_schema=structured_output_enforce_schema,
|
58
|
+
)
|
59
|
+
|
60
|
+
try:
|
61
|
+
response = unique_sdk.ChatCompletion.create(
|
62
|
+
company_id=company_id,
|
63
|
+
model=model,
|
64
|
+
messages=cast(
|
65
|
+
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
66
|
+
messages_dict,
|
67
|
+
),
|
68
|
+
timeout=timeout,
|
69
|
+
options=options, # type: ignore
|
70
|
+
)
|
71
|
+
return LanguageModelResponse(**response)
|
72
|
+
except Exception as e:
|
73
|
+
logger.error(f"Error completing: {e}")
|
74
|
+
raise e
|
75
|
+
|
76
|
+
|
77
|
+
async def complete_async(
|
78
|
+
company_id: str,
|
79
|
+
messages: LanguageModelMessages,
|
80
|
+
model_name: LanguageModelName | str,
|
81
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
82
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
83
|
+
tools: Optional[list[LanguageModelTool]] = None,
|
84
|
+
other_options: Optional[dict] = None,
|
85
|
+
structured_output_model: Optional[Type[BaseModel]] = None,
|
86
|
+
structured_output_enforce_schema: bool = False,
|
87
|
+
) -> LanguageModelResponse:
|
88
|
+
"""
|
89
|
+
Calls the completion endpoint asynchronously without streaming the response.
|
90
|
+
|
91
|
+
This method sends a request to the completion endpoint using the provided messages, model name,
|
92
|
+
temperature, timeout, and optional tools. It returns a `LanguageModelResponse` object containing
|
93
|
+
the completed result.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
company_id (str): The company ID associated with the request.
|
97
|
+
messages (LanguageModelMessages): The messages to complete.
|
98
|
+
model_name (LanguageModelName | str): The model name to use for the completion.
|
99
|
+
temperature (float): The temperature setting for the completion. Defaults to 0.
|
100
|
+
timeout (int): The timeout value in milliseconds for the request. Defaults to 240_000.
|
101
|
+
tools (Optional[list[LanguageModelTool]]): Optional list of tools to include in the request.
|
102
|
+
other_options (Optional[dict]): The other options to use. Defaults to None.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
LanguageModelResponse: The response object containing the completed result.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
Exception: If an error occurs during the request, an exception is raised and logged.
|
109
|
+
"""
|
110
|
+
options, model, messages_dict, _ = _prepare_completion_params_util(
|
111
|
+
messages=messages,
|
112
|
+
model_name=model_name,
|
113
|
+
temperature=temperature,
|
114
|
+
tools=tools,
|
115
|
+
other_options=other_options,
|
116
|
+
structured_output_model=structured_output_model,
|
117
|
+
structured_output_enforce_schema=structured_output_enforce_schema,
|
118
|
+
)
|
119
|
+
|
120
|
+
try:
|
121
|
+
response = await unique_sdk.ChatCompletion.create_async(
|
122
|
+
company_id=company_id,
|
123
|
+
model=model,
|
124
|
+
messages=cast(
|
125
|
+
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
126
|
+
messages_dict,
|
127
|
+
),
|
128
|
+
timeout=timeout,
|
129
|
+
options=options, # type: ignore
|
130
|
+
)
|
131
|
+
return LanguageModelResponse(**response)
|
132
|
+
except Exception as e:
|
133
|
+
logger.error(f"Error completing: {e}") # type: ignore
|
134
|
+
raise e
|
135
|
+
|
136
|
+
|
137
|
+
def stream_complete_to_chat(
|
138
|
+
company_id: str,
|
139
|
+
user_id: str,
|
140
|
+
assistant_message_id: str,
|
141
|
+
user_message_id: str,
|
142
|
+
chat_id: str,
|
143
|
+
assistant_id: str,
|
144
|
+
messages: LanguageModelMessages,
|
145
|
+
model_name: LanguageModelName | str,
|
146
|
+
content_chunks: list[ContentChunk] = [],
|
147
|
+
debug_info: dict = {},
|
148
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
149
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
150
|
+
tools: Optional[list[LanguageModelTool]] = None,
|
151
|
+
start_text: Optional[str] = None,
|
152
|
+
other_options: Optional[dict] = None,
|
153
|
+
) -> LanguageModelStreamResponse:
|
154
|
+
"""
|
155
|
+
Streams a completion synchronously.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
company_id (str): The company ID associated with the request.
|
159
|
+
user_id (str): The user ID for the request.
|
160
|
+
assistant_message_id (str): The assistant message ID.
|
161
|
+
user_message_id (str): The user message ID.
|
162
|
+
chat_id (str): The chat ID.
|
163
|
+
assistant_id (str): The assistant ID.
|
164
|
+
messages (LanguageModelMessages): The messages to complete.
|
165
|
+
model_name (LanguageModelName | str): The model name.
|
166
|
+
content_chunks (list[ContentChunk]): Content chunks for context.
|
167
|
+
debug_info (dict): Debug information.
|
168
|
+
temperature (float): Temperature setting.
|
169
|
+
timeout (int): Timeout in milliseconds.
|
170
|
+
tools (Optional[list[LanguageModelTool]]): Optional tools.
|
171
|
+
start_text (Optional[str]): Starting text.
|
172
|
+
other_options (Optional[dict]): Additional options.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
LanguageModelStreamResponse: The streaming response object.
|
176
|
+
"""
|
177
|
+
options, model, messages_dict, search_context = _prepare_completion_params_util(
|
178
|
+
messages=messages,
|
179
|
+
model_name=model_name,
|
180
|
+
temperature=temperature,
|
181
|
+
tools=tools,
|
182
|
+
other_options=other_options,
|
183
|
+
content_chunks=content_chunks,
|
184
|
+
)
|
185
|
+
|
186
|
+
try:
|
187
|
+
response = unique_sdk.Integrated.chat_stream_completion(
|
188
|
+
user_id=user_id,
|
189
|
+
company_id=company_id,
|
190
|
+
assistantMessageId=assistant_message_id,
|
191
|
+
userMessageId=user_message_id,
|
192
|
+
messages=cast(
|
193
|
+
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
194
|
+
messages_dict,
|
195
|
+
),
|
196
|
+
chatId=chat_id,
|
197
|
+
searchContext=search_context,
|
198
|
+
model=model,
|
199
|
+
timeout=timeout,
|
200
|
+
assistantId=assistant_id,
|
201
|
+
debugInfo=debug_info,
|
202
|
+
options=options, # type: ignore
|
203
|
+
startText=start_text,
|
204
|
+
)
|
205
|
+
return LanguageModelStreamResponse(**response)
|
206
|
+
except Exception as e:
|
207
|
+
logger.error(f"Error streaming completion: {e}")
|
208
|
+
raise e
|
209
|
+
|
210
|
+
|
211
|
+
async def stream_complete_to_chat_async(
|
212
|
+
company_id: str,
|
213
|
+
user_id: str,
|
214
|
+
assistant_message_id: str,
|
215
|
+
user_message_id: str,
|
216
|
+
chat_id: str,
|
217
|
+
assistant_id: str,
|
218
|
+
messages: LanguageModelMessages,
|
219
|
+
model_name: LanguageModelName | str,
|
220
|
+
content_chunks: list[ContentChunk] = [],
|
221
|
+
debug_info: dict = {},
|
222
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
223
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
224
|
+
tools: Optional[list[LanguageModelTool]] = None,
|
225
|
+
start_text: Optional[str] = None,
|
226
|
+
other_options: Optional[dict] = None,
|
227
|
+
) -> LanguageModelStreamResponse:
|
228
|
+
"""
|
229
|
+
Streams a completion asynchronously.
|
230
|
+
|
231
|
+
Args: [same as stream_complete]
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
LanguageModelStreamResponse: The streaming response object.
|
235
|
+
"""
|
236
|
+
options, model, messages_dict, search_context = _prepare_completion_params_util(
|
237
|
+
messages=messages,
|
238
|
+
model_name=model_name,
|
239
|
+
temperature=temperature,
|
240
|
+
tools=tools,
|
241
|
+
other_options=other_options,
|
242
|
+
content_chunks=content_chunks,
|
243
|
+
)
|
244
|
+
|
245
|
+
try:
|
246
|
+
response = await unique_sdk.Integrated.chat_stream_completion_async(
|
247
|
+
user_id=user_id,
|
248
|
+
company_id=company_id,
|
249
|
+
assistantMessageId=assistant_message_id,
|
250
|
+
userMessageId=user_message_id,
|
251
|
+
messages=cast(
|
252
|
+
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
253
|
+
messages_dict,
|
254
|
+
),
|
255
|
+
chatId=chat_id,
|
256
|
+
searchContext=search_context,
|
257
|
+
model=model,
|
258
|
+
timeout=timeout,
|
259
|
+
assistantId=assistant_id,
|
260
|
+
debugInfo=debug_info,
|
261
|
+
options=options, # type: ignore
|
262
|
+
startText=start_text,
|
263
|
+
)
|
264
|
+
return LanguageModelStreamResponse(**response)
|
265
|
+
except Exception as e:
|
266
|
+
logger.error(f"Error streaming completion: {e}")
|
267
|
+
raise e
|
268
|
+
|
269
|
+
|
270
|
+
def _add_tools_to_options(
|
271
|
+
options: dict,
|
272
|
+
tools: Optional[list[LanguageModelTool]],
|
273
|
+
) -> dict:
|
274
|
+
if tools:
|
275
|
+
options["tools"] = [
|
276
|
+
{
|
277
|
+
"type": "function",
|
278
|
+
"function": tool.model_dump(exclude_none=True),
|
279
|
+
}
|
280
|
+
for tool in tools
|
281
|
+
]
|
282
|
+
return options
|
283
|
+
|
284
|
+
|
285
|
+
def _to_search_context(chunks: list[ContentChunk]) -> dict | None:
|
286
|
+
if not chunks:
|
287
|
+
return None
|
288
|
+
return [
|
289
|
+
unique_sdk.Integrated.SearchResult(
|
290
|
+
id=chunk.id,
|
291
|
+
chunkId=chunk.chunk_id,
|
292
|
+
key=chunk.key,
|
293
|
+
title=chunk.title,
|
294
|
+
url=chunk.url,
|
295
|
+
startPage=chunk.start_page,
|
296
|
+
endPage=chunk.end_page,
|
297
|
+
order=chunk.order,
|
298
|
+
object=chunk.object,
|
299
|
+
) # type: ignore
|
300
|
+
for chunk in chunks
|
301
|
+
]
|
302
|
+
|
303
|
+
|
304
|
+
def _add_response_format_to_options(
|
305
|
+
options: dict,
|
306
|
+
structured_output_model: Type[BaseModel],
|
307
|
+
structured_output_enforce_schema: bool = False,
|
308
|
+
) -> dict:
|
309
|
+
options["responseFormat"] = {
|
310
|
+
"type": "json_schema",
|
311
|
+
"json_schema": {
|
312
|
+
"name": structured_output_model.__name__,
|
313
|
+
"strict": structured_output_enforce_schema,
|
314
|
+
"schema": structured_output_model.model_json_schema(),
|
315
|
+
},
|
316
|
+
}
|
317
|
+
return options
|
318
|
+
|
319
|
+
|
320
|
+
def _prepare_completion_params_util(
|
321
|
+
messages: LanguageModelMessages,
|
322
|
+
model_name: LanguageModelName | str,
|
323
|
+
temperature: float,
|
324
|
+
tools: Optional[list[LanguageModelTool]] = None,
|
325
|
+
other_options: Optional[dict] = None,
|
326
|
+
content_chunks: Optional[list[ContentChunk]] = None,
|
327
|
+
structured_output_model: Optional[Type[BaseModel]] = None,
|
328
|
+
structured_output_enforce_schema: bool = False,
|
329
|
+
) -> tuple[dict, str, dict, Optional[dict]]:
|
330
|
+
"""
|
331
|
+
Prepares common parameters for completion requests.
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
tuple containing:
|
335
|
+
- options (dict): Combined options including tools and temperature
|
336
|
+
- model (str): Resolved model name
|
337
|
+
- messages_dict (dict): Processed messages
|
338
|
+
- search_context (Optional[dict]): Processed content chunks if provided
|
339
|
+
"""
|
340
|
+
|
341
|
+
options = _add_tools_to_options({}, tools)
|
342
|
+
if structured_output_model:
|
343
|
+
options = _add_response_format_to_options(
|
344
|
+
options, structured_output_model, structured_output_enforce_schema
|
345
|
+
)
|
346
|
+
options["temperature"] = temperature
|
347
|
+
if other_options:
|
348
|
+
options.update(other_options)
|
349
|
+
|
350
|
+
model = model_name.name if isinstance(model_name, LanguageModelName) else model_name
|
351
|
+
|
352
|
+
# Different methods need different message dump parameters
|
353
|
+
messages_dict = messages.model_dump(
|
354
|
+
exclude_none=True,
|
355
|
+
by_alias=content_chunks is not None, # Use by_alias for streaming methods
|
356
|
+
)
|
357
|
+
|
358
|
+
search_context = (
|
359
|
+
_to_search_context(content_chunks) if content_chunks is not None else None
|
360
|
+
)
|
361
|
+
|
362
|
+
return options, model, messages_dict, search_context
|