aiqtoolkit 1.1.0a20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +198 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +116 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +372 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +125 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +134 -0
- aiq/builder/workflow_builder.py +733 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +233 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +34 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +37 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +36 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +307 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +869 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +550 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +269 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +101 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +164 -0
- aiq/eval/evaluate.py +322 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +22 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +181 -0
- aiq/front_ends/fastapi/step_adaptor.py +315 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +102 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +59 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +270 -0
- aiq/observability/register.py +97 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +81 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +116 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +75 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +67 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +36 -0
- aiq/tool/retriever.py +89 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +50 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
- aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
- aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
import datetime
|
|
18
|
+
import typing
|
|
19
|
+
import uuid
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from enum import Enum
|
|
22
|
+
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ConfigDict
|
|
25
|
+
from pydantic import Discriminator
|
|
26
|
+
from pydantic import HttpUrl
|
|
27
|
+
from pydantic import conlist
|
|
28
|
+
from pydantic import field_validator
|
|
29
|
+
from pydantic_core.core_schema import ValidationInfo
|
|
30
|
+
|
|
31
|
+
from aiq.data_models.interactive import HumanPrompt
|
|
32
|
+
from aiq.utils.type_converter import GlobalTypeConverter
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Message(BaseModel):
|
|
36
|
+
content: str
|
|
37
|
+
role: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AIQChatRequest(BaseModel):
|
|
41
|
+
"""
|
|
42
|
+
AIQChatRequest is a data model that represents a request to the AgentIQ chat API.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Allow extra fields in the model_config to support derived models
|
|
46
|
+
model_config = ConfigDict(extra="allow")
|
|
47
|
+
|
|
48
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
49
|
+
model: str | None = None
|
|
50
|
+
temperature: float | None = None
|
|
51
|
+
max_tokens: int | None = None
|
|
52
|
+
top_p: float | None = None
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def from_string(data: str,
|
|
56
|
+
*,
|
|
57
|
+
model: str | None = None,
|
|
58
|
+
temperature: float | None = None,
|
|
59
|
+
max_tokens: int | None = None,
|
|
60
|
+
top_p: float | None = None) -> "AIQChatRequest":
|
|
61
|
+
|
|
62
|
+
return AIQChatRequest(messages=[Message(content=data, role="user")],
|
|
63
|
+
model=model,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
max_tokens=max_tokens,
|
|
66
|
+
top_p=top_p)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class AIQChoiceMessage(BaseModel):
|
|
70
|
+
content: str | None = None
|
|
71
|
+
role: str | None = None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AIQChoice(BaseModel):
|
|
75
|
+
model_config = ConfigDict(extra="allow")
|
|
76
|
+
|
|
77
|
+
message: AIQChoiceMessage
|
|
78
|
+
finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
|
|
79
|
+
index: int
|
|
80
|
+
# logprobs: AIQChoiceLogprobs | None = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class AIQUsage(BaseModel):
|
|
84
|
+
prompt_tokens: int
|
|
85
|
+
completion_tokens: int
|
|
86
|
+
total_tokens: int
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AIQResponseSerializable(abc.ABC):
|
|
90
|
+
"""
|
|
91
|
+
AIQChatResponseSerializable is an abstract class that defines the interface for serializing output for the AgentIQ
|
|
92
|
+
chat streaming API.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def get_stream_data(self) -> str:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class AIQResponseBaseModelOutput(BaseModel, AIQResponseSerializable):
|
|
101
|
+
|
|
102
|
+
def get_stream_data(self) -> str:
|
|
103
|
+
return f"data: {self.model_dump_json()}\n\n"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AIQResponseBaseModelIntermediate(BaseModel, AIQResponseSerializable):
|
|
107
|
+
|
|
108
|
+
def get_stream_data(self) -> str:
|
|
109
|
+
return f"intermediate_data: {self.model_dump_json()}\n\n"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class AIQChatResponse(AIQResponseBaseModelOutput):
|
|
113
|
+
"""
|
|
114
|
+
AIQChatResponse is a data model that represents a response from the AgentIQ chat API.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
# Allow extra fields in the model_config to support derived models
|
|
118
|
+
model_config = ConfigDict(extra="allow")
|
|
119
|
+
id: str
|
|
120
|
+
object: str
|
|
121
|
+
model: str = ""
|
|
122
|
+
created: datetime.datetime
|
|
123
|
+
choices: list[AIQChoice]
|
|
124
|
+
usage: AIQUsage | None = None
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def from_string(data: str,
|
|
128
|
+
*,
|
|
129
|
+
id_: str | None = None,
|
|
130
|
+
object_: str | None = None,
|
|
131
|
+
model: str | None = None,
|
|
132
|
+
created: datetime.datetime | None = None,
|
|
133
|
+
usage: AIQUsage | None = None) -> "AIQChatResponse":
|
|
134
|
+
|
|
135
|
+
if id_ is None:
|
|
136
|
+
id_ = str(uuid.uuid4())
|
|
137
|
+
if object_ is None:
|
|
138
|
+
object_ = "chat.completion"
|
|
139
|
+
if model is None:
|
|
140
|
+
model = ""
|
|
141
|
+
if created is None:
|
|
142
|
+
created = datetime.datetime.now(datetime.timezone.utc)
|
|
143
|
+
|
|
144
|
+
return AIQChatResponse(
|
|
145
|
+
id=id_,
|
|
146
|
+
object=object_,
|
|
147
|
+
model=model,
|
|
148
|
+
created=created,
|
|
149
|
+
choices=[AIQChoice(index=0, message=AIQChoiceMessage(content=data), finish_reason="stop")],
|
|
150
|
+
usage=usage)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class AIQChatResponseChunk(AIQResponseBaseModelOutput):
|
|
154
|
+
"""
|
|
155
|
+
AIQChatResponseChunk is a data model that represents a response chunk from the AgentIQ chat streaming API.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
# Allow extra fields in the model_config to support derived models
|
|
159
|
+
model_config = ConfigDict(extra="allow")
|
|
160
|
+
|
|
161
|
+
id: str
|
|
162
|
+
choices: list[AIQChoice]
|
|
163
|
+
created: datetime.datetime
|
|
164
|
+
model: str = ""
|
|
165
|
+
object: str = "chat.completion.chunk"
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def from_string(data: str,
|
|
169
|
+
*,
|
|
170
|
+
id_: str | None = None,
|
|
171
|
+
created: datetime.datetime | None = None,
|
|
172
|
+
model: str | None = None,
|
|
173
|
+
object_: str | None = None) -> "AIQChatResponseChunk":
|
|
174
|
+
|
|
175
|
+
if id_ is None:
|
|
176
|
+
id_ = str(uuid.uuid4())
|
|
177
|
+
if created is None:
|
|
178
|
+
created = datetime.datetime.now(datetime.timezone.utc)
|
|
179
|
+
if model is None:
|
|
180
|
+
model = ""
|
|
181
|
+
if object_ is None:
|
|
182
|
+
object_ = "chat.completion.chunk"
|
|
183
|
+
|
|
184
|
+
return AIQChatResponseChunk(
|
|
185
|
+
id=id_,
|
|
186
|
+
choices=[AIQChoice(index=0, message=AIQChoiceMessage(content=data), finish_reason="stop")],
|
|
187
|
+
created=created,
|
|
188
|
+
model=model,
|
|
189
|
+
object=object_)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class AIQResponseIntermediateStep(AIQResponseBaseModelIntermediate):
|
|
193
|
+
"""
|
|
194
|
+
AIQResponseSerializedStep is a data model that represents a serialized step in the AgentIQ chat streaming API.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
# Allow extra fields in the model_config to support derived models
|
|
198
|
+
model_config = ConfigDict(extra="allow")
|
|
199
|
+
|
|
200
|
+
id: str
|
|
201
|
+
parent_id: str | None = None
|
|
202
|
+
type: str = "markdown"
|
|
203
|
+
name: str
|
|
204
|
+
payload: str
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class AIQResponsePayloadOutput(BaseModel, AIQResponseSerializable):
|
|
208
|
+
|
|
209
|
+
payload: typing.Any
|
|
210
|
+
|
|
211
|
+
def get_stream_data(self) -> str:
|
|
212
|
+
|
|
213
|
+
if (isinstance(self.payload, BaseModel)):
|
|
214
|
+
return f"data: {self.payload.model_dump_json()}\n\n"
|
|
215
|
+
|
|
216
|
+
return f"data: {self.payload}\n\n"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class AIQGenerateResponse(BaseModel):
|
|
220
|
+
# Allow extra fields in the model_config to support derived models
|
|
221
|
+
model_config = ConfigDict(extra="allow")
|
|
222
|
+
|
|
223
|
+
# (fixme) define the intermediate step model
|
|
224
|
+
intermediate_steps: list[tuple] | None = None
|
|
225
|
+
output: str
|
|
226
|
+
value: str | None = "default"
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class UserMessageContentRoleType(str, Enum):
|
|
230
|
+
USER = "user"
|
|
231
|
+
ASSISTANT = "assistant"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class ChatContentType(str, Enum):
|
|
235
|
+
"""
|
|
236
|
+
ChatContentType is an Enum that represents the type of Chat content.
|
|
237
|
+
"""
|
|
238
|
+
TEXT = "text"
|
|
239
|
+
IMAGE_URL = "image_url"
|
|
240
|
+
INPUT_AUDIO = "input_audio"
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class WebSocketMessageType(str, Enum):
|
|
244
|
+
"""
|
|
245
|
+
WebSocketMessageType is an Enum that represents WebSocket Message types.
|
|
246
|
+
"""
|
|
247
|
+
USER_MESSAGE = "user_message"
|
|
248
|
+
RESPONSE_MESSAGE = "system_response_message"
|
|
249
|
+
INTERMEDIATE_STEP_MESSAGE = "system_intermediate_message"
|
|
250
|
+
SYSTEM_INTERACTION_MESSAGE = "system_interaction_message"
|
|
251
|
+
USER_INTERACTION_MESSAGE = "user_interaction_message"
|
|
252
|
+
ERROR_MESSAGE = "error_message"
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class WorkflowSchemaType(str, Enum):
|
|
256
|
+
"""
|
|
257
|
+
WorkflowSchemaType is an Enum that represents Workkflow response types.
|
|
258
|
+
"""
|
|
259
|
+
GENERATE_STREAM = "generate_stream"
|
|
260
|
+
CHAT_STREAM = "chat_stream"
|
|
261
|
+
GENERATE = "generate"
|
|
262
|
+
CHAT = "chat"
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class WebSocketMessageStatus(str, Enum):
|
|
266
|
+
"""
|
|
267
|
+
WebSocketMessageStatus is an Enum that represents the status of a WebSocket message.
|
|
268
|
+
"""
|
|
269
|
+
IN_PROGRESS = "in_progress"
|
|
270
|
+
COMPLETE = "complete"
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class InputAudio(BaseModel):
|
|
274
|
+
data: str = "default"
|
|
275
|
+
format: str = "default"
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class AudioContent(BaseModel):
|
|
279
|
+
model_config = ConfigDict(extra="forbid")
|
|
280
|
+
|
|
281
|
+
type: typing.Literal[ChatContentType.INPUT_AUDIO] = ChatContentType.INPUT_AUDIO
|
|
282
|
+
input_audio: InputAudio = InputAudio()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class ImageUrl(BaseModel):
|
|
286
|
+
url: HttpUrl = HttpUrl(url="http://default.com")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class ImageContent(BaseModel):
|
|
290
|
+
model_config = ConfigDict(extra="forbid")
|
|
291
|
+
|
|
292
|
+
type: typing.Literal[ChatContentType.IMAGE_URL] = ChatContentType.IMAGE_URL
|
|
293
|
+
image_url: ImageUrl = ImageUrl()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class TextContent(BaseModel):
|
|
297
|
+
model_config = ConfigDict(extra="forbid")
|
|
298
|
+
|
|
299
|
+
type: typing.Literal[ChatContentType.TEXT] = ChatContentType.TEXT
|
|
300
|
+
text: str = "default"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class Security(BaseModel):
|
|
304
|
+
model_config = ConfigDict(extra="forbid")
|
|
305
|
+
|
|
306
|
+
api_key: str = "default"
|
|
307
|
+
token: str = "default"
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discriminator("type")]
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class UserMessages(BaseModel):
|
|
314
|
+
model_config = ConfigDict(extra="forbid")
|
|
315
|
+
|
|
316
|
+
role: UserMessageContentRoleType
|
|
317
|
+
content: list[UserContent]
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class UserMessageContent(BaseModel):
|
|
321
|
+
model_config = ConfigDict(extra="forbid")
|
|
322
|
+
messages: list[UserMessages]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class User(BaseModel):
|
|
326
|
+
model_config = ConfigDict(extra="forbid")
|
|
327
|
+
|
|
328
|
+
name: str = "default"
|
|
329
|
+
email: str = "default"
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class ErrorTypes(str, Enum):
|
|
333
|
+
UNKNOWN_ERROR = "unknown_error"
|
|
334
|
+
INVALID_MESSAGE = "invalid_message"
|
|
335
|
+
INVALID_MESSAGE_TYPE = "invalid_message_type"
|
|
336
|
+
INVALID_USER_MESSAGE_CONTENT = "invalid_user_message_content"
|
|
337
|
+
INVALID_DATA_CONTENT = "invalid_data_content"
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class Error(BaseModel):
|
|
341
|
+
model_config = ConfigDict(extra="forbid")
|
|
342
|
+
|
|
343
|
+
code: ErrorTypes = ErrorTypes.UNKNOWN_ERROR
|
|
344
|
+
message: str = "default"
|
|
345
|
+
details: str = "default"
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class WebSocketUserMessage(BaseModel):
|
|
349
|
+
"""
|
|
350
|
+
For more details, refer to the API documentation:
|
|
351
|
+
docs/source/developer_guide/websockets.md
|
|
352
|
+
"""
|
|
353
|
+
# Allow extra fields in the model_config to support derived models
|
|
354
|
+
model_config = ConfigDict(extra="allow")
|
|
355
|
+
|
|
356
|
+
type: typing.Literal[WebSocketMessageType.USER_MESSAGE]
|
|
357
|
+
schema_type: WorkflowSchemaType
|
|
358
|
+
id: str = "default"
|
|
359
|
+
thread_id: str = "default"
|
|
360
|
+
content: UserMessageContent
|
|
361
|
+
user: User = User()
|
|
362
|
+
security: Security = Security()
|
|
363
|
+
error: Error = Error()
|
|
364
|
+
schema_version: str = "1.0.0"
|
|
365
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
369
|
+
"""
|
|
370
|
+
For more details, refer to the API documentation:
|
|
371
|
+
docs/source/developer_guide/websockets.md
|
|
372
|
+
"""
|
|
373
|
+
type: typing.Literal[WebSocketMessageType.USER_INTERACTION_MESSAGE]
|
|
374
|
+
id: str = "default"
|
|
375
|
+
thread_id: str = "default"
|
|
376
|
+
content: UserMessageContent
|
|
377
|
+
user: User = User()
|
|
378
|
+
security: Security = Security()
|
|
379
|
+
error: Error = Error()
|
|
380
|
+
schema_version: str = "1.0.0"
|
|
381
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class SystemIntermediateStepContent(BaseModel):
|
|
385
|
+
model_config = ConfigDict(extra="forbid")
|
|
386
|
+
name: str
|
|
387
|
+
payload: str
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
391
|
+
"""
|
|
392
|
+
For more details, refer to the API documentation:
|
|
393
|
+
docs/source/developer_guide/websockets.md
|
|
394
|
+
"""
|
|
395
|
+
# Allow extra fields in the model_config to support derived models
|
|
396
|
+
model_config = ConfigDict(extra="allow")
|
|
397
|
+
|
|
398
|
+
type: typing.Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE]
|
|
399
|
+
id: str = "default"
|
|
400
|
+
thread_id: str | None = "default"
|
|
401
|
+
parent_id: str = "default"
|
|
402
|
+
intermediate_parent_id: str | None = "default"
|
|
403
|
+
update_message_id: str | None = "default"
|
|
404
|
+
content: SystemIntermediateStepContent
|
|
405
|
+
status: WebSocketMessageStatus
|
|
406
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class SystemResponseContent(BaseModel):
|
|
410
|
+
model_config = ConfigDict(extra="forbid")
|
|
411
|
+
|
|
412
|
+
text: str | None = None
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
416
|
+
"""
|
|
417
|
+
For more details, refer to the API documentation:
|
|
418
|
+
docs/source/developer_guide/websockets.md
|
|
419
|
+
"""
|
|
420
|
+
# Allow extra fields in the model_config to support derived models
|
|
421
|
+
model_config = ConfigDict(extra="allow")
|
|
422
|
+
|
|
423
|
+
type: typing.Literal[WebSocketMessageType.RESPONSE_MESSAGE, WebSocketMessageType.ERROR_MESSAGE]
|
|
424
|
+
id: str | None = "default"
|
|
425
|
+
thread_id: str | None = "default"
|
|
426
|
+
parent_id: str = "default"
|
|
427
|
+
content: SystemResponseContent | Error | AIQGenerateResponse
|
|
428
|
+
status: WebSocketMessageStatus
|
|
429
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
430
|
+
|
|
431
|
+
@field_validator("content")
|
|
432
|
+
@classmethod
|
|
433
|
+
def validate_content_by_type(cls, value: SystemResponseContent | Error | AIQGenerateResponse, info: ValidationInfo):
|
|
434
|
+
if info.data.get("type") == WebSocketMessageType.ERROR_MESSAGE and not isinstance(value, Error):
|
|
435
|
+
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
436
|
+
|
|
437
|
+
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
438
|
+
value, (SystemResponseContent, AIQGenerateResponse)):
|
|
439
|
+
raise ValueError(
|
|
440
|
+
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
441
|
+
return value
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class WebSocketSystemInteractionMessage(BaseModel):
|
|
445
|
+
"""
|
|
446
|
+
For more details, refer to the API documentation:
|
|
447
|
+
docs/source/developer_guide/websockets.md
|
|
448
|
+
"""
|
|
449
|
+
# Allow extra fields in the model_config to support derived models
|
|
450
|
+
model_config = ConfigDict(extra="allow")
|
|
451
|
+
|
|
452
|
+
type: typing.Literal[
|
|
453
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE
|
|
454
|
+
id: str | None = "default"
|
|
455
|
+
thread_id: str | None = "default"
|
|
456
|
+
parent_id: str = "default"
|
|
457
|
+
content: HumanPrompt
|
|
458
|
+
status: WebSocketMessageStatus
|
|
459
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
# ======== AIQGenerateResponse Converters ========
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _generate_response_to_str(response: AIQGenerateResponse) -> str:
|
|
466
|
+
return response.output
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
GlobalTypeConverter.register_converter(_generate_response_to_str)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _generate_response_to_chat_response(response: AIQGenerateResponse) -> AIQChatResponse:
|
|
473
|
+
data = response.output
|
|
474
|
+
|
|
475
|
+
# Simulate usage
|
|
476
|
+
prompt_tokens = 0
|
|
477
|
+
usage = AIQUsage(prompt_tokens=prompt_tokens,
|
|
478
|
+
completion_tokens=len(data.split()),
|
|
479
|
+
total_tokens=prompt_tokens + len(data.split()))
|
|
480
|
+
|
|
481
|
+
# Build and return the response
|
|
482
|
+
return AIQChatResponse.from_string(data, usage=usage)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
GlobalTypeConverter.register_converter(_generate_response_to_chat_response)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
# ======== AIQChatRequest Converters ========
|
|
489
|
+
def _aiq_chat_request_to_string(data: AIQChatRequest) -> str:
|
|
490
|
+
return data.messages[-1].content
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
GlobalTypeConverter.register_converter(_aiq_chat_request_to_string)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _string_to_aiq_chat_request(data: str) -> AIQChatRequest:
|
|
497
|
+
return AIQChatRequest.from_string(data, model="")
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
GlobalTypeConverter.register_converter(_string_to_aiq_chat_request)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
# ======== AIQChatResponse Converters ========
|
|
504
|
+
def _aiq_chat_response_to_string(data: AIQChatResponse) -> str:
|
|
505
|
+
return data.choices[0].message.content or ""
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
GlobalTypeConverter.register_converter(_aiq_chat_response_to_string)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _string_to_aiq_chat_response(data: str) -> AIQChatResponse:
|
|
512
|
+
'''Converts a string to an AIQChatResponse object'''
|
|
513
|
+
|
|
514
|
+
# Simulate usage
|
|
515
|
+
prompt_tokens = 0
|
|
516
|
+
usage = AIQUsage(prompt_tokens=prompt_tokens,
|
|
517
|
+
completion_tokens=len(data.split()),
|
|
518
|
+
total_tokens=prompt_tokens + len(data.split()))
|
|
519
|
+
|
|
520
|
+
# Build and return the response
|
|
521
|
+
return AIQChatResponse.from_string(data, usage=usage)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
GlobalTypeConverter.register_converter(_string_to_aiq_chat_response)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _chat_response_to_chat_response_chunk(data: AIQChatResponse) -> AIQChatResponseChunk:
|
|
528
|
+
|
|
529
|
+
return AIQChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
# ======== AIQChatResponseChunk Converters ========
|
|
536
|
+
def _aiq_chat_response_chunk_to_string(data: AIQChatResponseChunk) -> str:
|
|
537
|
+
return data.choices[0].message.content or ""
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
GlobalTypeConverter.register_converter(_aiq_chat_response_chunk_to_string)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def _string_to_aiq_chat_response_chunk(data: str) -> AIQChatResponseChunk:
|
|
544
|
+
'''Converts a string to an AIQChatResponseChunk object'''
|
|
545
|
+
|
|
546
|
+
# Build and return the response
|
|
547
|
+
return AIQChatResponseChunk.from_string(data)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
GlobalTypeConverter.register_converter(_string_to_aiq_chat_response_chunk)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import inspect
|
|
17
|
+
import sys
|
|
18
|
+
import typing
|
|
19
|
+
from hashlib import sha512
|
|
20
|
+
|
|
21
|
+
from pydantic import AliasChoices
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic import Field
|
|
24
|
+
|
|
25
|
+
_LT = typing.TypeVar("_LT")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class HashableBaseModel(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Subclass of a Pydantic BaseModel that is hashable. Use in objects that need to be hashed for caching purposes.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __hash__(self):
|
|
34
|
+
return int.from_bytes(bytes=sha512(f"{self.__class__.__qualname__}::{self.model_dump_json()}".encode(
|
|
35
|
+
'utf-8', errors='ignore')).digest(),
|
|
36
|
+
byteorder=sys.byteorder)
|
|
37
|
+
|
|
38
|
+
def __lt__(self, other):
|
|
39
|
+
return self.__hash__() < other.__hash__()
|
|
40
|
+
|
|
41
|
+
def __eq__(self, other):
|
|
42
|
+
return self.__hash__() == other.__hash__()
|
|
43
|
+
|
|
44
|
+
def __ne__(self, other):
|
|
45
|
+
return self.__hash__() != other.__hash__()
|
|
46
|
+
|
|
47
|
+
def __gt__(self, other):
|
|
48
|
+
return self.__hash__() > other.__hash__()
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def generate_json_schema(cls) -> dict[str, typing.Any]:
|
|
52
|
+
return cls.model_json_schema()
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def write_json_schema(cls, schema_path: str) -> None:
|
|
56
|
+
|
|
57
|
+
import json
|
|
58
|
+
|
|
59
|
+
schema = cls.generate_json_schema()
|
|
60
|
+
|
|
61
|
+
with open(schema_path, "w", encoding="utf-8") as f:
|
|
62
|
+
json.dump(schema, f, indent=2)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def subclass_depth(cls: type) -> int:
|
|
66
|
+
"""
|
|
67
|
+
Compute a class' subclass depth.
|
|
68
|
+
"""
|
|
69
|
+
depth = 0
|
|
70
|
+
while (cls is not object):
|
|
71
|
+
cls = cls.__base__
|
|
72
|
+
depth += 1
|
|
73
|
+
return depth
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _get_origin_or_base(cls: type) -> type:
|
|
77
|
+
"""
|
|
78
|
+
Get the origin of a type or the base class if it is not a generic.
|
|
79
|
+
"""
|
|
80
|
+
origin = typing.get_origin(cls)
|
|
81
|
+
if origin is None:
|
|
82
|
+
return cls
|
|
83
|
+
return origin
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class BaseModelRegistryTag:
|
|
87
|
+
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TypedBaseModel(BaseModel):
|
|
92
|
+
"""
|
|
93
|
+
Subclass of Pydantic BaseModel that allows for specifying the object type. Use in Pydantic discriminated unions.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
type: str = Field(init=False,
|
|
97
|
+
serialization_alias="_type",
|
|
98
|
+
validation_alias=AliasChoices('type', '_type'),
|
|
99
|
+
description="The type of the object",
|
|
100
|
+
title="Type",
|
|
101
|
+
repr=False)
|
|
102
|
+
|
|
103
|
+
full_type: typing.ClassVar[str]
|
|
104
|
+
|
|
105
|
+
def __init_subclass__(cls, name: str | None = None):
|
|
106
|
+
super().__init_subclass__()
|
|
107
|
+
|
|
108
|
+
if (name is not None):
|
|
109
|
+
module = inspect.getmodule(cls)
|
|
110
|
+
|
|
111
|
+
assert module is not None, f"Module not found for class {cls} when registering {name}"
|
|
112
|
+
package_name: str | None = module.__package__
|
|
113
|
+
|
|
114
|
+
# If the package name is not set, then we use the module name. Must have some namespace which will be unique
|
|
115
|
+
if (not package_name):
|
|
116
|
+
package_name = module.__name__
|
|
117
|
+
|
|
118
|
+
full_name = f"{package_name}/{name}"
|
|
119
|
+
|
|
120
|
+
type_field = cls.model_fields.get("type")
|
|
121
|
+
if type_field is not None:
|
|
122
|
+
type_field.default = name
|
|
123
|
+
cls.full_type = full_name
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def static_type(cls):
|
|
127
|
+
return cls.model_fields.get("type").default
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def static_full_type(cls):
|
|
131
|
+
return cls.full_type
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def discriminator(v: typing.Any) -> str | None:
|
|
135
|
+
# If its serialized, then we use the alias
|
|
136
|
+
if isinstance(v, dict):
|
|
137
|
+
return v.get("_type", v.get("type"))
|
|
138
|
+
|
|
139
|
+
# Otherwise we use the property
|
|
140
|
+
return getattr(v, "type")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
TypedBaseModelT = typing.TypeVar("TypedBaseModelT", bound=TypedBaseModel)
|