janito 3.12.1__py3-none-any.whl → 3.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
janito/llm/driver.py CHANGED
@@ -1,254 +1,290 @@
1
- import threading
2
- from abc import ABC, abstractmethod
3
- from queue import Queue
4
- from janito.llm.driver_input import DriverInput
5
- from janito.driver_events import (
6
- RequestStarted,
7
- RequestFinished,
8
- ResponseReceived,
9
- RequestStatus,
10
- )
11
-
12
-
13
- class LLMDriver(ABC):
14
- def clear_output_queue(self):
15
- """Remove all items from the output queue."""
16
- try:
17
- while True:
18
- self.output_queue.get_nowait()
19
- except Exception:
20
- pass
21
-
22
- def clear_input_queue(self):
23
- """Remove all items from the input queue."""
24
- try:
25
- while True:
26
- self.input_queue.get_nowait()
27
- except Exception:
28
- pass
29
-
30
- """
31
- Abstract base class for LLM drivers (threaded, queue-based).
32
- Subclasses must implement:
33
- - _call_api: Call provider API with DriverInput.
34
- - _convert_completion_message_to_parts: Convert provider message to MessagePart objects.
35
- - convert_history_to_api_messages: Convert LLMConversationHistory to provider-specific messages format for API calls.
36
- Workflow:
37
- - Accept DriverInput via input_queue.
38
- - Put DriverEvents on output_queue.
39
- - Use start() to launch worker loop in a thread.
40
- The driver automatically creates its own input/output queues, accessible via .input_queue and .output_queue.
41
- """
42
-
43
- available = True
44
- unavailable_reason = None
45
-
46
- def __init__(self, tools_adapter=None, provider_name=None):
47
- self.input_queue = Queue()
48
- self.output_queue = Queue()
49
- self._thread = None
50
- self.tools_adapter = tools_adapter
51
- self.provider_name = provider_name
52
-
53
- def start(self):
54
- """Validate tool schemas (if any) and launch the driver's background thread to process DriverInput objects."""
55
- # Validate all tool schemas before starting the thread
56
- if self.tools_adapter is not None:
57
- from janito.tools.tools_schema import ToolSchemaBase
58
-
59
- validator = ToolSchemaBase()
60
- for tool in self.tools_adapter.get_tools():
61
- # Validate the tool's class (not instance)
62
- validator.validate_tool_class(tool.__class__)
63
- self._thread = threading.Thread(target=self._run, daemon=True)
64
- self._thread.start()
65
-
66
- def _run(self):
67
- while True:
68
- driver_input = self.input_queue.get()
69
- if driver_input is None:
70
- break # Sentinel received, exit thread
71
- try:
72
- # Only process if driver_input is a DriverInput instance
73
- if isinstance(driver_input, DriverInput):
74
- self.process_driver_input(driver_input)
75
- else:
76
- # Optionally log or handle unexpected input types
77
- pass
78
- except Exception as e:
79
- import traceback
80
-
81
- self.output_queue.put(
82
- RequestFinished(
83
- driver_name=self.__class__.__name__,
84
- request_id=getattr(driver_input.config, "request_id", None),
85
- status=RequestStatus.ERROR,
86
- error=str(e),
87
- exception=e,
88
- traceback=traceback.format_exc(),
89
- )
90
- )
91
-
92
- def handle_driver_unavailable(self, request_id):
93
- self.output_queue.put(
94
- RequestFinished(
95
- driver_name=self.__class__.__name__,
96
- request_id=request_id,
97
- status=RequestStatus.ERROR,
98
- error=self.unavailable_reason,
99
- exception=ImportError(self.unavailable_reason),
100
- traceback=None,
101
- )
102
- )
103
-
104
- def emit_response_received(
105
- self, driver_name, request_id, result, parts, timestamp=None, metadata=None
106
- ):
107
- self.output_queue.put(
108
- ResponseReceived(
109
- driver_name=driver_name,
110
- request_id=request_id,
111
- parts=parts,
112
- tool_results=[],
113
- timestamp=timestamp,
114
- metadata=metadata or {},
115
- )
116
- )
117
- # Debug: print summary of parts by type
118
- if hasattr(self, "config") and getattr(self.config, "verbose_api", False):
119
- from collections import Counter
120
-
121
- type_counts = Counter(type(p).__name__ for p in parts)
122
- print(
123
- f"[verbose-api] Emitting ResponseReceived with parts: {dict(type_counts)}",
124
- flush=True,
125
- )
126
-
127
- def process_driver_input(self, driver_input: DriverInput):
128
-
129
- config = driver_input.config
130
- request_id = getattr(config, "request_id", None)
131
- if not self.available:
132
- self.handle_driver_unavailable(request_id)
133
- return
134
- # Prepare payload for RequestStarted event
135
- payload = {"provider_name": self.provider_name}
136
- if hasattr(config, "model") and getattr(config, "model", None):
137
- payload["model"] = getattr(config, "model")
138
- elif hasattr(config, "model_name") and getattr(config, "model_name", None):
139
- payload["model"] = getattr(config, "model_name")
140
- self.output_queue.put(
141
- RequestStarted(
142
- driver_name=self.__class__.__name__,
143
- request_id=request_id,
144
- payload=payload,
145
- )
146
- )
147
- # Check for cancel_event before starting
148
- if (
149
- hasattr(driver_input, "cancel_event")
150
- and driver_input.cancel_event is not None
151
- and driver_input.cancel_event.is_set()
152
- ):
153
- self.output_queue.put(
154
- RequestFinished(
155
- driver_name=self.__class__.__name__,
156
- request_id=request_id,
157
- status=RequestStatus.CANCELLED,
158
- reason="Canceled before start",
159
- )
160
- )
161
- return
162
- try:
163
- result = self._call_api(driver_input)
164
- # If result is None and cancel_event is set, treat as cancelled
165
- if (
166
- hasattr(driver_input, "cancel_event")
167
- and driver_input.cancel_event is not None
168
- and driver_input.cancel_event.is_set()
169
- ):
170
- self.output_queue.put(
171
- RequestFinished(
172
- driver_name=self.__class__.__name__,
173
- request_id=request_id,
174
- status=RequestStatus.CANCELLED,
175
- reason="Cancelled during processing (post-API)",
176
- )
177
- )
178
- return
179
- if (
180
- result is None
181
- and hasattr(driver_input, "cancel_event")
182
- and driver_input.cancel_event is not None
183
- and driver_input.cancel_event.is_set()
184
- ):
185
- # Already handled by driver
186
- return
187
- # Check for cancel_event after API call (subclasses should also check during long calls)
188
- if (
189
- hasattr(driver_input, "cancel_event")
190
- and driver_input.cancel_event is not None
191
- and driver_input.cancel_event.is_set()
192
- ):
193
- self.output_queue.put(
194
- RequestFinished(
195
- driver_name=self.__class__.__name__,
196
- request_id=request_id,
197
- status=RequestStatus.CANCELLED,
198
- reason="Canceled during processing",
199
- )
200
- )
201
- return
202
- message = self._get_message_from_result(result)
203
- parts = (
204
- self._convert_completion_message_to_parts(message) if message else []
205
- )
206
- timestamp = getattr(result, "created", None)
207
- metadata = {"usage": getattr(result, "usage", None), "raw_response": result}
208
- self.emit_response_received(
209
- self.__class__.__name__, request_id, result, parts, timestamp, metadata
210
- )
211
- except Exception as ex:
212
- import traceback
213
-
214
- self.output_queue.put(
215
- RequestFinished(
216
- driver_name=self.__class__.__name__,
217
- request_id=request_id,
218
- status=RequestStatus.ERROR,
219
- error=str(ex),
220
- exception=ex,
221
- traceback=traceback.format_exc(),
222
- )
223
- )
224
-
225
- @abstractmethod
226
- def _prepare_api_kwargs(self, config, conversation):
227
- """
228
- Subclasses must implement: Prepare API kwargs for the provider, including any tool schemas if needed.
229
- """
230
- pass
231
-
232
- @abstractmethod
233
- def _call_api(self, driver_input: DriverInput):
234
- """Subclasses implement: Use driver_input to call provider and return result object."""
235
- pass
236
-
237
- @abstractmethod
238
- def _convert_completion_message_to_parts(self, message):
239
- """Subclasses implement: Convert provider message to list of MessagePart objects."""
240
- pass
241
-
242
- @abstractmethod
243
- def convert_history_to_api_messages(self, conversation_history):
244
- """
245
- Subclasses implement: Convert LLMConversationHistory to the messages object required by their provider API.
246
- :param conversation_history: LLMConversationHistory instance
247
- :return: Provider-specific messages object (e.g., list of dicts for OpenAI)
248
- """
249
- pass
250
-
251
- @abstractmethod
252
- def _get_message_from_result(self, result):
253
- """Extract the message object from the provider result. Subclasses must implement this."""
254
- raise NotImplementedError("Subclasses must implement _get_message_from_result.")
1
+ import threading
2
+ from abc import ABC, abstractmethod
3
+ from queue import Queue
4
+ from janito.llm.driver_input import DriverInput
5
+ from janito.driver_events import (
6
+ RequestStarted,
7
+ RequestFinished,
8
+ ResponseReceived,
9
+ RequestStatus,
10
+ )
11
+ from janito.llm.response_cache import ResponseCache
12
+
13
+
14
+ class LLMDriver(ABC):
15
+ def clear_output_queue(self):
16
+ """Remove all items from the output queue."""
17
+ try:
18
+ while True:
19
+ self.output_queue.get_nowait()
20
+ except Exception:
21
+ pass
22
+
23
+ def clear_input_queue(self):
24
+ """Remove all items from the input queue."""
25
+ try:
26
+ while True:
27
+ self.input_queue.get_nowait()
28
+ except Exception:
29
+ pass
30
+
31
+ """
32
+ Abstract base class for LLM drivers (threaded, queue-based).
33
+ Subclasses must implement:
34
+ - _call_api: Call provider API with DriverInput.
35
+ - _convert_completion_message_to_parts: Convert provider message to MessagePart objects.
36
+ - convert_history_to_api_messages: Convert LLMConversationHistory to provider-specific messages format for API calls.
37
+ Workflow:
38
+ - Accept DriverInput via input_queue.
39
+ - Put DriverEvents on output_queue.
40
+ - Use start() to launch worker loop in a thread.
41
+ The driver automatically creates its own input/output queues, accessible via .input_queue and .output_queue.
42
+ """
43
+
44
+ available = True
45
+ unavailable_reason = None
46
+
47
+ def __init__(self, tools_adapter=None, provider_name=None, enable_cache=True):
48
+ self.input_queue = Queue()
49
+ self.output_queue = Queue()
50
+ self._thread = None
51
+ self.tools_adapter = tools_adapter
52
+ self.provider_name = provider_name
53
+ self.enable_cache = enable_cache
54
+ self.response_cache = ResponseCache() if enable_cache else None
55
+
56
+ def start(self):
57
+ """Validate tool schemas (if any) and launch the driver's background thread to process DriverInput objects."""
58
+ # Validate all tool schemas before starting the thread
59
+ if self.tools_adapter is not None:
60
+ from janito.tools.tools_schema import ToolSchemaBase
61
+
62
+ validator = ToolSchemaBase()
63
+ for tool in self.tools_adapter.get_tools():
64
+ # Validate the tool's class (not instance)
65
+ validator.validate_tool_class(tool.__class__)
66
+ self._thread = threading.Thread(target=self._run, daemon=True)
67
+ self._thread.start()
68
+
69
+ def _run(self):
70
+ while True:
71
+ driver_input = self.input_queue.get()
72
+ if driver_input is None:
73
+ break # Sentinel received, exit thread
74
+ try:
75
+ # Only process if driver_input is a DriverInput instance
76
+ if isinstance(driver_input, DriverInput):
77
+ self.process_driver_input(driver_input)
78
+ else:
79
+ # Optionally log or handle unexpected input types
80
+ pass
81
+ except Exception as e:
82
+ import traceback
83
+
84
+ self.output_queue.put(
85
+ RequestFinished(
86
+ driver_name=self.__class__.__name__,
87
+ request_id=getattr(driver_input.config, "request_id", None),
88
+ status=RequestStatus.ERROR,
89
+ error=str(e),
90
+ exception=e,
91
+ traceback=traceback.format_exc(),
92
+ )
93
+ )
94
+
95
+ def handle_driver_unavailable(self, request_id):
96
+ self.output_queue.put(
97
+ RequestFinished(
98
+ driver_name=self.__class__.__name__,
99
+ request_id=request_id,
100
+ status=RequestStatus.ERROR,
101
+ error=self.unavailable_reason,
102
+ exception=ImportError(self.unavailable_reason),
103
+ traceback=None,
104
+ )
105
+ )
106
+
107
+ def emit_response_received(
108
+ self, driver_name, request_id, result, parts, timestamp=None, metadata=None
109
+ ):
110
+ self.output_queue.put(
111
+ ResponseReceived(
112
+ driver_name=driver_name,
113
+ request_id=request_id,
114
+ parts=parts,
115
+ tool_results=[],
116
+ timestamp=timestamp,
117
+ metadata=metadata or {},
118
+ )
119
+ )
120
+ # Debug: print summary of parts by type
121
+ if hasattr(self, "config") and getattr(self.config, "verbose_api", False):
122
+ from collections import Counter
123
+
124
+ type_counts = Counter(type(p).__name__ for p in parts)
125
+ print(
126
+ f"[verbose-api] Emitting ResponseReceived with parts: {dict(type_counts)}",
127
+ flush=True,
128
+ )
129
+
130
+ def process_driver_input(self, driver_input: DriverInput):
131
+
132
+ config = driver_input.config
133
+ request_id = getattr(config, "request_id", None)
134
+ if not self.available:
135
+ self.handle_driver_unavailable(request_id)
136
+ return
137
+
138
+ # Check cache first if enabled
139
+ if self.response_cache:
140
+ cached_response = self.response_cache.get(driver_input)
141
+ if cached_response is not None:
142
+ # Use cached response
143
+ message = self._get_message_from_result(cached_response)
144
+ parts = (
145
+ self._convert_completion_message_to_parts(message) if message else []
146
+ )
147
+ timestamp = getattr(cached_response, "created", None)
148
+ metadata = {"usage": getattr(cached_response, "usage", None), "raw_response": cached_response, "cached": True}
149
+ self.emit_response_received(
150
+ self.__class__.__name__, request_id, cached_response, parts, timestamp, metadata
151
+ )
152
+ return
153
+
154
+ # Prepare payload for RequestStarted event
155
+ payload = {"provider_name": self.provider_name}
156
+ if hasattr(config, "model") and getattr(config, "model", None):
157
+ payload["model"] = getattr(config, "model")
158
+ elif hasattr(config, "model_name") and getattr(config, "model_name", None):
159
+ payload["model"] = getattr(config, "model_name")
160
+ self.output_queue.put(
161
+ RequestStarted(
162
+ driver_name=self.__class__.__name__,
163
+ request_id=request_id,
164
+ payload=payload,
165
+ )
166
+ )
167
+ # Check for cancel_event before starting
168
+ if (
169
+ hasattr(driver_input, "cancel_event")
170
+ and driver_input.cancel_event is not None
171
+ and driver_input.cancel_event.is_set()
172
+ ):
173
+ self.output_queue.put(
174
+ RequestFinished(
175
+ driver_name=self.__class__.__name__,
176
+ request_id=request_id,
177
+ status=RequestStatus.CANCELLED,
178
+ reason="Canceled before start",
179
+ )
180
+ )
181
+ return
182
+ try:
183
+ result = self._call_api(driver_input)
184
+ # If result is None and cancel_event is set, treat as cancelled
185
+ if (
186
+ hasattr(driver_input, "cancel_event")
187
+ and driver_input.cancel_event is not None
188
+ and driver_input.cancel_event.is_set()
189
+ ):
190
+ self.output_queue.put(
191
+ RequestFinished(
192
+ driver_name=self.__class__.__name__,
193
+ request_id=request_id,
194
+ status=RequestStatus.CANCELLED,
195
+ reason="Cancelled during processing (post-API)",
196
+ )
197
+ )
198
+ return
199
+ if (
200
+ result is None
201
+ and hasattr(driver_input, "cancel_event")
202
+ and driver_input.cancel_event is not None
203
+ and driver_input.cancel_event.is_set()
204
+ ):
205
+ # Already handled by driver
206
+ return
207
+ # Check for cancel_event after API call (subclasses should also check during long calls)
208
+ if (
209
+ hasattr(driver_input, "cancel_event")
210
+ and driver_input.cancel_event is not None
211
+ and driver_input.cancel_event.is_set()
212
+ ):
213
+ self.output_queue.put(
214
+ RequestFinished(
215
+ driver_name=self.__class__.__name__,
216
+ request_id=request_id,
217
+ status=RequestStatus.CANCELLED,
218
+ reason="Canceled during processing",
219
+ )
220
+ )
221
+ return
222
+ message = self._get_message_from_result(result)
223
+ parts = (
224
+ self._convert_completion_message_to_parts(message) if message else []
225
+ )
226
+ timestamp = getattr(result, "created", None)
227
+ metadata = {"usage": getattr(result, "usage", None), "raw_response": result}
228
+
229
+ # Cache the response if caching is enabled
230
+ if self.response_cache:
231
+ self.response_cache.set(driver_input, result)
232
+
233
+ self.emit_response_received(
234
+ self.__class__.__name__, request_id, result, parts, timestamp, metadata
235
+ )
236
+ except Exception as ex:
237
+ import traceback
238
+
239
+ self.output_queue.put(
240
+ RequestFinished(
241
+ driver_name=self.__class__.__name__,
242
+ request_id=request_id,
243
+ status=RequestStatus.ERROR,
244
+ error=str(ex),
245
+ exception=ex,
246
+ traceback=traceback.format_exc(),
247
+ )
248
+ )
249
+
250
+ def clear_cache(self):
251
+ """Clear the response cache if caching is enabled."""
252
+ if self.response_cache:
253
+ self.response_cache.clear()
254
+
255
+ def get_cache_stats(self):
256
+ """Get cache statistics if caching is enabled."""
257
+ if self.response_cache:
258
+ return self.response_cache.get_stats()
259
+ return {"total_entries": 0, "total_size": 0}
260
+
261
+ @abstractmethod
262
+ def _prepare_api_kwargs(self, config, conversation):
263
+ """
264
+ Subclasses must implement: Prepare API kwargs for the provider, including any tool schemas if needed.
265
+ """
266
+ pass
267
+
268
+ @abstractmethod
269
+ def _call_api(self, driver_input: DriverInput):
270
+ """Subclasses implement: Use driver_input to call provider and return result object."""
271
+ pass
272
+
273
+ @abstractmethod
274
+ def _convert_completion_message_to_parts(self, message):
275
+ """Subclasses implement: Convert provider message to list of MessagePart objects."""
276
+ pass
277
+
278
+ @abstractmethod
279
+ def convert_history_to_api_messages(self, conversation_history):
280
+ """
281
+ Subclasses implement: Convert LLMConversationHistory to the messages object required by their provider API.
282
+ :param conversation_history: LLMConversationHistory instance
283
+ :return: Provider-specific messages object (e.g., list of dicts for OpenAI)
284
+ """
285
+ pass
286
+
287
+ @abstractmethod
288
+ def _get_message_from_result(self, result):
289
+ """Extract the message object from the provider result. Subclasses must implement this."""
290
+ raise NotImplementedError("Subclasses must implement _get_message_from_result.")
@@ -0,0 +1,57 @@
1
+ """
2
+ Simple in-memory cache for LLM responses based on input hash.
3
+ No expiration - cache lives for the duration of the process.
4
+ """
5
+
6
+ import hashlib
7
+ import json
8
+ from typing import Any, Dict, Optional
9
+ from janito.llm.driver_input import DriverInput
10
+
11
+
12
+ class ResponseCache:
13
+ """Simple in-memory cache for LLM responses with no expiration."""
14
+
15
+ def __init__(self):
16
+ self._cache: Dict[str, Any] = {}
17
+
18
+ def _generate_key(self, driver_input: DriverInput) -> str:
19
+ """Generate a cache key from driver input."""
20
+ # Create a deterministic representation of the input
21
+ cache_data = {
22
+ "conversation_history": driver_input.conversation_history.get_history(),
23
+ "config": {
24
+ "model": getattr(driver_input.config, "model", None),
25
+ "temperature": getattr(driver_input.config, "temperature", None),
26
+ "max_tokens": getattr(driver_input.config, "max_tokens", None),
27
+ "top_p": getattr(driver_input.config, "top_p", None),
28
+ "presence_penalty": getattr(driver_input.config, "presence_penalty", None),
29
+ "frequency_penalty": getattr(driver_input.config, "frequency_penalty", None),
30
+ "stop": getattr(driver_input.config, "stop", None),
31
+ }
32
+ }
33
+
34
+ # Create hash from JSON representation
35
+ cache_str = json.dumps(cache_data, sort_keys=True, separators=(',', ':'))
36
+ return hashlib.sha256(cache_str.encode('utf-8')).hexdigest()
37
+
38
+ def get(self, driver_input: DriverInput) -> Optional[Any]:
39
+ """Get cached response for the given input."""
40
+ key = self._generate_key(driver_input)
41
+ return self._cache.get(key)
42
+
43
+ def set(self, driver_input: DriverInput, response: Any) -> None:
44
+ """Cache the response for the given input."""
45
+ key = self._generate_key(driver_input)
46
+ self._cache[key] = response
47
+
48
+ def clear(self) -> None:
49
+ """Clear all cached responses."""
50
+ self._cache.clear()
51
+
52
+ def get_stats(self) -> Dict[str, int]:
53
+ """Get cache statistics."""
54
+ return {
55
+ "total_entries": len(self._cache),
56
+ "total_size": sum(len(str(v)) for v in self._cache.values())
57
+ }