litellm-adk 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
litellm_adk/__init__.py CHANGED
@@ -1,10 +1,12 @@
1
1
  from .agents import LiteLLMAgent
2
+ from .session import Session
2
3
  from .tools import tool, tool_registry
3
4
  from .config.settings import settings
4
5
  from .memory import BaseMemory, InMemoryMemory, FileMemory, MongoDBMemory
5
6
 
6
7
  __all__ = [
7
8
  "LiteLLMAgent",
9
+ "Session",
8
10
  "tool",
9
11
  "tool_registry",
10
12
  "settings",
litellm_adk/core/agent.py CHANGED
@@ -4,9 +4,11 @@ from typing import List, Dict, Any, Optional, Union, Generator, AsyncGenerator
4
4
  from .base import BaseAgent
5
5
  from ..observability.logger import adk_logger
6
6
  from ..config.settings import settings
7
+ from ..session import Session
7
8
  from ..tools.registry import tool_registry
8
9
  from ..memory.base import BaseMemory
9
10
  from ..memory.in_memory import InMemoryMemory
11
+ from .context import ContextManager
10
12
  import uuid
11
13
 
12
14
  # Global LiteLLM configuration for resilience
@@ -25,7 +27,7 @@ class LiteLLMAgent(BaseAgent):
25
27
  system_prompt: str = "You are a helpful assistant.",
26
28
  tools: Optional[List[Dict[str, Any]]] = None,
27
29
  memory: Optional[BaseMemory] = None,
28
- session_id: Optional[str] = None,
30
+ max_context_tokens: Optional[int] = None,
29
31
  **kwargs
30
32
  ):
31
33
  self.model = model or settings.model
@@ -57,71 +59,103 @@ class LiteLLMAgent(BaseAgent):
57
59
 
58
60
  self.extra_kwargs = kwargs
59
61
 
60
- # Ensure model-specific parameters # Default parallel_tool_calls if not explicitly provided
61
62
  if "parallel_tool_calls" not in self.extra_kwargs:
62
- self.extra_kwargs["parallel_tool_calls"] = False
63
+ self.extra_kwargs["parallel_tool_calls"] = True
63
64
 
64
65
  self.sequential_tool_execution = kwargs.get("sequential_tool_execution", settings.sequential_execution)
65
66
 
66
67
  # Memory Persistence
67
68
  self.memory = memory or InMemoryMemory()
68
- self.session_id = session_id or str(uuid.uuid4())
69
- self.history = self.memory.get_messages(self.session_id)
69
+ self.max_context_tokens = max_context_tokens
70
70
 
71
- if not self.history:
72
- self.history = [{"role": "system", "content": self.system_prompt}]
73
- self.memory.add_message(self.session_id, self.history[0])
74
-
75
- adk_logger.debug(f"Initialized LiteLLMAgent with session_id={self.session_id}, model={self.model}")
71
+ adk_logger.debug(f"Initialized LiteLLMAgent as a service for model={self.model}")
76
72
 
77
- def _prepare_messages(self, prompt: str) -> List[Dict[str, str]]:
78
- # Refresh from memory in case it was modified elsewhere
79
- self.history = self.memory.get_messages(self.session_id)
73
+ def save_session(self, session: Union[str, Session]):
74
+ """Persist session metadata and state to memory."""
75
+ actual_id = session.id if isinstance(session, Session) else session
80
76
 
81
- messages = self.history.copy()
77
+ # If it's a Session object, we dump the full metadata
78
+ if isinstance(session, Session):
79
+ self.memory.save_session_metadata(actual_id, session.model_dump())
80
+ # If it's just an ID, there's nothing to dump from the service layer
81
+
82
+ def _prepare_messages(self, prompt: str, actual_session_id: str) -> List[Dict[str, str]]:
83
+ # 2. Fetch/Initialize History from Memory
84
+ history = self.memory.get_messages(actual_session_id)
85
+ if not history:
86
+ history = [{"role": "system", "content": self.system_prompt}]
87
+ # Don't persist system prompt until first real turn to keep DB clean
88
+
89
+ messages = history.copy()
82
90
  user_msg = {"role": "user", "content": prompt}
83
91
  messages.append(user_msg)
84
92
 
85
- # Persist the user message immediately
86
- self.memory.add_message(self.session_id, user_msg)
87
- self.history.append(user_msg)
93
+ # 3. Persist turn start
94
+ # Ensure messages are sanitized and tokenized before first persistence
95
+ current_user_msg = self._sanitize_message(user_msg)
96
+ current_user_msg["token_count"] = ContextManager.count_tokens([current_user_msg], self.model)
88
97
 
98
+ if not history:
99
+ system_msg = self._sanitize_message(messages[0])
100
+ system_msg["token_count"] = ContextManager.count_tokens([system_msg], self.model)
101
+ self.memory.add_messages(actual_session_id, [system_msg, current_user_msg])
102
+ else:
103
+ self.memory.add_message(actual_session_id, current_user_msg)
104
+
105
+ # 4. Context Management (Truncation)
106
+ if self.max_context_tokens:
107
+ messages = ContextManager.truncate_history(
108
+ messages,
109
+ self.model,
110
+ self.max_context_tokens
111
+ )
112
+
89
113
  return messages
