openhands-sdk 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. openhands/sdk/__init__.py +111 -0
  2. openhands/sdk/agent/__init__.py +8 -0
  3. openhands/sdk/agent/agent.py +650 -0
  4. openhands/sdk/agent/base.py +457 -0
  5. openhands/sdk/agent/prompts/in_context_learning_example.j2 +169 -0
  6. openhands/sdk/agent/prompts/in_context_learning_example_suffix.j2 +3 -0
  7. openhands/sdk/agent/prompts/model_specific/anthropic_claude.j2 +3 -0
  8. openhands/sdk/agent/prompts/model_specific/google_gemini.j2 +1 -0
  9. openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5-codex.j2 +2 -0
  10. openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5.j2 +3 -0
  11. openhands/sdk/agent/prompts/security_policy.j2 +22 -0
  12. openhands/sdk/agent/prompts/security_risk_assessment.j2 +21 -0
  13. openhands/sdk/agent/prompts/self_documentation.j2 +15 -0
  14. openhands/sdk/agent/prompts/system_prompt.j2 +132 -0
  15. openhands/sdk/agent/prompts/system_prompt_interactive.j2 +14 -0
  16. openhands/sdk/agent/prompts/system_prompt_long_horizon.j2 +40 -0
  17. openhands/sdk/agent/prompts/system_prompt_planning.j2 +40 -0
  18. openhands/sdk/agent/prompts/system_prompt_tech_philosophy.j2 +122 -0
  19. openhands/sdk/agent/utils.py +228 -0
  20. openhands/sdk/context/__init__.py +28 -0
  21. openhands/sdk/context/agent_context.py +264 -0
  22. openhands/sdk/context/condenser/__init__.py +18 -0
  23. openhands/sdk/context/condenser/base.py +100 -0
  24. openhands/sdk/context/condenser/llm_summarizing_condenser.py +248 -0
  25. openhands/sdk/context/condenser/no_op_condenser.py +14 -0
  26. openhands/sdk/context/condenser/pipeline_condenser.py +56 -0
  27. openhands/sdk/context/condenser/prompts/summarizing_prompt.j2 +59 -0
  28. openhands/sdk/context/condenser/utils.py +149 -0
  29. openhands/sdk/context/prompts/__init__.py +6 -0
  30. openhands/sdk/context/prompts/prompt.py +114 -0
  31. openhands/sdk/context/prompts/templates/ask_agent_template.j2 +11 -0
  32. openhands/sdk/context/prompts/templates/skill_knowledge_info.j2 +8 -0
  33. openhands/sdk/context/prompts/templates/system_message_suffix.j2 +32 -0
  34. openhands/sdk/context/skills/__init__.py +28 -0
  35. openhands/sdk/context/skills/exceptions.py +11 -0
  36. openhands/sdk/context/skills/skill.py +720 -0
  37. openhands/sdk/context/skills/trigger.py +36 -0
  38. openhands/sdk/context/skills/types.py +48 -0
  39. openhands/sdk/context/view.py +503 -0
  40. openhands/sdk/conversation/__init__.py +40 -0
  41. openhands/sdk/conversation/base.py +281 -0
  42. openhands/sdk/conversation/conversation.py +152 -0
  43. openhands/sdk/conversation/conversation_stats.py +85 -0
  44. openhands/sdk/conversation/event_store.py +157 -0
  45. openhands/sdk/conversation/events_list_base.py +17 -0
  46. openhands/sdk/conversation/exceptions.py +50 -0
  47. openhands/sdk/conversation/fifo_lock.py +133 -0
  48. openhands/sdk/conversation/impl/__init__.py +5 -0
  49. openhands/sdk/conversation/impl/local_conversation.py +665 -0
  50. openhands/sdk/conversation/impl/remote_conversation.py +956 -0
  51. openhands/sdk/conversation/persistence_const.py +9 -0
  52. openhands/sdk/conversation/response_utils.py +41 -0
  53. openhands/sdk/conversation/secret_registry.py +126 -0
  54. openhands/sdk/conversation/serialization_diff.py +0 -0
  55. openhands/sdk/conversation/state.py +392 -0
  56. openhands/sdk/conversation/stuck_detector.py +311 -0
  57. openhands/sdk/conversation/title_utils.py +191 -0
  58. openhands/sdk/conversation/types.py +45 -0
  59. openhands/sdk/conversation/visualizer/__init__.py +12 -0
  60. openhands/sdk/conversation/visualizer/base.py +67 -0
  61. openhands/sdk/conversation/visualizer/default.py +373 -0
  62. openhands/sdk/critic/__init__.py +15 -0
  63. openhands/sdk/critic/base.py +38 -0
  64. openhands/sdk/critic/impl/__init__.py +12 -0
  65. openhands/sdk/critic/impl/agent_finished.py +83 -0
  66. openhands/sdk/critic/impl/empty_patch.py +49 -0
  67. openhands/sdk/critic/impl/pass_critic.py +42 -0
  68. openhands/sdk/event/__init__.py +42 -0
  69. openhands/sdk/event/base.py +149 -0
  70. openhands/sdk/event/condenser.py +82 -0
  71. openhands/sdk/event/conversation_error.py +25 -0
  72. openhands/sdk/event/conversation_state.py +104 -0
  73. openhands/sdk/event/llm_completion_log.py +39 -0
  74. openhands/sdk/event/llm_convertible/__init__.py +20 -0
  75. openhands/sdk/event/llm_convertible/action.py +139 -0
  76. openhands/sdk/event/llm_convertible/message.py +142 -0
  77. openhands/sdk/event/llm_convertible/observation.py +141 -0
  78. openhands/sdk/event/llm_convertible/system.py +61 -0
  79. openhands/sdk/event/token.py +16 -0
  80. openhands/sdk/event/types.py +11 -0
  81. openhands/sdk/event/user_action.py +21 -0
  82. openhands/sdk/git/exceptions.py +43 -0
  83. openhands/sdk/git/git_changes.py +249 -0
  84. openhands/sdk/git/git_diff.py +129 -0
  85. openhands/sdk/git/models.py +21 -0
  86. openhands/sdk/git/utils.py +189 -0
  87. openhands/sdk/hooks/__init__.py +30 -0
  88. openhands/sdk/hooks/config.py +180 -0
  89. openhands/sdk/hooks/conversation_hooks.py +227 -0
  90. openhands/sdk/hooks/executor.py +155 -0
  91. openhands/sdk/hooks/manager.py +170 -0
  92. openhands/sdk/hooks/types.py +40 -0
  93. openhands/sdk/io/__init__.py +6 -0
  94. openhands/sdk/io/base.py +48 -0
  95. openhands/sdk/io/cache.py +85 -0
  96. openhands/sdk/io/local.py +119 -0
  97. openhands/sdk/io/memory.py +54 -0
  98. openhands/sdk/llm/__init__.py +45 -0
  99. openhands/sdk/llm/exceptions/__init__.py +45 -0
  100. openhands/sdk/llm/exceptions/classifier.py +50 -0
  101. openhands/sdk/llm/exceptions/mapping.py +54 -0
  102. openhands/sdk/llm/exceptions/types.py +101 -0
  103. openhands/sdk/llm/llm.py +1140 -0
  104. openhands/sdk/llm/llm_registry.py +122 -0
  105. openhands/sdk/llm/llm_response.py +59 -0
  106. openhands/sdk/llm/message.py +656 -0
  107. openhands/sdk/llm/mixins/fn_call_converter.py +1288 -0
  108. openhands/sdk/llm/mixins/non_native_fc.py +97 -0
  109. openhands/sdk/llm/options/__init__.py +1 -0
  110. openhands/sdk/llm/options/chat_options.py +93 -0
  111. openhands/sdk/llm/options/common.py +19 -0
  112. openhands/sdk/llm/options/responses_options.py +67 -0
  113. openhands/sdk/llm/router/__init__.py +10 -0
  114. openhands/sdk/llm/router/base.py +117 -0
  115. openhands/sdk/llm/router/impl/multimodal.py +76 -0
  116. openhands/sdk/llm/router/impl/random.py +22 -0
  117. openhands/sdk/llm/streaming.py +9 -0
  118. openhands/sdk/llm/utils/metrics.py +312 -0
  119. openhands/sdk/llm/utils/model_features.py +192 -0
  120. openhands/sdk/llm/utils/model_info.py +90 -0
  121. openhands/sdk/llm/utils/model_prompt_spec.py +98 -0
  122. openhands/sdk/llm/utils/retry_mixin.py +128 -0
  123. openhands/sdk/llm/utils/telemetry.py +362 -0
  124. openhands/sdk/llm/utils/unverified_models.py +156 -0
  125. openhands/sdk/llm/utils/verified_models.py +65 -0
  126. openhands/sdk/logger/__init__.py +22 -0
  127. openhands/sdk/logger/logger.py +195 -0
  128. openhands/sdk/logger/rolling.py +113 -0
  129. openhands/sdk/mcp/__init__.py +24 -0
  130. openhands/sdk/mcp/client.py +76 -0
  131. openhands/sdk/mcp/definition.py +106 -0
  132. openhands/sdk/mcp/exceptions.py +19 -0
  133. openhands/sdk/mcp/tool.py +270 -0
  134. openhands/sdk/mcp/utils.py +83 -0
  135. openhands/sdk/observability/__init__.py +4 -0
  136. openhands/sdk/observability/laminar.py +166 -0
  137. openhands/sdk/observability/utils.py +20 -0
  138. openhands/sdk/py.typed +0 -0
  139. openhands/sdk/secret/__init__.py +19 -0
  140. openhands/sdk/secret/secrets.py +92 -0
  141. openhands/sdk/security/__init__.py +6 -0
  142. openhands/sdk/security/analyzer.py +111 -0
  143. openhands/sdk/security/confirmation_policy.py +61 -0
  144. openhands/sdk/security/llm_analyzer.py +29 -0
  145. openhands/sdk/security/risk.py +100 -0
  146. openhands/sdk/tool/__init__.py +34 -0
  147. openhands/sdk/tool/builtins/__init__.py +34 -0
  148. openhands/sdk/tool/builtins/finish.py +106 -0
  149. openhands/sdk/tool/builtins/think.py +117 -0
  150. openhands/sdk/tool/registry.py +184 -0
  151. openhands/sdk/tool/schema.py +286 -0
  152. openhands/sdk/tool/spec.py +39 -0
  153. openhands/sdk/tool/tool.py +481 -0
  154. openhands/sdk/utils/__init__.py +22 -0
  155. openhands/sdk/utils/async_executor.py +115 -0
  156. openhands/sdk/utils/async_utils.py +39 -0
  157. openhands/sdk/utils/cipher.py +68 -0
  158. openhands/sdk/utils/command.py +90 -0
  159. openhands/sdk/utils/deprecation.py +166 -0
  160. openhands/sdk/utils/github.py +44 -0
  161. openhands/sdk/utils/json.py +48 -0
  162. openhands/sdk/utils/models.py +570 -0
  163. openhands/sdk/utils/paging.py +63 -0
  164. openhands/sdk/utils/pydantic_diff.py +85 -0
  165. openhands/sdk/utils/pydantic_secrets.py +64 -0
  166. openhands/sdk/utils/truncate.py +117 -0
  167. openhands/sdk/utils/visualize.py +58 -0
  168. openhands/sdk/workspace/__init__.py +17 -0
  169. openhands/sdk/workspace/base.py +158 -0
  170. openhands/sdk/workspace/local.py +189 -0
  171. openhands/sdk/workspace/models.py +35 -0
  172. openhands/sdk/workspace/remote/__init__.py +8 -0
  173. openhands/sdk/workspace/remote/async_remote_workspace.py +149 -0
  174. openhands/sdk/workspace/remote/base.py +164 -0
  175. openhands/sdk/workspace/remote/remote_workspace_mixin.py +323 -0
  176. openhands/sdk/workspace/workspace.py +49 -0
  177. openhands_sdk-1.7.3.dist-info/METADATA +17 -0
  178. openhands_sdk-1.7.3.dist-info/RECORD +180 -0
  179. openhands_sdk-1.7.3.dist-info/WHEEL +5 -0
  180. openhands_sdk-1.7.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,570 @@
