openhands-sdk 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,23 +1,38 @@
1
1
  import inspect
2
- import json
3
2
  import logging
4
- import os
3
+ import threading
5
4
  from abc import ABC
6
- from typing import Annotated, Any, ClassVar, Literal, NoReturn, Self, Union
5
+ from typing import Annotated, Any, Self, Union
7
6
 
8
7
  from pydantic import (
9
8
  BaseModel,
10
9
  Discriminator,
11
- Field,
10
+ ModelWrapValidatorHandler,
11
+ SerializationInfo,
12
+ SerializerFunctionWrapHandler,
12
13
  Tag,
13
- TypeAdapter,
14
- ValidationError,
14
+ ValidationInfo,
15
+ computed_field,
16
+ model_serializer,
17
+ model_validator,
15
18
  )
16
- from pydantic_core import ErrorDetails
19
+ from pydantic.json_schema import JsonSchemaValue
20
+ from pydantic_core import CoreSchema
17
21
 
18
22
 
19
23
  logger = logging.getLogger(__name__)
20
- _rebuild_required = True
24
+
25
+ # Thread-local storage for tracking schemas currently being generated.
26
+ # This prevents infinite recursion when generating JSON schemas for
27
+ # discriminated unions that reference each other.
28
+ _thread_local = threading.local()
29
+
30
+
31
+ def _get_schemas_in_progress() -> dict[type, JsonSchemaValue]:
32
+ """Get the thread-local dict for tracking in-progress schema generation."""
33
+ if not hasattr(_thread_local, "schemas_in_progress"):
34
+ _thread_local.schemas_in_progress = {}
35
+ return _thread_local.schemas_in_progress
21
36
 
22
37
 
23
38
  def _is_abstract(type_: type) -> bool:
@@ -28,27 +43,6 @@ def _is_abstract(type_: type) -> bool:
28
43
  return False
29
44
 
30
45
 
31
- def _get_all_subclasses(cls) -> set[type]:
32
- """
33
- Recursively finds and returns all (loaded) subclasses of a given class.
34
- """
35
- result = set()
36
- for subclass in cls.__subclasses__():
37
- result.add(subclass)
38
- result.update(_get_all_subclasses(subclass))
39
- return result
40
-
41
-
42
- def rebuild_all():
43
- """Rebuild all polymorphic classes."""
44
- global _rebuild_required
45
- _rebuild_required = False
46
- for cls in _get_all_subclasses(OpenHandsModel):
47
- cls.model_rebuild(force=True)
48
- for cls in _get_all_subclasses(DiscriminatedUnionMixin):
49
- cls.model_rebuild(force=True)
50
-
51
-
52
46
  def kind_of(obj) -> str:
53
47
  """Get the string value for the kind tag"""
54
48
  if isinstance(obj, dict):
@@ -58,46 +52,15 @@ def kind_of(obj) -> str:
58
52
  return obj.__name__
59
53
 
60
54
 
61
- def _create_enhanced_discriminated_union_error_message(
62
- invalid_kind: str, cls_name: str, valid_kinds: list[str]
63
- ) -> str:
64
- """Create an enhanced error message for discriminated union validation failures."""
65
- possible_kinds_str = ", ".join(sorted(valid_kinds)) if valid_kinds else "none"
66
- return (
67
- f"Unexpected kind '{invalid_kind}' for {cls_name}. "
68
- f"Expected one of: {possible_kinds_str}. "
69
- f"If you receive this error when trying to wrap a "
70
- f"DiscriminatedUnion instance inside another pydantic model, "
71
- f"you may need to use OpenHandsModel instead of BaseModel "
72
- f"to make sure that an invalid schema has not been cached."
73
- )
74
-
75
-
76
- def _extract_invalid_kind_from_validation_error(error: ErrorDetails) -> str:
77
- """Extract the invalid kind from a Pydantic validation error."""
78
- input_value = error.get("input")
79
- if input_value is not None and hasattr(input_value, "kind"):
80
- return input_value.kind
81
- elif isinstance(input_value, dict) and "kind" in input_value:
82
- return input_value["kind"]
83
- else:
84
- return kind_of(input_value)
85
-
86
-
87
- def _handle_discriminated_union_validation_error(
88
- validation_error: ValidationError, cls_name: str, valid_kinds: list[str]
89
- ) -> NoReturn:
90
- """Handle discriminated union validation errors with enhanced messages."""
91
- for error in validation_error.errors():
92
- if error.get("type") == "union_tag_invalid":
93
- invalid_kind = _extract_invalid_kind_from_validation_error(error)
94
- error_msg = _create_enhanced_discriminated_union_error_message(
95
- invalid_kind, cls_name, valid_kinds
96
- )
97
- raise ValueError(error_msg) from validation_error
98
-
99
- # If it's not a discriminated union error, re-raise the original error
100
- raise validation_error
55
+ def _get_all_subclasses(cls) -> set[type]:
56
+ """
57
+ Recursively finds and returns all (loaded) subclasses of a given class.
58
+ """
59
+ result = set()
60
+ for subclass in cls.__subclasses__():
61
+ result.add(subclass)
62
+ result.update(_get_all_subclasses(subclass))
63
+ return result
101
64
 
