tactus 0.31.0__py3-none-any.whl → 0.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. tactus/__init__.py +1 -1
  2. tactus/adapters/__init__.py +18 -1
  3. tactus/adapters/broker_log.py +127 -34
  4. tactus/adapters/channels/__init__.py +153 -0
  5. tactus/adapters/channels/base.py +174 -0
  6. tactus/adapters/channels/broker.py +179 -0
  7. tactus/adapters/channels/cli.py +448 -0
  8. tactus/adapters/channels/host.py +225 -0
  9. tactus/adapters/channels/ipc.py +297 -0
  10. tactus/adapters/channels/sse.py +305 -0
  11. tactus/adapters/cli_hitl.py +223 -1
  12. tactus/adapters/control_loop.py +879 -0
  13. tactus/adapters/file_storage.py +35 -2
  14. tactus/adapters/ide_log.py +7 -1
  15. tactus/backends/http_backend.py +0 -1
  16. tactus/broker/client.py +31 -1
  17. tactus/broker/server.py +416 -92
  18. tactus/cli/app.py +270 -7
  19. tactus/cli/control.py +393 -0
  20. tactus/core/config_manager.py +33 -6
  21. tactus/core/dsl_stubs.py +102 -18
  22. tactus/core/execution_context.py +265 -8
  23. tactus/core/lua_sandbox.py +8 -9
  24. tactus/core/registry.py +19 -2
  25. tactus/core/runtime.py +235 -27
  26. tactus/docker/Dockerfile.pypi +49 -0
  27. tactus/docs/__init__.py +33 -0
  28. tactus/docs/extractor.py +326 -0
  29. tactus/docs/html_renderer.py +72 -0
  30. tactus/docs/models.py +121 -0
  31. tactus/docs/templates/base.html +204 -0
  32. tactus/docs/templates/index.html +58 -0
  33. tactus/docs/templates/module.html +96 -0
  34. tactus/dspy/agent.py +403 -22
  35. tactus/dspy/broker_lm.py +57 -6
  36. tactus/dspy/config.py +14 -3
  37. tactus/dspy/history.py +2 -1
  38. tactus/dspy/module.py +136 -11
  39. tactus/dspy/signature.py +0 -1
  40. tactus/ide/config_server.py +536 -0
  41. tactus/ide/server.py +345 -21
  42. tactus/primitives/human.py +619 -47
  43. tactus/primitives/system.py +0 -1
  44. tactus/protocols/__init__.py +25 -0
  45. tactus/protocols/control.py +427 -0
  46. tactus/protocols/notification.py +207 -0
  47. tactus/sandbox/container_runner.py +79 -11
  48. tactus/sandbox/docker_manager.py +23 -0
  49. tactus/sandbox/entrypoint.py +26 -0
  50. tactus/sandbox/protocol.py +3 -0
  51. tactus/stdlib/README.md +77 -0
  52. tactus/stdlib/__init__.py +27 -1
  53. tactus/stdlib/classify/__init__.py +165 -0
  54. tactus/stdlib/classify/classify.spec.tac +195 -0
  55. tactus/stdlib/classify/classify.tac +257 -0
  56. tactus/stdlib/classify/fuzzy.py +282 -0
  57. tactus/stdlib/classify/llm.py +319 -0
  58. tactus/stdlib/classify/primitive.py +287 -0
  59. tactus/stdlib/core/__init__.py +57 -0
  60. tactus/stdlib/core/base.py +320 -0
  61. tactus/stdlib/core/confidence.py +211 -0
  62. tactus/stdlib/core/models.py +161 -0
  63. tactus/stdlib/core/retry.py +171 -0
  64. tactus/stdlib/core/validation.py +274 -0
  65. tactus/stdlib/extract/__init__.py +125 -0
  66. tactus/stdlib/extract/llm.py +330 -0
  67. tactus/stdlib/extract/primitive.py +256 -0
  68. tactus/stdlib/tac/tactus/classify/base.tac +51 -0
  69. tactus/stdlib/tac/tactus/classify/fuzzy.tac +87 -0
  70. tactus/stdlib/tac/tactus/classify/index.md +77 -0
  71. tactus/stdlib/tac/tactus/classify/init.tac +29 -0
  72. tactus/stdlib/tac/tactus/classify/llm.tac +150 -0
  73. tactus/stdlib/tac/tactus/classify.spec.tac +191 -0
  74. tactus/stdlib/tac/tactus/extract/base.tac +138 -0
  75. tactus/stdlib/tac/tactus/extract/index.md +96 -0
  76. tactus/stdlib/tac/tactus/extract/init.tac +27 -0
  77. tactus/stdlib/tac/tactus/extract/llm.tac +201 -0
  78. tactus/stdlib/tac/tactus/extract.spec.tac +153 -0
  79. tactus/stdlib/tac/tactus/generate/base.tac +142 -0
  80. tactus/stdlib/tac/tactus/generate/index.md +195 -0
  81. tactus/stdlib/tac/tactus/generate/init.tac +28 -0
  82. tactus/stdlib/tac/tactus/generate/llm.tac +169 -0
  83. tactus/stdlib/tac/tactus/generate.spec.tac +210 -0
  84. tactus/testing/behave_integration.py +171 -7
  85. tactus/testing/context.py +0 -1
  86. tactus/testing/evaluation_runner.py +0 -1
  87. tactus/testing/gherkin_parser.py +0 -1
  88. tactus/testing/mock_hitl.py +0 -1
  89. tactus/testing/mock_tools.py +0 -1
  90. tactus/testing/models.py +0 -1
  91. tactus/testing/steps/builtin.py +0 -1
  92. tactus/testing/steps/custom.py +81 -22
  93. tactus/testing/steps/registry.py +0 -1
  94. tactus/testing/test_runner.py +7 -1
  95. tactus/validation/semantic_visitor.py +11 -5
  96. tactus/validation/validator.py +0 -1
  97. {tactus-0.31.0.dist-info → tactus-0.34.1.dist-info}/METADATA +16 -2
  98. {tactus-0.31.0.dist-info → tactus-0.34.1.dist-info}/RECORD +101 -49
  99. {tactus-0.31.0.dist-info → tactus-0.34.1.dist-info}/WHEEL +0 -0
  100. {tactus-0.31.0.dist-info → tactus-0.34.1.dist-info}/entry_points.txt +0 -0
  101. {tactus-0.31.0.dist-info → tactus-0.34.1.dist-info}/licenses/LICENSE +0 -0
