aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (220) hide show
  1. aiq/agent/base.py +170 -8
  2. aiq/agent/dual_node.py +1 -1
  3. aiq/agent/react_agent/agent.py +146 -112
  4. aiq/agent/react_agent/prompt.py +1 -6
  5. aiq/agent/react_agent/register.py +36 -35
  6. aiq/agent/rewoo_agent/agent.py +36 -35
  7. aiq/agent/rewoo_agent/register.py +2 -2
  8. aiq/agent/tool_calling_agent/agent.py +3 -7
  9. aiq/agent/tool_calling_agent/register.py +1 -1
  10. aiq/authentication/__init__.py +14 -0
  11. aiq/authentication/api_key/__init__.py +14 -0
  12. aiq/authentication/api_key/api_key_auth_provider.py +92 -0
  13. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  14. aiq/authentication/api_key/register.py +26 -0
  15. aiq/authentication/exceptions/__init__.py +14 -0
  16. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  17. aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
  18. aiq/authentication/exceptions/call_back_exceptions.py +38 -0
  19. aiq/authentication/exceptions/request_exceptions.py +54 -0
  20. aiq/authentication/http_basic_auth/__init__.py +0 -0
  21. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  22. aiq/authentication/http_basic_auth/register.py +30 -0
  23. aiq/authentication/interfaces.py +93 -0
  24. aiq/authentication/oauth2/__init__.py +14 -0
  25. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  26. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  27. aiq/authentication/oauth2/register.py +25 -0
  28. aiq/authentication/register.py +21 -0
  29. aiq/builder/builder.py +64 -2
  30. aiq/builder/component_utils.py +16 -3
  31. aiq/builder/context.py +37 -0
  32. aiq/builder/eval_builder.py +43 -2
  33. aiq/builder/function.py +44 -12
  34. aiq/builder/function_base.py +1 -1
  35. aiq/builder/intermediate_step_manager.py +6 -8
  36. aiq/builder/user_interaction_manager.py +3 -0
  37. aiq/builder/workflow.py +23 -18
  38. aiq/builder/workflow_builder.py +421 -61
  39. aiq/cli/commands/info/list_mcp.py +103 -16
  40. aiq/cli/commands/sizing/__init__.py +14 -0
  41. aiq/cli/commands/sizing/calc.py +294 -0
  42. aiq/cli/commands/sizing/sizing.py +27 -0
  43. aiq/cli/commands/start.py +2 -1
  44. aiq/cli/entrypoint.py +2 -0
  45. aiq/cli/register_workflow.py +80 -0
  46. aiq/cli/type_registry.py +151 -30
  47. aiq/data_models/api_server.py +124 -12
  48. aiq/data_models/authentication.py +231 -0
  49. aiq/data_models/common.py +35 -7
  50. aiq/data_models/component.py +17 -9
  51. aiq/data_models/component_ref.py +33 -0
  52. aiq/data_models/config.py +60 -3
  53. aiq/data_models/dataset_handler.py +2 -1
  54. aiq/data_models/embedder.py +1 -0
  55. aiq/data_models/evaluate.py +23 -0
  56. aiq/data_models/function_dependencies.py +8 -0
  57. aiq/data_models/interactive.py +10 -1
  58. aiq/data_models/intermediate_step.py +38 -5
  59. aiq/data_models/its_strategy.py +30 -0
  60. aiq/data_models/llm.py +1 -0
  61. aiq/data_models/memory.py +1 -0
  62. aiq/data_models/object_store.py +44 -0
  63. aiq/data_models/profiler.py +1 -0
  64. aiq/data_models/retry_mixin.py +35 -0
  65. aiq/data_models/span.py +187 -0
  66. aiq/data_models/telemetry_exporter.py +2 -2
  67. aiq/embedder/nim_embedder.py +2 -1
  68. aiq/embedder/openai_embedder.py +2 -1
  69. aiq/eval/config.py +19 -1
  70. aiq/eval/dataset_handler/dataset_handler.py +87 -2
  71. aiq/eval/evaluate.py +208 -27
  72. aiq/eval/evaluator/base_evaluator.py +73 -0
  73. aiq/eval/evaluator/evaluator_model.py +1 -0
  74. aiq/eval/intermediate_step_adapter.py +11 -5
  75. aiq/eval/rag_evaluator/evaluate.py +55 -15
  76. aiq/eval/rag_evaluator/register.py +6 -1
  77. aiq/eval/remote_workflow.py +7 -2
  78. aiq/eval/runners/__init__.py +14 -0
  79. aiq/eval/runners/config.py +39 -0
  80. aiq/eval/runners/multi_eval_runner.py +54 -0
  81. aiq/eval/trajectory_evaluator/evaluate.py +22 -65
  82. aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
  83. aiq/eval/tunable_rag_evaluator/register.py +2 -0
  84. aiq/eval/usage_stats.py +41 -0
  85. aiq/eval/utils/output_uploader.py +10 -1
  86. aiq/eval/utils/weave_eval.py +184 -0
  87. aiq/experimental/__init__.py +0 -0
  88. aiq/experimental/decorators/__init__.py +0 -0
  89. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  90. aiq/experimental/inference_time_scaling/__init__.py +0 -0
  91. aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
  92. aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
  93. aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
  94. aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
  95. aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
  96. aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
  97. aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
  98. aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
  99. aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
  100. aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
  101. aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
  102. aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
  103. aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
  104. aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
  105. aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
  106. aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
  107. aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
  108. aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
  109. aiq/experimental/inference_time_scaling/register.py +36 -0
  110. aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
  111. aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
  112. aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
  113. aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
  114. aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
  115. aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
  116. aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
  117. aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
  118. aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
  119. aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
  120. aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
  121. aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
  122. aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
  123. aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
  124. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  125. aiq/front_ends/console/console_front_end_plugin.py +11 -2
  126. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  127. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  128. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  129. aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
  130. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  131. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
  132. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
  133. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  134. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  135. aiq/front_ends/fastapi/job_store.py +47 -25
  136. aiq/front_ends/fastapi/main.py +2 -0
  137. aiq/front_ends/fastapi/message_handler.py +108 -89
  138. aiq/front_ends/fastapi/step_adaptor.py +2 -1
  139. aiq/llm/aws_bedrock_llm.py +57 -0
  140. aiq/llm/nim_llm.py +2 -1
  141. aiq/llm/openai_llm.py +3 -2
  142. aiq/llm/register.py +1 -0
  143. aiq/meta/pypi.md +12 -12
  144. aiq/object_store/__init__.py +20 -0
  145. aiq/object_store/in_memory_object_store.py +74 -0
  146. aiq/object_store/interfaces.py +84 -0
  147. aiq/object_store/models.py +36 -0
  148. aiq/object_store/register.py +20 -0
  149. aiq/observability/__init__.py +14 -0
  150. aiq/observability/exporter/__init__.py +14 -0
  151. aiq/observability/exporter/base_exporter.py +449 -0
  152. aiq/observability/exporter/exporter.py +78 -0
  153. aiq/observability/exporter/file_exporter.py +33 -0
  154. aiq/observability/exporter/processing_exporter.py +269 -0
  155. aiq/observability/exporter/raw_exporter.py +52 -0
  156. aiq/observability/exporter/span_exporter.py +264 -0
  157. aiq/observability/exporter_manager.py +335 -0
  158. aiq/observability/mixin/__init__.py +14 -0
  159. aiq/observability/mixin/batch_config_mixin.py +26 -0
  160. aiq/observability/mixin/collector_config_mixin.py +23 -0
  161. aiq/observability/mixin/file_mixin.py +288 -0
  162. aiq/observability/mixin/file_mode.py +23 -0
  163. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  164. aiq/observability/mixin/serialize_mixin.py +61 -0
  165. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  166. aiq/observability/processor/__init__.py +14 -0
  167. aiq/observability/processor/batching_processor.py +316 -0
  168. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  169. aiq/observability/processor/processor.py +68 -0
  170. aiq/observability/register.py +36 -39
  171. aiq/observability/utils/__init__.py +14 -0
  172. aiq/observability/utils/dict_utils.py +236 -0
  173. aiq/observability/utils/time_utils.py +31 -0
  174. aiq/profiler/calc/__init__.py +14 -0
  175. aiq/profiler/calc/calc_runner.py +623 -0
  176. aiq/profiler/calc/calculations.py +288 -0
  177. aiq/profiler/calc/data_models.py +176 -0
  178. aiq/profiler/calc/plot.py +345 -0
  179. aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
  180. aiq/profiler/data_models.py +24 -0
  181. aiq/profiler/inference_metrics_model.py +3 -0
  182. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
  183. aiq/profiler/inference_optimization/data_models.py +2 -2
  184. aiq/profiler/inference_optimization/llm_metrics.py +2 -2
  185. aiq/profiler/profile_runner.py +61 -21
  186. aiq/runtime/loader.py +9 -3
  187. aiq/runtime/runner.py +23 -9
  188. aiq/runtime/session.py +25 -7
  189. aiq/runtime/user_metadata.py +2 -3
  190. aiq/tool/chat_completion.py +74 -0
  191. aiq/tool/code_execution/README.md +152 -0
  192. aiq/tool/code_execution/code_sandbox.py +151 -72
  193. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  194. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
  195. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
  196. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
  197. aiq/tool/code_execution/register.py +7 -3
  198. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  199. aiq/tool/mcp/exceptions.py +142 -0
  200. aiq/tool/mcp/mcp_client.py +41 -6
  201. aiq/tool/mcp/mcp_tool.py +3 -2
  202. aiq/tool/register.py +1 -0
  203. aiq/tool/server_tools.py +6 -3
  204. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  205. aiq/utils/exception_handlers/mcp.py +211 -0
  206. aiq/utils/io/model_processing.py +28 -0
  207. aiq/utils/log_utils.py +37 -0
  208. aiq/utils/string_utils.py +38 -0
  209. aiq/utils/type_converter.py +18 -2
  210. aiq/utils/type_utils.py +87 -0
  211. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/METADATA +53 -21
  212. aiqtoolkit-1.2.0rc2.dist-info/RECORD +436 -0
  213. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/WHEEL +1 -1
  214. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/entry_points.txt +3 -0
  215. aiq/front_ends/fastapi/websocket.py +0 -148
  216. aiq/observability/async_otel_listener.py +0 -429
  217. aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
  218. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  219. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  220. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+ from datetime import datetime
