groknroll 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. groknroll/__init__.py +36 -0
  2. groknroll/__main__.py +9 -0
  3. groknroll/agents/__init__.py +18 -0
  4. groknroll/agents/agent_manager.py +187 -0
  5. groknroll/agents/base_agent.py +118 -0
  6. groknroll/agents/build_agent.py +231 -0
  7. groknroll/agents/plan_agent.py +215 -0
  8. groknroll/cli/__init__.py +7 -0
  9. groknroll/cli/enhanced_cli.py +372 -0
  10. groknroll/cli/large_codebase_cli.py +413 -0
  11. groknroll/cli/main.py +331 -0
  12. groknroll/cli/rlm_commands.py +258 -0
  13. groknroll/clients/__init__.py +63 -0
  14. groknroll/clients/anthropic.py +112 -0
  15. groknroll/clients/azure_openai.py +142 -0
  16. groknroll/clients/base_lm.py +33 -0
  17. groknroll/clients/gemini.py +162 -0
  18. groknroll/clients/litellm.py +105 -0
  19. groknroll/clients/openai.py +129 -0
  20. groknroll/clients/portkey.py +94 -0
  21. groknroll/core/__init__.py +9 -0
  22. groknroll/core/agent.py +339 -0
  23. groknroll/core/comms_utils.py +264 -0
  24. groknroll/core/context.py +251 -0
  25. groknroll/core/exceptions.py +181 -0
  26. groknroll/core/large_codebase.py +564 -0
  27. groknroll/core/lm_handler.py +206 -0
  28. groknroll/core/rlm.py +446 -0
  29. groknroll/core/rlm_codebase.py +448 -0
  30. groknroll/core/rlm_integration.py +256 -0
  31. groknroll/core/types.py +276 -0
  32. groknroll/environments/__init__.py +34 -0
  33. groknroll/environments/base_env.py +182 -0
  34. groknroll/environments/constants.py +32 -0
  35. groknroll/environments/docker_repl.py +336 -0
  36. groknroll/environments/local_repl.py +388 -0
  37. groknroll/environments/modal_repl.py +502 -0
  38. groknroll/environments/prime_repl.py +588 -0
  39. groknroll/logger/__init__.py +4 -0
  40. groknroll/logger/rlm_logger.py +63 -0
  41. groknroll/logger/verbose.py +393 -0
  42. groknroll/operations/__init__.py +15 -0
  43. groknroll/operations/bash_ops.py +447 -0
  44. groknroll/operations/file_ops.py +473 -0
  45. groknroll/operations/git_ops.py +620 -0
  46. groknroll/oracle/__init__.py +11 -0
  47. groknroll/oracle/codebase_indexer.py +238 -0
  48. groknroll/oracle/oracle_agent.py +278 -0
  49. groknroll/setup.py +34 -0
  50. groknroll/storage/__init__.py +14 -0
  51. groknroll/storage/database.py +272 -0
  52. groknroll/storage/models.py +128 -0
  53. groknroll/utils/__init__.py +0 -0
  54. groknroll/utils/parsing.py +168 -0
  55. groknroll/utils/prompts.py +146 -0
  56. groknroll/utils/rlm_utils.py +19 -0
  57. groknroll-2.0.0.dist-info/METADATA +246 -0
  58. groknroll-2.0.0.dist-info/RECORD +62 -0
  59. groknroll-2.0.0.dist-info/WHEEL +5 -0
  60. groknroll-2.0.0.dist-info/entry_points.txt +3 -0
  61. groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
  62. groknroll-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,112 @@
