holmesgpt 0.13.2__py3-none-any.whl → 0.16.2a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. holmes/__init__.py +1 -1
  2. holmes/clients/robusta_client.py +17 -4
  3. holmes/common/env_vars.py +40 -1
  4. holmes/config.py +114 -144
  5. holmes/core/conversations.py +53 -14
  6. holmes/core/feedback.py +191 -0
  7. holmes/core/investigation.py +18 -22
  8. holmes/core/llm.py +489 -88
  9. holmes/core/models.py +103 -1
  10. holmes/core/openai_formatting.py +13 -0
  11. holmes/core/prompt.py +1 -1
  12. holmes/core/safeguards.py +4 -4
  13. holmes/core/supabase_dal.py +293 -100
  14. holmes/core/tool_calling_llm.py +423 -323
  15. holmes/core/tools.py +311 -33
  16. holmes/core/tools_utils/token_counting.py +14 -0
  17. holmes/core/tools_utils/tool_context_window_limiter.py +57 -0
  18. holmes/core/tools_utils/tool_executor.py +13 -8
  19. holmes/core/toolset_manager.py +155 -4
  20. holmes/core/tracing.py +6 -1
  21. holmes/core/transformers/__init__.py +23 -0
  22. holmes/core/transformers/base.py +62 -0
  23. holmes/core/transformers/llm_summarize.py +174 -0
  24. holmes/core/transformers/registry.py +122 -0
  25. holmes/core/transformers/transformer.py +31 -0
  26. holmes/core/truncation/compaction.py +59 -0
  27. holmes/core/truncation/dal_truncation_utils.py +23 -0
  28. holmes/core/truncation/input_context_window_limiter.py +218 -0
  29. holmes/interactive.py +177 -24
  30. holmes/main.py +7 -4
  31. holmes/plugins/prompts/_fetch_logs.jinja2 +26 -1
  32. holmes/plugins/prompts/_general_instructions.jinja2 +1 -2
  33. holmes/plugins/prompts/_runbook_instructions.jinja2 +23 -12
  34. holmes/plugins/prompts/conversation_history_compaction.jinja2 +88 -0
  35. holmes/plugins/prompts/generic_ask.jinja2 +2 -4
  36. holmes/plugins/prompts/generic_ask_conversation.jinja2 +2 -1
  37. holmes/plugins/prompts/generic_ask_for_issue_conversation.jinja2 +2 -1
  38. holmes/plugins/prompts/generic_investigation.jinja2 +2 -1
  39. holmes/plugins/prompts/investigation_procedure.jinja2 +48 -0
  40. holmes/plugins/prompts/kubernetes_workload_ask.jinja2 +2 -1
  41. holmes/plugins/prompts/kubernetes_workload_chat.jinja2 +2 -1
  42. holmes/plugins/runbooks/__init__.py +117 -18
  43. holmes/plugins/runbooks/catalog.json +2 -0
  44. holmes/plugins/toolsets/__init__.py +21 -8
  45. holmes/plugins/toolsets/aks-node-health.yaml +46 -0
  46. holmes/plugins/toolsets/aks.yaml +64 -0
  47. holmes/plugins/toolsets/atlas_mongodb/mongodb_atlas.py +26 -36
  48. holmes/plugins/toolsets/azure_sql/azure_sql_toolset.py +0 -1
  49. holmes/plugins/toolsets/azure_sql/tools/analyze_connection_failures.py +10 -7
  50. holmes/plugins/toolsets/azure_sql/tools/analyze_database_connections.py +9 -6
  51. holmes/plugins/toolsets/azure_sql/tools/analyze_database_health_status.py +8 -6
  52. holmes/plugins/toolsets/azure_sql/tools/analyze_database_performance.py +8 -6
  53. holmes/plugins/toolsets/azure_sql/tools/analyze_database_storage.py +9 -6
  54. holmes/plugins/toolsets/azure_sql/tools/get_active_alerts.py +9 -7
  55. holmes/plugins/toolsets/azure_sql/tools/get_slow_queries.py +9 -6
  56. holmes/plugins/toolsets/azure_sql/tools/get_top_cpu_queries.py +9 -6
  57. holmes/plugins/toolsets/azure_sql/tools/get_top_data_io_queries.py +9 -6
  58. holmes/plugins/toolsets/azure_sql/tools/get_top_log_io_queries.py +9 -6
  59. holmes/plugins/toolsets/bash/bash_toolset.py +10 -13
  60. holmes/plugins/toolsets/bash/common/bash.py +7 -7
  61. holmes/plugins/toolsets/cilium.yaml +284 -0
  62. holmes/plugins/toolsets/coralogix/toolset_coralogix_logs.py +5 -3
  63. holmes/plugins/toolsets/datadog/datadog_api.py +490 -24
  64. holmes/plugins/toolsets/datadog/datadog_logs_instructions.jinja2 +21 -10
  65. holmes/plugins/toolsets/datadog/toolset_datadog_general.py +349 -216
  66. holmes/plugins/toolsets/datadog/toolset_datadog_logs.py +190 -19
  67. holmes/plugins/toolsets/datadog/toolset_datadog_metrics.py +101 -44
  68. holmes/plugins/toolsets/datadog/toolset_datadog_rds.py +13 -16
  69. holmes/plugins/toolsets/datadog/toolset_datadog_traces.py +25 -31
  70. holmes/plugins/toolsets/git.py +51 -46
  71. holmes/plugins/toolsets/grafana/common.py +15 -3
  72. holmes/plugins/toolsets/grafana/grafana_api.py +46 -24
  73. holmes/plugins/toolsets/grafana/grafana_tempo_api.py +454 -0
  74. holmes/plugins/toolsets/grafana/loki/instructions.jinja2 +9 -0
  75. holmes/plugins/toolsets/grafana/loki/toolset_grafana_loki.py +117 -0
  76. holmes/plugins/toolsets/grafana/toolset_grafana.py +211 -91
  77. holmes/plugins/toolsets/grafana/toolset_grafana_dashboard.jinja2 +27 -0
  78. holmes/plugins/toolsets/grafana/toolset_grafana_tempo.jinja2 +246 -11
  79. holmes/plugins/toolsets/grafana/toolset_grafana_tempo.py +653 -293
  80. holmes/plugins/toolsets/grafana/trace_parser.py +1 -1
  81. holmes/plugins/toolsets/internet/internet.py +6 -7
  82. holmes/plugins/toolsets/internet/notion.py +5 -6
  83. holmes/plugins/toolsets/investigator/core_investigation.py +42 -34
  84. holmes/plugins/toolsets/kafka.py +25 -36
  85. holmes/plugins/toolsets/kubernetes.yaml +58 -84
  86. holmes/plugins/toolsets/kubernetes_logs.py +6 -6
  87. holmes/plugins/toolsets/kubernetes_logs.yaml +32 -0
  88. holmes/plugins/toolsets/logging_utils/logging_api.py +80 -4
  89. holmes/plugins/toolsets/mcp/toolset_mcp.py +181 -55
  90. holmes/plugins/toolsets/newrelic/__init__.py +0 -0
  91. holmes/plugins/toolsets/newrelic/new_relic_api.py +125 -0
  92. holmes/plugins/toolsets/newrelic/newrelic.jinja2 +41 -0
  93. holmes/plugins/toolsets/newrelic/newrelic.py +163 -0
  94. holmes/plugins/toolsets/opensearch/opensearch.py +10 -17
  95. holmes/plugins/toolsets/opensearch/opensearch_logs.py +7 -7
  96. holmes/plugins/toolsets/opensearch/opensearch_ppl_query_docs.jinja2 +1616 -0
  97. holmes/plugins/toolsets/opensearch/opensearch_query_assist.py +78 -0
  98. holmes/plugins/toolsets/opensearch/opensearch_query_assist_instructions.jinja2 +223 -0
  99. holmes/plugins/toolsets/opensearch/opensearch_traces.py +13 -16
  100. holmes/plugins/toolsets/openshift.yaml +283 -0
  101. holmes/plugins/toolsets/prometheus/prometheus.py +915 -390
  102. holmes/plugins/toolsets/prometheus/prometheus_instructions.jinja2 +43 -2
  103. holmes/plugins/toolsets/prometheus/utils.py +28 -0
  104. holmes/plugins/toolsets/rabbitmq/toolset_rabbitmq.py +9 -10
  105. holmes/plugins/toolsets/robusta/robusta.py +236 -65
  106. holmes/plugins/toolsets/robusta/robusta_instructions.jinja2 +26 -9
  107. holmes/plugins/toolsets/runbook/runbook_fetcher.py +137 -26
  108. holmes/plugins/toolsets/service_discovery.py +1 -1
  109. holmes/plugins/toolsets/servicenow_tables/instructions.jinja2 +83 -0
  110. holmes/plugins/toolsets/servicenow_tables/servicenow_tables.py +426 -0
  111. holmes/plugins/toolsets/utils.py +88 -0
  112. holmes/utils/config_utils.py +91 -0
  113. holmes/utils/default_toolset_installation_guide.jinja2 +1 -22
  114. holmes/utils/env.py +7 -0
  115. holmes/utils/global_instructions.py +75 -10
  116. holmes/utils/holmes_status.py +2 -1
  117. holmes/utils/holmes_sync_toolsets.py +0 -2
  118. holmes/utils/krr_utils.py +188 -0
  119. holmes/utils/sentry_helper.py +41 -0
  120. holmes/utils/stream.py +61 -7
  121. holmes/version.py +34 -14
  122. holmesgpt-0.16.2a0.dist-info/LICENSE +178 -0
  123. {holmesgpt-0.13.2.dist-info → holmesgpt-0.16.2a0.dist-info}/METADATA +29 -27
  124. {holmesgpt-0.13.2.dist-info → holmesgpt-0.16.2a0.dist-info}/RECORD +126 -102
  125. holmes/core/performance_timing.py +0 -72
  126. holmes/plugins/toolsets/grafana/tempo_api.py +0 -124
  127. holmes/plugins/toolsets/grafana/toolset_grafana_loki.py +0 -110
  128. holmes/plugins/toolsets/newrelic.py +0 -231
  129. holmes/plugins/toolsets/servicenow/install.md +0 -37
  130. holmes/plugins/toolsets/servicenow/instructions.jinja2 +0 -3
  131. holmes/plugins/toolsets/servicenow/servicenow.py +0 -219
  132. holmesgpt-0.13.2.dist-info/LICENSE.txt +0 -21
  133. {holmesgpt-0.13.2.dist-info → holmesgpt-0.16.2a0.dist-info}/WHEEL +0 -0
  134. {holmesgpt-0.13.2.dist-info → holmesgpt-0.16.2a0.dist-info}/entry_points.txt +0 -0