102
65
 
103
66
  def get_known_concrete_subclasses(cls) -> list[type]:
@@ -115,176 +78,194 @@ def get_known_concrete_subclasses(cls) -> list[type]:
115
78
  return out
116
79
 
117
80
 
118
- class OpenHandsModel(BaseModel):
119
- """
120
- Tags a class where the which may be a discriminated union or contain fields
121
- which contain a discriminated union. The first time an instance is initialized,
122
- the schema is loaded, or a model is validated after a subclass is defined we
123
- regenerate all the polymorphic mappings.
124
- """
81
+ def _get_checked_concrete_subclasses(cls: type) -> dict[str, type]:
82
+ result = {}
83
+ for sub in get_known_concrete_subclasses(cls):
84
+ existing = result.get(sub.__name__)
85
+ if existing:
86
+ raise ValueError(
87
+ f"Duplicate class definition for {cls.__module__}.{cls.__name__}: "
88
+ f"{existing.__module__}.{existing.__name__} : "
89
+ f"{sub.__module__}.{sub.__name__}"
90
+ )
91
+ if "<locals>" in sub.__qualname__:
92
+ raise ValueError(
93
+ f"Local classes not supported! {sub.__module__}.{sub.__name__} "
94
+ f"/ {cls.__module__}.{cls.__name__} "
95
+ "(Since they may not exist at deserialization time)"
96
+ )
97
+ result[sub.__name__] = sub
98
+ return result
125
99
 
126
- def model_post_init(self, _context):
127
- _rebuild_if_required()
128
100
 
129
- @classmethod
130
- def model_validate(cls, *args, **kwargs) -> Self:
131
- _rebuild_if_required()
132
- return super().model_validate(*args, **kwargs)
101
+ class OpenHandsModel(BaseModel):
102
+ """Deprecated: This class exists only for backward compatibility.
133
103
 
134
- @classmethod
135
- def model_validate_json(cls, *args, **kwargs) -> Self:
136
- _rebuild_if_required()
137
- return super().model_validate_json(*args, **kwargs)
104
+ This class is no longer required for discriminated union support.
105
+ New code should extend pydantic.BaseModel directly instead of OpenHandsModel.
138
106
 
139
- @classmethod
140
- def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
141
- _rebuild_if_required()
142
- return super().model_json_schema(*args, **kwargs)
143
-
144
- def model_dump_json(self, **kwargs):
145
- # This was overridden because it seems there is a bug where sometimes
146
- # duplicate fields are produced by model_dump_json which does not appear
147
- # in model_dump
148
- kwargs["mode"] = "json"
149
- return json.dumps(self.model_dump(**kwargs), ensure_ascii=False)
150
-
151
- def __init_subclass__(cls, **kwargs):
152
- """
153
- When a new subclass is defined, mark that we will need
154
- to rebuild everything
155
- """
156
- global _rebuild_required
157
- _rebuild_required = True
107
+ Existing code that extends OpenHandsModel will continue to work, but
108
+ migration to BaseModel is recommended.
109
+ """
158
110
 
159
- return super().__init_subclass__(**kwargs)
160
111
 
112
+ class DiscriminatedUnionMixin(OpenHandsModel):
113
+ @computed_field
114
+ @property
115
+ def kind(self) -> str:
116
+ return self.__class__.__name__
161
117
 