tactus/broker/server.py CHANGED
@@ -62,27 +62,19 @@ class OpenAIChatBackend:
62
62
  """
63
63
  Minimal OpenAI chat-completions backend used by the broker.
64
64
 
65
- Credentials are read from the broker process environment.
65
+ Credentials can be provided directly or read from the broker process environment.
66
66
  """
67
67
 
68
- def __init__(self, config: Optional[OpenAIChatConfig] = None):
68
+ def __init__(self, config: Optional[OpenAIChatConfig] = None, api_key: Optional[str] = None):
69
69
  self._config = config or OpenAIChatConfig()
70
+ self._api_key = api_key # Direct API key (bypasses environment)
70
71
 
71
72
  # Lazy-init the client so unit tests can run without OpenAI installed/configured.
72
73
  self._client = None
73
74
 
74
75
  def _get_client(self):
75
- if self._client is not None:
76
- return self._client
77
-
78
- from openai import AsyncOpenAI
79
-
80
- api_key = os.environ.get(self._config.api_key_env)
81
- if not api_key:
82
- raise RuntimeError(f"Missing OpenAI API key in environment: {self._config.api_key_env}")
83
-
84
- self._client = AsyncOpenAI(api_key=api_key)
85
- return self._client
76
+ # We don't need to maintain a client - LiteLLM handles that
77
+ return None
86
78
 
87
79
  async def chat(
88
80
  self,
@@ -92,19 +84,51 @@ class OpenAIChatBackend:
92
84
  temperature: Optional[float] = None,
93
85
  max_tokens: Optional[int] = None,
94
86
  stream: bool,
87
+ tools: Optional[list[dict[str, Any]]] = None,
88
+ tool_choice: Optional[str] = None,
95
89
  ):
96
- client = self._get_client()
90
+ # Use LiteLLM instead of raw OpenAI SDK for provider-agnostic support
91
+ import litellm
97
92
 
