lite-agent 0.6.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lite-agent might be problematic. Click here for more details.

lite_agent/agent.py CHANGED
@@ -7,9 +7,21 @@ from funcall import Funcall
7
7
  from jinja2 import Environment, FileSystemLoader
8
8
 
9
9
  from lite_agent.client import BaseLLMClient, LiteLLMClient, ReasoningConfig
10
+ from lite_agent.constants import CompletionMode, ToolName
10
11
  from lite_agent.loggers import logger
11
12
  from lite_agent.response_handlers import CompletionResponseHandler, ResponsesAPIHandler
12
- from lite_agent.types import AgentChunk, FunctionCallEvent, FunctionCallOutputEvent, RunnerMessages, ToolCall, message_to_llm_dict, system_message_to_llm_dict
13
+ from lite_agent.types import (
14
+ AgentChunk,
15
+ AssistantTextContent,
16
+ AssistantToolCall,
17
+ AssistantToolCallResult,
18
+ FunctionCallEvent,
19
+ FunctionCallOutputEvent,
20
+ RunnerMessages,
21
+ ToolCall,
22
+ message_to_llm_dict,
23
+ system_message_to_llm_dict,
24
+ )
13
25
  from lite_agent.types.messages import NewAssistantMessage, NewSystemMessage, NewUserMessage
14
26
 
15
27
  TEMPLATES_DIR = Path(__file__).parent / "templates"
@@ -32,10 +44,24 @@ class Agent:
32
44
  message_transfer: Callable[[RunnerMessages], RunnerMessages] | None = None,
33
45
  completion_condition: str = "stop",
34
46
  reasoning: ReasoningConfig = None,
47
+ stop_before_tools: list[str] | list[Callable] | None = None,
35
48
  ) -> None:
36
49
  self.name = name
37
50
  self.instructions = instructions
38
51
  self.reasoning = reasoning
52
+ # Convert stop_before_functions to function names
53
+ if stop_before_tools:
54
+ self.stop_before_functions = set()
55
+ for func in stop_before_tools:
56
+ if isinstance(func, str):
57
+ self.stop_before_functions.add(func)
58
+ elif callable(func):
59
+ self.stop_before_functions.add(func.__name__)
60
+ else:
61
+ msg = f"stop_before_functions must contain strings or callables, got {type(func)}"
62
+ raise TypeError(msg)
63
+ else:
64
+ self.stop_before_functions = set()
39
65
 
40
66
  if isinstance(model, BaseLLMClient):
41
67
  # If model is a BaseLLMClient instance, use it directly
@@ -54,7 +80,7 @@ class Agent:
54
80
  self.fc = Funcall(tools)
55
81
 
56
82
  # Add wait_for_user tool if completion condition is "call"
57
- if completion_condition == "call":
83
+ if completion_condition == CompletionMode.CALL:
58
84
  self._add_wait_for_user_tool()
59
85
 
60
86
  # Set parent for handoff agents
@@ -99,7 +125,7 @@ class Agent:
99
125
 
100
126
  # Add single dynamic tool for all transfers