90
114
 
91
- def _update_history(self, final_messages: List[Dict[str, Any]]):
92
- """Sync internal history and memory with the final message state."""
93
- # Find which messages were added since we prepared (the user message was already added)
94
- # We assume messages order is preserved
95
- start_idx = len(self.history)
96
- new_messages = [self._sanitize_message(m) for m in final_messages[start_idx:]]
97
-
115
+ def _update_history(self, new_messages: List[Dict[str, Any]], actual_session_id: str):
116
+ """Persist new messages to memory with token counts."""
98
117
  if new_messages:
99
- self.memory.add_messages(self.session_id, new_messages)
100
- self.history.extend(new_messages)
118
+ sanitized = []
119
+ for m in new_messages:
120
+ s = self._sanitize_message(m)
121
+ # Compute token count if not already present (optimization for future turns)
122
+ if "token_count" not in s:
123
+ s["token_count"] = ContextManager.count_tokens([s], self.model)
124
+ sanitized.append(s)
125
+
126
+ self.memory.add_messages(actual_session_id, sanitized)
101
127
 
102
128
  def _sanitize_message(self, message: Any) -> Dict[str, Any]:
103
- """Convert LiteLLM message objects to plain dictionaries for serialization."""
104
- if isinstance(message, dict):
105
- # Still need to sanitize tool_calls inside if they are objects
106
- if "tool_calls" in message and message["tool_calls"]:
107
- message["tool_calls"] = [self._sanitize_tool_call(tc) for tc in message["tool_calls"]]
108
- return message
129
+ """
130
+ Convert LiteLLM message objects to strictly compliant dictionaries.
131
+ Ensures compatibility with strict providers like OCI.
132
+ """
133
+ # If it's already a dict, extract only what we need to avoid 'extra key' errors
134
+ role = getattr(message, "role", "assistant") if not isinstance(message, dict) else message.get("role", "assistant")
135
+ content = getattr(message, "content", "") if not isinstance(message, dict) else message.get("content", "")
109
136
 
110
- # Manually extract common fields to ensure clean JSON
137
+ # OCI/OpenAI standard: content cannot be None for assistant/user/system
138
+ if content is None:
139
+ content = ""
140
+
111
141
  msg_dict = {
112
- "role": getattr(message, "role", "assistant"),
113
- "content": getattr(message, "content", None)
142
+ "role": role,
143
+ "content": content
114
144
  }
115
145
 
116
- if hasattr(message, "name") and message.name:
117
- msg_dict["name"] = message.name
118
-
119
- if hasattr(message, "tool_calls") and message.tool_calls:
120
- msg_dict["tool_calls"] = [self._sanitize_tool_call(tc) for tc in message.tool_calls]
121
-
122
- if hasattr(message, "tool_call_id") and message.tool_call_id:
123
- msg_dict["tool_call_id"] = message.tool_call_id
146
+ # Handle Tool Calls (Assistant Message)
147
+ tool_calls = getattr(message, "tool_calls", None) if not isinstance(message, dict) else message.get("tool_calls")
148
+ if tool_calls:
149
+ msg_dict["tool_calls"] = [self._sanitize_tool_call(tc) for tc in tool_calls]
124
150
 
151
+ # Handle Tool Result (Tool Role)
152
+ if role == "tool":
153
+ msg_dict["tool_call_id"] = getattr(message, "tool_call_id", None) if not isinstance(message, dict) else message.get("tool_call_id")
154
+ # Name is optional but good practice
155
+ name = getattr(message, "name", None) if not isinstance(message, dict) else message.get("name")
156
+ if name:
157
+ msg_dict["name"] = name
158
+
125
159
  return msg_dict
126
160
 
127
161
  def _sanitize_tool_call(self, tc: Any) -> Dict[str, Any]:
@@ -149,9 +183,28 @@ class LiteLLMAgent(BaseAgent):
149
183
  """Determines if we should process tool calls one by one."""
150
184
  return self.sequential_tool_execution
151
185
 
152
- async def _aexecute_tool(self, tool_call) -> Any:
153
- # Same as _execute_tool but for async if needed in future
154
- return self._execute_tool(tool_call)
186
+ async def _aexecute_tool(self, tool_call) -> Dict[str, Any]:
187
+ """Helper to execute a tool call asynchronously and return formatted result."""
188
+ function_name = self._get_tc_val(tool_call, "function", "name")
189
+ raw_args = self._get_tc_val(tool_call, "function", "arguments") or "{}"
190
+
191
+ try:
192
+ if isinstance(raw_args, dict):
193
+ arguments = raw_args
194
+ else:
195
+ arguments = json.loads(raw_args)
196
+ except json.JSONDecodeError:
197
+ adk_logger.warning(f"Failed to parse tool arguments for {function_name}: {raw_args}")
198
+ arguments = {}
199
+
200
+ result = await tool_registry.aexecute(function_name, **arguments)
201
+
202
+ return {
203
+ "role": "tool",
204
+ "tool_call_id": self._get_tc_val(tool_call, "id"),
205
+ "name": function_name,
206
+ "content": str(result)
207
+ }
155
208
 