holmes/core/llm.py CHANGED
@@ -1,32 +1,92 @@
1
1
  import json
2
2
  import logging
3
+ import os
3
4
  from abc import abstractmethod
4
- from typing import Any, Dict, List, Optional, Type, Union
5
+ from math import floor
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
5
7
 
6
- from litellm.types.utils import ModelResponse
8
+ import litellm
9
+ from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
10
+ from litellm.types.utils import ModelResponse, TextCompletionResponse
7
11
  import sentry_sdk
12
+ from pydantic import BaseModel, ConfigDict, SecretStr
13
+ from typing_extensions import Self
14
+
15
+ from holmes.clients.robusta_client import (
16
+ RobustaModel,
17
+ RobustaModelsResponse,
18
+ fetch_robusta_models,
19
+ )
8
20
 
9
- from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
10
- from pydantic import BaseModel
11
- import litellm
12
- import os
13
21
  from holmes.common.env_vars import (
22
+ FALLBACK_CONTEXT_WINDOW_SIZE,
23
+ LOAD_ALL_ROBUSTA_MODELS,
14
24
  REASONING_EFFORT,
25
+ ROBUSTA_AI,
26
+ ROBUSTA_API_ENDPOINT,
15
27
  THINKING,
28
+ EXTRA_HEADERS,
29
+ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT,
30
+ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS,
16
31
  )
