hanzo-mcp 0.1.36__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hanzo-mcp might be problematic. Click here for more details.
- hanzo_mcp/__init__.py +1 -1
- hanzo_mcp/cli.py +136 -31
- hanzo_mcp/server.py +34 -13
- hanzo_mcp/tools/__init__.py +4 -2
- hanzo_mcp/tools/agent/base_provider.py +73 -0
- hanzo_mcp/tools/agent/litellm_provider.py +45 -0
- hanzo_mcp/tools/agent/lmstudio_agent.py +385 -0
- hanzo_mcp/tools/agent/lmstudio_provider.py +219 -0
- hanzo_mcp/tools/agent/provider_registry.py +120 -0
- hanzo_mcp/tools/common/__init__.py +0 -1
- hanzo_mcp/tools/common/context.py +6 -8
- hanzo_mcp/tools/common/error_handling.py +86 -0
- hanzo_mcp/tools/common/logging_config.py +84 -0
- hanzo_mcp/tools/common/permissions.py +6 -8
- hanzo_mcp/tools/filesystem/__init__.py +6 -1
- hanzo_mcp/tools/filesystem/directory_tree.py +7 -46
- hanzo_mcp/tools/jupyter/__init__.py +6 -1
- hanzo_mcp/tools/shell/command_executor.py +7 -6
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/METADATA +1 -1
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/RECORD +24 -18
- hanzo_mcp/tools/common/path_utils.py +0 -51
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/WHEEL +0 -0
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/entry_points.txt +0 -0
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {hanzo_mcp-0.1.36.dist-info → hanzo_mcp-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""LM Studio agent tool for parallel model execution.
|
|
2
|
+
|
|
3
|
+
This module provides a tool for running tasks in parallel across multiple LM Studio models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any, Dict, List, Optional, final, override
|
|
11
|
+
|
|
12
|
+
from mcp.server.fastmcp import Context as MCPContext
|
|
13
|
+
from mcp.server.fastmcp import FastMCP
|
|
14
|
+
|
|
15
|
+
from hanzo_mcp.tools.common.base import BaseTool
|
|
16
|
+
from hanzo_mcp.tools.common.context import DocumentContext, create_tool_context
|
|
17
|
+
from hanzo_mcp.tools.common.permissions import PermissionManager
|
|
18
|
+
from hanzo_mcp.tools.agent.lmstudio_provider import LMStudioProvider
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@final
|
|
24
|
+
class LMStudioAgentTool(BaseTool):
|
|
25
|
+
"""Tool for parallel execution of tasks across multiple LM Studio models."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
@override
|
|
29
|
+
def name(self) -> str:
|
|
30
|
+
"""Get the tool name.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tool name
|
|
34
|
+
"""
|
|
35
|
+
return "lmstudio_dispatch"
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@override
|
|
39
|
+
def description(self) -> str:
|
|
40
|
+
"""Get the tool description.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tool description
|
|
44
|
+
"""
|
|
45
|
+
return """Run tasks in parallel across multiple LM Studio models.
|
|
46
|
+
|
|
47
|
+
This tool allows you to dispatch the same task or different tasks to multiple locally available
|
|
48
|
+
LM Studio models and execute them in parallel. This is useful for comparing model responses,
|
|
49
|
+
leveraging different model strengths, or simply speeding up processing by distributing tasks.
|
|
50
|
+
|
|
51
|
+
The task prompts can be the same for all models or different per model.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_tasks: A list of configurations, each with a 'model' name, and a 'prompt'.
|
|
55
|
+
Optionally can include 'system_prompt', 'temperature', 'max_tokens' and 'top_p'
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Combined results from all model executions with performance metrics
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@override
|
|
63
|
+
def parameters(self) -> dict[str, Any]:
|
|
64
|
+
"""Get the parameter specifications for the tool.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Parameter specifications
|
|
68
|
+
"""
|
|
69
|
+
return {
|
|
70
|
+
"properties": {
|
|
71
|
+
"model_tasks": {
|
|
72
|
+
"type": "array",
|
|
73
|
+
"items": {
|
|
74
|
+
"type": "object",
|
|
75
|
+
"properties": {
|
|
76
|
+
"model": {
|
|
77
|
+
"type": "string",
|
|
78
|
+
"description": "Name of the LM Studio model to use"
|
|
79
|
+
},
|
|
80
|
+
"identifier": {
|
|
81
|
+
"type": "string",
|
|
82
|
+
"description": "Optional identifier for the model instance"
|
|
83
|
+
},
|
|
84
|
+
"prompt": {
|
|
85
|
+
"type": "string",
|
|
86
|
+
"description": "Task prompt for the model"
|
|
87
|
+
},
|
|
88
|
+
"system_prompt": {
|
|
89
|
+
"type": "string",
|
|
90
|
+
"description": "Optional system prompt for the model"
|
|
91
|
+
},
|
|
92
|
+
"temperature": {
|
|
93
|
+
"type": "number",
|
|
94
|
+
"description": "Sampling temperature (defaults to 0.7)"
|
|
95
|
+
},
|
|
96
|
+
"max_tokens": {
|
|
97
|
+
"type": "integer",
|
|
98
|
+
"description": "Maximum tokens to generate (defaults to 2048)"
|
|
99
|
+
},
|
|
100
|
+
"top_p": {
|
|
101
|
+
"type": "number",
|
|
102
|
+
"description": "Top-p sampling parameter (defaults to 0.95)"
|
|
103
|
+
}
|
|
104
|
+
},
|
|
105
|
+
"required": ["model", "prompt"]
|
|
106
|
+
},
|
|
107
|
+
"description": "List of model task configurations to execute in parallel"
|
|
108
|
+
}
|
|
109
|
+
},
|
|
110
|
+
"required": ["model_tasks"],
|
|
111
|
+
"type": "object"
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
@override
|
|
116
|
+
def required(self) -> list[str]:
|
|
117
|
+
"""Get the list of required parameter names.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of required parameter names
|
|
121
|
+
"""
|
|
122
|
+
return ["model_tasks"]
|
|
123
|
+
|
|
124
|
+
def __init__(self, document_context: DocumentContext, permission_manager: PermissionManager) -> None:
|
|
125
|
+
"""Initialize the LM Studio agent tool.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
document_context: Document context for tracking file contents
|
|
129
|
+
permission_manager: Permission manager for access control
|
|
130
|
+
"""
|
|
131
|
+
self.document_context = document_context
|
|
132
|
+
self.permission_manager = permission_manager
|
|
133
|
+
self.provider = LMStudioProvider()
|
|
134
|
+
|
|
135
|
+
@override
|
|
136
|
+
async def call(self, ctx: MCPContext, **params: Any) -> str:
|
|
137
|
+
"""Execute the tool with the given parameters.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
ctx: MCP context
|
|
141
|
+
**params: Tool parameters
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tool execution result
|
|
145
|
+
"""
|
|
146
|
+
start_time = time.time()
|
|
147
|
+
|
|
148
|
+
# Create tool context
|
|
149
|
+
tool_ctx = create_tool_context(ctx)
|
|
150
|
+
tool_ctx.set_tool_info(self.name)
|
|
151
|
+
|
|
152
|
+
# Extract parameters
|
|
153
|
+
model_tasks = params.get("model_tasks")
|
|
154
|
+
if not model_tasks:
|
|
155
|
+
await tool_ctx.error("Parameter 'model_tasks' is required but was not provided")
|
|
156
|
+
return "Error: Parameter 'model_tasks' is required but was not provided"
|
|
157
|
+
|
|
158
|
+
if not isinstance(model_tasks, list):
|
|
159
|
+
await tool_ctx.error("Parameter 'model_tasks' must be an array")
|
|
160
|
+
return "Error: Parameter 'model_tasks' must be an array"
|
|
161
|
+
|
|
162
|
+
if not model_tasks:
|
|
163
|
+
await tool_ctx.error("At least one model task must be provided")
|
|
164
|
+
return "Error: At least one model task must be provided"
|
|
165
|
+
|
|
166
|
+
# Validate each model task
|
|
167
|
+
for i, task in enumerate(model_tasks):
|
|
168
|
+
if not isinstance(task, dict):
|
|
169
|
+
await tool_ctx.error(f"Model task at index {i} must be an object")
|
|
170
|
+
return f"Error: Model task at index {i} must be an object"
|
|
171
|
+
|
|
172
|
+
if "model" not in task:
|
|
173
|
+
await tool_ctx.error(f"Model task at index {i} must have a 'model' property")
|
|
174
|
+
return f"Error: Model task at index {i} must have a 'model' property"
|
|
175
|
+
|
|
176
|
+
if "prompt" not in task:
|
|
177
|
+
await tool_ctx.error(f"Model task at index {i} must have a 'prompt' property")
|
|
178
|
+
return f"Error: Model task at index {i} must have a 'prompt' property"
|
|
179
|
+
|
|
180
|
+
# Initialize the provider if needed
|
|
181
|
+
await self.provider.initialize()
|
|
182
|
+
|
|
183
|
+
# Execute the tasks in parallel
|
|
184
|
+
await tool_ctx.info(f"Executing {len(model_tasks)} tasks across LM Studio models")
|
|
185
|
+
result = await self._execute_parallel_tasks(model_tasks, tool_ctx)
|
|
186
|
+
|
|
187
|
+
# Calculate execution time
|
|
188
|
+
execution_time = time.time() - start_time
|
|
189
|
+
|
|
190
|
+
# Format the result
|
|
191
|
+
formatted_result = self._format_result(result, execution_time)
|
|
192
|
+
|
|
193
|
+
# Log completion
|
|
194
|
+
await tool_ctx.info(f"LM Studio model execution completed in {execution_time:.2f}s")
|
|
195
|
+
|
|
196
|
+
return formatted_result
|
|
197
|
+
|
|
198
|
+
async def _execute_parallel_tasks(self, model_tasks: List[Dict[str, Any]], tool_ctx: Any) -> List[Dict[str, Any]]:
|
|
199
|
+
"""Execute multiple model tasks in parallel.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model_tasks: List of model task configurations
|
|
203
|
+
tool_ctx: Tool context for logging
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
List of task results
|
|
207
|
+
"""
|
|
208
|
+
# Create tasks for loading models
|
|
209
|
+
load_tasks = []
|
|
210
|
+
|
|
211
|
+
for task in model_tasks:
|
|
212
|
+
model_name = task["model"]
|
|
213
|
+
identifier = task.get("identifier")
|
|
214
|
+
|
|
215
|
+
await tool_ctx.info(f"Loading model: {model_name}" + (f" as {identifier}" if identifier else ""))
|
|
216
|
+
load_tasks.append(self.provider.load_model(model_name, identifier))
|
|
217
|
+
|
|
218
|
+
# Wait for all models to load
|
|
219
|
+
try:
|
|
220
|
+
model_ids = await asyncio.gather(*load_tasks)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
await tool_ctx.error(f"Failed to load models: {str(e)}")
|
|
223
|
+
return [{"error": f"Failed to load models: {str(e)}"}]
|
|
224
|
+
|
|
225
|
+
# Create tasks for generating responses
|
|
226
|
+
generation_tasks = []
|
|
227
|
+
|
|
228
|
+
for i, (task, model_id) in enumerate(zip(model_tasks, model_ids)):
|
|
229
|
+
prompt = task["prompt"]
|
|
230
|
+
system_prompt = task.get("system_prompt")
|
|
231
|
+
max_tokens = task.get("max_tokens", 2048)
|
|
232
|
+
temperature = task.get("temperature", 0.7)
|
|
233
|
+
top_p = task.get("top_p", 0.95)
|
|
234
|
+
|
|
235
|
+
await tool_ctx.info(f"Generating with model {model_id}")
|
|
236
|
+
generation_tasks.append(
|
|
237
|
+
self._execute_single_task(
|
|
238
|
+
model_id=model_id,
|
|
239
|
+
prompt=prompt,
|
|
240
|
+
system_prompt=system_prompt,
|
|
241
|
+
max_tokens=max_tokens,
|
|
242
|
+
temperature=temperature,
|
|
243
|
+
top_p=top_p,
|
|
244
|
+
task_index=i,
|
|
245
|
+
tool_ctx=tool_ctx,
|
|
246
|
+
original_task=task
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Wait for all generation tasks to complete
|
|
251
|
+
results = await asyncio.gather(*generation_tasks, return_exceptions=True)
|
|
252
|
+
|
|
253
|
+
# Process results, handling any exceptions
|
|
254
|
+
processed_results = []
|
|
255
|
+
for i, result in enumerate(results):
|
|
256
|
+
if isinstance(result, Exception):
|
|
257
|
+
processed_results.append({
|
|
258
|
+
"model": model_tasks[i]["model"],
|
|
259
|
+
"error": str(result),
|
|
260
|
+
"success": False
|
|
261
|
+
})
|
|
262
|
+
else:
|
|
263
|
+
processed_results.append(result)
|
|
264
|
+
|
|
265
|
+
return processed_results
|
|
266
|
+
|
|
267
|
+
async def _execute_single_task(
|
|
268
|
+
self,
|
|
269
|
+
model_id: str,
|
|
270
|
+
prompt: str,
|
|
271
|
+
system_prompt: Optional[str],
|
|
272
|
+
max_tokens: int,
|
|
273
|
+
temperature: float,
|
|
274
|
+
top_p: float,
|
|
275
|
+
task_index: int,
|
|
276
|
+
tool_ctx: Any,
|
|
277
|
+
original_task: Dict[str, Any]
|
|
278
|
+
) -> Dict[str, Any]:
|
|
279
|
+
"""Execute a single model task.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
model_id: Model identifier
|
|
283
|
+
prompt: Prompt for the model
|
|
284
|
+
system_prompt: Optional system prompt
|
|
285
|
+
max_tokens: Maximum tokens to generate
|
|
286
|
+
temperature: Sampling temperature
|
|
287
|
+
top_p: Top-p sampling parameter
|
|
288
|
+
task_index: Index of the task
|
|
289
|
+
tool_ctx: Tool context for logging
|
|
290
|
+
original_task: Original task configuration
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Task result
|
|
294
|
+
"""
|
|
295
|
+
task_start_time = time.time()
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
# Generate response
|
|
299
|
+
generated_text, metadata = await self.provider.generate(
|
|
300
|
+
model_id=model_id,
|
|
301
|
+
prompt=prompt,
|
|
302
|
+
system_prompt=system_prompt,
|
|
303
|
+
max_tokens=max_tokens,
|
|
304
|
+
temperature=temperature,
|
|
305
|
+
top_p=top_p
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Calculate execution time
|
|
309
|
+
task_execution_time = time.time() - task_start_time
|
|
310
|
+
|
|
311
|
+
await tool_ctx.info(f"Task {task_index} completed in {task_execution_time:.2f}s")
|
|
312
|
+
|
|
313
|
+
# Return result
|
|
314
|
+
return {
|
|
315
|
+
"model": original_task["model"],
|
|
316
|
+
"identifier": model_id,
|
|
317
|
+
"result": generated_text,
|
|
318
|
+
"execution_time": task_execution_time,
|
|
319
|
+
"success": True,
|
|
320
|
+
"metadata": metadata
|
|
321
|
+
}
|
|
322
|
+
except Exception as e:
|
|
323
|
+
await tool_ctx.error(f"Error executing task {task_index}: {str(e)}")
|
|
324
|
+
return {
|
|
325
|
+
"model": original_task["model"],
|
|
326
|
+
"identifier": model_id,
|
|
327
|
+
"error": str(e),
|
|
328
|
+
"execution_time": time.time() - task_start_time,
|
|
329
|
+
"success": False
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
def _format_result(self, results: List[Dict[str, Any]], total_execution_time: float) -> str:
|
|
333
|
+
"""Format the task results.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
results: List of task results
|
|
337
|
+
total_execution_time: Total execution time
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Formatted results
|
|
341
|
+
"""
|
|
342
|
+
# Calculate summary statistics
|
|
343
|
+
successful = [r for r in results if r.get("success", False)]
|
|
344
|
+
failed = [r for r in results if not r.get("success", False)]
|
|
345
|
+
|
|
346
|
+
# Create the result string
|
|
347
|
+
output = [f"### LM Studio Dispatch Results\n"]
|
|
348
|
+
output.append(f"**Total execution time:** {total_execution_time:.2f}s")
|
|
349
|
+
output.append(f"**Models used:** {len(results)}")
|
|
350
|
+
output.append(f"**Successful:** {len(successful)}")
|
|
351
|
+
output.append(f"**Failed:** {len(failed)}\n")
|
|
352
|
+
|
|
353
|
+
# Add the results for each model
|
|
354
|
+
for i, result in enumerate(results):
|
|
355
|
+
model_name = result.get("model", "Unknown model")
|
|
356
|
+
model_id = result.get("identifier", model_name)
|
|
357
|
+
|
|
358
|
+
output.append(f"## Model {i+1}: {model_name}")
|
|
359
|
+
|
|
360
|
+
if result.get("success", False):
|
|
361
|
+
exec_time = result.get("execution_time", 0)
|
|
362
|
+
output.append(f"**Execution time:** {exec_time:.2f}s")
|
|
363
|
+
|
|
364
|
+
# Add the result
|
|
365
|
+
output.append("\n**Result:**\n")
|
|
366
|
+
output.append(result.get("result", "No result"))
|
|
367
|
+
else:
|
|
368
|
+
output.append(f"**Error:** {result.get('error', 'Unknown error')}")
|
|
369
|
+
|
|
370
|
+
output.append("\n" + "-" * 40 + "\n")
|
|
371
|
+
|
|
372
|
+
return "\n".join(output)
|
|
373
|
+
|
|
374
|
+
@override
|
|
375
|
+
def register(self, mcp_server: FastMCP) -> None:
|
|
376
|
+
"""Register this tool with the MCP server.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
mcp_server: The FastMCP server instance
|
|
380
|
+
"""
|
|
381
|
+
tool_self = self # Create a reference to self for use in the closure
|
|
382
|
+
|
|
383
|
+
@mcp_server.tool(name=self.name, description=self.mcp_description)
|
|
384
|
+
async def lmstudio_dispatch(ctx: MCPContext, model_tasks: List[Dict[str, Any]]) -> str:
|
|
385
|
+
return await tool_self.call(ctx, model_tasks=model_tasks)
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""LM Studio provider for agent delegation.
|
|
2
|
+
|
|
3
|
+
Enables the use of local LLMs via LM Studio's Python SDK.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import asyncio
|
|
8
|
+
import functools
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
from hanzo_mcp.tools.agent.base_provider import BaseModelProvider
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LMStudioProvider(BaseModelProvider):
|
|
17
|
+
"""Provider for local models via LM Studio Python SDK."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
"""Initialize the LM Studio provider."""
|
|
21
|
+
self.models = {}
|
|
22
|
+
self.initialized = False
|
|
23
|
+
|
|
24
|
+
async def initialize(self) -> None:
|
|
25
|
+
"""Initialize the LM Studio provider.
|
|
26
|
+
|
|
27
|
+
Import is done here to avoid dependency issues if LM Studio SDK is not installed.
|
|
28
|
+
"""
|
|
29
|
+
if self.initialized:
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
# Dynamic import to avoid dependency issues if LM Studio is not installed
|
|
34
|
+
from importlib.util import find_spec
|
|
35
|
+
if find_spec("lmstudio") is None:
|
|
36
|
+
logger.warning("LM Studio Python SDK not installed. Install with 'pip install lmstudio'")
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
# Import the LM Studio module
|
|
40
|
+
import lmstudio as lms
|
|
41
|
+
self.lms = lms
|
|
42
|
+
self.initialized = True
|
|
43
|
+
logger.info("LM Studio provider initialized successfully")
|
|
44
|
+
except ImportError as e:
|
|
45
|
+
logger.error(f"Failed to import LM Studio Python SDK: {str(e)}")
|
|
46
|
+
logger.error("Install LM Studio Python SDK with 'pip install lmstudio'")
|
|
47
|
+
except Exception as e:
|
|
48
|
+
logger.error(f"Failed to initialize LM Studio provider: {str(e)}")
|
|
49
|
+
|
|
50
|
+
async def load_model(self, model_name: str, identifier: Optional[str] = None) -> str:
|
|
51
|
+
"""Load a model from LM Studio.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_name: The name of the model to load
|
|
55
|
+
identifier: Optional identifier for the model instance
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
The identifier for the loaded model
|
|
59
|
+
"""
|
|
60
|
+
if not self.initialized:
|
|
61
|
+
await self.initialize()
|
|
62
|
+
if not self.initialized:
|
|
63
|
+
raise RuntimeError("LM Studio provider failed to initialize")
|
|
64
|
+
|
|
65
|
+
model_id = identifier or model_name
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
if model_id in self.models:
|
|
69
|
+
logger.info(f"Model {model_id} already loaded")
|
|
70
|
+
return model_id
|
|
71
|
+
|
|
72
|
+
logger.info(f"Loading model {model_name}")
|
|
73
|
+
|
|
74
|
+
# Use the thread pool to run the blocking operation
|
|
75
|
+
loop = asyncio.get_event_loop()
|
|
76
|
+
model = await loop.run_in_executor(
|
|
77
|
+
None,
|
|
78
|
+
functools.partial(self.lms.llm, model_name)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Store the model with its identifier
|
|
82
|
+
self.models[model_id] = model
|
|
83
|
+
logger.info(f"Model {model_name} loaded successfully as {model_id}")
|
|
84
|
+
return model_id
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error(f"Failed to load model {model_name}: {str(e)}")
|
|
88
|
+
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
|
|
89
|
+
|
|
90
|
+
async def generate(
|
|
91
|
+
self,
|
|
92
|
+
model_id: str,
|
|
93
|
+
prompt: str,
|
|
94
|
+
system_prompt: Optional[str] = None,
|
|
95
|
+
max_tokens: int = 4096,
|
|
96
|
+
temperature: float = 0.7,
|
|
97
|
+
top_p: float = 0.95,
|
|
98
|
+
stop_sequences: Optional[List[str]] = None,
|
|
99
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
100
|
+
"""Generate a response from the model.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model_id: The identifier of the model to use
|
|
104
|
+
prompt: The prompt to send to the model
|
|
105
|
+
system_prompt: Optional system prompt to send to the model
|
|
106
|
+
max_tokens: Maximum number of tokens to generate
|
|
107
|
+
temperature: Sampling temperature
|
|
108
|
+
top_p: Top-p sampling parameter
|
|
109
|
+
stop_sequences: Optional list of strings that will stop generation
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A tuple of (generated text, metadata)
|
|
113
|
+
"""
|
|
114
|
+
if not self.initialized:
|
|
115
|
+
await self.initialize()
|
|
116
|
+
if not self.initialized:
|
|
117
|
+
raise RuntimeError("LM Studio provider failed to initialize")
|
|
118
|
+
|
|
119
|
+
if model_id not in self.models:
|
|
120
|
+
raise ValueError(f"Model {model_id} not loaded")
|
|
121
|
+
|
|
122
|
+
model = self.models[model_id]
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
logger.debug(f"Generating with model {model_id}")
|
|
126
|
+
|
|
127
|
+
# Prepare generation parameters
|
|
128
|
+
params = {
|
|
129
|
+
"max_tokens": max_tokens,
|
|
130
|
+
"temperature": temperature,
|
|
131
|
+
"top_p": top_p
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if stop_sequences:
|
|
135
|
+
params["stop"] = stop_sequences
|
|
136
|
+
|
|
137
|
+
# Generate response
|
|
138
|
+
loop = asyncio.get_event_loop()
|
|
139
|
+
|
|
140
|
+
if system_prompt:
|
|
141
|
+
# For chat models with system prompt
|
|
142
|
+
response_future = loop.run_in_executor(
|
|
143
|
+
None,
|
|
144
|
+
functools.partial(
|
|
145
|
+
model.chat,
|
|
146
|
+
system=system_prompt,
|
|
147
|
+
message=prompt,
|
|
148
|
+
**params
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
# For completion models without system prompt
|
|
153
|
+
response_future = loop.run_in_executor(
|
|
154
|
+
None,
|
|
155
|
+
functools.partial(
|
|
156
|
+
model.respond,
|
|
157
|
+
prompt,
|
|
158
|
+
**params
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
response = await response_future
|
|
163
|
+
|
|
164
|
+
# Extract the generated text
|
|
165
|
+
if isinstance(response, dict) and "text" in response:
|
|
166
|
+
generated_text = response["text"]
|
|
167
|
+
elif isinstance(response, str):
|
|
168
|
+
generated_text = response
|
|
169
|
+
else:
|
|
170
|
+
generated_text = str(response)
|
|
171
|
+
|
|
172
|
+
# Metadata
|
|
173
|
+
metadata = {
|
|
174
|
+
"model": model_id,
|
|
175
|
+
"usage": {
|
|
176
|
+
"prompt_tokens": -1, # LM Studio Python SDK doesn't provide token counts
|
|
177
|
+
"completion_tokens": -1,
|
|
178
|
+
"total_tokens": -1
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
logger.debug(f"Generated {len(generated_text)} chars with model {model_id}")
|
|
183
|
+
return generated_text, metadata
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.error(f"Failed to generate with model {model_id}: {str(e)}")
|
|
186
|
+
raise RuntimeError(f"Failed to generate with model {model_id}: {str(e)}")
|
|
187
|
+
|
|
188
|
+
async def unload_model(self, model_id: str) -> None:
|
|
189
|
+
"""Unload a model from LM Studio.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
model_id: The identifier of the model to unload
|
|
193
|
+
"""
|
|
194
|
+
if not self.initialized:
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
if model_id not in self.models:
|
|
198
|
+
logger.warning(f"Model {model_id} not loaded")
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
# Just remove the model reference, Python garbage collection will handle it
|
|
203
|
+
del self.models[model_id]
|
|
204
|
+
logger.info(f"Model {model_id} unloaded")
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Failed to unload model {model_id}: {str(e)}")
|
|
207
|
+
|
|
208
|
+
async def shutdown(self) -> None:
|
|
209
|
+
"""Shutdown the LM Studio provider."""
|
|
210
|
+
if not self.initialized:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
# Clear all model references
|
|
215
|
+
self.models = {}
|
|
216
|
+
self.initialized = False
|
|
217
|
+
logger.info("LM Studio provider shut down")
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Failed to shut down LM Studio provider: {str(e)}")
|