98
- kwargs: dict[str, Any] = {"model": model, "messages": messages}
93
+ # Set API key from environment if configured
94
+ api_key = self._api_key or os.environ.get(self._config.api_key_env)
95
+ if api_key:
96
+ os.environ[self._config.api_key_env] = api_key
97
+
98
+ kwargs: dict[str, Any] = {"model": model, "messages": messages, "stream": stream}
99
99
  if temperature is not None:
100
100
  kwargs["temperature"] = temperature
101
101
  if max_tokens is not None:
102
102
  kwargs["max_tokens"] = max_tokens
103
+ if tools is not None:
104
+ kwargs["tools"] = tools
105
+ logger.info(f"[LITELLM_BACKEND] Sending {len(tools)} tools to LiteLLM")
106
+ logger.info(f"[LITELLM_BACKEND] Tool schemas: {tools}")
107
+ if tool_choice is not None:
108
+ kwargs["tool_choice"] = tool_choice
109
+ logger.info(f"[LITELLM_BACKEND] Setting tool_choice={tool_choice}")
110
+
111
+ # Always use acompletion for consistency, LiteLLM handles both sync/async
112
+ result = await litellm.acompletion(**kwargs)
103
113
 
104
114
  if stream:
105
- return await client.chat.completions.create(**kwargs, stream=True)
115
+ logger.info("[LITELLM_BACKEND] LiteLLM streaming response started")
116
+ else:
117
+ logger.info(
118
+ f"[LITELLM_BACKEND] LiteLLM response: finish_reason={result.choices[0].finish_reason if result.choices else 'NO_CHOICES'}"
119
+ )
120
+ if (
121
+ result.choices
122
+ and hasattr(result.choices[0].message, "tool_calls")
123
+ and result.choices[0].message.tool_calls
124
+ ):
125
+ logger.info(
126
+ f"[LITELLM_BACKEND] LiteLLM returned {len(result.choices[0].message.tool_calls)} tool calls"
127
+ )
128
+ else:
129
+ logger.info("[LITELLM_BACKEND] LiteLLM returned NO tool calls")
106
130
 
107
- return await client.chat.completions.create(**kwargs)
131
+ return result
108
132
 
109
133
 
110
134
  class HostToolRegistry:
@@ -140,12 +164,14 @@ class _BaseBrokerServer:
140
164
  openai_backend: Optional[OpenAIChatBackend] = None,
141
165
  tool_registry: Optional[HostToolRegistry] = None,
142
166
  event_handler: Optional[Callable[[dict[str, Any]], None]] = None,
167
+ control_handler: Optional[Callable[[dict], Awaitable[dict]]] = None,
143
168
  ):
144
169
  self._listener = None
145
170
  self._serve_task: asyncio.Task[None] | None = None
146
171
  self._openai = openai_backend or OpenAIChatBackend()
147
172
  self._tools = tool_registry or HostToolRegistry.default()
148
173
  self._event_handler = event_handler
174
+ self._control_handler = control_handler
149
175
 
150
176
  async def start(self) -> None:
151
177
  raise NotImplementedError
@@ -220,6 +246,10 @@ class _BaseBrokerServer:
220
246
  await self._handle_events_emit(req_id, params, byte_stream)
221
247
  return
222
248
 
249
+ if method == "control.request":
250
+ await self._handle_control_request(req_id, params, byte_stream)
251
+ return
252
+
223
253
  if method == "llm.chat":
224
254
  await self._handle_llm_chat(req_id, params, byte_stream)
225
255
  return
@@ -367,6 +397,8 @@ class _BaseBrokerServer:
367
397
  stream = bool(params.get("stream", False))
368
398
  temperature = params.get("temperature")
369
399
  max_tokens = params.get("max_tokens")
400
+ tools = params.get("tools")
401
+ tool_choice = params.get("tool_choice")
370
402
 
371
403
  if not isinstance(model, str) or not model:
372
404
  await _write_event_asyncio(
@@ -397,37 +429,80 @@ class _BaseBrokerServer:
397
429
  temperature=temperature,
398
430
  max_tokens=max_tokens,
399
431
  stream=True,
432
+ tools=tools,
433
+ tool_choice=tool_choice,
400
434
  )