32
+ from holmes.core.supabase_dal import SupabaseDal
33
+ from holmes.utils.env import environ_get_safe_int, replace_env_vars_values
34
+ from holmes.utils.file_utils import load_yaml_file
17
35
 
36
+ if TYPE_CHECKING:
37
+ from holmes.config import Config
18
38
 
19
- def environ_get_safe_int(env_var, default="0"):
20
- try:
21
- return max(int(os.environ.get(env_var, default)), 0)
22
- except ValueError:
23
- return int(default)
39
+ MODEL_LIST_FILE_LOCATION = os.environ.get(
40
+ "MODEL_LIST_FILE_LOCATION", "/etc/holmes/config/model_list.yaml"
41
+ )
24
42
 
25
43
 
26
44
  OVERRIDE_MAX_OUTPUT_TOKEN = environ_get_safe_int("OVERRIDE_MAX_OUTPUT_TOKEN")
27
45
  OVERRIDE_MAX_CONTENT_SIZE = environ_get_safe_int("OVERRIDE_MAX_CONTENT_SIZE")
28
46
 
29
47
 
48
+ def get_context_window_compaction_threshold_pct() -> int:
49
+ """Get the compaction threshold percentage at runtime to support test overrides."""
50
+ return environ_get_safe_int("CONTEXT_WINDOW_COMPACTION_THRESHOLD_PCT", default="95")
51
+
52
+
53
+ ROBUSTA_AI_MODEL_NAME = "Robusta"
54
+
55
+
56
+ class TokenCountMetadata(BaseModel):
57
+ total_tokens: int
58
+ tools_tokens: int
59
+ system_tokens: int
60
+ user_tokens: int
61
+ tools_to_call_tokens: int
62
+ assistant_tokens: int
63
+ other_tokens: int
64
+
65
+
66
+ class ModelEntry(BaseModel):
67
+ """ModelEntry represents a single LLM model configuration."""
68
+
69
+ model: str
70
+ # TODO: the name field seems to be redundant, can we remove it?
71
+ name: Optional[str] = None
72
+ api_key: Optional[SecretStr] = None
73
+ base_url: Optional[str] = None
74
+ is_robusta_model: Optional[bool] = None
75
+ custom_args: Optional[Dict[str, Any]] = None
76
+
77
+ # LLM configurations used services like Azure OpenAI Service
78
+ api_base: Optional[str] = None
79
+ api_version: Optional[str] = None
80
+
81
+ model_config = ConfigDict(
82
+ extra="allow",
83
+ )
84
+
85
+ @classmethod
86
+ def load_from_dict(cls, data: dict) -> Self:
87
+ return cls.model_validate(data)
88
+
89
+
30
90
  class LLM:
31
91
  @abstractmethod
32
92
  def __init__(self):
@@ -40,8 +100,23 @@ class LLM:
40
100
  def get_maximum_output_token(self) -> int:
41
101
  pass
42
102
 
