droidrun 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. droidrun/__init__.py +22 -10
  2. droidrun/__main__.py +1 -2
  3. droidrun/adb/__init__.py +3 -3
  4. droidrun/adb/device.py +2 -2
  5. droidrun/adb/manager.py +2 -2
  6. droidrun/agent/__init__.py +5 -15
  7. droidrun/agent/codeact/__init__.py +11 -0
  8. droidrun/agent/codeact/codeact_agent.py +420 -0
  9. droidrun/agent/codeact/events.py +28 -0
  10. droidrun/agent/codeact/prompts.py +26 -0
  11. droidrun/agent/common/default.py +5 -0
  12. droidrun/agent/common/events.py +4 -0
  13. droidrun/agent/context/__init__.py +23 -0
  14. droidrun/agent/context/agent_persona.py +15 -0
  15. droidrun/agent/context/context_injection_manager.py +66 -0
  16. droidrun/agent/context/episodic_memory.py +15 -0
  17. droidrun/agent/context/personas/__init__.py +11 -0
  18. droidrun/agent/context/personas/app_starter.py +44 -0
  19. droidrun/agent/context/personas/default.py +95 -0
  20. droidrun/agent/context/personas/extractor.py +52 -0
  21. droidrun/agent/context/personas/ui_expert.py +107 -0
  22. droidrun/agent/context/reflection.py +20 -0
  23. droidrun/agent/context/task_manager.py +124 -0
  24. droidrun/agent/context/todo.txt +4 -0
  25. droidrun/agent/droid/__init__.py +13 -0
  26. droidrun/agent/droid/droid_agent.py +357 -0
  27. droidrun/agent/droid/events.py +28 -0
  28. droidrun/agent/oneflows/reflector.py +265 -0
  29. droidrun/agent/planner/__init__.py +13 -0
  30. droidrun/agent/planner/events.py +16 -0
  31. droidrun/agent/planner/planner_agent.py +268 -0
  32. droidrun/agent/planner/prompts.py +124 -0
  33. droidrun/agent/utils/__init__.py +3 -0
  34. droidrun/agent/utils/async_utils.py +17 -0
  35. droidrun/agent/utils/chat_utils.py +312 -0
  36. droidrun/agent/utils/executer.py +132 -0
  37. droidrun/agent/utils/llm_picker.py +147 -0
  38. droidrun/agent/utils/trajectory.py +184 -0
  39. droidrun/cli/__init__.py +1 -1
  40. droidrun/cli/logs.py +283 -0
  41. droidrun/cli/main.py +358 -149
  42. droidrun/run.py +105 -0
  43. droidrun/tools/__init__.py +4 -30
  44. droidrun/tools/adb.py +879 -0
  45. droidrun/tools/ios.py +594 -0
  46. droidrun/tools/tools.py +99 -0
  47. droidrun-0.3.0.dist-info/METADATA +149 -0
  48. droidrun-0.3.0.dist-info/RECORD +52 -0
  49. droidrun/agent/llm_reasoning.py +0 -567
  50. droidrun/agent/react_agent.py +0 -556
  51. droidrun/llm/__init__.py +0 -24
  52. droidrun/tools/actions.py +0 -854
  53. droidrun/tools/device.py +0 -29
  54. droidrun-0.1.0.dist-info/METADATA +0 -276
  55. droidrun-0.1.0.dist-info/RECORD +0 -20
  56. {droidrun-0.1.0.dist-info → droidrun-0.3.0.dist-info}/WHEEL +0 -0
  57. {droidrun-0.1.0.dist-info → droidrun-0.3.0.dist-info}/entry_points.txt +0 -0
  58. {droidrun-0.1.0.dist-info → droidrun-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,312 @@
1
+ import base64
2
+ import re
3
+ import inspect
4
+
5
+
6
+ import json
7
+ import logging
8
+ from typing import List, TYPE_CHECKING, Optional, Tuple
9
+ from droidrun.agent.context import Reflection
10
+ from llama_index.core.base.llms.types import ChatMessage, ImageBlock, TextBlock
11
+
12
+ if TYPE_CHECKING:
13
+ from droidrun.tools import Tools
14
+
15
+ logger = logging.getLogger("droidrun")
16
+
17
+ def message_copy(message: ChatMessage, deep = True) -> ChatMessage:
18
+ if deep:
19
+ copied_message = message.model_copy()
20
+ copied_message.blocks = [block.model_copy () for block in message.blocks]
21
+
22
+ return copied_message
23
+ copied_message = message.model_copy()
24
+
25
+ # Create a new, independent list containing the same block references
26
+ copied_message.blocks = list(message.blocks) # or original_message.blocks[:]
27
+
28
+ return copied_message
29
+
30
+ async def add_reflection_summary(reflection: Reflection, chat_history: List[ChatMessage]) -> List[ChatMessage]:
31
+ """Add reflection summary and advice to help the planner understand what went wrong and what to do differently."""
32
+
33
+ reflection_text = "\n### The last task failed. You have additional information about what happenend. \nThe Reflection from Previous Attempt:\n"
34
+
35
+ if reflection.summary:
36
+ reflection_text += f"**What happened:** {reflection.summary}\n\n"
37
+
38
+ if reflection.advice:
39
+ reflection_text += f"**Recommended approach for this retry:** {reflection.advice}\n"
40
+
41
+ reflection_block = TextBlock(text=reflection_text)
42
+
43
+ # Copy chat_history and append reflection block to the last message
44
+ chat_history = chat_history.copy()
45
+ chat_history[-1] = message_copy(chat_history[-1])
46
+ chat_history[-1].blocks.append(reflection_block)
47
+
48
+ return chat_history
49
+
50
+ def _format_ui_elements(ui_data, level=0) -> str:
51
+ """Format UI elements in natural language: index. className: resourceId, text - bounds"""
52
+ if not ui_data:
53
+ return ""
54
+
55
+ formatted_lines = []
56
+ indent = " " * level # Indentation for nested elements
57
+
58
+ # Handle both list and single element
59
+ elements = ui_data if isinstance(ui_data, list) else [ui_data]
60
+
61
+ for element in elements:
62
+ if not isinstance(element, dict):
63
+ continue
64
+
65
+ # Extract element properties
66
+ index = element.get('index', '')
67
+ class_name = element.get('className', '')
68
+ resource_id = element.get('resourceId', '')
69
+ text = element.get('text', '')
70
+ bounds = element.get('bounds', '')
71
+ children = element.get('children', [])
72
+
73
+
74
+ # Format the line: index. className: resourceId, text - bounds
75
+ line_parts = []
76
+ if index != '':
77
+ line_parts.append(f"{index}.")
78
+ if class_name:
79
+ line_parts.append(class_name + ":")
80
+
81
+ details = []
82
+ if resource_id:
83
+ details.append(f'"{resource_id}"')
84
+ if text:
85
+ details.append(f'"{text}"')
86
+ if details:
87
+ line_parts.append(", ".join(details))
88
+
89
+ if bounds:
90
+ line_parts.append(f"- ({bounds})")
91
+
92
+ formatted_line = f"{indent}{' '.join(line_parts)}"
93
+ formatted_lines.append(formatted_line)
94
+
95
+ # Recursively format children with increased indentation
96
+ if children:
97
+ child_formatted = _format_ui_elements(children, level + 1)
98
+ if child_formatted:
99
+ formatted_lines.append(child_formatted)
100
+
101
+ return "\n".join(formatted_lines)
102
+
103
+ async def add_ui_text_block(ui_state: str, chat_history: List[ChatMessage], copy = True) -> List[ChatMessage]:
104
+ """Add UI elements to the chat history without modifying the original."""
105
+ if ui_state:
106
+ # Parse the JSON and format it in natural language
107
+ try:
108
+ ui_data = json.loads(ui_state) if isinstance(ui_state, str) else ui_state
109
+ formatted_ui = _format_ui_elements(ui_data)
110
+ ui_block = TextBlock(text=f"\nCurrent Clickable UI elements from the device in the schema 'index. className: resourceId, text - bounds(x1,y1,x2,y2)':\n{formatted_ui}\n")
111
+ except (json.JSONDecodeError, TypeError):
112
+ # Fallback to original format if parsing fails
113
+ ui_block = TextBlock(text="\nCurrent Clickable UI elements from the device using the custom TopViewService:\n```json\n" + json.dumps(ui_state) + "\n```\n")
114
+
115
+ if copy:
116
+ chat_history = chat_history.copy()
117
+ chat_history[-1] = message_copy(chat_history[-1])
118
+ chat_history[-1].blocks.append(ui_block)
119
+ return chat_history
120
+
121
+ async def add_screenshot_image_block(screenshot, chat_history: List[ChatMessage], copy = True) -> None:
122
+ if screenshot:
123
+ image_block = ImageBlock(image=base64.b64encode(screenshot))
124
+ if copy:
125
+ chat_history = chat_history.copy() # Create a copy of chat history to avoid modifying the original
126
+ chat_history[-1] = message_copy(chat_history[-1])
127
+ chat_history[-1].blocks.append(image_block)
128
+ return chat_history
129
+
130
+
131
+ async def add_phone_state_block(phone_state, chat_history: List[ChatMessage]) -> List[ChatMessage]:
132
+
133
+ # Format the phone state data nicely
134
+ if isinstance(phone_state, dict) and 'error' not in phone_state:
135
+ current_app = phone_state.get('currentApp', 'Unknown')
136
+ package_name = phone_state.get('packageName', 'Unknown')
137
+ keyboard_visible = phone_state.get('keyboardVisible', False)
138
+ focused_element = phone_state.get('focusedElement')
139
+
140
+ # Format the focused element
141
+ if focused_element:
142
+ element_text = focused_element.get('text', 'No text')
143
+ element_class = focused_element.get('className', 'Unknown')
144
+ element_bounds = focused_element.get('bounds', 'Unknown')
145
+ element_type = focused_element.get('type', 'unknown')
146
+ element_resource_id = focused_element.get('resourceId', '')
147
+
148
+ # Build focused element description
149
+ focused_desc = f"'{element_text}' ({element_class})"
150
+ if element_resource_id:
151
+ focused_desc += f" | ID: {element_resource_id}"
152
+ focused_desc += f" | Bounds: {element_bounds} | Type: {element_type}"
153
+ else:
154
+ focused_desc = "None"
155
+
156
+ phone_state_text = f"""
157
+ **Current Phone State:**
158
+ • **App:** {current_app} ({package_name})
159
+ • **Keyboard:** {'Visible' if keyboard_visible else 'Hidden'}
160
+ • **Focused Element:** {focused_desc}
161
+ """
162
+ else:
163
+ # Handle error cases or malformed data
164
+ if isinstance(phone_state, dict) and 'error' in phone_state:
165
+ phone_state_text = f"\n📱 **Phone State Error:** {phone_state.get('message', 'Unknown error')}\n"
166
+ else:
167
+ phone_state_text = f"\n📱 **Phone State:** {phone_state}\n"
168
+
169
+ ui_block = TextBlock(text=phone_state_text)
170
+ chat_history = chat_history.copy()
171
+ chat_history[-1] = message_copy(chat_history[-1])
172
+ chat_history[-1].blocks.append(ui_block)
173
+ return chat_history
174
+
175
+ async def add_packages_block(packages, chat_history: List[ChatMessage]) -> List[ChatMessage]:
176
+
177
+ ui_block = TextBlock(text=f"\nInstalled packages: {packages}\n```\n")
178
+ chat_history = chat_history.copy()
179
+ chat_history[-1] = message_copy(chat_history[-1])
180
+ chat_history[-1].blocks.append(ui_block)
181
+ return chat_history
182
+
183
+ async def add_memory_block(memory: List[str], chat_history: List[ChatMessage]) -> List[ChatMessage]:
184
+ memory_block = "\n### Remembered Information:\n"
185
+ for idx, item in enumerate(memory, 1):
186
+ memory_block += f"{idx}. {item}\n"
187
+
188
+ for i, msg in enumerate(chat_history):
189
+ if msg.role == "user":
190
+ if isinstance(msg.content, str):
191
+ updated_content = f"{memory_block}\n\n{msg.content}"
192
+ chat_history[i] = ChatMessage(role="user", content=updated_content)
193
+ elif isinstance(msg.content, list):
194
+ memory_text_block = TextBlock(text=memory_block)
195
+ content_blocks = [memory_text_block] + msg.content
196
+ chat_history[i] = ChatMessage(role="user", content=content_blocks)
197
+ break
198
+ return chat_history
199
+
200
+ async def get_reflection_block(reflections: List[Reflection]) -> ChatMessage:
201
+ reflection_block = "\n### You also have additional Knowledge to help you guide your current task from previous expierences:\n"
202
+ for reflection in reflections:
203
+ reflection_block += f"**{reflection.advice}\n"
204
+
205
+ return ChatMessage(role="user", content=reflection_block)
206
+
207
+ async def add_task_history_block(completed_tasks: list[dict], failed_tasks: list[dict], chat_history: List[ChatMessage]) -> List[ChatMessage]:
208
+ task_history = ""
209
+
210
+ # Combine all tasks and show in chronological order
211
+ all_tasks = completed_tasks + failed_tasks
212
+
213
+ if all_tasks:
214
+ task_history += "Task History (chronological order):\n"
215
+ for i, task in enumerate(all_tasks, 1):
216
+ if hasattr(task, 'description'):
217
+ status_indicator = "[success]" if hasattr(task, 'status') and task.status == "completed" else "[failed]"
218
+ task_history += f"{i}. {status_indicator} {task.description}\n"
219
+ elif isinstance(task, dict):
220
+ # For backward compatibility with dict format
221
+ task_description = task.get('description', str(task))
222
+ status_indicator = "[success]" if task in completed_tasks else "[failed]"
223
+ task_history += f"{i}. {status_indicator} {task_description}\n"
224
+ else:
225
+ status_indicator = "[success]" if task in completed_tasks else "[failed]"
226
+ task_history += f"{i}. {status_indicator} {task}\n"
227
+
228
+
229
+ task_block = TextBlock(text=f"{task_history}")
230
+
231
+ chat_history = chat_history.copy()
232
+ chat_history[-1] = message_copy(chat_history[-1])
233
+ chat_history[-1].blocks.append(task_block)
234
+ return chat_history
235
+
236
+ def parse_tool_descriptions(tool_list) -> str:
237
+ """Parses the available tools and their descriptions for the system prompt."""
238
+ logger.info("🛠️ Parsing tool descriptions...")
239
+ tool_descriptions = []
240
+
241
+ for tool in tool_list.values():
242
+ assert callable(tool), f"Tool {tool} is not callable."
243
+ tool_name = tool.__name__
244
+ tool_signature = inspect.signature(tool)
245
+ tool_docstring = tool.__doc__ or "No description available."
246
+ formatted_signature = f"def {tool_name}{tool_signature}:\n \"\"\"{tool_docstring}\"\"\"\n..."
247
+ tool_descriptions.append(formatted_signature)
248
+ logger.debug(f" - Parsed tool: {tool_name}")
249
+ descriptions = "\n".join(tool_descriptions)
250
+ logger.info(f"🔩 Found {len(tool_descriptions)} tools.")
251
+ return descriptions
252
+
253
+
254
+ def parse_persona_description(personas) -> str:
255
+ """Parses the available agent personas and their descriptions for the system prompt."""
256
+ logger.debug("👥 Parsing agent persona descriptions for Planner Agent...")
257
+
258
+ if not personas:
259
+ logger.warning("No agent personas provided to Planner Agent")
260
+ return "No specialized agents available."
261
+
262
+ persona_descriptions = []
263
+ for persona in personas:
264
+ # Format each persona with name, description, and expertise areas
265
+ expertise_list = ", ".join(persona.expertise_areas) if persona.expertise_areas else "General tasks"
266
+ formatted_persona = f"- **{persona.name}**: {persona.description}\n Expertise: {expertise_list}"
267
+ persona_descriptions.append(formatted_persona)
268
+ logger.debug(f" - Parsed persona: {persona.name}")
269
+
270
+ # Join all persona descriptions into a single string
271
+ descriptions = "\n".join(persona_descriptions)
272
+ logger.debug(f"👤 Found {len(persona_descriptions)} agent personas.")
273
+ return descriptions
274
+
275
+
276
+ def extract_code_and_thought(response_text: str) -> Tuple[Optional[str], str]:
277
+ """
278
+ Extracts code from Markdown blocks (```python ... ```) and the surrounding text (thought),
279
+ handling indented code blocks.
280
+
281
+ Returns:
282
+ Tuple[Optional[code_string], thought_string]
283
+ """
284
+ logger.debug("✂️ Extracting code and thought from response...")
285
+ code_pattern = r"^\s*```python\s*\n(.*?)\n^\s*```\s*?$"
286
+ code_matches = list(re.finditer(code_pattern, response_text, re.DOTALL | re.MULTILINE))
287
+
288
+ if not code_matches:
289
+ logger.debug(" - No code block found. Entire response is thought.")
290
+ return None, response_text.strip()
291
+
292
+ extracted_code_parts = []
293
+ for match in code_matches:
294
+ code_content = match.group(1)
295
+ extracted_code_parts.append(code_content)
296
+
297
+ extracted_code = "\n\n".join(extracted_code_parts)
298
+
299
+
300
+ thought_parts = []
301
+ last_end = 0
302
+ for match in code_matches:
303
+ start, end = match.span(0)
304
+ thought_parts.append(response_text[last_end:start])
305
+ last_end = end
306
+ thought_parts.append(response_text[last_end:])
307
+
308
+ thought_text = "".join(thought_parts).strip()
309
+ thought_preview = (thought_text[:100] + '...') if len(thought_text) > 100 else thought_text
310
+ logger.debug(f" - Extracted thought: {thought_preview}")
311
+
312
+ return extracted_code, thought_text
@@ -0,0 +1,132 @@
1
+ import io
2
+ import contextlib
3
+ import ast
4
+ import traceback
5
+ import logging
6
+ from typing import Any, Dict
7
+ from droidrun.agent.utils.async_utils import async_to_sync
8
+ from llama_index.core.workflow import Context
9
+ import asyncio
10
+ from asyncio import AbstractEventLoop
11
+ import threading
12
+
13
+ logger = logging.getLogger("droidrun")
14
+
15
+
16
+ class SimpleCodeExecutor:
17
+ """
18
+ A simple code executor that runs Python code with state persistence.
19
+
20
+ This executor maintains a global and local state between executions,
21
+ allowing for variables to persist across multiple code runs.
22
+
23
+ NOTE: not safe for production use! Use with caution.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ loop: AbstractEventLoop,
29
+ locals: Dict[str, Any] = {},
30
+ globals: Dict[str, Any] = {},
31
+ tools={},
32
+ use_same_scope: bool = True,
33
+ ):
34
+ """
35
+ Initialize the code executor.
36
+
37
+ Args:
38
+ locals: Local variables to use in the execution context
39
+ globals: Global variables to use in the execution context
40
+ tools: List of tools available for execution
41
+ """
42
+
43
+ # loop throught tools and add them to globals, but before that check if tool value is async, if so convert it to sync. tools is a dictionary of tool name: function
44
+ # e.g. tools = {'tool_name': tool_function}
45
+
46
+ # check if tools is a dictionary
47
+ if isinstance(tools, dict):
48
+ logger.debug(
49
+ f"🔧 Initializing SimpleCodeExecutor with tools: {tools.items()}"
50
+ )
51
+ for tool_name, tool_function in tools.items():
52
+ if asyncio.iscoroutinefunction(tool_function):
53
+ # If the function is async, convert it to sync
54
+ tool_function = async_to_sync(tool_function)
55
+ # Add the tool to globals
56
+ globals[tool_name] = tool_function
57
+ elif isinstance(tools, list):
58
+ logger.debug(f"🔧 Initializing SimpleCodeExecutor with tools: {tools}")
59
+ # If tools is a list, convert it to a dictionary with tool name as key and function as value
60
+ for tool in tools:
61
+ if asyncio.iscoroutinefunction(tool):
62
+ # If the function is async, convert it to sync
63
+ tool = async_to_sync(tool)
64
+ # Add the tool to globals
65
+ globals[tool.__name__] = tool
66
+ else:
67
+ raise ValueError("Tools must be a dictionary or a list of functions.")
68
+
69
+ import time
70
+
71
+ globals["time"] = time
72
+
73
+ self.globals = globals
74
+ self.locals = locals
75
+ self.loop = loop
76
+ self.use_same_scope = use_same_scope
77
+ if self.use_same_scope:
78
+ # If using the same scope, set the globals and locals to the same dictionary
79
+ self.globals = self.locals = {
80
+ **self.locals,
81
+ **{k: v for k, v in self.globals.items() if k not in self.locals},
82
+ }
83
+
84
+ async def execute(self, ctx: Context, code: str) -> str:
85
+ """
86
+ Execute Python code and capture output and return values.
87
+
88
+ Args:
89
+ code: Python code to execute
90
+
91
+ Returns:
92
+ str: Output from the execution, including print statements.
93
+ """
94
+ # Update UI elements before execution
95
+ self.globals['ui_state'] = await ctx.get("ui_state", None)
96
+
97
+ # Capture stdout and stderr
98
+ stdout = io.StringIO()
99
+ stderr = io.StringIO()
100
+
101
+ output = ""
102
+ try:
103
+ # Execute with captured output
104
+ thread_exception = []
105
+ with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
106
+
107
+ def execute_code():
108
+ try:
109
+ exec(code, self.globals, self.locals)
110
+ except Exception as e:
111
+ import traceback
112
+
113
+ thread_exception.append((e, traceback.format_exc()))
114
+
115
+ t = threading.Thread(target=execute_code)
116
+ t.start()
117
+ t.join()
118
+
119
+ # Get output
120
+ output = stdout.getvalue()
121
+ if stderr.getvalue():
122
+ output += "\n" + stderr.getvalue()
123
+ if thread_exception:
124
+ e, tb = thread_exception[0]
125
+ output += f"\nError: {type(e).__name__}: {str(e)}\n{tb}"
126
+
127
+ except Exception as e:
128
+ # Capture exception information
129
+ output = f"Error: {type(e).__name__}: {str(e)}\n"
130
+ output += traceback.format_exc()
131
+
132
+ return output
@@ -0,0 +1,147 @@
1
+ import importlib
2
+ import logging
3
+ from typing import Any
4
+ from llama_index.core.llms.llm import LLM
5
+ # Configure logging
6
+ logger = logging.getLogger("droidrun")
7
+
8
+ def load_llm(provider_name: str, **kwargs: Any) -> LLM:
9
+ """
10
+ Dynamically loads and initializes a LlamaIndex LLM.
11
+
12
+ Imports `llama_index.llms.<provider_name_lower>`, finds the class named
13
+ `provider_name` within that module, verifies it's an LLM subclass,
14
+ and initializes it with kwargs.
15
+
16
+ Args:
17
+ provider_name: The case-sensitive name of the provider and the class
18
+ (e.g., "OpenAI", "Ollama", "HuggingFaceLLM").
19
+ **kwargs: Keyword arguments for the LLM class constructor.
20
+
21
+ Returns:
22
+ An initialized LLM instance.
23
+
24
+ Raises:
25
+ ModuleNotFoundError: If the provider's module cannot be found.
26
+ AttributeError: If the class `provider_name` is not found in the module.
27
+ TypeError: If the found class is not a subclass of LLM or if kwargs are invalid.
28
+ RuntimeError: For other initialization errors.
29
+ """
30
+ if not provider_name:
31
+ raise ValueError("provider_name cannot be empty.")
32
+ if provider_name == "OpenAILike":
33
+ module_provider_part = "openai_like"
34
+ elif provider_name == "GoogleGenAI":
35
+ module_provider_part = "google_genai"
36
+ else:
37
+ # Use lowercase for module path, handle hyphens for package name suggestion
38
+ lower_provider_name = provider_name.lower()
39
+ # Special case common variations like HuggingFaceLLM -> huggingface module
40
+ if lower_provider_name.endswith("llm"):
41
+ module_provider_part = lower_provider_name[:-3].replace("-", "_")
42
+ else:
43
+ module_provider_part = lower_provider_name.replace("-", "_")
44
+ module_path = f"llama_index.llms.{module_provider_part}"
45
+ install_package_name = f"llama-index-llms-{module_provider_part.replace('_', '-')}"
46
+
47
+ try:
48
+ logger.debug(f"Attempting to import module: {module_path}")
49
+ llm_module = importlib.import_module(module_path)
50
+ logger.debug(f"Successfully imported module: {module_path}")
51
+
52
+ except ModuleNotFoundError:
53
+ logger.error(f"Module '{module_path}' not found. Try: pip install {install_package_name}")
54
+ raise ModuleNotFoundError(
55
+ f"Could not import '{module_path}'. Is '{install_package_name}' installed?"
56
+ ) from None
57
+
58
+ try:
59
+ logger.debug(f"Attempting to get class '{provider_name}' from module {module_path}")
60
+ llm_class = getattr(llm_module, provider_name)
61
+ logger.debug(f"Found class: {llm_class.__name__}")
62
+
63
+ # Verify the class is a subclass of LLM
64
+ if not isinstance(llm_class, type) or not issubclass(llm_class, LLM):
65
+ raise TypeError(f"Class '{provider_name}' found in '{module_path}' is not a valid LLM subclass.")
66
+
67
+ # Filter out None values from kwargs
68
+ filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
69
+
70
+ # Initialize
71
+ logger.debug(f"Initializing {llm_class.__name__} with kwargs: {list(filtered_kwargs.keys())}")
72
+ llm_instance = llm_class(**filtered_kwargs)
73
+ logger.debug(f"Successfully loaded and initialized LLM: {provider_name}")
74
+ if not llm_instance:
75
+ raise RuntimeError(f"Failed to initialize LLM instance for {provider_name}.")
76
+ return llm_instance
77
+
78
+ except AttributeError:
79
+ logger.error(f"Class '{provider_name}' not found in module '{module_path}'.")
80
+ raise AttributeError(
81
+ f"Could not find class '{provider_name}' in module '{module_path}'. Check spelling and capitalization."
82
+ ) from None
83
+ except TypeError as e:
84
+ logger.error(f"Error initializing {provider_name}: {e}")
85
+ raise # Re-raise TypeError (could be from issubclass check or __init__)
86
+ except Exception as e:
87
+ logger.error(f"An unexpected error occurred initializing {provider_name}: {e}")
88
+ raise e
89
+
90
+ # --- Example Usage ---
91
+ if __name__ == "__main__":
92
+ # Install the specific LLM integrations you want to test:
93
+ # pip install \
94
+ # llama-index-llms-anthropic \
95
+ # llama-index-llms-deepseek \
96
+ # llama-index-llms-gemini \
97
+ # llama-index-llms-openai
98
+
99
+ # Example 1: Load Anthropic (requires ANTHROPIC_API_KEY env var or kwarg)
100
+ print("\n--- Loading Anthropic ---")
101
+ try:
102
+ anthropic_llm = load_llm(
103
+ "Anthropic",
104
+ model="claude-3-7-sonnet-latest",
105
+ )
106
+ print(f"Loaded LLM: {type(anthropic_llm)}")
107
+ print(f"Model: {anthropic_llm.metadata}")
108
+ except Exception as e:
109
+ print(f"Failed to load Anthropic: {e}")
110
+
111
+ # Example 2: Load DeepSeek (requires DEEPSEEK_API_KEY env var or kwarg)
112
+ print("\n--- Loading DeepSeek ---")
113
+ try:
114
+ deepseek_llm = load_llm(
115
+ "DeepSeek",
116
+ model="deepseek-reasoner",
117
+ api_key="your api", # or set DEEPSEEK_API_KEY
118
+ )
119
+ print(f"Loaded LLM: {type(deepseek_llm)}")
120
+ print(f"Model: {deepseek_llm.metadata}")
121
+ except Exception as e:
122
+ print(f"Failed to load DeepSeek: {e}")
123
+
124
+ # Example 3: Load Gemini (requires GOOGLE_APPLICATION_CREDENTIALS or kwarg)
125
+ print("\n--- Loading Gemini ---")
126
+ try:
127
+ gemini_llm = load_llm(
128
+ "Gemini",
129
+ model="gemini-2.0-fash",
130
+ )
131
+ print(f"Loaded LLM: {type(gemini_llm)}")
132
+ print(f"Model: {gemini_llm.metadata}")
133
+ except Exception as e:
134
+ print(f"Failed to load Gemini: {e}")
135
+
136
+ # Example 4: Load OpenAI (requires OPENAI_API_KEY env var or kwarg)
137
+ print("\n--- Loading OpenAI ---")
138
+ try:
139
+ openai_llm = load_llm(
140
+ "OpenAI",
141
+ model="gp-4o",
142
+ temperature=0.5,
143
+ )
144
+ print(f"Loaded LLM: {type(openai_llm)}")
145
+ print(f"Model: {openai_llm.metadata}")
146
+ except Exception as e:
147
+ print(f"Failed to load OpenAI: {e}")