1
+ import inspect
2
+ import json
3
+ import logging
4
+ import os
5
+ from abc import ABC
6
+ from typing import Annotated, Any, ClassVar, Literal, NoReturn, Self, Union
7
+
8
+ from pydantic import (
9
+ BaseModel,
10
+ Discriminator,
11
+ Field,
12
+ Tag,
13
+ TypeAdapter,
14
+ ValidationError,
15
+ )
16
+ from pydantic_core import ErrorDetails
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+ _rebuild_required = True
21
+
22
+
23
+ def _is_abstract(type_: type) -> bool:
24
+ """Determine whether the class directly extends ABC or contains abstract methods"""
25
+ try:
26
+ return inspect.isabstract(type_) or ABC in type_.__bases__
27
+ except Exception:
28
+ return False
29
+
30
+
31
+ def _get_all_subclasses(cls) -> set[type]:
32
+ """
33
+ Recursively finds and returns all (loaded) subclasses of a given class.
34
+ """
35
+ result = set()
36
+ for subclass in cls.__subclasses__():
37
+ result.add(subclass)
38
+ result.update(_get_all_subclasses(subclass))
39
+ return result
40
+
41
+
42
+ def rebuild_all():
43
+ """Rebuild all polymorphic classes."""
44
+ global _rebuild_required
45
+ _rebuild_required = False
46
+ for cls in _get_all_subclasses(OpenHandsModel):
47
+ cls.model_rebuild(force=True)
48
+ for cls in _get_all_subclasses(DiscriminatedUnionMixin):
49
+ cls.model_rebuild(force=True)
50
+
51
+
52
+ def kind_of(obj) -> str:
53
+ """Get the string value for the kind tag"""
54
+ if isinstance(obj, dict):
55
+ return obj["kind"]
56
+ if not hasattr(obj, "__name__"):
57
+ obj = obj.__class__
58
+ return obj.__name__
59
+
60
+
61
+ def _create_enhanced_discriminated_union_error_message(
62
+ invalid_kind: str, cls_name: str, valid_kinds: list[str]
63
+ ) -> str:
64
+ """Create an enhanced error message for discriminated union validation failures."""
65
+ possible_kinds_str = ", ".join(sorted(valid_kinds)) if valid_kinds else "none"
66
+ return (
67
+ f"Unexpected kind '{invalid_kind}' for {cls_name}. "
68
+ f"Expected one of: {possible_kinds_str}. "
69
+ f"If you receive this error when trying to wrap a "
70
+ f"DiscriminatedUnion instance inside another pydantic model, "
71
+ f"you may need to use OpenHandsModel instead of BaseModel "
72
+ f"to make sure that an invalid schema has not been cached."
73
+ )
74
+
75
+
76
+ def _extract_invalid_kind_from_validation_error(error: ErrorDetails) -> str:
77
+ """Extract the invalid kind from a Pydantic validation error."""
78
+ input_value = error.get("input")
79
+ if input_value is not None and hasattr(input_value, "kind"):
80
+ return input_value.kind
81
+ elif isinstance(input_value, dict) and "kind" in input_value:
82
+ return input_value["kind"]
83
+ else:
84
+ return kind_of(input_value)
85
+
86
+
87
+ def _handle_discriminated_union_validation_error(
88
+ validation_error: ValidationError, cls_name: str, valid_kinds: list[str]
89
+ ) -> NoReturn:
90
+ """Handle discriminated union validation errors with enhanced messages."""
91
+ for error in validation_error.errors():
92
+ if error.get("type") == "union_tag_invalid":
93
+ invalid_kind = _extract_invalid_kind_from_validation_error(error)
94
+ error_msg = _create_enhanced_discriminated_union_error_message(
95
+ invalid_kind, cls_name, valid_kinds
96
+ )
97
+ raise ValueError(error_msg) from validation_error
98
+
99
+ # If it's not a discriminated union error, re-raise the original error
100
+ raise validation_error
101
+
102
+
103
+ def get_known_concrete_subclasses(cls) -> list[type]:
104
+ """Recursively returns all concrete subclasses in a stable order,
105
+ without deduping classes that share the same (module, name)."""
106
+ out: list[type] = []
107
+ for sub in cls.__subclasses__():
108
+ # Recurse first so deeper classes appear after their parents
109
+ out.extend(get_known_concrete_subclasses(sub))
110
+ if not _is_abstract(sub):
111
+ out.append(sub)
112
+
113
+ # Use qualname to distinguish nested/local classes (like test-local Cat)
114
+ out.sort(key=lambda t: (t.__module__, getattr(t, "__qualname__", t.__name__)))
115
+ return out
116
+
117
+
118
+ class OpenHandsModel(BaseModel):
119
+ """
120
+ Tags a class where the which may be a discriminated union or contain fields
121
+ which contain a discriminated union. The first time an instance is initialized,
122
+ the schema is loaded, or a model is validated after a subclass is defined we
123
+ regenerate all the polymorphic mappings.
124
+ """
125
+
126
+ def model_post_init(self, _context):
127
+ _rebuild_if_required()
128
+
129
+ @classmethod
130
+ def model_validate(cls, *args, **kwargs) -> Self:
131
+ _rebuild_if_required()
132
+ return super().model_validate(*args, **kwargs)
133
+
134
+ @classmethod
135
+ def model_validate_json(cls, *args, **kwargs) -> Self:
136
+ _rebuild_if_required()
137
+ return super().model_validate_json(*args, **kwargs)
138
+
139
+ @classmethod
140
+ def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
141
+ _rebuild_if_required()
142
+ return super().model_json_schema(*args, **kwargs)
143
+
144
+ def model_dump_json(self, **kwargs):
145
+ # This was overridden because it seems there is a bug where sometimes
146
+ # duplicate fields are produced by model_dump_json which does not appear
147
+ # in model_dump
148
+ kwargs["mode"] = "json"
149
+ return json.dumps(self.model_dump(**kwargs), ensure_ascii=False)
150
+
151
+ def __init_subclass__(cls, **kwargs):
152
+ """
153
+ When a new subclass is defined, mark that we will need
154
+ to rebuild everything
155
+ """
156
+ global _rebuild_required
157
+ _rebuild_required = True
158
+
159
+ return super().__init_subclass__(**kwargs)
160
+
161
+
162
+ class DiscriminatedUnionMixin(OpenHandsModel, ABC):
163
+ """A Base class for members of tagged unions discriminated by the class name.
164
+
165
+ This class provides automatic subclass registration and discriminated union
166
+ functionality. Each subclass is automatically registered when defined and
167
+ can be used for polymorphic serialization/deserialization.
168
+
169
+ Child classes will automatically have a type field defined, which is used as a
170
+ discriminator for union types.
171
+ """
172
+
173
+ __pydantic_core_schema__: ClassVar[Any]
174
+ __pydantic_validator__: ClassVar[Any]
175
+ __pydantic_serializer__: ClassVar[Any]
176
+
177
+ kind: str = Field(default="") # We dynamically update on a per class basis
178
+
179
+ @classmethod
180
+ def resolve_kind(cls, kind: str) -> type:
181
+ for subclass in get_known_concrete_subclasses(cls):
182
+ if subclass.__name__ == kind:
183
+ return subclass
184
+
185
+ # Generate enhanced error message for unknown kind
186
+ valid_kinds = [
187
+ subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
188
+ ]
189
+ error_msg = _create_enhanced_discriminated_union_error_message(
190
+ kind, cls.__name__, valid_kinds
191
+ )
192
+ raise ValueError(error_msg)
193
+
194
+ @classmethod
195
+ def __get_pydantic_core_schema__(cls, source_type, handler):
196
+ """Generate discriminated union schema for TypeAdapter compatibility."""
197
+ if cls.__name__ == "DiscriminatedUnionMixin":
198
+ return handler(source_type)
199
+
200
+ if _is_abstract(source_type):
201
+ _rebuild_if_required()
202
+ serializable_type = source_type.get_serializable_type()
203
+ # If there are subclasses, generate schema for the discriminated union
204
+ if serializable_type is not source_type:
205
+ from pydantic_core import core_schema
206
+
207
+ # Generate the base schema
208
+ base_schema = handler.generate_schema(serializable_type)
209
+
210
+ # Wrap it with a custom validation function that provides
211
+ # enhanced error messages
212
+ def validate_with_enhanced_error(value, handler_func, info): # noqa: ARG001
213
+ try:
214
+ return handler_func(value)
215
+ except ValidationError as e:
216
+ valid_kinds = [
217
+ subclass.__name__
218
+ for subclass in get_known_concrete_subclasses(source_type)
219
+ ]
220
+ _handle_discriminated_union_validation_error(
221
+ e, source_type.__name__, valid_kinds
222
+ )
223
+
224
+ # Create a with_info_wrap_validator_function schema
225
+ return core_schema.with_info_wrap_validator_function(
226
+ validate_with_enhanced_error,
227
+ base_schema,
228
+ )
229
+
230
+ return handler(source_type)
231
+
232
+ @classmethod
233
+ def __get_pydantic_json_schema__(cls, core_schema, handler):
234
+ """Add discriminator to OpenAPI schema and ensure component generation."""
235
+ json_schema = handler(core_schema)
236
+
237
+ # Add discriminator if this is a oneOf schema
238
+ if isinstance(json_schema, dict) and "oneOf" in json_schema:
239
+ # Add title for abstract classes to encourage separate component creation
240
+ if _is_abstract(cls) and "title" not in json_schema:
241
+ json_schema["title"] = cls.__name__
242
+
243
+ if "discriminator" not in json_schema:
244
+ mapping = {}
245
+ for option in json_schema["oneOf"]:
246
+ if "$ref" in option:
247
+ kind = option["$ref"].split("/")[-1]
248
+ mapping[kind] = option["$ref"]
249
+
250
+ if mapping:
251
+ json_schema["discriminator"] = {
252
+ "propertyName": "kind",
253
+ "mapping": mapping,
254
+ }
255
+
256
+ return json_schema
257
+
258
+ @classmethod
259
+ def model_rebuild(
260
+ cls,
261
+ *,
262
+ force=False,
263
+ raise_errors=True,
264
+ _parent_namespace_depth=2,
265
+ _types_namespace=None,
266
+ ):
267
+ if cls == DiscriminatedUnionMixin:
268
+ pass
269
+ if _is_abstract(cls):
270
+ subclasses = get_known_concrete_subclasses(cls)
271
+ kinds = [subclass.__name__ for subclass in subclasses]
272
+ if kinds:
273
+ kind_field = cls.model_fields["kind"]
274
+ kind_field.annotation = Literal[tuple(kinds)] # type: ignore
275
+ kind_field.default = kinds[0]
276
+
277
+ type_adapter = TypeAdapter(cls.get_serializable_type())
278
+ cls.__pydantic_core_schema__ = type_adapter.core_schema
279
+ cls.__pydantic_validator__ = type_adapter.validator
280
+ cls.__pydantic_serializer__ = type_adapter.serializer
281
+ return
282
+
283
+ return super().model_rebuild(
284
+ force=force,
285
+ raise_errors=raise_errors,
286
+ _parent_namespace_depth=_parent_namespace_depth,
287
+ _types_namespace=_types_namespace,
288
+ )
289
+
290
+ @classmethod
291
+ def get_serializable_type(cls) -> type:
292
+ """
293
+ Custom method to get the union of all currently loaded
294
+ non absract subclasses
295
+ """
296
+
297
+ # If the class is not abstract return self
298
+ if not _is_abstract(cls):
299
+ return cls
300
+
301
+ subclasses = list(get_known_concrete_subclasses(cls))
302
+ if not subclasses:
303
+ return cls
304
+
305
+ if len(subclasses) == 1:
306
+ # Returning the concrete type ensures Pydantic instantiates the subclass
307
+ # (e.g. Agent) rather than the abstract base (e.g. AgentBase) when there is
308
+ # only ONE concrete subclass.
309
+ return subclasses[0]
310
+
311
+ serializable_type = Annotated[
312
+ Union[*tuple(Annotated[t, Tag(t.__name__)] for t in subclasses)],
313
+ Discriminator(kind_of),
314
+ ]
315
+ return serializable_type # type: ignore
316
+
317
+ @classmethod
318
+ def model_validate(cls, obj: Any, **kwargs) -> Self:
319
+ try:
320
+ if _is_abstract(cls):
321
+ resolved = cls.resolve_kind(kind_of(obj))
322
+ else:
323
+ resolved = super()
324
+ result = resolved.model_validate(obj, **kwargs)
325
+ return result # type: ignore
326
+ except ValidationError as e:
327
+ valid_kinds = [
328
+ subclass.__name__ for subclass in get_known_concrete_subclasses(cls)
329
+ ]
330
+ _handle_discriminated_union_validation_error(e, cls.__name__, valid_kinds)
331
+
332
+ @classmethod
333
+ def model_validate_json(
334
+ cls,
335
+ json_data: str | bytes | bytearray,
336
+ **kwargs,
337
+ ) -> Self:
338
+ data = json.loads(json_data)
339
+ if _is_abstract(cls):
340
+ resolved = cls.resolve_kind(kind_of(data))
341
+ else:
342
+ resolved = super()
343
+ result = resolved.model_validate(data, **kwargs)
344
+ return result # type: ignore
345
+
346
+ def __init_subclass__(cls, **kwargs):
347
+ super().__init_subclass__(**kwargs)
348
+
349
+ # If concrete, stamp kind Literal and collision check
350
+ if not _is_abstract(cls):
351
+ # 1) Stamp discriminator
352
+ cls.kind = cls.__name__
353
+ cls.__annotations__["kind"] = Literal[cls.__name__]
354
+
355
+ # 2) Collision check
356
+ mro = cls.mro()
357
+ union_class = mro[mro.index(DiscriminatedUnionMixin) - 1]
358
+ concretes = get_known_concrete_subclasses(union_class) # sorted list
359
+ kinds: dict[str, type] = {}
360
+ for sub in concretes:
361
+ k = kind_of(sub)
362
+ if k in kinds and kinds[k] is not sub:
363
+ raise ValueError(
364
+ f"Duplicate kind detected for {union_class} : {cls}, {sub}"
365
+ )
366
+ kinds[k] = sub
367
+
368
+ # Rebuild any abstract union owners in the MRO that rely on subclass sets
369
+ for base in cls.mro():
370
+ # Stop when we pass ourselves
371
+ if base is cls:
372
+ continue
373
+ # Only rebuild abstract DiscriminatedUnion owners
374
+ if (
375
+ isinstance(base, type)
376
+ and issubclass(base, DiscriminatedUnionMixin)
377
+ and _is_abstract(base)
378
+ ):
379
+ base.model_rebuild(force=True)
380
+
381
+
382
+ def _rebuild_if_required():
383
+ if _rebuild_required:
384
+ rebuild_all()
385
+
386
+
387
+ def _extract_discriminated_unions(schema: dict) -> dict:
388
+ """Extract inline discriminated unions as separate components.
389
+
390
+ Recursively scans the schema and extracts any inline discriminated union
391
+ (oneOf + discriminator + title) as a separate component, replacing it with a $ref.
392
+ Also deduplicates schemas with identical titles.
393
+ """
394
+ import json
395
+ import re
396
+ from collections import defaultdict
397
+
398
+ if not isinstance(schema, dict):
399
+ return schema
400
+
401
+ # OpenAPI schema names must match this pattern
402
+ valid_name_pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
403
+
404
+ schemas = schema.get("components", {}).get("schemas", {})
405
+ extracted = {}
406
+
407
+ def _find_and_extract(obj, path=""):
408
+ if not isinstance(obj, dict):
409
+ return obj
410
+
411
+ # Extract inline discriminated unions
412
+ if "oneOf" in obj and "discriminator" in obj and "title" in obj:
413
+ title = obj["title"]
414
+ if (
415
+ title not in schemas
416
+ and title not in extracted
417
+ and valid_name_pattern.match(title)
418
+ ):
419
+ extracted[title] = {
420
+ "oneOf": obj["oneOf"],
421
+ "discriminator": obj["discriminator"],
422
+ "title": title,
423
+ }
424
+ return {"$ref": f"#/components/schemas/{title}"}
425
+
426
+ # Recursively process nested structures
427
+ result = {}
428
+ for key, value in obj.items():
429
+ if isinstance(value, dict):
430
+ result[key] = _find_and_extract(value, f"{path}.{key}")
431
+ elif isinstance(value, list):
432
+ result[key] = [
433
+ _find_and_extract(item, f"{path}.{key}[]") for item in value
434
+ ]
435
+ else:
436
+ result[key] = value
437
+ return result
438
+
439
+ schema = _find_and_extract(schema)
440
+
441
+ if extracted and "components" in schema and "schemas" in schema["components"]:
442
+ schema["components"]["schemas"].update(extracted)
443
+
444
+ # Deduplicate schemas with same title (prefer *-Output over *-Input over base)
445
+ schemas = schema.get("components", {}).get("schemas", {})
446
+ title_to_names = defaultdict(list)
447
+ for name, defn in schemas.items():
448
+ if isinstance(defn, dict):
449
+ title_to_names[defn.get("title", name)].append(name)
450
+
451
+ to_remove = {}
452
+ for title, names in title_to_names.items():
453
+ if len(names) > 1:
454
+ # Prefer: *-Output > *-Input > base name
455
+ keep = sorted(
456
+ names,
457
+ key=lambda n: (
458
+ 0 if n.endswith("-Output") else 1 if n.endswith("-Input") else 2,
459
+ n,
460
+ ),
461
+ )[0]
462
+ for name in names:
463
+ if name != keep:
464
+ to_remove[name] = keep
465
+
466
+ if to_remove:
467
+ schema_str = json.dumps(schema)
468
+ for old, new in to_remove.items():
469
+ schema_str = schema_str.replace(
470
+ f'"#/components/schemas/{old}"', f'"#/components/schemas/{new}"'
471
+ )
472
+ schema = json.loads(schema_str)
473
+ for old in to_remove:
474
+ schema["components"]["schemas"].pop(old, None)
475
+
476
+ return schema
477
+
478
+
479
+ def _patch_fastapi_discriminated_union_support():
480
+ """Patch FastAPI to handle discriminated union schemas without $ref.
481
+
482
+ This ensures discriminated unions from DiscriminatedUnionMixin work correctly
483
+ with FastAPI's OpenAPI schema generation. The patch prevents KeyError when
484
+ FastAPI encounters schemas without $ref keys (which discriminated unions use).
485
+
486
+ Also extracts inline discriminated unions as separate schema components for
487
+ better OpenAPI documentation and Swagger UI display.
488
+
489
+ Skips patching if SKIP_FASTAPI_DISCRIMINATED_UNION_FIX environment variable is set.
490
+ """
491
+ # Skip patching if environment variable flag is defined
492
+ if os.environ.get("SKIP_FASTAPI_DISCRIMINATED_UNION_FIX"):
493
+ logger.debug(
494
+ "Skipping FastAPI discriminated union patch due to environment variable"
495
+ )
496
+ return
497
+
498
+ try:
499
+ import fastapi._compat.v2 as fastapi_v2
500
+ from fastapi import FastAPI
501
+
502
+ _original_remap = fastapi_v2._remap_definitions_and_field_mappings
503
+
504
+ def _patched_remap_definitions_and_field_mappings(**kwargs):
505
+ """Patched version that handles schemas w/o $ref (discriminated unions)."""
506
+ field_mapping = kwargs.get("field_mapping", {})
507
+ model_name_map = kwargs.get("model_name_map", {})
508
+
509
+ # Build old_name -> new_name map, skipping schemas without $ref
510
+ old_name_to_new_name_map = {}
511
+ for field_key, schema in field_mapping.items():
512
+ model = field_key[0].type_
513
+ if model not in model_name_map:
514
+ continue
515
+ new_name = model_name_map[model]
516
+
517
+ # Skip schemas without $ref (discriminated unions)
518
+ if "$ref" not in schema:
519
+ continue
520
+
521
+ old_name = schema["$ref"].split("/")[-1]
522
+ if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
523
+ continue
524
+ old_name_to_new_name_map[old_name] = new_name
525
+
526
+ # Replace refs using FastAPI's helper
527
+ from fastapi._compat.v2 import _replace_refs
528
+
529
+ new_field_mapping = {}
530
+ for field_key, schema in field_mapping.items():
531
+ new_schema = _replace_refs(
532
+ schema=schema,
533
+ old_name_to_new_name_map=old_name_to_new_name_map,
534
+ )
535
+ new_field_mapping[field_key] = new_schema
536
+
537
+ definitions = kwargs.get("definitions", {})
538
+ new_definitions = {}
539
+ for key, value in definitions.items():
540
+ new_key = old_name_to_new_name_map.get(key, key)
541
+ new_value = _replace_refs(
542
+ schema=value,
543
+ old_name_to_new_name_map=old_name_to_new_name_map,
544
+ )
545
+ new_definitions[new_key] = new_value
546
+
547
+ return new_field_mapping, new_definitions
548
+
549
+ # Apply the patch
550
+ fastapi_v2._remap_definitions_and_field_mappings = (
551
+ _patched_remap_definitions_and_field_mappings
552
+ )
553
+
554
+ # Patch FastAPI.openapi() to extract discriminated unions
555
+ _original_openapi = FastAPI.openapi
556
+
557
+ def _patched_openapi(self):
558
+ """Patched openapi() that extracts discriminated unions."""
559
+ schema = _original_openapi(self)
560
+ return _extract_discriminated_unions(schema)
561
+
562
+ FastAPI.openapi = _patched_openapi
563
+
564
+ except (ImportError, AttributeError):
565
+ # FastAPI not available or internal API changed
566
+ pass
567
+
568
+
569
+ # Always call the FastAPI patch after DiscriminatedUnionMixin definition
570
+ _patch_fastapi_discriminated_union_support()
@@ -0,0 +1,63 @@
1
+ """Pagination utilities for iterating over paginated search results."""
2
+
3
+ from collections.abc import AsyncGenerator, Awaitable, Callable
4
+ from typing import Any, Protocol
5
+
6
+
7
+ class PageProtocol[T](Protocol):
8
+ """Protocol for page objects returned by search functions.
9
+
10
+ All page objects should have:
11
+ - items: A list of items of type T
12
+ - next_page_id: Optional string for pagination
13
+ """
14
+
15
+ items: list[T]
16
+ next_page_id: str | None
17
+
18
+
19
+ async def page_iterator[T](
20
+ search_func: Callable[..., Awaitable[PageProtocol[T]]],
21
+ *args: Any,
22
+ **kwargs: Any,
23
+ ) -> AsyncGenerator[T, None]:
24
+ """
25
+ Iterate over items from paginated search results.
26
+
27
+ This utility function handles pagination automatically by calling the search
28
+ function repeatedly with updated page_id parameters until all pages are
29
+ exhausted.
30
+
31
+ Args:
32
+ search_func: An async function that returns a PageProtocol[T] object
33
+ with 'items' and 'next_page_id' attributes
34
+ *args: Positional arguments to pass to the search function
35
+ **kwargs: Keyword arguments to pass to the search function
36
+
37
+ Yields:
38
+ Individual items of type T from each page
39
+
40
+ Example:
41
+ async for event in page_iterator(event_service.search_events, limit=50):
42
+ await send_event(event, websocket)
43
+
44
+ async for conversation in page_iterator(
45
+ conversation_service.search_conversations,
46
+ execution_status=ConversationExecutionStatus.RUNNING
47
+ ):
48
+ print(conversation.title)
49
+ """
50
+ page_id = kwargs.pop("page_id", None)
51
+
52
+ while True:
53
+ # Call the search function with current page_id
54
+ page = await search_func(*args, page_id=page_id, **kwargs)
55
+
56
+ # Yield each item from the current page
57
+ for item in page.items:
58
+ yield item
59
+
60
+ # Check if there are more pages
61
+ page_id = page.next_page_id
62
+ if not page_id:
63
+ break