openhands-sdk 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. openhands/sdk/__init__.py +111 -0
  2. openhands/sdk/agent/__init__.py +8 -0
  3. openhands/sdk/agent/agent.py +650 -0
  4. openhands/sdk/agent/base.py +457 -0
  5. openhands/sdk/agent/prompts/in_context_learning_example.j2 +169 -0
  6. openhands/sdk/agent/prompts/in_context_learning_example_suffix.j2 +3 -0
  7. openhands/sdk/agent/prompts/model_specific/anthropic_claude.j2 +3 -0
  8. openhands/sdk/agent/prompts/model_specific/google_gemini.j2 +1 -0
  9. openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5-codex.j2 +2 -0
  10. openhands/sdk/agent/prompts/model_specific/openai_gpt/gpt-5.j2 +3 -0
  11. openhands/sdk/agent/prompts/security_policy.j2 +22 -0
  12. openhands/sdk/agent/prompts/security_risk_assessment.j2 +21 -0
  13. openhands/sdk/agent/prompts/self_documentation.j2 +15 -0
  14. openhands/sdk/agent/prompts/system_prompt.j2 +132 -0
  15. openhands/sdk/agent/prompts/system_prompt_interactive.j2 +14 -0
  16. openhands/sdk/agent/prompts/system_prompt_long_horizon.j2 +40 -0
  17. openhands/sdk/agent/prompts/system_prompt_planning.j2 +40 -0
  18. openhands/sdk/agent/prompts/system_prompt_tech_philosophy.j2 +122 -0
  19. openhands/sdk/agent/utils.py +228 -0
  20. openhands/sdk/context/__init__.py +28 -0
  21. openhands/sdk/context/agent_context.py +264 -0
  22. openhands/sdk/context/condenser/__init__.py +18 -0
  23. openhands/sdk/context/condenser/base.py +100 -0
  24. openhands/sdk/context/condenser/llm_summarizing_condenser.py +248 -0
  25. openhands/sdk/context/condenser/no_op_condenser.py +14 -0
  26. openhands/sdk/context/condenser/pipeline_condenser.py +56 -0
  27. openhands/sdk/context/condenser/prompts/summarizing_prompt.j2 +59 -0
  28. openhands/sdk/context/condenser/utils.py +149 -0
  29. openhands/sdk/context/prompts/__init__.py +6 -0
  30. openhands/sdk/context/prompts/prompt.py +114 -0
  31. openhands/sdk/context/prompts/templates/ask_agent_template.j2 +11 -0
  32. openhands/sdk/context/prompts/templates/skill_knowledge_info.j2 +8 -0
  33. openhands/sdk/context/prompts/templates/system_message_suffix.j2 +32 -0
  34. openhands/sdk/context/skills/__init__.py +28 -0
  35. openhands/sdk/context/skills/exceptions.py +11 -0
  36. openhands/sdk/context/skills/skill.py +720 -0
  37. openhands/sdk/context/skills/trigger.py +36 -0
  38. openhands/sdk/context/skills/types.py +48 -0
  39. openhands/sdk/context/view.py +503 -0
  40. openhands/sdk/conversation/__init__.py +40 -0
  41. openhands/sdk/conversation/base.py +281 -0
  42. openhands/sdk/conversation/conversation.py +152 -0
  43. openhands/sdk/conversation/conversation_stats.py +85 -0
  44. openhands/sdk/conversation/event_store.py +157 -0
  45. openhands/sdk/conversation/events_list_base.py +17 -0
  46. openhands/sdk/conversation/exceptions.py +50 -0
  47. openhands/sdk/conversation/fifo_lock.py +133 -0
  48. openhands/sdk/conversation/impl/__init__.py +5 -0
  49. openhands/sdk/conversation/impl/local_conversation.py +665 -0
  50. openhands/sdk/conversation/impl/remote_conversation.py +956 -0
  51. openhands/sdk/conversation/persistence_const.py +9 -0
  52. openhands/sdk/conversation/response_utils.py +41 -0
  53. openhands/sdk/conversation/secret_registry.py +126 -0
  54. openhands/sdk/conversation/serialization_diff.py +0 -0
  55. openhands/sdk/conversation/state.py +392 -0
  56. openhands/sdk/conversation/stuck_detector.py +311 -0
  57. openhands/sdk/conversation/title_utils.py +191 -0
  58. openhands/sdk/conversation/types.py +45 -0
  59. openhands/sdk/conversation/visualizer/__init__.py +12 -0
  60. openhands/sdk/conversation/visualizer/base.py +67 -0
  61. openhands/sdk/conversation/visualizer/default.py +373 -0
  62. openhands/sdk/critic/__init__.py +15 -0
  63. openhands/sdk/critic/base.py +38 -0
  64. openhands/sdk/critic/impl/__init__.py +12 -0
  65. openhands/sdk/critic/impl/agent_finished.py +83 -0
  66. openhands/sdk/critic/impl/empty_patch.py +49 -0
  67. openhands/sdk/critic/impl/pass_critic.py +42 -0
  68. openhands/sdk/event/__init__.py +42 -0
  69. openhands/sdk/event/base.py +149 -0
  70. openhands/sdk/event/condenser.py +82 -0
  71. openhands/sdk/event/conversation_error.py +25 -0
  72. openhands/sdk/event/conversation_state.py +104 -0
  73. openhands/sdk/event/llm_completion_log.py +39 -0
  74. openhands/sdk/event/llm_convertible/__init__.py +20 -0
  75. openhands/sdk/event/llm_convertible/action.py +139 -0
  76. openhands/sdk/event/llm_convertible/message.py +142 -0
  77. openhands/sdk/event/llm_convertible/observation.py +141 -0
  78. openhands/sdk/event/llm_convertible/system.py +61 -0
  79. openhands/sdk/event/token.py +16 -0
  80. openhands/sdk/event/types.py +11 -0
  81. openhands/sdk/event/user_action.py +21 -0
  82. openhands/sdk/git/exceptions.py +43 -0
  83. openhands/sdk/git/git_changes.py +249 -0
  84. openhands/sdk/git/git_diff.py +129 -0
  85. openhands/sdk/git/models.py +21 -0
  86. openhands/sdk/git/utils.py +189 -0
  87. openhands/sdk/hooks/__init__.py +30 -0
  88. openhands/sdk/hooks/config.py +180 -0
  89. openhands/sdk/hooks/conversation_hooks.py +227 -0
  90. openhands/sdk/hooks/executor.py +155 -0
  91. openhands/sdk/hooks/manager.py +170 -0
  92. openhands/sdk/hooks/types.py +40 -0
  93. openhands/sdk/io/__init__.py +6 -0
  94. openhands/sdk/io/base.py +48 -0
  95. openhands/sdk/io/cache.py +85 -0
  96. openhands/sdk/io/local.py +119 -0
  97. openhands/sdk/io/memory.py +54 -0
  98. openhands/sdk/llm/__init__.py +45 -0
  99. openhands/sdk/llm/exceptions/__init__.py +45 -0
  100. openhands/sdk/llm/exceptions/classifier.py +50 -0
  101. openhands/sdk/llm/exceptions/mapping.py +54 -0
  102. openhands/sdk/llm/exceptions/types.py +101 -0
  103. openhands/sdk/llm/llm.py +1140 -0
  104. openhands/sdk/llm/llm_registry.py +122 -0
  105. openhands/sdk/llm/llm_response.py +59 -0
  106. openhands/sdk/llm/message.py +656 -0
  107. openhands/sdk/llm/mixins/fn_call_converter.py +1288 -0
  108. openhands/sdk/llm/mixins/non_native_fc.py +97 -0
  109. openhands/sdk/llm/options/__init__.py +1 -0
  110. openhands/sdk/llm/options/chat_options.py +93 -0
  111. openhands/sdk/llm/options/common.py +19 -0
  112. openhands/sdk/llm/options/responses_options.py +67 -0
  113. openhands/sdk/llm/router/__init__.py +10 -0
  114. openhands/sdk/llm/router/base.py +117 -0
  115. openhands/sdk/llm/router/impl/multimodal.py +76 -0
  116. openhands/sdk/llm/router/impl/random.py +22 -0
  117. openhands/sdk/llm/streaming.py +9 -0
  118. openhands/sdk/llm/utils/metrics.py +312 -0
  119. openhands/sdk/llm/utils/model_features.py +192 -0
  120. openhands/sdk/llm/utils/model_info.py +90 -0
  121. openhands/sdk/llm/utils/model_prompt_spec.py +98 -0
  122. openhands/sdk/llm/utils/retry_mixin.py +128 -0
  123. openhands/sdk/llm/utils/telemetry.py +362 -0
  124. openhands/sdk/llm/utils/unverified_models.py +156 -0
  125. openhands/sdk/llm/utils/verified_models.py +65 -0
  126. openhands/sdk/logger/__init__.py +22 -0
  127. openhands/sdk/logger/logger.py +195 -0
  128. openhands/sdk/logger/rolling.py +113 -0
  129. openhands/sdk/mcp/__init__.py +24 -0
  130. openhands/sdk/mcp/client.py +76 -0
  131. openhands/sdk/mcp/definition.py +106 -0
  132. openhands/sdk/mcp/exceptions.py +19 -0
  133. openhands/sdk/mcp/tool.py +270 -0
  134. openhands/sdk/mcp/utils.py +83 -0
  135. openhands/sdk/observability/__init__.py +4 -0
  136. openhands/sdk/observability/laminar.py +166 -0
  137. openhands/sdk/observability/utils.py +20 -0
  138. openhands/sdk/py.typed +0 -0
  139. openhands/sdk/secret/__init__.py +19 -0
  140. openhands/sdk/secret/secrets.py +92 -0
  141. openhands/sdk/security/__init__.py +6 -0
  142. openhands/sdk/security/analyzer.py +111 -0
  143. openhands/sdk/security/confirmation_policy.py +61 -0
  144. openhands/sdk/security/llm_analyzer.py +29 -0
  145. openhands/sdk/security/risk.py +100 -0
  146. openhands/sdk/tool/__init__.py +34 -0
  147. openhands/sdk/tool/builtins/__init__.py +34 -0
  148. openhands/sdk/tool/builtins/finish.py +106 -0
  149. openhands/sdk/tool/builtins/think.py +117 -0
  150. openhands/sdk/tool/registry.py +184 -0
  151. openhands/sdk/tool/schema.py +286 -0
  152. openhands/sdk/tool/spec.py +39 -0
  153. openhands/sdk/tool/tool.py +481 -0
  154. openhands/sdk/utils/__init__.py +22 -0
  155. openhands/sdk/utils/async_executor.py +115 -0
  156. openhands/sdk/utils/async_utils.py +39 -0
  157. openhands/sdk/utils/cipher.py +68 -0
  158. openhands/sdk/utils/command.py +90 -0
  159. openhands/sdk/utils/deprecation.py +166 -0
  160. openhands/sdk/utils/github.py +44 -0
  161. openhands/sdk/utils/json.py +48 -0
  162. openhands/sdk/utils/models.py +570 -0
  163. openhands/sdk/utils/paging.py +63 -0
  164. openhands/sdk/utils/pydantic_diff.py +85 -0
  165. openhands/sdk/utils/pydantic_secrets.py +64 -0
  166. openhands/sdk/utils/truncate.py +117 -0
  167. openhands/sdk/utils/visualize.py +58 -0
  168. openhands/sdk/workspace/__init__.py +17 -0
  169. openhands/sdk/workspace/base.py +158 -0
  170. openhands/sdk/workspace/local.py +189 -0
  171. openhands/sdk/workspace/models.py +35 -0
  172. openhands/sdk/workspace/remote/__init__.py +8 -0
  173. openhands/sdk/workspace/remote/async_remote_workspace.py +149 -0
  174. openhands/sdk/workspace/remote/base.py +164 -0
  175. openhands/sdk/workspace/remote/remote_workspace_mixin.py +323 -0
  176. openhands/sdk/workspace/workspace.py +49 -0
  177. openhands_sdk-1.7.3.dist-info/METADATA +17 -0
  178. openhands_sdk-1.7.3.dist-info/RECORD +180 -0
  179. openhands_sdk-1.7.3.dist-info/WHEEL +5 -0
  180. openhands_sdk-1.7.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,312 @@
