holmesgpt 0.14.3a0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of holmesgpt might be problematic. Click here for more details.

Files changed (30) hide show
  1. holmes/__init__.py +1 -1
  2. holmes/clients/robusta_client.py +12 -10
  3. holmes/common/env_vars.py +14 -0
  4. holmes/config.py +51 -4
  5. holmes/core/conversations.py +3 -2
  6. holmes/core/llm.py +198 -72
  7. holmes/core/openai_formatting.py +13 -0
  8. holmes/core/tool_calling_llm.py +129 -95
  9. holmes/core/tools.py +21 -1
  10. holmes/core/tools_utils/token_counting.py +2 -1
  11. holmes/core/tools_utils/tool_context_window_limiter.py +13 -4
  12. holmes/interactive.py +17 -7
  13. holmes/plugins/prompts/_general_instructions.jinja2 +1 -2
  14. holmes/plugins/toolsets/__init__.py +4 -0
  15. holmes/plugins/toolsets/atlas_mongodb/mongodb_atlas.py +0 -1
  16. holmes/plugins/toolsets/azure_sql/azure_sql_toolset.py +0 -1
  17. holmes/plugins/toolsets/grafana/grafana_api.py +1 -1
  18. holmes/plugins/toolsets/investigator/core_investigation.py +14 -13
  19. holmes/plugins/toolsets/opensearch/opensearch_ppl_query_docs.jinja2 +1616 -0
  20. holmes/plugins/toolsets/opensearch/opensearch_query_assist.py +78 -0
  21. holmes/plugins/toolsets/opensearch/opensearch_query_assist_instructions.jinja2 +223 -0
  22. holmes/plugins/toolsets/prometheus/prometheus.py +7 -4
  23. holmes/plugins/toolsets/service_discovery.py +1 -1
  24. holmes/plugins/toolsets/servicenow/servicenow.py +0 -1
  25. holmes/utils/stream.py +30 -1
  26. {holmesgpt-0.14.3a0.dist-info → holmesgpt-0.15.0.dist-info}/METADATA +3 -1
  27. {holmesgpt-0.14.3a0.dist-info → holmesgpt-0.15.0.dist-info}/RECORD +30 -27
  28. {holmesgpt-0.14.3a0.dist-info → holmesgpt-0.15.0.dist-info}/LICENSE.txt +0 -0
  29. {holmesgpt-0.14.3a0.dist-info → holmesgpt-0.15.0.dist-info}/WHEEL +0 -0
  30. {holmesgpt-0.14.3a0.dist-info → holmesgpt-0.15.0.dist-info}/entry_points.txt +0 -0
holmes/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # This is patched by github actions during release
2
- __version__ = "0.14.3-alpha"
2
+ __version__ = "0.15.0"
3
3
 
4
4
  # Re-export version functions from version module for backward compatibility
