droidrun 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,213 @@
1
+ import contextlib
2
+ from llama_index.core.callbacks import CallbackManager
3
+ from llama_index.core.callbacks.base_handler import BaseCallbackHandler
4
+ from llama_index.core.callbacks.schema import CBEventType, EventPayload
5
+ from llama_index.core.llms import LLM, ChatResponse
6
+ from pydantic import BaseModel
7
+ from typing import Any, Dict, List, Optional
8
+ from uuid import uuid4
9
+ import logging
10
+
11
+ logger = logging.getLogger("droidrun")
12
+ SUPPORTED_PROVIDERS = [
13
+ "Gemini",
14
+ "GoogleGenAI",
15
+ "OpenAI",
16
+ "Anthropic",
17
+ "Ollama",
18
+ "DeepSeek",
19
+ ]
20
+
21
+
22
+ class UsageResult(BaseModel):
23
+ request_tokens: int
24
+ response_tokens: int
25
+ total_tokens: int
26
+ requests: int
27
+
28
+ def get_usage_from_response(provider: str, chat_rsp: ChatResponse) -> UsageResult:
29
+ rsp = chat_rsp.raw
30
+ if not rsp:
31
+ raise ValueError("No raw response in chat response")
32
+
33
+ print(f"rsp: {rsp.__class__.__name__}")
34
+
35
+ if provider == "Gemini" or provider == "GoogleGenAI":
36
+ return UsageResult(
37
+ request_tokens=rsp["usage_metadata"]["prompt_token_count"],
38
+ response_tokens=rsp["usage_metadata"]["candidates_token_count"],
39
+ total_tokens=rsp["usage_metadata"]["total_token_count"],
40
+ requests=1,
41
+ )
42
+ elif provider == "OpenAI":
43
+ from openai.types import CompletionUsage as OpenAIUsage
44
+
45
+ usage: OpenAIUsage = rsp.usage
46
+ return UsageResult(
47
+ request_tokens=usage.prompt_tokens,
48
+ response_tokens=usage.completion_tokens,
49
+ total_tokens=usage.total_tokens,
50
+ requests=1,
51
+ )
52
+ elif provider == "Anthropic":
53
+ from anthropic.types import Usage as AnthropicUsage
54
+
55
+ usage: AnthropicUsage = rsp["usage"]
56
+ return UsageResult(
57
+ request_tokens=usage.input_tokens,
58
+ response_tokens=usage.output_tokens,
59
+ total_tokens=usage.input_tokens + usage.output_tokens,
60
+ requests=1,
61
+ )
62
+ elif provider == "Ollama":
63
+ # Ollama response format uses different field names
64
+ prompt_eval_count = rsp.get("prompt_eval_count", 0)
65
+ eval_count = rsp.get("eval_count", 0)
66
+ return UsageResult(
67
+ request_tokens=prompt_eval_count,
68
+ response_tokens=eval_count,
69
+ total_tokens=prompt_eval_count + eval_count,
70
+ requests=1,
71
+ )
72
+ elif provider == "DeepSeek":
73
+ # DeepSeek follows OpenAI-compatible format
74
+ usage = rsp.usage
75
+ if not usage:
76
+ usage = {}
77
+ return UsageResult(
78
+ request_tokens=usage.prompt_tokens or 0,
79
+ response_tokens=usage.completion_tokens or 0,
80
+ total_tokens=usage.total_tokens or 0,
81
+ requests=1,
82
+ )
83
+
84
+ raise ValueError(f"Unsupported provider: {provider}")
85
+
86
+ class TokenCountingHandler(BaseCallbackHandler):
87
+ """Token counting handler for LLamaIndex LLM calls."""
88
+
89
+ def __init__(self, provider: str):
90
+ super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
91
+ self.provider = provider
92
+ self.request_tokens: int = 0
93
+ self.response_tokens: int = 0
94
+ self.total_tokens: int = 0
95
+ self.requests: int = 0
96
+
97
+ @classmethod
98
+ def class_name(cls) -> str:
99
+ """Class name."""
100
+ return "TokenCountingHandler"
101
+
102
+ @property
103
+ def usage(self) -> UsageResult:
104
+ return UsageResult(
105
+ request_tokens=self.request_tokens,
106
+ response_tokens=self.response_tokens,
107
+ total_tokens=self.total_tokens,
108
+ requests=self.requests,
109
+ )
110
+
111
+ def _get_event_usage(self, payload: Dict[str, Any]) -> UsageResult:
112
+ if not EventPayload.RESPONSE in payload:
113
+ raise ValueError("No response in payload")
114
+
115
+ chat_rsp: ChatResponse = payload.get(EventPayload.RESPONSE)
116
+ return get_usage_from_response(self.provider, chat_rsp)
117
+
118
+ def on_event_start(
119
+ self,
120
+ event_type: CBEventType,
121
+ payload: Optional[Dict[str, Any]] = None,
122
+ event_id: str = "",
123
+ parent_id: str = "",
124
+ **kwargs: Any,
125
+ ) -> str:
126
+ """Run when an event starts and return id of event."""
127
+ return event_id or str(uuid4())
128
+
129
+ def on_event_end(
130
+ self,
131
+ event_type: CBEventType,
132
+ payload: Optional[Dict[str, Any]] = None,
133
+ event_id: str = "",
134
+ **kwargs: Any,
135
+ ) -> None:
136
+ """Run when an event ends."""
137
+ try:
138
+ usage = self._get_event_usage(payload)
139
+
140
+ self.request_tokens += usage.request_tokens
141
+ self.response_tokens += usage.response_tokens
142
+ self.total_tokens += usage.total_tokens
143
+ self.requests += usage.requests
144
+ except Exception as e:
145
+ self.requests += 1
146
+ logger.warning(
147
+ f"Error tracking usage for provider {self.provider}: {e}",
148
+ extra={"provider": self.provider},
149
+ )
150
+
151
+ def start_trace(self, trace_id: Optional[str] = None) -> None:
152
+ """Run when an overall trace is launched."""
153
+ pass
154
+
155
+ def end_trace(
156
+ self,
157
+ trace_id: Optional[str] = None,
158
+ trace_map: Optional[Dict[str, List[str]]] = None,
159
+ ) -> None:
160
+ """Run when an overall trace is exited."""
161
+ pass
162
+
163
+ @contextlib.contextmanager
164
+ def llm_callback(llm: LLM, *args: List[BaseCallbackHandler]):
165
+ for arg in args:
166
+ llm.callback_manager.add_handler(arg)
167
+ yield
168
+ for arg in args:
169
+ llm.callback_manager.remove_handler(arg)
170
+
171
+ def create_tracker(llm: LLM) -> TokenCountingHandler:
172
+ provider = llm.__class__.__name__
173
+ if provider not in SUPPORTED_PROVIDERS:
174
+ raise ValueError(f"Tracking not yet supported for provider: {provider}")
175
+
176
+ return TokenCountingHandler(provider)
177
+
178
+
179
+ def track_usage(llm: LLM) -> TokenCountingHandler:
180
+ """Track token usage for an LLM instance across all requests.
181
+
182
+ This function:
183
+ - Creates a new TokenCountingHandler for the LLM provider
184
+ - Registers that handler as an LLM callback to monitor all requests
185
+ - Returns the handler for accessing cumulative usage statistics
186
+
187
+ The handler counts tokens for total LLM usage across all requests. For fine-grained
188
+ per-request counting, use either:
189
+ - `create_tracker()` with `llm_callback()` context manager for temporary tracking
190
+ - `get_usage_from_response()` to extract usage from individual responses
191
+
192
+ Args:
193
+ llm: The LLamaIndex LLM instance to track usage for
194
+
195
+ Returns:
196
+ TokenCountingHandler: The registered handler that accumulates usage statistics
197
+
198
+ Raises:
199
+ ValueError: If the LLM provider is not supported for tracking
200
+
201
+ Example:
202
+ >>> llm = OpenAI()
203
+ >>> tracker = track_usage(llm)
204
+ >>> # ... make LLM calls ...
205
+ >>> print(f"Total tokens used: {tracker.usage.total_tokens}")
206
+ """
207
+ provider = llm.__class__.__name__
208
+ if provider not in SUPPORTED_PROVIDERS:
209
+ raise ValueError(f"Tracking not yet supported for provider: {provider}")
210
+
211
+ tracker = TokenCountingHandler(provider)
212
+ llm.callback_manager.add_handler(tracker)
213
+ return tracker
@@ -143,6 +143,6 @@ class SimpleCodeExecutor:
143
143
  result = {
144
144
  'output': output,
145
145
  'screenshots': self.globals['step_screenshots'],
146
- 'ui_states': self.globals['step_ui_states']
146
+ 'ui_states': self.globals['step_ui_states'],
147
147
  }
