holmesgpt 0.13.2__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. holmes/__init__.py +3 -5
  2. holmes/clients/robusta_client.py +20 -6
  3. holmes/common/env_vars.py +58 -3
  4. holmes/common/openshift.py +1 -1
  5. holmes/config.py +123 -148
  6. holmes/core/conversations.py +71 -15
  7. holmes/core/feedback.py +191 -0
  8. holmes/core/investigation.py +31 -39
  9. holmes/core/investigation_structured_output.py +3 -3
  10. holmes/core/issue.py +1 -1
  11. holmes/core/llm.py +508 -88
  12. holmes/core/models.py +108 -4
  13. holmes/core/openai_formatting.py +14 -1
  14. holmes/core/prompt.py +48 -3
  15. holmes/core/runbooks.py +1 -0
  16. holmes/core/safeguards.py +8 -6
  17. holmes/core/supabase_dal.py +295 -100
  18. holmes/core/tool_calling_llm.py +489 -428
  19. holmes/core/tools.py +325 -56
  20. holmes/core/tools_utils/token_counting.py +21 -0
  21. holmes/core/tools_utils/tool_context_window_limiter.py +40 -0
  22. holmes/core/tools_utils/tool_executor.py +0 -13
  23. holmes/core/tools_utils/toolset_utils.py +1 -0
  24. holmes/core/toolset_manager.py +191 -5
  25. holmes/core/tracing.py +19 -3
  26. holmes/core/transformers/__init__.py +23 -0
  27. holmes/core/transformers/base.py +63 -0
  28. holmes/core/transformers/llm_summarize.py +175 -0
  29. holmes/core/transformers/registry.py +123 -0
  30. holmes/core/transformers/transformer.py +32 -0
  31. holmes/core/truncation/compaction.py +94 -0
  32. holmes/core/truncation/dal_truncation_utils.py +23 -0
  33. holmes/core/truncation/input_context_window_limiter.py +219 -0
  34. holmes/interactive.py +228 -31
  35. holmes/main.py +23 -40
  36. holmes/plugins/interfaces.py +2 -1
  37. holmes/plugins/prompts/__init__.py +2 -1
  38. holmes/plugins/prompts/_fetch_logs.jinja2 +31 -6
  39. holmes/plugins/prompts/_general_instructions.jinja2 +1 -2
  40. holmes/plugins/prompts/_runbook_instructions.jinja2 +24 -12
  41. holmes/plugins/prompts/base_user_prompt.jinja2 +7 -0
  42. holmes/plugins/prompts/conversation_history_compaction.jinja2 +89 -0
  43. holmes/plugins/prompts/generic_ask.jinja2 +0 -4
  44. holmes/plugins/prompts/generic_ask_conversation.jinja2 +0 -1
  45. holmes/plugins/prompts/generic_ask_for_issue_conversation.jinja2 +0 -1
  46. holmes/plugins/prompts/generic_investigation.jinja2 +0 -1
  47. holmes/plugins/prompts/investigation_procedure.jinja2 +50 -1
  48. holmes/plugins/prompts/kubernetes_workload_ask.jinja2 +0 -1
  49. holmes/plugins/prompts/kubernetes_workload_chat.jinja2 +0 -1
  50. holmes/plugins/runbooks/__init__.py +145 -17
  51. holmes/plugins/runbooks/catalog.json +2 -0
  52. holmes/plugins/sources/github/__init__.py +4 -2
  53. holmes/plugins/sources/prometheus/models.py +1 -0
  54. holmes/plugins/toolsets/__init__.py +44 -27
  55. holmes/plugins/toolsets/aks-node-health.yaml +46 -0
  56. holmes/plugins/toolsets/aks.yaml +64 -0
  57. holmes/plugins/toolsets/atlas_mongodb/mongodb_atlas.py +38 -47
  58. holmes/plugins/toolsets/azure_sql/apis/alert_monitoring_api.py +3 -2
  59. holmes/plugins/toolsets/azure_sql/apis/azure_sql_api.py +2 -1
  60. holmes/plugins/toolsets/azure_sql/apis/connection_failure_api.py +3 -2
  61. holmes/plugins/toolsets/azure_sql/apis/connection_monitoring_api.py +3 -1
  62. holmes/plugins/toolsets/azure_sql/apis/storage_analysis_api.py +3 -1
  63. holmes/plugins/toolsets/azure_sql/azure_sql_toolset.py +12 -13
  64. holmes/plugins/toolsets/azure_sql/tools/analyze_connection_failures.py +15 -12
  65. holmes/plugins/toolsets/azure_sql/tools/analyze_database_connections.py +15 -12
  66. holmes/plugins/toolsets/azure_sql/tools/analyze_database_health_status.py +11 -11
  67. holmes/plugins/toolsets/azure_sql/tools/analyze_database_performance.py +11 -9
  68. holmes/plugins/toolsets/azure_sql/tools/analyze_database_storage.py +15 -12
  69. holmes/plugins/toolsets/azure_sql/tools/get_active_alerts.py +15 -15
  70. holmes/plugins/toolsets/azure_sql/tools/get_slow_queries.py +11 -8
  71. holmes/plugins/toolsets/azure_sql/tools/get_top_cpu_queries.py +11 -8
  72. holmes/plugins/toolsets/azure_sql/tools/get_top_data_io_queries.py +11 -8
  73. holmes/plugins/toolsets/azure_sql/tools/get_top_log_io_queries.py +11 -8
  74. holmes/plugins/toolsets/azure_sql/utils.py +0 -32
  75. holmes/plugins/toolsets/bash/argocd/__init__.py +3 -3
  76. holmes/plugins/toolsets/bash/aws/__init__.py +4 -4
  77. holmes/plugins/toolsets/bash/azure/__init__.py +4 -4
  78. holmes/plugins/toolsets/bash/bash_toolset.py +11 -15
  79. holmes/plugins/toolsets/bash/common/bash.py +23 -13
  80. holmes/plugins/toolsets/bash/common/bash_command.py +1 -1
  81. holmes/plugins/toolsets/bash/common/stringify.py +1 -1
  82. holmes/plugins/toolsets/bash/kubectl/__init__.py +2 -1
  83. holmes/plugins/toolsets/bash/kubectl/constants.py +0 -1
  84. holmes/plugins/toolsets/bash/kubectl/kubectl_get.py +3 -4
  85. holmes/plugins/toolsets/bash/parse_command.py +12 -13
  86. holmes/plugins/toolsets/cilium.yaml +284 -0
  87. holmes/plugins/toolsets/connectivity_check.py +124 -0
  88. holmes/plugins/toolsets/coralogix/api.py +132 -119
  89. holmes/plugins/toolsets/coralogix/coralogix.jinja2 +14 -0
  90. holmes/plugins/toolsets/coralogix/toolset_coralogix.py +219 -0
  91. holmes/plugins/toolsets/coralogix/utils.py +15 -79
  92. holmes/plugins/toolsets/datadog/datadog_api.py +525 -26
  93. holmes/plugins/toolsets/datadog/datadog_logs_instructions.jinja2 +55 -11
  94. holmes/plugins/toolsets/datadog/datadog_metrics_instructions.jinja2 +3 -3
  95. holmes/plugins/toolsets/datadog/datadog_models.py +59 -0
  96. holmes/plugins/toolsets/datadog/datadog_url_utils.py +213 -0
  97. holmes/plugins/toolsets/datadog/instructions_datadog_traces.jinja2 +165 -28
  98. holmes/plugins/toolsets/datadog/toolset_datadog_general.py +417 -241
  99. holmes/plugins/toolsets/datadog/toolset_datadog_logs.py +234 -214
  100. holmes/plugins/toolsets/datadog/toolset_datadog_metrics.py +167 -79
  101. holmes/plugins/toolsets/datadog/toolset_datadog_traces.py +374 -363
  102. holmes/plugins/toolsets/elasticsearch/__init__.py +6 -0
  103. holmes/plugins/toolsets/elasticsearch/elasticsearch.py +834 -0
  104. holmes/plugins/toolsets/elasticsearch/opensearch_ppl_query_docs.jinja2 +1616 -0
  105. holmes/plugins/toolsets/elasticsearch/opensearch_query_assist.py +78 -0
  106. holmes/plugins/toolsets/elasticsearch/opensearch_query_assist_instructions.jinja2 +223 -0
  107. holmes/plugins/toolsets/git.py +54 -50
  108. holmes/plugins/toolsets/grafana/base_grafana_toolset.py +16 -4
  109. holmes/plugins/toolsets/grafana/common.py +13 -29
  110. holmes/plugins/toolsets/grafana/grafana_tempo_api.py +455 -0
  111. holmes/plugins/toolsets/grafana/loki/instructions.jinja2 +25 -0
  112. holmes/plugins/toolsets/grafana/loki/toolset_grafana_loki.py +191 -0
  113. holmes/plugins/toolsets/grafana/loki_api.py +4 -0
  114. holmes/plugins/toolsets/grafana/toolset_grafana.py +293 -89
  115. holmes/plugins/toolsets/grafana/toolset_grafana_dashboard.jinja2 +49 -0
  116. holmes/plugins/toolsets/grafana/toolset_grafana_tempo.jinja2 +246 -11
  117. holmes/plugins/toolsets/grafana/toolset_grafana_tempo.py +820 -292
  118. holmes/plugins/toolsets/grafana/trace_parser.py +4 -3
  119. holmes/plugins/toolsets/internet/internet.py +15 -16
  120. holmes/plugins/toolsets/internet/notion.py +9 -11
  121. holmes/plugins/toolsets/investigator/core_investigation.py +44 -36
  122. holmes/plugins/toolsets/investigator/model.py +3 -1
  123. holmes/plugins/toolsets/json_filter_mixin.py +134 -0
  124. holmes/plugins/toolsets/kafka.py +36 -42
  125. holmes/plugins/toolsets/kubernetes.yaml +317 -113
  126. holmes/plugins/toolsets/kubernetes_logs.py +9 -9
  127. holmes/plugins/toolsets/kubernetes_logs.yaml +32 -0
  128. holmes/plugins/toolsets/logging_utils/logging_api.py +94 -8
  129. holmes/plugins/toolsets/mcp/toolset_mcp.py +218 -64
  130. holmes/plugins/toolsets/newrelic/new_relic_api.py +165 -0
  131. holmes/plugins/toolsets/newrelic/newrelic.jinja2 +65 -0
  132. holmes/plugins/toolsets/newrelic/newrelic.py +320 -0
  133. holmes/plugins/toolsets/openshift.yaml +283 -0
  134. holmes/plugins/toolsets/prometheus/prometheus.py +1202 -421
  135. holmes/plugins/toolsets/prometheus/prometheus_instructions.jinja2 +54 -5
  136. holmes/plugins/toolsets/prometheus/utils.py +28 -0
  137. holmes/plugins/toolsets/rabbitmq/api.py +23 -4
  138. holmes/plugins/toolsets/rabbitmq/toolset_rabbitmq.py +13 -14
  139. holmes/plugins/toolsets/robusta/robusta.py +239 -68
  140. holmes/plugins/toolsets/robusta/robusta_instructions.jinja2 +26 -9
  141. holmes/plugins/toolsets/runbook/runbook_fetcher.py +157 -27
  142. holmes/plugins/toolsets/service_discovery.py +1 -1
  143. holmes/plugins/toolsets/servicenow_tables/instructions.jinja2 +83 -0
  144. holmes/plugins/toolsets/servicenow_tables/servicenow_tables.py +426 -0
  145. holmes/plugins/toolsets/utils.py +88 -0
  146. holmes/utils/config_utils.py +91 -0
  147. holmes/utils/connection_utils.py +31 -0
  148. holmes/utils/console/result.py +10 -0
  149. holmes/utils/default_toolset_installation_guide.jinja2 +1 -22
  150. holmes/utils/env.py +7 -0
  151. holmes/utils/file_utils.py +2 -1
  152. holmes/utils/global_instructions.py +60 -11
  153. holmes/utils/holmes_status.py +6 -4
  154. holmes/utils/holmes_sync_toolsets.py +0 -2
  155. holmes/utils/krr_utils.py +188 -0
  156. holmes/utils/log.py +15 -0
  157. holmes/utils/markdown_utils.py +2 -3
  158. holmes/utils/memory_limit.py +58 -0
  159. holmes/utils/sentry_helper.py +64 -0
  160. holmes/utils/stream.py +69 -8
  161. holmes/utils/tags.py +4 -3
  162. holmes/version.py +37 -15
  163. holmesgpt-0.18.4.dist-info/LICENSE +178 -0
  164. {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/METADATA +35 -31
  165. holmesgpt-0.18.4.dist-info/RECORD +258 -0
  166. holmes/core/performance_timing.py +0 -72
  167. holmes/plugins/toolsets/aws.yaml +0 -80
  168. holmes/plugins/toolsets/coralogix/toolset_coralogix_logs.py +0 -112
  169. holmes/plugins/toolsets/datadog/datadog_traces_formatter.py +0 -310
  170. holmes/plugins/toolsets/datadog/toolset_datadog_rds.py +0 -739
  171. holmes/plugins/toolsets/grafana/grafana_api.py +0 -42
  172. holmes/plugins/toolsets/grafana/tempo_api.py +0 -124
  173. holmes/plugins/toolsets/grafana/toolset_grafana_loki.py +0 -110
  174. holmes/plugins/toolsets/newrelic.py +0 -231
  175. holmes/plugins/toolsets/opensearch/opensearch.py +0 -257
  176. holmes/plugins/toolsets/opensearch/opensearch_logs.py +0 -161
  177. holmes/plugins/toolsets/opensearch/opensearch_traces.py +0 -218
  178. holmes/plugins/toolsets/opensearch/opensearch_traces_instructions.jinja2 +0 -12
  179. holmes/plugins/toolsets/opensearch/opensearch_utils.py +0 -166
  180. holmes/plugins/toolsets/servicenow/install.md +0 -37
  181. holmes/plugins/toolsets/servicenow/instructions.jinja2 +0 -3
  182. holmes/plugins/toolsets/servicenow/servicenow.py +0 -219
  183. holmes/utils/keygen_utils.py +0 -6
  184. holmesgpt-0.13.2.dist-info/LICENSE.txt +0 -21
  185. holmesgpt-0.13.2.dist-info/RECORD +0 -234
  186. /holmes/plugins/toolsets/{opensearch → newrelic}/__init__.py +0 -0
  187. {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/WHEEL +0 -0
  188. {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/entry_points.txt +0 -0
holmes/core/llm.py CHANGED
@@ -1,32 +1,93 @@
1
1
  import json
2
2
  import logging
3
+ import os
4
+ import threading
3
5
  from abc import abstractmethod
4
- from typing import Any, Dict, List, Optional, Type, Union
6
+ from math import floor
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
5
8
 
6
- from litellm.types.utils import ModelResponse
9
+ import litellm
7
10
  import sentry_sdk
8
-
9
11
  from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
10
- from pydantic import BaseModel
11
- import litellm
12
- import os
12
+ from litellm.types.utils import ModelResponse, TextCompletionResponse
13
+ from pydantic import BaseModel, ConfigDict, SecretStr
14
+ from typing_extensions import Self
15
+
16
+ from holmes.clients.robusta_client import (
17
+ RobustaModel,
18
+ RobustaModelsResponse,
19
+ fetch_robusta_models,
20
+ )
13
21
  from holmes.common.env_vars import (
22
+ EXTRA_HEADERS,
23
+ FALLBACK_CONTEXT_WINDOW_SIZE,
24
+ LLM_REQUEST_TIMEOUT,
25
+ LOAD_ALL_ROBUSTA_MODELS,
14
26
  REASONING_EFFORT,
27
+ ROBUSTA_AI,
28
+ ROBUSTA_API_ENDPOINT,
15
29
  THINKING,
30
+ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT,
31
+ TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS,
16
32
  )
33
+ from holmes.core.supabase_dal import SupabaseDal
34
+ from holmes.utils.env import environ_get_safe_int, replace_env_vars_values
35
+ from holmes.utils.file_utils import load_yaml_file
17
36
 
37
+ if TYPE_CHECKING:
38
+ from holmes.config import Config
18
39
 
19
- def environ_get_safe_int(env_var, default="0"):
20
- try:
21
- return max(int(os.environ.get(env_var, default)), 0)
22
- except ValueError:
23
- return int(default)
40
+ MODEL_LIST_FILE_LOCATION = os.environ.get(
41
+ "MODEL_LIST_FILE_LOCATION", "/etc/holmes/config/model_list.yaml"
42
+ )
24
43
 
25
44
 
26
45
  OVERRIDE_MAX_OUTPUT_TOKEN = environ_get_safe_int("OVERRIDE_MAX_OUTPUT_TOKEN")
27
46
  OVERRIDE_MAX_CONTENT_SIZE = environ_get_safe_int("OVERRIDE_MAX_CONTENT_SIZE")
28
47
 
29
48
 
49
+ def get_context_window_compaction_threshold_pct() -> int:
50
+ """Get the compaction threshold percentage at runtime to support test overrides."""
51
+ return environ_get_safe_int("CONTEXT_WINDOW_COMPACTION_THRESHOLD_PCT", default="95")
52
+
53
+
54
+ ROBUSTA_AI_MODEL_NAME = "Robusta"
55
+
56
+
57
+ class TokenCountMetadata(BaseModel):
58
+ total_tokens: int
59
+ tools_tokens: int
60
+ system_tokens: int
61
+ user_tokens: int
62
+ tools_to_call_tokens: int
63
+ assistant_tokens: int
64
+ other_tokens: int
65
+
66
+
67
+ class ModelEntry(BaseModel):
68
+ """ModelEntry represents a single LLM model configuration."""
69
+
70
+ model: str
71
+ # TODO: the name field seems to be redundant, can we remove it?
72
+ name: Optional[str] = None
73
+ api_key: Optional[SecretStr] = None
74
+ base_url: Optional[str] = None
75
+ is_robusta_model: Optional[bool] = None
76
+ custom_args: Optional[Dict[str, Any]] = None
77
+
78
+ # LLM configurations used services like Azure OpenAI Service
79
+ api_base: Optional[str] = None
80
+ api_version: Optional[str] = None
81
+
82
+ model_config = ConfigDict(
83
+ extra="allow",
84
+ )
85
+
86
+ @classmethod
87
+ def load_from_dict(cls, data: dict) -> Self:
88
+ return cls.model_validate(data)
89
+
90
+
30
91
  class LLM:
31
92
  @abstractmethod
32
93
  def __init__(self):
@@ -40,8 +101,23 @@ class LLM:
40
101
  def get_maximum_output_token(self) -> int:
41
102
  pass
42
103
 
104
+ def get_max_token_count_for_single_tool(self) -> int:
105
+ if (
106
+ 0 < TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT
107
+ and TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT <= 100
108
+ ):
109
+ context_window_size = self.get_context_window_size()
110
+ calculated_max_tokens = int(
111
+ context_window_size * TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT // 100
112
+ )
113
+ return min(calculated_max_tokens, TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS)
114
+ else:
115
+ return TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS
116
+
43
117
  @abstractmethod
44
- def count_tokens_for_message(self, messages: list[dict]) -> int:
118
+ def count_tokens(
119
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
120
+ ) -> TokenCountMetadata:
45
121
  pass
46
122
 
47
123
  @abstractmethod
@@ -61,31 +137,55 @@ class LLM:
61
137
  class DefaultLLM(LLM):
62
138
  model: str
63
139
  api_key: Optional[str]
64
- base_url: Optional[str]
140
+ api_base: Optional[str]
141
+ api_version: Optional[str]
65
142
  args: Dict
143
+ is_robusta_model: bool
66
144
 
67
145
  def __init__(
68
146
  self,
69
147
  model: str,
70
148
  api_key: Optional[str] = None,
149
+ api_base: Optional[str] = None,
150
+ api_version: Optional[str] = None,
71
151
  args: Optional[Dict] = None,
72
- tracer=None,
152
+ tracer: Optional[Any] = None,
153
+ name: Optional[str] = None,
154
+ is_robusta_model: bool = False,
73
155
  ):
74
156
  self.model = model
75
157
  self.api_key = api_key
158
+ self.api_base = api_base
159
+ self.api_version = api_version
76
160
  self.args = args or {}
77
161
  self.tracer = tracer
162
+ self.name = name
163
+ self.is_robusta_model = is_robusta_model
164
+ self.update_custom_args()
165
+ self.check_llm(
166
+ self.model, self.api_key, self.api_base, self.api_version, self.args
167
+ )
78
168
 
79
- if not self.args:
80
- self.check_llm(self.model, self.api_key)
169
+ def update_custom_args(self):
170
+ self.max_context_size = self.args.get("custom_args", {}).get("max_context_size")
171
+ self.args.pop("custom_args", None)
81
172
 
82
- def check_llm(self, model: str, api_key: Optional[str]):
173
+ def check_llm(
174
+ self,
175
+ model: str,
176
+ api_key: Optional[str],
177
+ api_base: Optional[str],
178
+ api_version: Optional[str],
179
+ args: Optional[dict] = None,
180
+ ):
181
+ if self.is_robusta_model:
182
+ # The model is assumed correctly configured if it is a robusta model
183
+ # For robusta models, this code would fail because Holmes has no knowledge of the API keys
184
+ # to azure or bedrock as all completion API calls go through robusta's LLM proxy
185
+ return
186
+ args = args or {}
83
187
  logging.debug(f"Checking LiteLLM model {model}")
84
- # TODO: this WAS a hack to get around the fact that we can't pass in an api key to litellm.validate_environment
85
- # so without this hack it always complains that the environment variable for the api key is missing
86
- # to fix that, we always set an api key in the standard format that litellm expects (which is ${PROVIDER}_API_KEY)
87
- # TODO: we can now handle this better - see https://github.com/BerriAI/litellm/issues/4375#issuecomment-2223684750
88
- lookup = litellm.get_llm_provider(self.model)
188
+ lookup = litellm.get_llm_provider(model)
89
189
  if not lookup:
90
190
  raise Exception(f"Unknown provider for model {model}")
91
191
  provider = lookup[1]
@@ -119,85 +219,151 @@ class DefaultLLM(LLM):
119
219
  "environment variable for proper functionality. For more information, refer to the documentation: "
120
220
  "https://docs.litellm.ai/docs/providers/watsonx#usage---models-in-deployment-spaces"
121
221
  )
122
- elif provider == "bedrock" and (
123
- os.environ.get("AWS_PROFILE") or os.environ.get("AWS_BEARER_TOKEN_BEDROCK")
124
- ):
125
- model_requirements = {"keys_in_environment": True, "missing_keys": []}
222
+ elif provider == "bedrock":
223
+ if os.environ.get("AWS_PROFILE") or os.environ.get(
224
+ "AWS_BEARER_TOKEN_BEDROCK"
225
+ ):
226
+ model_requirements = {"keys_in_environment": True, "missing_keys": []}
227
+ elif args.get("aws_access_key_id") and args.get("aws_secret_access_key"):
228
+ return # break fast.
229
+ else:
230
+ model_requirements = litellm.validate_environment(
231
+ model=model, api_key=api_key, api_base=api_base
232
+ )
126
233
  else:
127
- #
128
- api_key_env_var = f"{provider.upper()}_API_KEY"
129
- if api_key:
130
- os.environ[api_key_env_var] = api_key
131
- model_requirements = litellm.validate_environment(model=model)
234
+ model_requirements = litellm.validate_environment(
235
+ model=model, api_key=api_key, api_base=api_base
236
+ )
237
+ # validate_environment does not accept api_version, and as a special case for Azure OpenAI Service,
238
+ # when all the other AZURE environments are set expect AZURE_API_VERSION, validate_environment complains
239
+ # the missing of it even after the api_version is set.
240
+ # TODO: There's an open PR in litellm to accept api_version in validate_environment, we can leverage this
241
+ # change if accepted to ignore the following check.
242
+ # https://github.com/BerriAI/litellm/pull/13808
243
+ if (
244
+ provider == "azure"
245
+ and ["AZURE_API_VERSION"] == model_requirements["missing_keys"]
246
+ and api_version is not None
247
+ ):
248
+ model_requirements["missing_keys"] = []
249
+ model_requirements["keys_in_environment"] = True
132
250
 
133
251
  if not model_requirements["keys_in_environment"]:
134
252
  raise Exception(
135
253
  f"model {model} requires the following environment variables: {model_requirements['missing_keys']}"
136
254
  )
137
255
 
138
- def _strip_model_prefix(self) -> str:
256
+ def _get_model_name_variants_for_lookup(self) -> list[str]:
139
257
  """
140
- Helper function to strip 'openai/' prefix from model name if it exists.
141
- model cost is taken from here which does not have the openai prefix
142
- https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
258
+ Generate model name variants to try when looking up in litellm.model_cost.
259
+ Returns a list of names to try in order: exact, lowercase, without prefix, etc.
143
260
  """
144
- model_name = self.model
145
- prefixes = ["openai/", "bedrock/", "vertex_ai/", "anthropic/"]
261
+ names_to_try = [self.model, self.model.lower()]
146
262
 
147
- for prefix in prefixes:
148
- if model_name.startswith(prefix):
149
- return model_name[len(prefix) :]
263
+ # If there's a prefix, also try without it
264
+ if "/" in self.model:
265
+ base_model = self.model.split("/", 1)[1]
266
+ names_to_try.extend([base_model, base_model.lower()])
150
267
 
151
- return model_name
152
-
153
- # this unfortunately does not seem to work for azure if the deployment name is not a well-known model name
154
- # if not litellm.supports_function_calling(model=model):
155
- # raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
268
+ # Remove duplicates while preserving order (dict.fromkeys maintains insertion order in Python 3.7+)
269
+ return list(dict.fromkeys(names_to_try))
156
270
 
157
271
  def get_context_window_size(self) -> int:
272
+ if self.max_context_size:
273
+ return self.max_context_size
274
+
158
275
  if OVERRIDE_MAX_CONTENT_SIZE:
159
276
  logging.debug(
160
277
  f"Using override OVERRIDE_MAX_CONTENT_SIZE {OVERRIDE_MAX_CONTENT_SIZE}"
161
278
  )
162
279
  return OVERRIDE_MAX_CONTENT_SIZE
163
280
 
164
- model_name = os.environ.get("MODEL_TYPE", self._strip_model_prefix())
165
- try:
166
- return litellm.model_cost[model_name]["max_input_tokens"]
167
- except Exception:
168
- logging.warning(
169
- f"Couldn't find model's name {model_name} in litellm's model list, fallback to 128k tokens for max_input_tokens"
170
- )
171
- return 128000
281
+ # Try each name variant
282
+ for name in self._get_model_name_variants_for_lookup():
283
+ try:
284
+ return litellm.model_cost[name]["max_input_tokens"]
285
+ except Exception:
286
+ continue
287
+
288
+ # Log which lookups we tried
289
+ logging.warning(
290
+ f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
291
+ f"using default {FALLBACK_CONTEXT_WINDOW_SIZE} tokens for max_input_tokens. "
292
+ f"To override, set OVERRIDE_MAX_CONTENT_SIZE environment variable to the correct value for your model."
293
+ )
294
+ return FALLBACK_CONTEXT_WINDOW_SIZE
172
295
 
173
296
  @sentry_sdk.trace
174
- def count_tokens_for_message(self, messages: list[dict]) -> int:
175
- total_token_count = 0
297
+ def count_tokens(
298
+ self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
299
+ ) -> TokenCountMetadata:
300
+ # TODO: Add a recount:bool flag to save time. When the flag is false, reuse 'message["token_count"]' for individual messages.
301
+ # It's only necessary to recount message tokens at the beginning of a session because the LLM model may have changed.
302
+ # Changing the model requires recounting tokens because the tokenizer may be different
303
+ total_tokens = 0
304
+ tools_tokens = 0
305
+ system_tokens = 0
306
+ assistant_tokens = 0
307
+ user_tokens = 0
308
+ other_tokens = 0
309
+ tools_to_call_tokens = 0
176
310
  for message in messages:
177
- if "token_count" in message and message["token_count"]:
178
- total_token_count += message["token_count"]
311
+ # count message tokens individually because it gives us fine grain information about each tool call/message etc.
312
+ # However be aware that the sum of individual message tokens is not equal to the overall messages token
313
+ token_count = litellm.token_counter( # type: ignore
314
+ model=self.model, messages=[message]
315
+ )
316
+ message["token_count"] = token_count
317
+ role = message.get("role")
318
+ if role == "system":
319
+ system_tokens += token_count
320
+ elif role == "user":
321
+ user_tokens += token_count
322
+ elif role == "tool":
323
+ tools_tokens += token_count
324
+ elif role == "assistant":
325
+ assistant_tokens += token_count
179
326
  else:
180
- # message can be counted by this method only if message contains a "content" key
181
- if "content" in message:
182
- if isinstance(message["content"], str):
183
- message_to_count = [
184
- {"type": "text", "text": message["content"]}
185
- ]
186
- elif isinstance(message["content"], list):
187
- message_to_count = [
188
- {"type": "text", "text": json.dumps(message["content"])}
189
- ]
190
- elif isinstance(message["content"], dict):
191
- if "type" not in message["content"]:
192
- message_to_count = [
193
- {"type": "text", "text": json.dumps(message["content"])}
194
- ]
195
- token_count = litellm.token_counter(
196
- model=self.model, messages=message_to_count
197
- )
198
- message["token_count"] = token_count
199
- total_token_count += token_count
200
- return total_token_count
327
+ # although this should not be needed,
328
+ # it is defensive code so that all tokens are accounted for
329
+ # and can potentially make debugging easier
330
+ other_tokens += token_count
331
+
332
+ messages_token_count_without_tools = litellm.token_counter( # type: ignore
333
+ model=self.model, messages=messages
334
+ )
335
+
336
+ total_tokens = litellm.token_counter( # type: ignore
337
+ model=self.model,
338
+ messages=messages,
339
+ tools=tools, # type: ignore
340
+ )
341
+ tools_to_call_tokens = max(0, total_tokens - messages_token_count_without_tools)
342
+
343
+ return TokenCountMetadata(
344
+ total_tokens=total_tokens,
345
+ system_tokens=system_tokens,
346
+ user_tokens=user_tokens,
347
+ tools_tokens=tools_tokens,
348
+ tools_to_call_tokens=tools_to_call_tokens,
349
+ other_tokens=other_tokens,
350
+ assistant_tokens=assistant_tokens,
351
+ )
352
+
353
+ def get_litellm_corrected_name_for_robusta_ai(self) -> str:
354
+ if self.is_robusta_model:
355
+ # For robusta models, self.model is the underlying provider/model used by Robusta AI
356
+ # To avoid litellm modifying the API URL according to the provider, the provider name
357
+ # is replaced with 'openai/' just before doing a completion() call
358
+ # Cf. https://docs.litellm.ai/docs/providers/openai_compatible
359
+ split_model_name = self.model.split("/")
360
+ return (
361
+ split_model_name[0]
362
+ if len(split_model_name) == 1
363
+ else f"openai/{split_model_name[1]}"
364
+ )
365
+ else:
366
+ return self.model
201
367
 
202
368
  def completion(
203
369
  self,
@@ -219,6 +385,9 @@ class DefaultLLM(LLM):
219
385
  if THINKING:
220
386
  self.args.setdefault("thinking", json.loads(THINKING))
221
387
 
388
+ if EXTRA_HEADERS:
389
+ self.args.setdefault("extra_headers", json.loads(EXTRA_HEADERS))
390
+
222
391
  if self.args.get("thinking", None):
223
392
  litellm.modify_params = True
224
393
 
@@ -228,20 +397,31 @@ class DefaultLLM(LLM):
228
397
  "reasoning_effort"
229
398
  ] # can be removed after next litelm version
230
399
 
400
+ existing_allowed = self.args.pop("allowed_openai_params", None)
401
+ if existing_allowed:
402
+ if allowed_openai_params is None:
403
+ allowed_openai_params = []
404
+ allowed_openai_params.extend(existing_allowed)
405
+
231
406
  self.args.setdefault("temperature", temperature)
232
407
 
233
408
  self._add_cache_control_to_last_message(messages)
234
409
 
235
410
  # Get the litellm module to use (wrapped or unwrapped)
236
411
  litellm_to_use = self.tracer.wrap_llm(litellm) if self.tracer else litellm
412
+
413
+ litellm_model_name = self.get_litellm_corrected_name_for_robusta_ai()
237
414
  result = litellm_to_use.completion(
238
- model=self.model,
415
+ model=litellm_model_name,
239
416
  api_key=self.api_key,
417
+ base_url=self.api_base,
418
+ api_version=self.api_version,
240
419
  messages=messages,
241
420
  response_format=response_format,
242
421
  drop_params=drop_params,
243
422
  allowed_openai_params=allowed_openai_params,
244
423
  stream=stream,
424
+ timeout=LLM_REQUEST_TIMEOUT,
245
425
  **tools_args,
246
426
  **self.args,
247
427
  )
@@ -254,20 +434,33 @@ class DefaultLLM(LLM):
254
434
  raise Exception(f"Unexpected type returned by the LLM {type(result)}")
255
435
 
256
436
  def get_maximum_output_token(self) -> int:
437
+ max_output_tokens = floor(min(64000, self.get_context_window_size() / 5))
438
+
257
439
  if OVERRIDE_MAX_OUTPUT_TOKEN:
258
440
  logging.debug(
259
441
  f"Using OVERRIDE_MAX_OUTPUT_TOKEN {OVERRIDE_MAX_OUTPUT_TOKEN}"
260
442
  )
261
443
  return OVERRIDE_MAX_OUTPUT_TOKEN
262
444
 
263
- model_name = os.environ.get("MODEL_TYPE", self._strip_model_prefix())
264
- try:
265
- return litellm.model_cost[model_name]["max_output_tokens"]
266
- except Exception:
267
- logging.warning(
268
- f"Couldn't find model's name {model_name} in litellm's model list, fallback to 4096 tokens for max_output_tokens"
269
- )
270
- return 4096
445
+ # Try each name variant
446
+ for name in self._get_model_name_variants_for_lookup():
447
+ try:
448
+ litellm_max_output_tokens = litellm.model_cost[name][
449
+ "max_output_tokens"
450
+ ]
451
+ if litellm_max_output_tokens < max_output_tokens:
452
+ max_output_tokens = litellm_max_output_tokens
453
+ return max_output_tokens
454
+ except Exception:
455
+ continue
456
+
457
+ # Log which lookups we tried
458
+ logging.warning(
459
+ f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
460
+ f"using {max_output_tokens} tokens for max_output_tokens. "
461
+ f"To override, set OVERRIDE_MAX_OUTPUT_TOKEN environment variable to the correct value for your model."
462
+ )
463
+ return max_output_tokens
271
464
 
272
465
  def _add_cache_control_to_last_message(
273
466
  self, messages: List[Dict[str, Any]]
@@ -276,6 +469,12 @@ class DefaultLLM(LLM):
276
469
  Add cache_control to the last non-user message for Anthropic prompt caching.
277
470
  Removes any existing cache_control from previous messages to avoid accumulation.
278
471
  """
472
+ # Skip cache_control for VertexAI/Gemini models as they don't support it with tools
473
+ if self.model and (
474
+ "vertex" in self.model.lower() or "gemini" in self.model.lower()
475
+ ):
476
+ return
477
+
279
478
  # First, remove any existing cache_control from all messages
280
479
  for msg in messages:
281
480
  content = msg.get("content")
@@ -305,7 +504,7 @@ class DefaultLLM(LLM):
305
504
  if content is None:
306
505
  return
307
506
 
308
- if isinstance(content, str):
507
+ if isinstance(content, str) and content:
309
508
  # Convert string to structured format with cache_control
310
509
  target_msg["content"] = [
311
510
  {
@@ -325,3 +524,224 @@ class DefaultLLM(LLM):
325
524
  logging.debug(
326
525
  f"Added cache_control to {target_msg.get('role')} message (structured content)"
327
526
  )
527
+
528
+
529
+ class LLMModelRegistry:
530
+ def __init__(self, config: "Config", dal: SupabaseDal) -> None:
531
+ self.config = config
532
+ self._llms: dict[str, ModelEntry] = {}
533
+ self._default_robusta_model = None
534
+ self.dal = dal
535
+ self._lock = threading.RLock()
536
+
537
+ self._init_models()
538
+
539
+ @property
540
+ def default_robusta_model(self) -> Optional[str]:
541
+ return self._default_robusta_model
542
+
543
+ def _init_models(self):
544
+ self._llms = self._parse_models_file(MODEL_LIST_FILE_LOCATION)
545
+
546
+ if self._should_load_robusta_ai():
547
+ self.configure_robusta_ai_model()
548
+
549
+ if self._should_load_config_model():
550
+ self._llms[self.config.model] = self._create_model_entry(
551
+ model=self.config.model,
552
+ model_name=self.config.model,
553
+ base_url=self.config.api_base,
554
+ is_robusta_model=False,
555
+ api_key=self.config.api_key,
556
+ api_version=self.config.api_version,
557
+ )
558
+
559
+ def _should_load_config_model(self) -> bool:
560
+ if self.config.model is not None:
561
+ if self._llms and self.config.model in self._llms:
562
+ # model already loaded from file
563
+ return False
564
+ return True
565
+
566
+ # backward compatibility - in the past config.model was set by default to gpt-4o.
567
+ # so we need to check if the user has set an OPENAI_API_KEY to load the config model.
568
+ has_openai_key = os.environ.get("OPENAI_API_KEY")
569
+ if has_openai_key:
570
+ self.config.model = "gpt-4.1"
571
+ return True
572
+
573
+ return False
574
+
575
+ def configure_robusta_ai_model(self) -> None:
576
+ try:
577
+ if not self.config.cluster_name or not LOAD_ALL_ROBUSTA_MODELS:
578
+ self._load_default_robusta_config()
579
+ return
580
+
581
+ if not self.dal.account_id or not self.dal.enabled:
582
+ self._load_default_robusta_config()
583
+ return
584
+
585
+ account_id, token = self.dal.get_ai_credentials()
586
+
587
+ robusta_models: RobustaModelsResponse | None = fetch_robusta_models(
588
+ account_id, token
589
+ )
590
+ if not robusta_models or not robusta_models.models:
591
+ self._load_default_robusta_config()
592
+ return
593
+
594
+ default_model = None
595
+ for model_name, model_data in robusta_models.models.items():
596
+ logging.info(f"Loading Robusta AI model: {model_name}")
597
+ self._llms[model_name] = self._create_robusta_model_entry(
598
+ model_name=model_name, model_data=model_data
599
+ )
600
+ if model_data.is_default:
601
+ default_model = model_name
602
+
603
+ if default_model:
604
+ logging.info(f"Setting default Robusta AI model to: {default_model}")
605
+ self._default_robusta_model: str = default_model # type: ignore
606
+
607
+ except Exception:
608
+ logging.exception("Failed to get all robusta models")
609
+ # fallback to default behavior
610
+ self._load_default_robusta_config()
611
+
612
+ def _load_default_robusta_config(self):
613
+ if self._should_load_robusta_ai():
614
+ logging.info("Loading default Robusta AI model")
615
+ self._llms[ROBUSTA_AI_MODEL_NAME] = ModelEntry(
616
+ name=ROBUSTA_AI_MODEL_NAME,
617
+ model="gpt-4o", # TODO: tech debt, this isn't really
618
+ base_url=ROBUSTA_API_ENDPOINT,
619
+ is_robusta_model=True,
620
+ )
621
+ self._default_robusta_model = ROBUSTA_AI_MODEL_NAME
622
+
623
+ def _should_load_robusta_ai(self) -> bool:
624
+ if not self.config.should_try_robusta_ai:
625
+ return False
626
+
627
+ # ROBUSTA_AI were set in the env vars, so we can use it directly
628
+ if ROBUSTA_AI is not None:
629
+ return ROBUSTA_AI
630
+
631
+ # MODEL is set in the env vars, e.g. the user is using a custom model
632
+ # so we don't need to load the robusta AI model and keep the behavior backward compatible
633
+ if "MODEL" in os.environ:
634
+ return False
635
+
636
+ # if the user has provided a model list, we don't need to load the robusta AI model
637
+ if self._llms:
638
+ return False
639
+
640
+ return True
641
+
642
+ def get_model_params(self, model_key: Optional[str] = None) -> ModelEntry:
643
+ with self._lock:
644
+ if not self._llms:
645
+ raise Exception("No llm models were loaded")
646
+
647
+ if model_key:
648
+ model_params = self._llms.get(model_key)
649
+ if model_params:
650
+ logging.info(f"Using selected model: {model_key}")
651
+ return model_params.model_copy()
652
+
653
+ if model_key.startswith("Robusta/"):
654
+ logging.warning("Resyncing Registry and Robusta models.")
655
+ self._init_models()
656
+ model_params = self._llms.get(model_key)
657
+ if model_params:
658
+ logging.info(f"Using selected model: {model_key}")
659
+ return model_params.model_copy()
660
+
661
+ logging.error(f"Couldn't find model: {model_key} in model list")
662
+
663
+ if self._default_robusta_model:
664
+ model_params = self._llms.get(self._default_robusta_model)
665
+ if model_params is not None:
666
+ logging.info(
667
+ f"Using default Robusta AI model: {self._default_robusta_model}"
668
+ )
669
+ return model_params.model_copy()
670
+
671
+ logging.error(
672
+ f"Couldn't find default Robusta AI model: {self._default_robusta_model} in model list"
673
+ )
674
+
675
+ model_key, first_model_params = next(iter(self._llms.items()))
676
+ logging.debug(f"Using first available model: {model_key}")
677
+ return first_model_params.model_copy()
678
+
679
+ @property
680
+ def models(self) -> dict[str, ModelEntry]:
681
+ with self._lock:
682
+ return self._llms
683
+
684
+ def _parse_models_file(self, path: str) -> dict[str, ModelEntry]:
685
+ models = load_yaml_file(path, raise_error=False, warn_not_found=False)
686
+ for _, params in models.items():
687
+ params = replace_env_vars_values(params)
688
+
689
+ llms = {}
690
+ for model_name, params in models.items():
691
+ llms[model_name] = ModelEntry.model_validate(params)
692
+
693
+ return llms
694
+
695
+ def _create_robusta_model_entry(
696
+ self, model_name: str, model_data: RobustaModel
697
+ ) -> ModelEntry:
698
+ entry = self._create_model_entry(
699
+ model=model_data.model,
700
+ model_name=model_name,
701
+ base_url=f"{ROBUSTA_API_ENDPOINT}/llm/{model_name}",
702
+ is_robusta_model=True,
703
+ )
704
+ entry.custom_args = model_data.holmes_args or {} # type: ignore[assignment]
705
+ return entry
706
+
707
+ def _create_model_entry(
708
+ self,
709
+ model: str,
710
+ model_name: str,
711
+ base_url: Optional[str] = None,
712
+ is_robusta_model: Optional[bool] = None,
713
+ api_key: Optional[SecretStr] = None,
714
+ api_base: Optional[str] = None,
715
+ api_version: Optional[str] = None,
716
+ ) -> ModelEntry:
717
+ return ModelEntry(
718
+ name=model_name,
719
+ model=model,
720
+ base_url=base_url,
721
+ is_robusta_model=is_robusta_model,
722
+ api_key=api_key,
723
+ api_base=api_base,
724
+ api_version=api_version,
725
+ )
726
+
727
+
728
+ def get_llm_usage(
729
+ llm_response: Union[ModelResponse, CustomStreamWrapper, TextCompletionResponse],
730
+ ) -> dict:
731
+ usage: dict = {}
732
+ if (
733
+ (
734
+ isinstance(llm_response, ModelResponse)
735
+ or isinstance(llm_response, TextCompletionResponse)
736
+ )
737
+ and hasattr(llm_response, "usage")
738
+ and llm_response.usage
739
+ ): # type: ignore
740
+ usage["prompt_tokens"] = llm_response.usage.prompt_tokens # type: ignore
741
+ usage["completion_tokens"] = llm_response.usage.completion_tokens # type: ignore
742
+ usage["total_tokens"] = llm_response.usage.total_tokens # type: ignore
743
+ elif isinstance(llm_response, CustomStreamWrapper):
744
+ complete_response = litellm.stream_chunk_builder(chunks=llm_response) # type: ignore
745
+ if complete_response:
746
+ return get_llm_usage(complete_response)
747
+ return usage