401
435
 
402
436
  full_text = ""
437
+ tool_calls_data = []
403
438
  async for chunk in stream_iter:
404
439
  try:
405
440
  delta = chunk.choices[0].delta
406
441
  text = getattr(delta, "content", None)
442
+ delta_tool_calls = getattr(delta, "tool_calls", None)
407
443
  except Exception:
408
444
  text = None
409
-
410
- if not text:
411
- continue
412
-
413
- full_text += text
414
- await _write_event_asyncio(
415
- writer, {"id": req_id, "event": "delta", "data": {"text": text}}
416
- )
445
+ delta_tool_calls = None
446
+
447
+ if text:
448
+ full_text += text
449
+ await _write_event_asyncio(
450
+ writer, {"id": req_id, "event": "delta", "data": {"text": text}}
451
+ )
452
+
453
+ # Accumulate tool calls from deltas
454
+ if delta_tool_calls:
455
+ logger.info(
456
+ f"[LITELLM_BACKEND] Received delta_tool_calls: {delta_tool_calls}"
457
+ )
458
+ for tc_delta in delta_tool_calls:
459
+ idx = tc_delta.index
460
+ # Extend tool_calls_data list if needed
461
+ while len(tool_calls_data) <= idx:
462
+ tool_calls_data.append(
463
+ {
464
+ "id": "",
465
+ "type": "function",
466
+ "function": {"name": "", "arguments": ""},
467
+ }
468
+ )
469
+
470
+ # Merge delta into accumulated tool call
471
+ if tc_delta.id:
472
+ tool_calls_data[idx]["id"] = tc_delta.id
473
+ if tc_delta.type:
474
+ tool_calls_data[idx]["type"] = tc_delta.type
475
+ if hasattr(tc_delta, "function") and tc_delta.function:
476
+ if tc_delta.function.name:
477
+ tool_calls_data[idx]["function"][
478
+ "name"
479
+ ] += tc_delta.function.name
480
+ if tc_delta.function.arguments:
481
+ tool_calls_data[idx]["function"][
482
+ "arguments"
483
+ ] += tc_delta.function.arguments
484
+
485
+ # Build final response data
486
+ logger.info(
487
+ f"[LITELLM_BACKEND] Streaming complete. tool_calls_data={tool_calls_data}, full_text length={len(full_text)}"
488
+ )
489
+ done_data = {
490
+ "text": full_text,
491
+ "usage": {
492
+ "prompt_tokens": 0,
493
+ "completion_tokens": 0,
494
+ "total_tokens": 0,
495
+ },
496
+ }
497
+ if tool_calls_data:
498
+ done_data["tool_calls"] = tool_calls_data
417
499
 
418
500
  await _write_event_asyncio(
419
501
  writer,
420
502
  {
421
503
  "id": req_id,
422
504
  "event": "done",
423
- "data": {
424
- "text": full_text,
425
- "usage": {
426
- "prompt_tokens": 0,
427
- "completion_tokens": 0,
428
- "total_tokens": 0,
429
- },
430
- },
505
+ "data": done_data,
431
506
  },
432
507
  )
433
508
  return
@@ -438,22 +513,47 @@ class _BaseBrokerServer:
438
513
  temperature=temperature,
439
514
  max_tokens=max_tokens,
440
515
  stream=False,
516
+ tools=tools,
517
+ tool_choice=tool_choice,
441
518
  )
442
519
  text = ""
520
+ tool_calls_data = None
443
521
  try:
444
- text = resp.choices[0].message.content or ""
522
+ message = resp.choices[0].message
523
+ text = message.content or ""
524
+
525
+ # Extract tool calls if present
526
+ if hasattr(message, "tool_calls") and message.tool_calls:
527
+ tool_calls_data = []
528
+ for tc in message.tool_calls:
529
+ tool_calls_data.append(
530
+ {
531
+ "id": tc.id,
532
+ "type": tc.type,
533
+ "function": {
534
+ "name": tc.function.name,
535
+ "arguments": tc.function.arguments,
536
+ },
537
+ }
538
+ )
445
539
  except Exception:
446
540
  text = ""
