aiqtoolkit 1.1.0a20250429__py3-none-any.whl → 1.1.0a20250502__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/react_agent/register.py +2 -2
- aiq/agent/reasoning_agent/reasoning_agent.py +1 -1
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/builder/component_utils.py +5 -5
- aiq/builder/front_end.py +4 -4
- aiq/builder/function_base.py +4 -4
- aiq/builder/function_info.py +1 -1
- aiq/builder/intermediate_step_manager.py +10 -8
- aiq/builder/workflow_builder.py +1 -1
- aiq/cli/cli_utils/validation.py +1 -1
- aiq/cli/commands/configure/channel/add.py +1 -1
- aiq/cli/commands/configure/channel/channel.py +3 -1
- aiq/cli/commands/configure/channel/remove.py +1 -1
- aiq/cli/commands/configure/channel/update.py +1 -1
- aiq/cli/commands/configure/configure.py +2 -2
- aiq/cli/commands/info/info.py +2 -2
- aiq/cli/commands/info/list_components.py +2 -2
- aiq/cli/commands/registry/publish.py +3 -3
- aiq/cli/commands/registry/pull.py +3 -3
- aiq/cli/commands/registry/registry.py +3 -1
- aiq/cli/commands/registry/remove.py +3 -3
- aiq/cli/commands/registry/search.py +3 -3
- aiq/cli/commands/start.py +4 -4
- aiq/cli/commands/uninstall.py +2 -2
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +2 -2
- aiq/cli/commands/workflow/workflow_commands.py +14 -8
- aiq/cli/entrypoint.py +1 -1
- aiq/data_models/api_server.py +73 -57
- aiq/data_models/component_ref.py +7 -7
- aiq/data_models/discovery_metadata.py +7 -7
- aiq/data_models/intermediate_step.py +2 -2
- aiq/eval/register.py +1 -0
- aiq/eval/remote_workflow.py +1 -1
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +263 -0
- aiq/eval/tunable_rag_evaluator/register.py +50 -0
- aiq/front_ends/console/console_front_end_config.py +1 -1
- aiq/front_ends/fastapi/fastapi_front_end_config.py +5 -5
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +27 -18
- aiq/front_ends/fastapi/response_helpers.py +33 -19
- aiq/memory/__init__.py +2 -2
- aiq/meta/pypi.md +18 -18
- aiq/observability/async_otel_listener.py +157 -10
- aiq/profiler/callbacks/agno_callback_handler.py +2 -2
- aiq/profiler/callbacks/langchain_callback_handler.py +1 -1
- aiq/profiler/callbacks/llama_index_callback_handler.py +1 -1
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- aiq/profiler/decorators/function_tracking.py +1 -1
- aiq/profiler/profile_runner.py +1 -1
- aiq/registry_handlers/local/local_handler.py +5 -5
- aiq/registry_handlers/local/register_local.py +1 -1
- aiq/registry_handlers/package_utils.py +2 -2
- aiq/registry_handlers/pypi/pypi_handler.py +5 -5
- aiq/registry_handlers/pypi/register_pypi.py +3 -3
- aiq/registry_handlers/registry_handler_base.py +7 -7
- aiq/registry_handlers/rest/register_rest.py +4 -4
- aiq/registry_handlers/rest/rest_handler.py +5 -5
- aiq/registry_handlers/schemas/package.py +1 -1
- aiq/registry_handlers/schemas/publish.py +4 -4
- aiq/registry_handlers/schemas/pull.py +5 -4
- aiq/registry_handlers/schemas/search.py +7 -7
- aiq/retriever/models.py +1 -1
- aiq/runtime/loader.py +6 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/retriever.py +1 -1
- aiq/utils/io/yaml_tools.py +75 -6
- aiq/utils/settings/global_settings.py +1 -1
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/METADATA +24 -21
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/RECORD +74 -71
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/entry_points.txt +0 -0
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.1.0a20250429.dist-info → aiqtoolkit-1.1.0a20250502.dist-info}/top_level.txt +0 -0
aiq/data_models/api_server.py
CHANGED
|
@@ -32,14 +32,63 @@ from aiq.data_models.interactive import HumanPrompt
|
|
|
32
32
|
from aiq.utils.type_converter import GlobalTypeConverter
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
class ChatContentType(str, Enum):
|
|
36
|
+
"""
|
|
37
|
+
ChatContentType is an Enum that represents the type of Chat content.
|
|
38
|
+
"""
|
|
39
|
+
TEXT = "text"
|
|
40
|
+
IMAGE_URL = "image_url"
|
|
41
|
+
INPUT_AUDIO = "input_audio"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InputAudio(BaseModel):
|
|
45
|
+
data: str = "default"
|
|
46
|
+
format: str = "default"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class AudioContent(BaseModel):
|
|
50
|
+
model_config = ConfigDict(extra="forbid")
|
|
51
|
+
|
|
52
|
+
type: typing.Literal[ChatContentType.INPUT_AUDIO] = ChatContentType.INPUT_AUDIO
|
|
53
|
+
input_audio: InputAudio = InputAudio()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ImageUrl(BaseModel):
|
|
57
|
+
url: HttpUrl = HttpUrl(url="http://default.com")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ImageContent(BaseModel):
|
|
61
|
+
model_config = ConfigDict(extra="forbid")
|
|
62
|
+
|
|
63
|
+
type: typing.Literal[ChatContentType.IMAGE_URL] = ChatContentType.IMAGE_URL
|
|
64
|
+
image_url: ImageUrl = ImageUrl()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TextContent(BaseModel):
|
|
68
|
+
model_config = ConfigDict(extra="forbid")
|
|
69
|
+
|
|
70
|
+
type: typing.Literal[ChatContentType.TEXT] = ChatContentType.TEXT
|
|
71
|
+
text: str = "default"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Security(BaseModel):
|
|
75
|
+
model_config = ConfigDict(extra="forbid")
|
|
76
|
+
|
|
77
|
+
api_key: str = "default"
|
|
78
|
+
token: str = "default"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discriminator("type")]
|
|
82
|
+
|
|
83
|
+
|
|
35
84
|
class Message(BaseModel):
|
|
36
|
-
content: str
|
|
85
|
+
content: str | list[UserContent]
|
|
37
86
|
role: str
|
|
38
87
|
|
|
39
88
|
|
|
40
89
|
class AIQChatRequest(BaseModel):
|
|
41
90
|
"""
|
|
42
|
-
AIQChatRequest is a data model that represents a request to the
|
|
91
|
+
AIQChatRequest is a data model that represents a request to the AIQ Toolkit chat API.
|
|
43
92
|
"""
|
|
44
93
|
|
|
45
94
|
# Allow extra fields in the model_config to support derived models
|
|
@@ -65,6 +114,20 @@ class AIQChatRequest(BaseModel):
|
|
|
65
114
|
max_tokens=max_tokens,
|
|
66
115
|
top_p=top_p)
|
|
67
116
|
|
|
117
|
+
@staticmethod
|
|
118
|
+
def from_content(content: list[UserContent],
|
|
119
|
+
*,
|
|
120
|
+
model: str | None = None,
|
|
121
|
+
temperature: float | None = None,
|
|
122
|
+
max_tokens: int | None = None,
|
|
123
|
+
top_p: float | None = None) -> "AIQChatRequest":
|
|
124
|
+
|
|
125
|
+
return AIQChatRequest(messages=[Message(content=content, role="user")],
|
|
126
|
+
model=model,
|
|
127
|
+
temperature=temperature,
|
|
128
|
+
max_tokens=max_tokens,
|
|
129
|
+
top_p=top_p)
|
|
130
|
+
|
|
68
131
|
|
|
69
132
|
class AIQChoiceMessage(BaseModel):
|
|
70
133
|
content: str | None = None
|
|
@@ -88,8 +151,8 @@ class AIQUsage(BaseModel):
|
|
|
88
151
|
|
|
89
152
|
class AIQResponseSerializable(abc.ABC):
|
|
90
153
|
"""
|
|
91
|
-
AIQChatResponseSerializable is an abstract class that defines the interface for serializing output for the
|
|
92
|
-
chat streaming API.
|
|
154
|
+
AIQChatResponseSerializable is an abstract class that defines the interface for serializing output for the AIQ
|
|
155
|
+
Toolkit chat streaming API.
|
|
93
156
|
"""
|
|
94
157
|
|
|
95
158
|
@abstractmethod
|
|
@@ -111,7 +174,7 @@ class AIQResponseBaseModelIntermediate(BaseModel, AIQResponseSerializable):
|
|
|
111
174
|
|
|
112
175
|
class AIQChatResponse(AIQResponseBaseModelOutput):
|
|
113
176
|
"""
|
|
114
|
-
AIQChatResponse is a data model that represents a response from the
|
|
177
|
+
AIQChatResponse is a data model that represents a response from the AIQ Toolkit chat API.
|
|
115
178
|
"""
|
|
116
179
|
|
|
117
180
|
# Allow extra fields in the model_config to support derived models
|
|
@@ -152,7 +215,7 @@ class AIQChatResponse(AIQResponseBaseModelOutput):
|
|
|
152
215
|
|
|
153
216
|
class AIQChatResponseChunk(AIQResponseBaseModelOutput):
|
|
154
217
|
"""
|
|
155
|
-
AIQChatResponseChunk is a data model that represents a response chunk from the
|
|
218
|
+
AIQChatResponseChunk is a data model that represents a response chunk from the AIQ Toolkit chat streaming API.
|
|
156
219
|
"""
|
|
157
220
|
|
|
158
221
|
# Allow extra fields in the model_config to support derived models
|
|
@@ -191,7 +254,7 @@ class AIQChatResponseChunk(AIQResponseBaseModelOutput):
|
|
|
191
254
|
|
|
192
255
|
class AIQResponseIntermediateStep(AIQResponseBaseModelIntermediate):
|
|
193
256
|
"""
|
|
194
|
-
AIQResponseSerializedStep is a data model that represents a serialized step in the
|
|
257
|
+
AIQResponseSerializedStep is a data model that represents a serialized step in the AIQ Toolkit chat streaming API.
|
|
195
258
|
"""
|
|
196
259
|
|
|
197
260
|
# Allow extra fields in the model_config to support derived models
|
|
@@ -231,15 +294,6 @@ class UserMessageContentRoleType(str, Enum):
|
|
|
231
294
|
ASSISTANT = "assistant"
|
|
232
295
|
|
|
233
296
|
|
|
234
|
-
class ChatContentType(str, Enum):
|
|
235
|
-
"""
|
|
236
|
-
ChatContentType is an Enum that represents the type of Chat content.
|
|
237
|
-
"""
|
|
238
|
-
TEXT = "text"
|
|
239
|
-
IMAGE_URL = "image_url"
|
|
240
|
-
INPUT_AUDIO = "input_audio"
|
|
241
|
-
|
|
242
|
-
|
|
243
297
|
class WebSocketMessageType(str, Enum):
|
|
244
298
|
"""
|
|
245
299
|
WebSocketMessageType is an Enum that represents WebSocket Message types.
|
|
@@ -270,46 +324,6 @@ class WebSocketMessageStatus(str, Enum):
|
|
|
270
324
|
COMPLETE = "complete"
|
|
271
325
|
|
|
272
326
|
|
|
273
|
-
class InputAudio(BaseModel):
|
|
274
|
-
data: str = "default"
|
|
275
|
-
format: str = "default"
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
class AudioContent(BaseModel):
|
|
279
|
-
model_config = ConfigDict(extra="forbid")
|
|
280
|
-
|
|
281
|
-
type: typing.Literal[ChatContentType.INPUT_AUDIO] = ChatContentType.INPUT_AUDIO
|
|
282
|
-
input_audio: InputAudio = InputAudio()
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
class ImageUrl(BaseModel):
|
|
286
|
-
url: HttpUrl = HttpUrl(url="http://default.com")
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
class ImageContent(BaseModel):
|
|
290
|
-
model_config = ConfigDict(extra="forbid")
|
|
291
|
-
|
|
292
|
-
type: typing.Literal[ChatContentType.IMAGE_URL] = ChatContentType.IMAGE_URL
|
|
293
|
-
image_url: ImageUrl = ImageUrl()
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
class TextContent(BaseModel):
|
|
297
|
-
model_config = ConfigDict(extra="forbid")
|
|
298
|
-
|
|
299
|
-
type: typing.Literal[ChatContentType.TEXT] = ChatContentType.TEXT
|
|
300
|
-
text: str = "default"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
class Security(BaseModel):
|
|
304
|
-
model_config = ConfigDict(extra="forbid")
|
|
305
|
-
|
|
306
|
-
api_key: str = "default"
|
|
307
|
-
token: str = "default"
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discriminator("type")]
|
|
311
|
-
|
|
312
|
-
|
|
313
327
|
class UserMessages(BaseModel):
|
|
314
328
|
model_config = ConfigDict(extra="forbid")
|
|
315
329
|
|
|
@@ -487,7 +501,9 @@ GlobalTypeConverter.register_converter(_generate_response_to_chat_response)
|
|
|
487
501
|
|
|
488
502
|
# ======== AIQChatRequest Converters ========
|
|
489
503
|
def _aiq_chat_request_to_string(data: AIQChatRequest) -> str:
|
|
490
|
-
|
|
504
|
+
if isinstance(data.messages[-1].content, str):
|
|
505
|
+
return data.messages[-1].content
|
|
506
|
+
return str(data.messages[-1].content)
|
|
491
507
|
|
|
492
508
|
|
|
493
509
|
GlobalTypeConverter.register_converter(_aiq_chat_request_to_string)
|
aiq/data_models/component_ref.py
CHANGED
|
@@ -43,7 +43,7 @@ class ComponentRefNode(HashableBaseModel):
|
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
ref_name (ComponentRef): The name of the component runtime instance.
|
|
46
|
-
component_group (ComponentGroup): The component group in an
|
|
46
|
+
component_group (ComponentGroup): The component group in an AIQ Toolkit configuration object.
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
ref_name: "ComponentRef"
|
|
@@ -70,7 +70,7 @@ class ComponentRef(str, ABC):
|
|
|
70
70
|
"""Provides the component group this ComponentRef object represents.
|
|
71
71
|
|
|
72
72
|
Returns:
|
|
73
|
-
ComponentGroup: A component group of the
|
|
73
|
+
ComponentGroup: A component group of the AIQ Toolkit configuration object
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
76
|
pass
|
|
@@ -82,7 +82,7 @@ class ComponentRef(str, ABC):
|
|
|
82
82
|
|
|
83
83
|
class EmbedderRef(ComponentRef):
|
|
84
84
|
"""
|
|
85
|
-
A reference to an embedder in an
|
|
85
|
+
A reference to an embedder in an AIQ Toolkit configuration object.
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
88
|
@property
|
|
@@ -93,7 +93,7 @@ class EmbedderRef(ComponentRef):
|
|
|
93
93
|
|
|
94
94
|
class FunctionRef(ComponentRef):
|
|
95
95
|
"""
|
|
96
|
-
A reference to a function in an
|
|
96
|
+
A reference to a function in an AIQ Toolkit configuration object.
|
|
97
97
|
"""
|
|
98
98
|
|
|
99
99
|
@property
|
|
@@ -104,7 +104,7 @@ class FunctionRef(ComponentRef):
|
|
|
104
104
|
|
|
105
105
|
class LLMRef(ComponentRef):
|
|
106
106
|
"""
|
|
107
|
-
A reference to an LLM in an
|
|
107
|
+
A reference to an LLM in an AIQ Toolkit configuration object.
|
|
108
108
|
"""
|
|
109
109
|
|
|
110
110
|
@property
|
|
@@ -115,7 +115,7 @@ class LLMRef(ComponentRef):
|
|
|
115
115
|
|
|
116
116
|
class MemoryRef(ComponentRef):
|
|
117
117
|
"""
|
|
118
|
-
A reference to a memory in an
|
|
118
|
+
A reference to a memory in an AIQ Toolkit configuration object.
|
|
119
119
|
"""
|
|
120
120
|
|
|
121
121
|
@property
|
|
@@ -126,7 +126,7 @@ class MemoryRef(ComponentRef):
|
|
|
126
126
|
|
|
127
127
|
class RetrieverRef(ComponentRef):
|
|
128
128
|
"""
|
|
129
|
-
A reference to a retriever in an
|
|
129
|
+
A reference to a retriever in an AIQ Toolkit configuration object.
|
|
130
130
|
"""
|
|
131
131
|
|
|
132
132
|
@property
|
|
@@ -55,11 +55,11 @@ class DiscoveryMetadata(BaseModel):
|
|
|
55
55
|
"""A data model representing metadata about each registered component to faciliate its discovery.
|
|
56
56
|
|
|
57
57
|
Args:
|
|
58
|
-
package (str): The name of the package containing the
|
|
59
|
-
version (str): The version number of the package containing the
|
|
60
|
-
component_type (AIQComponentEnum): The type of
|
|
61
|
-
component_name (str): The registered name of the
|
|
62
|
-
description (str): Description of the
|
|
58
|
+
package (str): The name of the package containing the AIQ Toolkit component.
|
|
59
|
+
version (str): The version number of the package containing the AIQ Toolkit component.
|
|
60
|
+
component_type (AIQComponentEnum): The type of AIQ Toolkit component this metadata represents.
|
|
61
|
+
component_name (str): The registered name of the AIQ Toolkit component.
|
|
62
|
+
description (str): Description of the AIQ Toolkit component pulled from its config objects docstrings.
|
|
63
63
|
developer_notes (str): Other notes to a developers to aid in the use of the component.
|
|
64
64
|
status (DiscoveryStatusEnum): Provides the status of the metadata discovery process.
|
|
65
65
|
"""
|
|
@@ -129,7 +129,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
129
129
|
@staticmethod
|
|
130
130
|
def from_config_type(config_type: type["TypedBaseModelT"],
|
|
131
131
|
component_type: AIQComponentEnum = AIQComponentEnum.UNDEFINED) -> "DiscoveryMetadata":
|
|
132
|
-
"""Generates discovery metadata from an
|
|
132
|
+
"""Generates discovery metadata from an AIQ Toolkit config object.
|
|
133
133
|
|
|
134
134
|
Args:
|
|
135
135
|
config_type (type[TypedBaseModelT]): A registered component's configuration object.
|
|
@@ -204,7 +204,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
204
204
|
"""Generates discovery metadata from an installed package name.
|
|
205
205
|
|
|
206
206
|
Args:
|
|
207
|
-
package_name (str): The name of the
|
|
207
|
+
package_name (str): The name of the AIQ Toolkit plugin package containing registered components.
|
|
208
208
|
package_version (str, optional): The version of the package, Defaults to None.
|
|
209
209
|
|
|
210
210
|
Returns:
|
|
@@ -98,7 +98,7 @@ class TraceMetadata(BaseModel):
|
|
|
98
98
|
|
|
99
99
|
class IntermediateStepPayload(BaseModel):
|
|
100
100
|
"""
|
|
101
|
-
AIQIntermediateStep is a data model that represents an intermediate step in the
|
|
101
|
+
AIQIntermediateStep is a data model that represents an intermediate step in the AIQ Toolkit. Intermediate steps are
|
|
102
102
|
captured while a request is running and can be used to show progress or to evaluate the path a workflow took to get
|
|
103
103
|
a response.
|
|
104
104
|
"""
|
|
@@ -203,7 +203,7 @@ class IntermediateStepPayload(BaseModel):
|
|
|
203
203
|
|
|
204
204
|
class IntermediateStep(BaseModel):
|
|
205
205
|
"""
|
|
206
|
-
AIQIntermediateStep is a data model that represents an intermediate step in the
|
|
206
|
+
AIQIntermediateStep is a data model that represents an intermediate step in the AIQ Toolkit. Intermediate steps are
|
|
207
207
|
captured while a request is running and can be used to show progress or to evaluate the path a workflow took to get
|
|
208
208
|
a response.
|
|
209
209
|
"""
|
aiq/eval/register.py
CHANGED
|
@@ -20,3 +20,4 @@
|
|
|
20
20
|
from .rag_evaluator.register import register_ragas_evaluator
|
|
21
21
|
from .swe_bench_evaluator.register import register_swe_bench_evaluator
|
|
22
22
|
from .trajectory_evaluator.register import register_trajectory_evaluator
|
|
23
|
+
from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
|
aiq/eval/remote_workflow.py
CHANGED
|
@@ -52,7 +52,7 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
52
52
|
|
|
53
53
|
try:
|
|
54
54
|
# Use the streaming endpoint
|
|
55
|
-
endpoint = f"{self.config.endpoint}/generate/
|
|
55
|
+
endpoint = f"{self.config.endpoint}/generate/full"
|
|
56
56
|
async with session.post(endpoint, json=payload) as response:
|
|
57
57
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
58
58
|
|
|
File without changes
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from langchain.output_parsers import ResponseSchema
|
|
20
|
+
from langchain.output_parsers import StructuredOutputParser
|
|
21
|
+
from langchain.schema import HumanMessage
|
|
22
|
+
from langchain.schema import SystemMessage
|
|
23
|
+
from langchain_core.language_models import BaseChatModel
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
|
|
26
|
+
from aiq.eval.evaluator.evaluator_model import EvalInput
|
|
27
|
+
from aiq.eval.evaluator.evaluator_model import EvalInputItem
|
|
28
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutput
|
|
29
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutputItem
|
|
30
|
+
from aiq.eval.utils.tqdm_position_registry import TqdmPositionRegistry
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# pylint: disable=line-too-long
|
|
35
|
+
# flake8: noqa: E501
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def evaluation_prompt(judge_llm_prompt: str,
|
|
39
|
+
question: str,
|
|
40
|
+
answer_description: str,
|
|
41
|
+
generated_answer: str,
|
|
42
|
+
format_instructions: str,
|
|
43
|
+
default_scoring: bool):
|
|
44
|
+
"""
|
|
45
|
+
This function generates a prompt for the judge LLM to evaluate the generated answer.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
DEFAULT_SCORING_INSTRUCTIONS = """
|
|
49
|
+
The coverage score is a measure of how well the generated answer covers the critical aspects mentioned in the expected answer. A low coverage score indicates that the generated answer misses critical aspects of the expected answer. A middle coverage score indicates that the generated answer covers some of the must-haves of the expected answer but lacks other details. A high coverage score indicates that all of the expected aspects are present in the generated answer.
|
|
50
|
+
The correctness score is a measure of how well the generated answer matches the expected answer. A low correctness score indicates that the generated answer is incorrect or does not match the expected answer. A middle correctness score indicates that the generated answer is correct but lacks some details. A high correctness score indicates that the generated answer is exactly the same as the expected answer.
|
|
51
|
+
The relevance score is a measure of how well the generated answer is relevant to the question. A low relevance score indicates that the generated answer is not relevant to the question. A middle relevance score indicates that the generated answer is somewhat relevant to the question. A high relevance score indicates that the generated answer is exactly relevant to the question.
|
|
52
|
+
The reasoning is a 1-2 sentence explanation for the scoring.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DEFAULT_EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format."
|
|
56
|
+
f"Judge based on the following scoring rubric: {DEFAULT_SCORING_INSTRUCTIONS}"
|
|
57
|
+
f"{judge_llm_prompt}\n"
|
|
58
|
+
f"{format_instructions}\n"
|
|
59
|
+
f"Here is the user's query: {question}"
|
|
60
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
61
|
+
f"Here is the generated answer: {generated_answer}")
|
|
62
|
+
|
|
63
|
+
EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format. {judge_llm_prompt}\n"
|
|
64
|
+
f"{format_instructions}\n"
|
|
65
|
+
f"Here is the user's query: {question}"
|
|
66
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
67
|
+
f"Here is the generated answer: {generated_answer}")
|
|
68
|
+
|
|
69
|
+
return EVAL_PROMPT if not default_scoring else DEFAULT_EVAL_PROMPT
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TunableRagEvaluator:
|
|
73
|
+
'''Tunable RAG evaluator class with customizable LLM prompt for scoring.'''
|
|
74
|
+
|
|
75
|
+
def __init__(self,
|
|
76
|
+
llm: BaseChatModel,
|
|
77
|
+
judge_llm_prompt: str,
|
|
78
|
+
max_concurrency: int,
|
|
79
|
+
default_scoring: bool,
|
|
80
|
+
default_score_weights: dict):
|
|
81
|
+
self.llm = llm
|
|
82
|
+
self.max_concurrency = max_concurrency
|
|
83
|
+
self.judge_llm_prompt = judge_llm_prompt
|
|
84
|
+
self.semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
85
|
+
self.default_scoring = default_scoring
|
|
86
|
+
# Use user-provided weights if available; otherwise, set equal weights for each score
|
|
87
|
+
self.default_score_weights = default_score_weights if default_score_weights else {
|
|
88
|
+
"coverage": 1 / 3, "correctness": 1 / 3, "relevance": 1 / 3
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
async def evaluate(self, eval_input: EvalInput) -> EvalOutput:
|
|
92
|
+
'''Evaluate function'''
|
|
93
|
+
|
|
94
|
+
async def process_item(item):
|
|
95
|
+
"""Compute RAG evaluation for an individual item"""
|
|
96
|
+
question = item.input_obj
|
|
97
|
+
answer_description = item.expected_output_obj
|
|
98
|
+
generated_answer = item.output_obj
|
|
99
|
+
|
|
100
|
+
# Call judge LLM to generate score
|
|
101
|
+
score = 0.0
|
|
102
|
+
|
|
103
|
+
default_evaluation_schema = [
|
|
104
|
+
ResponseSchema(
|
|
105
|
+
name="coverage_score",
|
|
106
|
+
description=
|
|
107
|
+
"Score for the coverage of all critical aspects mentioned in the expected answer. Ex. 0.5",
|
|
108
|
+
type="float"),
|
|
109
|
+
ResponseSchema(
|
|
110
|
+
name="correctness_score",
|
|
111
|
+
description=
|
|
112
|
+
"Score for the accuracy of the generated answer compared to the expected answer. Ex. 0.5",
|
|
113
|
+
type="float"),
|
|
114
|
+
ResponseSchema(name="relevance_score",
|
|
115
|
+
description="Score for the relevance of the generated answer to the question. Ex. 0.5",
|
|
116
|
+
type="float"),
|
|
117
|
+
ResponseSchema(
|
|
118
|
+
name="reasoning",
|
|
119
|
+
description=
|
|
120
|
+
"1-2 summarized sentences of reasoning for the scores. Ex. 'The generated answer covers all critical aspects mentioned in the expected answer, is correct, and is relevant to the question.'",
|
|
121
|
+
type="string"),
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
custom_evaluation_schema = [
|
|
125
|
+
ResponseSchema(name="score", description="Score for the generated answer. Ex. 0.5", type="float"),
|
|
126
|
+
ResponseSchema(
|
|
127
|
+
name="reasoning",
|
|
128
|
+
description=
|
|
129
|
+
"1-2 sentence reasoning for the score. Ex. 'The generated answer is exactly the same as the description of the expected answer.'",
|
|
130
|
+
type="string"),
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
if self.default_scoring:
|
|
134
|
+
evaluation_schema = default_evaluation_schema
|
|
135
|
+
else:
|
|
136
|
+
evaluation_schema = custom_evaluation_schema
|
|
137
|
+
|
|
138
|
+
llm_input_response_parser = StructuredOutputParser.from_response_schemas(evaluation_schema)
|
|
139
|
+
format_instructions = llm_input_response_parser.get_format_instructions()
|
|
140
|
+
|
|
141
|
+
eval_prompt = evaluation_prompt(judge_llm_prompt=self.judge_llm_prompt,
|
|
142
|
+
question=question,
|
|
143
|
+
answer_description=answer_description,
|
|
144
|
+
generated_answer=generated_answer,
|
|
145
|
+
format_instructions=format_instructions,
|
|
146
|
+
default_scoring=self.default_scoring)
|
|
147
|
+
|
|
148
|
+
messages = [
|
|
149
|
+
SystemMessage(content="You must respond only in JSON format."), HumanMessage(content=eval_prompt)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
response = await self.llm.ainvoke(messages)
|
|
153
|
+
|
|
154
|
+
# Initialize default values to handle service errors
|
|
155
|
+
coverage_score = 0.0
|
|
156
|
+
correctness_score = 0.0
|
|
157
|
+
relevance_score = 0.0
|
|
158
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
parsed_response = llm_input_response_parser.parse(response.content)
|
|
162
|
+
if self.default_scoring:
|
|
163
|
+
try:
|
|
164
|
+
coverage_score = parsed_response["coverage_score"]
|
|
165
|
+
correctness_score = parsed_response["correctness_score"]
|
|
166
|
+
relevance_score = parsed_response["relevance_score"]
|
|
167
|
+
reasoning = parsed_response["reasoning"]
|
|
168
|
+
except KeyError as e:
|
|
169
|
+
logger.error("Missing required keys in default scoring response: %s",
|
|
170
|
+
", ".join(str(arg) for arg in e.args))
|
|
171
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
172
|
+
|
|
173
|
+
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
174
|
+
correctness_weight = self.default_score_weights.get("correctness", 1 / 3)
|
|
175
|
+
relevance_weight = self.default_score_weights.get("relevance", 1 / 3)
|
|
176
|
+
|
|
177
|
+
# Calculate score
|
|
178
|
+
total_weight = coverage_weight + correctness_weight + relevance_weight
|
|
179
|
+
coverage_weight = coverage_weight / total_weight
|
|
180
|
+
correctness_weight = correctness_weight / total_weight
|
|
181
|
+
relevance_weight = relevance_weight / total_weight
|
|
182
|
+
|
|
183
|
+
if round(coverage_weight + correctness_weight + relevance_weight, 2) != 1:
|
|
184
|
+
logger.warning("The sum of the default score weights is not 1. The weights will be normalized.")
|
|
185
|
+
coverage_weight = coverage_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
186
|
+
correctness_weight = correctness_weight / (coverage_weight + correctness_weight +
|
|
187
|
+
relevance_weight)
|
|
188
|
+
relevance_weight = relevance_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
189
|
+
|
|
190
|
+
score = (coverage_weight * coverage_score + correctness_weight * correctness_score +
|
|
191
|
+
relevance_weight * relevance_score)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
try:
|
|
195
|
+
score = parsed_response["score"]
|
|
196
|
+
reasoning = parsed_response["reasoning"]
|
|
197
|
+
except KeyError as e:
|
|
198
|
+
logger.error("Missing required keys in custom scoring response: %s",
|
|
199
|
+
", ".join(str(arg) for arg in e.args))
|
|
200
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
201
|
+
raise
|
|
202
|
+
except (KeyError, ValueError) as e:
|
|
203
|
+
logger.error("Error parsing judge LLM response: %s", e)
|
|
204
|
+
score = 0.0
|
|
205
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
206
|
+
|
|
207
|
+
if self.default_scoring:
|
|
208
|
+
reasoning = {
|
|
209
|
+
"question": question,
|
|
210
|
+
"answer_description": answer_description,
|
|
211
|
+
"generated_answer": generated_answer,
|
|
212
|
+
"score_breakdown": {
|
|
213
|
+
"coverage_score": coverage_score,
|
|
214
|
+
"correctness_score": correctness_score,
|
|
215
|
+
"relevance_score": relevance_score,
|
|
216
|
+
},
|
|
217
|
+
"reasoning": reasoning,
|
|
218
|
+
}
|
|
219
|
+
else:
|
|
220
|
+
reasoning = {
|
|
221
|
+
"question": question,
|
|
222
|
+
"answer_description": answer_description,
|
|
223
|
+
"generated_answer": generated_answer,
|
|
224
|
+
"reasoning": reasoning
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return score, reasoning
|
|
228
|
+
|
|
229
|
+
async def wrapped_process(item: EvalInputItem) -> tuple[float, dict]:
|
|
230
|
+
"""
|
|
231
|
+
Process an item asynchronously and update the progress bar.
|
|
232
|
+
Use the semaphore to limit the number of concurrent items.
|
|
233
|
+
"""
|
|
234
|
+
async with self.semaphore:
|
|
235
|
+
result = await process_item(item)
|
|
236
|
+
# Update the progress bar
|
|
237
|
+
pbar.update(1)
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
# Claim a tqdm position to display the progress bar
|
|
242
|
+
tqdm_position = TqdmPositionRegistry.claim()
|
|
243
|
+
# Create a progress bar
|
|
244
|
+
pbar = tqdm(total=len(eval_input.eval_input_items), desc="Evaluating RAG", position=tqdm_position)
|
|
245
|
+
# Process items concurrently with a limit on concurrency
|
|
246
|
+
results = await asyncio.gather(*[wrapped_process(item) for item in eval_input.eval_input_items])
|
|
247
|
+
finally:
|
|
248
|
+
pbar.close()
|
|
249
|
+
TqdmPositionRegistry.release(tqdm_position)
|
|
250
|
+
|
|
251
|
+
# Extract scores and reasonings
|
|
252
|
+
sample_scores, sample_reasonings = zip(*results) if results else ([], [])
|
|
253
|
+
|
|
254
|
+
# Compute average score
|
|
255
|
+
avg_score = round(sum(sample_scores) / len(sample_scores), 2) if sample_scores else 0.0
|
|
256
|
+
|
|
257
|
+
# Construct EvalOutputItems
|
|
258
|
+
eval_output_items = [
|
|
259
|
+
EvalOutputItem(id=item.id, score=score, reasoning=reasoning)
|
|
260
|
+
for item, score, reasoning in zip(eval_input.eval_input_items, sample_scores, sample_reasonings)
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
return EvalOutput(average_score=avg_score, eval_output_items=eval_output_items)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from aiq.builder.builder import EvalBuilder
|
|
19
|
+
from aiq.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from aiq.cli.register_workflow import register_evaluator
|
|
22
|
+
from aiq.data_models.component_ref import LLMRef
|
|
23
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TunableRagEvaluatorConfig(EvaluatorBaseConfig, name="tunable_rag_evaluator"):
|
|
27
|
+
'''Configuration for tunable RAG evaluator'''
|
|
28
|
+
llm_name: LLMRef = Field(description="Name of the judge LLM")
|
|
29
|
+
judge_llm_prompt: str = Field(description="LLM prompt for the judge LLM")
|
|
30
|
+
default_scoring: bool = Field(description="Whether to use default scoring", default=False)
|
|
31
|
+
default_score_weights: dict = Field(
|
|
32
|
+
default={
|
|
33
|
+
"coverage": 0.5, "correctness": 0.3, "relevance": 0.2
|
|
34
|
+
},
|
|
35
|
+
description="Weights for the different scoring components when using default scoring")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@register_evaluator(config_type=TunableRagEvaluatorConfig)
|
|
39
|
+
async def register_tunable_rag_evaluator(config: TunableRagEvaluatorConfig, builder: EvalBuilder):
|
|
40
|
+
'''Register tunable RAG evaluator'''
|
|
41
|
+
from .evaluate import TunableRagEvaluator
|
|
42
|
+
|
|
43
|
+
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
44
|
+
evaluator = TunableRagEvaluator(llm,
|
|
45
|
+
config.judge_llm_prompt,
|
|
46
|
+
builder.get_max_concurrency(),
|
|
47
|
+
config.default_scoring,
|
|
48
|
+
config.default_score_weights)
|
|
49
|
+
|
|
50
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluator.evaluate, description="Tunable RAG Evaluator")
|
|
@@ -22,7 +22,7 @@ from aiq.data_models.front_end import FrontEndBaseConfig
|
|
|
22
22
|
|
|
23
23
|
class ConsoleFrontEndConfig(FrontEndBaseConfig, name="console"):
|
|
24
24
|
"""
|
|
25
|
-
A front end that allows an
|
|
25
|
+
A front end that allows an AIQ Toolkit workflow to be run from the console.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
input_query: list[str] | None = Field(default=None,
|