5
5
  from .version import (
@@ -1,8 +1,8 @@
1
1
  import logging
2
- from typing import List, Optional, Dict, Any
2
+ from typing import Optional, Dict, Any
3
3
  import requests # type: ignore
4
4
  from functools import cache
5
- from pydantic import BaseModel, ConfigDict, Field
5
+ from pydantic import BaseModel, ConfigDict
6
6
  from holmes.common.env_vars import ROBUSTA_API_ENDPOINT
7
7
 
8
8
  HOLMES_GET_INFO_URL = f"{ROBUSTA_API_ENDPOINT}/api/holmes/get_info"
@@ -14,13 +14,15 @@ class HolmesInfo(BaseModel):
14
14
  latest_version: Optional[str] = None
15
15
 
16
16
 
17
- class RobustaModelsResponse(BaseModel):
17
+ class RobustaModel(BaseModel):
18
18
  model_config = ConfigDict(extra="ignore")
19
- models: List[str]
20
- models_args: Dict[str, Any] = Field(
21
- default_factory=dict, alias="models_holmes_args"
22
- )
23
- default_model: Optional[str] = None
19
+ model: str
20
+ holmes_args: Optional[dict[str, Any]] = None
21
+ is_default: bool = False
22
+
23
+
24
+ class RobustaModelsResponse(BaseModel):
25
+ models: Dict[str, RobustaModel]
24
26
 
25
27
 
26
28
  @cache
@@ -30,13 +32,13 @@ def fetch_robusta_models(
30
32
  try:
31
33
  session_request = {"session_token": token, "account_id": account_id}
32
34
  resp = requests.post(
33
- f"{ROBUSTA_API_ENDPOINT}/api/llm/models",
35
+ f"{ROBUSTA_API_ENDPOINT}/api/llm/models/v2",
34
36
  json=session_request,
35
37
  timeout=10,
36
38
  )
37
39
  resp.raise_for_status()
38
40
  response_json = resp.json()
39
- return RobustaModelsResponse(**response_json)
41
+ return RobustaModelsResponse(**{"models": response_json})
40
42
  except Exception:
41
43
  logging.exception("Failed to fetch robusta models")
42
44
  return None
holmes/common/env_vars.py CHANGED
@@ -2,6 +2,16 @@ import os
2
2
  import json
3
3
  from typing import Optional
4
4
 
5
+ # Recommended models for different providers
6
+ RECOMMENDED_OPENAI_MODEL = "gpt-4.1"
7
+ RECOMMENDED_ANTHROPIC_MODEL = "anthropic/claude-opus-4-1-20250805"
8
+
9
+ # Default model for HolmesGPT
10
+ DEFAULT_MODEL = RECOMMENDED_OPENAI_MODEL
11
+ FALLBACK_CONTEXT_WINDOW_SIZE = (
12
+ 200000 # Fallback context window size if it can't be determined from the model
13
+ )
14
+
5
15
 
6
16
  def load_bool(env_var, default: Optional[bool]) -> Optional[bool]:
7
17
  env_value = os.environ.get(env_var)
@@ -38,6 +48,7 @@ DEVELOPMENT_MODE = load_bool("DEVELOPMENT_MODE", False)
38
48
  SENTRY_DSN = os.environ.get("SENTRY_DSN", "")
39
49
  SENTRY_TRACES_SAMPLE_RATE = float(os.environ.get("SENTRY_TRACES_SAMPLE_RATE", "0.0"))
40
50
 
51
+ EXTRA_HEADERS = os.environ.get("EXTRA_HEADERS", "")
41
52
  THINKING = os.environ.get("THINKING", "")
42
53
  REASONING_EFFORT = os.environ.get("REASONING_EFFORT", "").strip().lower()
43
54
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.00000001"))
@@ -82,6 +93,9 @@ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT = float(
82
93
  os.environ.get("TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT", 15)
83
94
  )
84
95
 
96
+ # Absolute max tokens to allocate for a single tool response
97
+ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS = 25000
98
+
85
99
  MAX_EVIDENCE_DATA_CHARACTERS_BEFORE_TRUNCATION = int(
86
100
  os.environ.get("MAX_EVIDENCE_DATA_CHARACTERS_BEFORE_TRUNCATION", 3000)
87
101
  )
holmes/config.py CHANGED
@@ -45,6 +45,9 @@ class SupportedTicketSources(str, Enum):
45
45
 
46
46
  class Config(RobustaBaseConfig):
47
47
  model: Optional[str] = None
48
+ api_key: Optional[SecretStr] = (
49
+ None # if None, read from OPENAI_API_KEY or AZURE_OPENAI_ENDPOINT env var
50
+ )
48
51
  api_base: Optional[str] = None
49
52
  api_version: Optional[str] = None
50
53
  fast_model: Optional[str] = None
@@ -95,6 +98,7 @@ class Config(RobustaBaseConfig):
95
98
  mcp_servers: Optional[dict[str, dict[str, Any]]] = None
96
99
 
97
100
  _server_tool_executor: Optional[ToolExecutor] = None
101
+ _agui_tool_executor: Optional[ToolExecutor] = None
98
102
 
99
103
  # TODO: Separate those fields to facade class, this shouldn't be part of the config.
100
104
  _toolset_manager: Optional[ToolsetManager] = PrivateAttr(None)
@@ -242,6 +246,23 @@ class Config(RobustaBaseConfig):
242
246
  )
243
247
  return ToolExecutor(cli_toolsets)
244
248
 
249
+ def create_agui_tool_executor(self, dal: Optional["SupabaseDal"]) -> ToolExecutor:
250
+ """
251
+ Creates ToolExecutor for the AG-UI server endpoints
252
+ """
253
+
254
+ if self._agui_tool_executor:
255
+ return self._agui_tool_executor
256
+
257
+ # Use same toolset as CLI for AG-UI front-end.
258
+ agui_toolsets = self.toolset_manager.list_console_toolsets(
259
+ dal=dal, refresh_status=True
260
+ )
261
+
262
+ self._agui_tool_executor = ToolExecutor(agui_toolsets)
263
+
264
+ return self._agui_tool_executor
265
+
245
266
  def create_tool_executor(self, dal: Optional["SupabaseDal"]) -> ToolExecutor:
246
267
  """
247
268
  Creates ToolExecutor for the server endpoints
@@ -273,6 +294,19 @@ class Config(RobustaBaseConfig):
273
294
  tool_executor, self.max_steps, self._get_llm(tracer=tracer)
274
295
  )
275
296
 
297
+ def create_agui_toolcalling_llm(
298
+ self,
299
+ dal: Optional["SupabaseDal"] = None,
300
+ model: Optional[str] = None,
301
+ tracer=None,
302
+ ) -> "ToolCallingLLM":
303
+ tool_executor = self.create_agui_tool_executor(dal)
304
+ from holmes.core.tool_calling_llm import ToolCallingLLM
305
+
306
+ return ToolCallingLLM(
307
+ tool_executor, self.max_steps, self._get_llm(model, tracer)
308
+ )
309
+
276
310
  def create_toolcalling_llm(
277
311
  self,
278
312
  dal: Optional["SupabaseDal"] = None,
@@ -441,7 +475,8 @@ class Config(RobustaBaseConfig):
441
475
  # TODO: move this to the llm model registry
442
476
  def _get_llm(self, model_key: Optional[str] = None, tracer=None) -> "DefaultLLM":
443
477
  sentry_sdk.set_tag("requested_model", model_key)
444
- model_params = self.llm_model_registry.get_model_params(model_key)
478
+ model_entry = self.llm_model_registry.get_model_params(model_key)
479
+ model_params = model_entry.model_dump(exclude_none=True)
445
480
  api_base = self.api_base
446
481
  api_version = self.api_version
447
482
 
@@ -453,6 +488,8 @@ class Config(RobustaBaseConfig):
453
488
  api_key = f"{account_id} {token}"
454
489
  else:
455
490
  api_key = model_params.pop("api_key", None)
491
+ if api_key is not None:
492
+ api_key = api_key.get_secret_value()
456
493
 
457
494
  model = model_params.pop("model")
458
495
  # It's ok if the model does not have api base and api version, which are defaults to None.
@@ -463,10 +500,20 @@ class Config(RobustaBaseConfig):
463
500
  api_version = model_params.pop("api_version", api_version)
464
501
  model_name = model_params.pop("name", None) or model_key or model
465
502
  sentry_sdk.set_tag("model_name", model_name)
466
- logging.info(f"Creating LLM with model: {model_name}")
467
- return DefaultLLM(
468
- model, api_key, api_base, api_version, model_params, tracer, model_name
503
+ llm = DefaultLLM(
504
+ model=model,
505
+ api_key=api_key,
506
+ api_base=api_base,
507
+ api_version=api_version,
508
+ args=model_params,
509
+ tracer=tracer,
510
+ name=model_name,
511
+ is_robusta_model=is_robusta_model,
469
512
  ) # type: ignore
513
+ logging.info(
514
+ f"Using model: {model_name} ({llm.get_context_window_size():,} total tokens, {llm.get_maximum_output_token():,} output tokens)"
515
+ )
516
+ return llm
470
517
 
471
518
  def get_models_list(self) -> List[str]:
472
519
  if self.llm_model_registry and self.llm_model_registry.models:
@@ -26,7 +26,8 @@ def calculate_tool_size(
26
26
  return DEFAULT_TOOL_SIZE
27
27
 
28
28
  context_window = ai.llm.get_context_window_size()
29
- message_size_without_tools = ai.llm.count_tokens_for_message(messages_without_tools)
29
+ tokens = ai.llm.count_tokens(messages_without_tools)
30
+ message_size_without_tools = tokens.total_tokens
30
31
  maximum_output_token = ai.llm.get_maximum_output_token()
31
32
 
32
33
  tool_size = min(
@@ -372,13 +373,13 @@ def build_chat_messages(
372
373
  )
373
374
 
374
375
  ask = add_global_instructions_to_user_prompt(ask, global_instructions)
375
-
376
376
  conversation_history.append( # type: ignore
377
377
  {
378
378
  "role": "user",
379
379
  "content": ask,
380
380
  },
381
381
  )
382
+
382
383
  number_of_tools = len(
383
384
  [message for message in conversation_history if message.get("role") == "tool"] # type: ignore
384
385
  )
holmes/core/llm.py CHANGED
@@ -1,23 +1,31 @@
1
1
  import json
2
2
  import logging
3
+ import os
3
4
  from abc import abstractmethod
4
5
  from math import floor
5
- from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
6
7
 
8
+ import litellm
9
+ from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
7
10
  from litellm.types.utils import ModelResponse, TextCompletionResponse
8
11
  import sentry_sdk
12
+ from pydantic import BaseModel, ConfigDict, SecretStr
13
+ from typing_extensions import Self
14
+
15
+ from holmes.clients.robusta_client import (
16
+ RobustaModel,
17
+ RobustaModelsResponse,
18
+ fetch_robusta_models,
19
+ )
9
20
 
10
- from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
11
- from pydantic import BaseModel
12
- import litellm
13
- import os
14
- from holmes.clients.robusta_client import RobustaModelsResponse, fetch_robusta_models
15
21
  from holmes.common.env_vars import (
22
+ FALLBACK_CONTEXT_WINDOW_SIZE,
16
23
  LOAD_ALL_ROBUSTA_MODELS,
17
24
  REASONING_EFFORT,
18
25
  ROBUSTA_AI,
19
26
  ROBUSTA_API_ENDPOINT,
20
27
  THINKING,
28
+ EXTRA_HEADERS,
21
29
  )
22
30
  from holmes.core.supabase_dal import SupabaseDal
23
31
  from holmes.utils.env import environ_get_safe_int, replace_env_vars_values
@@ -36,6 +44,39 @@ OVERRIDE_MAX_CONTENT_SIZE = environ_get_safe_int("OVERRIDE_MAX_CONTENT_SIZE")
36
44
  ROBUSTA_AI_MODEL_NAME = "Robusta"
37
45
 
38
46
 
47
+ class TokenCountMetadata(BaseModel):
48
+ total_tokens: int
49
+ tools_tokens: int
50
+ system_tokens: int
51
+ user_tokens: int
52
+ tools_to_call_tokens: int
53
+ other_tokens: int
54
+
55
+
56
+ class ModelEntry(BaseModel):
57
+ """ModelEntry represents a single LLM model configuration."""
58
+
59
+ model: str
60
+ # TODO: the name field seems to be redundant, can we remove it?
61
+ name: Optional[str] = None
62
+ api_key: Optional[SecretStr] = None
63
+ base_url: Optional[str] = None
64
+ is_robusta_model: Optional[bool] = None
65
+ custom_args: Optional[Dict[str, Any]] = None
66
+
67
+ # LLM configurations used services like Azure OpenAI Service
68
+ api_base: Optional[str] = None
69
+ api_version: Optional[str] = None
70
+
71
+ model_config = ConfigDict(
72
+ extra="allow",
73
+ )
74
+
75
+ @classmethod
76
+ def load_from_dict(cls, data: dict) -> Self:
77
+ return cls.model_validate(data)
78
+
79
+
39
80
  class LLM:
40
81
  @abstractmethod
41
82
  def __init__(self):
@@ -50,7 +91,9 @@ class LLM:
50
91
  pass
51
92
 
52
93
  @abstractmethod
53
- def count_tokens_for_message(self, messages: list[dict]) -> int:
94
+ def count_tokens(
95
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
96
+ ) -> TokenCountMetadata:
54
97
  pass
55
98
 
56
99
  @abstractmethod
@@ -73,6 +116,7 @@ class DefaultLLM(LLM):
73
116
  api_base: Optional[str]
74
117
  api_version: Optional[str]
75
118
  args: Dict
119
+ is_robusta_model: bool
76
120
 
77
121
  def __init__(
78
122
  self,
@@ -83,6 +127,7 @@ class DefaultLLM(LLM):
83
127
  args: Optional[Dict] = None,
84
128
  tracer: Optional[Any] = None,
85
129
  name: Optional[str] = None,
130
+ is_robusta_model: bool = False,
86
131
  ):
87
132
  self.model = model
88
133
  self.api_key = api_key
@@ -91,8 +136,11 @@ class DefaultLLM(LLM):
91
136
  self.args = args or {}
92
137
  self.tracer = tracer
93
138
  self.name = name
139
+ self.is_robusta_model = is_robusta_model
94
140
  self.update_custom_args()
95
- self.check_llm(self.model, self.api_key, self.api_base, self.api_version)
141
+ self.check_llm(
142
+ self.model, self.api_key, self.api_base, self.api_version, self.args
143
+ )
96
144
 
97
145
  def update_custom_args(self):
98
146
  self.max_context_size = self.args.get("custom_args", {}).get("max_context_size")
@@ -104,7 +152,14 @@ class DefaultLLM(LLM):
104
152
  api_key: Optional[str],
105
153
  api_base: Optional[str],
106
154
  api_version: Optional[str],
155
+ args: Optional[dict] = None,
107
156
  ):
157
+ if self.is_robusta_model:
158
+ # The model is assumed correctly configured if it is a robusta model
159
+ # For robusta models, this code would fail because Holmes has no knowledge of the API keys
160
+ # to azure or bedrock as all completion API calls go through robusta's LLM proxy
161
+ return
162
+ args = args or {}
108
163
  logging.debug(f"Checking LiteLLM model {model}")
109
164
  lookup = litellm.get_llm_provider(model)
110
165
  if not lookup:
@@ -140,10 +195,17 @@ class DefaultLLM(LLM):
140
195
  "environment variable for proper functionality. For more information, refer to the documentation: "
141
196
  "https://docs.litellm.ai/docs/providers/watsonx#usage---models-in-deployment-spaces"
142
197
  )
143
- elif provider == "bedrock" and (
144
- os.environ.get("AWS_PROFILE") or os.environ.get("AWS_BEARER_TOKEN_BEDROCK")
145
- ):
146
- model_requirements = {"keys_in_environment": True, "missing_keys": []}
198
+ elif provider == "bedrock":
199
+ if os.environ.get("AWS_PROFILE") or os.environ.get(
200
+ "AWS_BEARER_TOKEN_BEDROCK"
201
+ ):
202
+ model_requirements = {"keys_in_environment": True, "missing_keys": []}
203
+ elif args.get("aws_access_key_id") and args.get("aws_secret_access_key"):
204
+ return # break fast.
205
+ else:
206
+ model_requirements = litellm.validate_environment(
207
+ model=model, api_key=api_key, api_base=api_base
208
+ )
147
209
  else:
148
210
  model_requirements = litellm.validate_environment(
149
211
  model=model, api_key=api_key, api_base=api_base
@@ -202,39 +264,78 @@ class DefaultLLM(LLM):
202
264
  # Log which lookups we tried
203
265
  logging.warning(
204
266
  f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
205
- f"using default 128k tokens for max_input_tokens. "
267
+ f"using default {FALLBACK_CONTEXT_WINDOW_SIZE} tokens for max_input_tokens. "
206
268
  f"To override, set OVERRIDE_MAX_CONTENT_SIZE environment variable to the correct value for your model."
207
269
  )
208
- return 128000
270
+ return FALLBACK_CONTEXT_WINDOW_SIZE
209
271
 
210
272
  @sentry_sdk.trace
211
- def count_tokens_for_message(self, messages: list[dict]) -> int:
212
- total_token_count = 0
273
+ def count_tokens(
274
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
275
+ ) -> TokenCountMetadata:
276
+ # TODO: Add a recount:bool flag to save time. When the flag is false, reuse 'message["token_count"]' for individual messages.
277
+ # It's only necessary to recount message tokens at the beginning of a session because the LLM model may have changed.
278
+ # Changing the model requires recounting tokens because the tokenizer may be different
279
+ total_tokens = 0
280
+ tools_tokens = 0
281
+ system_tokens = 0
282
+ user_tokens = 0
283
+ other_tokens = 0
284
+ tools_to_call_tokens = 0
213
285
  for message in messages:
214
- if "token_count" in message and message["token_count"]:
215
- total_token_count += message["token_count"]
286
+ # count message tokens individually because it gives us fine grain information about each tool call/message etc.
287
+ # However be aware that the sum of individual message tokens is not equal to the overall messages token
288
+ token_count = litellm.token_counter( # type: ignore
289
+ model=self.model, messages=[message]
290
+ )
291
+ message["token_count"] = token_count
292
+ role = message.get("role")
293
+ if role == "system":
294
+ system_tokens += token_count
295
+ elif role == "user":
296
+ user_tokens += token_count
297
+ elif role == "tool":
298
+ tools_tokens += token_count
216
299
  else:
217
- # message can be counted by this method only if message contains a "content" key
218
- if "content" in message:
219
- if isinstance(message["content"], str):
220
- message_to_count = [
221
- {"type": "text", "text": message["content"]}
222
- ]
223
- elif isinstance(message["content"], list):
224
- message_to_count = [
225
- {"type": "text", "text": json.dumps(message["content"])}
226
- ]
227
- elif isinstance(message["content"], dict):
228
- if "type" not in message["content"]:
229
- message_to_count = [
230
- {"type": "text", "text": json.dumps(message["content"])}
231
- ]
232
- token_count = litellm.token_counter(
233
- model=self.model, messages=message_to_count
234
- )
235
- message["token_count"] = token_count
236
- total_token_count += token_count
237
- return total_token_count
300
+ # although this should not be needed,
301
+ # it is defensive code so that all tokens are accounted for
302
+ # and can potentially make debugging easier
303
+ other_tokens += token_count
304
+
305
+ messages_token_count_without_tools = litellm.token_counter( # type: ignore
306
+ model=self.model, messages=messages
307
+ )
308
+
309
+ total_tokens = litellm.token_counter( # type: ignore
310
+ model=self.model,
311
+ messages=messages,
312
+ tools=tools, # type: ignore
313
+ )
314
+ tools_to_call_tokens = max(0, total_tokens - messages_token_count_without_tools)
315
+
316
+ return TokenCountMetadata(
317
+ total_tokens=total_tokens,
318
+ system_tokens=system_tokens,
319
+ user_tokens=user_tokens,
320
+ tools_tokens=tools_tokens,
321
+ tools_to_call_tokens=tools_to_call_tokens,
322
+ other_tokens=other_tokens,
323
+ )
324
+
325
+ def get_litellm_corrected_name_for_robusta_ai(self) -> str:
326
+ if self.is_robusta_model:
327
+ # For robusta models, self.model is the underlying provider/model used by Robusta AI
328
+ # To avoid litellm modifying the API URL according to the provider, the provider name
329
+ # is replaced with 'openai/' just before doing a completion() call
330
+ # Cf. https://docs.litellm.ai/docs/providers/openai_compatible
331
+ split_model_name = self.model.split("/")
332
+ return (
333
+ split_model_name[0]
334
+ if len(split_model_name) == 1
335
+ else f"openai/{split_model_name[1]}"
336
+ )
337
+ else:
338
+ return self.model
238
339
 
239
340
  def completion(
240
341
  self,
@@ -256,6 +357,9 @@ class DefaultLLM(LLM):
256
357
  if THINKING:
257
358
  self.args.setdefault("thinking", json.loads(THINKING))
258
359
 
360
+ if EXTRA_HEADERS:
361
+ self.args.setdefault("extra_headers", json.loads(EXTRA_HEADERS))
362
+
259
363
  if self.args.get("thinking", None):
260
364
  litellm.modify_params = True
261
365
 
@@ -271,8 +375,10 @@ class DefaultLLM(LLM):
271
375
 
272
376
  # Get the litellm module to use (wrapped or unwrapped)
273
377
  litellm_to_use = self.tracer.wrap_llm(litellm) if self.tracer else litellm
378
+
379
+ litellm_model_name = self.get_litellm_corrected_name_for_robusta_ai()
274
380
  result = litellm_to_use.completion(
275
- model=self.model,
381
+ model=litellm_model_name,
276
382
  api_key=self.api_key,
277
383
  base_url=self.api_base,
278
384
  api_version=self.api_version,
@@ -328,6 +434,12 @@ class DefaultLLM(LLM):
328
434
  Add cache_control to the last non-user message for Anthropic prompt caching.
329
435
  Removes any existing cache_control from previous messages to avoid accumulation.
330
436
  """
437
+ # Skip cache_control for VertexAI/Gemini models as they don't support it with tools
438
+ if self.model and (
439
+ "vertex" in self.model.lower() or "gemini" in self.model.lower()
440
+ ):
441
+ return
442
+
331
443
  # First, remove any existing cache_control from all messages
332
444
  for msg in messages:
333
445
  content = msg.get("content")
@@ -382,7 +494,7 @@ class DefaultLLM(LLM):
382
494
  class LLMModelRegistry:
383
495
  def __init__(self, config: "Config", dal: SupabaseDal) -> None:
384
496
  self.config = config
385
- self._llms: dict[str, dict[str, Any]] = {}
497
+ self._llms: dict[str, ModelEntry] = {}
386
498
  self._default_robusta_model = None
387
499
  self.dal = dal
388
500
 
@@ -404,6 +516,8 @@ class LLMModelRegistry:
404
516
  model_name=self.config.model,
405
517
  base_url=self.config.api_base,
406
518
  is_robusta_model=False,
519
+ api_key=self.config.api_key,
520
+ api_version=self.config.api_version,
407
521
  )