156
209
  def _get_tc_val(self, tool_call, attr, subattr=None):
157
210
  """Helper to get value from either object or dict tool call."""
@@ -193,12 +246,14 @@ class LiteLLMAgent(BaseAgent):
193
246
 
194
247
  return tool_registry.execute(function_name, **arguments)
195
248
 
196
- def invoke(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> str:
249
+ def invoke(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, session_id: Optional[Union[str, Session]] = None, **kwargs) -> str:
197
250
  """
198
251
  Execute a synchronous completion with automatic tool calling.
199
252
  """
200
- messages = self._prepare_messages(prompt)
253
+ actual_session_id = session_id.id if isinstance(session_id, Session) else (session_id or str(uuid.uuid4()))
254
+ messages = self._prepare_messages(prompt, actual_session_id=actual_session_id)
201
255
  tools = tools or self.tools
256
+ new_turns = [] # Track only what's new in this specific call
202
257
 
203
258
  adk_logger.info(f"Invoking completion for model: {self.model}")
204
259
 
@@ -216,39 +271,45 @@ class LiteLLMAgent(BaseAgent):
216
271
 
217
272
  # Check if the model wants to call tools
218
273
  if hasattr(message, "tool_calls") and message.tool_calls:
219
- # If sequential is enabled, we only process the FIRST tool call
220
274
  tool_calls_to_process = [message.tool_calls[0]] if self._should_handle_sequentially() else message.tool_calls
221
275
 
222
- # We update the original message to only include the calls we are handling
223
- # (to keep history clean for strict models)
224
276
  if self._should_handle_sequentially():
225
277
  message.tool_calls = tool_calls_to_process
226
278
 
227
279
  sanitized_msg = self._sanitize_message(message)
228
280
  messages.append(sanitized_msg)
281
+ new_turns.append(sanitized_msg)
229
282
 
230
283
  for tool_call in tool_calls_to_process:
231
284
  result = self._execute_tool(tool_call)
232
285
 
233
- messages.append({
286
+ tool_response = {
234
287
  "role": "tool",
235
288
  "tool_call_id": tool_call.id,
236
289
  "name": tool_call.function.name,
237
290
  "content": str(result)
238
- })
291
+ }
292
+ messages.append(tool_response)
293
+ new_turns.append(tool_response)
239
294
 
240
295
  continue
241
296
 
242
- messages.append(self._sanitize_message(message))
243
- self._update_history(messages)
244
- return message.content
297
+ final_msg = self._sanitize_message(message)
298
+ messages.append(final_msg)
299
+ new_turns.append(final_msg)
300
+
301
+ # Persist only the new assistant/tool turns
302
+ self._update_history(new_turns, actual_session_id=actual_session_id)
303
+ return final_msg.get("content")
245
304
 
246
- async def ainvoke(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> str:
305
+ async def ainvoke(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, session_id: Optional[Union[str, Session]] = None, **kwargs) -> str:
247
306
  """
248
307
  Execute an asynchronous completion with automatic tool calling.
249
308
  """
250
- messages = self._prepare_messages(prompt)
309
+ actual_session_id = session_id.id if isinstance(session_id, Session) else (session_id or str(uuid.uuid4()))
310
+ messages = self._prepare_messages(prompt, actual_session_id=actual_session_id)
251
311
  tools = tools or self.tools
312
+ new_turns = []
252
313
 
253
314
  adk_logger.info(f"Invoking async completion for model: {self.model}")
254
315
 
@@ -272,28 +333,39 @@ class LiteLLMAgent(BaseAgent):
272
333
 
273
334
  sanitized_msg = self._sanitize_message(message)
274
335
  messages.append(sanitized_msg)
336
+ new_turns.append(sanitized_msg)
275
337
 
276
- for tool_call in tool_calls_to_process:
277
- result = self._execute_tool(tool_call)
278
- messages.append({
279
- "role": "tool",
280
- "tool_call_id": tool_call.id,
281
- "name": tool_call.function.name,
282
- "content": str(result)
283
- })
338
+ if self._should_handle_sequentially():
339
+ for tool_call in tool_calls_to_process:
340
+ result = await self._aexecute_tool(tool_call)
341
+ messages.append(result)
342
+ new_turns.append(result)
343
+ else:
344
+ # Parallel Execution
345
+ import asyncio
346
+ results = await asyncio.gather(*[self._aexecute_tool(tc) for tc in tool_calls_to_process])
347
+ for res in results:
348
+ messages.append(res)
349
+ new_turns.append(res)
284
350
  continue
285
351
 
286
- messages.append(self._sanitize_message(message))
287
- self._update_history(messages)
288
- return message.content
352
+ final_msg = self._sanitize_message(message)
353
+ messages.append(final_msg)
354
+ new_turns.append(final_msg)
355
+
356
+ self._update_history(new_turns, actual_session_id=actual_session_id)
357
+ return final_msg.get("content")
289
358
 
290
- def stream(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> Generator[str, None, None]:
359
+ def stream(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, session_id: Optional[Union[str, Session]] = None, **kwargs) -> Generator[str, None, None]:
291
360
  """
292
361
  Execute a streaming completion with automatic tool calling.
293
362
  """
294
- messages = self._prepare_messages(prompt)
363
+ actual_session_id = session_id.id if isinstance(session_id, Session) else (session_id or str(uuid.uuid4()))
364
+ messages = self._prepare_messages(prompt, actual_session_id=actual_session_id)
295
365
  tools = tools or self.tools
296
366
 
367
+ new_turns = []
368
+
297
369
  while True:
298
370
  response = litellm.completion(
299
371
  model=self.model,
@@ -356,7 +428,7 @@ class LiteLLMAgent(BaseAgent):
356
428
  if last_tc.function.arguments is None:
357
429
  last_tc.function.arguments = ""
358
430
  last_tc.function.arguments += tc_delta.function.arguments
359
-
431
+
360
432
  # Build final flattened tool calls list (as dicts for history)
361
433
  tool_calls = []
362
434
  for idx in sorted(tool_calls_by_index.keys()):
@@ -370,39 +442,46 @@ class LiteLLMAgent(BaseAgent):
370
442
  "arguments": tc_obj.function.arguments
371
443
  }
372
444
  })
373
-
445
+
374
446
  if tool_calls:
375
447
  # If sequential, only keep the first tool call
376
448
  if self._should_handle_sequentially():
377
449
  tool_calls = [tool_calls[0]]
378
-
450
+
379
451
  # Add the assistant's composite tool call message to history
380
- assistant_msg = {"role": "assistant", "tool_calls": tool_calls, "content": full_content or None}
452
+ assistant_msg = self._sanitize_message({"role": "assistant", "tool_calls": tool_calls, "content": full_content})
381
453
  messages.append(assistant_msg)
454
+ new_turns.append(assistant_msg)
382
455
 
383
456
  for tool_call in tool_calls:
384
457
  result = self._execute_tool(tool_call)
385
- messages.append({
458
+ tool_resp = {
386
459
  "role": "tool",
387
460
  "tool_call_id": tool_call["id"],
388
461
  "name": tool_call["function"]["name"],
389
462
  "content": str(result)
390
- })
463
+ }
464
+ messages.append(tool_resp)
465
+ new_turns.append(tool_resp)
391
466
 
392
467
  # Loop back to continue the conversation with tool results
393
468
  continue
394
469
 
395
470
  # No tool calls, store final content and finish
396
- messages.append({"role": "assistant", "content": full_content})
397
- self._update_history(messages)
471
+ final_msg = self._sanitize_message({"role": "assistant", "content": full_content})
472
+ messages.append(final_msg)
473
+ new_turns.append(final_msg)
474
+ self._update_history(new_turns, actual_session_id=actual_session_id)
398
475
  return
399
476
 
400
- async def astream(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> AsyncGenerator[str, None]:
477
+ async def astream(self, prompt: str, tools: Optional[List[Dict[str, Any]]] = None, session_id: Optional[Union[str, Session]] = None, **kwargs) -> AsyncGenerator[str, None]:
401
478
  """
402
479
  Execute an asynchronous streaming completion with automatic tool calling.
403
480
  """
404
- messages = self._prepare_messages(prompt)
481
+ actual_session_id = session_id.id if isinstance(session_id, Session) else (session_id or str(uuid.uuid4()))
482
+ messages = self._prepare_messages(prompt, actual_session_id=actual_session_id)
405
483
  tools = tools or self.tools
484
+ new_turns = []
406
485
 
407
486
  while True:
408
487
  response = await litellm.acompletion(
@@ -460,7 +539,7 @@ class LiteLLMAgent(BaseAgent):
460
539
  if last_tc.function.arguments is None:
461
540
  last_tc.function.arguments = ""
462
541
  last_tc.function.arguments += tc_delta.function.arguments
463
-
542
+
464
543
  tool_calls = []
465
544
  for idx in sorted(tool_calls_by_index.keys()):
466
545
  for tc_obj in tool_calls_by_index[idx]:
@@ -473,24 +552,30 @@ class LiteLLMAgent(BaseAgent):
473
552
  "arguments": tc_obj.function.arguments
474
553
  }
475
554
  })