103
+ def get_max_token_count_for_single_tool(self) -> int:
104
+ if (
105
+ 0 < TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT
106
+ and TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT <= 100
107
+ ):
108
+ context_window_size = self.get_context_window_size()
109
+ calculated_max_tokens = int(
110
+ context_window_size * TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT // 100
111
+ )
112
+ return min(calculated_max_tokens, TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS)
113
+ else:
114
+ return TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS
115
+
43
116
  @abstractmethod
44
- def count_tokens_for_message(self, messages: list[dict]) -> int:
117
+ def count_tokens(
118
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
119
+ ) -> TokenCountMetadata:
45
120
  pass
46
121
 
47
122
  @abstractmethod
@@ -61,31 +136,55 @@ class LLM:
61
136
  class DefaultLLM(LLM):
62
137
  model: str
63
138
  api_key: Optional[str]
64
- base_url: Optional[str]
139
+ api_base: Optional[str]
140
+ api_version: Optional[str]
65
141
  args: Dict
142
+ is_robusta_model: bool
66
143
 
67
144
  def __init__(
68
145
  self,
69
146
  model: str,
70
147
  api_key: Optional[str] = None,
148
+ api_base: Optional[str] = None,
149
+ api_version: Optional[str] = None,
71
150
  args: Optional[Dict] = None,
72
- tracer=None,
151
+ tracer: Optional[Any] = None,
152
+ name: Optional[str] = None,
153
+ is_robusta_model: bool = False,
73
154
  ):
74
155
  self.model = model
75
156
  self.api_key = api_key
157
+ self.api_base = api_base
158
+ self.api_version = api_version
76
159
  self.args = args or {}
77
160
  self.tracer = tracer
161
+ self.name = name
162
+ self.is_robusta_model = is_robusta_model
163
+ self.update_custom_args()
164
+ self.check_llm(
165
+ self.model, self.api_key, self.api_base, self.api_version, self.args
166
+ )
78
167
 
79
- if not self.args:
80
- self.check_llm(self.model, self.api_key)
168
+ def update_custom_args(self):
169
+ self.max_context_size = self.args.get("custom_args", {}).get("max_context_size")
170
+ self.args.pop("custom_args", None)
81
171
 
82
- def check_llm(self, model: str, api_key: Optional[str]):
172
+ def check_llm(
173
+ self,
174
+ model: str,
175
+ api_key: Optional[str],
176
+ api_base: Optional[str],
177
+ api_version: Optional[str],
178
+ args: Optional[dict] = None,
179
+ ):
180
+ if self.is_robusta_model:
181
+ # The model is assumed correctly configured if it is a robusta model
182
+ # For robusta models, this code would fail because Holmes has no knowledge of the API keys
183
+ # to azure or bedrock as all completion API calls go through robusta's LLM proxy
184
+ return
185
+ args = args or {}
83
186
  logging.debug(f"Checking LiteLLM model {model}")
84
- # TODO: this WAS a hack to get around the fact that we can't pass in an api key to litellm.validate_environment
85
- # so without this hack it always complains that the environment variable for the api key is missing
86
- # to fix that, we always set an api key in the standard format that litellm expects (which is ${PROVIDER}_API_KEY)
87
- # TODO: we can now handle this better - see https://github.com/BerriAI/litellm/issues/4375#issuecomment-2223684750
88
- lookup = litellm.get_llm_provider(self.model)
187
+ lookup = litellm.get_llm_provider(model)
89
188
  if not lookup:
90
189
  raise Exception(f"Unknown provider for model {model}")
91
190
  provider = lookup[1]
@@ -119,85 +218,151 @@ class DefaultLLM(LLM):
119
218
  "environment variable for proper functionality. For more information, refer to the documentation: "
120
219
  "https://docs.litellm.ai/docs/providers/watsonx#usage---models-in-deployment-spaces"
121
220
  )
122
- elif provider == "bedrock" and (
123
- os.environ.get("AWS_PROFILE") or os.environ.get("AWS_BEARER_TOKEN_BEDROCK")
124
- ):
125
- model_requirements = {"keys_in_environment": True, "missing_keys": []}
221
+ elif provider == "bedrock":
222
+ if os.environ.get("AWS_PROFILE") or os.environ.get(
223
+ "AWS_BEARER_TOKEN_BEDROCK"
224
+ ):
225
+ model_requirements = {"keys_in_environment": True, "missing_keys": []}
226
+ elif args.get("aws_access_key_id") and args.get("aws_secret_access_key"):
227
+ return # break fast.
228
+ else:
229
+ model_requirements = litellm.validate_environment(
230
+ model=model, api_key=api_key, api_base=api_base
231
+ )
126
232
  else:
127
- #
128
- api_key_env_var = f"{provider.upper()}_API_KEY"
129
- if api_key:
130
- os.environ[api_key_env_var] = api_key
131
- model_requirements = litellm.validate_environment(model=model)
233
+ model_requirements = litellm.validate_environment(
234
+ model=model, api_key=api_key, api_base=api_base
235
+ )
236
+ # validate_environment does not accept api_version, and as a special case for Azure OpenAI Service,
237
+ # when all the other AZURE environments are set expect AZURE_API_VERSION, validate_environment complains
238
+ # the missing of it even after the api_version is set.
239
+ # TODO: There's an open PR in litellm to accept api_version in validate_environment, we can leverage this
240
+ # change if accepted to ignore the following check.
241
+ # https://github.com/BerriAI/litellm/pull/13808
242
+ if (
243
+ provider == "azure"
244
+ and ["AZURE_API_VERSION"] == model_requirements["missing_keys"]
245
+ and api_version is not None
246
+ ):
247
+ model_requirements["missing_keys"] = []
248
+ model_requirements["keys_in_environment"] = True
132
249
 
133
250
  if not model_requirements["keys_in_environment"]:
134
251
  raise Exception(
135
252
  f"model {model} requires the following environment variables: {model_requirements['missing_keys']}"
136
253
  )
137
254
 
138
- def _strip_model_prefix(self) -> str:
255
+ def _get_model_name_variants_for_lookup(self) -> list[str]:
139
256
  """
140
- Helper function to strip 'openai/' prefix from model name if it exists.
141
- model cost is taken from here which does not have the openai prefix
142
- https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
257
+ Generate model name variants to try when looking up in litellm.model_cost.
258
+ Returns a list of names to try in order: exact, lowercase, without prefix, etc.
143
259
  """
144
- model_name = self.model
145
- prefixes = ["openai/", "bedrock/", "vertex_ai/", "anthropic/"]
260
+ names_to_try = [self.model, self.model.lower()]
146
261
 
147
- for prefix in prefixes:
148
- if model_name.startswith(prefix):
149
- return model_name[len(prefix) :]
262
+ # If there's a prefix, also try without it
263
+ if "/" in self.model:
264
+ base_model = self.model.split("/", 1)[1]
265
+ names_to_try.extend([base_model, base_model.lower()])
150
266
 
151
- return model_name
152
-
153
- # this unfortunately does not seem to work for azure if the deployment name is not a well-known model name
154
- # if not litellm.supports_function_calling(model=model):
155
- # raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
267
+ # Remove duplicates while preserving order (dict.fromkeys maintains insertion order in Python 3.7+)
268
+ return list(dict.fromkeys(names_to_try))
156
269
 
157
270
  def get_context_window_size(self) -> int:
271
+ if self.max_context_size:
272
+ return self.max_context_size
273
+
158
274
  if OVERRIDE_MAX_CONTENT_SIZE:
159
275
  logging.debug(
160
276
  f"Using override OVERRIDE_MAX_CONTENT_SIZE {OVERRIDE_MAX_CONTENT_SIZE}"
161
277
  )
162
278
  return OVERRIDE_MAX_CONTENT_SIZE
163
279
 
164
- model_name = os.environ.get("MODEL_TYPE", self._strip_model_prefix())
165
- try:
166
- return litellm.model_cost[model_name]["max_input_tokens"]
167
- except Exception:
168
- logging.warning(
169
- f"Couldn't find model's name {model_name} in litellm's model list, fallback to 128k tokens for max_input_tokens"
170
- )
171
- return 128000
280
+ # Try each name variant
281
+ for name in self._get_model_name_variants_for_lookup():
282
+ try:
283
+ return litellm.model_cost[name]["max_input_tokens"]
284
+ except Exception:
285
+ continue
286
+
287
+ # Log which lookups we tried
288
+ logging.warning(
289
+ f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
290
+ f"using default {FALLBACK_CONTEXT_WINDOW_SIZE} tokens for max_input_tokens. "
291
+ f"To override, set OVERRIDE_MAX_CONTENT_SIZE environment variable to the correct value for your model."
292
+ )
293
+ return FALLBACK_CONTEXT_WINDOW_SIZE
172
294
 
173
295
  @sentry_sdk.trace
174
- def count_tokens_for_message(self, messages: list[dict]) -> int:
175
- total_token_count = 0
296
+ def count_tokens(
297
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
298
+ ) -> TokenCountMetadata:
299
+ # TODO: Add a recount:bool flag to save time. When the flag is false, reuse 'message["token_count"]' for individual messages.
300
+ # It's only necessary to recount message tokens at the beginning of a session because the LLM model may have changed.
301
+ # Changing the model requires recounting tokens because the tokenizer may be different
302
+ total_tokens = 0
303
+ tools_tokens = 0
304
+ system_tokens = 0
305
+ assistant_tokens = 0
306
+ user_tokens = 0
307
+ other_tokens = 0
308
+ tools_to_call_tokens = 0
176
309
  for message in messages:
177
- if "token_count" in message and message["token_count"]:
178
- total_token_count += message["token_count"]
310
+ # count message tokens individually because it gives us fine grain information about each tool call/message etc.
311
+ # However be aware that the sum of individual message tokens is not equal to the overall messages token
312
+ token_count = litellm.token_counter( # type: ignore
313
+ model=self.model, messages=[message]
314
+ )
315
+ message["token_count"] = token_count
316
+ role = message.get("role")
317
+ if role == "system":
318
+ system_tokens += token_count
319
+ elif role == "user":
320
+ user_tokens += token_count
321
+ elif role == "tool":
322
+ tools_tokens += token_count
323
+ elif role == "assistant":
324
+ assistant_tokens += token_count
179
325
  else:
180
- # message can be counted by this method only if message contains a "content" key
181
- if "content" in message:
182
- if isinstance(message["content"], str):
183
- message_to_count = [
184
- {"type": "text", "text": message["content"]}
185
- ]
186
- elif isinstance(message["content"], list):
187
- message_to_count = [
188
- {"type": "text", "text": json.dumps(message["content"])}
189
- ]
190
- elif isinstance(message["content"], dict):
191
- if "type" not in message["content"]:
192
- message_to_count = [
193
- {"type": "text", "text": json.dumps(message["content"])}
194
- ]
195
- token_count = litellm.token_counter(
196
- model=self.model, messages=message_to_count
197
- )
198
- message["token_count"] = token_count
199
- total_token_count += token_count
200
- return total_token_count
326
+ # although this should not be needed,
327
+ # it is defensive code so that all tokens are accounted for
328
+ # and can potentially make debugging easier
329
+ other_tokens += token_count
330
+
331
+ messages_token_count_without_tools = litellm.token_counter( # type: ignore
332
+ model=self.model, messages=messages
333
+ )
334
+
335
+ total_tokens = litellm.token_counter( # type: ignore
336
+ model=self.model,
337
+ messages=messages,
338
+ tools=tools, # type: ignore
339
+ )
340
+ tools_to_call_tokens = max(0, total_tokens - messages_token_count_without_tools)
341
+
342
+ return TokenCountMetadata(
343
+ total_tokens=total_tokens,
344
+ system_tokens=system_tokens,
345
+ user_tokens=user_tokens,
346
+ tools_tokens=tools_tokens,
347
+ tools_to_call_tokens=tools_to_call_tokens,
348
+ other_tokens=other_tokens,
349
+ assistant_tokens=assistant_tokens,
350
+ )
351
+
352
+ def get_litellm_corrected_name_for_robusta_ai(self) -> str:
353
+ if self.is_robusta_model:
354
+ # For robusta models, self.model is the underlying provider/model used by Robusta AI
355
+ # To avoid litellm modifying the API URL according to the provider, the provider name
356
+ # is replaced with 'openai/' just before doing a completion() call
357
+ # Cf. https://docs.litellm.ai/docs/providers/openai_compatible
358
+ split_model_name = self.model.split("/")
359
+ return (
360
+ split_model_name[0]
361
+ if len(split_model_name) == 1
362
+ else f"openai/{split_model_name[1]}"
363
+ )
364
+ else:
365
+ return self.model
201
366
 