148
148
  return result
@@ -2,9 +2,12 @@ import importlib
2
2
  import logging
3
3
  from typing import Any
4
4
  from llama_index.core.llms.llm import LLM
5
+ from droidrun.agent.usage import track_usage
6
+
5
7
  # Configure logging
6
8
  logger = logging.getLogger("droidrun")
7
9
 
10
+
8
11
  def load_llm(provider_name: str, **kwargs: Any) -> LLM:
9
12
  """
10
13
  Dynamically loads and initializes a LlamaIndex LLM.
@@ -51,29 +54,39 @@ def load_llm(provider_name: str, **kwargs: Any) -> LLM:
51
54
  logger.debug(f"Successfully imported module: {module_path}")
52
55
 
53
56
  except ModuleNotFoundError:
54
- logger.error(f"Module '{module_path}' not found. Try: pip install {install_package_name}")
57
+ logger.error(
58
+ f"Module '{module_path}' not found. Try: pip install {install_package_name}"
59
+ )
55
60
  raise ModuleNotFoundError(
56
61
  f"Could not import '{module_path}'. Is '{install_package_name}' installed?"
57
62
  ) from None
58
63
 
59
64
  try:
60
- logger.debug(f"Attempting to get class '{provider_name}' from module {module_path}")
65
+ logger.debug(
66
+ f"Attempting to get class '{provider_name}' from module {module_path}"
67
+ )
61
68
  llm_class = getattr(llm_module, provider_name)
62
69
  logger.debug(f"Found class: {llm_class.__name__}")
63
70
 
64
71
  # Verify the class is a subclass of LLM
65
72
  if not isinstance(llm_class, type) or not issubclass(llm_class, LLM):
66
- raise TypeError(f"Class '{provider_name}' found in '{module_path}' is not a valid LLM subclass.")
73
+ raise TypeError(
74
+ f"Class '{provider_name}' found in '{module_path}' is not a valid LLM subclass."
75
+ )
67
76
 
68
77
  # Filter out None values from kwargs
69
78
  filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
70
-
79
+
71
80
  # Initialize
72
- logger.debug(f"Initializing {llm_class.__name__} with kwargs: {list(filtered_kwargs.keys())}")
81
+ logger.debug(
82
+ f"Initializing {llm_class.__name__} with kwargs: {list(filtered_kwargs.keys())}"
83
+ )
73
84
  llm_instance = llm_class(**filtered_kwargs)
74
85
  logger.debug(f"Successfully loaded and initialized LLM: {provider_name}")
75
86
  if not llm_instance:
76
- raise RuntimeError(f"Failed to initialize LLM instance for {provider_name}.")
87
+ raise RuntimeError(
88
+ f"Failed to initialize LLM instance for {provider_name}."
89
+ )
77
90
  return llm_instance
78
91
 
79
92
  except AttributeError:
@@ -83,11 +96,12 @@ def load_llm(provider_name: str, **kwargs: Any) -> LLM:
83
96
  ) from None
84
97
  except TypeError as e:
85
98
  logger.error(f"Error initializing {provider_name}: {e}")
86
- raise # Re-raise TypeError (could be from issubclass check or __init__)
99
+ raise # Re-raise TypeError (could be from issubclass check or __init__)
87
100
  except Exception as e:
88
101
  logger.error(f"An unexpected error occurred initializing {provider_name}: {e}")
89
102
  raise e
90
-
103
+
104
+
91
105
  # --- Example Usage ---
92
106
  if __name__ == "__main__":
93
107
  # Install the specific LLM integrations you want to test:
@@ -97,52 +111,75 @@ if __name__ == "__main__":
97
111
  # llama-index-llms-gemini \
98
112
  # llama-index-llms-openai
99
113
 
100
- # Example 1: Load Anthropic (requires ANTHROPIC_API_KEY env var or kwarg)
101
- print("\n--- Loading Anthropic ---")
102
- try:
103
- anthropic_llm = load_llm(
104
- "Anthropic",
105
- model="claude-3-7-sonnet-latest",
106
- )
107
- print(f"Loaded LLM: {type(anthropic_llm)}")
108
- print(f"Model: {anthropic_llm.metadata}")
109
- except Exception as e:
110
- print(f"Failed to load Anthropic: {e}")
114
+ from llama_index.core.base.llms.types import ChatMessage
111
115
 
112
- # Example 2: Load DeepSeek (requires DEEPSEEK_API_KEY env var or kwarg)
113
- print("\n--- Loading DeepSeek ---")
114
- try:
115
- deepseek_llm = load_llm(
116
- "DeepSeek",
117
- model="deepseek-reasoner",
118
- api_key="your api", # or set DEEPSEEK_API_KEY
119
- )
120
- print(f"Loaded LLM: {type(deepseek_llm)}")
121
- print(f"Model: {deepseek_llm.metadata}")
122
- except Exception as e:
123
- print(f"Failed to load DeepSeek: {e}")
116
+ providers = [
117
+ {
118
+ "name": "Anthropic",
119
+ "model": "claude-3-7-sonnet-latest",
120
+ },
121
+ {
122
+ "name": "DeepSeek",
123
+ "model": "deepseek-reasoner",
124
+ },
125
+ {
126
+ "name": "GoogleGenAI",
127
+ "model": "gemini-2.5-flash",
128
+ },
129
+ {
130
+ "name": "OpenAI",
131
+ "model": "gpt-4",
132
+ },
133
+ {
134
+ "name": "Ollama",
135
+ "model": "llama3.2:1b",
136
+ "base_url": "http://localhost:11434",
137
+ },
138
+ ]
124
139
 
125
- # Example 3: Load Gemini (requires GOOGLE_APPLICATION_CREDENTIALS or kwarg)
126
- print("\n--- Loading Gemini ---")
127
- try:
128
- gemini_llm = load_llm(
129
- "Gemini",
130
- model="gemini-2.0-fash",
131
- )
132
- print(f"Loaded LLM: {type(gemini_llm)}")
133
- print(f"Model: {gemini_llm.metadata}")
134
- except Exception as e:
135
- print(f"Failed to load Gemini: {e}")
140
+ system_prompt = ChatMessage(
141
+ role="system",
142
+ content="You are a personal health and food coach. You are given a user's health and food preferences and you need to recommend a meal plan for them. only output the meal plan, no other text.",
143
+ )
136
144
 
137
- # Example 4: Load OpenAI (requires OPENAI_API_KEY env var or kwarg)
138
- print("\n--- Loading OpenAI ---")
139
- try:
140
- openai_llm = load_llm(
141
- "OpenAI",
142
- model="gp-4o",
143
- temperature=0.5,
144
- )
145
- print(f"Loaded LLM: {type(openai_llm)}")
146
- print(f"Model: {openai_llm.metadata}")
147
- except Exception as e:
148
- print(f"Failed to load OpenAI: {e}")
145
+ user_prompt = ChatMessage(
146
+ role="user",
147
+ content="I am a 25 year old male. I am 5'10 and 180 pounds. I am a vegetarian. I am allergic to peanuts and tree nuts. I am allergic to shellfish. I am allergic to eggs. I am allergic to dairy. I am allergic to soy. I am allergic to wheat. I am allergic to corn. I am allergic to oats. I am allergic to rice. I am allergic to barley. I am allergic to rye. I am allergic to oats. I am allergic to rice. I am allergic to barley. I am allergic to rye.",
148
+ )
149
+
150
+ messages = [system_prompt, user_prompt]
151
+
152
+ for provider in providers:
153
+ print(f"\n{'#' * 35} Loading {provider['name']} {'#' * 35}")
154
+ print("-" * 100)
155
+
156
+ try:
157
+ provider_name = provider.pop("name")
158
+ llm = load_llm(provider_name, **provider)
159
+ provider["name"] = provider_name
160
+ print(f"Loaded LLM: {type(llm)}")
161
+ print(f"Model: {llm.metadata}")
162
+ print("-" * 100)
163
+
164
+ tracker = track_usage(llm)
165
+ print(f"Tracker: {type(tracker)}")
166
+ print(f"Usage: {tracker.usage}")
167
+ print("-" * 100)
168
+
169
+ assert tracker.usage.requests == 0
170
+ assert tracker.usage.request_tokens == 0
171
+ assert tracker.usage.response_tokens == 0
172
+ assert tracker.usage.total_tokens == 0
173
+
174
+ res = llm.chat(messages)
175
+ print(f"Response: {res.message.content}")
176
+ print("-" * 100)
177
+ print(f"Usage: {tracker.usage}")
178
+
179
+ assert tracker.usage.requests == 1
180
+ assert tracker.usage.request_tokens > 0
181
+ assert tracker.usage.response_tokens > 0
182
+ assert tracker.usage.total_tokens > tracker.usage.request_tokens
183
+ assert tracker.usage.total_tokens > tracker.usage.response_tokens
184
+ except Exception as e:
185
+ print(f"Failed to load and track usage for {provider['name']}: {e}")