408
522
 
409
523
  def _should_load_config_model(self) -> bool:
@@ -414,7 +528,7 @@ class LLMModelRegistry:
414
528
  # so we need to check if the user has set an OPENAI_API_KEY to load the config model.
415
529
  has_openai_key = os.environ.get("OPENAI_API_KEY")
416
530
  if has_openai_key:
417
- self.config.model = "gpt-4o"
531
+ self.config.model = "gpt-4.1"
418
532
  return True
419
533
 
420
534
  return False
@@ -437,16 +551,18 @@ class LLMModelRegistry:
437
551
  self._load_default_robusta_config()
438
552
  return
439
553
 
440
- for model in robusta_models.models:
441
- logging.info(f"Loading Robusta AI model: {model}")
442
- args = robusta_models.models_args.get(model)
443
- self._llms[model] = self._create_robusta_model_entry(model, args)
444
-
445
- if robusta_models.default_model:
446
- logging.info(
447
- f"Setting default Robusta AI model to: {robusta_models.default_model}"
554
+ default_model = None
555
+ for model_name, model_data in robusta_models.models.items():
556
+ logging.info(f"Loading Robusta AI model: {model_name}")
557
+ self._llms[model_name] = self._create_robusta_model_entry(
558
+ model_name=model_name, model_data=model_data
448
559
  )
449
- self._default_robusta_model: str = robusta_models.default_model # type: ignore
560
+ if model_data.is_default:
561
+ default_model = model_name
562
+
563
+ if default_model:
564
+ logging.info(f"Setting default Robusta AI model to: {default_model}")
565
+ self._default_robusta_model: str = default_model # type: ignore
450
566
 
451
567
  except Exception:
452
568
  logging.exception("Failed to get all robusta models")
@@ -456,12 +572,12 @@ class LLMModelRegistry:
456
572
  def _load_default_robusta_config(self):
457
573
  if self._should_load_robusta_ai():
458
574
  logging.info("Loading default Robusta AI model")
459
- self._llms[ROBUSTA_AI_MODEL_NAME] = {
460
- "name": ROBUSTA_AI_MODEL_NAME,
461
- "base_url": ROBUSTA_API_ENDPOINT,
462
- "is_robusta_model": True,
463
- "model": "gpt-4o",
464
- }
575
+ self._llms[ROBUSTA_AI_MODEL_NAME] = ModelEntry(
576
+ name=ROBUSTA_AI_MODEL_NAME,
577
+ model="gpt-4o", # TODO: tech debt, this isn't really
578
+ base_url=ROBUSTA_API_ENDPOINT,
579
+ is_robusta_model=True,
580
+ )
465
581
  self._default_robusta_model = ROBUSTA_AI_MODEL_NAME
466
582
 
467
583
  def _should_load_robusta_ai(self) -> bool:
@@ -483,7 +599,7 @@ class LLMModelRegistry:
483
599
 
484
600
  return True
485
601
 
486
- def get_model_params(self, model_key: Optional[str] = None) -> dict:
602
+ def get_model_params(self, model_key: Optional[str] = None) -> ModelEntry:
487
603
  if not self._llms:
488
604
  raise Exception("No llm models were loaded")
489
605
 
@@ -515,26 +631,30 @@ class LLMModelRegistry:
515
631
  return self._llms[name] # type: ignore
516
632
 
517
633
  @property
518
- def models(self) -> dict[str, dict[str, Any]]:
634
+ def models(self) -> dict[str, ModelEntry]:
519
635
  return self._llms
520
636
 
521
- def _parse_models_file(self, path: str):
637
+ def _parse_models_file(self, path: str) -> dict[str, ModelEntry]:
522
638
  models = load_yaml_file(path, raise_error=False, warn_not_found=False)
523
639
  for _, params in models.items():
524
640
  params = replace_env_vars_values(params)
525
641
 
526
- return models
642
+ llms = {}
643
+ for model_name, params in models.items():
644
+ llms[model_name] = ModelEntry.model_validate(params)
645
+
646
+ return llms
527
647
 
528
648
  def _create_robusta_model_entry(
529
- self, model_name: str, args: Optional[dict[str, Any]] = None
530
- ) -> dict[str, Any]:
649
+ self, model_name: str, model_data: RobustaModel
650
+ ) -> ModelEntry:
531
651
  entry = self._create_model_entry(
532
- model="gpt-4o", # Robusta AI model is using openai like API.
652
+ model=model_data.model,
533
653
  model_name=model_name,
534
654
  base_url=f"{ROBUSTA_API_ENDPOINT}/llm/{model_name}",
535
655
  is_robusta_model=True,
536
656
  )