476
-
555
+
477
556
  if tool_calls:
478
557
  if self._should_handle_sequentially():
479
558
  tool_calls = [tool_calls[0]]
480
-
481
- assistant_msg = {"role": "assistant", "tool_calls": tool_calls, "content": full_content or None}
559
+
560
+ assistant_msg = self._sanitize_message({"role": "assistant", "tool_calls": tool_calls, "content": full_content})
482
561
  messages.append(assistant_msg)
562
+ new_turns.append(assistant_msg)
483
563
 
484
- for tool_call in tool_calls:
485
- result = self._execute_tool(tool_call)
486
- messages.append({
487
- "role": "tool",
488
- "tool_call_id": tool_call["id"],
489
- "name": tool_call["function"]["name"],
490
- "content": str(result)
491
- })
564
+ if self._should_handle_sequentially():
565
+ for tool_call in tool_calls:
566
+ result = await self._aexecute_tool(tool_call)
567
+ messages.append(result)
568
+ new_turns.append(result)
569
+ else:
570
+ import asyncio
571
+ results = await asyncio.gather(*[self._aexecute_tool(tc) for tc in tool_calls])
572
+ for res in results:
573
+ messages.append(res)
574
+ new_turns.append(res)
492
575
  continue
493
576
 
494
- messages.append({"role": "assistant", "content": full_content})
495
- self._update_history(messages)
577
+ final_msg = self._sanitize_message({"role": "assistant", "content": full_content})
578
+ messages.append(final_msg)
579
+ new_turns.append(final_msg)
580
+ self._update_history(new_turns, actual_session_id=actual_session_id)
496
581
  return