101
127
  self.fc.add_dynamic_tool(
102
- name="transfer_to_agent",
128
+ name=ToolName.TRANSFER_TO_AGENT,
103
129
  description="Transfer conversation to another agent.",
104
130
  parameters={
105
131
  "name": {
@@ -129,7 +155,7 @@ class Agent:
129
155
 
130
156
  # Add dynamic tool for parent transfer
131
157
  self.fc.add_dynamic_tool(
132
- name="transfer_to_parent",
158
+ name=ToolName.TRANSFER_TO_PARENT,
133
159
  description="Transfer conversation back to parent agent when current task is completed or cannot be solved by current agent",
134
160
  parameters={},
135
161
  required=[],
@@ -160,7 +186,7 @@ class Agent:
160
186
  try:
161
187
  # Try to remove the existing transfer tool
162
188
  if hasattr(self.fc, "remove_dynamic_tool"):
163
- self.fc.remove_dynamic_tool("transfer_to_agent")
189
+ self.fc.remove_dynamic_tool(ToolName.TRANSFER_TO_AGENT)
164
190
  except Exception as e:
165
191
  # If removal fails, log and continue anyway
166
192
  logger.debug(f"Failed to remove existing transfer tool: {e}")
@@ -205,31 +231,30 @@ class Agent:
205
231
  for message in messages:
206
232
  if isinstance(message, NewAssistantMessage):
207
233
  for item in message.content:
208
- match item.type:
209
- case "text":
210
- res.append(
211
- {
212
- "role": "assistant",
213
- "content": item.text,
214
- },
215
- )
216
- case "tool_call":
217
- res.append(
218
- {
219
- "type": "function_call",
220
- "call_id": item.call_id,
221
- "name": item.name,
222
- "arguments": item.arguments,
223
- },
224
- )
225
- case "tool_call_result":
226
- res.append(
227
- {
228
- "type": "function_call_output",
229
- "call_id": item.call_id,
230
- "output": item.output,
231
- },
232
- )
234
+ if isinstance(item, AssistantTextContent):
235
+ res.append(
236
+ {
237
+ "role": "assistant",
238
+ "content": item.text,
239
+ },
240
+ )
241
+ elif isinstance(item, AssistantToolCall):
242
+ res.append(
243
+ {
244
+ "type": "function_call",
245
+ "call_id": item.call_id,
246
+ "name": item.name,
247
+ "arguments": item.arguments,
248
+ },
249
+ )
250
+ elif isinstance(item, AssistantToolCallResult):
251
+ res.append(
252
+ {
253
+ "type": "function_call_output",
254
+ "call_id": item.call_id,
255
+ "output": item.output,
256
+ },
257
+ )
233
258
  elif isinstance(message, NewSystemMessage):
234
259
  res.append(
235
260
  {
@@ -269,9 +294,6 @@ class Agent:
269
294
  "content": contents,
270
295
  },
271
296
  )
272
- # Handle dict messages (legacy format)
273
- elif isinstance(message, dict):
274
- res.append(message)
275
297
  return res
276
298
 
277
299
  async def completion(
@@ -279,6 +301,7 @@ class Agent:
279
301
  messages: RunnerMessages,
280
302
  record_to_file: Path | None = None,
281
303
  reasoning: ReasoningConfig = None,
304
+ *,
282
305
  streaming: bool = True,
283
306
  ) -> AsyncGenerator[AgentChunk, None]:
284
307
  # Apply message transfer callback if provided - always use legacy format for LLM compatibility
@@ -301,13 +324,14 @@ class Agent:
301
324
 
302
325
  # Use response handler for unified processing
303
326
  handler = CompletionResponseHandler()
304
- return handler.handle(resp, streaming, record_to_file)
327
+ return handler.handle(resp, streaming=streaming, record_to=record_to_file)
305
328
 
306
329
  async def responses(
307
330
  self,
308
331
  messages: RunnerMessages,
309
332
  record_to_file: Path | None = None,
310
333
  reasoning: ReasoningConfig = None,
334
+ *,
311
335
  streaming: bool = True,
312
336
  ) -> AsyncGenerator[AgentChunk, None]:
313
337
  # Apply message transfer callback if provided - always use legacy format for LLM compatibility
@@ -328,20 +352,29 @@ class Agent:
328
352
  )
329
353
  # Use response handler for unified processing
330
354
  handler = ResponsesAPIHandler()
331
- return handler.handle(resp, streaming, record_to_file)
355
+ return handler.handle(resp, streaming=streaming, record_to=record_to_file)
332
356
 
333
357
  async def list_require_confirm_tools(self, tool_calls: Sequence[ToolCall] | None) -> Sequence[ToolCall]:
334
358
  if not tool_calls:
335
359
  return []
336
360
  results = []
337
361
  for tool_call in tool_calls:
338
- tool_func = self.fc.function_registry.get(tool_call.function.name)
362
+ function_name = tool_call.function.name
363
+
364
+ # Check if function is in dynamic stop_before_functions list
365
+ if function_name in self.stop_before_functions:
366
+ logger.debug('Tool call "%s" requires confirmation (stop_before_functions)', tool_call.id)
367
+ results.append(tool_call)
368
+ continue
369
+
370
+ # Check decorator-based require_confirmation
371
+ tool_func = self.fc.function_registry.get(function_name)
339
372
  if not tool_func:
340
- logger.warning("Tool function %s not found in registry", tool_call.function.name)
373
+ logger.warning("Tool function %s not found in registry", function_name)
341
374
  continue
342
- tool_meta = self.fc.get_tool_meta(tool_call.function.name)
375
+ tool_meta = self.fc.get_tool_meta(function_name)
343
376
  if tool_meta["require_confirm"]:
344
- logger.debug('Tool call "%s" requires confirmation', tool_call.id)
377
+ logger.debug('Tool call "%s" requires confirmation (decorator)', tool_call.id)
345
378
  results.append(tool_call)
346
379
  return results
347
380
 
@@ -396,10 +429,79 @@ class Agent:
396
429
  role = message_dict.get("role")
397
430
 
398
431
  if role == "assistant":
399
- # Look ahead for function_call messages
432
+ # For NewAssistantMessage, extract directly from the message object
400
433
  tool_calls = []
434
+ tool_results = []
435
+
436
+ if isinstance(message, NewAssistantMessage):
437
+ # Process content directly from NewAssistantMessage
438
+ for item in message.content:
439
+ if item.type == "tool_call":
440
+ tool_call = {
441
+ "id": item.call_id,
442
+ "type": "function",
443
+ "function": {
444
+ "name": item.name,
445
+ "arguments": item.arguments,
446
+ },
447
+ "index": len(tool_calls),
448
+ }
449
+ tool_calls.append(tool_call)
450
+ elif item.type == "tool_call_result":
451
+ # Collect tool call results to be added as separate tool messages
452
+ tool_results.append({
453
+ "call_id": item.call_id,
454
+ "output": item.output,
455
+ })
456
+
457
+ # Create assistant message with only text content and tool calls
458
+ text_content = " ".join([item.text for item in message.content if item.type == "text"])
459
+ message_dict = {
460
+ "role": "assistant",
461
+ "content": text_content if text_content else None,
462
+ }
463
+ if tool_calls:
464
+ message_dict["tool_calls"] = tool_calls
465
+ else:
466
+ # Legacy handling for dict messages
467
+ content = message_dict.get("content", [])
468
+ # Handle both string and array content
469
+ if isinstance(content, list):
470
+ # Extract tool_calls and tool_call_results from content array and filter out non-text content
471
+ filtered_content = []
472
+ for item in content:
473
+ if isinstance(item, dict):
474
+ if item.get("type") == "tool_call":
475
+ tool_call = {
476
+ "id": item.get("call_id", ""),
477
+ "type": "function",
478
+ "function": {
479
+ "name": item.get("name", ""),
480
+ "arguments": item.get("arguments", "{}"),
481
+ },
482
+ "index": len(tool_calls),
483
+ }
484
+ tool_calls.append(tool_call)
485
+ elif item.get("type") == "tool_call_result":
486
+ # Collect tool call results to be added as separate tool messages
487
+ tool_results.append({
488
+ "call_id": item.get("call_id", ""),
489
+ "output": item.get("output", ""),
490
+ })
491
+ elif item.get("type") == "text":
492
+ filtered_content.append(item)
493
+
494
+ # Update content to only include text items
495
+ if filtered_content:
496
+ message_dict = message_dict.copy()
497
+ message_dict["content"] = filtered_content
498
+ elif tool_calls:
499
+ # If we have tool_calls but no text content, set content to None per OpenAI API spec
500
+ message_dict = message_dict.copy()
501
+ message_dict["content"] = None
502
+
503
+ # Look ahead for function_call messages (legacy support)
401
504
  j = i + 1
402
-
403
505
  while j < len(messages):
404
506
  next_message = messages[j]
405
507
  next_dict = message_to_llm_dict(next_message) if isinstance(next_message, (NewUserMessage, NewSystemMessage, NewAssistantMessage)) else next_message
@@ -419,12 +521,33 @@ class Agent:
419
521
  else:
420
522
  break
421
523
 
422
- # Create assistant message with tool_calls if any
423
- assistant_msg = message_dict.copy()
424
- if tool_calls:
425
- assistant_msg["tool_calls"] = tool_calls # type: ignore
524
+ # For legacy dict messages, create assistant message with tool_calls if any
525
+ if not isinstance(message, NewAssistantMessage):
526
+ assistant_msg = message_dict.copy()
527
+ if tool_calls:
528
+ assistant_msg["tool_calls"] = tool_calls # type: ignore
529
+
530
+ # Convert content format for OpenAI API compatibility
531
+ content = assistant_msg.get("content", [])
532
+ if isinstance(content, list):
533
+ # Extract text content and convert to string using list comprehension
534
+ text_parts = [item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text"]
535
+ assistant_msg["content"] = " ".join(text_parts) if text_parts else None
536
+
537
+ message_dict = assistant_msg
538
+
539
+ converted_messages.append(message_dict)
540
+
541
+ # Add tool messages for any tool_call_results found in the assistant message
542
+ converted_messages.extend([
543
+ {
544
+ "role": "tool",
545
+ "tool_call_id": tool_result["call_id"],
546
+ "content": tool_result["output"],
547
+ }
548
+ for tool_result in tool_results
549
+ ])
426
550
 
427
- converted_messages.append(assistant_msg)
428
551
  i = j # Skip the function_call messages we've processed
429
552
 
430
553
  elif message_type == "function_call_output":
@@ -536,10 +659,73 @@ class Agent:
536
659
 
537
660
  # Add dynamic tool for task completion
538
661
  self.fc.add_dynamic_tool(
539
- name="wait_for_user",
662
+ name=ToolName.WAIT_FOR_USER,
540
663
  description="Call this function when you have completed your assigned task or need more information from the user.",
541
664
  parameters={},
542
665
  required=[],
543
666
  handler=wait_for_user_handler,
544
667
  )
545
668
 
669
+ def set_stop_before_functions(self, functions: list[str] | list[Callable]) -> None:
670
+ """Set the list of functions that require confirmation before execution.
671
+
672
+ Args:
673
+ functions: List of function names (str) or callable objects
674
+ """
675
+ self.stop_before_functions = set()
676
+ for func in functions:
677
+ if isinstance(func, str):
678
+ self.stop_before_functions.add(func)
679
+ elif callable(func):
680
+ self.stop_before_functions.add(func.__name__)
681
+ else:
682
+ msg = f"stop_before_functions must contain strings or callables, got {type(func)}"
683
+ raise TypeError(msg)
684
+ logger.debug(f"Set stop_before_functions to: {self.stop_before_functions}")
685
+
686
+ def add_stop_before_function(self, function: str | Callable) -> None:
687
+ """Add a function to the stop_before_functions list.
688
+
689
+ Args:
690
+ function: Function name (str) or callable object to add
691
+ """
692
+ if isinstance(function, str):
693
+ function_name = function
694
+ elif callable(function):
695
+ function_name = function.__name__
696
+ else:
697
+ msg = f"function must be a string or callable, got {type(function)}"
698
+ raise TypeError(msg)
699
+
700
+ self.stop_before_functions.add(function_name)
701
+ logger.debug(f"Added '{function_name}' to stop_before_functions")
702
+
703
+ def remove_stop_before_function(self, function: str | Callable) -> None:
704
+ """Remove a function from the stop_before_functions list.
705
+
706
+ Args:
707
+ function: Function name (str) or callable object to remove
708
+ """
709
+ if isinstance(function, str):
710
+ function_name = function
711
+ elif callable(function):
712
+ function_name = function.__name__
713
+ else:
714
+ msg = f"function must be a string or callable, got {type(function)}"
715
+ raise TypeError(msg)
716
+
717
+ self.stop_before_functions.discard(function_name)
718
+ logger.debug(f"Removed '{function_name}' from stop_before_functions")
719
+
720
+ def clear_stop_before_functions(self) -> None:
721
+ """Clear all function names from the stop_before_functions list."""
722
+ self.stop_before_functions.clear()
723
+ logger.debug("Cleared all stop_before_functions")
724
+
725
+ def get_stop_before_functions(self) -> set[str]:
726
+ """Get the current set of function names that require confirmation.
727
+
728
+ Returns:
729
+ Set of function names
730
+ """
731
+ return self.stop_before_functions.copy()