1
+ from collections import defaultdict
2
+ from typing import Any
3
+
4
+ import anthropic
5
+
6
+ from groknroll.clients.base_lm import BaseLM
7
+ from groknroll.core.types import ModelUsageSummary, UsageSummary
8
+
9
+
10
+ class AnthropicClient(BaseLM):
11
+ """
12
+ LM Client for running models with the Anthropic API.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ api_key: str,
18
+ model_name: str | None = None,
19
+ max_tokens: int = 32768,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(model_name=model_name, **kwargs)
23
+ self.client = anthropic.Anthropic(api_key=api_key)
24
+ self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
25
+ self.model_name = model_name
26
+ self.max_tokens = max_tokens
27
+
28
+ # Per-model usage tracking
29
+ self.model_call_counts: dict[str, int] = defaultdict(int)
30
+ self.model_input_tokens: dict[str, int] = defaultdict(int)
31
+ self.model_output_tokens: dict[str, int] = defaultdict(int)
32
+ self.model_total_tokens: dict[str, int] = defaultdict(int)
33
+
34
+ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
35
+ messages, system = self._prepare_messages(prompt)
36
+
37
+ model = model or self.model_name
38
+ if not model:
39
+ raise ValueError("Model name is required for Anthropic client.")
40
+
41
+ kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
42
+ if system:
43
+ kwargs["system"] = system
44
+
45
+ response = self.client.messages.create(**kwargs)
46
+ self._track_cost(response, model)
47
+ return response.content[0].text
48
+
49
+ async def acompletion(
50
+ self, prompt: str | list[dict[str, Any]], model: str | None = None
51
+ ) -> str:
52
+ messages, system = self._prepare_messages(prompt)
53
+
54
+ model = model or self.model_name
55
+ if not model:
56
+ raise ValueError("Model name is required for Anthropic client.")
57
+
58
+ kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
59
+ if system:
60
+ kwargs["system"] = system
61
+
62
+ response = await self.async_client.messages.create(**kwargs)
63
+ self._track_cost(response, model)
64
+ return response.content[0].text
65
+
66
+ def _prepare_messages(
67
+ self, prompt: str | list[dict[str, Any]]
68
+ ) -> tuple[list[dict[str, Any]], str | None]:
69
+ """Prepare messages and extract system prompt for Anthropic API."""
70
+ system = None
71
+
72
+ if isinstance(prompt, str):
73
+ messages = [{"role": "user", "content": prompt}]
74
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
75
+ # Extract system message if present (Anthropic handles system separately)
76
+ messages = []
77
+ for msg in prompt:
78
+ if msg.get("role") == "system":
79
+ system = msg.get("content")
80
+ else:
81
+ messages.append(msg)
82
+ else:
83
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
84
+
85
+ return messages, system
86
+
87
+ def _track_cost(self, response: anthropic.types.Message, model: str):
88
+ self.model_call_counts[model] += 1
89
+ self.model_input_tokens[model] += response.usage.input_tokens
90
+ self.model_output_tokens[model] += response.usage.output_tokens
91
+ self.model_total_tokens[model] += response.usage.input_tokens + response.usage.output_tokens
92
+
93
+ # Track last call for handler to read
94
+ self.last_prompt_tokens = response.usage.input_tokens
95
+ self.last_completion_tokens = response.usage.output_tokens
96
+
97
+ def get_usage_summary(self) -> UsageSummary:
98
+ model_summaries = {}
99
+ for model in self.model_call_counts:
100
+ model_summaries[model] = ModelUsageSummary(
101
+ total_calls=self.model_call_counts[model],
102
+ total_input_tokens=self.model_input_tokens[model],
103
+ total_output_tokens=self.model_output_tokens[model],
104
+ )
105
+ return UsageSummary(model_usage_summaries=model_summaries)
106
+
107
+ def get_last_usage(self) -> ModelUsageSummary:
108
+ return ModelUsageSummary(
109
+ total_calls=1,
110
+ total_input_tokens=self.last_prompt_tokens,
111
+ total_output_tokens=self.last_completion_tokens,
112
+ )
@@ -0,0 +1,142 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from typing import Any
4
+
5
+ import openai
6
+ from dotenv import load_dotenv
7
+
8
+ from groknroll.clients.base_lm import BaseLM
9
+ from groknroll.core.types import ModelUsageSummary, UsageSummary
10
+
11
+ load_dotenv()
12
+
13
+ # Load API key from environment variable
14
+ DEFAULT_AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
15
+
16
+
17
+ class AzureOpenAIClient(BaseLM):
18
+ """
19
+ LM Client for running models with the Azure OpenAI API.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ api_key: str | None = None,
25
+ model_name: str | None = None,
26
+ azure_endpoint: str | None = None,
27
+ api_version: str | None = None,
28
+ azure_deployment: str | None = None,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(model_name=model_name, **kwargs)
32
+
33
+ if api_key is None:
34
+ api_key = DEFAULT_AZURE_OPENAI_API_KEY
35
+
36
+ if azure_endpoint is None:
37
+ azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
38
+
39
+ if api_version is None:
40
+ api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01")
41
+
42
+ if azure_deployment is None:
43
+ azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
44
+
45
+ if azure_endpoint is None:
46
+ raise ValueError(
47
+ "azure_endpoint is required for Azure OpenAI client. "
48
+ "Set it via argument or AZURE_OPENAI_ENDPOINT environment variable."
49
+ )
50
+
51
+ self.client = openai.AzureOpenAI(
52
+ api_key=api_key,
53
+ azure_endpoint=azure_endpoint,
54
+ api_version=api_version,
55
+ azure_deployment=azure_deployment,
56
+ )
57
+ self.async_client = openai.AsyncAzureOpenAI(
58
+ api_key=api_key,
59
+ azure_endpoint=azure_endpoint,
60
+ api_version=api_version,
61
+ azure_deployment=azure_deployment,
62
+ )
63
+ self.model_name = model_name
64
+ self.azure_deployment = azure_deployment
65
+
66
+ # Per-model usage tracking
67
+ self.model_call_counts: dict[str, int] = defaultdict(int)
68
+ self.model_input_tokens: dict[str, int] = defaultdict(int)
69
+ self.model_output_tokens: dict[str, int] = defaultdict(int)
70
+ self.model_total_tokens: dict[str, int] = defaultdict(int)
71
+
72
+ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
73
+ if isinstance(prompt, str):
74
+ messages = [{"role": "user", "content": prompt}]
75
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
76
+ messages = prompt
77
+ else:
78
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
79
+
80
+ model = model or self.model_name
81
+ if not model:
82
+ raise ValueError("Model name is required for Azure OpenAI client.")
83
+
84
+ response = self.client.chat.completions.create(
85
+ model=model,
86
+ messages=messages,
87
+ )
88
+ self._track_cost(response, model)
89
+ return response.choices[0].message.content
90
+
91
+ async def acompletion(
92
+ self, prompt: str | list[dict[str, Any]], model: str | None = None
93
+ ) -> str:
94
+ if isinstance(prompt, str):
95
+ messages = [{"role": "user", "content": prompt}]
96
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
97
+ messages = prompt
98
+ else:
99
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
100
+
101
+ model = model or self.model_name
102
+ if not model:
103
+ raise ValueError("Model name is required for Azure OpenAI client.")
104
+
105
+ response = await self.async_client.chat.completions.create(
106
+ model=model,
107
+ messages=messages,
108
+ )
109
+ self._track_cost(response, model)
110
+ return response.choices[0].message.content
111
+
112
+ def _track_cost(self, response: openai.ChatCompletion, model: str):
113
+ self.model_call_counts[model] += 1
114
+
115
+ usage = getattr(response, "usage", None)
116
+ if usage is None:
117
+ raise ValueError("No usage data received. Tracking tokens not possible.")
118
+
119
+ self.model_input_tokens[model] += usage.prompt_tokens
120
+ self.model_output_tokens[model] += usage.completion_tokens
121
+ self.model_total_tokens[model] += usage.total_tokens
122
+
123
+ # Track last call for handler to read
124
+ self.last_prompt_tokens = usage.prompt_tokens
125
+ self.last_completion_tokens = usage.completion_tokens
126
+
127
+ def get_usage_summary(self) -> UsageSummary:
128
+ model_summaries = {}
129
+ for model in self.model_call_counts:
130
+ model_summaries[model] = ModelUsageSummary(
131
+ total_calls=self.model_call_counts[model],
132
+ total_input_tokens=self.model_input_tokens[model],
133
+ total_output_tokens=self.model_output_tokens[model],
134
+ )
135
+ return UsageSummary(model_usage_summaries=model_summaries)
136
+
137
+ def get_last_usage(self) -> ModelUsageSummary:
138
+ return ModelUsageSummary(
139
+ total_calls=1,
140
+ total_input_tokens=self.last_prompt_tokens,
141
+ total_output_tokens=self.last_completion_tokens,
142
+ )
@@ -0,0 +1,33 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from groknroll.core.types import UsageSummary
5
+
6
+
7
+ class BaseLM(ABC):
8
+ """
9
+ Base class for all language model routers / clients. When the RLM makes sub-calls, it currently
10
+ does so in a model-agnostic way, so this class provides a base interface for all language models.
11
+ """
12
+
13
+ def __init__(self, model_name: str, **kwargs):
14
+ self.model_name = model_name
15
+ self.kwargs = kwargs
16
+
17
+ @abstractmethod
18
+ def completion(self, prompt: str | dict[str, Any]) -> str:
19
+ raise NotImplementedError
20
+
21
+ @abstractmethod
22
+ async def acompletion(self, prompt: str | dict[str, Any]) -> str:
23
+ raise NotImplementedError
24
+
25
+ @abstractmethod
26
+ def get_usage_summary(self) -> UsageSummary:
27
+ """Get cost summary for all model calls."""
28
+ raise NotImplementedError
29
+
30
+ @abstractmethod
31
+ def get_last_usage(self) -> UsageSummary:
32
+ """Get the last cost summary of the model."""
33
+ raise NotImplementedError
@@ -0,0 +1,162 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from typing import Any
4
+
5
+ from dotenv import load_dotenv
6
+ from google import genai
7
+ from google.genai import types
8
+
9
+ from groknroll.clients.base_lm import BaseLM
10
+ from groknroll.core.types import ModelUsageSummary, UsageSummary
11
+
12
+ load_dotenv()
13
+
14
+ DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
15
+
16
+
17
+ class GeminiClient(BaseLM):
18
+ """
19
+ LM Client for running models with the Google Gemini API.
20
+ Uses the official google-genai SDK.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ api_key: str | None = None,
26
+ model_name: str | None = "gemini-2.5-flash",
27
+ **kwargs,
28
+ ):
29
+ super().__init__(model_name=model_name, **kwargs)
30
+
31
+ if api_key is None:
32
+ api_key = DEFAULT_GEMINI_API_KEY
33
+
34
+ if api_key is None:
35
+ raise ValueError(
36
+ "Gemini API key is required. Set GEMINI_API_KEY env var or pass api_key."
37
+ )
38
+
39
+ self.client = genai.Client(api_key=api_key)
40
+ self.model_name = model_name
41
+
42
+ # Per-model usage tracking
43
+ self.model_call_counts: dict[str, int] = defaultdict(int)
44
+ self.model_input_tokens: dict[str, int] = defaultdict(int)
45
+ self.model_output_tokens: dict[str, int] = defaultdict(int)
46
+ self.model_total_tokens: dict[str, int] = defaultdict(int)
47
+
48
+ # Last call tracking
49
+ self.last_prompt_tokens = 0
50
+ self.last_completion_tokens = 0
51
+
52
+ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
53
+ contents, system_instruction = self._prepare_contents(prompt)
54
+
55
+ model = model or self.model_name
56
+ if not model:
57
+ raise ValueError("Model name is required for Gemini client.")
58
+
59
+ config = None
60
+ if system_instruction:
61
+ config = types.GenerateContentConfig(system_instruction=system_instruction)
62
+
63
+ response = self.client.models.generate_content(
64
+ model=model,
65
+ contents=contents,
66
+ config=config,
67
+ )
68
+
69
+ self._track_cost(response, model)
70
+ return response.text
71
+
72
+ async def acompletion(
73
+ self, prompt: str | list[dict[str, Any]], model: str | None = None
74
+ ) -> str:
75
+ contents, system_instruction = self._prepare_contents(prompt)
76
+
77
+ model = model or self.model_name
78
+ if not model:
79
+ raise ValueError("Model name is required for Gemini client.")
80
+
81
+ config = None
82
+ if system_instruction:
83
+ config = types.GenerateContentConfig(system_instruction=system_instruction)
84
+
85
+ # google-genai SDK supports async via aio interface
86
+ response = await self.client.aio.models.generate_content(
87
+ model=model,
88
+ contents=contents,
89
+ config=config,
90
+ )
91
+
92
+ self._track_cost(response, model)
93
+ return response.text
94
+
95
+ def _prepare_contents(
96
+ self, prompt: str | list[dict[str, Any]]
97
+ ) -> tuple[list[types.Content] | str, str | None]:
98
+ """Prepare contents and extract system instruction for Gemini API."""
99
+ system_instruction = None
100
+
101
+ if isinstance(prompt, str):
102
+ return prompt, None
103
+
104
+ if isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
105
+ # Convert OpenAI-style messages to Gemini format
106
+ contents = []
107
+ for msg in prompt:
108
+ role = msg.get("role")
109
+ content = msg.get("content", "")
110
+
111
+ if role == "system":
112
+ # Gemini handles system instruction separately
113
+ system_instruction = content
114
+ elif role == "user":
115
+ contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
116
+ elif role == "assistant":
117
+ # Gemini uses "model" instead of "assistant"
118
+ contents.append(types.Content(role="model", parts=[types.Part(text=content)]))
119
+ else:
120
+ # Default to user role for unknown roles
121
+ contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
122
+
123
+ return contents, system_instruction
124
+
125
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
126
+
127
+ def _track_cost(self, response: types.GenerateContentResponse, model: str):
128
+ self.model_call_counts[model] += 1
129
+
130
+ # Extract token usage from response
131
+ usage = response.usage_metadata
132
+ if usage:
133
+ input_tokens = usage.prompt_token_count or 0
134
+ output_tokens = usage.candidates_token_count or 0
135
+
136
+ self.model_input_tokens[model] += input_tokens
137
+ self.model_output_tokens[model] += output_tokens
138
+ self.model_total_tokens[model] += input_tokens + output_tokens
139
+
140
+ # Track last call for handler to read
141
+ self.last_prompt_tokens = input_tokens
142
+ self.last_completion_tokens = output_tokens
143
+ else:
144
+ self.last_prompt_tokens = 0
145
+ self.last_completion_tokens = 0
146
+
147
+ def get_usage_summary(self) -> UsageSummary:
148
+ model_summaries = {}
149
+ for model in self.model_call_counts:
150
+ model_summaries[model] = ModelUsageSummary(
151
+ total_calls=self.model_call_counts[model],
152
+ total_input_tokens=self.model_input_tokens[model],
153
+ total_output_tokens=self.model_output_tokens[model],
154
+ )
155
+ return UsageSummary(model_usage_summaries=model_summaries)
156
+
157
+ def get_last_usage(self) -> ModelUsageSummary:
158
+ return ModelUsageSummary(
159
+ total_calls=1,
160
+ total_input_tokens=self.last_prompt_tokens,
161
+ total_output_tokens=self.last_completion_tokens,
162
+ )
@@ -0,0 +1,105 @@
1
+ from collections import defaultdict
2
+ from typing import Any
3
+
4
+ import litellm
5
+
6
+ from groknroll.clients.base_lm import BaseLM
7
+ from groknroll.core.types import ModelUsageSummary, UsageSummary
8
+
9
+
10
+ class LiteLLMClient(BaseLM):
11
+ """
12
+ LM Client for running models with LiteLLM.
13
+ LiteLLM provides a unified interface to 100+ LLM providers.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ model_name: str | None = None,
19
+ api_key: str | None = None,
20
+ api_base: str | None = None,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(model_name=model_name, **kwargs)
24
+ self.model_name = model_name
25
+ self.api_key = api_key
26
+ self.api_base = api_base
27
+
28
+ # Per-model usage tracking
29
+ self.model_call_counts: dict[str, int] = defaultdict(int)
30
+ self.model_input_tokens: dict[str, int] = defaultdict(int)
31
+ self.model_output_tokens: dict[str, int] = defaultdict(int)
32
+ self.model_total_tokens: dict[str, int] = defaultdict(int)
33
+
34
+ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
35
+ if isinstance(prompt, str):
36
+ messages = [{"role": "user", "content": prompt}]
37
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
38
+ messages = prompt
39
+ else:
40
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
41
+
42
+ model = model or self.model_name
43
+ if not model:
44
+ raise ValueError("Model name is required for LiteLLM client.")
45
+
46
+ kwargs = {"model": model, "messages": messages}
47
+ if self.api_key:
48
+ kwargs["api_key"] = self.api_key
49
+ if self.api_base:
50
+ kwargs["api_base"] = self.api_base
51
+
52
+ response = litellm.completion(**kwargs)
53
+ self._track_cost(response, model)
54
+ return response.choices[0].message.content
55
+
56
+ async def acompletion(
57
+ self, prompt: str | list[dict[str, Any]], model: str | None = None
58
+ ) -> str:
59
+ if isinstance(prompt, str):
60
+ messages = [{"role": "user", "content": prompt}]
61
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
62
+ messages = prompt
63
+ else:
64
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
65
+
66
+ model = model or self.model_name
67
+ if not model:
68
+ raise ValueError("Model name is required for LiteLLM client.")
69
+
70
+ kwargs = {"model": model, "messages": messages}
71
+ if self.api_key:
72
+ kwargs["api_key"] = self.api_key
73
+ if self.api_base:
74
+ kwargs["api_base"] = self.api_base
75
+
76
+ response = await litellm.acompletion(**kwargs)
77
+ self._track_cost(response, model)
78
+ return response.choices[0].message.content
79
+
80
+ def _track_cost(self, response, model: str):
81
+ self.model_call_counts[model] += 1
82
+ self.model_input_tokens[model] += response.usage.prompt_tokens
83
+ self.model_output_tokens[model] += response.usage.completion_tokens
84
+ self.model_total_tokens[model] += response.usage.total_tokens
85
+
86
+ # Track last call for handler to read
87
+ self.last_prompt_tokens = response.usage.prompt_tokens
88
+ self.last_completion_tokens = response.usage.completion_tokens
89
+
90
+ def get_usage_summary(self) -> UsageSummary:
91
+ model_summaries = {}
92
+ for model in self.model_call_counts:
93
+ model_summaries[model] = ModelUsageSummary(
94
+ total_calls=self.model_call_counts[model],
95
+ total_input_tokens=self.model_input_tokens[model],
96
+ total_output_tokens=self.model_output_tokens[model],
97
+ )
98
+ return UsageSummary(model_usage_summaries=model_summaries)
99
+
100
+ def get_last_usage(self) -> ModelUsageSummary:
101
+ return ModelUsageSummary(
102
+ total_calls=1,
103
+ total_input_tokens=self.last_prompt_tokens,
104
+ total_output_tokens=self.last_completion_tokens,
105
+ )
@@ -0,0 +1,129 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from typing import Any
4
+
5
+ import openai
6
+ from dotenv import load_dotenv
7
+
8
+ from groknroll.clients.base_lm import BaseLM
9
+ from groknroll.core.types import ModelUsageSummary, UsageSummary
10
+
11
+ load_dotenv()
12
+
13
+ # Load API keys from environment variables
14
+ DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
15
+ DEFAULT_OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
16
+ DEFAULT_VERCEL_API_KEY = os.getenv("AI_GATEWAY_API_KEY")
17
+ DEFAULT_PRIME_INTELLECT_BASE_URL = "https://api.pinference.ai/api/v1/"
18
+
19
+
20
+ class OpenAIClient(BaseLM):
21
+ """
22
+ LM Client for running models with the OpenAI API. Works with vLLM as well.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ api_key: str | None = None,
28
+ model_name: str | None = None,
29
+ base_url: str | None = None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(model_name=model_name, **kwargs)
33
+
34
+ if api_key is None:
35
+ if base_url == "https://api.openai.com/v1" or base_url is None:
36
+ api_key = DEFAULT_OPENAI_API_KEY
37
+ elif base_url == "https://openrouter.ai/api/v1":
38
+ api_key = DEFAULT_OPENROUTER_API_KEY
39
+ elif base_url == "https://ai-gateway.vercel.sh/v1":
40
+ api_key = DEFAULT_VERCEL_API_KEY
41
+
42
+ # For vLLM, set base_url to local vLLM server address.
43
+ self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
44
+ self.async_client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url)
45
+ self.model_name = model_name
46
+
47
+ # Per-model usage tracking
48
+ self.model_call_counts: dict[str, int] = defaultdict(int)
49
+ self.model_input_tokens: dict[str, int] = defaultdict(int)
50
+ self.model_output_tokens: dict[str, int] = defaultdict(int)
51
+ self.model_total_tokens: dict[str, int] = defaultdict(int)
52
+
53
+ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
54
+ if isinstance(prompt, str):
55
+ messages = [{"role": "user", "content": prompt}]
56
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
57
+ messages = prompt
58
+ else:
59
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
60
+
61
+ model = model or self.model_name
62
+ if not model:
63
+ raise ValueError("Model name is required for OpenAI client.")
64
+
65
+ extra_body = {}
66
+ if self.client.base_url == DEFAULT_PRIME_INTELLECT_BASE_URL:
67
+ extra_body["usage"] = {"include": True}
68
+
69
+ response = self.client.chat.completions.create(
70
+ model=model, messages=messages, extra_body=extra_body
71
+ )
72
+ self._track_cost(response, model)
73
+ return response.choices[0].message.content
74
+
75
+ async def acompletion(
76
+ self, prompt: str | list[dict[str, Any]], model: str | None = None
77
+ ) -> str:
78
+ if isinstance(prompt, str):
79
+ messages = [{"role": "user", "content": prompt}]
80
+ elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
81
+ messages = prompt
82
+ else:
83
+ raise ValueError(f"Invalid prompt type: {type(prompt)}")
84
+
85
+ model = model or self.model_name
86
+ if not model:
87
+ raise ValueError("Model name is required for OpenAI client.")
88
+
89
+ extra_body = {}
90
+ if self.client.base_url == DEFAULT_PRIME_INTELLECT_BASE_URL:
91
+ extra_body["usage"] = {"include": True}
92
+
93
+ response = await self.async_client.chat.completions.create(
94
+ model=model, messages=messages, extra_body=extra_body
95
+ )
96
+ self._track_cost(response, model)
97
+ return response.choices[0].message.content
98
+
99
+ def _track_cost(self, response: openai.ChatCompletion, model: str):
100
+ self.model_call_counts[model] += 1
101
+
102
+ usage = getattr(response, "usage", None)
103
+ if usage is None:
104
+ raise ValueError("No usage data received. Tracking tokens not possible.")
105
+
106
+ self.model_input_tokens[model] += usage.prompt_tokens
107
+ self.model_output_tokens[model] += usage.completion_tokens
108
+ self.model_total_tokens[model] += usage.total_tokens
109
+
110
+ # Track last call for handler to read
111
+ self.last_prompt_tokens = usage.prompt_tokens
112
+ self.last_completion_tokens = usage.completion_tokens
113
+
114
+ def get_usage_summary(self) -> UsageSummary:
115
+ model_summaries = {}
116
+ for model in self.model_call_counts:
117
+ model_summaries[model] = ModelUsageSummary(
118
+ total_calls=self.model_call_counts[model],
119
+ total_input_tokens=self.model_input_tokens[model],
120
+ total_output_tokens=self.model_output_tokens[model],
121
+ )
122
+ return UsageSummary(model_usage_summaries=model_summaries)
123
+
124
+ def get_last_usage(self) -> ModelUsageSummary:
125
+ return ModelUsageSummary(
126
+ total_calls=1,
127
+ total_input_tokens=self.last_prompt_tokens,
128
+ total_output_tokens=self.last_completion_tokens,
129
+ )