@@ -0,0 +1,84 @@
1
+ from typing import List, Dict, Any, Optional
2
+ import litellm
3
+ from ..observability.logger import adk_logger
4
+
5
+ class ContextManager:
6
+ """
7
+ Handles token counting and history truncation/summarization.
8
+ """
9
+
10
+ @staticmethod
11
+ def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
12
+ """
13
+ Calculate the number of tokens in a list of messages.
14
+ Uses cached 'token_count' if available, otherwise LiteLLM's token_counter.
15
+ """
16
+ # optimization: if it's a single message with a cached count, use it
17
+ if len(messages) == 1 and "token_count" in messages[0]:
18
+ return messages[0]["token_count"]
19
+
20
+ try:
21
+ # Pass a copy to avoid any potential in-place modifications by token_counter
22
+ return litellm.token_counter(model=model, messages=[m.copy() for m in messages])
23
+ except Exception as e:
24
+ adk_logger.warning(f"Token counting failed for model {model}: {e}. Falling back to estimate.")
25
+ # Rough estimate: 4 chars per token
26
+ return sum(len(str(m.get("content", ""))) for m in messages) // 4
27
+
28
+ @staticmethod
29
+ def truncate_history(
30
+ messages: List[Dict[str, Any]],
31
+ model: str,
32
+ max_tokens: int,
33
+ reserve_tokens: int = 500
34
+ ) -> List[Dict[str, Any]]:
35
+ """
36
+ Truncate history to fit within max_tokens, always preserving the system prompt
37
+ and the latest message.
38
+ """
39
+ if not messages:
40
+ return []
41
+
42
+ # 1. Separate System Prompt
43
+ system_prompt = None
44
+ if messages[0].get("role") == "system":
45
+ system_prompt = messages[0]
46
+ other_messages = messages[1:]
47
+ else:
48
+ other_messages = messages
49
+
50
+ # 2. Calculate Budget
51
+ actual_reserve = min(reserve_tokens, int(max_tokens * 0.2))
52
+ allowed_tokens = max_tokens - actual_reserve
53
+
54
+ if system_prompt:
55
+ allowed_tokens -= ContextManager.count_tokens([system_prompt], model)
56
+
57
+ # 3. Quick Check: Is truncation even needed?
58
+ # This one call avoids the N calls in the loop below for most turns
59
+ if ContextManager.count_tokens(other_messages, model) <= allowed_tokens:
60
+ return messages
61
+
62
+ # 4. Truncate (Keeping the LATEST messages)
63
+ truncated = []
64
+ current_tokens = 0
65
+
66
+ if other_messages:
67
+ last_msg = other_messages[-1]
68
+ truncated.append(last_msg)
69
+ current_tokens += ContextManager.count_tokens([last_msg], model)
70
+
71
+ for msg in reversed(other_messages[:-1]):
72
+ msg_tokens = ContextManager.count_tokens([msg], model)
73
+ if current_tokens + msg_tokens > allowed_tokens:
74
+ break
75
+ truncated.insert(0, msg)
76
+ current_tokens += msg_tokens
77
+
78
+ # 5. Reconstruct
79
+ result = []
80
+ if system_prompt:
81
+ result.append(system_prompt)
82
+ result.extend(truncated)
83
+
84
+ return result
@@ -25,3 +25,13 @@ class BaseMemory(ABC):
25
25
  def clear(self, session_id: str):
26
26
  """Clear history for a session."""
27
27
  pass
28
+
29
+ @abstractmethod
30
+ def get_session_metadata(self, session_id: str) -> Dict[str, Any]:
31
+ """Retrieve metadata/state for a given session."""
32
+ pass
33
+
34
+ @abstractmethod
35
+ def save_session_metadata(self, session_id: str, metadata: Dict[str, Any]):
36
+ """Save/Update metadata/state for a given session."""
37
+ pass
@@ -5,18 +5,24 @@ from .base import BaseMemory
5
5
 
