openhands-sdk 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/__init__.py +111 -0
- openhands/sdk/agent/__init__.py +8 -0
- openhands/sdk/agent/agent.py +650 -0
- openhands/sdk/agent/base.py +457 -0
- openhands/sdk/agent/prompts/in_context_learning_example.j2 +169 -0
- openhands/sdk/agent/prompts/in_context_learning_example_suffix.j2 +3 -0
- openhands/sdk/agent/prompts/model_specific/anthropic_claude.j2 +3 -0
- openhands/sdk/agent/prompts/model_specific/google_gemini.j2 +1 -0
- openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5-codex.j2 +2 -0
- openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5.j2 +3 -0
- openhands/sdk/agent/prompts/security_policy.j2 +22 -0
- openhands/sdk/agent/prompts/security_risk_assessment.j2 +21 -0
- openhands/sdk/agent/prompts/self_documentation.j2 +15 -0
- openhands/sdk/agent/prompts/system_prompt.j2 +132 -0
- openhands/sdk/agent/prompts/system_prompt_interactive.j2 +14 -0
- openhands/sdk/agent/prompts/system_prompt_long_horizon.j2 +40 -0
- openhands/sdk/agent/prompts/system_prompt_planning.j2 +40 -0
- openhands/sdk/agent/prompts/system_prompt_tech_philosophy.j2 +122 -0
- openhands/sdk/agent/utils.py +228 -0
- openhands/sdk/context/__init__.py +28 -0
- openhands/sdk/context/agent_context.py +264 -0
- openhands/sdk/context/condenser/__init__.py +18 -0
- openhands/sdk/context/condenser/base.py +100 -0
- openhands/sdk/context/condenser/llm_summarizing_condenser.py +248 -0
- openhands/sdk/context/condenser/no_op_condenser.py +14 -0
- openhands/sdk/context/condenser/pipeline_condenser.py +56 -0
- openhands/sdk/context/condenser/prompts/summarizing_prompt.j2 +59 -0
- openhands/sdk/context/condenser/utils.py +149 -0
- openhands/sdk/context/prompts/__init__.py +6 -0
- openhands/sdk/context/prompts/prompt.py +114 -0
- openhands/sdk/context/prompts/templates/ask_agent_template.j2 +11 -0
- openhands/sdk/context/prompts/templates/skill_knowledge_info.j2 +8 -0
- openhands/sdk/context/prompts/templates/system_message_suffix.j2 +32 -0
- openhands/sdk/context/skills/__init__.py +28 -0
- openhands/sdk/context/skills/exceptions.py +11 -0
- openhands/sdk/context/skills/skill.py +720 -0
- openhands/sdk/context/skills/trigger.py +36 -0
- openhands/sdk/context/skills/types.py +48 -0
- openhands/sdk/context/view.py +503 -0
- openhands/sdk/conversation/__init__.py +40 -0
- openhands/sdk/conversation/base.py +281 -0
- openhands/sdk/conversation/conversation.py +152 -0
- openhands/sdk/conversation/conversation_stats.py +85 -0
- openhands/sdk/conversation/event_store.py +157 -0
- openhands/sdk/conversation/events_list_base.py +17 -0
- openhands/sdk/conversation/exceptions.py +50 -0
- openhands/sdk/conversation/fifo_lock.py +133 -0
- openhands/sdk/conversation/impl/__init__.py +5 -0
- openhands/sdk/conversation/impl/local_conversation.py +665 -0
- openhands/sdk/conversation/impl/remote_conversation.py +956 -0
- openhands/sdk/conversation/persistence_const.py +9 -0
- openhands/sdk/conversation/response_utils.py +41 -0
- openhands/sdk/conversation/secret_registry.py +126 -0
- openhands/sdk/conversation/serialization_diff.py +0 -0
- openhands/sdk/conversation/state.py +392 -0
- openhands/sdk/conversation/stuck_detector.py +311 -0
- openhands/sdk/conversation/title_utils.py +191 -0
- openhands/sdk/conversation/types.py +45 -0
- openhands/sdk/conversation/visualizer/__init__.py +12 -0
- openhands/sdk/conversation/visualizer/base.py +67 -0
- openhands/sdk/conversation/visualizer/default.py +373 -0
- openhands/sdk/critic/__init__.py +15 -0
- openhands/sdk/critic/base.py +38 -0
- openhands/sdk/critic/impl/__init__.py +12 -0
- openhands/sdk/critic/impl/agent_finished.py +83 -0
- openhands/sdk/critic/impl/empty_patch.py +49 -0
- openhands/sdk/critic/impl/pass_critic.py +42 -0
- openhands/sdk/event/__init__.py +42 -0
- openhands/sdk/event/base.py +149 -0
- openhands/sdk/event/condenser.py +82 -0
- openhands/sdk/event/conversation_error.py +25 -0
- openhands/sdk/event/conversation_state.py +104 -0
- openhands/sdk/event/llm_completion_log.py +39 -0
- openhands/sdk/event/llm_convertible/__init__.py +20 -0
- openhands/sdk/event/llm_convertible/action.py +139 -0
- openhands/sdk/event/llm_convertible/message.py +142 -0
- openhands/sdk/event/llm_convertible/observation.py +141 -0
- openhands/sdk/event/llm_convertible/system.py +61 -0
- openhands/sdk/event/token.py +16 -0
- openhands/sdk/event/types.py +11 -0
- openhands/sdk/event/user_action.py +21 -0
- openhands/sdk/git/exceptions.py +43 -0
- openhands/sdk/git/git_changes.py +249 -0
- openhands/sdk/git/git_diff.py +129 -0
- openhands/sdk/git/models.py +21 -0
- openhands/sdk/git/utils.py +189 -0
- openhands/sdk/hooks/__init__.py +30 -0
- openhands/sdk/hooks/config.py +180 -0
- openhands/sdk/hooks/conversation_hooks.py +227 -0
- openhands/sdk/hooks/executor.py +155 -0
- openhands/sdk/hooks/manager.py +170 -0
- openhands/sdk/hooks/types.py +40 -0
- openhands/sdk/io/__init__.py +6 -0
- openhands/sdk/io/base.py +48 -0
- openhands/sdk/io/cache.py +85 -0
- openhands/sdk/io/local.py +119 -0
- openhands/sdk/io/memory.py +54 -0
- openhands/sdk/llm/__init__.py +45 -0
- openhands/sdk/llm/exceptions/__init__.py +45 -0
- openhands/sdk/llm/exceptions/classifier.py +50 -0
- openhands/sdk/llm/exceptions/mapping.py +54 -0
- openhands/sdk/llm/exceptions/types.py +101 -0
- openhands/sdk/llm/llm.py +1140 -0
- openhands/sdk/llm/llm_registry.py +122 -0
- openhands/sdk/llm/llm_response.py +59 -0
- openhands/sdk/llm/message.py +656 -0
- openhands/sdk/llm/mixins/fn_call_converter.py +1288 -0
- openhands/sdk/llm/mixins/non_native_fc.py +97 -0
- openhands/sdk/llm/options/__init__.py +1 -0
- openhands/sdk/llm/options/chat_options.py +93 -0
- openhands/sdk/llm/options/common.py +19 -0
- openhands/sdk/llm/options/responses_options.py +67 -0
- openhands/sdk/llm/router/__init__.py +10 -0
- openhands/sdk/llm/router/base.py +117 -0
- openhands/sdk/llm/router/impl/multimodal.py +76 -0
- openhands/sdk/llm/router/impl/random.py +22 -0
- openhands/sdk/llm/streaming.py +9 -0
- openhands/sdk/llm/utils/metrics.py +312 -0
- openhands/sdk/llm/utils/model_features.py +192 -0
- openhands/sdk/llm/utils/model_info.py +90 -0
- openhands/sdk/llm/utils/model_prompt_spec.py +98 -0
- openhands/sdk/llm/utils/retry_mixin.py +128 -0
- openhands/sdk/llm/utils/telemetry.py +362 -0
- openhands/sdk/llm/utils/unverified_models.py +156 -0
- openhands/sdk/llm/utils/verified_models.py +65 -0
- openhands/sdk/logger/__init__.py +22 -0
- openhands/sdk/logger/logger.py +195 -0
- openhands/sdk/logger/rolling.py +113 -0
- openhands/sdk/mcp/__init__.py +24 -0
- openhands/sdk/mcp/client.py +76 -0
- openhands/sdk/mcp/definition.py +106 -0
- openhands/sdk/mcp/exceptions.py +19 -0
- openhands/sdk/mcp/tool.py +270 -0
- openhands/sdk/mcp/utils.py +83 -0
- openhands/sdk/observability/__init__.py +4 -0
- openhands/sdk/observability/laminar.py +166 -0
- openhands/sdk/observability/utils.py +20 -0
- openhands/sdk/py.typed +0 -0
- openhands/sdk/secret/__init__.py +19 -0
- openhands/sdk/secret/secrets.py +92 -0
- openhands/sdk/security/__init__.py +6 -0
- openhands/sdk/security/analyzer.py +111 -0
- openhands/sdk/security/confirmation_policy.py +61 -0
- openhands/sdk/security/llm_analyzer.py +29 -0
- openhands/sdk/security/risk.py +100 -0
- openhands/sdk/tool/__init__.py +34 -0
- openhands/sdk/tool/builtins/__init__.py +34 -0
- openhands/sdk/tool/builtins/finish.py +106 -0
- openhands/sdk/tool/builtins/think.py +117 -0
- openhands/sdk/tool/registry.py +184 -0
- openhands/sdk/tool/schema.py +286 -0
- openhands/sdk/tool/spec.py +39 -0
- openhands/sdk/tool/tool.py +481 -0
- openhands/sdk/utils/__init__.py +22 -0
- openhands/sdk/utils/async_executor.py +115 -0
- openhands/sdk/utils/async_utils.py +39 -0
- openhands/sdk/utils/cipher.py +68 -0
- openhands/sdk/utils/command.py +90 -0
- openhands/sdk/utils/deprecation.py +166 -0
- openhands/sdk/utils/github.py +44 -0
- openhands/sdk/utils/json.py +48 -0
- openhands/sdk/utils/models.py +570 -0
- openhands/sdk/utils/paging.py +63 -0
- openhands/sdk/utils/pydantic_diff.py +85 -0
- openhands/sdk/utils/pydantic_secrets.py +64 -0
- openhands/sdk/utils/truncate.py +117 -0
- openhands/sdk/utils/visualize.py +58 -0
- openhands/sdk/workspace/__init__.py +17 -0
- openhands/sdk/workspace/base.py +158 -0
- openhands/sdk/workspace/local.py +189 -0
- openhands/sdk/workspace/models.py +35 -0
- openhands/sdk/workspace/remote/__init__.py +8 -0
- openhands/sdk/workspace/remote/async_remote_workspace.py +149 -0
- openhands/sdk/workspace/remote/base.py +164 -0
- openhands/sdk/workspace/remote/remote_workspace_mixin.py +323 -0
- openhands/sdk/workspace/workspace.py +49 -0
- openhands_sdk-1.7.3.dist-info/METADATA +17 -0
- openhands_sdk-1.7.3.dist-info/RECORD +180 -0
- openhands_sdk-1.7.3.dist-info/WHEEL +5 -0
- openhands_sdk-1.7.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC
|
|
6
|
+
from typing import Annotated, Any, ClassVar, Literal, NoReturn, Self, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
Discriminator,
|
|
11
|
+
Field,
|
|
12
|
+
Tag,
|
|
13
|
+
TypeAdapter,
|
|
14
|
+
ValidationError,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_core import ErrorDetails
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
_rebuild_required = True
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_abstract(type_: type) -> bool:
|
|
24
|
+
"""Determine whether the class directly extends ABC or contains abstract methods"""
|
|
25
|
+
try:
|
|
26
|
+
return inspect.isabstract(type_) or ABC in type_.__bases__
|
|
27
|
+
except Exception:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_all_subclasses(cls) -> set[type]:
|
|
32
|
+
"""
|
|
33
|
+
Recursively finds and returns all (loaded) subclasses of a given class.
|
|
34
|
+
"""
|
|
35
|
+
result = set()
|
|
36
|
+
for subclass in cls.__subclasses__():
|
|
37
|
+
result.add(subclass)
|
|
38
|
+
result.update(_get_all_subclasses(subclass))
|
|
39
|
+
return result
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def rebuild_all():
|
|
43
|
+
"""Rebuild all polymorphic classes."""
|
|
44
|
+
global _rebuild_required
|
|
45
|
+
_rebuild_required = False
|
|
46
|
+
for cls in _get_all_subclasses(OpenHandsModel):
|
|
47
|
+
cls.model_rebuild(force=True)
|
|
48
|
+
for cls in _get_all_subclasses(DiscriminatedUnionMixin):
|
|
49
|
+
cls.model_rebuild(force=True)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def kind_of(obj) -> str:
|
|
53
|
+
"""Get the string value for the kind tag"""
|
|
54
|
+
if isinstance(obj, dict):
|
|
55
|
+
return obj["kind"]
|
|
56
|
+
if not hasattr(obj, "__name__"):
|
|
57
|
+
obj = obj.__class__
|
|
58
|
+
return obj.__name__
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _create_enhanced_discriminated_union_error_message(
|
|
62
|
+
invalid_kind: str, cls_name: str, valid_kinds: list[str]
|
|
63
|
+
) -> str:
|
|
64
|
+
"""Create an enhanced error message for discriminated union validation failures."""
|
|
65
|
+
possible_kinds_str = ", ".join(sorted(valid_kinds)) if valid_kinds else "none"
|
|
66
|
+
return (
|
|
67
|
+
f"Unexpected kind '{invalid_kind}' for {cls_name}. "
|
|
68
|
+
f"Expected one of: {possible_kinds_str}. "
|
|
69
|
+
f"If you receive this error when trying to wrap a "
|
|
70
|
+
f"DiscriminatedUnion instance inside another pydantic model, "
|
|
71
|
+
f"you may need to use OpenHandsModel instead of BaseModel "
|
|
72
|
+
f"to make sure that an invalid schema has not been cached."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _extract_invalid_kind_from_validation_error(error: ErrorDetails) -> str:
|
|
77
|
+
"""Extract the invalid kind from a Pydantic validation error."""
|
|
78
|
+
input_value = error.get("input")
|
|
79
|
+
if input_value is not None and hasattr(input_value, "kind"):
|
|
80
|
+
return input_value.kind
|
|
81
|
+
elif isinstance(input_value, dict) and "kind" in input_value:
|
|
82
|
+
return input_value["kind"]
|
|
83
|
+
else:
|
|
84
|
+
return kind_of(input_value)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _handle_discriminated_union_validation_error(
|
|
88
|
+
validation_error: ValidationError, cls_name: str, valid_kinds: list[str]
|
|
89
|
+
) -> NoReturn:
|
|
90
|
+
"""Handle discriminated union validation errors with enhanced messages."""
|
|
91
|
+
for error in validation_error.errors():
|
|
92
|
+
if error.get("type") == "union_tag_invalid":
|
|
93
|
+
invalid_kind = _extract_invalid_kind_from_validation_error(error)
|
|
94
|
+
error_msg = _create_enhanced_discriminated_union_error_message(
|
|
95
|
+
invalid_kind, cls_name, valid_kinds
|
|
96
|
+
)
|
|
97
|
+
raise ValueError(error_msg) from validation_error
|
|
98
|
+
|
|
99
|
+
# If it's not a discriminated union error, re-raise the original error
|
|
100
|
+
raise validation_error
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_known_concrete_subclasses(cls) -> list[type]:
|
|
104
|
+
"""Recursively returns all concrete subclasses in a stable order,
|
|
105
|
+
without deduping classes that share the same (module, name)."""
|
|
106
|
+
out: list[type] = []
|
|
107
|
+
for sub in cls.__subclasses__():
|
|
108
|
+
# Recurse first so deeper classes appear after their parents
|
|
109
|
+
out.extend(get_known_concrete_subclasses(sub))
|
|
110
|
+
if not _is_abstract(sub):
|
|
111
|
+
out.append(sub)
|
|
112
|
+
|
|
113
|
+
# Use qualname to distinguish nested/local classes (like test-local Cat)
|
|
114
|
+
out.sort(key=lambda t: (t.__module__, getattr(t, "__qualname__", t.__name__)))
|
|
115
|
+
return out
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class OpenHandsModel(BaseModel):
|
|
119
|
+
"""
|
|
120
|
+
Tags a class where the which may be a discriminated union or contain fields
|
|
121
|
+
which contain a discriminated union. The first time an instance is initialized,
|
|
122
|
+
the schema is loaded, or a model is validated after a subclass is defined we
|
|
123
|
+
regenerate all the polymorphic mappings.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def model_post_init(self, _context):
|
|
127
|
+
_rebuild_if_required()
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def model_validate(cls, *args, **kwargs) -> Self:
|
|
131
|
+
_rebuild_if_required()
|
|
132
|
+
return super().model_validate(*args, **kwargs)
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def model_validate_json(cls, *args, **kwargs) -> Self:
|
|
136
|
+
_rebuild_if_required()
|
|
137
|
+
return super().model_validate_json(*args, **kwargs)
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
|
|
141
|
+
_rebuild_if_required()
|
|
142
|
+
return super().model_json_schema(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
def model_dump_json(self, **kwargs):
|
|
145
|
+
# This was overridden because it seems there is a bug where sometimes
|
|
146
|
+
# duplicate fields are produced by model_dump_json which does not appear
|
|
147
|
+
# in model_dump
|
|
148
|
+
kwargs["mode"] = "json"
|
|
149
|
+
return json.dumps(self.model_dump(**kwargs), ensure_ascii=False)
|
|
150
|
+
|
|
151
|
+
def __init_subclass__(cls, **kwargs):
|
|
152
|
+
"""
|
|
153
|
+
When a new subclass is defined, mark that we will need
|
|
154
|
+
to rebuild everything
|
|
155
|
+
"""
|
|
156
|
+
global _rebuild_required
|
|
157
|
+
_rebuild_required = True
|
|
158
|
+
|
|
159
|
+
return super().__init_subclass__(**kwargs)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class DiscriminatedUnionMixin(OpenHandsModel, ABC):
|
|
163
|
+
"""A Base class for members of tagged unions discriminated by the class name.
|
|
164
|
+
|
|
165
|
+
This class provides automatic subclass registration and discriminated union
|
|
166
|
+
functionality. Each subclass is automatically registered when defined and
|
|
167
|
+
can be used for polymorphic serialization/deserialization.
|
|
168
|
+
|
|
169
|
+
Child classes will automatically have a type field defined, which is used as a
|
|
170
|
+
discriminator for union types.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
__pydantic_core_schema__: ClassVar[Any]
|
|
174
|
+
__pydantic_validator__: ClassVar[Any]
|
|
175
|
+
__pydantic_serializer__: ClassVar[Any]
|
|
176
|
+
|
|
177
|
+
kind: str = Field(default="") # We dynamically update on a per class basis
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def resolve_kind(cls, kind: str) -> type:
|
|
181
|
+
for subclass in get_known_concrete_subclasses(cls):
|
|
182
|
+
if subclass.__name__ == kind:
|
|
183
|
+
return subclass
|
|
184
|
+
|
|
185
|
+
# Generate enhanced error message for unknown kind
|
|
186
|
+
valid_kinds = [
|
|
187
|
+
subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
|
|
188
|
+
]
|
|
189
|
+
error_msg = _create_enhanced_discriminated_union_error_message(
|
|
190
|
+
kind, cls.__name__, valid_kinds
|
|
191
|
+
)
|
|
192
|
+
raise ValueError(error_msg)
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def __get_pydantic_core_schema__(cls, source_type, handler):
|
|
196
|
+
"""Generate discriminated union schema for TypeAdapter compatibility."""
|
|
197
|
+
if cls.__name__ == "DiscriminatedUnionMixin":
|
|
198
|
+
return handler(source_type)
|
|
199
|
+
|
|
200
|
+
if _is_abstract(source_type):
|
|
201
|
+
_rebuild_if_required()
|
|
202
|
+
serializable_type = source_type.get_serializable_type()
|
|
203
|
+
# If there are subclasses, generate schema for the discriminated union
|
|
204
|
+
if serializable_type is not source_type:
|
|
205
|
+
from pydantic_core import core_schema
|
|
206
|
+
|
|
207
|
+
# Generate the base schema
|
|
208
|
+
base_schema = handler.generate_schema(serializable_type)
|
|
209
|
+
|
|
210
|
+
# Wrap it with a custom validation function that provides
|
|
211
|
+
# enhanced error messages
|
|
212
|
+
def validate_with_enhanced_error(value, handler_func, info): # noqa: ARG001
|
|
213
|
+
try:
|
|
214
|
+
return handler_func(value)
|
|
215
|
+
except ValidationError as e:
|
|
216
|
+
valid_kinds = [
|
|
217
|
+
subclass.__name__
|
|
218
|
+
for subclass in get_known_concrete_subclasses(source_type)
|
|
219
|
+
]
|
|
220
|
+
_handle_discriminated_union_validation_error(
|
|
221
|
+
e, source_type.__name__, valid_kinds
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Create a with_info_wrap_validator_function schema
|
|
225
|
+
return core_schema.with_info_wrap_validator_function(
|
|
226
|
+
validate_with_enhanced_error,
|
|
227
|
+
base_schema,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return handler(source_type)
|
|
231
|
+
|
|
232
|
+
@classmethod
|
|
233
|
+
def __get_pydantic_json_schema__(cls, core_schema, handler):
|
|
234
|
+
"""Add discriminator to OpenAPI schema and ensure component generation."""
|
|
235
|
+
json_schema = handler(core_schema)
|
|
236
|
+
|
|
237
|
+
# Add discriminator if this is a oneOf schema
|
|
238
|
+
if isinstance(json_schema, dict) and "oneOf" in json_schema:
|
|
239
|
+
# Add title for abstract classes to encourage separate component creation
|
|
240
|
+
if _is_abstract(cls) and "title" not in json_schema:
|
|
241
|
+
json_schema["title"] = cls.__name__
|
|
242
|
+
|
|
243
|
+
if "discriminator" not in json_schema:
|
|
244
|
+
mapping = {}
|
|
245
|
+
for option in json_schema["oneOf"]:
|
|
246
|
+
if "$ref" in option:
|
|
247
|
+
kind = option["$ref"].split("/")[-1]
|
|
248
|
+
mapping[kind] = option["$ref"]
|
|
249
|
+
|
|
250
|
+
if mapping:
|
|
251
|
+
json_schema["discriminator"] = {
|
|
252
|
+
"propertyName": "kind",
|
|
253
|
+
"mapping": mapping,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
return json_schema
|
|
257
|
+
|
|
258
|
+
@classmethod
|
|
259
|
+
def model_rebuild(
|
|
260
|
+
cls,
|
|
261
|
+
*,
|
|
262
|
+
force=False,
|
|
263
|
+
raise_errors=True,
|
|
264
|
+
_parent_namespace_depth=2,
|
|
265
|
+
_types_namespace=None,
|
|
266
|
+
):
|
|
267
|
+
if cls == DiscriminatedUnionMixin:
|
|
268
|
+
pass
|
|
269
|
+
if _is_abstract(cls):
|
|
270
|
+
subclasses = get_known_concrete_subclasses(cls)
|
|
271
|
+
kinds = [subclass.__name__ for subclass in subclasses]
|
|
272
|
+
if kinds:
|
|
273
|
+
kind_field = cls.model_fields["kind"]
|
|
274
|
+
kind_field.annotation = Literal[tuple(kinds)] # type: ignore
|
|
275
|
+
kind_field.default = kinds[0]
|
|
276
|
+
|
|
277
|
+
type_adapter = TypeAdapter(cls.get_serializable_type())
|
|
278
|
+
cls.__pydantic_core_schema__ = type_adapter.core_schema
|
|
279
|
+
cls.__pydantic_validator__ = type_adapter.validator
|
|
280
|
+
cls.__pydantic_serializer__ = type_adapter.serializer
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
return super().model_rebuild(
|
|
284
|
+
force=force,
|
|
285
|
+
raise_errors=raise_errors,
|
|
286
|
+
_parent_namespace_depth=_parent_namespace_depth,
|
|
287
|
+
_types_namespace=_types_namespace,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
@classmethod
|
|
291
|
+
def get_serializable_type(cls) -> type:
|
|
292
|
+
"""
|
|
293
|
+
Custom method to get the union of all currently loaded
|
|
294
|
+
non absract subclasses
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
# If the class is not abstract return self
|
|
298
|
+
if not _is_abstract(cls):
|
|
299
|
+
return cls
|
|
300
|
+
|
|
301
|
+
subclasses = list(get_known_concrete_subclasses(cls))
|
|
302
|
+
if not subclasses:
|
|
303
|
+
return cls
|
|
304
|
+
|
|
305
|
+
if len(subclasses) == 1:
|
|
306
|
+
# Returning the concrete type ensures Pydantic instantiates the subclass
|
|
307
|
+
# (e.g. Agent) rather than the abstract base (e.g. AgentBase) when there is
|
|
308
|
+
# only ONE concrete subclass.
|
|
309
|
+
return subclasses[0]
|
|
310
|
+
|
|
311
|
+
serializable_type = Annotated[
|
|
312
|
+
Union[*tuple(Annotated[t, Tag(t.__name__)] for t in subclasses)],
|
|
313
|
+
Discriminator(kind_of),
|
|
314
|
+
]
|
|
315
|
+
return serializable_type # type: ignore
|
|
316
|
+
|
|
317
|
+
@classmethod
|
|
318
|
+
def model_validate(cls, obj: Any, **kwargs) -> Self:
|
|
319
|
+
try:
|
|
320
|
+
if _is_abstract(cls):
|
|
321
|
+
resolved = cls.resolve_kind(kind_of(obj))
|
|
322
|
+
else:
|
|
323
|
+
resolved = super()
|
|
324
|
+
result = resolved.model_validate(obj, **kwargs)
|
|
325
|
+
return result # type: ignore
|
|
326
|
+
except ValidationError as e:
|
|
327
|
+
valid_kinds = [
|
|
328
|
+
subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
|
|
329
|
+
]
|
|
330
|
+
_handle_discriminated_union_validation_error(e, cls.__name__, valid_kinds)
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def model_validate_json(
|
|
334
|
+
cls,
|
|
335
|
+
json_data: str | bytes | bytearray,
|
|
336
|
+
**kwargs,
|
|
337
|
+
) -> Self:
|
|
338
|
+
data = json.loads(json_data)
|
|
339
|
+
if _is_abstract(cls):
|
|
340
|
+
resolved = cls.resolve_kind(kind_of(data))
|
|
341
|
+
else:
|
|
342
|
+
resolved = super()
|
|
343
|
+
result = resolved.model_validate(data, **kwargs)
|
|
344
|
+
return result # type: ignore
|
|
345
|
+
|
|
346
|
+
def __init_subclass__(cls, **kwargs):
|
|
347
|
+
super().__init_subclass__(**kwargs)
|
|
348
|
+
|
|
349
|
+
# If concrete, stamp kind Literal and collision check
|
|
350
|
+
if not _is_abstract(cls):
|
|
351
|
+
# 1) Stamp discriminator
|
|
352
|
+
cls.kind = cls.__name__
|
|
353
|
+
cls.__annotations__["kind"] = Literal[cls.__name__]
|
|
354
|
+
|
|
355
|
+
# 2) Collision check
|
|
356
|
+
mro = cls.mro()
|
|
357
|
+
union_class = mro[mro.index(DiscriminatedUnionMixin) - 1]
|
|
358
|
+
concretes = get_known_concrete_subclasses(union_class) # sorted list
|
|
359
|
+
kinds: dict[str, type] = {}
|
|
360
|
+
for sub in concretes:
|
|
361
|
+
k = kind_of(sub)
|
|
362
|
+
if k in kinds and kinds[k] is not sub:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"Duplicate kind detected for {union_class} : {cls}, {sub}"
|
|
365
|
+
)
|
|
366
|
+
kinds[k] = sub
|
|
367
|
+
|
|
368
|
+
# Rebuild any abstract union owners in the MRO that rely on subclass sets
|
|
369
|
+
for base in cls.mro():
|
|
370
|
+
# Stop when we pass ourselves
|
|
371
|
+
if base is cls:
|
|
372
|
+
continue
|
|
373
|
+
# Only rebuild abstract DiscriminatedUnion owners
|
|
374
|
+
if (
|
|
375
|
+
isinstance(base, type)
|
|
376
|
+
and issubclass(base, DiscriminatedUnionMixin)
|
|
377
|
+
and _is_abstract(base)
|
|
378
|
+
):
|
|
379
|
+
base.model_rebuild(force=True)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _rebuild_if_required():
|
|
383
|
+
if _rebuild_required:
|
|
384
|
+
rebuild_all()
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _extract_discriminated_unions(schema: dict) -> dict:
|
|
388
|
+
"""Extract inline discriminated unions as separate components.
|
|
389
|
+
|
|
390
|
+
Recursively scans the schema and extracts any inline discriminated union
|
|
391
|
+
(oneOf + discriminator + title) as a separate component, replacing it with a $ref.
|
|
392
|
+
Also deduplicates schemas with identical titles.
|
|
393
|
+
"""
|
|
394
|
+
import json
|
|
395
|
+
import re
|
|
396
|
+
from collections import defaultdict
|
|
397
|
+
|
|
398
|
+
if not isinstance(schema, dict):
|
|
399
|
+
return schema
|
|
400
|
+
|
|
401
|
+
# OpenAPI schema names must match this pattern
|
|
402
|
+
valid_name_pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
403
|
+
|
|
404
|
+
schemas = schema.get("components", {}).get("schemas", {})
|
|
405
|
+
extracted = {}
|
|
406
|
+
|
|
407
|
+
def _find_and_extract(obj, path=""):
|
|
408
|
+
if not isinstance(obj, dict):
|
|
409
|
+
return obj
|
|
410
|
+
|
|
411
|
+
# Extract inline discriminated unions
|
|
412
|
+
if "oneOf" in obj and "discriminator" in obj and "title" in obj:
|
|
413
|
+
title = obj["title"]
|
|
414
|
+
if (
|
|
415
|
+
title not in schemas
|
|
416
|
+
and title not in extracted
|
|
417
|
+
and valid_name_pattern.match(title)
|
|
418
|
+
):
|
|
419
|
+
extracted[title] = {
|
|
420
|
+
"oneOf": obj["oneOf"],
|
|
421
|
+
"discriminator": obj["discriminator"],
|
|
422
|
+
"title": title,
|
|
423
|
+
}
|
|
424
|
+
return {"$ref": f"#/components/schemas/{title}"}
|
|
425
|
+
|
|
426
|
+
# Recursively process nested structures
|
|
427
|
+
result = {}
|
|
428
|
+
for key, value in obj.items():
|
|
429
|
+
if isinstance(value, dict):
|
|
430
|
+
result[key] = _find_and_extract(value, f"{path}.{key}")
|
|
431
|
+
elif isinstance(value, list):
|
|
432
|
+
result[key] = [
|
|
433
|
+
_find_and_extract(item, f"{path}.{key}[]") for item in value
|
|
434
|
+
]
|
|
435
|
+
else:
|
|
436
|
+
result[key] = value
|
|
437
|
+
return result
|
|
438
|
+
|
|
439
|
+
schema = _find_and_extract(schema)
|
|
440
|
+
|
|
441
|
+
if extracted and "components" in schema and "schemas" in schema["components"]:
|
|
442
|
+
schema["components"]["schemas"].update(extracted)
|
|
443
|
+
|
|
444
|
+
# Deduplicate schemas with same title (prefer *-Output over *-Input over base)
|
|
445
|
+
schemas = schema.get("components", {}).get("schemas", {})
|
|
446
|
+
title_to_names = defaultdict(list)
|
|
447
|
+
for name, defn in schemas.items():
|
|
448
|
+
if isinstance(defn, dict):
|
|
449
|
+
title_to_names[defn.get("title", name)].append(name)
|
|
450
|
+
|
|
451
|
+
to_remove = {}
|
|
452
|
+
for title, names in title_to_names.items():
|
|
453
|
+
if len(names) > 1:
|
|
454
|
+
# Prefer: *-Output > *-Input > base name
|
|
455
|
+
keep = sorted(
|
|
456
|
+
names,
|
|
457
|
+
key=lambda n: (
|
|
458
|
+
0 if n.endswith("-Output") else 1 if n.endswith("-Input") else 2,
|
|
459
|
+
n,
|
|
460
|
+
),
|
|
461
|
+
)[0]
|
|
462
|
+
for name in names:
|
|
463
|
+
if name != keep:
|
|
464
|
+
to_remove[name] = keep
|
|
465
|
+
|
|
466
|
+
if to_remove:
|
|
467
|
+
schema_str = json.dumps(schema)
|
|
468
|
+
for old, new in to_remove.items():
|
|
469
|
+
schema_str = schema_str.replace(
|
|
470
|
+
f'"#/components/schemas/{old}"', f'"#/components/schemas/{new}"'
|
|
471
|
+
)
|
|
472
|
+
schema = json.loads(schema_str)
|
|
473
|
+
for old in to_remove:
|
|
474
|
+
schema["components"]["schemas"].pop(old, None)
|
|
475
|
+
|
|
476
|
+
return schema
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def _patch_fastapi_discriminated_union_support():
|
|
480
|
+
"""Patch FastAPI to handle discriminated union schemas without $ref.
|
|
481
|
+
|
|
482
|
+
This ensures discriminated unions from DiscriminatedUnionMixin work correctly
|
|
483
|
+
with FastAPI's OpenAPI schema generation. The patch prevents KeyError when
|
|
484
|
+
FastAPI encounters schemas without $ref keys (which discriminated unions use).
|
|
485
|
+
|
|
486
|
+
Also extracts inline discriminated unions as separate schema components for
|
|
487
|
+
better OpenAPI documentation and Swagger UI display.
|
|
488
|
+
|
|
489
|
+
Skips patching if SKIP_FASTAPI_DISCRIMINATED_UNION_FIX environment variable is set.
|
|
490
|
+
"""
|
|
491
|
+
# Skip patching if environment variable flag is defined
|
|
492
|
+
if os.environ.get("SKIP_FASTAPI_DISCRIMINATED_UNION_FIX"):
|
|
493
|
+
logger.debug(
|
|
494
|
+
"Skipping FastAPI discriminated union patch due to environment variable"
|
|
495
|
+
)
|
|
496
|
+
return
|
|
497
|
+
|
|
498
|
+
try:
|
|
499
|
+
import fastapi._compat.v2 as fastapi_v2
|
|
500
|
+
from fastapi import FastAPI
|
|
501
|
+
|
|
502
|
+
_original_remap = fastapi_v2._remap_definitions_and_field_mappings
|
|
503
|
+
|
|
504
|
+
def _patched_remap_definitions_and_field_mappings(**kwargs):
|
|
505
|
+
"""Patched version that handles schemas w/o $ref (discriminated unions)."""
|
|
506
|
+
field_mapping = kwargs.get("field_mapping", {})
|
|
507
|
+
model_name_map = kwargs.get("model_name_map", {})
|
|
508
|
+
|
|
509
|
+
# Build old_name -> new_name map, skipping schemas without $ref
|
|
510
|
+
old_name_to_new_name_map = {}
|
|
511
|
+
for field_key, schema in field_mapping.items():
|
|
512
|
+
model = field_key[0].type_
|
|
513
|
+
if model not in model_name_map:
|
|
514
|
+
continue
|
|
515
|
+
new_name = model_name_map[model]
|
|
516
|
+
|
|
517
|
+
# Skip schemas without $ref (discriminated unions)
|
|
518
|
+
if "$ref" not in schema:
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
old_name = schema["$ref"].split("/")[-1]
|
|
522
|
+
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
|
|
523
|
+
continue
|
|
524
|
+
old_name_to_new_name_map[old_name] = new_name
|
|
525
|
+
|
|
526
|
+
# Replace refs using FastAPI's helper
|
|
527
|
+
from fastapi._compat.v2 import _replace_refs
|
|
528
|
+
|
|
529
|
+
new_field_mapping = {}
|
|
530
|
+
for field_key, schema in field_mapping.items():
|
|
531
|
+
new_schema = _replace_refs(
|
|
532
|
+
schema=schema,
|
|
533
|
+
old_name_to_new_name_map=old_name_to_new_name_map,
|
|
534
|
+
)
|
|
535
|
+
new_field_mapping[field_key] = new_schema
|
|
536
|
+
|
|
537
|
+
definitions = kwargs.get("definitions", {})
|
|
538
|
+
new_definitions = {}
|
|
539
|
+
for key, value in definitions.items():
|
|
540
|
+
new_key = old_name_to_new_name_map.get(key, key)
|
|
541
|
+
new_value = _replace_refs(
|
|
542
|
+
schema=value,
|
|
543
|
+
old_name_to_new_name_map=old_name_to_new_name_map,
|
|
544
|
+
)
|
|
545
|
+
new_definitions[new_key] = new_value
|
|
546
|
+
|
|
547
|
+
return new_field_mapping, new_definitions
|
|
548
|
+
|
|
549
|
+
# Apply the patch
|
|
550
|
+
fastapi_v2._remap_definitions_and_field_mappings = (
|
|
551
|
+
_patched_remap_definitions_and_field_mappings
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Patch FastAPI.openapi() to extract discriminated unions
|
|
555
|
+
_original_openapi = FastAPI.openapi
|
|
556
|
+
|
|
557
|
+
def _patched_openapi(self):
|
|
558
|
+
"""Patched openapi() that extracts discriminated unions."""
|
|
559
|
+
schema = _original_openapi(self)
|
|
560
|
+
return _extract_discriminated_unions(schema)
|
|
561
|
+
|
|
562
|
+
FastAPI.openapi = _patched_openapi
|
|
563
|
+
|
|
564
|
+
except (ImportError, AttributeError):
|
|
565
|
+
# FastAPI not available or internal API changed
|
|
566
|
+
pass
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
# Always call the FastAPI patch after DiscriminatedUnionMixin definition
|
|
570
|
+
_patch_fastapi_discriminated_union_support()
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Pagination utilities for iterating over paginated search results."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
4
|
+
from typing import Any, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PageProtocol[T](Protocol):
|
|
8
|
+
"""Protocol for page objects returned by search functions.
|
|
9
|
+
|
|
10
|
+
All page objects should have:
|
|
11
|
+
- items: A list of items of type T
|
|
12
|
+
- next_page_id: Optional string for pagination
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
items: list[T]
|
|
16
|
+
next_page_id: str | None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def page_iterator[T](
|
|
20
|
+
search_func: Callable[..., Awaitable[PageProtocol[T]]],
|
|
21
|
+
*args: Any,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> AsyncGenerator[T, None]:
|
|
24
|
+
"""
|
|
25
|
+
Iterate over items from paginated search results.
|
|
26
|
+
|
|
27
|
+
This utility function handles pagination automatically by calling the search
|
|
28
|
+
function repeatedly with updated page_id parameters until all pages are
|
|
29
|
+
exhausted.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
search_func: An async function that returns a PageProtocol[T] object
|
|
33
|
+
with 'items' and 'next_page_id' attributes
|
|
34
|
+
*args: Positional arguments to pass to the search function
|
|
35
|
+
**kwargs: Keyword arguments to pass to the search function
|
|
36
|
+
|
|
37
|
+
Yields:
|
|
38
|
+
Individual items of type T from each page
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
async for event in page_iterator(event_service.search_events, limit=50):
|
|
42
|
+
await send_event(event, websocket)
|
|
43
|
+
|
|
44
|
+
async for conversation in page_iterator(
|
|
45
|
+
conversation_service.search_conversations,
|
|
46
|
+
execution_status=ConversationExecutionStatus.RUNNING
|
|
47
|
+
):
|
|
48
|
+
print(conversation.title)
|
|
49
|
+
"""
|
|
50
|
+
page_id = kwargs.pop("page_id", None)
|
|
51
|
+
|
|
52
|
+
while True:
|
|
53
|
+
# Call the search function with current page_id
|
|
54
|
+
page = await search_func(*args, page_id=page_id, **kwargs)
|
|
55
|
+
|
|
56
|
+
# Yield each item from the current page
|
|
57
|
+
for item in page.items:
|
|
58
|
+
yield item
|
|
59
|
+
|
|
60
|
+
# Check if there are more pages
|
|
61
|
+
page_id = page.next_page_id
|
|
62
|
+
if not page_id:
|
|
63
|
+
break
|