541
+ tool_calls_data = None
542
+
543
+ # Build response data
544
+ done_data = {
545
+ "text": text,
546
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
547
+ }
548
+ if tool_calls_data:
549
+ done_data["tool_calls"] = tool_calls_data
447
550
 
448
551
  await _write_event_asyncio(
449
552
  writer,
450
553
  {
451
554
  "id": req_id,
452
555
  "event": "done",
453
- "data": {
454
- "text": text,
455
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
456
- },
556
+ "data": done_data,
457
557
  },
458
558
  )
459
559
  except Exception as e:
@@ -548,6 +648,62 @@ class _BaseBrokerServer:
548
648
 
549
649
  await _write_event_anyio(byte_stream, {"id": req_id, "event": "done", "data": {"ok": True}})
550
650
 
651
+ async def _handle_control_request(
652
+ self, req_id: str, params: dict[str, Any], byte_stream: anyio.abc.ByteStream
653
+ ) -> None:
654
+ """Handle control.request method for HITL requests from container."""
655
+ request_data = params.get("request")
656
+ if not isinstance(request_data, dict):
657
+ await _write_event_anyio(
658
+ byte_stream,
659
+ {
660
+ "id": req_id,
661
+ "event": "error",
662
+ "error": {"type": "BadRequest", "message": "params.request must be an object"},
663
+ },
664
+ )
665
+ return
666
+
667
+ if self._control_handler is None:
668
+ await _write_event_anyio(
669
+ byte_stream,
670
+ {
671
+ "id": req_id,
672
+ "event": "error",
673
+ "error": {
674
+ "type": "NoControlHandler",
675
+ "message": "No control handler configured",
676
+ },
677
+ },
678
+ )
679
+ return
680
+
681
+ try:
682
+ # Send delivered event
683
+ await _write_event_anyio(byte_stream, {"id": req_id, "event": "delivered"})
684
+
685
+ # Call control handler and await response
686
+ response_data = await self._control_handler(request_data)
687
+
688
+ # Send response event
689
+ await _write_event_anyio(
690
+ byte_stream, {"id": req_id, "event": "response", "data": response_data}
691
+ )
692
+ except asyncio.TimeoutError:
693
+ await _write_event_anyio(
694
+ byte_stream, {"id": req_id, "event": "timeout", "data": {"timed_out": True}}
695
+ )
696
+ except Exception as e:
697
+ logger.debug("[BROKER] control.request handler raised", exc_info=True)
698
+ await _write_event_anyio(
699
+ byte_stream,
700
+ {
701
+ "id": req_id,
702
+ "event": "error",
703
+ "error": {"type": type(e).__name__, "message": str(e)},
704
+ },
705
+ )
706
+
551
707
  async def _handle_llm_chat(
552
708
  self, req_id: str, params: dict[str, Any], byte_stream: anyio.abc.ByteStream
553
709
  ) -> None:
@@ -571,6 +727,8 @@ class _BaseBrokerServer:
571
727
  stream = bool(params.get("stream", False))
572
728
  temperature = params.get("temperature")
573
729
  max_tokens = params.get("max_tokens")
730
+ tools = params.get("tools")
731
+ tool_choice = params.get("tool_choice")
574
732
 
575
733
  if not isinstance(model, str) or not model:
576
734
  await _write_event_anyio(
@@ -601,37 +759,80 @@ class _BaseBrokerServer:
601
759
  temperature=temperature,
602
760
  max_tokens=max_tokens,
603
761
  stream=True,
762
+ tools=tools,
763
+ tool_choice=tool_choice,
604
764
  )
605
765
 
606
766
  full_text = ""
767
+ tool_calls_data = []
607
768
  async for chunk in stream_iter:
608
769
  try:
609
770
  delta = chunk.choices[0].delta
610
771
  text = getattr(delta, "content", None)
772
+ delta_tool_calls = getattr(delta, "tool_calls", None)
611
773
  except Exception:
612
774
  text = None
