openhands-sdk 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/agent/agent.py +4 -1
- openhands/sdk/agent/base.py +53 -15
- openhands/sdk/context/condenser/__init__.py +2 -0
- openhands/sdk/context/condenser/base.py +59 -8
- openhands/sdk/context/condenser/llm_summarizing_condenser.py +38 -10
- openhands/sdk/context/skills/skill.py +336 -118
- openhands/sdk/context/view.py +2 -0
- openhands/sdk/conversation/impl/remote_conversation.py +99 -55
- openhands/sdk/conversation/state.py +9 -5
- openhands/sdk/llm/llm.py +1 -2
- openhands/sdk/llm/options/chat_options.py +4 -1
- openhands/sdk/llm/utils/verified_models.py +1 -1
- openhands/sdk/mcp/tool.py +3 -1
- openhands/sdk/tool/schema.py +6 -3
- openhands/sdk/utils/models.py +198 -472
- openhands/sdk/workspace/base.py +22 -0
- openhands/sdk/workspace/local.py +16 -0
- {openhands_sdk-1.7.3.dist-info → openhands_sdk-1.7.4.dist-info}/METADATA +2 -2
- {openhands_sdk-1.7.3.dist-info → openhands_sdk-1.7.4.dist-info}/RECORD +21 -21
- {openhands_sdk-1.7.3.dist-info → openhands_sdk-1.7.4.dist-info}/WHEEL +0 -0
- {openhands_sdk-1.7.3.dist-info → openhands_sdk-1.7.4.dist-info}/top_level.txt +0 -0
openhands/sdk/utils/models.py
CHANGED
|
@@ -1,23 +1,38 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import json
|
|
3
2
|
import logging
|
|
4
|
-
import
|
|
3
|
+
import threading
|
|
5
4
|
from abc import ABC
|
|
6
|
-
from typing import Annotated, Any,
|
|
5
|
+
from typing import Annotated, Any, Self, Union
|
|
7
6
|
|
|
8
7
|
from pydantic import (
|
|
9
8
|
BaseModel,
|
|
10
9
|
Discriminator,
|
|
11
|
-
|
|
10
|
+
ModelWrapValidatorHandler,
|
|
11
|
+
SerializationInfo,
|
|
12
|
+
SerializerFunctionWrapHandler,
|
|
12
13
|
Tag,
|
|
13
|
-
|
|
14
|
-
|
|
14
|
+
ValidationInfo,
|
|
15
|
+
computed_field,
|
|
16
|
+
model_serializer,
|
|
17
|
+
model_validator,
|
|
15
18
|
)
|
|
16
|
-
from
|
|
19
|
+
from pydantic.json_schema import JsonSchemaValue
|
|
20
|
+
from pydantic_core import CoreSchema
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
logger = logging.getLogger(__name__)
|
|
20
|
-
|
|
24
|
+
|
|
25
|
+
# Thread-local storage for tracking schemas currently being generated.
|
|
26
|
+
# This prevents infinite recursion when generating JSON schemas for
|
|
27
|
+
# discriminated unions that reference each other.
|
|
28
|
+
_thread_local = threading.local()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_schemas_in_progress() -> dict[type, JsonSchemaValue]:
|
|
32
|
+
"""Get the thread-local dict for tracking in-progress schema generation."""
|
|
33
|
+
if not hasattr(_thread_local, "schemas_in_progress"):
|
|
34
|
+
_thread_local.schemas_in_progress = {}
|
|
35
|
+
return _thread_local.schemas_in_progress
|
|
21
36
|
|
|
22
37
|
|
|
23
38
|
def _is_abstract(type_: type) -> bool:
|
|
@@ -28,27 +43,6 @@ def _is_abstract(type_: type) -> bool:
|
|
|
28
43
|
return False
|
|
29
44
|
|
|
30
45
|
|
|
31
|
-
def _get_all_subclasses(cls) -> set[type]:
|
|
32
|
-
"""
|
|
33
|
-
Recursively finds and returns all (loaded) subclasses of a given class.
|
|
34
|
-
"""
|
|
35
|
-
result = set()
|
|
36
|
-
for subclass in cls.__subclasses__():
|
|
37
|
-
result.add(subclass)
|
|
38
|
-
result.update(_get_all_subclasses(subclass))
|
|
39
|
-
return result
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def rebuild_all():
|
|
43
|
-
"""Rebuild all polymorphic classes."""
|
|
44
|
-
global _rebuild_required
|
|
45
|
-
_rebuild_required = False
|
|
46
|
-
for cls in _get_all_subclasses(OpenHandsModel):
|
|
47
|
-
cls.model_rebuild(force=True)
|
|
48
|
-
for cls in _get_all_subclasses(DiscriminatedUnionMixin):
|
|
49
|
-
cls.model_rebuild(force=True)
|
|
50
|
-
|
|
51
|
-
|
|
52
46
|
def kind_of(obj) -> str:
|
|
53
47
|
"""Get the string value for the kind tag"""
|
|
54
48
|
if isinstance(obj, dict):
|
|
@@ -58,46 +52,15 @@ def kind_of(obj) -> str:
|
|
|
58
52
|
return obj.__name__
|
|
59
53
|
|
|
60
54
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
)
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
f"DiscriminatedUnion instance inside another pydantic model, "
|
|
71
|
-
f"you may need to use OpenHandsModel instead of BaseModel "
|
|
72
|
-
f"to make sure that an invalid schema has not been cached."
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _extract_invalid_kind_from_validation_error(error: ErrorDetails) -> str:
|
|
77
|
-
"""Extract the invalid kind from a Pydantic validation error."""
|
|
78
|
-
input_value = error.get("input")
|
|
79
|
-
if input_value is not None and hasattr(input_value, "kind"):
|
|
80
|
-
return input_value.kind
|
|
81
|
-
elif isinstance(input_value, dict) and "kind" in input_value:
|
|
82
|
-
return input_value["kind"]
|
|
83
|
-
else:
|
|
84
|
-
return kind_of(input_value)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _handle_discriminated_union_validation_error(
|
|
88
|
-
validation_error: ValidationError, cls_name: str, valid_kinds: list[str]
|
|
89
|
-
) -> NoReturn:
|
|
90
|
-
"""Handle discriminated union validation errors with enhanced messages."""
|
|
91
|
-
for error in validation_error.errors():
|
|
92
|
-
if error.get("type") == "union_tag_invalid":
|
|
93
|
-
invalid_kind = _extract_invalid_kind_from_validation_error(error)
|
|
94
|
-
error_msg = _create_enhanced_discriminated_union_error_message(
|
|
95
|
-
invalid_kind, cls_name, valid_kinds
|
|
96
|
-
)
|
|
97
|
-
raise ValueError(error_msg) from validation_error
|
|
98
|
-
|
|
99
|
-
# If it's not a discriminated union error, re-raise the original error
|
|
100
|
-
raise validation_error
|
|
55
|
+
def _get_all_subclasses(cls) -> set[type]:
|
|
56
|
+
"""
|
|
57
|
+
Recursively finds and returns all (loaded) subclasses of a given class.
|
|
58
|
+
"""
|
|
59
|
+
result = set()
|
|
60
|
+
for subclass in cls.__subclasses__():
|
|
61
|
+
result.add(subclass)
|
|
62
|
+
result.update(_get_all_subclasses(subclass))
|
|
63
|
+
return result
|
|
101
64
|
|
|
102
65
|
|
|
103
66
|
def get_known_concrete_subclasses(cls) -> list[type]:
|
|
@@ -115,176 +78,194 @@ def get_known_concrete_subclasses(cls) -> list[type]:
|
|
|
115
78
|
return out
|
|
116
79
|
|
|
117
80
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
81
|
+
def _get_checked_concrete_subclasses(cls: type) -> dict[str, type]:
|
|
82
|
+
result = {}
|
|
83
|
+
for sub in get_known_concrete_subclasses(cls):
|
|
84
|
+
existing = result.get(sub.__name__)
|
|
85
|
+
if existing:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Duplicate class definition for {cls.__module__}.{cls.__name__}: "
|
|
88
|
+
f"{existing.__module__}.{existing.__name__} : "
|
|
89
|
+
f"{sub.__module__}.{sub.__name__}"
|
|
90
|
+
)
|
|
91
|
+
if "<locals>" in sub.__qualname__:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Local classes not supported! {sub.__module__}.{sub.__name__} "
|
|
94
|
+
f"/ {cls.__module__}.{cls.__name__} "
|
|
95
|
+
"(Since they may not exist at deserialization time)"
|
|
96
|
+
)
|
|
97
|
+
result[sub.__name__] = sub
|
|
98
|
+
return result
|
|
125
99
|
|
|
126
|
-
def model_post_init(self, _context):
|
|
127
|
-
_rebuild_if_required()
|
|
128
100
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
_rebuild_if_required()
|
|
132
|
-
return super().model_validate(*args, **kwargs)
|
|
101
|
+
class OpenHandsModel(BaseModel):
|
|
102
|
+
"""Deprecated: This class exists only for backward compatibility.
|
|
133
103
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
_rebuild_if_required()
|
|
137
|
-
return super().model_validate_json(*args, **kwargs)
|
|
104
|
+
This class is no longer required for discriminated union support.
|
|
105
|
+
New code should extend pydantic.BaseModel directly instead of OpenHandsModel.
|
|
138
106
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return super().model_json_schema(*args, **kwargs)
|
|
143
|
-
|
|
144
|
-
def model_dump_json(self, **kwargs):
|
|
145
|
-
# This was overridden because it seems there is a bug where sometimes
|
|
146
|
-
# duplicate fields are produced by model_dump_json which does not appear
|
|
147
|
-
# in model_dump
|
|
148
|
-
kwargs["mode"] = "json"
|
|
149
|
-
return json.dumps(self.model_dump(**kwargs), ensure_ascii=False)
|
|
150
|
-
|
|
151
|
-
def __init_subclass__(cls, **kwargs):
|
|
152
|
-
"""
|
|
153
|
-
When a new subclass is defined, mark that we will need
|
|
154
|
-
to rebuild everything
|
|
155
|
-
"""
|
|
156
|
-
global _rebuild_required
|
|
157
|
-
_rebuild_required = True
|
|
107
|
+
Existing code that extends OpenHandsModel will continue to work, but
|
|
108
|
+
migration to BaseModel is recommended.
|
|
109
|
+
"""
|
|
158
110
|
|
|
159
|
-
return super().__init_subclass__(**kwargs)
|
|
160
111
|
|
|
112
|
+
class DiscriminatedUnionMixin(OpenHandsModel):
|
|
113
|
+
@computed_field
|
|
114
|
+
@property
|
|
115
|
+
def kind(self) -> str:
|
|
116
|
+
return self.__class__.__name__
|
|
161
117
|
|
|
162
|
-
|
|
163
|
-
|
|
118
|
+
@model_validator(mode="wrap")
|
|
119
|
+
@classmethod
|
|
120
|
+
def _validate_subtype(
|
|
121
|
+
cls, data: Any, handler: ModelWrapValidatorHandler[Self], info: ValidationInfo
|
|
122
|
+
) -> Self:
|
|
123
|
+
if isinstance(data, cls):
|
|
124
|
+
return data
|
|
125
|
+
kind = data.pop("kind", None)
|
|
126
|
+
if not _is_abstract(cls):
|
|
127
|
+
# Sanity check: if we're validating a concrete class directly,
|
|
128
|
+
# the kind (if provided) should match the class name. This should
|
|
129
|
+
# always be true at this point since resolve_kind() would have
|
|
130
|
+
# already routed to the correct subclass.
|
|
131
|
+
assert kind is None or kind == cls.__name__
|
|
132
|
+
return handler(data)
|
|
133
|
+
if kind is None:
|
|
134
|
+
subclasses = _get_checked_concrete_subclasses(cls)
|
|
135
|
+
if not subclasses:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"No kinds defined for {cls.__module__}.{cls.__name__}"
|
|
138
|
+
)
|
|
139
|
+
elif len(subclasses) == 1:
|
|
140
|
+
# If there is ony 1 possible implementation, then we do not need
|
|
141
|
+
# to state the kind explicitly - it can only be this!
|
|
142
|
+
kind = next(iter(subclasses))
|
|
143
|
+
else:
|
|
144
|
+
# There is more than 1 kind defined but the input did not specify
|
|
145
|
+
# This will cause an error to be raised
|
|
146
|
+
kind = ""
|
|
147
|
+
subclass = cls.resolve_kind(kind)
|
|
148
|
+
return subclass.model_validate(data, context=info.context)
|
|
149
|
+
|
|
150
|
+
@model_serializer(mode="wrap")
|
|
151
|
+
def _serialize_by_kind(
|
|
152
|
+
self, handler: SerializerFunctionWrapHandler, info: SerializationInfo
|
|
153
|
+
):
|
|
154
|
+
if isinstance(self, dict):
|
|
155
|
+
# Sometimes pydantic passes a dict in here.
|
|
156
|
+
return self
|
|
157
|
+
if self._is_handler_for_current_class(handler):
|
|
158
|
+
result = handler(self)
|
|
159
|
+
return result
|
|
160
|
+
|
|
161
|
+
# Delegate to the implementing class
|
|
162
|
+
result = self.model_dump(
|
|
163
|
+
mode=info.mode,
|
|
164
|
+
context=info.context,
|
|
165
|
+
by_alias=info.by_alias,
|
|
166
|
+
exclude_unset=info.exclude_unset,
|
|
167
|
+
exclude_defaults=info.exclude_defaults,
|
|
168
|
+
exclude_none=info.exclude_none,
|
|
169
|
+
exclude_computed_fields=info.exclude_computed_fields,
|
|
170
|
+
round_trip=info.round_trip,
|
|
171
|
+
serialize_as_any=info.serialize_as_any,
|
|
172
|
+
)
|
|
173
|
+
return result
|
|
164
174
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
175
|
+
def _is_handler_for_current_class(
|
|
176
|
+
self, handler: SerializerFunctionWrapHandler
|
|
177
|
+
) -> bool:
|
|
178
|
+
"""Check if the handler is for this class by parsing its repr string.
|
|
168
179
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
180
|
+
WARNING: This is a fragile approach that relies on Pydantic's internal
|
|
181
|
+
repr format for SerializerFunctionWrapHandler. The handler is a Pydantic
|
|
182
|
+
wrapper around a Rust function that provides no public API for determining
|
|
183
|
+
which class it serializes. Parsing the repr string is the only available
|
|
184
|
+
mechanism.
|
|
172
185
|
|
|
173
|
-
|
|
174
|
-
__pydantic_validator__: ClassVar[Any]
|
|
175
|
-
__pydantic_serializer__: ClassVar[Any]
|
|
186
|
+
Expected format: `SerializationCallable(serializer=<ClassName>)`
|
|
176
187
|
|
|
177
|
-
|
|
188
|
+
If Pydantic changes this format, multiple unit tests will fail immediately,
|
|
189
|
+
including tests in test_discriminated_union.py that verify serialization
|
|
190
|
+
behavior across the class hierarchy.
|
|
191
|
+
"""
|
|
192
|
+
# should be in the format `SerializationCallable(serializer=<NAME>)`
|
|
193
|
+
repr_str = str(handler)
|
|
178
194
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
for subclass in get_known_concrete_subclasses(cls):
|
|
182
|
-
if subclass.__name__ == kind:
|
|
183
|
-
return subclass
|
|
184
|
-
|
|
185
|
-
# Generate enhanced error message for unknown kind
|
|
186
|
-
valid_kinds = [
|
|
187
|
-
subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
|
|
188
|
-
]
|
|
189
|
-
error_msg = _create_enhanced_discriminated_union_error_message(
|
|
190
|
-
kind, cls.__name__, valid_kinds
|
|
191
|
-
)
|
|
192
|
-
raise ValueError(error_msg)
|
|
195
|
+
# Get everything after =
|
|
196
|
+
_, name = repr_str.split("=", 1)
|
|
193
197
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
"""Generate discriminated union schema for TypeAdapter compatibility."""
|
|
197
|
-
if cls.__name__ == "DiscriminatedUnionMixin":
|
|
198
|
-
return handler(source_type)
|
|
199
|
-
|
|
200
|
-
if _is_abstract(source_type):
|
|
201
|
-
_rebuild_if_required()
|
|
202
|
-
serializable_type = source_type.get_serializable_type()
|
|
203
|
-
# If there are subclasses, generate schema for the discriminated union
|
|
204
|
-
if serializable_type is not source_type:
|
|
205
|
-
from pydantic_core import core_schema
|
|
206
|
-
|
|
207
|
-
# Generate the base schema
|
|
208
|
-
base_schema = handler.generate_schema(serializable_type)
|
|
209
|
-
|
|
210
|
-
# Wrap it with a custom validation function that provides
|
|
211
|
-
# enhanced error messages
|
|
212
|
-
def validate_with_enhanced_error(value, handler_func, info): # noqa: ARG001
|
|
213
|
-
try:
|
|
214
|
-
return handler_func(value)
|
|
215
|
-
except ValidationError as e:
|
|
216
|
-
valid_kinds = [
|
|
217
|
-
subclass.__name__
|
|
218
|
-
for subclass in get_known_concrete_subclasses(source_type)
|
|
219
|
-
]
|
|
220
|
-
_handle_discriminated_union_validation_error(
|
|
221
|
-
e, source_type.__name__, valid_kinds
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
# Create a with_info_wrap_validator_function schema
|
|
225
|
-
return core_schema.with_info_wrap_validator_function(
|
|
226
|
-
validate_with_enhanced_error,
|
|
227
|
-
base_schema,
|
|
228
|
-
)
|
|
198
|
+
# Cut off the )
|
|
199
|
+
name = name[:-1]
|
|
229
200
|
|
|
230
|
-
|
|
201
|
+
result = self.__class__.__name__ == name
|
|
202
|
+
return result
|
|
231
203
|
|
|
232
204
|
@classmethod
|
|
233
|
-
def __get_pydantic_json_schema__(
|
|
234
|
-
|
|
235
|
-
|
|
205
|
+
def __get_pydantic_json_schema__(
|
|
206
|
+
cls, core_schema: CoreSchema, handler: Any
|
|
207
|
+
) -> JsonSchemaValue:
|
|
208
|
+
schemas_in_progress = _get_schemas_in_progress()
|
|
209
|
+
|
|
210
|
+
# First we check if we are already generating a schema
|
|
211
|
+
schema = schemas_in_progress.get(cls)
|
|
212
|
+
if schema:
|
|
213
|
+
return schema
|
|
214
|
+
|
|
215
|
+
# Set a temp schema to prevent infinite recursion
|
|
216
|
+
schemas_in_progress[cls] = {"$ref": f"#/$defs/{cls.__name__}"}
|
|
217
|
+
try:
|
|
218
|
+
if _is_abstract(cls):
|
|
219
|
+
subclasses = _get_checked_concrete_subclasses(cls)
|
|
220
|
+
if not subclasses:
|
|
221
|
+
raise ValueError(f"No subclasses defined for {cls.__name__}")
|
|
222
|
+
if len(subclasses) == 1:
|
|
223
|
+
# Use the shared generator for single subclass too
|
|
224
|
+
gen = handler.generate_json_schema
|
|
225
|
+
sub_schema = gen.generate_inner(
|
|
226
|
+
next(iter(subclasses.values())).__pydantic_core_schema__
|
|
227
|
+
)
|
|
228
|
+
return sub_schema
|
|
236
229
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
230
|
+
# Use the shared generator to properly register definitions
|
|
231
|
+
gen = handler.generate_json_schema
|
|
232
|
+
schemas = []
|
|
233
|
+
for sub in subclasses.values():
|
|
234
|
+
sub_schema = gen.generate_inner(sub.__pydantic_core_schema__)
|
|
235
|
+
schemas.append(sub_schema)
|
|
242
236
|
|
|
243
|
-
|
|
237
|
+
# Build discriminator mapping from $ref schemas
|
|
244
238
|
mapping = {}
|
|
245
|
-
for option in
|
|
239
|
+
for option in schemas:
|
|
246
240
|
if "$ref" in option:
|
|
247
241
|
kind = option["$ref"].split("/")[-1]
|
|
248
242
|
mapping[kind] = option["$ref"]
|
|
249
243
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
244
|
+
schema = {
|
|
245
|
+
"oneOf": schemas,
|
|
246
|
+
"discriminator": {"propertyName": "kind", "mapping": mapping},
|
|
247
|
+
}
|
|
248
|
+
else:
|
|
249
|
+
schema = handler(core_schema)
|
|
250
|
+
schema["properties"]["kind"] = {
|
|
251
|
+
"const": cls.__name__,
|
|
252
|
+
"title": "Kind",
|
|
253
|
+
"type": "string",
|
|
254
|
+
}
|
|
255
|
+
finally:
|
|
256
|
+
# Reset temp schema
|
|
257
|
+
schemas_in_progress.pop(cls)
|
|
258
|
+
return schema
|
|
257
259
|
|
|
258
260
|
@classmethod
|
|
259
|
-
def
|
|
260
|
-
cls
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
if cls == DiscriminatedUnionMixin:
|
|
268
|
-
pass
|
|
269
|
-
if _is_abstract(cls):
|
|
270
|
-
subclasses = get_known_concrete_subclasses(cls)
|
|
271
|
-
kinds = [subclass.__name__ for subclass in subclasses]
|
|
272
|
-
if kinds:
|
|
273
|
-
kind_field = cls.model_fields["kind"]
|
|
274
|
-
kind_field.annotation = Literal[tuple(kinds)] # type: ignore
|
|
275
|
-
kind_field.default = kinds[0]
|
|
276
|
-
|
|
277
|
-
type_adapter = TypeAdapter(cls.get_serializable_type())
|
|
278
|
-
cls.__pydantic_core_schema__ = type_adapter.core_schema
|
|
279
|
-
cls.__pydantic_validator__ = type_adapter.validator
|
|
280
|
-
cls.__pydantic_serializer__ = type_adapter.serializer
|
|
281
|
-
return
|
|
282
|
-
|
|
283
|
-
return super().model_rebuild(
|
|
284
|
-
force=force,
|
|
285
|
-
raise_errors=raise_errors,
|
|
286
|
-
_parent_namespace_depth=_parent_namespace_depth,
|
|
287
|
-
_types_namespace=_types_namespace,
|
|
261
|
+
def resolve_kind(cls, kind: str) -> type[Self]:
|
|
262
|
+
subclasses = _get_checked_concrete_subclasses(cls)
|
|
263
|
+
subclass = subclasses.get(kind)
|
|
264
|
+
if subclass:
|
|
265
|
+
return subclass
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Unknown kind '{kind}' for {cls.__module__}.{cls.__name__}; "
|
|
268
|
+
f"Expected one of: {list(subclasses)}"
|
|
288
269
|
)
|
|
289
270
|
|
|
290
271
|
@classmethod
|
|
@@ -298,7 +279,7 @@ class DiscriminatedUnionMixin(OpenHandsModel, ABC):
|
|
|
298
279
|
if not _is_abstract(cls):
|
|
299
280
|
return cls
|
|
300
281
|
|
|
301
|
-
subclasses =
|
|
282
|
+
subclasses = _get_checked_concrete_subclasses(cls)
|
|
302
283
|
if not subclasses:
|
|
303
284
|
return cls
|
|
304
285
|
|
|
@@ -306,265 +287,10 @@ class DiscriminatedUnionMixin(OpenHandsModel, ABC):
|
|
|
306
287
|
# Returning the concrete type ensures Pydantic instantiates the subclass
|
|
307
288
|
# (e.g. Agent) rather than the abstract base (e.g. AgentBase) when there is
|
|
308
289
|
# only ONE concrete subclass.
|
|
309
|
-
return subclasses
|
|
290
|
+
return next(iter(subclasses.values()))
|
|
310
291
|
|
|
311
292
|
serializable_type = Annotated[
|
|
312
|
-
Union[*tuple(Annotated[t, Tag(
|
|
293
|
+
Union[*tuple(Annotated[t, Tag(n)] for n, t in subclasses.items())],
|
|
313
294
|
Discriminator(kind_of),
|
|
314
295
|
]
|
|
315
296
|
return serializable_type # type: ignore
|
|
316
|
-
|
|
317
|
-
@classmethod
|
|
318
|
-
def model_validate(cls, obj: Any, **kwargs) -> Self:
|
|
319
|
-
try:
|
|
320
|
-
if _is_abstract(cls):
|
|
321
|
-
resolved = cls.resolve_kind(kind_of(obj))
|
|
322
|
-
else:
|
|
323
|
-
resolved = super()
|
|
324
|
-
result = resolved.model_validate(obj, **kwargs)
|
|
325
|
-
return result # type: ignore
|
|
326
|
-
except ValidationError as e:
|
|
327
|
-
valid_kinds = [
|
|
328
|
-
subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
|
|
329
|
-
]
|
|
330
|
-
_handle_discriminated_union_validation_error(e, cls.__name__, valid_kinds)
|
|
331
|
-
|
|
332
|
-
@classmethod
|
|
333
|
-
def model_validate_json(
|
|
334
|
-
cls,
|
|
335
|
-
json_data: str | bytes | bytearray,
|
|
336
|
-
**kwargs,
|
|
337
|
-
) -> Self:
|
|
338
|
-
data = json.loads(json_data)
|
|
339
|
-
if _is_abstract(cls):
|
|
340
|
-
resolved = cls.resolve_kind(kind_of(data))
|
|
341
|
-
else:
|
|
342
|
-
resolved = super()
|
|
343
|
-
result = resolved.model_validate(data, **kwargs)
|
|
344
|
-
return result # type: ignore
|
|
345
|
-
|
|
346
|
-
def __init_subclass__(cls, **kwargs):
|
|
347
|
-
super().__init_subclass__(**kwargs)
|
|
348
|
-
|
|
349
|
-
# If concrete, stamp kind Literal and collision check
|
|
350
|
-
if not _is_abstract(cls):
|
|
351
|
-
# 1) Stamp discriminator
|
|
352
|
-
cls.kind = cls.__name__
|
|
353
|
-
cls.__annotations__["kind"] = Literal[cls.__name__]
|
|
354
|
-
|
|
355
|
-
# 2) Collision check
|
|
356
|
-
mro = cls.mro()
|
|
357
|
-
union_class = mro[mro.index(DiscriminatedUnionMixin) - 1]
|
|
358
|
-
concretes = get_known_concrete_subclasses(union_class) # sorted list
|
|
359
|
-
kinds: dict[str, type] = {}
|
|
360
|
-
for sub in concretes:
|
|
361
|
-
k = kind_of(sub)
|
|
362
|
-
if k in kinds and kinds[k] is not sub:
|
|
363
|
-
raise ValueError(
|
|
364
|
-
f"Duplicate kind detected for {union_class} : {cls}, {sub}"
|
|
365
|
-
)
|
|
366
|
-
kinds[k] = sub
|
|
367
|
-
|
|
368
|
-
# Rebuild any abstract union owners in the MRO that rely on subclass sets
|
|
369
|
-
for base in cls.mro():
|
|
370
|
-
# Stop when we pass ourselves
|
|
371
|
-
if base is cls:
|
|
372
|
-
continue
|
|
373
|
-
# Only rebuild abstract DiscriminatedUnion owners
|
|
374
|
-
if (
|
|
375
|
-
isinstance(base, type)
|
|
376
|
-
and issubclass(base, DiscriminatedUnionMixin)
|
|
377
|
-
and _is_abstract(base)
|
|
378
|
-
):
|
|
379
|
-
base.model_rebuild(force=True)
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
def _rebuild_if_required():
|
|
383
|
-
if _rebuild_required:
|
|
384
|
-
rebuild_all()
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def _extract_discriminated_unions(schema: dict) -> dict:
|
|
388
|
-
"""Extract inline discriminated unions as separate components.
|
|
389
|
-
|
|
390
|
-
Recursively scans the schema and extracts any inline discriminated union
|
|
391
|
-
(oneOf + discriminator + title) as a separate component, replacing it with a $ref.
|
|
392
|
-
Also deduplicates schemas with identical titles.
|
|
393
|
-
"""
|
|
394
|
-
import json
|
|
395
|
-
import re
|
|
396
|
-
from collections import defaultdict
|
|
397
|
-
|
|
398
|
-
if not isinstance(schema, dict):
|
|
399
|
-
return schema
|
|
400
|
-
|
|
401
|
-
# OpenAPI schema names must match this pattern
|
|
402
|
-
valid_name_pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
403
|
-
|
|
404
|
-
schemas = schema.get("components", {}).get("schemas", {})
|
|
405
|
-
extracted = {}
|
|
406
|
-
|
|
407
|
-
def _find_and_extract(obj, path=""):
|
|
408
|
-
if not isinstance(obj, dict):
|
|
409
|
-
return obj
|
|
410
|
-
|
|
411
|
-
# Extract inline discriminated unions
|
|
412
|
-
if "oneOf" in obj and "discriminator" in obj and "title" in obj:
|
|
413
|
-
title = obj["title"]
|
|
414
|
-
if (
|
|
415
|
-
title not in schemas
|
|
416
|
-
and title not in extracted
|
|
417
|
-
and valid_name_pattern.match(title)
|
|
418
|
-
):
|
|
419
|
-
extracted[title] = {
|
|
420
|
-
"oneOf": obj["oneOf"],
|
|
421
|
-
"discriminator": obj["discriminator"],
|
|
422
|
-
"title": title,
|
|
423
|
-
}
|
|
424
|
-
return {"$ref": f"#/components/schemas/{title}"}
|
|
425
|
-
|
|
426
|
-
# Recursively process nested structures
|
|
427
|
-
result = {}
|
|
428
|
-
for key, value in obj.items():
|
|
429
|
-
if isinstance(value, dict):
|
|
430
|
-
result[key] = _find_and_extract(value, f"{path}.{key}")
|
|
431
|
-
elif isinstance(value, list):
|
|
432
|
-
result[key] = [
|
|
433
|
-
_find_and_extract(item, f"{path}.{key}[]") for item in value
|
|
434
|
-
]
|
|
435
|
-
else:
|
|
436
|
-
result[key] = value
|
|
437
|
-
return result
|
|
438
|
-
|
|
439
|
-
schema = _find_and_extract(schema)
|
|
440
|
-
|
|
441
|
-
if extracted and "components" in schema and "schemas" in schema["components"]:
|
|
442
|
-
schema["components"]["schemas"].update(extracted)
|
|
443
|
-
|
|
444
|
-
# Deduplicate schemas with same title (prefer *-Output over *-Input over base)
|
|
445
|
-
schemas = schema.get("components", {}).get("schemas", {})
|
|
446
|
-
title_to_names = defaultdict(list)
|
|
447
|
-
for name, defn in schemas.items():
|
|
448
|
-
if isinstance(defn, dict):
|
|
449
|
-
title_to_names[defn.get("title", name)].append(name)
|
|
450
|
-
|
|
451
|
-
to_remove = {}
|
|
452
|
-
for title, names in title_to_names.items():
|
|
453
|
-
if len(names) > 1:
|
|
454
|
-
# Prefer: *-Output > *-Input > base name
|
|
455
|
-
keep = sorted(
|
|
456
|
-
names,
|
|
457
|
-
key=lambda n: (
|
|
458
|
-
0 if n.endswith("-Output") else 1 if n.endswith("-Input") else 2,
|
|
459
|
-
n,
|
|
460
|
-
),
|
|
461
|
-
)[0]
|
|
462
|
-
for name in names:
|
|
463
|
-
if name != keep:
|
|
464
|
-
to_remove[name] = keep
|
|
465
|
-
|
|
466
|
-
if to_remove:
|
|
467
|
-
schema_str = json.dumps(schema)
|
|
468
|
-
for old, new in to_remove.items():
|
|
469
|
-
schema_str = schema_str.replace(
|
|
470
|
-
f'"#/components/schemas/{old}"', f'"#/components/schemas/{new}"'
|
|
471
|
-
)
|
|
472
|
-
schema = json.loads(schema_str)
|
|
473
|
-
for old in to_remove:
|
|
474
|
-
schema["components"]["schemas"].pop(old, None)
|
|
475
|
-
|
|
476
|
-
return schema
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
def _patch_fastapi_discriminated_union_support():
|
|
480
|
-
"""Patch FastAPI to handle discriminated union schemas without $ref.
|
|
481
|
-
|
|
482
|
-
This ensures discriminated unions from DiscriminatedUnionMixin work correctly
|
|
483
|
-
with FastAPI's OpenAPI schema generation. The patch prevents KeyError when
|
|
484
|
-
FastAPI encounters schemas without $ref keys (which discriminated unions use).
|
|
485
|
-
|
|
486
|
-
Also extracts inline discriminated unions as separate schema components for
|
|
487
|
-
better OpenAPI documentation and Swagger UI display.
|
|
488
|
-
|
|
489
|
-
Skips patching if SKIP_FASTAPI_DISCRIMINATED_UNION_FIX environment variable is set.
|
|
490
|
-
"""
|
|
491
|
-
# Skip patching if environment variable flag is defined
|
|
492
|
-
if os.environ.get("SKIP_FASTAPI_DISCRIMINATED_UNION_FIX"):
|
|
493
|
-
logger.debug(
|
|
494
|
-
"Skipping FastAPI discriminated union patch due to environment variable"
|
|
495
|
-
)
|
|
496
|
-
return
|
|
497
|
-
|
|
498
|
-
try:
|
|
499
|
-
import fastapi._compat.v2 as fastapi_v2
|
|
500
|
-
from fastapi import FastAPI
|
|
501
|
-
|
|
502
|
-
_original_remap = fastapi_v2._remap_definitions_and_field_mappings
|
|
503
|
-
|
|
504
|
-
def _patched_remap_definitions_and_field_mappings(**kwargs):
|
|
505
|
-
"""Patched version that handles schemas w/o $ref (discriminated unions)."""
|
|
506
|
-
field_mapping = kwargs.get("field_mapping", {})
|
|
507
|
-
model_name_map = kwargs.get("model_name_map", {})
|
|
508
|
-
|
|
509
|
-
# Build old_name -> new_name map, skipping schemas without $ref
|
|
510
|
-
old_name_to_new_name_map = {}
|
|
511
|
-
for field_key, schema in field_mapping.items():
|
|
512
|
-
model = field_key[0].type_
|
|
513
|
-
if model not in model_name_map:
|
|
514
|
-
continue
|
|
515
|
-
new_name = model_name_map[model]
|
|
516
|
-
|
|
517
|
-
# Skip schemas without $ref (discriminated unions)
|
|
518
|
-
if "$ref" not in schema:
|
|
519
|
-
continue
|
|
520
|
-
|
|
521
|
-
old_name = schema["$ref"].split("/")[-1]
|
|
522
|
-
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
|
|
523
|
-
continue
|
|
524
|
-
old_name_to_new_name_map[old_name] = new_name
|
|
525
|
-
|
|
526
|
-
# Replace refs using FastAPI's helper
|
|
527
|
-
from fastapi._compat.v2 import _replace_refs
|
|
528
|
-
|
|
529
|
-
new_field_mapping = {}
|
|
530
|
-
for field_key, schema in field_mapping.items():
|
|
531
|
-
new_schema = _replace_refs(
|
|
532
|
-
schema=schema,
|
|
533
|
-
old_name_to_new_name_map=old_name_to_new_name_map,
|
|
534
|
-
)
|
|
535
|
-
new_field_mapping[field_key] = new_schema
|
|
536
|
-
|
|
537
|
-
definitions = kwargs.get("definitions", {})
|
|
538
|
-
new_definitions = {}
|
|
539
|
-
for key, value in definitions.items():
|
|
540
|
-
new_key = old_name_to_new_name_map.get(key, key)
|
|
541
|
-
new_value = _replace_refs(
|
|
542
|
-
schema=value,
|
|
543
|
-
old_name_to_new_name_map=old_name_to_new_name_map,
|
|
544
|
-
)
|
|
545
|
-
new_definitions[new_key] = new_value
|
|
546
|
-
|
|
547
|
-
return new_field_mapping, new_definitions
|
|
548
|
-
|
|
549
|
-
# Apply the patch
|
|
550
|
-
fastapi_v2._remap_definitions_and_field_mappings = (
|
|
551
|
-
_patched_remap_definitions_and_field_mappings
|
|
552
|
-
)
|
|
553
|
-
|
|
554
|
-
# Patch FastAPI.openapi() to extract discriminated unions
|
|
555
|
-
_original_openapi = FastAPI.openapi
|
|
556
|
-
|
|
557
|
-
def _patched_openapi(self):
|
|
558
|
-
"""Patched openapi() that extracts discriminated unions."""
|
|
559
|
-
schema = _original_openapi(self)
|
|
560
|
-
return _extract_discriminated_unions(schema)
|
|
561
|
-
|
|
562
|
-
FastAPI.openapi = _patched_openapi
|
|
563
|
-
|
|
564
|
-
except (ImportError, AttributeError):
|
|
565
|
-
# FastAPI not available or internal API changed
|
|
566
|
-
pass
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
# Always call the FastAPI patch after DiscriminatedUnionMixin definition
|
|
570
|
-
_patch_fastapi_discriminated_union_support()
|