unique_toolkit 0.7.13__py3-none-any.whl → 0.7.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/validators.py +54 -5
- unique_toolkit/app/schemas.py +15 -20
- unique_toolkit/chat/service.py +313 -65
- unique_toolkit/content/functions.py +12 -3
- unique_toolkit/content/service.py +5 -0
- unique_toolkit/evaluators/config.py +9 -18
- unique_toolkit/evaluators/context_relevancy/constants.py +4 -2
- unique_toolkit/evaluators/context_relevancy/utils.py +32 -18
- unique_toolkit/evaluators/hallucination/constants.py +2 -2
- unique_toolkit/evaluators/hallucination/utils.py +40 -30
- unique_toolkit/language_model/functions.py +23 -19
- unique_toolkit/language_model/infos.py +6 -4
- unique_toolkit/language_model/schemas.py +116 -32
- unique_toolkit/protocols/support.py +28 -0
- {unique_toolkit-0.7.13.dist-info → unique_toolkit-0.7.17.dist-info}/METADATA +18 -2
- {unique_toolkit-0.7.13.dist-info → unique_toolkit-0.7.17.dist-info}/RECORD +18 -17
- {unique_toolkit-0.7.13.dist-info → unique_toolkit-0.7.17.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.7.13.dist-info → unique_toolkit-0.7.17.dist-info}/WHEEL +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from pathlib import Path
|
3
3
|
|
4
|
+
import unique_sdk
|
4
5
|
from requests import Response
|
5
6
|
from typing_extensions import deprecated
|
6
7
|
|
@@ -365,6 +366,7 @@ class ContentService:
|
|
365
366
|
scope_id: str | None = None,
|
366
367
|
chat_id: str | None = None,
|
367
368
|
skip_ingestion: bool = False,
|
369
|
+
ingestion_config: unique_sdk.Content.IngestionConfig | None = None,
|
368
370
|
) -> Content:
|
369
371
|
"""
|
370
372
|
Uploads content to the knowledge base.
|
@@ -390,6 +392,7 @@ class ContentService:
|
|
390
392
|
scope_id=scope_id,
|
391
393
|
chat_id=chat_id,
|
392
394
|
skip_ingestion=skip_ingestion,
|
395
|
+
ingestion_config=ingestion_config,
|
393
396
|
)
|
394
397
|
|
395
398
|
def upload_content(
|
@@ -400,6 +403,7 @@ class ContentService:
|
|
400
403
|
scope_id: str | None = None,
|
401
404
|
chat_id: str | None = None,
|
402
405
|
skip_ingestion: bool = False,
|
406
|
+
ingestion_config: unique_sdk.Content.IngestionConfig | None = None,
|
403
407
|
):
|
404
408
|
"""
|
405
409
|
Uploads content to the knowledge base.
|
@@ -425,6 +429,7 @@ class ContentService:
|
|
425
429
|
scope_id=scope_id,
|
426
430
|
chat_id=chat_id,
|
427
431
|
skip_ingestion=skip_ingestion,
|
432
|
+
ingestion_config=ingestion_config,
|
428
433
|
)
|
429
434
|
|
430
435
|
def request_content_by_id(
|
@@ -1,35 +1,26 @@
|
|
1
1
|
from humps import camelize
|
2
|
-
from pydantic import BaseModel, ConfigDict
|
2
|
+
from pydantic import BaseModel, ConfigDict
|
3
3
|
|
4
|
-
from unique_toolkit._common.validators import
|
4
|
+
from unique_toolkit._common.validators import LMI, LanguageModelInfo
|
5
5
|
from unique_toolkit.evaluators.schemas import (
|
6
6
|
EvaluationMetricName,
|
7
7
|
)
|
8
8
|
from unique_toolkit.language_model.infos import (
|
9
|
-
LanguageModel,
|
10
9
|
LanguageModelName,
|
11
10
|
)
|
12
11
|
|
13
|
-
model_config = ConfigDict(
|
14
|
-
alias_generator=camelize,
|
15
|
-
populate_by_name=True,
|
16
|
-
arbitrary_types_allowed=True,
|
17
|
-
validate_default=True,
|
18
|
-
json_encoders={LanguageModel: lambda v: v.display_name},
|
19
|
-
)
|
20
|
-
|
21
12
|
|
22
13
|
class EvaluationMetricConfig(BaseModel):
|
23
|
-
model_config =
|
14
|
+
model_config = ConfigDict(
|
15
|
+
alias_generator=camelize,
|
16
|
+
populate_by_name=True,
|
17
|
+
validate_default=True,
|
18
|
+
)
|
24
19
|
|
25
20
|
enabled: bool = False
|
26
21
|
name: EvaluationMetricName
|
27
|
-
language_model:
|
28
|
-
LanguageModelName.AZURE_GPT_35_TURBO_0125
|
22
|
+
language_model: LMI = LanguageModelInfo.from_name(
|
23
|
+
LanguageModelName.AZURE_GPT_35_TURBO_0125,
|
29
24
|
)
|
30
25
|
custom_prompts: dict[str, str] = {}
|
31
26
|
score_to_emoji: dict[str, str] = {}
|
32
|
-
|
33
|
-
@field_validator("language_model", mode="before")
|
34
|
-
def validate_language_model(cls, value: LanguageModelName | LanguageModel):
|
35
|
-
return validate_and_init_language_model(value)
|
@@ -7,7 +7,7 @@ from unique_toolkit.evaluators.schemas import (
|
|
7
7
|
EvaluationMetricInputFieldName,
|
8
8
|
EvaluationMetricName,
|
9
9
|
)
|
10
|
-
from unique_toolkit.language_model.infos import
|
10
|
+
from unique_toolkit.language_model.infos import LanguageModelInfo
|
11
11
|
from unique_toolkit.language_model.service import LanguageModelName
|
12
12
|
|
13
13
|
SYSTEM_MSG_KEY = "systemPrompt"
|
@@ -23,7 +23,9 @@ context_relevancy_required_input_fields = [
|
|
23
23
|
default_config = EvaluationMetricConfig(
|
24
24
|
enabled=False,
|
25
25
|
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
26
|
-
language_model=
|
26
|
+
language_model=LanguageModelInfo.from_name(
|
27
|
+
LanguageModelName.AZURE_GPT_35_TURBO_0125
|
28
|
+
),
|
27
29
|
score_to_emoji={"LOW": "🟢", "MEDIUM": "🟡", "HIGH": "🔴"},
|
28
30
|
custom_prompts={
|
29
31
|
SYSTEM_MSG_KEY: CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
@@ -22,6 +22,7 @@ from unique_toolkit.evaluators.schemas import (
|
|
22
22
|
EvaluationMetricName,
|
23
23
|
EvaluationMetricResult,
|
24
24
|
)
|
25
|
+
from unique_toolkit.language_model import LanguageModelName
|
25
26
|
from unique_toolkit.language_model.schemas import (
|
26
27
|
LanguageModelMessages,
|
27
28
|
LanguageModelSystemMessage,
|
@@ -34,12 +35,12 @@ logger = logging.getLogger(__name__)
|
|
34
35
|
|
35
36
|
async def check_context_relevancy_async(
|
36
37
|
company_id: str,
|
37
|
-
|
38
|
+
evaluation_metric_input: EvaluationMetricInput,
|
38
39
|
config: EvaluationMetricConfig,
|
39
40
|
logger: logging.Logger = logger,
|
40
41
|
) -> EvaluationMetricResult | None:
|
41
|
-
"""
|
42
|
-
|
42
|
+
"""Analyzes the relevancy of the context provided for the given evaluation_metric_input and output.
|
43
|
+
|
43
44
|
The analysis classifies the context relevancy level as:
|
44
45
|
- low
|
45
46
|
- medium
|
@@ -47,14 +48,14 @@ async def check_context_relevancy_async(
|
|
47
48
|
|
48
49
|
This method performs the following steps:
|
49
50
|
1. Logs the start of the analysis using the provided `logger`.
|
50
|
-
2. Validates the required fields in the `
|
51
|
+
2. Validates the required fields in the `evaluation_metric_input` data.
|
51
52
|
3. Retrieves the messages using the `_get_msgs` method.
|
52
53
|
4. Calls `LanguageModelService.complete_async_util` to get a completion result.
|
53
54
|
5. Parses and returns the evaluation metric result based on the content of the completion result.
|
54
55
|
|
55
56
|
Args:
|
56
57
|
company_id (str): The company ID for the analysis.
|
57
|
-
|
58
|
+
evaluation_metric_input (EvaluationMetricInput): The evaluation_metric_input data used for evaluation, including the generated output and reference information.
|
58
59
|
config (EvaluationMetricConfig): Configuration settings for the evaluation.
|
59
60
|
logger (Optional[logging.Logger], optional): The logger used for logging information and errors. Defaults to the logger for the current module.
|
60
61
|
|
@@ -63,13 +64,23 @@ async def check_context_relevancy_async(
|
|
63
64
|
|
64
65
|
Raises:
|
65
66
|
EvaluatorException: If required fields are missing or an error occurs during the evaluation.
|
67
|
+
|
66
68
|
"""
|
67
|
-
|
68
|
-
|
69
|
+
model_group_name = (
|
70
|
+
config.language_model.name.value
|
71
|
+
if isinstance(config.language_model.name, LanguageModelName)
|
72
|
+
else config.language_model.name
|
73
|
+
)
|
74
|
+
logger.info(f"Analyzing context relevancy with {model_group_name}.")
|
69
75
|
|
70
|
-
|
76
|
+
evaluation_metric_input.validate_required_fields(
|
77
|
+
context_relevancy_required_input_fields,
|
78
|
+
)
|
71
79
|
|
72
|
-
if
|
80
|
+
if (
|
81
|
+
evaluation_metric_input.context_texts
|
82
|
+
and len(evaluation_metric_input.context_texts) == 0
|
83
|
+
):
|
73
84
|
error_message = "No context texts provided."
|
74
85
|
raise EvaluatorException(
|
75
86
|
user_message=error_message,
|
@@ -77,11 +88,11 @@ async def check_context_relevancy_async(
|
|
77
88
|
)
|
78
89
|
|
79
90
|
try:
|
80
|
-
msgs = _get_msgs(
|
91
|
+
msgs = _get_msgs(evaluation_metric_input, config)
|
81
92
|
result = await LanguageModelService.complete_async_util(
|
82
93
|
company_id=company_id,
|
83
94
|
messages=msgs,
|
84
|
-
model_name=
|
95
|
+
model_name=model_group_name,
|
85
96
|
)
|
86
97
|
result_content = result.choices[0].message.content
|
87
98
|
if not result_content:
|
@@ -104,25 +115,28 @@ async def check_context_relevancy_async(
|
|
104
115
|
|
105
116
|
|
106
117
|
def _get_msgs(
|
107
|
-
|
118
|
+
evaluation_metric_input: EvaluationMetricInput,
|
108
119
|
config: EvaluationMetricConfig,
|
109
|
-
):
|
110
|
-
"""
|
111
|
-
|
120
|
+
) -> LanguageModelMessages:
|
121
|
+
"""Composes the messages for context relevancy analysis.
|
122
|
+
|
123
|
+
The messages are based on the provided evaluation_metric_input and configuration.
|
112
124
|
|
113
125
|
Args:
|
114
|
-
|
126
|
+
evaluation_metric_input (EvaluationMetricInput): The evaluation_metric_input data that includes context texts for the analysis.
|
115
127
|
config (EvaluationMetricConfig): The configuration settings for composing messages.
|
116
128
|
|
117
129
|
Returns:
|
118
|
-
LanguageModelMessages: The composed messages as per the provided
|
130
|
+
LanguageModelMessages: The composed messages as per the provided evaluation_metric_input and configuration.
|
131
|
+
|
119
132
|
"""
|
120
133
|
system_msg_content = _get_system_prompt(config)
|
121
134
|
system_msg = LanguageModelSystemMessage(content=system_msg_content)
|
122
135
|
|
123
136
|
user_msg_templ = Template(_get_user_prompt(config))
|
124
137
|
user_msg_content = user_msg_templ.substitute(
|
125
|
-
|
138
|
+
evaluation_metric_input_text=evaluation_metric_input.evaluation_metric_input_text,
|
139
|
+
contexts_text=evaluation_metric_input.get_joined_context_texts(),
|
126
140
|
)
|
127
141
|
user_msg = LanguageModelUserMessage(content=user_msg_content)
|
128
142
|
return LanguageModelMessages([system_msg, user_msg])
|
@@ -10,7 +10,7 @@ from unique_toolkit.evaluators.schemas import (
|
|
10
10
|
EvaluationMetricName,
|
11
11
|
)
|
12
12
|
from unique_toolkit.language_model.infos import (
|
13
|
-
|
13
|
+
LanguageModelInfo,
|
14
14
|
LanguageModelName,
|
15
15
|
)
|
16
16
|
|
@@ -23,7 +23,7 @@ USER_MSG_DEFAULT_KEY = "userPromptDefault"
|
|
23
23
|
hallucination_metric_default_config = EvaluationMetricConfig(
|
24
24
|
enabled=False,
|
25
25
|
name=EvaluationMetricName.HALLUCINATION,
|
26
|
-
language_model=
|
26
|
+
language_model=LanguageModelInfo.from_name(LanguageModelName.AZURE_GPT_4_0613),
|
27
27
|
score_to_emoji={"LOW": "🟢", "MEDIUM": "🟡", "HIGH": "🔴"},
|
28
28
|
custom_prompts={
|
29
29
|
SYSTEM_MSG_KEY: HALLUCINATION_METRIC_SYSTEM_MSG,
|
@@ -20,6 +20,7 @@ from unique_toolkit.evaluators.schemas import (
|
|
20
20
|
EvaluationMetricName,
|
21
21
|
EvaluationMetricResult,
|
22
22
|
)
|
23
|
+
from unique_toolkit.language_model import LanguageModelName
|
23
24
|
from unique_toolkit.language_model.schemas import (
|
24
25
|
LanguageModelMessages,
|
25
26
|
LanguageModelSystemMessage,
|
@@ -43,8 +44,9 @@ async def check_hallucination_async(
|
|
43
44
|
config: EvaluationMetricConfig,
|
44
45
|
logger: logging.Logger = logger,
|
45
46
|
) -> EvaluationMetricResult | None:
|
46
|
-
"""
|
47
|
-
|
47
|
+
"""Analyze the level of hallucination in the generated output.
|
48
|
+
|
49
|
+
by comparing it with the provided input
|
48
50
|
and the contexts or history. The analysis classifies the hallucination level as:
|
49
51
|
- low
|
50
52
|
- medium
|
@@ -72,16 +74,23 @@ async def check_hallucination_async(
|
|
72
74
|
|
73
75
|
Raises:
|
74
76
|
EvaluatorException: If the context texts are empty, required fields are missing, or an error occurs during the evaluation.
|
77
|
+
|
75
78
|
"""
|
76
|
-
|
77
|
-
|
79
|
+
model_group_name = (
|
80
|
+
config.language_model.name.value
|
81
|
+
if isinstance(config.language_model.name, LanguageModelName)
|
82
|
+
else config.language_model.name
|
83
|
+
)
|
84
|
+
logger.info(f"Analyzing level of hallucination with {model_group_name}.")
|
78
85
|
|
79
86
|
input.validate_required_fields(hallucination_required_input_fields)
|
80
87
|
|
81
88
|
try:
|
82
89
|
msgs = _get_msgs(input, config, logger)
|
83
90
|
result = await LanguageModelService.complete_async_util(
|
84
|
-
company_id=company_id,
|
91
|
+
company_id=company_id,
|
92
|
+
messages=msgs,
|
93
|
+
model_name=model_group_name,
|
85
94
|
)
|
86
95
|
result_content = result.choices[0].message.content
|
87
96
|
if not result_content:
|
@@ -104,71 +113,72 @@ async def check_hallucination_async(
|
|
104
113
|
|
105
114
|
|
106
115
|
def _get_msgs(
|
107
|
-
|
116
|
+
evaluation_metric_input: EvaluationMetricInput,
|
108
117
|
config: EvaluationMetricConfig,
|
109
118
|
logger: logging.Logger,
|
110
119
|
):
|
111
|
-
"""
|
112
|
-
Composes the messages for hallucination analysis based on the provided input and configuration.
|
120
|
+
"""Composes the messages for hallucination analysis based on the provided evaluation_metric_input and configuration.
|
113
121
|
|
114
122
|
This method decides how to compose the messages based on the availability of context texts and history
|
115
|
-
message texts in the `
|
123
|
+
message texts in the `evaluation_metric_input`
|
116
124
|
|
117
125
|
Args:
|
118
|
-
|
126
|
+
evaluation_metric_input (EvaluationMetricInput): The evaluation_metric_input data that includes context texts and history message texts
|
119
127
|
for the analysis.
|
120
128
|
config (EvaluationMetricConfig): The configuration settings for composing messages.
|
121
129
|
logger (Optional[logging.Logger], optional): The logger used for logging debug information.
|
122
130
|
Defaults to the logger for the current module.
|
123
131
|
|
124
132
|
Returns:
|
125
|
-
The composed messages as per the provided
|
133
|
+
The composed messages as per the provided evaluation_metric_input and configuration. The exact type and structure
|
126
134
|
depend on the implementation of the `compose_msgs` and `compose_msgs_default` methods.
|
127
135
|
|
128
136
|
"""
|
129
|
-
if
|
137
|
+
if (
|
138
|
+
evaluation_metric_input.context_texts
|
139
|
+
or evaluation_metric_input.history_messages
|
140
|
+
):
|
130
141
|
logger.debug("Using context / history for hallucination evaluation.")
|
131
|
-
return _compose_msgs(
|
132
|
-
|
133
|
-
|
134
|
-
return _compose_msgs_default(input, config)
|
142
|
+
return _compose_msgs(evaluation_metric_input, config)
|
143
|
+
logger.debug("No contexts and history provided for hallucination evaluation.")
|
144
|
+
return _compose_msgs_default(evaluation_metric_input, config)
|
135
145
|
|
136
146
|
|
137
147
|
def _compose_msgs(
|
138
|
-
|
148
|
+
evaluation_metric_input: EvaluationMetricInput,
|
139
149
|
config: EvaluationMetricConfig,
|
140
150
|
):
|
141
|
-
"""
|
142
|
-
Composes the hallucination analysis messages.
|
143
|
-
"""
|
151
|
+
"""Composes the hallucination analysis messages."""
|
144
152
|
system_msg_content = _get_system_prompt_with_contexts(config)
|
145
153
|
system_msg = LanguageModelSystemMessage(content=system_msg_content)
|
146
154
|
|
147
155
|
user_msg_templ = Template(_get_user_prompt_with_contexts(config))
|
148
156
|
user_msg_content = user_msg_templ.substitute(
|
149
|
-
|
150
|
-
contexts_text=
|
151
|
-
|
152
|
-
|
157
|
+
evaluation_metric_input_text=evaluation_metric_input.evaluation_metric_input_text,
|
158
|
+
contexts_text=evaluation_metric_input.get_joined_context_texts(
|
159
|
+
tag_name="reference",
|
160
|
+
),
|
161
|
+
history_messages_text=evaluation_metric_input.get_joined_history_texts(
|
162
|
+
tag_name="conversation",
|
163
|
+
),
|
164
|
+
output_text=evaluation_metric_input.output_text,
|
153
165
|
)
|
154
166
|
user_msg = LanguageModelUserMessage(content=user_msg_content)
|
155
167
|
return LanguageModelMessages([system_msg, user_msg])
|
156
168
|
|
157
169
|
|
158
170
|
def _compose_msgs_default(
|
159
|
-
|
171
|
+
evaluation_metric_input: EvaluationMetricInput,
|
160
172
|
config: EvaluationMetricConfig,
|
161
173
|
):
|
162
|
-
"""
|
163
|
-
Composes the hallucination analysis prompt without messages.
|
164
|
-
"""
|
174
|
+
"""Composes the hallucination analysis prompt without messages."""
|
165
175
|
system_msg_content = _get_system_prompt_default(config)
|
166
176
|
system_msg = LanguageModelSystemMessage(content=system_msg_content)
|
167
177
|
|
168
178
|
user_msg_templ = Template(_get_user_prompt_default(config))
|
169
179
|
user_msg_content = user_msg_templ.substitute(
|
170
|
-
|
171
|
-
output_text=
|
180
|
+
evaluation_metric_input_text=evaluation_metric_input.evaluation_metric_input_text,
|
181
|
+
output_text=evaluation_metric_input.output_text,
|
172
182
|
)
|
173
183
|
user_msg = LanguageModelUserMessage(content=user_msg_content)
|
174
184
|
return LanguageModelMessages([system_msg, user_msg])
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import cast
|
3
3
|
|
4
4
|
import unique_sdk
|
5
5
|
from pydantic import BaseModel
|
@@ -29,11 +29,10 @@ def complete(
|
|
29
29
|
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
30
30
|
tools: list[LanguageModelTool] | None = None,
|
31
31
|
other_options: dict | None = None,
|
32
|
-
structured_output_model:
|
32
|
+
structured_output_model: type[BaseModel] | None = None,
|
33
33
|
structured_output_enforce_schema: bool = False,
|
34
34
|
) -> LanguageModelResponse:
|
35
|
-
"""
|
36
|
-
Calls the completion endpoint synchronously without streaming the response.
|
35
|
+
"""Call the completion endpoint synchronously without streaming the response.
|
37
36
|
|
38
37
|
Args:
|
39
38
|
company_id (str): The company ID associated with the request.
|
@@ -46,6 +45,7 @@ def complete(
|
|
46
45
|
|
47
46
|
Returns:
|
48
47
|
LanguageModelResponse: The response object containing the completed result.
|
48
|
+
|
49
49
|
"""
|
50
50
|
options, model, messages_dict, _ = _prepare_completion_params_util(
|
51
51
|
messages=messages,
|
@@ -62,7 +62,7 @@ def complete(
|
|
62
62
|
company_id=company_id,
|
63
63
|
model=model,
|
64
64
|
messages=cast(
|
65
|
-
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
65
|
+
"list[unique_sdk.Integrated.ChatCompletionRequestMessage]",
|
66
66
|
messages_dict,
|
67
67
|
),
|
68
68
|
timeout=timeout,
|
@@ -82,11 +82,10 @@ async def complete_async(
|
|
82
82
|
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
83
83
|
tools: list[LanguageModelTool] | None = None,
|
84
84
|
other_options: dict | None = None,
|
85
|
-
structured_output_model:
|
85
|
+
structured_output_model: type[BaseModel] | None = None,
|
86
86
|
structured_output_enforce_schema: bool = False,
|
87
87
|
) -> LanguageModelResponse:
|
88
|
-
"""
|
89
|
-
Calls the completion endpoint asynchronously without streaming the response.
|
88
|
+
"""Call the completion endpoint asynchronously without streaming the response.
|
90
89
|
|
91
90
|
This method sends a request to the completion endpoint using the provided messages, model name,
|
92
91
|
temperature, timeout, and optional tools. It returns a `LanguageModelResponse` object containing
|
@@ -105,7 +104,9 @@ async def complete_async(
|
|
105
104
|
LanguageModelResponse: The response object containing the completed result.
|
106
105
|
|
107
106
|
Raises:
|
108
|
-
Exception: If an error occurs during the request, an exception is raised
|
107
|
+
Exception: If an error occurs during the request, an exception is raised
|
108
|
+
and logged.
|
109
|
+
|
109
110
|
"""
|
110
111
|
options, model, messages_dict, _ = _prepare_completion_params_util(
|
111
112
|
messages=messages,
|
@@ -122,7 +123,7 @@ async def complete_async(
|
|
122
123
|
company_id=company_id,
|
123
124
|
model=model,
|
124
125
|
messages=cast(
|
125
|
-
list[unique_sdk.Integrated.ChatCompletionRequestMessage],
|
126
|
+
"list[unique_sdk.Integrated.ChatCompletionRequestMessage]",
|
126
127
|
messages_dict,
|
127
128
|
),
|
128
129
|
timeout=timeout,
|
@@ -130,7 +131,7 @@ async def complete_async(
|
|
130
131
|
)
|
131
132
|
return LanguageModelResponse(**response)
|
132
133
|
except Exception as e:
|
133
|
-
logger.
|
134
|
+
logger.exception(f"Error completing: {e}")
|
134
135
|
raise e
|
135
136
|
|
136
137
|
|
@@ -163,14 +164,14 @@ def _to_search_context(chunks: list[ContentChunk]) -> dict | None:
|
|
163
164
|
endPage=chunk.end_page,
|
164
165
|
order=chunk.order,
|
165
166
|
object=chunk.object,
|
166
|
-
)
|
167
|
+
)
|
167
168
|
for chunk in chunks
|
168
169
|
]
|
169
170
|
|
170
171
|
|
171
172
|
def _add_response_format_to_options(
|
172
173
|
options: dict,
|
173
|
-
structured_output_model:
|
174
|
+
structured_output_model: type[BaseModel],
|
174
175
|
structured_output_enforce_schema: bool = False,
|
175
176
|
) -> dict:
|
176
177
|
options["responseFormat"] = {
|
@@ -191,11 +192,10 @@ def _prepare_completion_params_util(
|
|
191
192
|
tools: list[LanguageModelTool] | None = None,
|
192
193
|
other_options: dict | None = None,
|
193
194
|
content_chunks: list[ContentChunk] | None = None,
|
194
|
-
structured_output_model:
|
195
|
+
structured_output_model: type[BaseModel] | None = None,
|
195
196
|
structured_output_enforce_schema: bool = False,
|
196
197
|
) -> tuple[dict, str, dict, dict | None]:
|
197
|
-
"""
|
198
|
-
Prepares common parameters for completion requests.
|
198
|
+
"""Prepare common parameters for completion requests.
|
199
199
|
|
200
200
|
Returns:
|
201
201
|
tuple containing:
|
@@ -203,18 +203,22 @@ def _prepare_completion_params_util(
|
|
203
203
|
- model (str): Resolved model name
|
204
204
|
- messages_dict (dict): Processed messages
|
205
205
|
- search_context (dict | None): Processed content chunks if provided
|
206
|
-
"""
|
207
206
|
|
207
|
+
"""
|
208
208
|
options = _add_tools_to_options({}, tools)
|
209
209
|
if structured_output_model:
|
210
210
|
options = _add_response_format_to_options(
|
211
|
-
options,
|
211
|
+
options,
|
212
|
+
structured_output_model,
|
213
|
+
structured_output_enforce_schema,
|
212
214
|
)
|
213
215
|
options["temperature"] = temperature
|
214
216
|
if other_options:
|
215
217
|
options.update(other_options)
|
216
218
|
|
217
|
-
model =
|
219
|
+
model = (
|
220
|
+
model_name.value if isinstance(model_name, LanguageModelName) else model_name
|
221
|
+
)
|
218
222
|
|
219
223
|
# Different methods need different message dump parameters
|
220
224
|
messages_dict = messages.model_dump(
|
@@ -492,8 +492,10 @@ class LanguageModelInfo(BaseModel):
|
|
492
492
|
|
493
493
|
@deprecated(
|
494
494
|
"""
|
495
|
-
Use `LanguageModelInfo` instead of `LanguageModel
|
496
|
-
|
495
|
+
Use `LanguageModelInfo` instead of `LanguageModel`.
|
496
|
+
|
497
|
+
`LanguageModel` will be deprecated on 31.12.2025
|
498
|
+
""",
|
497
499
|
)
|
498
500
|
class LanguageModel:
|
499
501
|
_info: ClassVar[LanguageModelInfo]
|
@@ -503,8 +505,8 @@ class LanguageModel:
|
|
503
505
|
|
504
506
|
@property
|
505
507
|
def info(self) -> LanguageModelInfo:
|
506
|
-
"""
|
507
|
-
|
508
|
+
"""Return all infos about the model.
|
509
|
+
|
508
510
|
- name
|
509
511
|
- version
|
510
512
|
- provider
|