6
6
  class FileMemory(BaseMemory):
7
7
  """
8
- JSON file-based persistence for conversation history.
8
+ JSON file-based persistence for conversation history and session metadata.
9
9
  """
10
10
  def __init__(self, file_path: str = "conversations.json"):
11
11
  self.file_path = file_path
12
- self._cache: Dict[str, List[Dict[str, Any]]] = {}
12
+ self._cache: Dict[str, Dict[str, Any]] = {}
13
13
  self._load()
14
14
 
15
15
  def _load(self):
16
16
  if os.path.exists(self.file_path):
17
17
  with open(self.file_path, "r", encoding="utf-8") as f:
18
18
  try:
19
- self._cache = json.load(f)
19
+ data = json.load(f)
20
+ # Support legacy format (where values were just lists)
21
+ for k, v in data.items():
22
+ if isinstance(v, list):
23
+ self._cache[k] = {"messages": v, "metadata": {}}
24
+ else:
25
+ self._cache[k] = v
20
26
  except json.JSONDecodeError:
21
27
  self._cache = {}
22
28
  else:
@@ -27,21 +33,30 @@ class FileMemory(BaseMemory):
27
33
  json.dump(self._cache, f, indent=2, ensure_ascii=False)
28
34
 
29
35
  def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
30
- return self._cache.get(session_id, []).copy()
36
+ return self._cache.get(session_id, {}).get("messages", []).copy()
31
37
 
32
38
  def add_message(self, session_id: str, message: Dict[str, Any]):
33
39
  if session_id not in self._cache:
34
- self._cache[session_id] = []
35
- self._cache[session_id].append(message)
40
+ self._cache[session_id] = {"messages": [], "metadata": {}}
41
+ self._cache[session_id]["messages"].append(message)
36
42
  self._save()
37
43
 
38
44
  def add_messages(self, session_id: str, messages: List[Dict[str, Any]]):
39
45
  if session_id not in self._cache:
40
- self._cache[session_id] = []
41
- self._cache[session_id].extend(messages)
46
+ self._cache[session_id] = {"messages": [], "metadata": {}}
47
+ self._cache[session_id]["messages"].extend(messages)
42
48
  self._save()
43
49
 
44
50
  def clear(self, session_id: str):
45
51
  if session_id in self._cache:
46
- self._cache[session_id] = []
52
+ self._cache[session_id] = {"messages": [], "metadata": {}}
47
53
  self._save()
54
+
55
+ def get_session_metadata(self, session_id: str) -> Dict[str, Any]:
56
+ return self._cache.get(session_id, {}).get("metadata", {}).copy()
57
+
58
+ def save_session_metadata(self, session_id: str, metadata: Dict[str, Any]):
59
+ if session_id not in self._cache:
60
+ self._cache[session_id] = {"messages": [], "metadata": {}}
61
+ self._cache[session_id]["metadata"] = metadata
62
+ self._save()
@@ -3,10 +3,11 @@ from .base import BaseMemory
3
3
 
4
4
  class InMemoryMemory(BaseMemory):
5
5
  """
6
- Standard in-memory store for conversation history.
6
+ Standard in-memory store for conversation history and session metadata.
7
7
  """
8
8
  def __init__(self):
9
9
  self._storage: Dict[str, List[Dict[str, Any]]] = {}
10
+ self._metadata: Dict[str, Dict[str, Any]] = {}
10
11
 
11
12
  def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
12
13
  return self._storage.get(session_id, []).copy()
@@ -24,3 +25,11 @@ class InMemoryMemory(BaseMemory):
24
25
  def clear(self, session_id: str):
25
26
  if session_id in self._storage:
26
27
  self._storage[session_id] = []
28
+ if session_id in self._metadata:
29
+ self._metadata[session_id] = {}
30
+
31
+ def get_session_metadata(self, session_id: str) -> Dict[str, Any]:
32
+ return self._metadata.get(session_id, {}).copy()
33
+
34
+ def save_session_metadata(self, session_id: str, metadata: Dict[str, Any]):
35
+ self._metadata[session_id] = metadata.copy()
@@ -4,7 +4,7 @@ import pymongo
4
4
 
5
5
  class MongoDBMemory(BaseMemory):
6
6
  """
7
- MongoDB-based persistence for conversation history.
7
+ MongoDB-based persistence for conversation history and session metadata.
8
8
  """
