aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +170 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +146 -112
- aiq/agent/react_agent/prompt.py +1 -6
- aiq/agent/react_agent/register.py +36 -35
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/agent/tool_calling_agent/register.py +1 -1
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +37 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +44 -12
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +421 -61
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +2 -1
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +124 -12
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/dataset_handler.py +2 -1
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/evaluate.py +23 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +38 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +1 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +87 -2
- aiq/eval/evaluate.py +208 -27
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +1 -0
- aiq/eval/intermediate_step_adapter.py +11 -5
- aiq/eval/rag_evaluator/evaluate.py +55 -15
- aiq/eval/rag_evaluator/register.py +6 -1
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/trajectory_evaluator/evaluate.py +22 -65
- aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
- aiq/eval/tunable_rag_evaluator/register.py +2 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/output_uploader.py +10 -1
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/job_store.py +47 -25
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +108 -89
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +3 -2
- aiq/llm/register.py +1 -0
- aiq/meta/pypi.md +12 -12
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +36 -39
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/inference_metrics_model.py +3 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
- aiq/profiler/inference_optimization/data_models.py +2 -2
- aiq/profiler/inference_optimization/llm_metrics.py +2 -2
- aiq/profiler/profile_runner.py +61 -21
- aiq/runtime/loader.py +9 -3
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +25 -7
- aiq/runtime/user_metadata.py +2 -3
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +41 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +6 -3
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/METADATA +53 -21
- aiqtoolkit-1.2.0rc1.dist-info/RECORD +436 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -148
- aiq/observability/async_otel_listener.py +0 -429
- aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/top_level.txt +0 -0
aiq/data_models/interactive.py
CHANGED
|
@@ -33,6 +33,7 @@ class HumanPromptModelType(str, Enum):
|
|
|
33
33
|
RADIO = "radio"
|
|
34
34
|
CHECKBOX = "checkbox"
|
|
35
35
|
DROPDOWN = "dropdown"
|
|
36
|
+
OAUTH_CONSENT = "oauth_consent"
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class BinaryChoiceOptionsType(str, Enum):
|
|
@@ -145,6 +146,14 @@ class HumanPromptNotification(HumanPromptBase):
|
|
|
145
146
|
input_type: typing.Literal[HumanPromptModelType.NOTIFICATION] = HumanPromptModelType.NOTIFICATION
|
|
146
147
|
|
|
147
148
|
|
|
149
|
+
class _HumanPromptOAuthConsent(HumanPromptBase):
|
|
150
|
+
"""
|
|
151
|
+
Represents an OAuth consent prompt interaction used to notify the UI to open the authentication page for completing
|
|
152
|
+
the consent flow.
|
|
153
|
+
"""
|
|
154
|
+
input_type: typing.Literal[HumanPromptModelType.OAUTH_CONSENT] = HumanPromptModelType.OAUTH_CONSENT
|
|
155
|
+
|
|
156
|
+
|
|
148
157
|
class HumanPromptBinary(HumanPromptBase):
|
|
149
158
|
"""
|
|
150
159
|
Represents a binary interaction.
|
|
@@ -190,7 +199,7 @@ class HumanPromptDropdown(HumanPromptMultipleChoiceBase):
|
|
|
190
199
|
|
|
191
200
|
|
|
192
201
|
HumanPrompt = typing.Annotated[HumanPromptText | HumanPromptNotification | HumanPromptBinary | HumanPromptRadio
|
|
193
|
-
| HumanPromptCheckbox | HumanPromptDropdown,
|
|
202
|
+
| HumanPromptCheckbox | HumanPromptDropdown | _HumanPromptOAuthConsent,
|
|
194
203
|
Discriminator("input_type")]
|
|
195
204
|
|
|
196
205
|
|
|
@@ -17,6 +17,7 @@ import time
|
|
|
17
17
|
import typing
|
|
18
18
|
import uuid
|
|
19
19
|
from enum import Enum
|
|
20
|
+
from typing import Literal
|
|
20
21
|
|
|
21
22
|
from pydantic import BaseModel
|
|
22
23
|
from pydantic import ConfigDict
|
|
@@ -82,6 +83,26 @@ class UsageInfo(BaseModel):
|
|
|
82
83
|
seconds_between_calls: int = 0
|
|
83
84
|
|
|
84
85
|
|
|
86
|
+
class ToolParameters(BaseModel):
|
|
87
|
+
properties: dict[str, typing.Any] = Field(..., description="The properties of the function parameters.")
|
|
88
|
+
required: list[str] = Field(default_factory=list, description="The required properties of the function parameters.")
|
|
89
|
+
type_: Literal["object"] = Field(default="object", description="The type of the function parameters.", alias="type")
|
|
90
|
+
additionalProperties: bool = Field(default=False,
|
|
91
|
+
description="Enable function parameters allow additional properties.")
|
|
92
|
+
strict: bool = Field(default=True, description="Ensure function calls reliably adhere to the function schema.")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ToolDetails(BaseModel):
|
|
96
|
+
name: str = Field(..., description="The name of the function.")
|
|
97
|
+
description: str = Field(..., description="The description of the function.")
|
|
98
|
+
parameters: ToolParameters = Field(..., description="The parameters of the function.")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ToolSchema(BaseModel):
|
|
102
|
+
type: Literal["function"] = Field(..., description="The type of the tool.")
|
|
103
|
+
function: ToolDetails = Field(..., description="The function details.")
|
|
104
|
+
|
|
105
|
+
|
|
85
106
|
class TraceMetadata(BaseModel):
|
|
86
107
|
chat_responses: typing.Any | None = None
|
|
87
108
|
chat_inputs: typing.Any | None = None
|
|
@@ -91,6 +112,8 @@ class TraceMetadata(BaseModel):
|
|
|
91
112
|
span_inputs: typing.Any | None = None
|
|
92
113
|
span_outputs: typing.Any | None = None
|
|
93
114
|
provided_metadata: typing.Any | None = None
|
|
115
|
+
tools_schema: list[ToolSchema] = Field(default_factory=list,
|
|
116
|
+
description="The schema of tools used in a tool calling request.")
|
|
94
117
|
|
|
95
118
|
# Allow extra fields in the model_config to support derived models
|
|
96
119
|
model_config = ConfigDict(extra="allow")
|
|
@@ -211,9 +234,23 @@ class IntermediateStep(BaseModel):
|
|
|
211
234
|
# Allow extra fields in the model_config to support derived models
|
|
212
235
|
model_config = ConfigDict(extra="forbid")
|
|
213
236
|
|
|
214
|
-
|
|
237
|
+
parent_id: str
|
|
238
|
+
"""
|
|
239
|
+
The parent step ID for the current step. The parent ID is the ID of the last START step which has a different UUID
|
|
240
|
+
than the current step. This value is different from the function_ancestry.parent_id value which tracks the last
|
|
241
|
+
parent FUNCTION step. For the first START step, the parent_id is 'root'.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
function_ancestry: InvocationNode
|
|
245
|
+
"""
|
|
246
|
+
The function ancestry for the current step showing the current AIQ function that was being executed when the step
|
|
247
|
+
was created.
|
|
248
|
+
"""
|
|
215
249
|
|
|
216
250
|
payload: IntermediateStepPayload
|
|
251
|
+
"""
|
|
252
|
+
The payload for the current step.
|
|
253
|
+
"""
|
|
217
254
|
|
|
218
255
|
# ===== Payload Properties =====
|
|
219
256
|
@property
|
|
@@ -263,7 +300,3 @@ class IntermediateStep(BaseModel):
|
|
|
263
300
|
@property
|
|
264
301
|
def event_state(self) -> IntermediateStepState:
|
|
265
302
|
return self.payload.event_state
|
|
266
|
-
|
|
267
|
-
@property
|
|
268
|
-
def parent_id(self) -> str | None:
|
|
269
|
-
return self.function_ancestry.function_id if self.function_ancestry else None
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ITSStrategyBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
"""
|
|
24
|
+
Base configuration class for Inference Time Scaling (ITS) strategy.
|
|
25
|
+
This class is used to define the structure of ITS strategy configurations.
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
ITSStrategyBaseConfigT = typing.TypeVar("ITSStrategyBaseConfigT", bound=ITSStrategyBaseConfig)
|
aiq/data_models/llm.py
CHANGED
aiq/data_models/memory.py
CHANGED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ObjectStoreBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
ObjectStoreBaseConfigT = typing.TypeVar("ObjectStoreBaseConfigT", bound=ObjectStoreBaseConfig)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class KeyAlreadyExistsError(Exception):
|
|
30
|
+
|
|
31
|
+
def __init__(self, key: str, additional_message: str | None = None):
|
|
32
|
+
parts = [f"Key already exists: {key}."]
|
|
33
|
+
if additional_message:
|
|
34
|
+
parts.append(additional_message)
|
|
35
|
+
super().__init__(" ".join(parts))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class NoSuchKeyError(Exception):
|
|
39
|
+
|
|
40
|
+
def __init__(self, key: str, additional_message: str | None = None):
|
|
41
|
+
parts = [f"No object found with key: {key}."]
|
|
42
|
+
if additional_message:
|
|
43
|
+
parts.append(additional_message)
|
|
44
|
+
super().__init__(" ".join(parts))
|
aiq/data_models/profiler.py
CHANGED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RetryMixin(BaseModel):
|
|
21
|
+
"""Mixin class for retry configuration."""
|
|
22
|
+
do_auto_retry: bool = Field(default=True,
|
|
23
|
+
description="Whether to automatically retry method calls"
|
|
24
|
+
" that fail with a retryable error.",
|
|
25
|
+
exclude=True)
|
|
26
|
+
num_retries: int = Field(default=5,
|
|
27
|
+
description="Number of times to retry a method call that fails"
|
|
28
|
+
" with a retryable error.",
|
|
29
|
+
exclude=True)
|
|
30
|
+
retry_on_status_codes: list[int | str] = Field(default_factory=lambda: [429, 500, 502, 503, 504],
|
|
31
|
+
description="List of HTTP status codes that should trigger a retry.",
|
|
32
|
+
exclude=True)
|
|
33
|
+
retry_on_errors: list[str] | None = Field(default_factory=lambda: ["Too Many Requests"],
|
|
34
|
+
description="List of error substrings that should trigger a retry.",
|
|
35
|
+
exclude=True)
|
aiq/data_models/span.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
import uuid
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic import Field
|
|
24
|
+
from pydantic import field_validator
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SpanKind(Enum):
|
|
30
|
+
LLM = "LLM"
|
|
31
|
+
TOOL = "TOOL"
|
|
32
|
+
WORKFLOW = "WORKFLOW"
|
|
33
|
+
TASK = "TASK"
|
|
34
|
+
FUNCTION = "FUNCTION"
|
|
35
|
+
CUSTOM = "CUSTOM"
|
|
36
|
+
SPAN = "SPAN"
|
|
37
|
+
EMBEDDER = "EMBEDDER"
|
|
38
|
+
RETRIEVER = "RETRIEVER"
|
|
39
|
+
AGENT = "AGENT"
|
|
40
|
+
RERANKER = "RERANKER"
|
|
41
|
+
GUARDRAIL = "GUARDRAIL"
|
|
42
|
+
EVALUATOR = "EVALUATOR"
|
|
43
|
+
UNKNOWN = "UNKNOWN"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
EVENT_TYPE_TO_SPAN_KIND_MAP = {
|
|
47
|
+
"LLM_START": SpanKind.LLM,
|
|
48
|
+
"LLM_END": SpanKind.LLM,
|
|
49
|
+
"LLM_NEW_TOKEN": SpanKind.LLM,
|
|
50
|
+
"TOOL_START": SpanKind.TOOL,
|
|
51
|
+
"TOOL_END": SpanKind.TOOL,
|
|
52
|
+
"WORKFLOW_START": SpanKind.WORKFLOW,
|
|
53
|
+
"WORKFLOW_END": SpanKind.WORKFLOW,
|
|
54
|
+
"TASK_START": SpanKind.TASK,
|
|
55
|
+
"TASK_END": SpanKind.TASK,
|
|
56
|
+
"FUNCTION_START": SpanKind.FUNCTION,
|
|
57
|
+
"FUNCTION_END": SpanKind.FUNCTION,
|
|
58
|
+
"CUSTOM_START": SpanKind.CUSTOM,
|
|
59
|
+
"CUSTOM_END": SpanKind.CUSTOM,
|
|
60
|
+
"SPAN_START": SpanKind.SPAN,
|
|
61
|
+
"SPAN_END": SpanKind.SPAN,
|
|
62
|
+
"EMBEDDER_START": SpanKind.EMBEDDER,
|
|
63
|
+
"EMBEDDER_END": SpanKind.EMBEDDER,
|
|
64
|
+
"RETRIEVER_START": SpanKind.RETRIEVER,
|
|
65
|
+
"RETRIEVER_END": SpanKind.RETRIEVER,
|
|
66
|
+
"AGENT_START": SpanKind.AGENT,
|
|
67
|
+
"AGENT_END": SpanKind.AGENT,
|
|
68
|
+
"RERANKER_START": SpanKind.RERANKER,
|
|
69
|
+
"RERANKER_END": SpanKind.RERANKER,
|
|
70
|
+
"GUARDRAIL_START": SpanKind.GUARDRAIL,
|
|
71
|
+
"GUARDRAIL_END": SpanKind.GUARDRAIL,
|
|
72
|
+
"EVALUATOR_START": SpanKind.EVALUATOR,
|
|
73
|
+
"EVALUATOR_END": SpanKind.EVALUATOR,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def event_type_to_span_kind(event_type: str) -> SpanKind:
|
|
78
|
+
"""Convert an event type to a span kind.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
event_type (str): The event type to convert.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
SpanKind: The span kind.
|
|
85
|
+
"""
|
|
86
|
+
return EVENT_TYPE_TO_SPAN_KIND_MAP.get(event_type, SpanKind.UNKNOWN)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SpanAttributes(Enum):
|
|
90
|
+
AIQ_SPAN_KIND = "aiq.span.kind"
|
|
91
|
+
INPUT_VALUE = "input.value"
|
|
92
|
+
INPUT_MIME_TYPE = "input.mime_type"
|
|
93
|
+
LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
|
|
94
|
+
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
|
|
95
|
+
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
|
|
96
|
+
OUTPUT_VALUE = "output.value"
|
|
97
|
+
OUTPUT_MIME_TYPE = "output.mime_type"
|
|
98
|
+
AIQ_USAGE_NUM_LLM_CALLS = "aiq.usage.num_llm_calls"
|
|
99
|
+
AIQ_USAGE_SECONDS_BETWEEN_CALLS = "aiq.usage.seconds_between_calls"
|
|
100
|
+
AIQ_USAGE_TOKEN_COUNT_PROMPT = "aiq.usage.token_count.prompt"
|
|
101
|
+
AIQ_USAGE_TOKEN_COUNT_COMPLETION = "aiq.usage.token_count.completion"
|
|
102
|
+
AIQ_USAGE_TOKEN_COUNT_TOTAL = "aiq.usage.token_count.total"
|
|
103
|
+
AIQ_EVENT_TYPE = "aiq.event_type"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class MimeTypes(Enum):
|
|
107
|
+
TEXT = "text/plain"
|
|
108
|
+
JSON = "application/json"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class SpanStatusCode(Enum):
|
|
112
|
+
OK = "OK"
|
|
113
|
+
ERROR = "ERROR"
|
|
114
|
+
UNSET = "UNSET"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SpanEvent(BaseModel):
|
|
118
|
+
timestamp: float = Field(default_factory=lambda: int(time.time() * 1e9), description="The timestamp of the event.")
|
|
119
|
+
name: str = Field(description="The name of the event.")
|
|
120
|
+
attributes: dict[str, Any] = Field(default_factory=dict, description="The attributes of the event.")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SpanStatus(BaseModel):
|
|
124
|
+
code: SpanStatusCode = Field(default=SpanStatusCode.OK, description="The status code of the span.")
|
|
125
|
+
message: str | None = Field(default=None, description="The status message of the span.")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class SpanContext(BaseModel):
|
|
129
|
+
trace_id: int = Field(default_factory=lambda: uuid.uuid4().int, description="The 128-bit trace ID of the span.")
|
|
130
|
+
span_id: int = Field(default_factory=lambda: uuid.uuid4().int & ((1 << 64) - 1),
|
|
131
|
+
description="The 64-bit span ID of the span.")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class Span(BaseModel):
|
|
135
|
+
name: str = Field(description="The name of the span.")
|
|
136
|
+
context: SpanContext | None = Field(default=None, description="The context of the span.")
|
|
137
|
+
parent: "Span | None" = Field(default=None, description="The parent span of the span.")
|
|
138
|
+
start_time: int = Field(default_factory=lambda: int(time.time() * 1e9), description="The start time of the span.")
|
|
139
|
+
end_time: int | None = Field(default=None, description="The end time of the span.")
|
|
140
|
+
attributes: dict[str, Any] = Field(default_factory=dict, description="The attributes of the span.")
|
|
141
|
+
events: list[SpanEvent] = Field(default_factory=list, description="The events of the span.")
|
|
142
|
+
status: SpanStatus = Field(default_factory=SpanStatus, description="The status of the span.")
|
|
143
|
+
|
|
144
|
+
@field_validator('context', mode='before')
|
|
145
|
+
@classmethod
|
|
146
|
+
def set_default_context(cls, v: SpanContext | None) -> SpanContext:
|
|
147
|
+
"""Set the default context if the context is not provided.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
v (SpanContext | None): The context to set.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
SpanContext: The context.
|
|
154
|
+
"""
|
|
155
|
+
if v is None:
|
|
156
|
+
return SpanContext()
|
|
157
|
+
return v
|
|
158
|
+
|
|
159
|
+
def set_attribute(self, key: str, value: Any) -> None:
|
|
160
|
+
"""Set the attribute of the span.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
key (str): The key of the attribute.
|
|
164
|
+
value (Any): The value of the attribute.
|
|
165
|
+
"""
|
|
166
|
+
self.attributes[key] = value
|
|
167
|
+
|
|
168
|
+
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
|
169
|
+
"""Add an event to the span.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
name (str): The name of the event.
|
|
173
|
+
attributes (dict[str, Any] | None): The attributes of the event.
|
|
174
|
+
"""
|
|
175
|
+
if attributes is None:
|
|
176
|
+
attributes = {}
|
|
177
|
+
self.events = self.events + [SpanEvent(name=name, attributes=attributes)]
|
|
178
|
+
|
|
179
|
+
def end(self, end_time: int | None = None) -> None:
|
|
180
|
+
"""End the span.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
end_time (int | None): The end time of the span.
|
|
184
|
+
"""
|
|
185
|
+
if end_time is None:
|
|
186
|
+
end_time = int(time.time() * 1e9)
|
|
187
|
+
self.end_time = end_time
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
17
|
|
|
18
|
-
from .common import BaseModelRegistryTag
|
|
19
|
-
from .common import TypedBaseModel
|
|
18
|
+
from aiq.data_models.common import BaseModelRegistryTag
|
|
19
|
+
from aiq.data_models.common import TypedBaseModel
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class TelemetryExporterBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
aiq/embedder/nim_embedder.py
CHANGED
|
@@ -24,6 +24,7 @@ from aiq.builder.builder import Builder
|
|
|
24
24
|
from aiq.builder.embedder import EmbedderProviderInfo
|
|
25
25
|
from aiq.cli.register_workflow import register_embedder_provider
|
|
26
26
|
from aiq.data_models.embedder import EmbedderBaseConfig
|
|
27
|
+
from aiq.data_models.retry_mixin import RetryMixin
|
|
27
28
|
|
|
28
29
|
allowed_truncate_values = ["NONE", "START", "END"]
|
|
29
30
|
|
|
@@ -37,7 +38,7 @@ def option_in_allowed_values(v):
|
|
|
37
38
|
TruncationOption = typing.Annotated[str, AfterValidator(option_in_allowed_values)]
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
class NIMEmbedderModelConfig(EmbedderBaseConfig, name="nim"):
|
|
41
|
+
class NIMEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="nim"):
|
|
41
42
|
"""A NVIDIA Inference Microservice (NIM) embedder provider to be used with an embedder client."""
|
|
42
43
|
|
|
43
44
|
api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
|
aiq/embedder/openai_embedder.py
CHANGED
|
@@ -21,9 +21,10 @@ from aiq.builder.builder import Builder
|
|
|
21
21
|
from aiq.builder.embedder import EmbedderProviderInfo
|
|
22
22
|
from aiq.cli.register_workflow import register_embedder_provider
|
|
23
23
|
from aiq.data_models.embedder import EmbedderBaseConfig
|
|
24
|
+
from aiq.data_models.retry_mixin import RetryMixin
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
class OpenAIEmbedderModelConfig(EmbedderBaseConfig, name="openai"):
|
|
27
|
+
class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
|
|
27
28
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
28
29
|
|
|
29
30
|
model_config = ConfigDict(protected_namespaces=())
|
aiq/eval/config.py
CHANGED
|
@@ -17,13 +17,18 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
from pydantic import BaseModel
|
|
19
19
|
|
|
20
|
+
from aiq.eval.evaluator.evaluator_model import EvalInput
|
|
21
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutput
|
|
22
|
+
from aiq.eval.usage_stats import UsageStats
|
|
23
|
+
from aiq.profiler.data_models import ProfilerResults
|
|
24
|
+
|
|
20
25
|
|
|
21
26
|
class EvaluationRunConfig(BaseModel):
|
|
22
27
|
"""
|
|
23
28
|
Parameters used for a single evaluation run.
|
|
24
29
|
"""
|
|
25
30
|
config_file: Path
|
|
26
|
-
dataset: str | None # dataset file path can be specified in the config file
|
|
31
|
+
dataset: str | None = None # dataset file path can be specified in the config file
|
|
27
32
|
result_json_path: str = "$"
|
|
28
33
|
skip_workflow: bool = False
|
|
29
34
|
skip_completed_entries: bool = False
|
|
@@ -31,6 +36,14 @@ class EvaluationRunConfig(BaseModel):
|
|
|
31
36
|
endpoint_timeout: int = 300
|
|
32
37
|
reps: int = 1
|
|
33
38
|
override: tuple[tuple[str, str], ...] = ()
|
|
39
|
+
# If false, the output will not be written to the output directory. This is
|
|
40
|
+
# useful when running evaluation via another tool.
|
|
41
|
+
write_output: bool = True
|
|
42
|
+
# if true, the dataset is adjusted to a multiple of the concurrency
|
|
43
|
+
adjust_dataset_size: bool = False
|
|
44
|
+
# number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
|
|
45
|
+
# concurrency. The is only used if adjust_dataset_size is true
|
|
46
|
+
num_passes: int = 0
|
|
34
47
|
|
|
35
48
|
|
|
36
49
|
class EvaluationRunOutput(BaseModel):
|
|
@@ -40,3 +53,8 @@ class EvaluationRunOutput(BaseModel):
|
|
|
40
53
|
workflow_output_file: Path | None
|
|
41
54
|
evaluator_output_files: list[Path]
|
|
42
55
|
workflow_interrupted: bool
|
|
56
|
+
|
|
57
|
+
eval_input: EvalInput
|
|
58
|
+
evaluation_results: list[tuple[str, EvalOutput]]
|
|
59
|
+
usage_stats: UsageStats | None = None
|
|
60
|
+
profiler_results: ProfilerResults
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
+
import math
|
|
17
18
|
|
|
18
19
|
import pandas as pd
|
|
19
20
|
|
|
@@ -33,12 +34,23 @@ class DatasetHandler:
|
|
|
33
34
|
One DatasetHandler object is needed for each dataset to be evaluated.
|
|
34
35
|
"""
|
|
35
36
|
|
|
36
|
-
def __init__(self,
|
|
37
|
+
def __init__(self,
|
|
38
|
+
dataset_config: EvalDatasetConfig,
|
|
39
|
+
reps: int,
|
|
40
|
+
concurrency: int,
|
|
41
|
+
num_passes: int | None = None,
|
|
42
|
+
adjust_dataset_size: bool = False):
|
|
37
43
|
from aiq.eval.intermediate_step_adapter import IntermediateStepAdapter
|
|
38
44
|
|
|
39
45
|
self.dataset_config = dataset_config
|
|
40
46
|
self.dataset_filter = DatasetFilter(dataset_config.filter)
|
|
41
47
|
self.reps = reps
|
|
48
|
+
|
|
49
|
+
# number of passes at specific concurrency
|
|
50
|
+
self.concurrency = concurrency
|
|
51
|
+
self.num_passes = num_passes
|
|
52
|
+
self.adjust_dataset_size = adjust_dataset_size
|
|
53
|
+
|
|
42
54
|
# Helpers
|
|
43
55
|
self.intermediate_step_adapter = IntermediateStepAdapter()
|
|
44
56
|
|
|
@@ -81,6 +93,7 @@ class DatasetHandler:
|
|
|
81
93
|
output_obj=row.get(self.generated_answer_key, "") if structured else "",
|
|
82
94
|
trajectory=row.get(self.trajectory_key, []) if structured else [],
|
|
83
95
|
expected_trajectory=row.get(self.expected_trajectory_key, []) if structured else [],
|
|
96
|
+
full_dataset_entry=row.to_dict(),
|
|
84
97
|
)
|
|
85
98
|
|
|
86
99
|
# if input dataframe is empty return an empty list
|
|
@@ -108,6 +121,63 @@ class DatasetHandler:
|
|
|
108
121
|
|
|
109
122
|
return input_df
|
|
110
123
|
|
|
124
|
+
def adjust_dataset(self, input_df: pd.DataFrame) -> pd.DataFrame:
|
|
125
|
+
"""
|
|
126
|
+
Adjust the dataset so its length is a multiple of concurrency.
|
|
127
|
+
|
|
128
|
+
If num_passes > 0:
|
|
129
|
+
dataset size is adjusted to concurrency * num_passes
|
|
130
|
+
else:
|
|
131
|
+
dataset size is adjusted to the largest multiple of concurrency
|
|
132
|
+
that is less than or equal to the current dataset size
|
|
133
|
+
"""
|
|
134
|
+
if self.concurrency <= 0:
|
|
135
|
+
raise ValueError("Concurrency must be > 0")
|
|
136
|
+
|
|
137
|
+
if self.num_passes < 0:
|
|
138
|
+
raise ValueError("num_passes must be >= 0")
|
|
139
|
+
|
|
140
|
+
original_size = input_df.shape[0]
|
|
141
|
+
|
|
142
|
+
# Calculate target size
|
|
143
|
+
if self.num_passes > 0:
|
|
144
|
+
# When num_passes is specified, always use concurrency * num_passes
|
|
145
|
+
# This respects the user's intent for exact number of passes
|
|
146
|
+
target_size = self.concurrency * self.num_passes
|
|
147
|
+
else:
|
|
148
|
+
# When num_passes = 0, use the largest multiple of concurrency <= original_size
|
|
149
|
+
# If original_size < concurrency, we need at least concurrency rows
|
|
150
|
+
if original_size >= self.concurrency:
|
|
151
|
+
target_size = (original_size // self.concurrency) * self.concurrency
|
|
152
|
+
else:
|
|
153
|
+
target_size = self.concurrency
|
|
154
|
+
|
|
155
|
+
if target_size == 0:
|
|
156
|
+
raise ValueError("Input dataset too small for even one batch at given concurrency.")
|
|
157
|
+
|
|
158
|
+
id_col = self.dataset_config.id_key
|
|
159
|
+
|
|
160
|
+
# If we need more rows than we have, replicate the dataset
|
|
161
|
+
if original_size < target_size:
|
|
162
|
+
# Clean existing _rep suffix if present
|
|
163
|
+
input_df[id_col] = input_df[id_col].astype(str).str.replace(r"_rep\d+$", "", regex=True)
|
|
164
|
+
|
|
165
|
+
# Calculate how many complete copies we need
|
|
166
|
+
copies_needed = math.ceil(target_size / original_size)
|
|
167
|
+
|
|
168
|
+
# Create the replicated dataframe
|
|
169
|
+
replicated_dfs = []
|
|
170
|
+
for i in range(copies_needed):
|
|
171
|
+
df_copy = input_df.copy()
|
|
172
|
+
if i > 0: # Add suffix to all but the first copy
|
|
173
|
+
df_copy[id_col] = df_copy[id_col].astype(str) + f"_rep{i}"
|
|
174
|
+
replicated_dfs.append(df_copy)
|
|
175
|
+
|
|
176
|
+
input_df = pd.concat(replicated_dfs, ignore_index=True)
|
|
177
|
+
|
|
178
|
+
# Return exactly the target size
|
|
179
|
+
return input_df.head(target_size)
|
|
180
|
+
|
|
111
181
|
def get_eval_input_from_dataset(self, dataset: str) -> EvalInput:
|
|
112
182
|
# read the dataset and convert it to EvalInput
|
|
113
183
|
|
|
@@ -126,9 +196,14 @@ class DatasetHandler:
|
|
|
126
196
|
input_df = self.dataset_filter.apply_filters(input_df)
|
|
127
197
|
input_df.drop_duplicates(subset=[self.dataset_config.id_key], inplace=True)
|
|
128
198
|
|
|
199
|
+
if self.reps > 1 and self.adjust_dataset_size:
|
|
200
|
+
raise ValueError("reps and adjust_dataset_size are mutually exclusive")
|
|
201
|
+
|
|
129
202
|
# If more than one repetition is needed, replicate the rows
|
|
130
203
|
if self.reps > 1:
|
|
131
204
|
input_df = self.setup_reps(input_df)
|
|
205
|
+
elif self.adjust_dataset_size:
|
|
206
|
+
input_df = self.adjust_dataset(input_df)
|
|
132
207
|
|
|
133
208
|
# Convert the DataFrame to a list of EvalInput objects
|
|
134
209
|
return self.get_eval_input_from_df(input_df)
|
|
@@ -151,6 +226,16 @@ class DatasetHandler:
|
|
|
151
226
|
allow re-running evaluation using the orignal config file and '--skip_workflow' option.
|
|
152
227
|
"""
|
|
153
228
|
|
|
229
|
+
def parse_if_json_string(value):
|
|
230
|
+
if isinstance(value, str):
|
|
231
|
+
try:
|
|
232
|
+
return json.loads(value)
|
|
233
|
+
except json.JSONDecodeError:
|
|
234
|
+
return value
|
|
235
|
+
if hasattr(value, "model_dump"):
|
|
236
|
+
return value.model_dump()
|
|
237
|
+
return value
|
|
238
|
+
|
|
154
239
|
indent = 2
|
|
155
240
|
if self.is_structured_input():
|
|
156
241
|
# Extract structured data from EvalInputItems
|
|
@@ -164,6 +249,6 @@ class DatasetHandler:
|
|
|
164
249
|
} for item in eval_input.eval_input_items]
|
|
165
250
|
else:
|
|
166
251
|
# Unstructured case: return only raw output objects as a JSON array
|
|
167
|
-
data = [
|
|
252
|
+
data = [parse_if_json_string(item.output_obj) for item in eval_input.eval_input_items]
|
|
168
253
|
|
|
169
254
|
return json.dumps(data, indent=indent, ensure_ascii=False, default=str)
|