613
-
614
- if not text:
615
- continue
616
-
617
- full_text += text
618
- await _write_event_anyio(
619
- byte_stream, {"id": req_id, "event": "delta", "data": {"text": text}}
620
- )
775
+ delta_tool_calls = None
776
+
777
+ if text:
778
+ full_text += text
779
+ await _write_event_anyio(
780
+ byte_stream, {"id": req_id, "event": "delta", "data": {"text": text}}
781
+ )
782
+
783
+ # Accumulate tool calls from deltas
784
+ if delta_tool_calls:
785
+ logger.info(
786
+ f"[LITELLM_BACKEND] Received delta_tool_calls: {delta_tool_calls}"
787
+ )
788
+ for tc_delta in delta_tool_calls:
789
+ idx = tc_delta.index
790
+ # Extend tool_calls_data list if needed
791
+ while len(tool_calls_data) <= idx:
792
+ tool_calls_data.append(
793
+ {
794
+ "id": "",
795
+ "type": "function",
796
+ "function": {"name": "", "arguments": ""},
797
+ }
798
+ )
799
+
800
+ # Merge delta into accumulated tool call
801
+ if tc_delta.id:
802
+ tool_calls_data[idx]["id"] = tc_delta.id
803
+ if tc_delta.type:
804
+ tool_calls_data[idx]["type"] = tc_delta.type
805
+ if hasattr(tc_delta, "function") and tc_delta.function:
806
+ if tc_delta.function.name:
807
+ tool_calls_data[idx]["function"][
808
+ "name"
809
+ ] += tc_delta.function.name
810
+ if tc_delta.function.arguments:
811
+ tool_calls_data[idx]["function"][
812
+ "arguments"
813
+ ] += tc_delta.function.arguments
814
+
815
+ # Build final response data
816
+ logger.info(
817
+ f"[LITELLM_BACKEND] Streaming complete. tool_calls_data={tool_calls_data}, full_text length={len(full_text)}"
818
+ )
819
+ done_data = {
820
+ "text": full_text,
821
+ "usage": {
822
+ "prompt_tokens": 0,
823
+ "completion_tokens": 0,
824
+ "total_tokens": 0,
825
+ },
826
+ }
827
+ if tool_calls_data:
828
+ done_data["tool_calls"] = tool_calls_data
621
829
 
622
830
  await _write_event_anyio(
623
831
  byte_stream,
624
832
  {
625
833
  "id": req_id,
626
834
  "event": "done",
627
- "data": {
628
- "text": full_text,
629
- "usage": {
630
- "prompt_tokens": 0,
631
- "completion_tokens": 0,
632
- "total_tokens": 0,
633
- },
634
- },
835
+ "data": done_data,
635
836
  },
636
837
  )
637
838
  return
@@ -642,22 +843,47 @@ class _BaseBrokerServer:
642
843
  temperature=temperature,
643
844
  max_tokens=max_tokens,
644
845
  stream=False,
846
+ tools=tools,
847
+ tool_choice=tool_choice,
645
848
  )
646
849
  text = ""
850
+ tool_calls_data = None
647
851
  try:
648
- text = resp.choices[0].message.content or ""
852
+ message = resp.choices[0].message
853
+ text = message.content or ""
854
+
855
+ # Extract tool calls if present
856
+ if hasattr(message, "tool_calls") and message.tool_calls:
857
+ tool_calls_data = []
858
+ for tc in message.tool_calls:
859
+ tool_calls_data.append(
860
+ {
861
+ "id": tc.id,
862
+ "type": tc.type,
863
+ "function": {
864
+ "name": tc.function.name,
865
+ "arguments": tc.function.arguments,
866
+ },
867
+ }
868
+ )
649
869
  except Exception:
650
870
  text = ""
871
+ tool_calls_data = None
872
+
873
+ # Build response data
874
+ done_data = {
875
+ "text": text,
876
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
877
+ }
878
+ if tool_calls_data:
879
+ done_data["tool_calls"] = tool_calls_data
651
880
 
652
881
  await _write_event_anyio(
653
882
  byte_stream,
654
883
  {
655
884
  "id": req_id,
656
885
  "event": "done",
657
- "data": {
658
- "text": text,
659
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
660
- },
886
+ "data": done_data,
661
887
  },
662
888
  )
663
889
  except Exception as e:
@@ -968,6 +1194,8 @@ class BrokerServer(_BaseBrokerServer):
968
1194
  stream = bool(params.get("stream", False))