18
+ from datetime import timezone
19
+ from enum import Enum
20
+
21
+ import httpx
22
+ from pydantic import BaseModel
23
+ from pydantic import ConfigDict
24
+ from pydantic import Field
25
+ from pydantic import SecretStr
26
+
27
+ from aiq.data_models.common import BaseModelRegistryTag
28
+ from aiq.data_models.common import TypedBaseModel
29
+
30
+
31
+ class AuthProviderBaseConfig(TypedBaseModel, BaseModelRegistryTag):
32
+ """
33
+ Base configuration for authentication providers.
34
+ """
35
+
36
+ # Default, forbid extra fields to prevent unexpected behavior or miss typed options
37
+ model_config = ConfigDict(extra="forbid")
38
+
39
+
40
+ AuthProviderBaseConfigT = typing.TypeVar("AuthProviderBaseConfigT", bound=AuthProviderBaseConfig)
41
+
42
+
43
+ class CredentialLocation(str, Enum):
44
+ """
45
+ Enum representing the location of credentials in an HTTP request.
46
+ """
47
+ HEADER = "header"
48
+ QUERY = "query"
49
+ COOKIE = "cookie"
50
+ BODY = "body"
51
+
52
+
53
+ class AuthFlowType(str, Enum):
54
+ """
55
+ Enum representing different types of authentication flows.
56
+ """
57
+ API_KEY = "api_key"
58
+ OAUTH2_CLIENT_CREDENTIALS = "oauth2_client_credentials"
59
+ OAUTH2_AUTHORIZATION_CODE = "oauth2_auth_code_flow"
60
+ OAUTH2_PASSWORD = "oauth2_password"
61
+ OAUTH2_DEVICE_CODE = "oauth2_device_code"
62
+ HTTP_BASIC = "http_basic"
63
+ NONE = "none"
64
+
65
+
66
+ class AuthenticatedContext(BaseModel):
67
+ """
68
+ Represents an authenticated context for making requests.
69
+ """
70
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
71
+ headers: dict[str, str] | httpx.Headers | None = Field(default=None,
72
+ description="HTTP headers used for authentication.")
73
+ query_params: dict[str, str] | httpx.QueryParams | None = Field(
74
+ default=None, description="Query parameters used for authentication.")
75
+ cookies: dict[str, str] | httpx.Cookies | None = Field(default=None, description="Cookies used for authentication.")
76
+ body: dict[str, str] | None = Field(default=None, description="Authenticated Body value, if applicable.")
77
+ metadata: dict[str, typing.Any] | None = Field(default=None, description="Additional metadata for the request.")
78
+
79
+
80
+ class HeaderAuthScheme(str, Enum):
81
+ """
82
+ Enum representing different header authentication schemes.
83
+ """
84
+ BEARER = "Bearer"
85
+ X_API_KEY = "X-API-Key"
86
+ BASIC = "Basic"
87
+ CUSTOM = "Custom"
88
+
89
+
90
+ class HTTPMethod(str, Enum):
91
+ """
92
+ Enum representing HTTP methods used in requests.
93
+ """
94
+ GET = "GET"
95
+ POST = "POST"
96
+ PUT = "PUT"
97
+ DELETE = "DELETE"
98
+ PATCH = "PATCH"
99
+ HEAD = "HEAD"
100
+ OPTIONS = "OPTIONS"
101
+
102
+
103
+ class CredentialKind(str, Enum):
104
+ """
105
+ Enum representing different kinds of credentials used for authentication.
106
+ """
107
+ HEADER = "header"
108
+ QUERY = "query"
109
+ COOKIE = "cookie"
110
+ BASIC = "basic_auth"
111
+ BEARER = "bearer_token"
112
+
113
+
114
+ class _CredBase(BaseModel):
115
+ """
116
+ Base class for credentials used in authentication.
117
+ """
118
+ kind: CredentialKind
119
+ model_config = ConfigDict(extra="forbid")
120
+
121
+
122
+ class HeaderCred(_CredBase):
123
+ """
124
+ Represents a credential that is sent in the HTTP header.
125
+ """
126
+ kind: typing.Literal[CredentialKind.HEADER] = CredentialKind.HEADER
127
+ name: str
128
+ value: SecretStr
129
+
130
+
131
+ class QueryCred(_CredBase):
132
+ """
133
+ Represents a credential that is sent as a query parameter in the URL.
134
+ """
135
+ kind: typing.Literal[CredentialKind.QUERY] = CredentialKind.QUERY
136
+ name: str
137
+ value: SecretStr
138
+
139
+
140
+ class CookieCred(_CredBase):
141
+ """
142
+ Represents a credential that is sent as a cookie in the HTTP request.
143
+ """
144
+ kind: typing.Literal[CredentialKind.COOKIE] = CredentialKind.COOKIE
145
+ name: str
146
+ value: SecretStr
147
+
148
+
149
+ class BasicAuthCred(_CredBase):
150
+ """
151
+ Represents credentials for HTTP Basic Authentication.
152
+ """
153
+ kind: typing.Literal[CredentialKind.BASIC] = CredentialKind.BASIC
154
+ username: SecretStr
155
+ password: SecretStr
156
+
157
+
158
+ class BearerTokenCred(_CredBase):
159
+ """
160
+ Represents a credential for Bearer Token Authentication.
161
+ """
162
+ kind: typing.Literal[CredentialKind.BEARER] = CredentialKind.BEARER
163
+ token: SecretStr
164
+ scheme: str = "Bearer"
165
+ header_name: str = "Authorization"
166
+
167
+
168
+ Credential = typing.Annotated[
169
+ typing.Union[
170
+ HeaderCred,
171
+ QueryCred,
172
+ CookieCred,
173
+ BasicAuthCred,
174
+ BearerTokenCred,
175
+ ],
176
+ Field(discriminator="kind"),
177
+ ]
178
+
179
+
180
+ class AuthResult(BaseModel):
181
+ """
182
+ Represents the result of an authentication process.
183
+ """
184
+ credentials: list[Credential] = Field(default_factory=list,
185
+ description="List of credentials used for authentication.")
186
+ token_expires_at: datetime | None = Field(default=None, description="Expiration time of the token, if applicable.")
187
+ raw: dict[str, typing.Any] = Field(default_factory=dict,
188
+ description="Raw response data from the authentication process.")
189
+
190
+ model_config = ConfigDict(extra="forbid")
191
+
192
+ def is_expired(self) -> bool:
193
+ """
194
+ Checks if the authentication token has expired.
195
+ """
196
+ return bool(self.token_expires_at and datetime.now(timezone.utc) >= self.token_expires_at)
197
+
198
+ def as_requests_kwargs(self) -> dict[str, typing.Any]:
199
+ """
200
+ Converts the authentication credentials into a format suitable for use with the `httpx` library.
201
+ """
202
+ kw: dict[str, typing.Any] = {"headers": {}, "params": {}, "cookies": {}}
203
+
204
+ for cred in self.credentials:
205
+ match cred:
206
+ case HeaderCred():
207
+ kw["headers"][cred.name] = cred.value.get_secret_value()
208
+ case QueryCred():
209
+ kw["params"][cred.name] = cred.value.get_secret_value()
210
+ case CookieCred():
211
+ kw["cookies"][cred.name] = cred.value.get_secret_value()
212
+ case BearerTokenCred():
213
+ kw["headers"][cred.header_name] = (f"{cred.scheme} {cred.token.get_secret_value()}")
214
+ case BasicAuthCred():
215
+ kw["auth"] = (
216
+ cred.username.get_secret_value(),
217
+ cred.password.get_secret_value(),
218
+ )
219
+
220
+ return kw
221
+
222
+ def attach(self, target_kwargs: dict[str, typing.Any]) -> None:
223
+ """
224
+ Attaches the authentication credentials to the target request kwargs.
225
+ """
226
+ merged = self.as_requests_kwargs()
227
+ for k, v in merged.items():
228
+ if isinstance(v, dict):
229
+ target_kwargs.setdefault(k, {}).update(v)
230
+ else:
231
+ target_kwargs[k] = v
aiq/data_models/common.py CHANGED
@@ -21,6 +21,8 @@ from hashlib import sha512
21
21
  from pydantic import AliasChoices