162
- class DiscriminatedUnionMixin(OpenHandsModel, ABC):
163
- """A Base class for members of tagged unions discriminated by the class name.
118
+ @model_validator(mode="wrap")
119
+ @classmethod
120
+ def _validate_subtype(
121
+ cls, data: Any, handler: ModelWrapValidatorHandler[Self], info: ValidationInfo
122
+ ) -> Self:
123
+ if isinstance(data, cls):
124
+ return data
125
+ kind = data.pop("kind", None)
126
+ if not _is_abstract(cls):
127
+ # Sanity check: if we're validating a concrete class directly,
128
+ # the kind (if provided) should match the class name. This should
129
+ # always be true at this point since resolve_kind() would have
130
+ # already routed to the correct subclass.
131
+ assert kind is None or kind == cls.__name__
132
+ return handler(data)
133
+ if kind is None:
134
+ subclasses = _get_checked_concrete_subclasses(cls)
135
+ if not subclasses:
136
+ raise ValueError(
137
+ f"No kinds defined for {cls.__module__}.{cls.__name__}"
138
+ )
139
+ elif len(subclasses) == 1:
140
+ # If there is ony 1 possible implementation, then we do not need
141
+ # to state the kind explicitly - it can only be this!
142
+ kind = next(iter(subclasses))
143
+ else:
144
+ # There is more than 1 kind defined but the input did not specify
145
+ # This will cause an error to be raised
146
+ kind = ""
147
+ subclass = cls.resolve_kind(kind)
148
+ return subclass.model_validate(data, context=info.context)
149
+
150
+ @model_serializer(mode="wrap")
151
+ def _serialize_by_kind(
152
+ self, handler: SerializerFunctionWrapHandler, info: SerializationInfo
153
+ ):
154
+ if isinstance(self, dict):
155
+ # Sometimes pydantic passes a dict in here.
156
+ return self
157
+ if self._is_handler_for_current_class(handler):
158
+ result = handler(self)
159
+ return result
160
+
161
+ # Delegate to the implementing class
162
+ result = self.model_dump(
163
+ mode=info.mode,
164
+ context=info.context,
165
+ by_alias=info.by_alias,
166
+ exclude_unset=info.exclude_unset,
167
+ exclude_defaults=info.exclude_defaults,
168
+ exclude_none=info.exclude_none,
169
+ exclude_computed_fields=info.exclude_computed_fields,
170
+ round_trip=info.round_trip,
171
+ serialize_as_any=info.serialize_as_any,
172
+ )
173
+ return result
164
174
 
165
- This class provides automatic subclass registration and discriminated union
166
- functionality. Each subclass is automatically registered when defined and
167
- can be used for polymorphic serialization/deserialization.
175
+ def _is_handler_for_current_class(
176
+ self, handler: SerializerFunctionWrapHandler
177
+ ) -> bool:
178
+ """Check if the handler is for this class by parsing its repr string.
168
179
 
169
- Child classes will automatically have a type field defined, which is used as a
170
- discriminator for union types.
171
- """
180
+ WARNING: This is a fragile approach that relies on Pydantic's internal
181
+ repr format for SerializerFunctionWrapHandler. The handler is a Pydantic
182
+ wrapper around a Rust function that provides no public API for determining
183
+ which class it serializes. Parsing the repr string is the only available
184
+ mechanism.
172
185
 
173
- __pydantic_core_schema__: ClassVar[Any]
174
- __pydantic_validator__: ClassVar[Any]
175
- __pydantic_serializer__: ClassVar[Any]
186
+ Expected format: `SerializationCallable(serializer=<ClassName>)`
176
187
 
177
- kind: str = Field(default="") # We dynamically update on a per class basis
188
+ If Pydantic changes this format, multiple unit tests will fail immediately,
189
+ including tests in test_discriminated_union.py that verify serialization
190
+ behavior across the class hierarchy.
191
+ """
192
+ # should be in the format `SerializationCallable(serializer=<NAME>)`
193
+ repr_str = str(handler)
178
194
 