969
1195
  temperature = params.get("temperature")
970
1196
  max_tokens = params.get("max_tokens")
1197
+ tools = params.get("tools")
1198
+ tool_choice = params.get("tool_choice")
971
1199
 
972
1200
  if not isinstance(model, str) or not model:
973
1201
  await write_event(
@@ -990,65 +1218,157 @@ class BrokerServer(_BaseBrokerServer):
990
1218
 
991
1219
  try:
992
1220
  if stream:
993
- stream_iter = await self._openai.chat(
994
- model=model,
995
- messages=messages,
996
- temperature=temperature,
997
- max_tokens=max_tokens,
998
- stream=True,
1221
+ # Build kwargs for OpenAI chat call
1222
+ chat_kwargs = {
1223
+ "model": model,
1224
+ "messages": messages,
1225
+ "stream": True,
1226
+ }
1227
+ if temperature is not None:
1228
+ chat_kwargs["temperature"] = temperature
1229
+ if max_tokens is not None:
1230
+ chat_kwargs["max_tokens"] = max_tokens
1231
+ if tools is not None:
1232
+ chat_kwargs["tools"] = tools
1233
+ logger.info(f"[BROKER_SERVER] Added {len(tools)} tools to chat_kwargs")
1234
+ else:
1235
+ logger.warning("[BROKER_SERVER] No tools to add to chat_kwargs")
1236
+ if tool_choice is not None:
1237
+ chat_kwargs["tool_choice"] = tool_choice
1238
+ logger.info(f"[BROKER_SERVER] Added tool_choice={tool_choice} to chat_kwargs")
1239
+ else:
1240
+ logger.warning("[BROKER_SERVER] No tool_choice to add")
1241
+
1242
+ logger.info(
1243
+ f"[BROKER_SERVER] Calling backend.chat() with {len(chat_kwargs)} kwargs: {list(chat_kwargs.keys())}"
999
1244
  )
1245
+ stream_iter = await self._openai.chat(**chat_kwargs)
1000
1246
 
1001
1247
  full_text = ""
1248
+ tool_calls_data = []
1002
1249
  async for chunk in stream_iter:
1003
1250
  try:
1004
1251
  delta = chunk.choices[0].delta
1005
1252
  text = getattr(delta, "content", None)
1253
+ delta_tool_calls = getattr(delta, "tool_calls", None)
1006
1254
  except Exception:
1007
1255
  text = None
1008
-
1009
- if not text:
1010
- continue
1011
-
1012
- full_text += text
1013
- await write_event({"id": req_id, "event": "delta", "data": {"text": text}})
1256
+ delta_tool_calls = None
1257
+
1258
+ if text:
1259
+ full_text += text
1260
+ await write_event({"id": req_id, "event": "delta", "data": {"text": text}})
1261
+
1262
+ # Accumulate tool calls from deltas
1263
+ if delta_tool_calls:
1264
+ logger.info(
1265
+ f"[LITELLM_BACKEND] Received delta_tool_calls: {delta_tool_calls}"
1266
+ )
1267
+ for tc_delta in delta_tool_calls:
1268
+ idx = tc_delta.index
1269
+ # Extend tool_calls_data list if needed
1270
+ while len(tool_calls_data) <= idx:
1271
+ tool_calls_data.append(
1272
+ {
1273
+ "id": "",
1274
+ "type": "function",
1275
+ "function": {"name": "", "arguments": ""},
1276
+ }
1277
+ )
1278
+
1279
+ # Merge delta into accumulated tool call
1280
+ if tc_delta.id:
1281
+ tool_calls_data[idx]["id"] = tc_delta.id
1282
+ if tc_delta.type:
1283
+ tool_calls_data[idx]["type"] = tc_delta.type
1284
+ if hasattr(tc_delta, "function") and tc_delta.function:
1285
+ if tc_delta.function.name:
1286
+ tool_calls_data[idx]["function"][
1287
+ "name"
1288
+ ] += tc_delta.function.name
1289
+ if tc_delta.function.arguments:
1290
+ tool_calls_data[idx]["function"][
1291
+ "arguments"
1292
+ ] += tc_delta.function.arguments
1293
+
1294
+ # Build final response data
1295
+ logger.info(
1296
+ f"[LITELLM_BACKEND] Streaming complete. tool_calls_data={tool_calls_data}, full_text length={len(full_text)}"
1297
+ )
1298
+ done_data = {
1299
+ "text": full_text,
1300
+ "usage": {
1301
+ "prompt_tokens": 0,
1302
+ "completion_tokens": 0,
1303
+ "total_tokens": 0,
1304
+ },
1305
+ }
1306
+ if tool_calls_data:
1307
+ done_data["tool_calls"] = tool_calls_data
1014
1308
 
1015
1309
  await write_event(
1016
1310
  {
1017
1311
  "id": req_id,
1018
1312
  "event": "done",
1019
- "data": {
1020
- "text": full_text,
1021
- "usage": {
1022
- "prompt_tokens": 0,
1023
- "completion_tokens": 0,
1024
- "total_tokens": 0,
1025
- },
1026
- },
1313
+ "data": done_data,
1027
1314
  }
1028
1315
  )
1029
1316
  return
1030
1317
 
1031
- resp = await self._openai.chat(
1032
- model=model,
1033
- messages=messages,
1034
- temperature=temperature,
1035
- max_tokens=max_tokens,
1036
- stream=False,
1037
- )
1318
+ # Build kwargs for OpenAI chat call
1319
+ chat_kwargs = {
1320
+ "model": model,
1321
+ "messages": messages,
1322
+ "stream": False,
1323
+ }
1324
+ if temperature is not None:
1325
+ chat_kwargs["temperature"] = temperature
1326
+ if max_tokens is not None:
1327
+ chat_kwargs["max_tokens"] = max_tokens
1328
+ if tools is not None:
1329
+ chat_kwargs["tools"] = tools
1330
+ if tool_choice is not None:
1331
+ chat_kwargs["tool_choice"] = tool_choice
1332
+
1333
+ resp = await self._openai.chat(**chat_kwargs)
1334
+
1038
1335
  text = ""
1336
+ tool_calls_data = None
1039
1337
  try:
1040
- text = resp.choices[0].message.content or ""
1338
+ message = resp.choices[0].message
1339
+ text = message.content or ""
1340
+
1341
+ # Extract tool calls if present
1342
+ if hasattr(message, "tool_calls") and message.tool_calls:
1343
+ tool_calls_data = []
1344
+ for tc in message.tool_calls:
1345
+ tool_calls_data.append(
1346
+ {
1347
+ "id": tc.id,
1348
+ "type": tc.type,
1349
+ "function": {
1350
+ "name": tc.function.name,
1351
+ "arguments": tc.function.arguments,
1352
+ },
1353
+ }
1354
+ )
1041
1355
  except Exception:
1042
1356
  text = ""
1357
+ tool_calls_data = None
1358
+
1359
+ # Build response data
1360
+ done_data = {
1361
+ "text": text,
1362
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
1363
+ }
1364
+ if tool_calls_data:
1365
+ done_data["tool_calls"] = tool_calls_data
1043
1366
 
1044
1367
  await write_event(
1045
1368
  {
1046
1369
  "id": req_id,
1047
1370
  "event": "done",
1048
- "data": {
1049
- "text": text,
1050
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
1051
- },
1371
+ "data": done_data,
1052
1372
  }
1053
1373
  )
1054
1374
  except Exception as e:
@@ -1078,9 +1398,13 @@ class TcpBrokerServer(_BaseBrokerServer):
1078
1398
  openai_backend: Optional[OpenAIChatBackend] = None,
1079
1399
  tool_registry: Optional[HostToolRegistry] = None,
1080
1400
  event_handler: Optional[Callable[[dict[str, Any]], None]] = None,
1401
+ control_handler: Optional[Callable[[dict], Awaitable[dict]]] = None,
1081
1402
  ):
1082
1403
  super().__init__(
1083
- openai_backend=openai_backend, tool_registry=tool_registry, event_handler=event_handler
1404
+ openai_backend=openai_backend,
1405
+ tool_registry=tool_registry,
1406
+ event_handler=event_handler,
1407
+ control_handler=control_handler,
1084
1408
  )
1085
1409
  self.host = host
1086
1410
  self.port = port