9
9
  def __init__(
10
10
  self,
@@ -21,8 +21,6 @@ class MongoDBMemory(BaseMemory):
21
21
  def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
22
22
  doc = self.collection.find_one({"session_id": session_id})
23
23
  if doc:
24
- # MongoDB returns a list of dicts, but we need to ensure
25
- # we return a copy to prevent in-place modifications
26
24
  return list(doc.get("messages", []))
27
25
  return []
28
26
 
@@ -43,5 +41,18 @@ class MongoDBMemory(BaseMemory):
43
41
  def clear(self, session_id: str):
44
42
  self.collection.update_one(
45
43
  {"session_id": session_id},
46
- {"$set": {"messages": []}}
44
+ {"$set": {"messages": [], "metadata": {}}}
45
+ )
46
+
47
+ def get_session_metadata(self, session_id: str) -> Dict[str, Any]:
48
+ doc = self.collection.find_one({"session_id": session_id})
49
+ if doc:
50
+ return doc.get("metadata", {})
51
+ return {}
52
+
53
+ def save_session_metadata(self, session_id: str, metadata: Dict[str, Any]):
54
+ self.collection.update_one(
55
+ {"session_id": session_id},
56
+ {"$set": {"metadata": metadata}},
57
+ upsert=True
47
58
  )
@@ -0,0 +1,3 @@
1
+ from .session import Session
2
+
3
+ __all__ = ["Session"]
@@ -0,0 +1,24 @@
1
+ from typing import Dict, Any, Optional
2
+ from pydantic import BaseModel, Field
3
+ import uuid
4
+ import time
5
+
6
+ class Session(BaseModel):
7
+ """
8
+ Represents a conversation session with metadata and state.
9
+ """
10
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
11
+ user_id: Optional[str] = None
12
+ app_name: Optional[str] = None
13
+ metadata: Dict[str, Any] = Field(default_factory=dict)
14
+ state: Dict[str, Any] = Field(default_factory=dict)
15
+ created_at: float = Field(default_factory=time.time)
16
+ updated_at: float = Field(default_factory=time.time)
17
+
18
+ def update_state(self, key: str, value: Any):
19
+ self.state[key] = value
20
+ self.updated_at = time.time()
21
+
22
+ def update_metadata(self, key: str, value: Any):
23
+ self.metadata[key] = value
24
+ self.updated_at = time.time()
@@ -79,7 +79,37 @@ class ToolRegistry:
79
79
  raise ValueError(f"Tool '{name}' not found in registry.")
80
80
 
81
81
  adk_logger.info(f"Executing tool: {name} with args: {kwargs}")
82
- return self._tools[name]["func"](**kwargs)
82
+ func = self._tools[name]["func"]
83
+
84
+ # Handle both sync and async functions if called synchronously
85
+ if inspect.iscoroutinefunction(func):
86
+ import asyncio
87
+ try:
88
+ loop = asyncio.get_event_loop()
89
+ if loop.is_running():
90
+ # If already in a loop, we can't block. This is a fallback risk.
91
+ adk_logger.warning(f"Sync execution of async tool '{name}' in running loop.")
92
+ return asyncio.run_coroutine_threadsafe(func(**kwargs), loop).result()
93
+ return asyncio.run(func(**kwargs))
94
+ except RuntimeError:
95
+ return asyncio.run(func(**kwargs))
96
+
97
+ return func(**kwargs)
98
+
99
+ async def aexecute(self, name: str, **kwargs) -> Any:
100
+ """
101
+ Asynchronously executes a registered tool by name.
102
+ """
103
+ if name not in self._tools:
104
+ raise ValueError(f"Tool '{name}' not found in registry.")
105
+
106
+ adk_logger.info(f"Executing tool (async): {name} with args: {kwargs}")
107
+ func = self._tools[name]["func"]
108
+
109
+ if inspect.iscoroutinefunction(func):
110
+ return await func(**kwargs)
111
+ else:
112
+ return func(**kwargs)
83
113
 
84
114
  # Global tool registry
85
115
  tool_registry = ToolRegistry()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: litellm-adk
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Production-grade multiservice Agent Development Kit
5
5
  License-Expression: MIT
6
6
  Classifier: Programming Language :: Python :: 3
@@ -0,0 +1,21 @@
1
+ litellm_adk/__init__.py,sha256=k82anVFffmjIP32AcLCFsEtKjk6YXbEpZ8spy1UWFko,414
2
+ litellm_adk/agents/__init__.py,sha256=KsiCcyYyn0iJo0sZsd7n3nJ5ezVEqJqrrTP1b9ryG0M,69
3
+ litellm_adk/config/settings.py,sha256=sdI4PrJKzIRJPm5vEBKdiecRcjrHNASsYbmtOiBnY_c,1182
4
+ litellm_adk/core/agent.py,sha256=bCigFJCKAkbAeiMt68mbDSq-zoSlYCQgMp3MowksKok,26856
5
+ litellm_adk/core/base.py,sha256=ov2bZk_a15FFGsQSdKwHrQ1cvALdZM8ByK5hGvFWyL0,386
6
+ litellm_adk/core/context.py,sha256=7tPPRAQ79EumA-mnaG39bsG8PpLqwCpBT-f8zq0Ftn0,3162
7
+ litellm_adk/memory/__init__.py,sha256=ICPUbV0PsTHEQSm0S35_d1ToeyrgMVFs_hRokvRRJL4,212
8
+ litellm_adk/memory/base.py,sha256=gfwhcORWZJl4qqb0u2uIPOLos1MYwjwkqiLT242ExSM,1138
9
+ litellm_adk/memory/file.py,sha256=_9oFobFE3tbM1f7L3jBBEEVXl6SM4mMITnNxzTO5iKg,2519
10
+ litellm_adk/memory/in_memory.py,sha256=K5sR_z5W8Yyd8Uj3DDD6an6RggBLXiYAjYVJgG7z94o,1384
11
+ litellm_adk/memory/mongodb.py,sha256=Sxsuemfxhxm3NeyoYDWC6xnH56uZ2idSAViEGKK1cjo,2012
12
+ litellm_adk/observability/logger.py,sha256=PXr20D7gtDIrg6eZD8Hm1-tfAuTXyUVDUMD9-8Aw32E,619
13
+ litellm_adk/session/__init__.py,sha256=G1F-Xgmj-Aezewv977vZKflhmZnqEwrYkcwqMTyVpWI,55
14
+ litellm_adk/session/session.py,sha256=CE-bovngwsRMjjYCC1gdDjJKnewRPJO8Qw2akMJTdbw,840
15
+ litellm_adk/tools/__init__.py,sha256=J-Rkx-psP5sZXgcy5h4mygvQd-tZUONKLYt4LSOiEV8,82
16
+ litellm_adk/tools/registry.py,sha256=UTZrWuSWDhdu8arZt9bfkHz8V0sy-Odpiki6Sh6Up4E,4337
17
+ litellm_adk-0.2.1.dist-info/licenses/LICENSE,sha256=BfYjX2LxngGX9t6Dk1Y5ptJNAkKcQuGG-OAR9jsKUGM,1091
18
+ litellm_adk-0.2.1.dist-info/METADATA,sha256=PZFvhminGSfIw_pOMpFYcoRb6c3fIONvjFzcW6MYAjg,2981
19
+ litellm_adk-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
20
+ litellm_adk-0.2.1.dist-info/top_level.txt,sha256=30MPgkTEjMUe8z-jnjMM2vbtqdghK_isd_ufRQ1w2hM,12
21
+ litellm_adk-0.2.1.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- litellm_adk/__init__.py,sha256=DfQsjG9D_5OvxX-8tOtGY3kuCxYb6Hn_I-2BvnTqf0M,368
2
- litellm_adk/agents/__init__.py,sha256=KsiCcyYyn0iJo0sZsd7n3nJ5ezVEqJqrrTP1b9ryG0M,69
3
- litellm_adk/config/settings.py,sha256=sdI4PrJKzIRJPm5vEBKdiecRcjrHNASsYbmtOiBnY_c,1182
4
- litellm_adk/core/agent.py,sha256=XE6dJZ6q3ZDnaMLtYHislnx6-IDTlXrrBx5q1_4Nt-4,22248
5
- litellm_adk/core/base.py,sha256=ov2bZk_a15FFGsQSdKwHrQ1cvALdZM8ByK5hGvFWyL0,386
6
- litellm_adk/memory/__init__.py,sha256=ICPUbV0PsTHEQSm0S35_d1ToeyrgMVFs_hRokvRRJL4,212
7
- litellm_adk/memory/base.py,sha256=Bm33oPaLNOJdG0RJGc38g387GSMyi_ymQjOMlDexTyk,788
8
- litellm_adk/memory/file.py,sha256=C0pB1pWJ4HtjCn6ICe54pL3cVmVnPI6D5jRvQaffgE4,1623
9
- litellm_adk/memory/in_memory.py,sha256=AVMV7iqb-UvbPE-CZmRi14LkV-7hEqrkwtkrwlxvy_w,951
10
- litellm_adk/memory/mongodb.py,sha256=M7IQsgahT6ALSNTQ2AKjSUWGR7uuz-KielBYIu_oLVk,1657
11
- litellm_adk/observability/logger.py,sha256=PXr20D7gtDIrg6eZD8Hm1-tfAuTXyUVDUMD9-8Aw32E,619
12
- litellm_adk/tools/__init__.py,sha256=J-Rkx-psP5sZXgcy5h4mygvQd-tZUONKLYt4LSOiEV8,82
13
- litellm_adk/tools/registry.py,sha256=M_48BpN0XSea_3msjGSyyDDRWu9uBNDLDtLh9Vh5yp8,3089
14
- litellm_adk-0.2.0.dist-info/licenses/LICENSE,sha256=BfYjX2LxngGX9t6Dk1Y5ptJNAkKcQuGG-OAR9jsKUGM,1091
15
- litellm_adk-0.2.0.dist-info/METADATA,sha256=x7FfuvUhi-u365L7AUzXP3qfF_9PQQ-dNwL7L2iNd14,2981
16
- litellm_adk-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
17
- litellm_adk-0.2.0.dist-info/top_level.txt,sha256=30MPgkTEjMUe8z-jnjMM2vbtqdghK_isd_ufRQ1w2hM,12
18
- litellm_adk-0.2.0.dist-info/RECORD,,