179
- @classmethod
180
- def resolve_kind(cls, kind: str) -> type:
181
- for subclass in get_known_concrete_subclasses(cls):
182
- if subclass.__name__ == kind:
183
- return subclass
184
-
185
- # Generate enhanced error message for unknown kind
186
- valid_kinds = [
187
- subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
188
- ]
189
- error_msg = _create_enhanced_discriminated_union_error_message(
190
- kind, cls.__name__, valid_kinds
191
- )
192
- raise ValueError(error_msg)
195
+ # Get everything after =
196
+ _, name = repr_str.split("=", 1)
193
197
 
194
- @classmethod
195
- def __get_pydantic_core_schema__(cls, source_type, handler):
196
- """Generate discriminated union schema for TypeAdapter compatibility."""
197
- if cls.__name__ == "DiscriminatedUnionMixin":
198
- return handler(source_type)
199
-
200
- if _is_abstract(source_type):
201
- _rebuild_if_required()
202
- serializable_type = source_type.get_serializable_type()
203
- # If there are subclasses, generate schema for the discriminated union
204
- if serializable_type is not source_type:
205
- from pydantic_core import core_schema
206
-
207
- # Generate the base schema
208
- base_schema = handler.generate_schema(serializable_type)
209
-
210
- # Wrap it with a custom validation function that provides
211
- # enhanced error messages
212
- def validate_with_enhanced_error(value, handler_func, info): # noqa: ARG001
213
- try:
214
- return handler_func(value)
215
- except ValidationError as e:
216
- valid_kinds = [
217
- subclass.__name__
218
- for subclass in get_known_concrete_subclasses(source_type)
219
- ]
220
- _handle_discriminated_union_validation_error(
221
- e, source_type.__name__, valid_kinds
222
- )
223
-
224
- # Create a with_info_wrap_validator_function schema
225
- return core_schema.with_info_wrap_validator_function(
226
- validate_with_enhanced_error,
227
- base_schema,
228
- )
198
+ # Cut off the )
199
+ name = name[:-1]
229
200
 
230
- return handler(source_type)
201
+ result = self.__class__.__name__ == name
202
+ return result
231
203
 
232
204
  @classmethod
233
- def __get_pydantic_json_schema__(cls, core_schema, handler):
234
- """Add discriminator to OpenAPI schema and ensure component generation."""
235
- json_schema = handler(core_schema)
205
+ def __get_pydantic_json_schema__(
206
+ cls, core_schema: CoreSchema, handler: Any
207
+ ) -> JsonSchemaValue:
208
+ schemas_in_progress = _get_schemas_in_progress()
209
+
210
+ # First we check if we are already generating a schema
211
+ schema = schemas_in_progress.get(cls)
212
+ if schema:
213
+ return schema
214
+
215
+ # Set a temp schema to prevent infinite recursion
216
+ schemas_in_progress[cls] = {"$ref": f"#/$defs/{cls.__name__}"}
217
+ try:
218
+ if _is_abstract(cls):
219
+ subclasses = _get_checked_concrete_subclasses(cls)
220
+ if not subclasses:
221
+ raise ValueError(f"No subclasses defined for {cls.__name__}")
222
+ if len(subclasses) == 1:
223
+ # Use the shared generator for single subclass too
224
+ gen = handler.generate_json_schema
225
+ sub_schema = gen.generate_inner(
226
+ next(iter(subclasses.values())).__pydantic_core_schema__
227
+ )
228
+ return sub_schema
236
229
 
237
- # Add discriminator if this is a oneOf schema
238
- if isinstance(json_schema, dict) and "oneOf" in json_schema:
239
- # Add title for abstract classes to encourage separate component creation
240
- if _is_abstract(cls) and "title" not in json_schema:
241
- json_schema["title"] = cls.__name__
230
+ # Use the shared generator to properly register definitions
231
+ gen = handler.generate_json_schema
232
+ schemas = []
233
+ for sub in subclasses.values():
234
+ sub_schema = gen.generate_inner(sub.__pydantic_core_schema__)
235
+ schemas.append(sub_schema)
242
236
 
243
- if "discriminator" not in json_schema:
237
+ # Build discriminator mapping from $ref schemas
244
238
  mapping = {}
245
- for option in json_schema["oneOf"]:
239
+ for option in schemas:
246
240
  if "$ref" in option:
247
241
  kind = option["$ref"].split("/")[-1]
248
242
  mapping[kind] = option["$ref"]
249
243
 
