holmesgpt 0.13.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- holmes/__init__.py +3 -5
- holmes/clients/robusta_client.py +20 -6
- holmes/common/env_vars.py +58 -3
- holmes/common/openshift.py +1 -1
- holmes/config.py +123 -148
- holmes/core/conversations.py +71 -15
- holmes/core/feedback.py +191 -0
- holmes/core/investigation.py +31 -39
- holmes/core/investigation_structured_output.py +3 -3
- holmes/core/issue.py +1 -1
- holmes/core/llm.py +508 -88
- holmes/core/models.py +108 -4
- holmes/core/openai_formatting.py +14 -1
- holmes/core/prompt.py +48 -3
- holmes/core/runbooks.py +1 -0
- holmes/core/safeguards.py +8 -6
- holmes/core/supabase_dal.py +295 -100
- holmes/core/tool_calling_llm.py +489 -428
- holmes/core/tools.py +325 -56
- holmes/core/tools_utils/token_counting.py +21 -0
- holmes/core/tools_utils/tool_context_window_limiter.py +40 -0
- holmes/core/tools_utils/tool_executor.py +0 -13
- holmes/core/tools_utils/toolset_utils.py +1 -0
- holmes/core/toolset_manager.py +191 -5
- holmes/core/tracing.py +19 -3
- holmes/core/transformers/__init__.py +23 -0
- holmes/core/transformers/base.py +63 -0
- holmes/core/transformers/llm_summarize.py +175 -0
- holmes/core/transformers/registry.py +123 -0
- holmes/core/transformers/transformer.py +32 -0
- holmes/core/truncation/compaction.py +94 -0
- holmes/core/truncation/dal_truncation_utils.py +23 -0
- holmes/core/truncation/input_context_window_limiter.py +219 -0
- holmes/interactive.py +228 -31
- holmes/main.py +23 -40
- holmes/plugins/interfaces.py +2 -1
- holmes/plugins/prompts/__init__.py +2 -1
- holmes/plugins/prompts/_fetch_logs.jinja2 +31 -6
- holmes/plugins/prompts/_general_instructions.jinja2 +1 -2
- holmes/plugins/prompts/_runbook_instructions.jinja2 +24 -12
- holmes/plugins/prompts/base_user_prompt.jinja2 +7 -0
- holmes/plugins/prompts/conversation_history_compaction.jinja2 +89 -0
- holmes/plugins/prompts/generic_ask.jinja2 +0 -4
- holmes/plugins/prompts/generic_ask_conversation.jinja2 +0 -1
- holmes/plugins/prompts/generic_ask_for_issue_conversation.jinja2 +0 -1
- holmes/plugins/prompts/generic_investigation.jinja2 +0 -1
- holmes/plugins/prompts/investigation_procedure.jinja2 +50 -1
- holmes/plugins/prompts/kubernetes_workload_ask.jinja2 +0 -1
- holmes/plugins/prompts/kubernetes_workload_chat.jinja2 +0 -1
- holmes/plugins/runbooks/__init__.py +145 -17
- holmes/plugins/runbooks/catalog.json +2 -0
- holmes/plugins/sources/github/__init__.py +4 -2
- holmes/plugins/sources/prometheus/models.py +1 -0
- holmes/plugins/toolsets/__init__.py +44 -27
- holmes/plugins/toolsets/aks-node-health.yaml +46 -0
- holmes/plugins/toolsets/aks.yaml +64 -0
- holmes/plugins/toolsets/atlas_mongodb/mongodb_atlas.py +38 -47
- holmes/plugins/toolsets/azure_sql/apis/alert_monitoring_api.py +3 -2
- holmes/plugins/toolsets/azure_sql/apis/azure_sql_api.py +2 -1
- holmes/plugins/toolsets/azure_sql/apis/connection_failure_api.py +3 -2
- holmes/plugins/toolsets/azure_sql/apis/connection_monitoring_api.py +3 -1
- holmes/plugins/toolsets/azure_sql/apis/storage_analysis_api.py +3 -1
- holmes/plugins/toolsets/azure_sql/azure_sql_toolset.py +12 -13
- holmes/plugins/toolsets/azure_sql/tools/analyze_connection_failures.py +15 -12
- holmes/plugins/toolsets/azure_sql/tools/analyze_database_connections.py +15 -12
- holmes/plugins/toolsets/azure_sql/tools/analyze_database_health_status.py +11 -11
- holmes/plugins/toolsets/azure_sql/tools/analyze_database_performance.py +11 -9
- holmes/plugins/toolsets/azure_sql/tools/analyze_database_storage.py +15 -12
- holmes/plugins/toolsets/azure_sql/tools/get_active_alerts.py +15 -15
- holmes/plugins/toolsets/azure_sql/tools/get_slow_queries.py +11 -8
- holmes/plugins/toolsets/azure_sql/tools/get_top_cpu_queries.py +11 -8
- holmes/plugins/toolsets/azure_sql/tools/get_top_data_io_queries.py +11 -8
- holmes/plugins/toolsets/azure_sql/tools/get_top_log_io_queries.py +11 -8
- holmes/plugins/toolsets/azure_sql/utils.py +0 -32
- holmes/plugins/toolsets/bash/argocd/__init__.py +3 -3
- holmes/plugins/toolsets/bash/aws/__init__.py +4 -4
- holmes/plugins/toolsets/bash/azure/__init__.py +4 -4
- holmes/plugins/toolsets/bash/bash_toolset.py +11 -15
- holmes/plugins/toolsets/bash/common/bash.py +23 -13
- holmes/plugins/toolsets/bash/common/bash_command.py +1 -1
- holmes/plugins/toolsets/bash/common/stringify.py +1 -1
- holmes/plugins/toolsets/bash/kubectl/__init__.py +2 -1
- holmes/plugins/toolsets/bash/kubectl/constants.py +0 -1
- holmes/plugins/toolsets/bash/kubectl/kubectl_get.py +3 -4
- holmes/plugins/toolsets/bash/parse_command.py +12 -13
- holmes/plugins/toolsets/cilium.yaml +284 -0
- holmes/plugins/toolsets/connectivity_check.py +124 -0
- holmes/plugins/toolsets/coralogix/api.py +132 -119
- holmes/plugins/toolsets/coralogix/coralogix.jinja2 +14 -0
- holmes/plugins/toolsets/coralogix/toolset_coralogix.py +219 -0
- holmes/plugins/toolsets/coralogix/utils.py +15 -79
- holmes/plugins/toolsets/datadog/datadog_api.py +525 -26
- holmes/plugins/toolsets/datadog/datadog_logs_instructions.jinja2 +55 -11
- holmes/plugins/toolsets/datadog/datadog_metrics_instructions.jinja2 +3 -3
- holmes/plugins/toolsets/datadog/datadog_models.py +59 -0
- holmes/plugins/toolsets/datadog/datadog_url_utils.py +213 -0
- holmes/plugins/toolsets/datadog/instructions_datadog_traces.jinja2 +165 -28
- holmes/plugins/toolsets/datadog/toolset_datadog_general.py +417 -241
- holmes/plugins/toolsets/datadog/toolset_datadog_logs.py +234 -214
- holmes/plugins/toolsets/datadog/toolset_datadog_metrics.py +167 -79
- holmes/plugins/toolsets/datadog/toolset_datadog_traces.py +374 -363
- holmes/plugins/toolsets/elasticsearch/__init__.py +6 -0
- holmes/plugins/toolsets/elasticsearch/elasticsearch.py +834 -0
- holmes/plugins/toolsets/elasticsearch/opensearch_ppl_query_docs.jinja2 +1616 -0
- holmes/plugins/toolsets/elasticsearch/opensearch_query_assist.py +78 -0
- holmes/plugins/toolsets/elasticsearch/opensearch_query_assist_instructions.jinja2 +223 -0
- holmes/plugins/toolsets/git.py +54 -50
- holmes/plugins/toolsets/grafana/base_grafana_toolset.py +16 -4
- holmes/plugins/toolsets/grafana/common.py +13 -29
- holmes/plugins/toolsets/grafana/grafana_tempo_api.py +455 -0
- holmes/plugins/toolsets/grafana/loki/instructions.jinja2 +25 -0
- holmes/plugins/toolsets/grafana/loki/toolset_grafana_loki.py +191 -0
- holmes/plugins/toolsets/grafana/loki_api.py +4 -0
- holmes/plugins/toolsets/grafana/toolset_grafana.py +293 -89
- holmes/plugins/toolsets/grafana/toolset_grafana_dashboard.jinja2 +49 -0
- holmes/plugins/toolsets/grafana/toolset_grafana_tempo.jinja2 +246 -11
- holmes/plugins/toolsets/grafana/toolset_grafana_tempo.py +820 -292
- holmes/plugins/toolsets/grafana/trace_parser.py +4 -3
- holmes/plugins/toolsets/internet/internet.py +15 -16
- holmes/plugins/toolsets/internet/notion.py +9 -11
- holmes/plugins/toolsets/investigator/core_investigation.py +44 -36
- holmes/plugins/toolsets/investigator/model.py +3 -1
- holmes/plugins/toolsets/json_filter_mixin.py +134 -0
- holmes/plugins/toolsets/kafka.py +36 -42
- holmes/plugins/toolsets/kubernetes.yaml +317 -113
- holmes/plugins/toolsets/kubernetes_logs.py +9 -9
- holmes/plugins/toolsets/kubernetes_logs.yaml +32 -0
- holmes/plugins/toolsets/logging_utils/logging_api.py +94 -8
- holmes/plugins/toolsets/mcp/toolset_mcp.py +218 -64
- holmes/plugins/toolsets/newrelic/new_relic_api.py +165 -0
- holmes/plugins/toolsets/newrelic/newrelic.jinja2 +65 -0
- holmes/plugins/toolsets/newrelic/newrelic.py +320 -0
- holmes/plugins/toolsets/openshift.yaml +283 -0
- holmes/plugins/toolsets/prometheus/prometheus.py +1202 -421
- holmes/plugins/toolsets/prometheus/prometheus_instructions.jinja2 +54 -5
- holmes/plugins/toolsets/prometheus/utils.py +28 -0
- holmes/plugins/toolsets/rabbitmq/api.py +23 -4
- holmes/plugins/toolsets/rabbitmq/toolset_rabbitmq.py +13 -14
- holmes/plugins/toolsets/robusta/robusta.py +239 -68
- holmes/plugins/toolsets/robusta/robusta_instructions.jinja2 +26 -9
- holmes/plugins/toolsets/runbook/runbook_fetcher.py +157 -27
- holmes/plugins/toolsets/service_discovery.py +1 -1
- holmes/plugins/toolsets/servicenow_tables/instructions.jinja2 +83 -0
- holmes/plugins/toolsets/servicenow_tables/servicenow_tables.py +426 -0
- holmes/plugins/toolsets/utils.py +88 -0
- holmes/utils/config_utils.py +91 -0
- holmes/utils/connection_utils.py +31 -0
- holmes/utils/console/result.py +10 -0
- holmes/utils/default_toolset_installation_guide.jinja2 +1 -22
- holmes/utils/env.py +7 -0
- holmes/utils/file_utils.py +2 -1
- holmes/utils/global_instructions.py +60 -11
- holmes/utils/holmes_status.py +6 -4
- holmes/utils/holmes_sync_toolsets.py +0 -2
- holmes/utils/krr_utils.py +188 -0
- holmes/utils/log.py +15 -0
- holmes/utils/markdown_utils.py +2 -3
- holmes/utils/memory_limit.py +58 -0
- holmes/utils/sentry_helper.py +64 -0
- holmes/utils/stream.py +69 -8
- holmes/utils/tags.py +4 -3
- holmes/version.py +37 -15
- holmesgpt-0.18.4.dist-info/LICENSE +178 -0
- {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/METADATA +35 -31
- holmesgpt-0.18.4.dist-info/RECORD +258 -0
- holmes/core/performance_timing.py +0 -72
- holmes/plugins/toolsets/aws.yaml +0 -80
- holmes/plugins/toolsets/coralogix/toolset_coralogix_logs.py +0 -112
- holmes/plugins/toolsets/datadog/datadog_traces_formatter.py +0 -310
- holmes/plugins/toolsets/datadog/toolset_datadog_rds.py +0 -739
- holmes/plugins/toolsets/grafana/grafana_api.py +0 -42
- holmes/plugins/toolsets/grafana/tempo_api.py +0 -124
- holmes/plugins/toolsets/grafana/toolset_grafana_loki.py +0 -110
- holmes/plugins/toolsets/newrelic.py +0 -231
- holmes/plugins/toolsets/opensearch/opensearch.py +0 -257
- holmes/plugins/toolsets/opensearch/opensearch_logs.py +0 -161
- holmes/plugins/toolsets/opensearch/opensearch_traces.py +0 -218
- holmes/plugins/toolsets/opensearch/opensearch_traces_instructions.jinja2 +0 -12
- holmes/plugins/toolsets/opensearch/opensearch_utils.py +0 -166
- holmes/plugins/toolsets/servicenow/install.md +0 -37
- holmes/plugins/toolsets/servicenow/instructions.jinja2 +0 -3
- holmes/plugins/toolsets/servicenow/servicenow.py +0 -219
- holmes/utils/keygen_utils.py +0 -6
- holmesgpt-0.13.2.dist-info/LICENSE.txt +0 -21
- holmesgpt-0.13.2.dist-info/RECORD +0 -234
- /holmes/plugins/toolsets/{opensearch → newrelic}/__init__.py +0 -0
- {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/WHEEL +0 -0
- {holmesgpt-0.13.2.dist-info → holmesgpt-0.18.4.dist-info}/entry_points.txt +0 -0
holmes/core/llm.py
CHANGED
|
@@ -1,32 +1,93 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
import threading
|
|
3
5
|
from abc import abstractmethod
|
|
4
|
-
from
|
|
6
|
+
from math import floor
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
|
5
8
|
|
|
6
|
-
|
|
9
|
+
import litellm
|
|
7
10
|
import sentry_sdk
|
|
8
|
-
|
|
9
11
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
|
10
|
-
from
|
|
11
|
-
import
|
|
12
|
-
import
|
|
12
|
+
from litellm.types.utils import ModelResponse, TextCompletionResponse
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, SecretStr
|
|
14
|
+
from typing_extensions import Self
|
|
15
|
+
|
|
16
|
+
from holmes.clients.robusta_client import (
|
|
17
|
+
RobustaModel,
|
|
18
|
+
RobustaModelsResponse,
|
|
19
|
+
fetch_robusta_models,
|
|
20
|
+
)
|
|
13
21
|
from holmes.common.env_vars import (
|
|
22
|
+
EXTRA_HEADERS,
|
|
23
|
+
FALLBACK_CONTEXT_WINDOW_SIZE,
|
|
24
|
+
LLM_REQUEST_TIMEOUT,
|
|
25
|
+
LOAD_ALL_ROBUSTA_MODELS,
|
|
14
26
|
REASONING_EFFORT,
|
|
27
|
+
ROBUSTA_AI,
|
|
28
|
+
ROBUSTA_API_ENDPOINT,
|
|
15
29
|
THINKING,
|
|
30
|
+
TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT,
|
|
31
|
+
TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS,
|
|
16
32
|
)
|
|
33
|
+
from holmes.core.supabase_dal import SupabaseDal
|
|
34
|
+
from holmes.utils.env import environ_get_safe_int, replace_env_vars_values
|
|
35
|
+
from holmes.utils.file_utils import load_yaml_file
|
|
17
36
|
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from holmes.config import Config
|
|
18
39
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
except ValueError:
|
|
23
|
-
return int(default)
|
|
40
|
+
MODEL_LIST_FILE_LOCATION = os.environ.get(
|
|
41
|
+
"MODEL_LIST_FILE_LOCATION", "/etc/holmes/config/model_list.yaml"
|
|
42
|
+
)
|
|
24
43
|
|
|
25
44
|
|
|
26
45
|
OVERRIDE_MAX_OUTPUT_TOKEN = environ_get_safe_int("OVERRIDE_MAX_OUTPUT_TOKEN")
|
|
27
46
|
OVERRIDE_MAX_CONTENT_SIZE = environ_get_safe_int("OVERRIDE_MAX_CONTENT_SIZE")
|
|
28
47
|
|
|
29
48
|
|
|
49
|
+
def get_context_window_compaction_threshold_pct() -> int:
|
|
50
|
+
"""Get the compaction threshold percentage at runtime to support test overrides."""
|
|
51
|
+
return environ_get_safe_int("CONTEXT_WINDOW_COMPACTION_THRESHOLD_PCT", default="95")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
ROBUSTA_AI_MODEL_NAME = "Robusta"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TokenCountMetadata(BaseModel):
|
|
58
|
+
total_tokens: int
|
|
59
|
+
tools_tokens: int
|
|
60
|
+
system_tokens: int
|
|
61
|
+
user_tokens: int
|
|
62
|
+
tools_to_call_tokens: int
|
|
63
|
+
assistant_tokens: int
|
|
64
|
+
other_tokens: int
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ModelEntry(BaseModel):
|
|
68
|
+
"""ModelEntry represents a single LLM model configuration."""
|
|
69
|
+
|
|
70
|
+
model: str
|
|
71
|
+
# TODO: the name field seems to be redundant, can we remove it?
|
|
72
|
+
name: Optional[str] = None
|
|
73
|
+
api_key: Optional[SecretStr] = None
|
|
74
|
+
base_url: Optional[str] = None
|
|
75
|
+
is_robusta_model: Optional[bool] = None
|
|
76
|
+
custom_args: Optional[Dict[str, Any]] = None
|
|
77
|
+
|
|
78
|
+
# LLM configurations used services like Azure OpenAI Service
|
|
79
|
+
api_base: Optional[str] = None
|
|
80
|
+
api_version: Optional[str] = None
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(
|
|
83
|
+
extra="allow",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def load_from_dict(cls, data: dict) -> Self:
|
|
88
|
+
return cls.model_validate(data)
|
|
89
|
+
|
|
90
|
+
|
|
30
91
|
class LLM:
|
|
31
92
|
@abstractmethod
|
|
32
93
|
def __init__(self):
|
|
@@ -40,8 +101,23 @@ class LLM:
|
|
|
40
101
|
def get_maximum_output_token(self) -> int:
|
|
41
102
|
pass
|
|
42
103
|
|
|
104
|
+
def get_max_token_count_for_single_tool(self) -> int:
|
|
105
|
+
if (
|
|
106
|
+
0 < TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT
|
|
107
|
+
and TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT <= 100
|
|
108
|
+
):
|
|
109
|
+
context_window_size = self.get_context_window_size()
|
|
110
|
+
calculated_max_tokens = int(
|
|
111
|
+
context_window_size * TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_PCT // 100
|
|
112
|
+
)
|
|
113
|
+
return min(calculated_max_tokens, TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS)
|
|
114
|
+
else:
|
|
115
|
+
return TOOL_MAX_ALLOCATED_CONTEXT_WINDOW_TOKENS
|
|
116
|
+
|
|
43
117
|
@abstractmethod
|
|
44
|
-
def
|
|
118
|
+
def count_tokens(
|
|
119
|
+
self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
|
|
120
|
+
) -> TokenCountMetadata:
|
|
45
121
|
pass
|
|
46
122
|
|
|
47
123
|
@abstractmethod
|
|
@@ -61,31 +137,55 @@ class LLM:
|
|
|
61
137
|
class DefaultLLM(LLM):
|
|
62
138
|
model: str
|
|
63
139
|
api_key: Optional[str]
|
|
64
|
-
|
|
140
|
+
api_base: Optional[str]
|
|
141
|
+
api_version: Optional[str]
|
|
65
142
|
args: Dict
|
|
143
|
+
is_robusta_model: bool
|
|
66
144
|
|
|
67
145
|
def __init__(
|
|
68
146
|
self,
|
|
69
147
|
model: str,
|
|
70
148
|
api_key: Optional[str] = None,
|
|
149
|
+
api_base: Optional[str] = None,
|
|
150
|
+
api_version: Optional[str] = None,
|
|
71
151
|
args: Optional[Dict] = None,
|
|
72
|
-
tracer=None,
|
|
152
|
+
tracer: Optional[Any] = None,
|
|
153
|
+
name: Optional[str] = None,
|
|
154
|
+
is_robusta_model: bool = False,
|
|
73
155
|
):
|
|
74
156
|
self.model = model
|
|
75
157
|
self.api_key = api_key
|
|
158
|
+
self.api_base = api_base
|
|
159
|
+
self.api_version = api_version
|
|
76
160
|
self.args = args or {}
|
|
77
161
|
self.tracer = tracer
|
|
162
|
+
self.name = name
|
|
163
|
+
self.is_robusta_model = is_robusta_model
|
|
164
|
+
self.update_custom_args()
|
|
165
|
+
self.check_llm(
|
|
166
|
+
self.model, self.api_key, self.api_base, self.api_version, self.args
|
|
167
|
+
)
|
|
78
168
|
|
|
79
|
-
|
|
80
|
-
|
|
169
|
+
def update_custom_args(self):
|
|
170
|
+
self.max_context_size = self.args.get("custom_args", {}).get("max_context_size")
|
|
171
|
+
self.args.pop("custom_args", None)
|
|
81
172
|
|
|
82
|
-
def check_llm(
|
|
173
|
+
def check_llm(
|
|
174
|
+
self,
|
|
175
|
+
model: str,
|
|
176
|
+
api_key: Optional[str],
|
|
177
|
+
api_base: Optional[str],
|
|
178
|
+
api_version: Optional[str],
|
|
179
|
+
args: Optional[dict] = None,
|
|
180
|
+
):
|
|
181
|
+
if self.is_robusta_model:
|
|
182
|
+
# The model is assumed correctly configured if it is a robusta model
|
|
183
|
+
# For robusta models, this code would fail because Holmes has no knowledge of the API keys
|
|
184
|
+
# to azure or bedrock as all completion API calls go through robusta's LLM proxy
|
|
185
|
+
return
|
|
186
|
+
args = args or {}
|
|
83
187
|
logging.debug(f"Checking LiteLLM model {model}")
|
|
84
|
-
|
|
85
|
-
# so without this hack it always complains that the environment variable for the api key is missing
|
|
86
|
-
# to fix that, we always set an api key in the standard format that litellm expects (which is ${PROVIDER}_API_KEY)
|
|
87
|
-
# TODO: we can now handle this better - see https://github.com/BerriAI/litellm/issues/4375#issuecomment-2223684750
|
|
88
|
-
lookup = litellm.get_llm_provider(self.model)
|
|
188
|
+
lookup = litellm.get_llm_provider(model)
|
|
89
189
|
if not lookup:
|
|
90
190
|
raise Exception(f"Unknown provider for model {model}")
|
|
91
191
|
provider = lookup[1]
|
|
@@ -119,85 +219,151 @@ class DefaultLLM(LLM):
|
|
|
119
219
|
"environment variable for proper functionality. For more information, refer to the documentation: "
|
|
120
220
|
"https://docs.litellm.ai/docs/providers/watsonx#usage---models-in-deployment-spaces"
|
|
121
221
|
)
|
|
122
|
-
elif provider == "bedrock"
|
|
123
|
-
os.environ.get("AWS_PROFILE") or os.environ.get(
|
|
124
|
-
|
|
125
|
-
|
|
222
|
+
elif provider == "bedrock":
|
|
223
|
+
if os.environ.get("AWS_PROFILE") or os.environ.get(
|
|
224
|
+
"AWS_BEARER_TOKEN_BEDROCK"
|
|
225
|
+
):
|
|
226
|
+
model_requirements = {"keys_in_environment": True, "missing_keys": []}
|
|
227
|
+
elif args.get("aws_access_key_id") and args.get("aws_secret_access_key"):
|
|
228
|
+
return # break fast.
|
|
229
|
+
else:
|
|
230
|
+
model_requirements = litellm.validate_environment(
|
|
231
|
+
model=model, api_key=api_key, api_base=api_base
|
|
232
|
+
)
|
|
126
233
|
else:
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
234
|
+
model_requirements = litellm.validate_environment(
|
|
235
|
+
model=model, api_key=api_key, api_base=api_base
|
|
236
|
+
)
|
|
237
|
+
# validate_environment does not accept api_version, and as a special case for Azure OpenAI Service,
|
|
238
|
+
# when all the other AZURE environments are set expect AZURE_API_VERSION, validate_environment complains
|
|
239
|
+
# the missing of it even after the api_version is set.
|
|
240
|
+
# TODO: There's an open PR in litellm to accept api_version in validate_environment, we can leverage this
|
|
241
|
+
# change if accepted to ignore the following check.
|
|
242
|
+
# https://github.com/BerriAI/litellm/pull/13808
|
|
243
|
+
if (
|
|
244
|
+
provider == "azure"
|
|
245
|
+
and ["AZURE_API_VERSION"] == model_requirements["missing_keys"]
|
|
246
|
+
and api_version is not None
|
|
247
|
+
):
|
|
248
|
+
model_requirements["missing_keys"] = []
|
|
249
|
+
model_requirements["keys_in_environment"] = True
|
|
132
250
|
|
|
133
251
|
if not model_requirements["keys_in_environment"]:
|
|
134
252
|
raise Exception(
|
|
135
253
|
f"model {model} requires the following environment variables: {model_requirements['missing_keys']}"
|
|
136
254
|
)
|
|
137
255
|
|
|
138
|
-
def
|
|
256
|
+
def _get_model_name_variants_for_lookup(self) -> list[str]:
|
|
139
257
|
"""
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
|
|
258
|
+
Generate model name variants to try when looking up in litellm.model_cost.
|
|
259
|
+
Returns a list of names to try in order: exact, lowercase, without prefix, etc.
|
|
143
260
|
"""
|
|
144
|
-
|
|
145
|
-
prefixes = ["openai/", "bedrock/", "vertex_ai/", "anthropic/"]
|
|
261
|
+
names_to_try = [self.model, self.model.lower()]
|
|
146
262
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
263
|
+
# If there's a prefix, also try without it
|
|
264
|
+
if "/" in self.model:
|
|
265
|
+
base_model = self.model.split("/", 1)[1]
|
|
266
|
+
names_to_try.extend([base_model, base_model.lower()])
|
|
150
267
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# this unfortunately does not seem to work for azure if the deployment name is not a well-known model name
|
|
154
|
-
# if not litellm.supports_function_calling(model=model):
|
|
155
|
-
# raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
|
|
268
|
+
# Remove duplicates while preserving order (dict.fromkeys maintains insertion order in Python 3.7+)
|
|
269
|
+
return list(dict.fromkeys(names_to_try))
|
|
156
270
|
|
|
157
271
|
def get_context_window_size(self) -> int:
|
|
272
|
+
if self.max_context_size:
|
|
273
|
+
return self.max_context_size
|
|
274
|
+
|
|
158
275
|
if OVERRIDE_MAX_CONTENT_SIZE:
|
|
159
276
|
logging.debug(
|
|
160
277
|
f"Using override OVERRIDE_MAX_CONTENT_SIZE {OVERRIDE_MAX_CONTENT_SIZE}"
|
|
161
278
|
)
|
|
162
279
|
return OVERRIDE_MAX_CONTENT_SIZE
|
|
163
280
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
281
|
+
# Try each name variant
|
|
282
|
+
for name in self._get_model_name_variants_for_lookup():
|
|
283
|
+
try:
|
|
284
|
+
return litellm.model_cost[name]["max_input_tokens"]
|
|
285
|
+
except Exception:
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
# Log which lookups we tried
|
|
289
|
+
logging.warning(
|
|
290
|
+
f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
|
|
291
|
+
f"using default {FALLBACK_CONTEXT_WINDOW_SIZE} tokens for max_input_tokens. "
|
|
292
|
+
f"To override, set OVERRIDE_MAX_CONTENT_SIZE environment variable to the correct value for your model."
|
|
293
|
+
)
|
|
294
|
+
return FALLBACK_CONTEXT_WINDOW_SIZE
|
|
172
295
|
|
|
173
296
|
@sentry_sdk.trace
|
|
174
|
-
def
|
|
175
|
-
|
|
297
|
+
def count_tokens(
|
|
298
|
+
self, messages: list[dict], tools: Optional[list[dict[str, Any]]] = None
|
|
299
|
+
) -> TokenCountMetadata:
|
|
300
|
+
# TODO: Add a recount:bool flag to save time. When the flag is false, reuse 'message["token_count"]' for individual messages.
|
|
301
|
+
# It's only necessary to recount message tokens at the beginning of a session because the LLM model may have changed.
|
|
302
|
+
# Changing the model requires recounting tokens because the tokenizer may be different
|
|
303
|
+
total_tokens = 0
|
|
304
|
+
tools_tokens = 0
|
|
305
|
+
system_tokens = 0
|
|
306
|
+
assistant_tokens = 0
|
|
307
|
+
user_tokens = 0
|
|
308
|
+
other_tokens = 0
|
|
309
|
+
tools_to_call_tokens = 0
|
|
176
310
|
for message in messages:
|
|
177
|
-
|
|
178
|
-
|
|
311
|
+
# count message tokens individually because it gives us fine grain information about each tool call/message etc.
|
|
312
|
+
# However be aware that the sum of individual message tokens is not equal to the overall messages token
|
|
313
|
+
token_count = litellm.token_counter( # type: ignore
|
|
314
|
+
model=self.model, messages=[message]
|
|
315
|
+
)
|
|
316
|
+
message["token_count"] = token_count
|
|
317
|
+
role = message.get("role")
|
|
318
|
+
if role == "system":
|
|
319
|
+
system_tokens += token_count
|
|
320
|
+
elif role == "user":
|
|
321
|
+
user_tokens += token_count
|
|
322
|
+
elif role == "tool":
|
|
323
|
+
tools_tokens += token_count
|
|
324
|
+
elif role == "assistant":
|
|
325
|
+
assistant_tokens += token_count
|
|
179
326
|
else:
|
|
180
|
-
#
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
327
|
+
# although this should not be needed,
|
|
328
|
+
# it is defensive code so that all tokens are accounted for
|
|
329
|
+
# and can potentially make debugging easier
|
|
330
|
+
other_tokens += token_count
|
|
331
|
+
|
|
332
|
+
messages_token_count_without_tools = litellm.token_counter( # type: ignore
|
|
333
|
+
model=self.model, messages=messages
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
total_tokens = litellm.token_counter( # type: ignore
|
|
337
|
+
model=self.model,
|
|
338
|
+
messages=messages,
|
|
339
|
+
tools=tools, # type: ignore
|
|
340
|
+
)
|
|
341
|
+
tools_to_call_tokens = max(0, total_tokens - messages_token_count_without_tools)
|
|
342
|
+
|
|
343
|
+
return TokenCountMetadata(
|
|
344
|
+
total_tokens=total_tokens,
|
|
345
|
+
system_tokens=system_tokens,
|
|
346
|
+
user_tokens=user_tokens,
|
|
347
|
+
tools_tokens=tools_tokens,
|
|
348
|
+
tools_to_call_tokens=tools_to_call_tokens,
|
|
349
|
+
other_tokens=other_tokens,
|
|
350
|
+
assistant_tokens=assistant_tokens,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def get_litellm_corrected_name_for_robusta_ai(self) -> str:
|
|
354
|
+
if self.is_robusta_model:
|
|
355
|
+
# For robusta models, self.model is the underlying provider/model used by Robusta AI
|
|
356
|
+
# To avoid litellm modifying the API URL according to the provider, the provider name
|
|
357
|
+
# is replaced with 'openai/' just before doing a completion() call
|
|
358
|
+
# Cf. https://docs.litellm.ai/docs/providers/openai_compatible
|
|
359
|
+
split_model_name = self.model.split("/")
|
|
360
|
+
return (
|
|
361
|
+
split_model_name[0]
|
|
362
|
+
if len(split_model_name) == 1
|
|
363
|
+
else f"openai/{split_model_name[1]}"
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
return self.model
|
|
201
367
|
|
|
202
368
|
def completion(
|
|
203
369
|
self,
|
|
@@ -219,6 +385,9 @@ class DefaultLLM(LLM):
|
|
|
219
385
|
if THINKING:
|
|
220
386
|
self.args.setdefault("thinking", json.loads(THINKING))
|
|
221
387
|
|
|
388
|
+
if EXTRA_HEADERS:
|
|
389
|
+
self.args.setdefault("extra_headers", json.loads(EXTRA_HEADERS))
|
|
390
|
+
|
|
222
391
|
if self.args.get("thinking", None):
|
|
223
392
|
litellm.modify_params = True
|
|
224
393
|
|
|
@@ -228,20 +397,31 @@ class DefaultLLM(LLM):
|
|
|
228
397
|
"reasoning_effort"
|
|
229
398
|
] # can be removed after next litelm version
|
|
230
399
|
|
|
400
|
+
existing_allowed = self.args.pop("allowed_openai_params", None)
|
|
401
|
+
if existing_allowed:
|
|
402
|
+
if allowed_openai_params is None:
|
|
403
|
+
allowed_openai_params = []
|
|
404
|
+
allowed_openai_params.extend(existing_allowed)
|
|
405
|
+
|
|
231
406
|
self.args.setdefault("temperature", temperature)
|
|
232
407
|
|
|
233
408
|
self._add_cache_control_to_last_message(messages)
|
|
234
409
|
|
|
235
410
|
# Get the litellm module to use (wrapped or unwrapped)
|
|
236
411
|
litellm_to_use = self.tracer.wrap_llm(litellm) if self.tracer else litellm
|
|
412
|
+
|
|
413
|
+
litellm_model_name = self.get_litellm_corrected_name_for_robusta_ai()
|
|
237
414
|
result = litellm_to_use.completion(
|
|
238
|
-
model=
|
|
415
|
+
model=litellm_model_name,
|
|
239
416
|
api_key=self.api_key,
|
|
417
|
+
base_url=self.api_base,
|
|
418
|
+
api_version=self.api_version,
|
|
240
419
|
messages=messages,
|
|
241
420
|
response_format=response_format,
|
|
242
421
|
drop_params=drop_params,
|
|
243
422
|
allowed_openai_params=allowed_openai_params,
|
|
244
423
|
stream=stream,
|
|
424
|
+
timeout=LLM_REQUEST_TIMEOUT,
|
|
245
425
|
**tools_args,
|
|
246
426
|
**self.args,
|
|
247
427
|
)
|
|
@@ -254,20 +434,33 @@ class DefaultLLM(LLM):
|
|
|
254
434
|
raise Exception(f"Unexpected type returned by the LLM {type(result)}")
|
|
255
435
|
|
|
256
436
|
def get_maximum_output_token(self) -> int:
|
|
437
|
+
max_output_tokens = floor(min(64000, self.get_context_window_size() / 5))
|
|
438
|
+
|
|
257
439
|
if OVERRIDE_MAX_OUTPUT_TOKEN:
|
|
258
440
|
logging.debug(
|
|
259
441
|
f"Using OVERRIDE_MAX_OUTPUT_TOKEN {OVERRIDE_MAX_OUTPUT_TOKEN}"
|
|
260
442
|
)
|
|
261
443
|
return OVERRIDE_MAX_OUTPUT_TOKEN
|
|
262
444
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
445
|
+
# Try each name variant
|
|
446
|
+
for name in self._get_model_name_variants_for_lookup():
|
|
447
|
+
try:
|
|
448
|
+
litellm_max_output_tokens = litellm.model_cost[name][
|
|
449
|
+
"max_output_tokens"
|
|
450
|
+
]
|
|
451
|
+
if litellm_max_output_tokens < max_output_tokens:
|
|
452
|
+
max_output_tokens = litellm_max_output_tokens
|
|
453
|
+
return max_output_tokens
|
|
454
|
+
except Exception:
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
# Log which lookups we tried
|
|
458
|
+
logging.warning(
|
|
459
|
+
f"Couldn't find model {self.model} in litellm's model list (tried: {', '.join(self._get_model_name_variants_for_lookup())}), "
|
|
460
|
+
f"using {max_output_tokens} tokens for max_output_tokens. "
|
|
461
|
+
f"To override, set OVERRIDE_MAX_OUTPUT_TOKEN environment variable to the correct value for your model."
|
|
462
|
+
)
|
|
463
|
+
return max_output_tokens
|
|
271
464
|
|
|
272
465
|
def _add_cache_control_to_last_message(
|
|
273
466
|
self, messages: List[Dict[str, Any]]
|
|
@@ -276,6 +469,12 @@ class DefaultLLM(LLM):
|
|
|
276
469
|
Add cache_control to the last non-user message for Anthropic prompt caching.
|
|
277
470
|
Removes any existing cache_control from previous messages to avoid accumulation.
|
|
278
471
|
"""
|
|
472
|
+
# Skip cache_control for VertexAI/Gemini models as they don't support it with tools
|
|
473
|
+
if self.model and (
|
|
474
|
+
"vertex" in self.model.lower() or "gemini" in self.model.lower()
|
|
475
|
+
):
|
|
476
|
+
return
|
|
477
|
+
|
|
279
478
|
# First, remove any existing cache_control from all messages
|
|
280
479
|
for msg in messages:
|
|
281
480
|
content = msg.get("content")
|
|
@@ -305,7 +504,7 @@ class DefaultLLM(LLM):
|
|
|
305
504
|
if content is None:
|
|
306
505
|
return
|
|
307
506
|
|
|
308
|
-
if isinstance(content, str):
|
|
507
|
+
if isinstance(content, str) and content:
|
|
309
508
|
# Convert string to structured format with cache_control
|
|
310
509
|
target_msg["content"] = [
|
|
311
510
|
{
|
|
@@ -325,3 +524,224 @@ class DefaultLLM(LLM):
|
|
|
325
524
|
logging.debug(
|
|
326
525
|
f"Added cache_control to {target_msg.get('role')} message (structured content)"
|
|
327
526
|
)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class LLMModelRegistry:
|
|
530
|
+
def __init__(self, config: "Config", dal: SupabaseDal) -> None:
|
|
531
|
+
self.config = config
|
|
532
|
+
self._llms: dict[str, ModelEntry] = {}
|
|
533
|
+
self._default_robusta_model = None
|
|
534
|
+
self.dal = dal
|
|
535
|
+
self._lock = threading.RLock()
|
|
536
|
+
|
|
537
|
+
self._init_models()
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def default_robusta_model(self) -> Optional[str]:
|
|
541
|
+
return self._default_robusta_model
|
|
542
|
+
|
|
543
|
+
def _init_models(self):
|
|
544
|
+
self._llms = self._parse_models_file(MODEL_LIST_FILE_LOCATION)
|
|
545
|
+
|
|
546
|
+
if self._should_load_robusta_ai():
|
|
547
|
+
self.configure_robusta_ai_model()
|
|
548
|
+
|
|
549
|
+
if self._should_load_config_model():
|
|
550
|
+
self._llms[self.config.model] = self._create_model_entry(
|
|
551
|
+
model=self.config.model,
|
|
552
|
+
model_name=self.config.model,
|
|
553
|
+
base_url=self.config.api_base,
|
|
554
|
+
is_robusta_model=False,
|
|
555
|
+
api_key=self.config.api_key,
|
|
556
|
+
api_version=self.config.api_version,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def _should_load_config_model(self) -> bool:
|
|
560
|
+
if self.config.model is not None:
|
|
561
|
+
if self._llms and self.config.model in self._llms:
|
|
562
|
+
# model already loaded from file
|
|
563
|
+
return False
|
|
564
|
+
return True
|
|
565
|
+
|
|
566
|
+
# backward compatibility - in the past config.model was set by default to gpt-4o.
|
|
567
|
+
# so we need to check if the user has set an OPENAI_API_KEY to load the config model.
|
|
568
|
+
has_openai_key = os.environ.get("OPENAI_API_KEY")
|
|
569
|
+
if has_openai_key:
|
|
570
|
+
self.config.model = "gpt-4.1"
|
|
571
|
+
return True
|
|
572
|
+
|
|
573
|
+
return False
|
|
574
|
+
|
|
575
|
+
def configure_robusta_ai_model(self) -> None:
|
|
576
|
+
try:
|
|
577
|
+
if not self.config.cluster_name or not LOAD_ALL_ROBUSTA_MODELS:
|
|
578
|
+
self._load_default_robusta_config()
|
|
579
|
+
return
|
|
580
|
+
|
|
581
|
+
if not self.dal.account_id or not self.dal.enabled:
|
|
582
|
+
self._load_default_robusta_config()
|
|
583
|
+
return
|
|
584
|
+
|
|
585
|
+
account_id, token = self.dal.get_ai_credentials()
|
|
586
|
+
|
|
587
|
+
robusta_models: RobustaModelsResponse | None = fetch_robusta_models(
|
|
588
|
+
account_id, token
|
|
589
|
+
)
|
|
590
|
+
if not robusta_models or not robusta_models.models:
|
|
591
|
+
self._load_default_robusta_config()
|
|
592
|
+
return
|
|
593
|
+
|
|
594
|
+
default_model = None
|
|
595
|
+
for model_name, model_data in robusta_models.models.items():
|
|
596
|
+
logging.info(f"Loading Robusta AI model: {model_name}")
|
|
597
|
+
self._llms[model_name] = self._create_robusta_model_entry(
|
|
598
|
+
model_name=model_name, model_data=model_data
|
|
599
|
+
)
|
|
600
|
+
if model_data.is_default:
|
|
601
|
+
default_model = model_name
|
|
602
|
+
|
|
603
|
+
if default_model:
|
|
604
|
+
logging.info(f"Setting default Robusta AI model to: {default_model}")
|
|
605
|
+
self._default_robusta_model: str = default_model # type: ignore
|
|
606
|
+
|
|
607
|
+
except Exception:
|
|
608
|
+
logging.exception("Failed to get all robusta models")
|
|
609
|
+
# fallback to default behavior
|
|
610
|
+
self._load_default_robusta_config()
|
|
611
|
+
|
|
612
|
+
def _load_default_robusta_config(self):
|
|
613
|
+
if self._should_load_robusta_ai():
|
|
614
|
+
logging.info("Loading default Robusta AI model")
|
|
615
|
+
self._llms[ROBUSTA_AI_MODEL_NAME] = ModelEntry(
|
|
616
|
+
name=ROBUSTA_AI_MODEL_NAME,
|
|
617
|
+
model="gpt-4o", # TODO: tech debt, this isn't really
|
|
618
|
+
base_url=ROBUSTA_API_ENDPOINT,
|
|
619
|
+
is_robusta_model=True,
|
|
620
|
+
)
|
|
621
|
+
self._default_robusta_model = ROBUSTA_AI_MODEL_NAME
|
|
622
|
+
|
|
623
|
+
def _should_load_robusta_ai(self) -> bool:
|
|
624
|
+
if not self.config.should_try_robusta_ai:
|
|
625
|
+
return False
|
|
626
|
+
|
|
627
|
+
# ROBUSTA_AI were set in the env vars, so we can use it directly
|
|
628
|
+
if ROBUSTA_AI is not None:
|
|
629
|
+
return ROBUSTA_AI
|
|
630
|
+
|
|
631
|
+
# MODEL is set in the env vars, e.g. the user is using a custom model
|
|
632
|
+
# so we don't need to load the robusta AI model and keep the behavior backward compatible
|
|
633
|
+
if "MODEL" in os.environ:
|
|
634
|
+
return False
|
|
635
|
+
|
|
636
|
+
# if the user has provided a model list, we don't need to load the robusta AI model
|
|
637
|
+
if self._llms:
|
|
638
|
+
return False
|
|
639
|
+
|
|
640
|
+
return True
|
|
641
|
+
|
|
642
|
+
def get_model_params(self, model_key: Optional[str] = None) -> ModelEntry:
|
|
643
|
+
with self._lock:
|
|
644
|
+
if not self._llms:
|
|
645
|
+
raise Exception("No llm models were loaded")
|
|
646
|
+
|
|
647
|
+
if model_key:
|
|
648
|
+
model_params = self._llms.get(model_key)
|
|
649
|
+
if model_params:
|
|
650
|
+
logging.info(f"Using selected model: {model_key}")
|
|
651
|
+
return model_params.model_copy()
|
|
652
|
+
|
|
653
|
+
if model_key.startswith("Robusta/"):
|
|
654
|
+
logging.warning("Resyncing Registry and Robusta models.")
|
|
655
|
+
self._init_models()
|
|
656
|
+
model_params = self._llms.get(model_key)
|
|
657
|
+
if model_params:
|
|
658
|
+
logging.info(f"Using selected model: {model_key}")
|
|
659
|
+
return model_params.model_copy()
|
|
660
|
+
|
|
661
|
+
logging.error(f"Couldn't find model: {model_key} in model list")
|
|
662
|
+
|
|
663
|
+
if self._default_robusta_model:
|
|
664
|
+
model_params = self._llms.get(self._default_robusta_model)
|
|
665
|
+
if model_params is not None:
|
|
666
|
+
logging.info(
|
|
667
|
+
f"Using default Robusta AI model: {self._default_robusta_model}"
|
|
668
|
+
)
|
|
669
|
+
return model_params.model_copy()
|
|
670
|
+
|
|
671
|
+
logging.error(
|
|
672
|
+
f"Couldn't find default Robusta AI model: {self._default_robusta_model} in model list"
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
model_key, first_model_params = next(iter(self._llms.items()))
|
|
676
|
+
logging.debug(f"Using first available model: {model_key}")
|
|
677
|
+
return first_model_params.model_copy()
|
|
678
|
+
|
|
679
|
+
@property
|
|
680
|
+
def models(self) -> dict[str, ModelEntry]:
|
|
681
|
+
with self._lock:
|
|
682
|
+
return self._llms
|
|
683
|
+
|
|
684
|
+
def _parse_models_file(self, path: str) -> dict[str, ModelEntry]:
|
|
685
|
+
models = load_yaml_file(path, raise_error=False, warn_not_found=False)
|
|
686
|
+
for _, params in models.items():
|
|
687
|
+
params = replace_env_vars_values(params)
|
|
688
|
+
|
|
689
|
+
llms = {}
|
|
690
|
+
for model_name, params in models.items():
|
|
691
|
+
llms[model_name] = ModelEntry.model_validate(params)
|
|
692
|
+
|
|
693
|
+
return llms
|
|
694
|
+
|
|
695
|
+
def _create_robusta_model_entry(
|
|
696
|
+
self, model_name: str, model_data: RobustaModel
|
|
697
|
+
) -> ModelEntry:
|
|
698
|
+
entry = self._create_model_entry(
|
|
699
|
+
model=model_data.model,
|
|
700
|
+
model_name=model_name,
|
|
701
|
+
base_url=f"{ROBUSTA_API_ENDPOINT}/llm/{model_name}",
|
|
702
|
+
is_robusta_model=True,
|
|
703
|
+
)
|
|
704
|
+
entry.custom_args = model_data.holmes_args or {} # type: ignore[assignment]
|
|
705
|
+
return entry
|
|
706
|
+
|
|
707
|
+
def _create_model_entry(
|
|
708
|
+
self,
|
|
709
|
+
model: str,
|
|
710
|
+
model_name: str,
|
|
711
|
+
base_url: Optional[str] = None,
|
|
712
|
+
is_robusta_model: Optional[bool] = None,
|
|
713
|
+
api_key: Optional[SecretStr] = None,
|
|
714
|
+
api_base: Optional[str] = None,
|
|
715
|
+
api_version: Optional[str] = None,
|
|
716
|
+
) -> ModelEntry:
|
|
717
|
+
return ModelEntry(
|
|
718
|
+
name=model_name,
|
|
719
|
+
model=model,
|
|
720
|
+
base_url=base_url,
|
|
721
|
+
is_robusta_model=is_robusta_model,
|
|
722
|
+
api_key=api_key,
|
|
723
|
+
api_base=api_base,
|
|
724
|
+
api_version=api_version,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def get_llm_usage(
|
|
729
|
+
llm_response: Union[ModelResponse, CustomStreamWrapper, TextCompletionResponse],
|
|
730
|
+
) -> dict:
|
|
731
|
+
usage: dict = {}
|
|
732
|
+
if (
|
|
733
|
+
(
|
|
734
|
+
isinstance(llm_response, ModelResponse)
|
|
735
|
+
or isinstance(llm_response, TextCompletionResponse)
|
|
736
|
+
)
|
|
737
|
+
and hasattr(llm_response, "usage")
|
|
738
|
+
and llm_response.usage
|
|
739
|
+
): # type: ignore
|
|
740
|
+
usage["prompt_tokens"] = llm_response.usage.prompt_tokens # type: ignore
|
|
741
|
+
usage["completion_tokens"] = llm_response.usage.completion_tokens # type: ignore
|
|
742
|
+
usage["total_tokens"] = llm_response.usage.total_tokens # type: ignore
|
|
743
|
+
elif isinstance(llm_response, CustomStreamWrapper):
|
|
744
|
+
complete_response = litellm.stream_chunk_builder(chunks=llm_response) # type: ignore
|
|
745
|
+
if complete_response:
|
|
746
|
+
return get_llm_usage(complete_response)
|
|
747
|
+
return usage
|