537
- entry["custom_args"] = args or {} # type: ignore[assignment]
657
+ entry.custom_args = model_data.holmes_args or {} # type: ignore[assignment]
538
658
  return entry
539
659
 
540
660
  def _create_model_entry(
@@ -543,13 +663,19 @@ class LLMModelRegistry:
543
663
  model_name: str,
544
664
  base_url: Optional[str] = None,
545
665
  is_robusta_model: Optional[bool] = None,
546
- ) -> dict[str, Any]:
547
- return {
548
- "name": model_name,
549
- "base_url": base_url,
550
- "is_robusta_model": is_robusta_model,
551
- "model": model,
552
- }
666
+ api_key: Optional[SecretStr] = None,
667
+ api_base: Optional[str] = None,
668
+ api_version: Optional[str] = None,
669
+ ) -> ModelEntry:
670
+ return ModelEntry(
671
+ name=model_name,
672
+ model=model,
673
+ base_url=base_url,
674
+ is_robusta_model=is_robusta_model,
675
+ api_key=api_key,
676
+ api_base=api_base,
677
+ api_version=api_version,
678
+ )
553
679
 
554
680
 
555
681
  def get_llm_usage(
@@ -80,6 +80,19 @@ def format_tool_to_open_ai_standard(
80
80
  )
81
81
  if param_attributes.description is not None:
82
82
  tool_properties[param_name]["description"] = param_attributes.description
83
+ # Add enum constraint if specified
84
+ if hasattr(param_attributes, "enum") and param_attributes.enum:
85
+ enum_values = list(
86
+ param_attributes.enum
87
+ ) # Create a copy to avoid modifying original
88
+ # In strict mode, optional parameters need None in their enum to match the type allowing null
89
+ if (
90
+ strict_mode
91
+ and not param_attributes.required
92
+ and None not in enum_values
93
+ ):
94
+ enum_values.append(None)
95
+ tool_properties[param_name]["enum"] = enum_values
83
96
 
84
97
  result: dict[str, Any] = {
85
98
  "type": "function",