groknroll 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groknroll/__init__.py +36 -0
- groknroll/__main__.py +9 -0
- groknroll/agents/__init__.py +18 -0
- groknroll/agents/agent_manager.py +187 -0
- groknroll/agents/base_agent.py +118 -0
- groknroll/agents/build_agent.py +231 -0
- groknroll/agents/plan_agent.py +215 -0
- groknroll/cli/__init__.py +7 -0
- groknroll/cli/enhanced_cli.py +372 -0
- groknroll/cli/large_codebase_cli.py +413 -0
- groknroll/cli/main.py +331 -0
- groknroll/cli/rlm_commands.py +258 -0
- groknroll/clients/__init__.py +63 -0
- groknroll/clients/anthropic.py +112 -0
- groknroll/clients/azure_openai.py +142 -0
- groknroll/clients/base_lm.py +33 -0
- groknroll/clients/gemini.py +162 -0
- groknroll/clients/litellm.py +105 -0
- groknroll/clients/openai.py +129 -0
- groknroll/clients/portkey.py +94 -0
- groknroll/core/__init__.py +9 -0
- groknroll/core/agent.py +339 -0
- groknroll/core/comms_utils.py +264 -0
- groknroll/core/context.py +251 -0
- groknroll/core/exceptions.py +181 -0
- groknroll/core/large_codebase.py +564 -0
- groknroll/core/lm_handler.py +206 -0
- groknroll/core/rlm.py +446 -0
- groknroll/core/rlm_codebase.py +448 -0
- groknroll/core/rlm_integration.py +256 -0
- groknroll/core/types.py +276 -0
- groknroll/environments/__init__.py +34 -0
- groknroll/environments/base_env.py +182 -0
- groknroll/environments/constants.py +32 -0
- groknroll/environments/docker_repl.py +336 -0
- groknroll/environments/local_repl.py +388 -0
- groknroll/environments/modal_repl.py +502 -0
- groknroll/environments/prime_repl.py +588 -0
- groknroll/logger/__init__.py +4 -0
- groknroll/logger/rlm_logger.py +63 -0
- groknroll/logger/verbose.py +393 -0
- groknroll/operations/__init__.py +15 -0
- groknroll/operations/bash_ops.py +447 -0
- groknroll/operations/file_ops.py +473 -0
- groknroll/operations/git_ops.py +620 -0
- groknroll/oracle/__init__.py +11 -0
- groknroll/oracle/codebase_indexer.py +238 -0
- groknroll/oracle/oracle_agent.py +278 -0
- groknroll/setup.py +34 -0
- groknroll/storage/__init__.py +14 -0
- groknroll/storage/database.py +272 -0
- groknroll/storage/models.py +128 -0
- groknroll/utils/__init__.py +0 -0
- groknroll/utils/parsing.py +168 -0
- groknroll/utils/prompts.py +146 -0
- groknroll/utils/rlm_utils.py +19 -0
- groknroll-2.0.0.dist-info/METADATA +246 -0
- groknroll-2.0.0.dist-info/RECORD +62 -0
- groknroll-2.0.0.dist-info/WHEEL +5 -0
- groknroll-2.0.0.dist-info/entry_points.txt +3 -0
- groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
- groknroll-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import anthropic
|
|
5
|
+
|
|
6
|
+
from groknroll.clients.base_lm import BaseLM
|
|
7
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnthropicClient(BaseLM):
|
|
11
|
+
"""
|
|
12
|
+
LM Client for running models with the Anthropic API.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
api_key: str,
|
|
18
|
+
model_name: str | None = None,
|
|
19
|
+
max_tokens: int = 32768,
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
23
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
|
24
|
+
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
25
|
+
self.model_name = model_name
|
|
26
|
+
self.max_tokens = max_tokens
|
|
27
|
+
|
|
28
|
+
# Per-model usage tracking
|
|
29
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
30
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
31
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
32
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
33
|
+
|
|
34
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
35
|
+
messages, system = self._prepare_messages(prompt)
|
|
36
|
+
|
|
37
|
+
model = model or self.model_name
|
|
38
|
+
if not model:
|
|
39
|
+
raise ValueError("Model name is required for Anthropic client.")
|
|
40
|
+
|
|
41
|
+
kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
|
|
42
|
+
if system:
|
|
43
|
+
kwargs["system"] = system
|
|
44
|
+
|
|
45
|
+
response = self.client.messages.create(**kwargs)
|
|
46
|
+
self._track_cost(response, model)
|
|
47
|
+
return response.content[0].text
|
|
48
|
+
|
|
49
|
+
async def acompletion(
|
|
50
|
+
self, prompt: str | list[dict[str, Any]], model: str | None = None
|
|
51
|
+
) -> str:
|
|
52
|
+
messages, system = self._prepare_messages(prompt)
|
|
53
|
+
|
|
54
|
+
model = model or self.model_name
|
|
55
|
+
if not model:
|
|
56
|
+
raise ValueError("Model name is required for Anthropic client.")
|
|
57
|
+
|
|
58
|
+
kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
|
|
59
|
+
if system:
|
|
60
|
+
kwargs["system"] = system
|
|
61
|
+
|
|
62
|
+
response = await self.async_client.messages.create(**kwargs)
|
|
63
|
+
self._track_cost(response, model)
|
|
64
|
+
return response.content[0].text
|
|
65
|
+
|
|
66
|
+
def _prepare_messages(
|
|
67
|
+
self, prompt: str | list[dict[str, Any]]
|
|
68
|
+
) -> tuple[list[dict[str, Any]], str | None]:
|
|
69
|
+
"""Prepare messages and extract system prompt for Anthropic API."""
|
|
70
|
+
system = None
|
|
71
|
+
|
|
72
|
+
if isinstance(prompt, str):
|
|
73
|
+
messages = [{"role": "user", "content": prompt}]
|
|
74
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
75
|
+
# Extract system message if present (Anthropic handles system separately)
|
|
76
|
+
messages = []
|
|
77
|
+
for msg in prompt:
|
|
78
|
+
if msg.get("role") == "system":
|
|
79
|
+
system = msg.get("content")
|
|
80
|
+
else:
|
|
81
|
+
messages.append(msg)
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
84
|
+
|
|
85
|
+
return messages, system
|
|
86
|
+
|
|
87
|
+
def _track_cost(self, response: anthropic.types.Message, model: str):
|
|
88
|
+
self.model_call_counts[model] += 1
|
|
89
|
+
self.model_input_tokens[model] += response.usage.input_tokens
|
|
90
|
+
self.model_output_tokens[model] += response.usage.output_tokens
|
|
91
|
+
self.model_total_tokens[model] += response.usage.input_tokens + response.usage.output_tokens
|
|
92
|
+
|
|
93
|
+
# Track last call for handler to read
|
|
94
|
+
self.last_prompt_tokens = response.usage.input_tokens
|
|
95
|
+
self.last_completion_tokens = response.usage.output_tokens
|
|
96
|
+
|
|
97
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
98
|
+
model_summaries = {}
|
|
99
|
+
for model in self.model_call_counts:
|
|
100
|
+
model_summaries[model] = ModelUsageSummary(
|
|
101
|
+
total_calls=self.model_call_counts[model],
|
|
102
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
103
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
104
|
+
)
|
|
105
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
106
|
+
|
|
107
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
108
|
+
return ModelUsageSummary(
|
|
109
|
+
total_calls=1,
|
|
110
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
111
|
+
total_output_tokens=self.last_completion_tokens,
|
|
112
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import openai
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
from groknroll.clients.base_lm import BaseLM
|
|
9
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
10
|
+
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
# Load API key from environment variable
|
|
14
|
+
DEFAULT_AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AzureOpenAIClient(BaseLM):
|
|
18
|
+
"""
|
|
19
|
+
LM Client for running models with the Azure OpenAI API.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
model_name: str | None = None,
|
|
26
|
+
azure_endpoint: str | None = None,
|
|
27
|
+
api_version: str | None = None,
|
|
28
|
+
azure_deployment: str | None = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
32
|
+
|
|
33
|
+
if api_key is None:
|
|
34
|
+
api_key = DEFAULT_AZURE_OPENAI_API_KEY
|
|
35
|
+
|
|
36
|
+
if azure_endpoint is None:
|
|
37
|
+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
38
|
+
|
|
39
|
+
if api_version is None:
|
|
40
|
+
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01")
|
|
41
|
+
|
|
42
|
+
if azure_deployment is None:
|
|
43
|
+
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
|
44
|
+
|
|
45
|
+
if azure_endpoint is None:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"azure_endpoint is required for Azure OpenAI client. "
|
|
48
|
+
"Set it via argument or AZURE_OPENAI_ENDPOINT environment variable."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.client = openai.AzureOpenAI(
|
|
52
|
+
api_key=api_key,
|
|
53
|
+
azure_endpoint=azure_endpoint,
|
|
54
|
+
api_version=api_version,
|
|
55
|
+
azure_deployment=azure_deployment,
|
|
56
|
+
)
|
|
57
|
+
self.async_client = openai.AsyncAzureOpenAI(
|
|
58
|
+
api_key=api_key,
|
|
59
|
+
azure_endpoint=azure_endpoint,
|
|
60
|
+
api_version=api_version,
|
|
61
|
+
azure_deployment=azure_deployment,
|
|
62
|
+
)
|
|
63
|
+
self.model_name = model_name
|
|
64
|
+
self.azure_deployment = azure_deployment
|
|
65
|
+
|
|
66
|
+
# Per-model usage tracking
|
|
67
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
68
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
69
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
70
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
71
|
+
|
|
72
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
73
|
+
if isinstance(prompt, str):
|
|
74
|
+
messages = [{"role": "user", "content": prompt}]
|
|
75
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
76
|
+
messages = prompt
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
79
|
+
|
|
80
|
+
model = model or self.model_name
|
|
81
|
+
if not model:
|
|
82
|
+
raise ValueError("Model name is required for Azure OpenAI client.")
|
|
83
|
+
|
|
84
|
+
response = self.client.chat.completions.create(
|
|
85
|
+
model=model,
|
|
86
|
+
messages=messages,
|
|
87
|
+
)
|
|
88
|
+
self._track_cost(response, model)
|
|
89
|
+
return response.choices[0].message.content
|
|
90
|
+
|
|
91
|
+
async def acompletion(
|
|
92
|
+
self, prompt: str | list[dict[str, Any]], model: str | None = None
|
|
93
|
+
) -> str:
|
|
94
|
+
if isinstance(prompt, str):
|
|
95
|
+
messages = [{"role": "user", "content": prompt}]
|
|
96
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
97
|
+
messages = prompt
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
100
|
+
|
|
101
|
+
model = model or self.model_name
|
|
102
|
+
if not model:
|
|
103
|
+
raise ValueError("Model name is required for Azure OpenAI client.")
|
|
104
|
+
|
|
105
|
+
response = await self.async_client.chat.completions.create(
|
|
106
|
+
model=model,
|
|
107
|
+
messages=messages,
|
|
108
|
+
)
|
|
109
|
+
self._track_cost(response, model)
|
|
110
|
+
return response.choices[0].message.content
|
|
111
|
+
|
|
112
|
+
def _track_cost(self, response: openai.ChatCompletion, model: str):
|
|
113
|
+
self.model_call_counts[model] += 1
|
|
114
|
+
|
|
115
|
+
usage = getattr(response, "usage", None)
|
|
116
|
+
if usage is None:
|
|
117
|
+
raise ValueError("No usage data received. Tracking tokens not possible.")
|
|
118
|
+
|
|
119
|
+
self.model_input_tokens[model] += usage.prompt_tokens
|
|
120
|
+
self.model_output_tokens[model] += usage.completion_tokens
|
|
121
|
+
self.model_total_tokens[model] += usage.total_tokens
|
|
122
|
+
|
|
123
|
+
# Track last call for handler to read
|
|
124
|
+
self.last_prompt_tokens = usage.prompt_tokens
|
|
125
|
+
self.last_completion_tokens = usage.completion_tokens
|
|
126
|
+
|
|
127
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
128
|
+
model_summaries = {}
|
|
129
|
+
for model in self.model_call_counts:
|
|
130
|
+
model_summaries[model] = ModelUsageSummary(
|
|
131
|
+
total_calls=self.model_call_counts[model],
|
|
132
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
133
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
134
|
+
)
|
|
135
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
136
|
+
|
|
137
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
138
|
+
return ModelUsageSummary(
|
|
139
|
+
total_calls=1,
|
|
140
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
141
|
+
total_output_tokens=self.last_completion_tokens,
|
|
142
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from groknroll.core.types import UsageSummary
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseLM(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Base class for all language model routers / clients. When the RLM makes sub-calls, it currently
|
|
10
|
+
does so in a model-agnostic way, so this class provides a base interface for all language models.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, model_name: str, **kwargs):
|
|
14
|
+
self.model_name = model_name
|
|
15
|
+
self.kwargs = kwargs
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def completion(self, prompt: str | dict[str, Any]) -> str:
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def acompletion(self, prompt: str | dict[str, Any]) -> str:
|
|
23
|
+
raise NotImplementedError
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
27
|
+
"""Get cost summary for all model calls."""
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def get_last_usage(self) -> UsageSummary:
|
|
32
|
+
"""Get the last cost summary of the model."""
|
|
33
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
from google import genai
|
|
7
|
+
from google.genai import types
|
|
8
|
+
|
|
9
|
+
from groknroll.clients.base_lm import BaseLM
|
|
10
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GeminiClient(BaseLM):
|
|
18
|
+
"""
|
|
19
|
+
LM Client for running models with the Google Gemini API.
|
|
20
|
+
Uses the official google-genai SDK.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
model_name: str | None = "gemini-2.5-flash",
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
30
|
+
|
|
31
|
+
if api_key is None:
|
|
32
|
+
api_key = DEFAULT_GEMINI_API_KEY
|
|
33
|
+
|
|
34
|
+
if api_key is None:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Gemini API key is required. Set GEMINI_API_KEY env var or pass api_key."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
self.client = genai.Client(api_key=api_key)
|
|
40
|
+
self.model_name = model_name
|
|
41
|
+
|
|
42
|
+
# Per-model usage tracking
|
|
43
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
44
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
45
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
46
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
47
|
+
|
|
48
|
+
# Last call tracking
|
|
49
|
+
self.last_prompt_tokens = 0
|
|
50
|
+
self.last_completion_tokens = 0
|
|
51
|
+
|
|
52
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
53
|
+
contents, system_instruction = self._prepare_contents(prompt)
|
|
54
|
+
|
|
55
|
+
model = model or self.model_name
|
|
56
|
+
if not model:
|
|
57
|
+
raise ValueError("Model name is required for Gemini client.")
|
|
58
|
+
|
|
59
|
+
config = None
|
|
60
|
+
if system_instruction:
|
|
61
|
+
config = types.GenerateContentConfig(system_instruction=system_instruction)
|
|
62
|
+
|
|
63
|
+
response = self.client.models.generate_content(
|
|
64
|
+
model=model,
|
|
65
|
+
contents=contents,
|
|
66
|
+
config=config,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self._track_cost(response, model)
|
|
70
|
+
return response.text
|
|
71
|
+
|
|
72
|
+
async def acompletion(
|
|
73
|
+
self, prompt: str | list[dict[str, Any]], model: str | None = None
|
|
74
|
+
) -> str:
|
|
75
|
+
contents, system_instruction = self._prepare_contents(prompt)
|
|
76
|
+
|
|
77
|
+
model = model or self.model_name
|
|
78
|
+
if not model:
|
|
79
|
+
raise ValueError("Model name is required for Gemini client.")
|
|
80
|
+
|
|
81
|
+
config = None
|
|
82
|
+
if system_instruction:
|
|
83
|
+
config = types.GenerateContentConfig(system_instruction=system_instruction)
|
|
84
|
+
|
|
85
|
+
# google-genai SDK supports async via aio interface
|
|
86
|
+
response = await self.client.aio.models.generate_content(
|
|
87
|
+
model=model,
|
|
88
|
+
contents=contents,
|
|
89
|
+
config=config,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self._track_cost(response, model)
|
|
93
|
+
return response.text
|
|
94
|
+
|
|
95
|
+
def _prepare_contents(
|
|
96
|
+
self, prompt: str | list[dict[str, Any]]
|
|
97
|
+
) -> tuple[list[types.Content] | str, str | None]:
|
|
98
|
+
"""Prepare contents and extract system instruction for Gemini API."""
|
|
99
|
+
system_instruction = None
|
|
100
|
+
|
|
101
|
+
if isinstance(prompt, str):
|
|
102
|
+
return prompt, None
|
|
103
|
+
|
|
104
|
+
if isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
105
|
+
# Convert OpenAI-style messages to Gemini format
|
|
106
|
+
contents = []
|
|
107
|
+
for msg in prompt:
|
|
108
|
+
role = msg.get("role")
|
|
109
|
+
content = msg.get("content", "")
|
|
110
|
+
|
|
111
|
+
if role == "system":
|
|
112
|
+
# Gemini handles system instruction separately
|
|
113
|
+
system_instruction = content
|
|
114
|
+
elif role == "user":
|
|
115
|
+
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
|
|
116
|
+
elif role == "assistant":
|
|
117
|
+
# Gemini uses "model" instead of "assistant"
|
|
118
|
+
contents.append(types.Content(role="model", parts=[types.Part(text=content)]))
|
|
119
|
+
else:
|
|
120
|
+
# Default to user role for unknown roles
|
|
121
|
+
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
|
|
122
|
+
|
|
123
|
+
return contents, system_instruction
|
|
124
|
+
|
|
125
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
126
|
+
|
|
127
|
+
def _track_cost(self, response: types.GenerateContentResponse, model: str):
|
|
128
|
+
self.model_call_counts[model] += 1
|
|
129
|
+
|
|
130
|
+
# Extract token usage from response
|
|
131
|
+
usage = response.usage_metadata
|
|
132
|
+
if usage:
|
|
133
|
+
input_tokens = usage.prompt_token_count or 0
|
|
134
|
+
output_tokens = usage.candidates_token_count or 0
|
|
135
|
+
|
|
136
|
+
self.model_input_tokens[model] += input_tokens
|
|
137
|
+
self.model_output_tokens[model] += output_tokens
|
|
138
|
+
self.model_total_tokens[model] += input_tokens + output_tokens
|
|
139
|
+
|
|
140
|
+
# Track last call for handler to read
|
|
141
|
+
self.last_prompt_tokens = input_tokens
|
|
142
|
+
self.last_completion_tokens = output_tokens
|
|
143
|
+
else:
|
|
144
|
+
self.last_prompt_tokens = 0
|
|
145
|
+
self.last_completion_tokens = 0
|
|
146
|
+
|
|
147
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
148
|
+
model_summaries = {}
|
|
149
|
+
for model in self.model_call_counts:
|
|
150
|
+
model_summaries[model] = ModelUsageSummary(
|
|
151
|
+
total_calls=self.model_call_counts[model],
|
|
152
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
153
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
154
|
+
)
|
|
155
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
156
|
+
|
|
157
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
158
|
+
return ModelUsageSummary(
|
|
159
|
+
total_calls=1,
|
|
160
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
161
|
+
total_output_tokens=self.last_completion_tokens,
|
|
162
|
+
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import litellm
|
|
5
|
+
|
|
6
|
+
from groknroll.clients.base_lm import BaseLM
|
|
7
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LiteLLMClient(BaseLM):
|
|
11
|
+
"""
|
|
12
|
+
LM Client for running models with LiteLLM.
|
|
13
|
+
LiteLLM provides a unified interface to 100+ LLM providers.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str | None = None,
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
api_base: str | None = None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
24
|
+
self.model_name = model_name
|
|
25
|
+
self.api_key = api_key
|
|
26
|
+
self.api_base = api_base
|
|
27
|
+
|
|
28
|
+
# Per-model usage tracking
|
|
29
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
30
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
31
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
32
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
33
|
+
|
|
34
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
35
|
+
if isinstance(prompt, str):
|
|
36
|
+
messages = [{"role": "user", "content": prompt}]
|
|
37
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
38
|
+
messages = prompt
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
41
|
+
|
|
42
|
+
model = model or self.model_name
|
|
43
|
+
if not model:
|
|
44
|
+
raise ValueError("Model name is required for LiteLLM client.")
|
|
45
|
+
|
|
46
|
+
kwargs = {"model": model, "messages": messages}
|
|
47
|
+
if self.api_key:
|
|
48
|
+
kwargs["api_key"] = self.api_key
|
|
49
|
+
if self.api_base:
|
|
50
|
+
kwargs["api_base"] = self.api_base
|
|
51
|
+
|
|
52
|
+
response = litellm.completion(**kwargs)
|
|
53
|
+
self._track_cost(response, model)
|
|
54
|
+
return response.choices[0].message.content
|
|
55
|
+
|
|
56
|
+
async def acompletion(
|
|
57
|
+
self, prompt: str | list[dict[str, Any]], model: str | None = None
|
|
58
|
+
) -> str:
|
|
59
|
+
if isinstance(prompt, str):
|
|
60
|
+
messages = [{"role": "user", "content": prompt}]
|
|
61
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
62
|
+
messages = prompt
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
65
|
+
|
|
66
|
+
model = model or self.model_name
|
|
67
|
+
if not model:
|
|
68
|
+
raise ValueError("Model name is required for LiteLLM client.")
|
|
69
|
+
|
|
70
|
+
kwargs = {"model": model, "messages": messages}
|
|
71
|
+
if self.api_key:
|
|
72
|
+
kwargs["api_key"] = self.api_key
|
|
73
|
+
if self.api_base:
|
|
74
|
+
kwargs["api_base"] = self.api_base
|
|
75
|
+
|
|
76
|
+
response = await litellm.acompletion(**kwargs)
|
|
77
|
+
self._track_cost(response, model)
|
|
78
|
+
return response.choices[0].message.content
|
|
79
|
+
|
|
80
|
+
def _track_cost(self, response, model: str):
|
|
81
|
+
self.model_call_counts[model] += 1
|
|
82
|
+
self.model_input_tokens[model] += response.usage.prompt_tokens
|
|
83
|
+
self.model_output_tokens[model] += response.usage.completion_tokens
|
|
84
|
+
self.model_total_tokens[model] += response.usage.total_tokens
|
|
85
|
+
|
|
86
|
+
# Track last call for handler to read
|
|
87
|
+
self.last_prompt_tokens = response.usage.prompt_tokens
|
|
88
|
+
self.last_completion_tokens = response.usage.completion_tokens
|
|
89
|
+
|
|
90
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
91
|
+
model_summaries = {}
|
|
92
|
+
for model in self.model_call_counts:
|
|
93
|
+
model_summaries[model] = ModelUsageSummary(
|
|
94
|
+
total_calls=self.model_call_counts[model],
|
|
95
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
96
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
97
|
+
)
|
|
98
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
99
|
+
|
|
100
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
101
|
+
return ModelUsageSummary(
|
|
102
|
+
total_calls=1,
|
|
103
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
104
|
+
total_output_tokens=self.last_completion_tokens,
|
|
105
|
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import openai
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
from groknroll.clients.base_lm import BaseLM
|
|
9
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
10
|
+
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
# Load API keys from environment variables
|
|
14
|
+
DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
15
|
+
DEFAULT_OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
|
|
16
|
+
DEFAULT_VERCEL_API_KEY = os.getenv("AI_GATEWAY_API_KEY")
|
|
17
|
+
DEFAULT_PRIME_INTELLECT_BASE_URL = "https://api.pinference.ai/api/v1/"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OpenAIClient(BaseLM):
|
|
21
|
+
"""
|
|
22
|
+
LM Client for running models with the OpenAI API. Works with vLLM as well.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
model_name: str | None = None,
|
|
29
|
+
base_url: str | None = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
33
|
+
|
|
34
|
+
if api_key is None:
|
|
35
|
+
if base_url == "https://api.openai.com/v1" or base_url is None:
|
|
36
|
+
api_key = DEFAULT_OPENAI_API_KEY
|
|
37
|
+
elif base_url == "https://openrouter.ai/api/v1":
|
|
38
|
+
api_key = DEFAULT_OPENROUTER_API_KEY
|
|
39
|
+
elif base_url == "https://ai-gateway.vercel.sh/v1":
|
|
40
|
+
api_key = DEFAULT_VERCEL_API_KEY
|
|
41
|
+
|
|
42
|
+
# For vLLM, set base_url to local vLLM server address.
|
|
43
|
+
self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
|
44
|
+
self.async_client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
45
|
+
self.model_name = model_name
|
|
46
|
+
|
|
47
|
+
# Per-model usage tracking
|
|
48
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
49
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
50
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
51
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
52
|
+
|
|
53
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
54
|
+
if isinstance(prompt, str):
|
|
55
|
+
messages = [{"role": "user", "content": prompt}]
|
|
56
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
57
|
+
messages = prompt
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
60
|
+
|
|
61
|
+
model = model or self.model_name
|
|
62
|
+
if not model:
|
|
63
|
+
raise ValueError("Model name is required for OpenAI client.")
|
|
64
|
+
|
|
65
|
+
extra_body = {}
|
|
66
|
+
if self.client.base_url == DEFAULT_PRIME_INTELLECT_BASE_URL:
|
|
67
|
+
extra_body["usage"] = {"include": True}
|
|
68
|
+
|
|
69
|
+
response = self.client.chat.completions.create(
|
|
70
|
+
model=model, messages=messages, extra_body=extra_body
|
|
71
|
+
)
|
|
72
|
+
self._track_cost(response, model)
|
|
73
|
+
return response.choices[0].message.content
|
|
74
|
+
|
|
75
|
+
async def acompletion(
|
|
76
|
+
self, prompt: str | list[dict[str, Any]], model: str | None = None
|
|
77
|
+
) -> str:
|
|
78
|
+
if isinstance(prompt, str):
|
|
79
|
+
messages = [{"role": "user", "content": prompt}]
|
|
80
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
81
|
+
messages = prompt
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
84
|
+
|
|
85
|
+
model = model or self.model_name
|
|
86
|
+
if not model:
|
|
87
|
+
raise ValueError("Model name is required for OpenAI client.")
|
|
88
|
+
|
|
89
|
+
extra_body = {}
|
|
90
|
+
if self.client.base_url == DEFAULT_PRIME_INTELLECT_BASE_URL:
|
|
91
|
+
extra_body["usage"] = {"include": True}
|
|
92
|
+
|
|
93
|
+
response = await self.async_client.chat.completions.create(
|
|
94
|
+
model=model, messages=messages, extra_body=extra_body
|
|
95
|
+
)
|
|
96
|
+
self._track_cost(response, model)
|
|
97
|
+
return response.choices[0].message.content
|
|
98
|
+
|
|
99
|
+
def _track_cost(self, response: openai.ChatCompletion, model: str):
|
|
100
|
+
self.model_call_counts[model] += 1
|
|
101
|
+
|
|
102
|
+
usage = getattr(response, "usage", None)
|
|
103
|
+
if usage is None:
|
|
104
|
+
raise ValueError("No usage data received. Tracking tokens not possible.")
|
|
105
|
+
|
|
106
|
+
self.model_input_tokens[model] += usage.prompt_tokens
|
|
107
|
+
self.model_output_tokens[model] += usage.completion_tokens
|
|
108
|
+
self.model_total_tokens[model] += usage.total_tokens
|
|
109
|
+
|
|
110
|
+
# Track last call for handler to read
|
|
111
|
+
self.last_prompt_tokens = usage.prompt_tokens
|
|
112
|
+
self.last_completion_tokens = usage.completion_tokens
|
|
113
|
+
|
|
114
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
115
|
+
model_summaries = {}
|
|
116
|
+
for model in self.model_call_counts:
|
|
117
|
+
model_summaries[model] = ModelUsageSummary(
|
|
118
|
+
total_calls=self.model_call_counts[model],
|
|
119
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
120
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
121
|
+
)
|
|
122
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
123
|
+
|
|
124
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
125
|
+
return ModelUsageSummary(
|
|
126
|
+
total_calls=1,
|
|
127
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
128
|
+
total_output_tokens=self.last_completion_tokens,
|
|
129
|
+
)
|