202
367
  def completion(
203
368
  self,
@@ -219,6 +384,9 @@ class DefaultLLM(LLM):
219
384
  if THINKING:
220
385
  self.args.setdefault("thinking", json.loads(THINKING))
221
386
 
387
+ if EXTRA_HEADERS:
388
+ self.args.setdefault("extra_headers", json.loads(EXTRA_HEADERS))
389
+
222
390
  if self.args.get("thinking", None):
223
391
  litellm.modify_params = True
224
392
 
@@ -234,9 +402,13 @@ class DefaultLLM(LLM):
234
402
 
235
403
  # Get the litellm module to use (wrapped or unwrapped)
236
404
  litellm_to_use = self.tracer.wrap_llm(litellm) if self.tracer else litellm
405
+
406
+ litellm_model_name = self.get_litellm_corrected_name_for_robusta_ai()
237
407
  result = litellm_to_use.completion(
238
- model=self.model,
408
+ model=litellm_model_name,
239
409
  api_key=self.api_key,
410
+ base_url=self.api_base,
411
+ api_version=self.api_version,
240
412
  messages=messages,
241
413
  response_format=response_format,
242
414
  drop_params=drop_params,
@@ -254,20 +426,33 @@ class DefaultLLM(LLM):
254
426
  raise Exception(f"Unexpected type returned by the LLM {type(result)}")
255
427
 
256
428
  def get_maximum_output_token(self) -> int:
429
+ max_output_tokens = floor(min(64000, self.get_context_window_size() / 5))
430
+
257
431
  if OVERRIDE_MAX_OUTPUT_TOKEN:
258
432
  logging.debug(
259
433
  f"Using OVERRIDE_MAX_OUTPUT_TOKEN {OVERRIDE_MAX_OUTPUT_TOKEN}"
260
434
  )
261
435
  return OVERRIDE_MAX_OUTPUT_TOKEN
262
436
 
263
- model_name = os.environ.get("MODEL_TYPE", self._strip_model_prefix())
264
- try:
265
- return litellm.model_cost[model_name]["max_output_tokens"]
266
- except Exception:
267
- logging.warning(
268
- f"Couldn't find model's name {model_name} in litellm's model list, fallback to 4096 tokens for max_output_tokens"
269
- )
270
- return 4096
437
+ # Try each name variant
438
+ for name in self._get_model_name_variants_for_lookup():
439
+ try:
440
+ litellm_max_output_tokens = litellm.model_cost[name][
441
+ "max_output_tokens"
442
+ ]
443
+ if litellm_max_output_tokens < max_output_tokens:
444
+ max_output_tokens = litellm_max_output_tokens
445
+ return max_output_tokens
446
+ except Exception:
447
+ continue
448
+
449
+ # Log which lookups we tried
450
+ logging.warning(
451
+ f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
452
+ f"using {max_output_tokens} tokens for max_output_tokens. "
453
+ f"To override, set OVERRIDE_MAX_OUTPUT_TOKEN environment variable to the correct value for your model."
454
+ )
455
+ return max_output_tokens
271
456
 
272
457
  def _add_cache_control_to_last_message(
273
458
  self, messages: List[Dict[str, Any]]
@@ -276,6 +461,12 @@ class DefaultLLM(LLM):
276
461
  Add cache_control to the last non-user message for Anthropic prompt caching.
277
462
  Removes any existing cache_control from previous messages to avoid accumulation.
278
463
  """
464
+ # Skip cache_control for VertexAI/Gemini models as they don't support it with tools
465
+ if self.model and (
466
+ "vertex" in self.model.lower() or "gemini" in self.model.lower()
467
+ ):
468
+ return
469
+
279
470
  # First, remove any existing cache_control from all messages
280
471
  for msg in messages:
281
472
  content = msg.get("content")
@@ -305,7 +496,7 @@ class DefaultLLM(LLM):
305
496
  if content is None:
306
497
  return
307
498
 
308
- if isinstance(content, str):
499
+ if isinstance(content, str) and content:
309
500
  # Convert string to structured format with cache_control
310
501
  target_msg["content"] = [
311
502
  {
@@ -325,3 +516,213 @@ class DefaultLLM(LLM):
325
516
  logging.debug(
326
517
  f"Added cache_control to {target_msg.get('role')} message (structured content)"
327
518
  )
519
+
520
+
521
+ class LLMModelRegistry:
522
+ def __init__(self, config: "Config", dal: SupabaseDal) -> None:
523
+ self.config = config
524
+ self._llms: dict[str, ModelEntry] = {}
525
+ self._default_robusta_model = None
526
+ self.dal = dal
527
+
528
+ self._init_models()
529
+
530
+ @property
531
+ def default_robusta_model(self) -> Optional[str]:
532
+ return self._default_robusta_model
533
+
534
+ def _init_models(self):
535
+ self._llms = self._parse_models_file(MODEL_LIST_FILE_LOCATION)
536
+
537
+ if self._should_load_robusta_ai():
538
+ self.configure_robusta_ai_model()
539
+
540
+ if self._should_load_config_model():
541
+ self._llms[self.config.model] = self._create_model_entry(
542
+ model=self.config.model,
543
+ model_name=self.config.model,
544
+ base_url=self.config.api_base,
545
+ is_robusta_model=False,
546
+ api_key=self.config.api_key,
547
+ api_version=self.config.api_version,
548
+ )
549
+
550
+ def _should_load_config_model(self) -> bool:
551
+ if self.config.model is not None:
552
+ return True
553
+
554
+ # backward compatibility - in the past config.model was set by default to gpt-4o.
555
+ # so we need to check if the user has set an OPENAI_API_KEY to load the config model.
556
+ has_openai_key = os.environ.get("OPENAI_API_KEY")
557
+ if has_openai_key:
558
+ self.config.model = "gpt-4.1"
559
+ return True
560
+
561
+ return False
562
+
563
+ def configure_robusta_ai_model(self) -> None:
564
+ try:
565
+ if not self.config.cluster_name or not LOAD_ALL_ROBUSTA_MODELS:
566
+ self._load_default_robusta_config()
567
+ return
568
+
569
+ if not self.dal.account_id or not self.dal.enabled:
570
+ self._load_default_robusta_config()
571
+ return
572
+
573
+ account_id, token = self.dal.get_ai_credentials()
574
+
575
+ robusta_models: RobustaModelsResponse | None = fetch_robusta_models(
576
+ account_id, token
577
+ )
578
+ if not robusta_models or not robusta_models.models:
579
+ self._load_default_robusta_config()
580
+ return
581
+
582
+ default_model = None
583
+ for model_name, model_data in robusta_models.models.items():
584
+ logging.info(f"Loading Robusta AI model: {model_name}")
585
+ self._llms[model_name] = self._create_robusta_model_entry(
586
+ model_name=model_name, model_data=model_data
587
+ )
588
+ if model_data.is_default:
589
+ default_model = model_name
590
+
591
+ if default_model:
592
+ logging.info(f"Setting default Robusta AI model to: {default_model}")
593
+ self._default_robusta_model: str = default_model # type: ignore
594
+
595
+ except Exception:
596
+ logging.exception("Failed to get all robusta models")
597
+ # fallback to default behavior
598
+ self._load_default_robusta_config()
599
+
600
+ def _load_default_robusta_config(self):
601
+ if self._should_load_robusta_ai():
602
+ logging.info("Loading default Robusta AI model")
603
+ self._llms[ROBUSTA_AI_MODEL_NAME] = ModelEntry(
604
+ name=ROBUSTA_AI_MODEL_NAME,
605
+ model="gpt-4o", # TODO: tech debt, this isn't really
606
+ base_url=ROBUSTA_API_ENDPOINT,
607
+ is_robusta_model=True,
608
+ )
609
+ self._default_robusta_model = ROBUSTA_AI_MODEL_NAME
610
+
611
+ def _should_load_robusta_ai(self) -> bool:
612
+ if not self.config.should_try_robusta_ai:
613
+ return False
614
+
615
+ # ROBUSTA_AI were set in the env vars, so we can use it directly
616
+ if ROBUSTA_AI is not None:
617
+ return ROBUSTA_AI
618
+
619
+ # MODEL is set in the env vars, e.g. the user is using a custom model
620
+ # so we don't need to load the robusta AI model and keep the behavior backward compatible
621
+ if "MODEL" in os.environ:
622
+ return False
623
+
624
+ # if the user has provided a model list, we don't need to load the robusta AI model
625
+ if self._llms:
626
+ return False
627
+
628
+ return True
629
+
630
+ def get_model_params(self, model_key: Optional[str] = None) -> ModelEntry:
631
+ if not self._llms:
632
+ raise Exception("No llm models were loaded")
633
+
634
+ if model_key:
635
+ model_params = self._llms.get(model_key)
636
+ if model_params is not None:
637
+ logging.info(f"Using selected model: {model_key}")
638
+ return model_params.copy()
639
+
640
+ logging.error(f"Couldn't find model: {model_key} in model list")
641
+
642
+ if self._default_robusta_model:
643
+ model_params = self._llms.get(self._default_robusta_model)
644
+ if model_params is not None:
645
+ logging.info(
646
+ f"Using default Robusta AI model: {self._default_robusta_model}"
647
+ )
648
+ return model_params.copy()
649
+
650
+ logging.error(
651
+ f"Couldn't find default Robusta AI model: {self._default_robusta_model} in model list"
652
+ )
653
+
654
+ model_key, first_model_params = next(iter(self._llms.items()))
655
+ logging.debug(f"Using first available model: {model_key}")
656
+ return first_model_params.copy()
657
+
658
+ def get_llm(self, name: str) -> LLM: # TODO: fix logic
659
+ return self._llms[name] # type: ignore
660
+
661
+ @property
662
+ def models(self) -> dict[str, ModelEntry]:
663
+ return self._llms
664
+
665
+ def _parse_models_file(self, path: str) -> dict[str, ModelEntry]:
666
+ models = load_yaml_file(path, raise_error=False, warn_not_found=False)
667
+ for _, params in models.items():
668
+ params = replace_env_vars_values(params)
669
+
670
+ llms = {}
671
+ for model_name, params in models.items():
672
+ llms[model_name] = ModelEntry.model_validate(params)
673
+
674
+ return llms
675
+
676
+ def _create_robusta_model_entry(
677
+ self, model_name: str, model_data: RobustaModel
678
+ ) -> ModelEntry:
679
+ entry = self._create_model_entry(
680
+ model=model_data.model,
681
+ model_name=model_name,
682
+ base_url=f"{ROBUSTA_API_ENDPOINT}/llm/{model_name}",
683
+ is_robusta_model=True,
684
+ )
685
+ entry.custom_args = model_data.holmes_args or {} # type: ignore[assignment]
686
+ return entry
687
+
688
+ def _create_model_entry(
689
+ self,
690
+ model: str,
691
+ model_name: str,
692
+ base_url: Optional[str] = None,
693
+ is_robusta_model: Optional[bool] = None,
694
+ api_key: Optional[SecretStr] = None,
695
+ api_base: Optional[str] = None,
696
+ api_version: Optional[str] = None,
697
+ ) -> ModelEntry:
698
+ return ModelEntry(
699
+ name=model_name,
700
+ model=model,
701
+ base_url=base_url,
702
+ is_robusta_model=is_robusta_model,
703
+ api_key=api_key,
704
+ api_base=api_base,
705
+ api_version=api_version,
706
+ )
707
+
708
+
709
+ def get_llm_usage(
710
+ llm_response: Union[ModelResponse, CustomStreamWrapper, TextCompletionResponse],
711
+ ) -> dict:
712
+ usage: dict = {}
713
+ if (
714
+ (
715
+ isinstance(llm_response, ModelResponse)
716
+ or isinstance(llm_response, TextCompletionResponse)
717
+ )
718
+ and hasattr(llm_response, "usage")
719
+ and llm_response.usage
720
+ ): # type: ignore
721
+ usage["prompt_tokens"] = llm_response.usage.prompt_tokens # type: ignore
722
+ usage["completion_tokens"] = llm_response.usage.completion_tokens # type: ignore
723
+ usage["total_tokens"] = llm_response.usage.total_tokens # type: ignore
724
+ elif isinstance(llm_response, CustomStreamWrapper):
725
+ complete_response = litellm.stream_chunk_builder(chunks=llm_response) # type: ignore
726
+ if complete_response:
727
+ return get_llm_usage(complete_response)
728
+ return usage