250
- if mapping:
251
- json_schema["discriminator"] = {
252
- "propertyName": "kind",
253
- "mapping": mapping,
254
- }
255
-
256
- return json_schema
244
+ schema = {
245
+ "oneOf": schemas,
246
+ "discriminator": {"propertyName": "kind", "mapping": mapping},
247
+ }
248
+ else:
249
+ schema = handler(core_schema)
250
+ schema["properties"]["kind"] = {
251
+ "const": cls.__name__,
252
+ "title": "Kind",
253
+ "type": "string",
254
+ }
255
+ finally:
256
+ # Reset temp schema
257
+ schemas_in_progress.pop(cls)
258
+ return schema
257
259
 
258
260
  @classmethod
259
- def model_rebuild(
260
- cls,
261
- *,
262
- force=False,
263
- raise_errors=True,
264
- _parent_namespace_depth=2,
265
- _types_namespace=None,
266
- ):
267
- if cls == DiscriminatedUnionMixin:
268
- pass
269
- if _is_abstract(cls):
270
- subclasses = get_known_concrete_subclasses(cls)
271
- kinds = [subclass.__name__ for subclass in subclasses]
272
- if kinds:
273
- kind_field = cls.model_fields["kind"]
274
- kind_field.annotation = Literal[tuple(kinds)] # type: ignore
275
- kind_field.default = kinds[0]
276
-
277
- type_adapter = TypeAdapter(cls.get_serializable_type())
278
- cls.__pydantic_core_schema__ = type_adapter.core_schema
279
- cls.__pydantic_validator__ = type_adapter.validator
280
- cls.__pydantic_serializer__ = type_adapter.serializer
281
- return
282
-
283
- return super().model_rebuild(
284
- force=force,
285
- raise_errors=raise_errors,
286
- _parent_namespace_depth=_parent_namespace_depth,
287
- _types_namespace=_types_namespace,
261
+ def resolve_kind(cls, kind: str) -> type[Self]:
262
+ subclasses = _get_checked_concrete_subclasses(cls)
263
+ subclass = subclasses.get(kind)
264
+ if subclass:
265
+ return subclass
266
+ raise ValueError(
267
+ f"Unknown kind '{kind}' for {cls.__module__}.{cls.__name__}; "
268
+ f"Expected one of: {list(subclasses)}"
288
269
  )
289
270
 
290
271
  @classmethod
@@ -298,7 +279,7 @@ class DiscriminatedUnionMixin(OpenHandsModel, ABC):
298
279
  if not _is_abstract(cls):
299
280
  return cls
300
281
 
301
- subclasses = list(get_known_concrete_subclasses(cls))
282
+ subclasses = _get_checked_concrete_subclasses(cls)
302
283
  if not subclasses:
303
284
  return cls
304
285
 
@@ -306,265 +287,10 @@ class DiscriminatedUnionMixin(OpenHandsModel, ABC):
306
287
  # Returning the concrete type ensures Pydantic instantiates the subclass
307
288
  # (e.g. Agent) rather than the abstract base (e.g. AgentBase) when there is
308
289
  # only ONE concrete subclass.
309
- return subclasses[0]
290
+ return next(iter(subclasses.values()))
310
291
 
311
292
  serializable_type = Annotated[
312
- Union[*tuple(Annotated[t, Tag(t.__name__)] for t in subclasses)],
293
+ Union[*tuple(Annotated[t, Tag(n)] for n, t in subclasses.items())],
313
294
  Discriminator(kind_of),
314
295
  ]
315
296
  return serializable_type # type: ignore
316
-
317
- @classmethod
318
- def model_validate(cls, obj: Any, **kwargs) -> Self:
319
- try:
320
- if _is_abstract(cls):
321
- resolved = cls.resolve_kind(kind_of(obj))
322
- else:
323
- resolved = super()
324
- result = resolved.model_validate(obj, **kwargs)
325
- return result # type: ignore
326
- except ValidationError as e:
327
- valid_kinds = [
328
- subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
329
- ]
330
- _handle_discriminated_union_validation_error(e, cls.__name__, valid_kinds)
331
-
332
- @classmethod
333
- def model_validate_json(
334
- cls,
335
- json_data: str | bytes | bytearray,
336
- **kwargs,
337
- ) -> Self:
338
- data = json.loads(json_data)
339
- if _is_abstract(cls):
340
- resolved = cls.resolve_kind(kind_of(data))
341
- else:
342
- resolved = super()
343
- result = resolved.model_validate(data, **kwargs)
344
- return result # type: ignore
345
-
346
- def __init_subclass__(cls, **kwargs):
347
- super().__init_subclass__(**kwargs)
348
-
349
- # If concrete, stamp kind Literal and collision check
350
- if not _is_abstract(cls):
351
- # 1) Stamp discriminator
352
- cls.kind = cls.__name__
353
- cls.__annotations__["kind"] = Literal[cls.__name__]
354
-
355
- # 2) Collision check
356
- mro = cls.mro()
357
- union_class = mro[mro.index(DiscriminatedUnionMixin) - 1]
358
- concretes = get_known_concrete_subclasses(union_class) # sorted list
359
- kinds: dict[str, type] = {}
360
- for sub in concretes:
361
- k = kind_of(sub)
362
- if k in kinds and kinds[k] is not sub:
363
- raise ValueError(
364
- f"Duplicate kind detected for {union_class} : {cls}, {sub}"
365
- )
366
- kinds[k] = sub
367
-
368
- # Rebuild any abstract union owners in the MRO that rely on subclass sets
369
- for base in cls.mro():
370
- # Stop when we pass ourselves
371
- if base is cls:
372
- continue
373
- # Only rebuild abstract DiscriminatedUnion owners
374
- if (
375
- isinstance(base, type)
376
- and issubclass(base, DiscriminatedUnionMixin)
377
- and _is_abstract(base)
378
- ):
379
- base.model_rebuild(force=True)
380
-
381
-
382
- def _rebuild_if_required():
383
- if _rebuild_required:
384
- rebuild_all()
385
-
386
-
387
- def _extract_discriminated_unions(schema: dict) -> dict:
388
- """Extract inline discriminated unions as separate components.
389
-
390
- Recursively scans the schema and extracts any inline discriminated union
391
- (oneOf + discriminator + title) as a separate component, replacing it with a $ref.
392
- Also deduplicates schemas with identical titles.
393
- """
394
- import json
395
- import re
396
- from collections import defaultdict
397
-
398
- if not isinstance(schema, dict):
399
- return schema
400
-
401
- # OpenAPI schema names must match this pattern
402
- valid_name_pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
403
-
404
- schemas = schema.get("components", {}).get("schemas", {})
405
- extracted = {}
406
-
407
- def _find_and_extract(obj, path=""):
408
- if not isinstance(obj, dict):
409
- return obj
410
-
411
- # Extract inline discriminated unions
412
- if "oneOf" in obj and "discriminator" in obj and "title" in obj:
413
- title = obj["title"]
414
- if (
415
- title not in schemas
416
- and title not in extracted
417
- and valid_name_pattern.match(title)
418
- ):
419
- extracted[title] = {
420
- "oneOf": obj["oneOf"],
421
- "discriminator": obj["discriminator"],
422
- "title": title,
423
- }
424
- return {"$ref": f"#/components/schemas/{title}"}
425
-
426
- # Recursively process nested structures
427
- result = {}
428
- for key, value in obj.items():
429
- if isinstance(value, dict):
430
- result[key] = _find_and_extract(value, f"{path}.{key}")
431
- elif isinstance(value, list):
432
- result[key] = [
433
- _find_and_extract(item, f"{path}.{key}[]") for item in value
434
- ]
435
- else:
436
- result[key] = value
437
- return result
438
-
439
- schema = _find_and_extract(schema)
440
-
441
- if extracted and "components" in schema and "schemas" in schema["components"]:
442
- schema["components"]["schemas"].update(extracted)
443
-
444
- # Deduplicate schemas with same title (prefer *-Output over *-Input over base)
445
- schemas = schema.get("components", {}).get("schemas", {})
446
- title_to_names = defaultdict(list)
447
- for name, defn in schemas.items():
448
- if isinstance(defn, dict):
449
- title_to_names[defn.get("title", name)].append(name)
450
-
451
- to_remove = {}
452
- for title, names in title_to_names.items():
453
- if len(names) > 1:
454
- # Prefer: *-Output > *-Input > base name
455
- keep = sorted(
456
- names,
457
- key=lambda n: (
458
- 0 if n.endswith("-Output") else 1 if n.endswith("-Input") else 2,
459
- n,
460
- ),
461
- )[0]
462
- for name in names:
463
- if name != keep:
464
- to_remove[name] = keep
465
-
466
- if to_remove:
467
- schema_str = json.dumps(schema)
468
- for old, new in to_remove.items():
469
- schema_str = schema_str.replace(
470
- f'"#/components/schemas/{old}"', f'"#/components/schemas/{new}"'
471
- )
472
- schema = json.loads(schema_str)
473
- for old in to_remove:
474
- schema["components"]["schemas"].pop(old, None)
475
-
476
- return schema
477
-
478
-
479
- def _patch_fastapi_discriminated_union_support():
480
- """Patch FastAPI to handle discriminated union schemas without $ref.
481
-
482
- This ensures discriminated unions from DiscriminatedUnionMixin work correctly
483
- with FastAPI's OpenAPI schema generation. The patch prevents KeyError when
484
- FastAPI encounters schemas without $ref keys (which discriminated unions use).
485
-
486
- Also extracts inline discriminated unions as separate schema components for
487
- better OpenAPI documentation and Swagger UI display.
488
-
489
- Skips patching if SKIP_FASTAPI_DISCRIMINATED_UNION_FIX environment variable is set.
490
- """
491
- # Skip patching if environment variable flag is defined
492
- if os.environ.get("SKIP_FASTAPI_DISCRIMINATED_UNION_FIX"):
493
- logger.debug(
494
- "Skipping FastAPI discriminated union patch due to environment variable"
495
- )
496
- return
497
-
498
- try:
499
- import fastapi._compat.v2 as fastapi_v2
500
- from fastapi import FastAPI
501
-
502
- _original_remap = fastapi_v2._remap_definitions_and_field_mappings
503
-
504
- def _patched_remap_definitions_and_field_mappings(**kwargs):
505
- """Patched version that handles schemas w/o $ref (discriminated unions)."""
506
- field_mapping = kwargs.get("field_mapping", {})
507
- model_name_map = kwargs.get("model_name_map", {})
508
-
509
- # Build old_name -> new_name map, skipping schemas without $ref
510
- old_name_to_new_name_map = {}
511
- for field_key, schema in field_mapping.items():
512
- model = field_key[0].type_
513
- if model not in model_name_map:
514
- continue
515
- new_name = model_name_map[model]
516
-
517
- # Skip schemas without $ref (discriminated unions)
518
- if "$ref" not in schema:
519
- continue
520
-
521
- old_name = schema["$ref"].split("/")[-1]
522
- if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
523
- continue
524
- old_name_to_new_name_map[old_name] = new_name
525
-
526
- # Replace refs using FastAPI's helper
527
- from fastapi._compat.v2 import _replace_refs
528
-
529
- new_field_mapping = {}
530
- for field_key, schema in field_mapping.items():
531
- new_schema = _replace_refs(
532
- schema=schema,
533
- old_name_to_new_name_map=old_name_to_new_name_map,
534
- )
535
- new_field_mapping[field_key] = new_schema
536
-
537
- definitions = kwargs.get("definitions", {})
538
- new_definitions = {}
539
- for key, value in definitions.items():
540
- new_key = old_name_to_new_name_map.get(key, key)
541
- new_value = _replace_refs(
542
- schema=value,
543
- old_name_to_new_name_map=old_name_to_new_name_map,
544
- )
545
- new_definitions[new_key] = new_value
546
-
547
- return new_field_mapping, new_definitions
548
-
549
- # Apply the patch
550
- fastapi_v2._remap_definitions_and_field_mappings = (
551
- _patched_remap_definitions_and_field_mappings
552
- )
553
-
554
- # Patch FastAPI.openapi() to extract discriminated unions
555
- _original_openapi = FastAPI.openapi
556
-
557
- def _patched_openapi(self):
558
- """Patched openapi() that extracts discriminated unions."""
559
- schema = _original_openapi(self)
560
- return _extract_discriminated_unions(schema)
561
-
562
- FastAPI.openapi = _patched_openapi
563
-
564
- except (ImportError, AttributeError):
565
- # FastAPI not available or internal API changed
566
- pass
567
-
568
-
569
- # Always call the FastAPI patch after DiscriminatedUnionMixin definition
570
- _patch_fastapi_discriminated_union_support()