22
22
  from pydantic import BaseModel
23
23
  from pydantic import Field
24
+ from pydantic.json_schema import GenerateJsonSchema
25
+ from pydantic.json_schema import JsonSchemaMode
24
26
 
25
27
  _LT = typing.TypeVar("_LT")
26
28
 
@@ -67,8 +69,8 @@ def subclass_depth(cls: type) -> int:
67
69
  Compute a class' subclass depth.
68
70
  """
69
71
  depth = 0
70
- while (cls is not object):
71
- cls = cls.__base__
72
+ while (cls is not object and cls.__base__ is not None):
73
+ cls = cls.__base__ # type: ignore
72
74
  depth += 1
73
75
  return depth
74
76
 
@@ -93,7 +95,8 @@ class TypedBaseModel(BaseModel):
93
95
  Subclass of Pydantic BaseModel that allows for specifying the object type. Use in Pydantic discriminated unions.
94
96
  """
95
97
 
96
- type: str = Field(init=False,
98
+ type: str = Field(default="unknown",
99
+ init=False,
97
100
  serialization_alias="_type",
98
101
  validation_alias=AliasChoices('type', '_type'),
99
102
  description="The type of the object",
@@ -101,6 +104,7 @@ class TypedBaseModel(BaseModel):
101
104
  repr=False)
102
105
 
103
106
  full_type: typing.ClassVar[str]
107
+ _typed_model_name: typing.ClassVar[str | None] = None
104
108
 
105
109
  def __init_subclass__(cls, name: str | None = None):
106
110
  super().__init_subclass__()
@@ -117,14 +121,38 @@ class TypedBaseModel(BaseModel):
117
121
 
118
122
  full_name = f"{package_name}/{name}"
119
123
 
120
- type_field = cls.model_fields.get("type")
121
- if type_field is not None:
122
- type_field.default = name
124
+ # Store the type name as a class attribute - no field manipulation needed!
125
+ cls._typed_model_name = name # type: ignore
123
126
  cls.full_type = full_name
124
127
 
128
+ def model_post_init(self, __context):
129
+ """Set the type field to the correct value after instance creation."""
130
+ if hasattr(self.__class__, '_typed_model_name') and self.__class__._typed_model_name is not None:
131
+ object.__setattr__(self, 'type', self.__class__._typed_model_name)
132
+ # If no type name is set, the field retains its default "unknown" value
133
+
134
+ @classmethod
135
+ def model_json_schema(cls,
136
+ by_alias: bool = True,
137
+ ref_template: str = '#/$defs/{model}',
138
+ schema_generator: "type[GenerateJsonSchema]" = GenerateJsonSchema,
139
+ mode: JsonSchemaMode = 'validation') -> dict:
140
+ """Override to provide correct default for type field in schema."""
141
+ schema = super().model_json_schema(by_alias=by_alias,
142
+ ref_template=ref_template,
143
+ schema_generator=schema_generator,
144
+ mode=mode)
145
+
146
+ # Fix the type field default to show the actual component type instead of "unknown"
147
+ if ('properties' in schema and 'type' in schema['properties'] and hasattr(cls, '_typed_model_name')
148
+ and cls._typed_model_name is not None):
149
+ schema['properties']['type']['default'] = cls._typed_model_name
150
+
151
+ return schema
152
+
125
153
  @classmethod
126
154
  def static_type(cls):
127
- return cls.model_fields.get("type").default
155
+ return getattr(cls, '_typed_model_name')
128
156
 
129
157
  @classmethod
130
158
  def static_full_type(cls):
@@ -20,27 +20,35 @@ logger = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  class AIQComponentEnum(StrEnum):
23
+ # Keep sorted!!!
24
+ AUTHENTICATION_PROVIDER = "auth_provider"
25
+ EMBEDDER_CLIENT = "embedder_client"
26
+ EMBEDDER_PROVIDER = "embedder_provider"
27
+ EVALUATOR = "evaluator"
23
28
  FRONT_END = "front_end"
24
29
  FUNCTION = "function"
25
- TOOL_WRAPPER = "tool_wrapper"
26
- LLM_PROVIDER = "llm_provider"
30
+ ITS_STRATEGY = "its_strategy"
27
31
  LLM_CLIENT = "llm_client"
28
- EMBEDDER_PROVIDER = "embedder_provider"
29
- EMBEDDER_CLIENT = "embedder_client"
30
- EVALUATOR = "evaluator"
32
+ LLM_PROVIDER = "llm_provider"
33
+ LOGGING = "logging"
31
34
  MEMORY = "memory"
32
- RETRIEVER_PROVIDER = "retriever_provider"
33
- RETRIEVER_CLIENT = "retriever_client"
35
+ OBJECT_STORE = "object_store"
36
+ PACKAGE = "package"
34
37
  REGISTRY_HANDLER = "registry_handler"
35
- LOGGING = "logging"
38
+ RETRIEVER_CLIENT = "retriever_client"
39
+ RETRIEVER_PROVIDER = "retriever_provider"
40
+ TOOL_WRAPPER = "tool_wrapper"
36
41
  TRACING = "tracing"
37
- PACKAGE = "package"
38
42
  UNDEFINED = "undefined"
39
43
 
40
44
 
41
45
  class ComponentGroup(StrEnum):
46
+ # Keep sorted!!!
47
+ AUTHENTICATION = "authentication"
42
48
  EMBEDDERS = "embedders"
43
49
  FUNCTIONS = "functions"
50
+ ITS_STRATEGIES = "its_strategies"
44
51
  LLMS = "llms"
45
52
  MEMORY = "memory"
53
+ OBJECT_STORES = "object_stores"
46
54
  RETRIEVERS = "retrievers"
@@ -124,6 +124,17 @@ class MemoryRef(ComponentRef):
124
124
  return ComponentGroup.MEMORY
125
125
 
126
126
 
127
+ class ObjectStoreRef(ComponentRef):
128
+ """
129
+ A reference to an object store in an AIQ toolkit configuration object.
130
+ """
131
+
132
+ @property
133
+ @typing.override
134
+ def component_group(self):
135
+ return ComponentGroup.OBJECT_STORES
136
+
137
+
127
138
  class RetrieverRef(ComponentRef):
128
139
  """
129
140
  A reference to a retriever in an AIQ Toolkit configuration object.
@@ -133,3 +144,25 @@ class RetrieverRef(ComponentRef):
133
144
  @override
134
145
  def component_group(self):
135
146
  return ComponentGroup.RETRIEVERS
147
+
148
+
149
+ class AuthenticationRef(ComponentRef):
150
+ """
151
+ A reference to an API Authentication Provider in an AIQ Toolkit configuration object.
152
+ """
153
+
154
+ @property
155
+ @override
156
+ def component_group(self):
157
+ return ComponentGroup.AUTHENTICATION
158
+
159
+
160
+ class ITSStrategyRef(ComponentRef):
161
+ """
162
+ A reference to an ITS strategy in an AgentIQ configuration object.
163
+ """
164
+
165
+ @property
166
+ @override
167
+ def component_group(self):
168
+ return ComponentGroup.ITS_STRATEGIES
aiq/data_models/config.py CHANGED
@@ -29,15 +29,18 @@ from aiq.data_models.evaluate import EvalConfig
29
29
  from aiq.data_models.front_end import FrontEndBaseConfig
30
30
  from aiq.data_models.function import EmptyFunctionConfig
31
31
  from aiq.data_models.function import FunctionBaseConfig
32
+ from aiq.data_models.its_strategy import ITSStrategyBaseConfig
32
33
  from aiq.data_models.logging import LoggingBaseConfig
33
34
  from aiq.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
35
  from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
35
36
 
37
+ from .authentication import AuthProviderBaseConfig
36
38
  from .common import HashableBaseModel
37
39
  from .common import TypedBaseModel
38
40
  from .embedder import EmbedderBaseConfig
39
41
  from .llm import LLMBaseConfig
40
42
  from .memory import MemoryBaseConfig
43
+ from .object_store import ObjectStoreBaseConfig
41
44
  from .retriever import RetrieverBaseConfig
42
45
 
43
46
  logger = logging.getLogger(__name__)
@@ -57,12 +60,16 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
57
60
 
58
61
  if (info.field_name in ('workflow', 'functions')):
59
62
  registered_keys = GlobalTypeRegistry.get().get_registered_functions()
63
+ elif (info.field_name == "authentication"):
64
+ registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
60
65
  elif (info.field_name == "llms"):
61
66
  registered_keys = GlobalTypeRegistry.get().get_registered_llm_providers()
62
67
  elif (info.field_name == "embedders"):
63
68
  registered_keys = GlobalTypeRegistry.get().get_registered_embedder_providers()
64
69
  elif (info.field_name == "memory"):
65
70
  registered_keys = GlobalTypeRegistry.get().get_registered_memorys()
71
+ elif (info.field_name == "object_stores"):
72
+ registered_keys = GlobalTypeRegistry.get().get_registered_object_stores()
66
73
  elif (info.field_name == "retrievers"):
67
74
  registered_keys = GlobalTypeRegistry.get().get_registered_retriever_providers()
68
75
  elif (info.field_name == "tracing"):
@@ -73,6 +80,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
73
80
  registered_keys = GlobalTypeRegistry.get().get_registered_evaluators()
74
81
  elif (info.field_name == "front_ends"):
75
82
  registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
83
+ elif (info.field_name == "its_strategies"):
84
+ registered_keys = GlobalTypeRegistry.get().get_registered_its_strategies()
76
85
 
77
86
  else:
78
87
  assert False, f"Unknown field name {info.field_name} in validator"
@@ -242,12 +251,21 @@ class AIQConfig(HashableBaseModel):
242
251
  # Memory Configuration
243
252
  memory: dict[str, MemoryBaseConfig] = {}
244
253
 
254
+ # Object Stores Configuration
255
+ object_stores: dict[str, ObjectStoreBaseConfig] = {}
256
+
245
257
  # Retriever Configuration
246
258
  retrievers: dict[str, RetrieverBaseConfig] = {}
247
259
 
260
+ # ITS Strategies
261
+ its_strategies: dict[str, ITSStrategyBaseConfig] = {}
262
+
248
263
  # Workflow Configuration
249
264
  workflow: FunctionBaseConfig = EmptyFunctionConfig()
250
265
 
266
+ # Authentication Configuration
267
+ authentication: dict[str, AuthProviderBaseConfig] = {}
268
+
251
269
  # Evaluation Options
252
270
  eval: EvalConfig = EvalConfig()
253
271
 
@@ -263,9 +281,20 @@ class AIQConfig(HashableBaseModel):
263
281
  stream.write(f"Number of LLMs: {len(self.llms)}\n")
264
282
  stream.write(f"Number of Embedders: {len(self.embedders)}\n")
265
283
  stream.write(f"Number of Memory: {len(self.memory)}\n")
284
+ stream.write(f"Number of Object Stores: {len(self.object_stores)}\n")
266
285
  stream.write(f"Number of Retrievers: {len(self.retrievers)}\n")
267
-
268
- @field_validator("functions", "llms", "embedders", "memory", "retrievers", "workflow", mode="wrap")
286
+ stream.write(f"Number of ITS Strategies: {len(self.its_strategies)}\n")
287
+ stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
+
289
+ @field_validator("functions",
290
+ "llms",
291
+ "embedders",
292
+ "memory",
293
+ "retrievers",
294
+ "workflow",
295
+ "its_strategies",
296
+ "authentication",
297
+ mode="wrap")
269
298
  @classmethod
270
299
  def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
271
300
 
@@ -286,27 +315,45 @@ class AIQConfig(HashableBaseModel):
286
315
  typing.Annotated[type_registry.compute_annotation(LLMBaseConfig),
287
316
  Discriminator(TypedBaseModel.discriminator)]]
288
317
 
318
+ AuthenticationProviderAnnotation = dict[str,
319
+ typing.Annotated[
320
+ type_registry.compute_annotation(AuthProviderBaseConfig),
321
+ Discriminator(TypedBaseModel.discriminator)]]
322
+
289
323
  EmbeddersAnnotation = dict[str,
290
324
  typing.Annotated[type_registry.compute_annotation(EmbedderBaseConfig),
291
325
  Discriminator(TypedBaseModel.discriminator)]]
292
326
 
293
327
  FunctionsAnnotation = dict[str,
294
- typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig, ),
328
+ typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
295
329
  Discriminator(TypedBaseModel.discriminator)]]
296
330
 
297
331
  MemoryAnnotation = dict[str,
298
332
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
299
333
  Discriminator(TypedBaseModel.discriminator)]]
300
334
 
335
+ ObjectStoreAnnotation = dict[str,
336
+ typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
+ Discriminator(TypedBaseModel.discriminator)]]
338
+
301
339
  RetrieverAnnotation = dict[str,
302
340
  typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
303
341
  Discriminator(TypedBaseModel.discriminator)]]
304
342
 
343
+ ITSStrategyAnnotation = dict[str,
344
+ typing.Annotated[type_registry.compute_annotation(ITSStrategyBaseConfig),
345
+ Discriminator(TypedBaseModel.discriminator)]]
346
+
305
347
  WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
306
348
  Discriminator(TypedBaseModel.discriminator)]
307
349
 
308
350
  should_rebuild = False
309
351
 
352
+ auth_providers_field = cls.model_fields.get("authentication")
353
+ if auth_providers_field is not None and auth_providers_field.annotation != AuthenticationProviderAnnotation:
354
+ auth_providers_field.annotation = AuthenticationProviderAnnotation
355
+ should_rebuild = True
356
+
310
357
  llms_field = cls.model_fields.get("llms")
311
358
  if llms_field is not None and llms_field.annotation != LLMsAnnotation:
312
359
  llms_field.annotation = LLMsAnnotation
@@ -327,11 +374,21 @@ class AIQConfig(HashableBaseModel):
327
374
  memory_field.annotation = MemoryAnnotation
328
375
  should_rebuild = True
329
376
 
377
+ object_stores_field = cls.model_fields.get("object_stores")
378
+ if object_stores_field is not None and object_stores_field.annotation != ObjectStoreAnnotation:
379
+ object_stores_field.annotation = ObjectStoreAnnotation
380
+ should_rebuild = True
381
+
330
382
  retrievers_field = cls.model_fields.get("retrievers")
331
383
  if retrievers_field is not None and retrievers_field.annotation != RetrieverAnnotation:
332
384
  retrievers_field.annotation = RetrieverAnnotation
333
385
  should_rebuild = True
334
386
 
387
+ its_strategies_field = cls.model_fields.get("its_strategies")
388
+ if its_strategies_field is not None and its_strategies_field.annotation != ITSStrategyAnnotation:
389
+ its_strategies_field.annotation = ITSStrategyAnnotation
390
+ should_rebuild = True
391
+
335
392
  workflow_field = cls.model_fields.get("workflow")
336
393
  if workflow_field is not None and workflow_field.annotation != WorkflowAnnotation:
337
394
  workflow_field.annotation = WorkflowAnnotation
@@ -30,7 +30,8 @@ from aiq.data_models.common import TypedBaseModel
30
30
 
31
31
  class EvalS3Config(BaseModel):
32
32
 
33
- endpoint_url: str
33
+ endpoint_url: str | None = None
34
+ region_name: str | None = None
34
35
  bucket: str
35
36
  access_key: str
36
37
  secret_key: str
@@ -20,6 +20,7 @@ from .common import TypedBaseModel
20
20
 
21
21
 
22
22
  class EmbedderBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
+ """ Base configuration for embedding model providers. """
23
24
  pass
24
25
 
25
26
 
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import typing
17
+ from enum import Enum
17
18
  from pathlib import Path
18
19
 
19
20
  from pydantic import BaseModel
@@ -28,6 +29,12 @@ from aiq.data_models.intermediate_step import IntermediateStepType
28
29
  from aiq.data_models.profiler import ProfilerConfig
29
30
 
30
31
 
32
+ class JobEvictionPolicy(str, Enum):
33
+ """Policy for evicting old jobs when max_jobs is exceeded."""
34
+ TIME_CREATED = "time_created"
35
+ TIME_MODIFIED = "time_modified"
36
+
37
+
31
38
  class EvalCustomScriptConfig(BaseModel):
32
39
  # Path to the script to run
33
40
  script: Path
@@ -35,6 +42,16 @@ class EvalCustomScriptConfig(BaseModel):
35
42
  kwargs: dict[str, str] = {}
36
43
 
37
44
 
45
+ class JobManagementConfig(BaseModel):
46
+ # Whether to append a unique job ID to the output directory for each run
47
+ append_job_id_to_output_dir: bool = False
48
+ # Maximum number of jobs to keep in the output directory. Oldest jobs will be evicted.
49
+ # A value of 0 means no limit.
50
+ max_jobs: int = 0
51
+ # Policy for evicting old jobs. Defaults to using time_created.
52
+ eviction_policy: JobEvictionPolicy = JobEvictionPolicy.TIME_CREATED
53
+
54
+
38
55
  class EvalOutputConfig(BaseModel):
39
56
  # Output directory for the workflow and evaluation results
40
57
  dir: Path = Path("/tmp/aiq/examples/default/")
@@ -46,6 +63,8 @@ class EvalOutputConfig(BaseModel):
46
63
  s3: EvalS3Config | None = None
47
64
  # Whether to cleanup the output directory before running the workflow
48
65
  cleanup: bool = True
66
+ # Job management configuration (job id, eviction, etc.)
67
+ job_management: JobManagementConfig = JobManagementConfig()
49
68
  # Filter for the workflow output steps
50
69
  workflow_output_step_filter: list[IntermediateStepType] | None = None
51
70
 
@@ -53,6 +72,10 @@ class EvalOutputConfig(BaseModel):
53
72
  class EvalGeneralConfig(BaseModel):
54
73
  max_concurrency: int = 8
55
74
 
75
+ # Workflow alias for displaying in evaluation UI, if not provided,
76
+ # the workflow type will be used
77
+ workflow_alias: str | None = None
78
+
56
79
  # Output directory for the workflow and evaluation results
57
80
  output_dir: Path = Path("/tmp/aiq/examples/default/")
58
81
 
@@ -26,6 +26,7 @@ class FunctionDependencies(BaseModel):
26
26
  llms: set[str] = Field(default_factory=set)
27
27
  embedders: set[str] = Field(default_factory=set)
28
28
  memory_clients: set[str] = Field(default_factory=set)
29
+ object_stores: set[str] = Field(default_factory=set)
29
30
  retrievers: set[str] = Field(default_factory=set)
30
31
 
31
32
  @field_serializer("functions", when_used="json")
@@ -44,6 +45,10 @@ class FunctionDependencies(BaseModel):
44
45
  def serialize_memory_clients(self, v: set[str]) -> list[str]:
45
46
  return list(v)
46
47
 
48
+ @field_serializer("object_stores", when_used="json")
49
+ def serialize_object_stores(self, v: set[str]) -> list[str]:
50
+ return list(v)
51
+
47
52
  @field_serializer("retrievers", when_used="json")
48
53
  def serialize_retrievers(self, v: set[str]) -> list[str]:
49
54
  return list(v)
@@ -60,5 +65,8 @@ class FunctionDependencies(BaseModel):
60
65
  def add_memory_client(self, memory_client: str):
61
66
  self.memory_clients.add(memory_client) # pylint: disable=no-member
62
67
 
68
+ def add_object_store(self, object_store: str):
69
+ self.object_stores.add(object_store) # pylint: disable=no-member
70
+
63
71
  def add_retriever(self, retriever: str):
64
72
  self.retrievers.add(retriever) # pylint: disable=no-member