1
+ import copy
2
+ import time
3
+ from typing import final
4
+
5
+ from pydantic import BaseModel, Field, field_validator, model_validator
6
+
7
+
8
+ class Cost(BaseModel):
9
+ model: str
10
+ cost: float = Field(ge=0.0, description="Cost must be non-negative")
11
+ timestamp: float = Field(default_factory=time.time)
12
+
13
+ @field_validator("cost")
14
+ @classmethod
15
+ def validate_cost(cls, v: float) -> float:
16
+ if v < 0:
17
+ raise ValueError("Cost cannot be negative")
18
+ return v
19
+
20
+
21
+ class ResponseLatency(BaseModel):
22
+ """Metric tracking the round-trip time per completion call."""
23
+
24
+ model: str
25
+ latency: float = Field(ge=0.0, description="Latency must be non-negative")
26
+ response_id: str
27
+
28
+ @field_validator("latency")
29
+ @classmethod
30
+ def validate_latency(cls, v: float) -> float:
31
+ return max(0.0, v)
32
+
33
+
34
+ class TokenUsage(BaseModel):
35
+ """Metric tracking detailed token usage per completion call."""
36
+
37
+ model: str = Field(default="")
38
+ prompt_tokens: int = Field(
39
+ default=0, ge=0, description="Prompt tokens must be non-negative"
40
+ )
41
+ completion_tokens: int = Field(
42
+ default=0, ge=0, description="Completion tokens must be non-negative"
43
+ )
44
+ cache_read_tokens: int = Field(
45
+ default=0, ge=0, description="Cache read tokens must be non-negative"
46
+ )
47
+ cache_write_tokens: int = Field(
48
+ default=0, ge=0, description="Cache write tokens must be non-negative"
49
+ )
50
+ reasoning_tokens: int = Field(
51
+ default=0, ge=0, description="Reasoning tokens must be non-negative"
52
+ )
53
+ context_window: int = Field(
54
+ default=0, ge=0, description="Context window must be non-negative"
55
+ )
56
+ per_turn_token: int = Field(
57
+ default=0, ge=0, description="Per turn tokens must be non-negative"
58
+ )
59
+ response_id: str = Field(default="")
60
+
61
+ def __add__(self, other: "TokenUsage") -> "TokenUsage":
62
+ """Add two TokenUsage instances together."""
63
+ return TokenUsage(
64
+ model=self.model,
65
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
66
+ completion_tokens=self.completion_tokens + other.completion_tokens,
67
+ cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
68
+ cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
69
+ reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
70
+ context_window=max(self.context_window, other.context_window),
71
+ per_turn_token=other.per_turn_token,
72
+ response_id=self.response_id,
73
+ )
74
+
75
+
76
+ class MetricsSnapshot(BaseModel):
77
+ """A snapshot of metrics at a point in time.
78
+
79
+ Does not include lists of individual costs, latencies, or token usages.
80
+ """
81
+
82
+ model_name: str = Field(default="default", description="Name of the model")
83
+ accumulated_cost: float = Field(
84
+ default=0.0, ge=0.0, description="Total accumulated cost, must be non-negative"
85
+ )
86
+ max_budget_per_task: float | None = Field(
87
+ default=None, description="Maximum budget per task"
88
+ )
89
+ accumulated_token_usage: TokenUsage | None = Field(
90
+ default=None, description="Accumulated token usage across all calls"
91
+ )
92
+
93
+
94
+ @final
95
+ class Metrics(MetricsSnapshot):
96
+ """Metrics class can record various metrics during running and evaluation.
97
+ We track:
98
+ - accumulated_cost and costs
99
+ - max_budget_per_task (budget limit)
100
+ - A list of ResponseLatency
101
+ - A list of TokenUsage (one per call).
102
+ """
103
+
104
+ costs: list[Cost] = Field(
105
+ default_factory=list, description="List of individual costs"
106
+ )
107
+ response_latencies: list[ResponseLatency] = Field(
108
+ default_factory=list, description="List of response latencies"
109
+ )
110
+ token_usages: list[TokenUsage] = Field(
111
+ default_factory=list, description="List of token usage records"
112
+ )
113
+
114
+ @field_validator("accumulated_cost")
115
+ @classmethod
116
+ def validate_accumulated_cost(cls, v: float) -> float:
117
+ if v < 0:
118
+ raise ValueError("Total cost cannot be negative.")
119
+ return v
120
+
121
+ @model_validator(mode="after")
122
+ def initialize_accumulated_token_usage(self) -> "Metrics":
123
+ if self.accumulated_token_usage is None:
124
+ self.accumulated_token_usage = TokenUsage(
125
+ model=self.model_name,
126
+ prompt_tokens=0,
127
+ completion_tokens=0,
128
+ cache_read_tokens=0,
129
+ cache_write_tokens=0,
130
+ reasoning_tokens=0,
131
+ context_window=0,
132
+ response_id="",
133
+ )
134
+ return self
135
+
136
+ def get_snapshot(self) -> MetricsSnapshot:
137
+ """Get a snapshot of the current metrics without the detailed lists."""
138
+ return MetricsSnapshot(
139
+ model_name=self.model_name,
140
+ accumulated_cost=self.accumulated_cost,
141
+ max_budget_per_task=self.max_budget_per_task,
142
+ accumulated_token_usage=copy.deepcopy(self.accumulated_token_usage)
143
+ if self.accumulated_token_usage
144
+ else None,
145
+ )
146
+
147
+ def add_cost(self, value: float) -> None:
148
+ if value < 0:
149
+ raise ValueError("Added cost cannot be negative.")
150
+ self.accumulated_cost += value
151
+ self.costs.append(Cost(cost=value, model=self.model_name))
152
+
153
+ def add_response_latency(self, value: float, response_id: str) -> None:
154
+ self.response_latencies.append(
155
+ ResponseLatency(
156
+ latency=max(0.0, value), model=self.model_name, response_id=response_id
157
+ )
158
+ )
159
+
160
+ def add_token_usage(
161
+ self,
162
+ prompt_tokens: int,
163
+ completion_tokens: int,
164
+ cache_read_tokens: int,
165
+ cache_write_tokens: int,
166
+ context_window: int,
167
+ response_id: str,
168
+ reasoning_tokens: int = 0,
169
+ ) -> None:
170
+ """Add a single usage record."""
171
+ # Token each turn for calculating context usage.
172
+ per_turn_token = prompt_tokens + completion_tokens
173
+
174
+ usage = TokenUsage(
175
+ model=self.model_name,
176
+ prompt_tokens=prompt_tokens,
177
+ completion_tokens=completion_tokens,
178
+ cache_read_tokens=cache_read_tokens,
179
+ cache_write_tokens=cache_write_tokens,
180
+ reasoning_tokens=reasoning_tokens,
181
+ context_window=context_window,
182
+ per_turn_token=per_turn_token,
183
+ response_id=response_id,
184
+ )
185
+ self.token_usages.append(usage)
186
+
187
+ # Update accumulated token usage using the __add__ operator
188
+ new_usage = TokenUsage(
189
+ model=self.model_name,
190
+ prompt_tokens=prompt_tokens,
191
+ completion_tokens=completion_tokens,
192
+ cache_read_tokens=cache_read_tokens,
193
+ cache_write_tokens=cache_write_tokens,
194
+ reasoning_tokens=reasoning_tokens,
195
+ context_window=context_window,
196
+ per_turn_token=per_turn_token,
197
+ response_id="",
198
+ )
199
+ if self.accumulated_token_usage is None:
200
+ self.accumulated_token_usage = new_usage
201
+ else:
202
+ self.accumulated_token_usage = self.accumulated_token_usage + new_usage
203
+
204
+ def merge(self, other: "Metrics") -> None:
205
+ """Merge 'other' metrics into this one."""
206
+ self.accumulated_cost += other.accumulated_cost
207
+
208
+ # Keep the max_budget_per_task from other if it's set and this one isn't
209
+ if self.max_budget_per_task is None and other.max_budget_per_task is not None:
210
+ self.max_budget_per_task = other.max_budget_per_task
211
+
212
+ self.costs += other.costs
213
+ self.token_usages += other.token_usages
214
+ self.response_latencies += other.response_latencies
215
+
216
+ # Merge accumulated token usage using the __add__ operator
217
+ if self.accumulated_token_usage is None:
218
+ self.accumulated_token_usage = other.accumulated_token_usage
219
+ elif other.accumulated_token_usage is not None:
220
+ self.accumulated_token_usage = (
221
+ self.accumulated_token_usage + other.accumulated_token_usage
222
+ )
223
+
224
+ def get(self) -> dict:
225
+ """Return the metrics in a dictionary."""
226
+ return {
227
+ "accumulated_cost": self.accumulated_cost,
228
+ "max_budget_per_task": self.max_budget_per_task,
229
+ "accumulated_token_usage": self.accumulated_token_usage.model_dump()
230
+ if self.accumulated_token_usage
231
+ else None,
232
+ "costs": [cost.model_dump() for cost in self.costs],
233
+ "response_latencies": [
234
+ latency.model_dump() for latency in self.response_latencies
235
+ ],
236
+ "token_usages": [usage.model_dump() for usage in self.token_usages],
237
+ }
238
+
239
+ def log(self) -> str:
240
+ """Log the metrics."""
241
+ metrics = self.get()
242
+ logs = ""
243
+ for key, value in metrics.items():
244
+ logs += f"{key}: {value}\n"
245
+ return logs
246
+
247
+ def deep_copy(self) -> "Metrics":
248
+ """Create a deep copy of the Metrics object."""
249
+ return copy.deepcopy(self)
250
+
251
+ def diff(self, baseline: "Metrics") -> "Metrics":
252
+ """Calculate the difference between current metrics and a baseline.
253
+
254
+ This is useful for tracking metrics for specific operations like delegates.
255
+
256
+ Args:
257
+ baseline: A metrics object representing the baseline state
258
+
259
+ Returns:
260
+ A new Metrics object containing only the differences since the baseline
261
+ """
262
+ result = Metrics(model_name=self.model_name)
263
+
264
+ # Calculate cost difference
265
+ result.accumulated_cost = self.accumulated_cost - baseline.accumulated_cost
266
+
267
+ # Include only costs that were added after the baseline
268
+ if baseline.costs:
269
+ last_baseline_timestamp = baseline.costs[-1].timestamp
270
+ result.costs = [
271
+ cost for cost in self.costs if cost.timestamp > last_baseline_timestamp
272
+ ]
273
+ else:
274
+ result.costs = self.costs.copy()
275
+
276
+ # Include only response latencies that were added after the baseline
277
+ result.response_latencies = self.response_latencies[
278
+ len(baseline.response_latencies) :
279
+ ]
280
+
281
+ # Include only token usages that were added after the baseline
282
+ result.token_usages = self.token_usages[len(baseline.token_usages) :]
283
+
284
+ # Calculate accumulated token usage difference
285
+ base_usage = baseline.accumulated_token_usage
286
+ current_usage = self.accumulated_token_usage
287
+
288
+ if current_usage is not None and base_usage is not None:
289
+ result.accumulated_token_usage = TokenUsage(
290
+ model=self.model_name,
291
+ prompt_tokens=current_usage.prompt_tokens - base_usage.prompt_tokens,
292
+ completion_tokens=current_usage.completion_tokens
293
+ - base_usage.completion_tokens,
294
+ cache_read_tokens=current_usage.cache_read_tokens
295
+ - base_usage.cache_read_tokens,
296
+ cache_write_tokens=current_usage.cache_write_tokens
297
+ - base_usage.cache_write_tokens,
298
+ reasoning_tokens=current_usage.reasoning_tokens
299
+ - base_usage.reasoning_tokens,
300
+ context_window=current_usage.context_window,
301
+ per_turn_token=0,
302
+ response_id="",
303
+ )
304
+ elif current_usage is not None:
305
+ result.accumulated_token_usage = current_usage
306
+ else:
307
+ result.accumulated_token_usage = None
308
+
309
+ return result
310
+
311
+ def __repr__(self) -> str:
312
+ return f"Metrics({self.get()}"
@@ -0,0 +1,192 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ def model_matches(model: str, patterns: list[str]) -> bool:
5
+ """Return True if any pattern appears as a substring in the raw model name.
6
+
7
+ Matching semantics:
8
+ - Case-insensitive substring search on full raw model string
9
+ """
10
+ raw = (model or "").strip().lower()
11
+ for pat in patterns:
12
+ token = pat.strip().lower()
13
+ if token in raw:
14
+ return True
15
+ return False
16
+
17
+
18
+ def apply_ordered_model_rules(model: str, rules: list[str]) -> bool:
19
+ """Apply ordered include/exclude model rules to determine final support.
20
+
21
+ Rules semantics:
22
+ - Each entry is a substring token. '!' prefix marks an exclude rule.
23
+ - Case-insensitive substring matching against the raw model string.
24
+ - Evaluated in order; the last matching rule wins.
25
+ - If no rule matches, returns False.
26
+ """
27
+ raw = (model or "").strip().lower()
28
+ decided: bool | None = None
29
+ for rule in rules:
30
+ token = rule.strip().lower()
31
+ if not token:
32
+ continue
33
+ is_exclude = token.startswith("!")
34
+ core = token[1:] if is_exclude else token
35
+ if core and core in raw:
36
+ decided = not is_exclude
37
+ return bool(decided)
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class ModelFeatures:
42
+ supports_reasoning_effort: bool
43
+ supports_extended_thinking: bool
44
+ supports_prompt_cache: bool
45
+ supports_stop_words: bool
46
+ supports_responses_api: bool
47
+ force_string_serializer: bool
48
+ send_reasoning_content: bool
49
+ supports_prompt_cache_retention: bool
50
+
51
+
52
+ # Model lists capturing current behavior. Keep entries lowercase.
53
+
54
+ REASONING_EFFORT_MODELS: list[str] = [
55
+ # Mirror main behavior exactly (no unintended expansion)
56
+ "o1-2024-12-17",
57
+ "o1",
58
+ "o3",
59
+ "o3-2025-04-16",
60
+ "o3-mini-2025-01-31",
61
+ "o3-mini",
62
+ "o4-mini",
63
+ "o4-mini-2025-04-16",
64
+ "gemini-2.5-flash",
65
+ "gemini-2.5-pro",
66
+ # OpenAI GPT-5 family (includes mini variants)
67
+ "gpt-5",
68
+ # Anthropic Opus 4.5
69
+ "claude-opus-4-5",
70
+ # Nova 2 Lite
71
+ "nova-2-lite",
72
+ ]
73
+
74
+ EXTENDED_THINKING_MODELS: list[str] = [
75
+ # Anthropic model family
76
+ # We did not include sonnet 3.7 and 4 here as they don't brings
77
+ # significant performance improvements for agents
78
+ "claude-sonnet-4-5",
79
+ "claude-haiku-4-5",
80
+ ]
81
+
82
+ PROMPT_CACHE_MODELS: list[str] = [
83
+ "claude-3-7-sonnet",
84
+ "claude-sonnet-3-7-latest",
85
+ "claude-3-5-sonnet",
86
+ "claude-3-5-haiku",
87
+ "claude-3-haiku-20240307",
88
+ "claude-3-opus-20240229",
89
+ "claude-sonnet-4",
90
+ "claude-opus-4",
91
+ # Anthropic Haiku 4.5 variants (dash only; official IDs use hyphens)
92
+ "claude-haiku-4-5",
93
+ "claude-opus-4-5",
94
+ ]
95
+
96
+ # Models that support a top-level prompt_cache_retention parameter
97
+ # Source: OpenAI Prompt Caching docs (extended retention), which list:
98
+ # - gpt-5.2
99
+ # - gpt-5.1
100
+ # - gpt-5.1-codex
101
+ # - gpt-5.1-codex-mini
102
+ # - gpt-5.1-chat-latest
103
+ # - gpt-5
104
+ # - gpt-5-codex
105
+ # - gpt-4.1
106
+ # Use ordered include/exclude rules (last wins) to naturally express exceptions.
107
+ PROMPT_CACHE_RETENTION_MODELS: list[str] = [
108
+ # Broad allow for GPT-5 family and GPT-4.1 (covers gpt-5.2 and variants)
109
+ "gpt-5",
110
+ "gpt-4.1",
111
+ # Exclude all mini variants by default
112
+ "!mini",
113
+ # Re-allow the explicitly documented supported mini variant
114
+ "gpt-5.1-codex-mini",
115
+ ]
116
+
117
+ SUPPORTS_STOP_WORDS_FALSE_MODELS: list[str] = [
118
+ # o-series families don't support stop words
119
+ "o1",
120
+ "o3",
121
+ # grok-4 specific model name (basename)
122
+ "grok-4-0709",
123
+ "grok-code-fast-1",
124
+ # DeepSeek R1 family
125
+ "deepseek-r1-0528",
126
+ ]
127
+
128
+ # Models that should use the OpenAI Responses API path by default
129
+ RESPONSES_API_MODELS: list[str] = [
130
+ # OpenAI GPT-5 family (includes mini variants)
131
+ "gpt-5",
132
+ # OpenAI Codex (uses Responses API)
133
+ "codex-mini-latest",
134
+ ]
135
+
136
+ # Models that require string serializer for tool messages
137
+ # These models don't support structured content format [{"type":"text","text":"..."}]
138
+ # and need plain strings instead
139
+ # NOTE: model_matches uses case-insensitive substring matching, not globbing.
140
+ # Keep these entries as bare substrings without wildcards.
141
+ FORCE_STRING_SERIALIZER_MODELS: list[str] = [
142
+ "deepseek", # e.g., DeepSeek-V3.2-Exp
143
+ "glm", # e.g., GLM-4.5 / GLM-4.6
144
+ # Kimi K2-Instruct requires string serialization only on Groq
145
+ "groq/kimi-k2-instruct", # explicit provider-prefixed IDs
146
+ # MiniMax-M2 via OpenRouter rejects array content with
147
+ # "Input should be a valid string" for ChatCompletionToolMessage.content
148
+ "openrouter/minimax",
149
+ ]
150
+
151
+ # Models that we should send full reasoning content
152
+ # in the message input
153
+ SEND_REASONING_CONTENT_MODELS: list[str] = [
154
+ "kimi-k2-thinking",
155
+ "openrouter/minimax-m2", # MiniMax-M2 via OpenRouter (interleaved thinking)
156
+ "deepseek/deepseek-reasoner",
157
+ ]
158
+
159
+
160
+ def get_features(model: str) -> ModelFeatures:
161
+ """Get model features."""
162
+ return ModelFeatures(
163
+ supports_reasoning_effort=model_matches(model, REASONING_EFFORT_MODELS),
164
+ supports_extended_thinking=model_matches(model, EXTENDED_THINKING_MODELS),
165
+ supports_prompt_cache=model_matches(model, PROMPT_CACHE_MODELS),
166
+ supports_stop_words=not model_matches(model, SUPPORTS_STOP_WORDS_FALSE_MODELS),
167
+ supports_responses_api=model_matches(model, RESPONSES_API_MODELS),
168
+ force_string_serializer=model_matches(model, FORCE_STRING_SERIALIZER_MODELS),
169
+ send_reasoning_content=model_matches(model, SEND_REASONING_CONTENT_MODELS),
170
+ # Extended prompt_cache_retention support follows ordered include/exclude rules.
171
+ supports_prompt_cache_retention=apply_ordered_model_rules(
172
+ model, PROMPT_CACHE_RETENTION_MODELS
173
+ ),
174
+ )
175
+
176
+
177
+ # Default temperature mapping.
178
+ # Each entry: (pattern, default_temperature)
179
+ DEFAULT_TEMPERATURE_MODELS: list[tuple[str, float]] = [
180
+ ("kimi-k2-thinking", 1.0),
181
+ ]
182
+
183
+
184
+ def get_default_temperature(model: str) -> float:
185
+ """Return the default temperature for a given model pattern.
186
+
187
+ Uses case-insensitive substring matching via model_matches.
188
+ """
189
+ for pattern, value in DEFAULT_TEMPERATURE_MODELS:
190
+ if model_matches(model, [pattern]):
191
+ return value
192
+ return 0.0
@@ -0,0 +1,90 @@
1
+ import time
2
+ from functools import lru_cache
3
+ from logging import getLogger
4
+
5
+ import httpx
6
+ from litellm.types.utils import ModelInfo
7
+ from litellm.utils import get_model_info
8
+ from pydantic import SecretStr
9
+
10
+
11
+ logger = getLogger(__name__)
12
+
13
+
14
+ @lru_cache
15
+ def _get_model_info_from_litellm_proxy(
16
+ secret_api_key: SecretStr | str | None,
17
+ base_url: str,
18
+ model: str,
19
+ cache_key: int | None = None,
20
+ ):
21
+ logger.debug(f"Get model_info_from_litellm_proxy:{cache_key}")
22
+ try:
23
+ headers = {}
24
+ if isinstance(secret_api_key, SecretStr):
25
+ secret_api_key = secret_api_key.get_secret_value()
26
+ if secret_api_key:
27
+ headers["Authorization"] = f"Bearer {secret_api_key}"
28
+
29
+ response = httpx.get(f"{base_url}/v1/model/info", headers=headers)
30
+ data = response.json().get("data", [])
31
+ current = next(
32
+ (
33
+ info
34
+ for info in data
35
+ if info["model_name"] == model.removeprefix("litellm_proxy/")
36
+ ),
37
+ None,
38
+ )
39
+ if current:
40
+ model_info = current.get("model_info")
41
+ logger.debug(f"Got model info from litellm proxy: {model_info}")
42
+ return model_info
43
+ except Exception as e:
44
+ logger.debug(
45
+ f"Error fetching model info from proxy: {e}",
46
+ exc_info=True,
47
+ stack_info=True,
48
+ )
49
+
50
+
51
+ def get_litellm_model_info(
52
+ secret_api_key: SecretStr | str | None, base_url: str | None, model: str
53
+ ) -> ModelInfo | None:
54
+ # Try to get model info via openrouter or litellm proxy first
55
+ try:
56
+ if model.startswith("openrouter"):
57
+ model_info = get_model_info(model)
58
+ if model_info:
59
+ return model_info
60
+ except Exception as e:
61
+ logger.debug(f"get_model_info(openrouter) failed: {e}")
62
+
63
+ if model.startswith("litellm_proxy/") and base_url:
64
+ # Use the current hour as a cache key - only refresh hourly
65
+ cache_key = int(time.time() / 3600)
66
+
67
+ model_info = _get_model_info_from_litellm_proxy(
68
+ secret_api_key=secret_api_key,
69
+ base_url=base_url,
70
+ model=model,
71
+ cache_key=cache_key,
72
+ )
73
+ if model_info:
74
+ return model_info
75
+
76
+ # Fallbacks: try base name variants
77
+ try:
78
+ model_info = get_model_info(model.split(":")[0])
79
+ if model_info:
80
+ return model_info
81
+ except Exception:
82
+ pass
83
+ try:
84
+ model_info = get_model_info(model.split("/")[-1])
85
+ if model_info:
86
+ return model_info
87
+ except Exception:
88
+ pass
89
+
90
+ return None
@@ -0,0 +1,98 @@
1
+ """Utilities for detecting model families and variants.
2
+
3
+ These helpers allow prompts and other systems to tailor behavior for specific
4
+ LLM providers while keeping naming heuristics centralized.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pydantic import BaseModel, ConfigDict
10
+
11
+
12
+ class ModelPromptSpec(BaseModel):
13
+ """Detected prompt metadata for a given model configuration."""
14
+
15
+ model_config = ConfigDict(frozen=True)
16
+
17
+ family: str | None = None
18
+ variant: str | None = None
19
+
20
+
21
+ _MODEL_FAMILY_PATTERNS: dict[str, tuple[str, ...]] = {
22
+ "openai_gpt": (
23
+ "gpt-",
24
+ "o1",
25
+ "o3",
26
+ "o4",
27
+ ),
28
+ "anthropic_claude": ("claude",),
29
+ "google_gemini": ("gemini",),
30
+ "meta_llama": ("llama",),
31
+ "mistral": ("mistral",),
32
+ "deepseek": ("deepseek",),
33
+ "alibaba_qwen": ("qwen",),
34
+ }
35
+
36
+ # Ordered heuristics to pick the most specific variant available for a family.
37
+ _MODEL_VARIANT_PATTERNS: dict[str, tuple[tuple[str, tuple[str, ...]], ...]] = {
38
+ "openai_gpt": (
39
+ ("gpt-5-codex", ("gpt-5-codex", "gpt-5.1-codex")),
40
+ ("gpt-5", ("gpt-5", "gpt-5.1")),
41
+ ),
42
+ }
43
+
44
+
45
+ def _normalize(name: str | None) -> str:
46
+ return (name or "").strip().lower()
47
+
48
+
49
+ def _match_family(model_name: str) -> str | None:
50
+ normalized = _normalize(model_name)
51
+ if not normalized:
52
+ return None
53
+
54
+ for family, patterns in _MODEL_FAMILY_PATTERNS.items():
55
+ if any(pattern in normalized for pattern in patterns):
56
+ return family
57
+ return None
58
+
59
+
60
+ def _match_variant(
61
+ family: str,
62
+ model_name: str,
63
+ canonical_name: str | None = None,
64
+ ) -> str | None:
65
+ patterns = _MODEL_VARIANT_PATTERNS.get(family)
66
+ if not patterns:
67
+ return None
68
+
69
+ # Choose canonical_name if available, otherwise fall back to model_name
70
+ candidate = _normalize(canonical_name) or _normalize(model_name)
71
+ if not candidate:
72
+ return None
73
+
74
+ for variant, substrings in patterns:
75
+ if any(sub in candidate for sub in substrings):
76
+ return variant
77
+
78
+ return None
79
+
80
+
81
+ def get_model_prompt_spec(
82
+ model_name: str,
83
+ canonical_name: str | None = None,
84
+ ) -> ModelPromptSpec:
85
+ """Return family and variant prompt metadata for the given identifiers."""
86
+
87
+ family = _match_family(model_name)
88
+ if family is None and canonical_name:
89
+ family = _match_family(canonical_name)
90
+
91
+ variant = None
92
+ if family is not None:
93
+ variant = _match_variant(family, model_name, canonical_name)
94
+
95
+ return ModelPromptSpec(family=family, variant=variant)
96
+
97
+
98
+ __all__ = ["ModelPromptSpec", "get_model_prompt_spec"]