ummaya 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,7 +82,7 @@ def _assemble_tool_calls(
82
82
 
83
83
 
84
84
  def _tool_definition_name(tool_def: ToolDefinition | dict[str, object]) -> str | None:
85
- """Extract a root primitive name from an OpenAI tool definition."""
85
+ """Extract a function name from an OpenAI tool definition."""
86
86
 
87
87
  if isinstance(tool_def, ToolDefinition):
88
88
  return tool_def.function.name
@@ -93,6 +93,27 @@ def _tool_definition_name(tool_def: ToolDefinition | dict[str, object]) -> str |
93
93
  return name if isinstance(name, str) else None
94
94
 
95
95
 
96
+ def _export_turn_tool_definitions(
97
+ tool_registry: ToolRegistry,
98
+ tool_ids: tuple[str, ...],
99
+ ) -> list[dict[str, object]]:
100
+ """Export selected concrete adapter schemas in ranking order."""
101
+
102
+ tool_defs: list[dict[str, object]] = []
103
+ seen: set[str] = set()
104
+ for tool_id in tool_ids:
105
+ if tool_id in seen:
106
+ continue
107
+ seen.add(tool_id)
108
+ try:
109
+ tool = tool_registry.find(tool_id)
110
+ except ToolNotFoundError:
111
+ logger.warning("Selected turn tool disappeared from registry: %s", tool_id)
112
+ continue
113
+ tool_defs.append(tool.to_openai_tool())
114
+ return tool_defs
115
+
116
+
96
117
  def _latest_successful_tool_payload(messages: list[ChatMessage]) -> dict[str, object] | None:
97
118
  """Return the latest non-error tool-result JSON payload, if present."""
98
119
 
@@ -296,14 +317,38 @@ async def dispatch_tool_calls( # noqa: C901
296
317
 
297
318
  async def _dispatch_one(tc: ToolCall) -> ToolResult:
298
319
  """Dispatch a single tool call via the executor."""
299
- if tc.function.name in {"find", "locate"}:
320
+ if tc.function.name in {"find", "locate", "check", "send"}:
300
321
  return await _dispatch_root_primitive(
301
322
  tc,
302
323
  tool_registry,
303
324
  tool_executor,
304
325
  session_context=session_context,
305
326
  )
306
- return await tool_executor.dispatch(tc.function.name, tc.function.arguments)
327
+ try:
328
+ tool = tool_registry.find(tc.function.name)
329
+ except ToolNotFoundError:
330
+ return await tool_executor.dispatch(tc.function.name, tc.function.arguments)
331
+ gate = tool.policy.citizen_facing_gate if tool.policy is not None else None
332
+ if gate in {None, "read-only"}:
333
+ return await tool_executor.dispatch(
334
+ tc.function.name,
335
+ tc.function.arguments,
336
+ tool_call_id=tc.id,
337
+ )
338
+ primitive = tool.primitive
339
+ if primitive is None:
340
+ return ToolResult(
341
+ tool_id=tc.function.name,
342
+ success=False,
343
+ error=f"{tc.function.name} is missing primitive metadata for gated dispatch.",
344
+ error_type="schema_mismatch",
345
+ )
346
+ return await _dispatch_concrete_adapter(
347
+ tc,
348
+ primitive,
349
+ tool_executor,
350
+ session_context=session_context,
351
+ )
307
352
 
308
353
  async def _flush_group(items: list[tuple[int, ToolCall]], safe: bool) -> None:
309
354
  """Execute a group of tool calls, concurrently if safe."""
@@ -407,6 +452,59 @@ async def _dispatch_root_primitive(
407
452
  return ToolResult(tool_id=primitive, success=True, data=data)
408
453
 
409
454
 
455
+ async def _dispatch_concrete_adapter(
456
+ tc: ToolCall,
457
+ primitive: str,
458
+ tool_executor: ToolExecutor,
459
+ *,
460
+ session_context: SessionContext | None,
461
+ ) -> ToolResult:
462
+ """Dispatch a directly model-facing concrete adapter call."""
463
+
464
+ try:
465
+ raw_args = json.loads(tc.function.arguments)
466
+ except (TypeError, json.JSONDecodeError) as exc:
467
+ return ToolResult(
468
+ tool_id=tc.function.name,
469
+ success=False,
470
+ error=str(exc),
471
+ error_type="validation",
472
+ )
473
+ if not isinstance(raw_args, dict):
474
+ return ToolResult(
475
+ tool_id=tc.function.name,
476
+ success=False,
477
+ error=f"{tc.function.name} requires a JSON object argument.",
478
+ error_type="validation",
479
+ )
480
+
481
+ request_id = tc.id or f"{tc.function.name}-call"
482
+ if primitive == "find":
483
+ output = await tool_executor.invoke(
484
+ tc.function.name,
485
+ raw_args,
486
+ request_id=request_id,
487
+ session_identity=session_context,
488
+ )
489
+ else:
490
+ output = await tool_executor.invoke_raw(
491
+ tc.function.name,
492
+ raw_args,
493
+ request_id=request_id,
494
+ session_identity=session_context,
495
+ )
496
+
497
+ data = _primitive_output_dict(output)
498
+ if data.get("kind") == "error":
499
+ return ToolResult(
500
+ tool_id=tc.function.name,
501
+ success=False,
502
+ error=str(data.get("message") or data),
503
+ error_type="execution",
504
+ )
505
+ return ToolResult(tool_id=tc.function.name, success=True, data=data)
506
+
507
+
410
508
  def _primitive_output_dict(output: object) -> dict[str, object]:
411
509
  """Convert primitive facade output to ToolResult data."""
412
510
 
@@ -532,8 +630,13 @@ async def _query_inner(ctx: QueryContext) -> AsyncIterator[QueryEvent]: # noqa:
532
630
  tool_defs: list[ToolDefinition | dict[str, object]] | None = None
533
631
  force_no_tools_next_turn = False
534
632
  else:
535
- raw_defs = ctx.tool_registry.export_core_tools_openai()
536
- if ctx.allowed_core_tool_ids is not None:
633
+ raw_defs = _export_turn_tool_definitions(
634
+ ctx.tool_registry,
635
+ ctx.turn_tool_ids,
636
+ )
637
+ if not raw_defs:
638
+ raw_defs = ctx.tool_registry.export_core_tools_openai()
639
+ if ctx.allowed_core_tool_ids is not None and not ctx.turn_tool_ids:
537
640
  raw_defs = [
538
641
  tool_def
539
642
  for tool_def in raw_defs