aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +170 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +146 -112
- aiq/agent/react_agent/prompt.py +1 -6
- aiq/agent/react_agent/register.py +36 -35
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/agent/tool_calling_agent/register.py +1 -1
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +37 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +44 -12
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +421 -61
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +2 -1
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +124 -12
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/dataset_handler.py +2 -1
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/evaluate.py +23 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +38 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +1 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +87 -2
- aiq/eval/evaluate.py +208 -27
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +1 -0
- aiq/eval/intermediate_step_adapter.py +11 -5
- aiq/eval/rag_evaluator/evaluate.py +55 -15
- aiq/eval/rag_evaluator/register.py +6 -1
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/trajectory_evaluator/evaluate.py +22 -65
- aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
- aiq/eval/tunable_rag_evaluator/register.py +2 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/output_uploader.py +10 -1
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/job_store.py +47 -25
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +108 -89
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +3 -2
- aiq/llm/register.py +1 -0
- aiq/meta/pypi.md +12 -12
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +36 -39
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/inference_metrics_model.py +3 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
- aiq/profiler/inference_optimization/data_models.py +2 -2
- aiq/profiler/inference_optimization/llm_metrics.py +2 -2
- aiq/profiler/profile_runner.py +61 -21
- aiq/runtime/loader.py +9 -3
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +25 -7
- aiq/runtime/user_metadata.py +2 -3
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +41 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +6 -3
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/METADATA +53 -21
- aiqtoolkit-1.2.0rc2.dist-info/RECORD +436 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -148
- aiq/observability/async_otel_listener.py +0 -429
- aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from datetime import timezone
|
|
19
|
+
from enum import Enum
|
|
20
|
+
|
|
21
|
+
import httpx
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic import ConfigDict
|
|
24
|
+
from pydantic import Field
|
|
25
|
+
from pydantic import SecretStr
|
|
26
|
+
|
|
27
|
+
from aiq.data_models.common import BaseModelRegistryTag
|
|
28
|
+
from aiq.data_models.common import TypedBaseModel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AuthProviderBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
32
|
+
"""
|
|
33
|
+
Base configuration for authentication providers.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# Default, forbid extra fields to prevent unexpected behavior or miss typed options
|
|
37
|
+
model_config = ConfigDict(extra="forbid")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
AuthProviderBaseConfigT = typing.TypeVar("AuthProviderBaseConfigT", bound=AuthProviderBaseConfig)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CredentialLocation(str, Enum):
|
|
44
|
+
"""
|
|
45
|
+
Enum representing the location of credentials in an HTTP request.
|
|
46
|
+
"""
|
|
47
|
+
HEADER = "header"
|
|
48
|
+
QUERY = "query"
|
|
49
|
+
COOKIE = "cookie"
|
|
50
|
+
BODY = "body"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AuthFlowType(str, Enum):
|
|
54
|
+
"""
|
|
55
|
+
Enum representing different types of authentication flows.
|
|
56
|
+
"""
|
|
57
|
+
API_KEY = "api_key"
|
|
58
|
+
OAUTH2_CLIENT_CREDENTIALS = "oauth2_client_credentials"
|
|
59
|
+
OAUTH2_AUTHORIZATION_CODE = "oauth2_auth_code_flow"
|
|
60
|
+
OAUTH2_PASSWORD = "oauth2_password"
|
|
61
|
+
OAUTH2_DEVICE_CODE = "oauth2_device_code"
|
|
62
|
+
HTTP_BASIC = "http_basic"
|
|
63
|
+
NONE = "none"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AuthenticatedContext(BaseModel):
|
|
67
|
+
"""
|
|
68
|
+
Represents an authenticated context for making requests.
|
|
69
|
+
"""
|
|
70
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
|
71
|
+
headers: dict[str, str] | httpx.Headers | None = Field(default=None,
|
|
72
|
+
description="HTTP headers used for authentication.")
|
|
73
|
+
query_params: dict[str, str] | httpx.QueryParams | None = Field(
|
|
74
|
+
default=None, description="Query parameters used for authentication.")
|
|
75
|
+
cookies: dict[str, str] | httpx.Cookies | None = Field(default=None, description="Cookies used for authentication.")
|
|
76
|
+
body: dict[str, str] | None = Field(default=None, description="Authenticated Body value, if applicable.")
|
|
77
|
+
metadata: dict[str, typing.Any] | None = Field(default=None, description="Additional metadata for the request.")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class HeaderAuthScheme(str, Enum):
|
|
81
|
+
"""
|
|
82
|
+
Enum representing different header authentication schemes.
|
|
83
|
+
"""
|
|
84
|
+
BEARER = "Bearer"
|
|
85
|
+
X_API_KEY = "X-API-Key"
|
|
86
|
+
BASIC = "Basic"
|
|
87
|
+
CUSTOM = "Custom"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class HTTPMethod(str, Enum):
|
|
91
|
+
"""
|
|
92
|
+
Enum representing HTTP methods used in requests.
|
|
93
|
+
"""
|
|
94
|
+
GET = "GET"
|
|
95
|
+
POST = "POST"
|
|
96
|
+
PUT = "PUT"
|
|
97
|
+
DELETE = "DELETE"
|
|
98
|
+
PATCH = "PATCH"
|
|
99
|
+
HEAD = "HEAD"
|
|
100
|
+
OPTIONS = "OPTIONS"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class CredentialKind(str, Enum):
|
|
104
|
+
"""
|
|
105
|
+
Enum representing different kinds of credentials used for authentication.
|
|
106
|
+
"""
|
|
107
|
+
HEADER = "header"
|
|
108
|
+
QUERY = "query"
|
|
109
|
+
COOKIE = "cookie"
|
|
110
|
+
BASIC = "basic_auth"
|
|
111
|
+
BEARER = "bearer_token"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class _CredBase(BaseModel):
|
|
115
|
+
"""
|
|
116
|
+
Base class for credentials used in authentication.
|
|
117
|
+
"""
|
|
118
|
+
kind: CredentialKind
|
|
119
|
+
model_config = ConfigDict(extra="forbid")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class HeaderCred(_CredBase):
|
|
123
|
+
"""
|
|
124
|
+
Represents a credential that is sent in the HTTP header.
|
|
125
|
+
"""
|
|
126
|
+
kind: typing.Literal[CredentialKind.HEADER] = CredentialKind.HEADER
|
|
127
|
+
name: str
|
|
128
|
+
value: SecretStr
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class QueryCred(_CredBase):
|
|
132
|
+
"""
|
|
133
|
+
Represents a credential that is sent as a query parameter in the URL.
|
|
134
|
+
"""
|
|
135
|
+
kind: typing.Literal[CredentialKind.QUERY] = CredentialKind.QUERY
|
|
136
|
+
name: str
|
|
137
|
+
value: SecretStr
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class CookieCred(_CredBase):
|
|
141
|
+
"""
|
|
142
|
+
Represents a credential that is sent as a cookie in the HTTP request.
|
|
143
|
+
"""
|
|
144
|
+
kind: typing.Literal[CredentialKind.COOKIE] = CredentialKind.COOKIE
|
|
145
|
+
name: str
|
|
146
|
+
value: SecretStr
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class BasicAuthCred(_CredBase):
|
|
150
|
+
"""
|
|
151
|
+
Represents credentials for HTTP Basic Authentication.
|
|
152
|
+
"""
|
|
153
|
+
kind: typing.Literal[CredentialKind.BASIC] = CredentialKind.BASIC
|
|
154
|
+
username: SecretStr
|
|
155
|
+
password: SecretStr
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class BearerTokenCred(_CredBase):
|
|
159
|
+
"""
|
|
160
|
+
Represents a credential for Bearer Token Authentication.
|
|
161
|
+
"""
|
|
162
|
+
kind: typing.Literal[CredentialKind.BEARER] = CredentialKind.BEARER
|
|
163
|
+
token: SecretStr
|
|
164
|
+
scheme: str = "Bearer"
|
|
165
|
+
header_name: str = "Authorization"
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
Credential = typing.Annotated[
|
|
169
|
+
typing.Union[
|
|
170
|
+
HeaderCred,
|
|
171
|
+
QueryCred,
|
|
172
|
+
CookieCred,
|
|
173
|
+
BasicAuthCred,
|
|
174
|
+
BearerTokenCred,
|
|
175
|
+
],
|
|
176
|
+
Field(discriminator="kind"),
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class AuthResult(BaseModel):
|
|
181
|
+
"""
|
|
182
|
+
Represents the result of an authentication process.
|
|
183
|
+
"""
|
|
184
|
+
credentials: list[Credential] = Field(default_factory=list,
|
|
185
|
+
description="List of credentials used for authentication.")
|
|
186
|
+
token_expires_at: datetime | None = Field(default=None, description="Expiration time of the token, if applicable.")
|
|
187
|
+
raw: dict[str, typing.Any] = Field(default_factory=dict,
|
|
188
|
+
description="Raw response data from the authentication process.")
|
|
189
|
+
|
|
190
|
+
model_config = ConfigDict(extra="forbid")
|
|
191
|
+
|
|
192
|
+
def is_expired(self) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
Checks if the authentication token has expired.
|
|
195
|
+
"""
|
|
196
|
+
return bool(self.token_expires_at and datetime.now(timezone.utc) >= self.token_expires_at)
|
|
197
|
+
|
|
198
|
+
def as_requests_kwargs(self) -> dict[str, typing.Any]:
|
|
199
|
+
"""
|
|
200
|
+
Converts the authentication credentials into a format suitable for use with the `httpx` library.
|
|
201
|
+
"""
|
|
202
|
+
kw: dict[str, typing.Any] = {"headers": {}, "params": {}, "cookies": {}}
|
|
203
|
+
|
|
204
|
+
for cred in self.credentials:
|
|
205
|
+
match cred:
|
|
206
|
+
case HeaderCred():
|
|
207
|
+
kw["headers"][cred.name] = cred.value.get_secret_value()
|
|
208
|
+
case QueryCred():
|
|
209
|
+
kw["params"][cred.name] = cred.value.get_secret_value()
|
|
210
|
+
case CookieCred():
|
|
211
|
+
kw["cookies"][cred.name] = cred.value.get_secret_value()
|
|
212
|
+
case BearerTokenCred():
|
|
213
|
+
kw["headers"][cred.header_name] = (f"{cred.scheme} {cred.token.get_secret_value()}")
|
|
214
|
+
case BasicAuthCred():
|
|
215
|
+
kw["auth"] = (
|
|
216
|
+
cred.username.get_secret_value(),
|
|
217
|
+
cred.password.get_secret_value(),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return kw
|
|
221
|
+
|
|
222
|
+
def attach(self, target_kwargs: dict[str, typing.Any]) -> None:
|
|
223
|
+
"""
|
|
224
|
+
Attaches the authentication credentials to the target request kwargs.
|
|
225
|
+
"""
|
|
226
|
+
merged = self.as_requests_kwargs()
|
|
227
|
+
for k, v in merged.items():
|
|
228
|
+
if isinstance(v, dict):
|
|
229
|
+
target_kwargs.setdefault(k, {}).update(v)
|
|
230
|
+
else:
|
|
231
|
+
target_kwargs[k] = v
|
aiq/data_models/common.py
CHANGED
|
@@ -21,6 +21,8 @@ from hashlib import sha512
|
|
|
21
21
|
from pydantic import AliasChoices
|
|
22
22
|
from pydantic import BaseModel
|
|
23
23
|
from pydantic import Field
|
|
24
|
+
from pydantic.json_schema import GenerateJsonSchema
|
|
25
|
+
from pydantic.json_schema import JsonSchemaMode
|
|
24
26
|
|
|
25
27
|
_LT = typing.TypeVar("_LT")
|
|
26
28
|
|
|
@@ -67,8 +69,8 @@ def subclass_depth(cls: type) -> int:
|
|
|
67
69
|
Compute a class' subclass depth.
|
|
68
70
|
"""
|
|
69
71
|
depth = 0
|
|
70
|
-
while (cls is not object):
|
|
71
|
-
cls = cls.__base__
|
|
72
|
+
while (cls is not object and cls.__base__ is not None):
|
|
73
|
+
cls = cls.__base__ # type: ignore
|
|
72
74
|
depth += 1
|
|
73
75
|
return depth
|
|
74
76
|
|
|
@@ -93,7 +95,8 @@ class TypedBaseModel(BaseModel):
|
|
|
93
95
|
Subclass of Pydantic BaseModel that allows for specifying the object type. Use in Pydantic discriminated unions.
|
|
94
96
|
"""
|
|
95
97
|
|
|
96
|
-
type: str = Field(
|
|
98
|
+
type: str = Field(default="unknown",
|
|
99
|
+
init=False,
|
|
97
100
|
serialization_alias="_type",
|
|
98
101
|
validation_alias=AliasChoices('type', '_type'),
|
|
99
102
|
description="The type of the object",
|
|
@@ -101,6 +104,7 @@ class TypedBaseModel(BaseModel):
|
|
|
101
104
|
repr=False)
|
|
102
105
|
|
|
103
106
|
full_type: typing.ClassVar[str]
|
|
107
|
+
_typed_model_name: typing.ClassVar[str | None] = None
|
|
104
108
|
|
|
105
109
|
def __init_subclass__(cls, name: str | None = None):
|
|
106
110
|
super().__init_subclass__()
|
|
@@ -117,14 +121,38 @@ class TypedBaseModel(BaseModel):
|
|
|
117
121
|
|
|
118
122
|
full_name = f"{package_name}/{name}"
|
|
119
123
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
type_field.default = name
|
|
124
|
+
# Store the type name as a class attribute - no field manipulation needed!
|
|
125
|
+
cls._typed_model_name = name # type: ignore
|
|
123
126
|
cls.full_type = full_name
|
|
124
127
|
|
|
128
|
+
def model_post_init(self, __context):
|
|
129
|
+
"""Set the type field to the correct value after instance creation."""
|
|
130
|
+
if hasattr(self.__class__, '_typed_model_name') and self.__class__._typed_model_name is not None:
|
|
131
|
+
object.__setattr__(self, 'type', self.__class__._typed_model_name)
|
|
132
|
+
# If no type name is set, the field retains its default "unknown" value
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def model_json_schema(cls,
|
|
136
|
+
by_alias: bool = True,
|
|
137
|
+
ref_template: str = '#/$defs/{model}',
|
|
138
|
+
schema_generator: "type[GenerateJsonSchema]" = GenerateJsonSchema,
|
|
139
|
+
mode: JsonSchemaMode = 'validation') -> dict:
|
|
140
|
+
"""Override to provide correct default for type field in schema."""
|
|
141
|
+
schema = super().model_json_schema(by_alias=by_alias,
|
|
142
|
+
ref_template=ref_template,
|
|
143
|
+
schema_generator=schema_generator,
|
|
144
|
+
mode=mode)
|
|
145
|
+
|
|
146
|
+
# Fix the type field default to show the actual component type instead of "unknown"
|
|
147
|
+
if ('properties' in schema and 'type' in schema['properties'] and hasattr(cls, '_typed_model_name')
|
|
148
|
+
and cls._typed_model_name is not None):
|
|
149
|
+
schema['properties']['type']['default'] = cls._typed_model_name
|
|
150
|
+
|
|
151
|
+
return schema
|
|
152
|
+
|
|
125
153
|
@classmethod
|
|
126
154
|
def static_type(cls):
|
|
127
|
-
return cls
|
|
155
|
+
return getattr(cls, '_typed_model_name')
|
|
128
156
|
|
|
129
157
|
@classmethod
|
|
130
158
|
def static_full_type(cls):
|
aiq/data_models/component.py
CHANGED
|
@@ -20,27 +20,35 @@ logger = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class AIQComponentEnum(StrEnum):
|
|
23
|
+
# Keep sorted!!!
|
|
24
|
+
AUTHENTICATION_PROVIDER = "auth_provider"
|
|
25
|
+
EMBEDDER_CLIENT = "embedder_client"
|
|
26
|
+
EMBEDDER_PROVIDER = "embedder_provider"
|
|
27
|
+
EVALUATOR = "evaluator"
|
|
23
28
|
FRONT_END = "front_end"
|
|
24
29
|
FUNCTION = "function"
|
|
25
|
-
|
|
26
|
-
LLM_PROVIDER = "llm_provider"
|
|
30
|
+
ITS_STRATEGY = "its_strategy"
|
|
27
31
|
LLM_CLIENT = "llm_client"
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
EVALUATOR = "evaluator"
|
|
32
|
+
LLM_PROVIDER = "llm_provider"
|
|
33
|
+
LOGGING = "logging"
|
|
31
34
|
MEMORY = "memory"
|
|
32
|
-
|
|
33
|
-
|
|
35
|
+
OBJECT_STORE = "object_store"
|
|
36
|
+
PACKAGE = "package"
|
|
34
37
|
REGISTRY_HANDLER = "registry_handler"
|
|
35
|
-
|
|
38
|
+
RETRIEVER_CLIENT = "retriever_client"
|
|
39
|
+
RETRIEVER_PROVIDER = "retriever_provider"
|
|
40
|
+
TOOL_WRAPPER = "tool_wrapper"
|
|
36
41
|
TRACING = "tracing"
|
|
37
|
-
PACKAGE = "package"
|
|
38
42
|
UNDEFINED = "undefined"
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
class ComponentGroup(StrEnum):
|
|
46
|
+
# Keep sorted!!!
|
|
47
|
+
AUTHENTICATION = "authentication"
|
|
42
48
|
EMBEDDERS = "embedders"
|
|
43
49
|
FUNCTIONS = "functions"
|
|
50
|
+
ITS_STRATEGIES = "its_strategies"
|
|
44
51
|
LLMS = "llms"
|
|
45
52
|
MEMORY = "memory"
|
|
53
|
+
OBJECT_STORES = "object_stores"
|
|
46
54
|
RETRIEVERS = "retrievers"
|
aiq/data_models/component_ref.py
CHANGED
|
@@ -124,6 +124,17 @@ class MemoryRef(ComponentRef):
|
|
|
124
124
|
return ComponentGroup.MEMORY
|
|
125
125
|
|
|
126
126
|
|
|
127
|
+
class ObjectStoreRef(ComponentRef):
|
|
128
|
+
"""
|
|
129
|
+
A reference to an object store in an AIQ toolkit configuration object.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
@typing.override
|
|
134
|
+
def component_group(self):
|
|
135
|
+
return ComponentGroup.OBJECT_STORES
|
|
136
|
+
|
|
137
|
+
|
|
127
138
|
class RetrieverRef(ComponentRef):
|
|
128
139
|
"""
|
|
129
140
|
A reference to a retriever in an AIQ Toolkit configuration object.
|
|
@@ -133,3 +144,25 @@ class RetrieverRef(ComponentRef):
|
|
|
133
144
|
@override
|
|
134
145
|
def component_group(self):
|
|
135
146
|
return ComponentGroup.RETRIEVERS
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class AuthenticationRef(ComponentRef):
|
|
150
|
+
"""
|
|
151
|
+
A reference to an API Authentication Provider in an AIQ Toolkit configuration object.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
@override
|
|
156
|
+
def component_group(self):
|
|
157
|
+
return ComponentGroup.AUTHENTICATION
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class ITSStrategyRef(ComponentRef):
|
|
161
|
+
"""
|
|
162
|
+
A reference to an ITS strategy in an AgentIQ configuration object.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
@override
|
|
167
|
+
def component_group(self):
|
|
168
|
+
return ComponentGroup.ITS_STRATEGIES
|
aiq/data_models/config.py
CHANGED
|
@@ -29,15 +29,18 @@ from aiq.data_models.evaluate import EvalConfig
|
|
|
29
29
|
from aiq.data_models.front_end import FrontEndBaseConfig
|
|
30
30
|
from aiq.data_models.function import EmptyFunctionConfig
|
|
31
31
|
from aiq.data_models.function import FunctionBaseConfig
|
|
32
|
+
from aiq.data_models.its_strategy import ITSStrategyBaseConfig
|
|
32
33
|
from aiq.data_models.logging import LoggingBaseConfig
|
|
33
34
|
from aiq.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
35
|
from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
35
36
|
|
|
37
|
+
from .authentication import AuthProviderBaseConfig
|
|
36
38
|
from .common import HashableBaseModel
|
|
37
39
|
from .common import TypedBaseModel
|
|
38
40
|
from .embedder import EmbedderBaseConfig
|
|
39
41
|
from .llm import LLMBaseConfig
|
|
40
42
|
from .memory import MemoryBaseConfig
|
|
43
|
+
from .object_store import ObjectStoreBaseConfig
|
|
41
44
|
from .retriever import RetrieverBaseConfig
|
|
42
45
|
|
|
43
46
|
logger = logging.getLogger(__name__)
|
|
@@ -57,12 +60,16 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
|
|
58
61
|
if (info.field_name in ('workflow', 'functions')):
|
|
59
62
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
63
|
+
elif (info.field_name == "authentication"):
|
|
64
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
60
65
|
elif (info.field_name == "llms"):
|
|
61
66
|
registered_keys = GlobalTypeRegistry.get().get_registered_llm_providers()
|
|
62
67
|
elif (info.field_name == "embedders"):
|
|
63
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_embedder_providers()
|
|
64
69
|
elif (info.field_name == "memory"):
|
|
65
70
|
registered_keys = GlobalTypeRegistry.get().get_registered_memorys()
|
|
71
|
+
elif (info.field_name == "object_stores"):
|
|
72
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_object_stores()
|
|
66
73
|
elif (info.field_name == "retrievers"):
|
|
67
74
|
registered_keys = GlobalTypeRegistry.get().get_registered_retriever_providers()
|
|
68
75
|
elif (info.field_name == "tracing"):
|
|
@@ -73,6 +80,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
73
80
|
registered_keys = GlobalTypeRegistry.get().get_registered_evaluators()
|
|
74
81
|
elif (info.field_name == "front_ends"):
|
|
75
82
|
registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
|
|
83
|
+
elif (info.field_name == "its_strategies"):
|
|
84
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_its_strategies()
|
|
76
85
|
|
|
77
86
|
else:
|
|
78
87
|
assert False, f"Unknown field name {info.field_name} in validator"
|
|
@@ -242,12 +251,21 @@ class AIQConfig(HashableBaseModel):
|
|
|
242
251
|
# Memory Configuration
|
|
243
252
|
memory: dict[str, MemoryBaseConfig] = {}
|
|
244
253
|
|
|
254
|
+
# Object Stores Configuration
|
|
255
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = {}
|
|
256
|
+
|
|
245
257
|
# Retriever Configuration
|
|
246
258
|
retrievers: dict[str, RetrieverBaseConfig] = {}
|
|
247
259
|
|
|
260
|
+
# ITS Strategies
|
|
261
|
+
its_strategies: dict[str, ITSStrategyBaseConfig] = {}
|
|
262
|
+
|
|
248
263
|
# Workflow Configuration
|
|
249
264
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
250
265
|
|
|
266
|
+
# Authentication Configuration
|
|
267
|
+
authentication: dict[str, AuthProviderBaseConfig] = {}
|
|
268
|
+
|
|
251
269
|
# Evaluation Options
|
|
252
270
|
eval: EvalConfig = EvalConfig()
|
|
253
271
|
|
|
@@ -263,9 +281,20 @@ class AIQConfig(HashableBaseModel):
|
|
|
263
281
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
264
282
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
265
283
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
284
|
+
stream.write(f"Number of Object Stores: {len(self.object_stores)}\n")
|
|
266
285
|
stream.write(f"Number of Retrievers: {len(self.retrievers)}\n")
|
|
267
|
-
|
|
268
|
-
|
|
286
|
+
stream.write(f"Number of ITS Strategies: {len(self.its_strategies)}\n")
|
|
287
|
+
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
|
+
|
|
289
|
+
@field_validator("functions",
|
|
290
|
+
"llms",
|
|
291
|
+
"embedders",
|
|
292
|
+
"memory",
|
|
293
|
+
"retrievers",
|
|
294
|
+
"workflow",
|
|
295
|
+
"its_strategies",
|
|
296
|
+
"authentication",
|
|
297
|
+
mode="wrap")
|
|
269
298
|
@classmethod
|
|
270
299
|
def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
271
300
|
|
|
@@ -286,27 +315,45 @@ class AIQConfig(HashableBaseModel):
|
|
|
286
315
|
typing.Annotated[type_registry.compute_annotation(LLMBaseConfig),
|
|
287
316
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
288
317
|
|
|
318
|
+
AuthenticationProviderAnnotation = dict[str,
|
|
319
|
+
typing.Annotated[
|
|
320
|
+
type_registry.compute_annotation(AuthProviderBaseConfig),
|
|
321
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
322
|
+
|
|
289
323
|
EmbeddersAnnotation = dict[str,
|
|
290
324
|
typing.Annotated[type_registry.compute_annotation(EmbedderBaseConfig),
|
|
291
325
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
292
326
|
|
|
293
327
|
FunctionsAnnotation = dict[str,
|
|
294
|
-
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig
|
|
328
|
+
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
295
329
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
296
330
|
|
|
297
331
|
MemoryAnnotation = dict[str,
|
|
298
332
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
299
333
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
300
334
|
|
|
335
|
+
ObjectStoreAnnotation = dict[str,
|
|
336
|
+
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
+
|
|
301
339
|
RetrieverAnnotation = dict[str,
|
|
302
340
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
303
341
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
304
342
|
|
|
343
|
+
ITSStrategyAnnotation = dict[str,
|
|
344
|
+
typing.Annotated[type_registry.compute_annotation(ITSStrategyBaseConfig),
|
|
345
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
|
+
|
|
305
347
|
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
306
348
|
Discriminator(TypedBaseModel.discriminator)]
|
|
307
349
|
|
|
308
350
|
should_rebuild = False
|
|
309
351
|
|
|
352
|
+
auth_providers_field = cls.model_fields.get("authentication")
|
|
353
|
+
if auth_providers_field is not None and auth_providers_field.annotation != AuthenticationProviderAnnotation:
|
|
354
|
+
auth_providers_field.annotation = AuthenticationProviderAnnotation
|
|
355
|
+
should_rebuild = True
|
|
356
|
+
|
|
310
357
|
llms_field = cls.model_fields.get("llms")
|
|
311
358
|
if llms_field is not None and llms_field.annotation != LLMsAnnotation:
|
|
312
359
|
llms_field.annotation = LLMsAnnotation
|
|
@@ -327,11 +374,21 @@ class AIQConfig(HashableBaseModel):
|
|
|
327
374
|
memory_field.annotation = MemoryAnnotation
|
|
328
375
|
should_rebuild = True
|
|
329
376
|
|
|
377
|
+
object_stores_field = cls.model_fields.get("object_stores")
|
|
378
|
+
if object_stores_field is not None and object_stores_field.annotation != ObjectStoreAnnotation:
|
|
379
|
+
object_stores_field.annotation = ObjectStoreAnnotation
|
|
380
|
+
should_rebuild = True
|
|
381
|
+
|
|
330
382
|
retrievers_field = cls.model_fields.get("retrievers")
|
|
331
383
|
if retrievers_field is not None and retrievers_field.annotation != RetrieverAnnotation:
|
|
332
384
|
retrievers_field.annotation = RetrieverAnnotation
|
|
333
385
|
should_rebuild = True
|
|
334
386
|
|
|
387
|
+
its_strategies_field = cls.model_fields.get("its_strategies")
|
|
388
|
+
if its_strategies_field is not None and its_strategies_field.annotation != ITSStrategyAnnotation:
|
|
389
|
+
its_strategies_field.annotation = ITSStrategyAnnotation
|
|
390
|
+
should_rebuild = True
|
|
391
|
+
|
|
335
392
|
workflow_field = cls.model_fields.get("workflow")
|
|
336
393
|
if workflow_field is not None and workflow_field.annotation != WorkflowAnnotation:
|
|
337
394
|
workflow_field.annotation = WorkflowAnnotation
|
aiq/data_models/embedder.py
CHANGED
aiq/data_models/evaluate.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from enum import Enum
|
|
17
18
|
from pathlib import Path
|
|
18
19
|
|
|
19
20
|
from pydantic import BaseModel
|
|
@@ -28,6 +29,12 @@ from aiq.data_models.intermediate_step import IntermediateStepType
|
|
|
28
29
|
from aiq.data_models.profiler import ProfilerConfig
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class JobEvictionPolicy(str, Enum):
|
|
33
|
+
"""Policy for evicting old jobs when max_jobs is exceeded."""
|
|
34
|
+
TIME_CREATED = "time_created"
|
|
35
|
+
TIME_MODIFIED = "time_modified"
|
|
36
|
+
|
|
37
|
+
|
|
31
38
|
class EvalCustomScriptConfig(BaseModel):
|
|
32
39
|
# Path to the script to run
|
|
33
40
|
script: Path
|
|
@@ -35,6 +42,16 @@ class EvalCustomScriptConfig(BaseModel):
|
|
|
35
42
|
kwargs: dict[str, str] = {}
|
|
36
43
|
|
|
37
44
|
|
|
45
|
+
class JobManagementConfig(BaseModel):
|
|
46
|
+
# Whether to append a unique job ID to the output directory for each run
|
|
47
|
+
append_job_id_to_output_dir: bool = False
|
|
48
|
+
# Maximum number of jobs to keep in the output directory. Oldest jobs will be evicted.
|
|
49
|
+
# A value of 0 means no limit.
|
|
50
|
+
max_jobs: int = 0
|
|
51
|
+
# Policy for evicting old jobs. Defaults to using time_created.
|
|
52
|
+
eviction_policy: JobEvictionPolicy = JobEvictionPolicy.TIME_CREATED
|
|
53
|
+
|
|
54
|
+
|
|
38
55
|
class EvalOutputConfig(BaseModel):
|
|
39
56
|
# Output directory for the workflow and evaluation results
|
|
40
57
|
dir: Path = Path("/tmp/aiq/examples/default/")
|
|
@@ -46,6 +63,8 @@ class EvalOutputConfig(BaseModel):
|
|
|
46
63
|
s3: EvalS3Config | None = None
|
|
47
64
|
# Whether to cleanup the output directory before running the workflow
|
|
48
65
|
cleanup: bool = True
|
|
66
|
+
# Job management configuration (job id, eviction, etc.)
|
|
67
|
+
job_management: JobManagementConfig = JobManagementConfig()
|
|
49
68
|
# Filter for the workflow output steps
|
|
50
69
|
workflow_output_step_filter: list[IntermediateStepType] | None = None
|
|
51
70
|
|
|
@@ -53,6 +72,10 @@ class EvalOutputConfig(BaseModel):
|
|
|
53
72
|
class EvalGeneralConfig(BaseModel):
|
|
54
73
|
max_concurrency: int = 8
|
|
55
74
|
|
|
75
|
+
# Workflow alias for displaying in evaluation UI, if not provided,
|
|
76
|
+
# the workflow type will be used
|
|
77
|
+
workflow_alias: str | None = None
|
|
78
|
+
|
|
56
79
|
# Output directory for the workflow and evaluation results
|
|
57
80
|
output_dir: Path = Path("/tmp/aiq/examples/default/")
|
|
58
81
|
|
|
@@ -26,6 +26,7 @@ class FunctionDependencies(BaseModel):
|
|
|
26
26
|
llms: set[str] = Field(default_factory=set)
|
|
27
27
|
embedders: set[str] = Field(default_factory=set)
|
|
28
28
|
memory_clients: set[str] = Field(default_factory=set)
|
|
29
|
+
object_stores: set[str] = Field(default_factory=set)
|
|
29
30
|
retrievers: set[str] = Field(default_factory=set)
|
|
30
31
|
|
|
31
32
|
@field_serializer("functions", when_used="json")
|
|
@@ -44,6 +45,10 @@ class FunctionDependencies(BaseModel):
|
|
|
44
45
|
def serialize_memory_clients(self, v: set[str]) -> list[str]:
|
|
45
46
|
return list(v)
|
|
46
47
|
|
|
48
|
+
@field_serializer("object_stores", when_used="json")
|
|
49
|
+
def serialize_object_stores(self, v: set[str]) -> list[str]:
|
|
50
|
+
return list(v)
|
|
51
|
+
|
|
47
52
|
@field_serializer("retrievers", when_used="json")
|
|
48
53
|
def serialize_retrievers(self, v: set[str]) -> list[str]:
|
|
49
54
|
return list(v)
|
|
@@ -60,5 +65,8 @@ class FunctionDependencies(BaseModel):
|
|
|
60
65
|
def add_memory_client(self, memory_client: str):
|
|
61
66
|
self.memory_clients.add(memory_client) # pylint: disable=no-member
|
|
62
67
|
|
|
68
|
+
def add_object_store(self, object_store: str):
|
|
69
|
+
self.object_stores.add(object_store) # pylint: disable=no-member
|
|
70
|
+
|
|
63
71
|
def add_retriever(self, retriever: str):
|
|
64
72
|
self.retrievers.add(retriever) # pylint: disable=no-member
|