aethergraph 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aethergraph/__main__.py +3 -0
- aethergraph/api/v1/artifacts.py +23 -4
- aethergraph/api/v1/schemas.py +7 -0
- aethergraph/api/v1/session.py +123 -4
- aethergraph/config/config.py +2 -0
- aethergraph/config/search.py +49 -0
- aethergraph/contracts/services/channel.py +18 -1
- aethergraph/contracts/services/execution.py +58 -0
- aethergraph/contracts/services/llm.py +26 -0
- aethergraph/contracts/services/memory.py +10 -4
- aethergraph/contracts/services/planning.py +53 -0
- aethergraph/contracts/storage/event_log.py +8 -0
- aethergraph/contracts/storage/search_backend.py +47 -0
- aethergraph/contracts/storage/vector_index.py +73 -0
- aethergraph/core/graph/action_spec.py +76 -0
- aethergraph/core/graph/graph_fn.py +75 -2
- aethergraph/core/graph/graphify.py +74 -2
- aethergraph/core/runtime/graph_runner.py +2 -1
- aethergraph/core/runtime/node_context.py +66 -3
- aethergraph/core/runtime/node_services.py +8 -0
- aethergraph/core/runtime/run_manager.py +263 -271
- aethergraph/core/runtime/run_types.py +54 -1
- aethergraph/core/runtime/runtime_env.py +35 -14
- aethergraph/core/runtime/runtime_services.py +308 -18
- aethergraph/plugins/agents/default_chat_agent.py +266 -74
- aethergraph/plugins/agents/default_chat_agent_v2.py +487 -0
- aethergraph/plugins/channel/adapters/webui.py +69 -21
- aethergraph/plugins/channel/routes/webui_routes.py +8 -48
- aethergraph/runtime/__init__.py +12 -0
- aethergraph/server/app_factory.py +10 -1
- aethergraph/server/ui_static/assets/index-CFktGdbW.js +4913 -0
- aethergraph/server/ui_static/assets/index-DcfkFlTA.css +1 -0
- aethergraph/server/ui_static/index.html +2 -2
- aethergraph/services/artifacts/facade.py +157 -21
- aethergraph/services/artifacts/types.py +35 -0
- aethergraph/services/artifacts/utils.py +42 -0
- aethergraph/services/channel/channel_bus.py +3 -1
- aethergraph/services/channel/event_hub copy.py +55 -0
- aethergraph/services/channel/event_hub.py +81 -0
- aethergraph/services/channel/factory.py +3 -2
- aethergraph/services/channel/session.py +709 -74
- aethergraph/services/container/default_container.py +69 -7
- aethergraph/services/execution/__init__.py +0 -0
- aethergraph/services/execution/local_python.py +118 -0
- aethergraph/services/indices/__init__.py +0 -0
- aethergraph/services/indices/global_indices.py +21 -0
- aethergraph/services/indices/scoped_indices.py +292 -0
- aethergraph/services/llm/generic_client.py +342 -46
- aethergraph/services/llm/generic_embed_client.py +359 -0
- aethergraph/services/llm/types.py +3 -1
- aethergraph/services/memory/distillers/llm_long_term.py +60 -109
- aethergraph/services/memory/distillers/llm_long_term_v1.py +180 -0
- aethergraph/services/memory/distillers/llm_meta_summary.py +57 -266
- aethergraph/services/memory/distillers/llm_meta_summary_v1.py +342 -0
- aethergraph/services/memory/distillers/long_term.py +48 -131
- aethergraph/services/memory/distillers/long_term_v1.py +170 -0
- aethergraph/services/memory/facade/chat.py +18 -8
- aethergraph/services/memory/facade/core.py +159 -19
- aethergraph/services/memory/facade/distillation.py +86 -31
- aethergraph/services/memory/facade/retrieval.py +100 -1
- aethergraph/services/memory/factory.py +4 -1
- aethergraph/services/planning/__init__.py +0 -0
- aethergraph/services/planning/action_catalog.py +271 -0
- aethergraph/services/planning/bindings.py +56 -0
- aethergraph/services/planning/dependency_index.py +65 -0
- aethergraph/services/planning/flow_validator.py +263 -0
- aethergraph/services/planning/graph_io_adapter.py +150 -0
- aethergraph/services/planning/input_parser.py +312 -0
- aethergraph/services/planning/missing_inputs.py +28 -0
- aethergraph/services/planning/node_planner.py +613 -0
- aethergraph/services/planning/orchestrator.py +112 -0
- aethergraph/services/planning/plan_executor.py +506 -0
- aethergraph/services/planning/plan_types.py +321 -0
- aethergraph/services/planning/planner.py +617 -0
- aethergraph/services/planning/planner_service.py +369 -0
- aethergraph/services/planning/planning_context_builder.py +43 -0
- aethergraph/services/planning/quick_actions.py +29 -0
- aethergraph/services/planning/routers/__init__.py +0 -0
- aethergraph/services/planning/routers/simple_router.py +26 -0
- aethergraph/services/rag/facade.py +0 -3
- aethergraph/services/scope/scope.py +30 -30
- aethergraph/services/scope/scope_factory.py +15 -7
- aethergraph/services/skills/__init__.py +0 -0
- aethergraph/services/skills/skill_registry.py +465 -0
- aethergraph/services/skills/skills.py +220 -0
- aethergraph/services/skills/utils.py +194 -0
- aethergraph/storage/artifacts/artifact_index_jsonl.py +16 -10
- aethergraph/storage/artifacts/artifact_index_sqlite.py +12 -2
- aethergraph/storage/docstore/sqlite_doc_sync.py +1 -1
- aethergraph/storage/memory/event_persist.py +42 -2
- aethergraph/storage/memory/fs_persist.py +32 -2
- aethergraph/storage/search_backend/__init__.py +0 -0
- aethergraph/storage/search_backend/generic_vector_backend.py +230 -0
- aethergraph/storage/search_backend/null_backend.py +34 -0
- aethergraph/storage/search_backend/sqlite_lexical_backend.py +387 -0
- aethergraph/storage/search_backend/utils.py +31 -0
- aethergraph/storage/search_factory.py +75 -0
- aethergraph/storage/vector_index/faiss_index.py +72 -4
- aethergraph/storage/vector_index/sqlite_index.py +521 -52
- aethergraph/storage/vector_index/sqlite_index_vanila.py +311 -0
- aethergraph/storage/vector_index/utils.py +22 -0
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/METADATA +1 -1
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/RECORD +108 -64
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/WHEEL +1 -1
- aethergraph/plugins/agents/default_chat_agent copy.py +0 -90
- aethergraph/server/ui_static/assets/index-BR5GtXcZ.css +0 -1
- aethergraph/server/ui_static/assets/index-CQ0HZZ83.js +0 -400
- aethergraph/services/eventhub/event_hub.py +0 -76
- aethergraph/services/llm/generic_client copy.py +0 -691
- aethergraph/services/prompts/file_store.py +0 -41
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/licenses/NOTICE +0 -0
- {aethergraph-0.1.0a2.dist-info → aethergraph-0.1.0a4.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
@@ -37,6 +38,8 @@ from aethergraph.services.llm.utils import (
|
|
|
37
38
|
_validate_json_schema,
|
|
38
39
|
)
|
|
39
40
|
|
|
41
|
+
DeltaCallback = Callable[[str], Awaitable[None]]
|
|
42
|
+
|
|
40
43
|
|
|
41
44
|
# ---- Helpers --------------------------------------------------------------
|
|
42
45
|
class _Retry:
|
|
@@ -99,6 +102,7 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
99
102
|
self._retry = _Retry()
|
|
100
103
|
self._client = httpx.AsyncClient(timeout=timeout)
|
|
101
104
|
self._bound_loop = None
|
|
105
|
+
self._timeout = timeout
|
|
102
106
|
|
|
103
107
|
# Resolve creds/base
|
|
104
108
|
self.api_key = (
|
|
@@ -236,19 +240,16 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
236
240
|
logger.warning(f"llm_metering_failed: {e}")
|
|
237
241
|
|
|
238
242
|
async def _ensure_client(self):
|
|
239
|
-
"""Ensure the httpx client is bound to the current event loop.
|
|
240
|
-
This allows safe usage across multiple async contexts.
|
|
241
|
-
"""
|
|
242
243
|
loop = asyncio.get_running_loop()
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
self._client = httpx.AsyncClient(timeout=self.
|
|
244
|
+
|
|
245
|
+
if self._client is None:
|
|
246
|
+
self._client = httpx.AsyncClient(timeout=self.timeout)
|
|
247
|
+
self._bound_loop = loop
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
if self._bound_loop is not loop:
|
|
251
|
+
# Don't attempt to close the old client here; it belongs to the old loop.
|
|
252
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
252
253
|
self._bound_loop = loop
|
|
253
254
|
|
|
254
255
|
async def chat(
|
|
@@ -298,6 +299,10 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
298
299
|
validate_json: If True, validate JSON output against schema.
|
|
299
300
|
fail_on_unsupported: If True, raise error for unsupported features.
|
|
300
301
|
**kw: Additional provider-specific keyword arguments.
|
|
302
|
+
Common cross-provider options include:
|
|
303
|
+
- model: override default model name.
|
|
304
|
+
- tools: OpenAI-style tools / functions description.
|
|
305
|
+
- tool_choice: tool selection strategy (e.g., "auto", "none", or provider-specific dict).
|
|
301
306
|
|
|
302
307
|
Returns:
|
|
303
308
|
tuple[str, dict[str, int]]: The model response (text or structured output) and usage statistics.
|
|
@@ -313,7 +318,7 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
313
318
|
- Rate limiting and metering help manage resource usage effectively.
|
|
314
319
|
"""
|
|
315
320
|
await self._ensure_client()
|
|
316
|
-
model = kw.
|
|
321
|
+
model = kw.pop("model", self.model)
|
|
317
322
|
|
|
318
323
|
start = time.perf_counter()
|
|
319
324
|
|
|
@@ -355,6 +360,250 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
355
360
|
|
|
356
361
|
return text, usage
|
|
357
362
|
|
|
363
|
+
async def chat_stream(
|
|
364
|
+
self,
|
|
365
|
+
messages: list[dict[str, Any]],
|
|
366
|
+
*,
|
|
367
|
+
reasoning_effort: str | None = None,
|
|
368
|
+
max_output_tokens: int | None = None,
|
|
369
|
+
output_format: ChatOutputFormat = "text",
|
|
370
|
+
json_schema: dict[str, Any] | None = None,
|
|
371
|
+
schema_name: str = "output",
|
|
372
|
+
strict_schema: bool = True,
|
|
373
|
+
validate_json: bool = True,
|
|
374
|
+
fail_on_unsupported: bool = True,
|
|
375
|
+
on_delta: DeltaCallback | None = None,
|
|
376
|
+
**kw: Any,
|
|
377
|
+
) -> tuple[str, dict[str, int]]:
|
|
378
|
+
"""
|
|
379
|
+
Stream a chat request to the LLM provider and return the accumulated response.
|
|
380
|
+
|
|
381
|
+
This method handles provider-specific streaming paths, falling back to non-streaming
|
|
382
|
+
chat() if streaming is not implemented. It supports real-time delta updates via
|
|
383
|
+
a callback function and returns the full response text and usage statistics at the end.
|
|
384
|
+
|
|
385
|
+
Examples:
|
|
386
|
+
Basic usage with a list of messages:
|
|
387
|
+
```python
|
|
388
|
+
response, usage = await context.llm().chat_stream(
|
|
389
|
+
messages=[{"role": "user", "content": "Hello, assistant!"}]
|
|
390
|
+
)
|
|
391
|
+
```
|
|
392
|
+
|
|
393
|
+
Using a delta callback for real-time updates:
|
|
394
|
+
```python
|
|
395
|
+
async def on_delta(delta):
|
|
396
|
+
print(delta, end="")
|
|
397
|
+
|
|
398
|
+
response, usage = await context.llm().chat_stream(
|
|
399
|
+
messages=[{"role": "user", "content": "Tell me a joke."}],
|
|
400
|
+
on_delta=on_delta
|
|
401
|
+
)
|
|
402
|
+
```
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
messages: List of message dicts, each with "role" and "content" keys.
|
|
406
|
+
reasoning_effort: Optional string to control model reasoning depth.
|
|
407
|
+
max_output_tokens: Optional maximum number of output tokens.
|
|
408
|
+
output_format: Output format, e.g., "text" or "json".
|
|
409
|
+
json_schema: Optional JSON schema for validating structured output.
|
|
410
|
+
schema_name: Name for the root schema object (default: "output").
|
|
411
|
+
strict_schema: If True, enforce strict schema validation.
|
|
412
|
+
validate_json: If True, validate JSON output against schema.
|
|
413
|
+
fail_on_unsupported: If True, raise error for unsupported features.
|
|
414
|
+
on_delta: Optional callback function to handle real-time text deltas.
|
|
415
|
+
**kw: Additional provider-specific keyword arguments.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
tuple[str, dict[str, int]]: The accumulated response text and usage statistics.
|
|
419
|
+
|
|
420
|
+
Raises:
|
|
421
|
+
NotImplementedError: If the provider is not supported.
|
|
422
|
+
RuntimeError: For various errors including invalid JSON output or rate limit violations.
|
|
423
|
+
LLMUnsupportedFeatureError: If a requested feature is unsupported by the provider.
|
|
424
|
+
|
|
425
|
+
Notes:
|
|
426
|
+
- This method centralizes handling of streaming and non-streaming paths for LLM providers.
|
|
427
|
+
- The `on_delta` callback allows for real-time updates, making it suitable for interactive applications.
|
|
428
|
+
- Rate limiting and usage metering are applied consistently across providers.
|
|
429
|
+
- Currently, only OpenAI's Responses API streaming is implemented; other providers will fall back to the non-streaming `chat()` method.
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
await self._ensure_client()
|
|
433
|
+
model = kw.pop("model", self.model)
|
|
434
|
+
start = time.perf_counter()
|
|
435
|
+
|
|
436
|
+
# For now, only OpenAI Responses streaming is implemented.
|
|
437
|
+
if self.provider == "openai":
|
|
438
|
+
text, usage = await self._chat_openai_responses_stream(
|
|
439
|
+
messages,
|
|
440
|
+
model=model,
|
|
441
|
+
reasoning_effort=reasoning_effort,
|
|
442
|
+
max_output_tokens=max_output_tokens,
|
|
443
|
+
output_format=output_format,
|
|
444
|
+
json_schema=json_schema,
|
|
445
|
+
schema_name=schema_name,
|
|
446
|
+
strict_schema=strict_schema,
|
|
447
|
+
fail_on_unsupported=fail_on_unsupported,
|
|
448
|
+
on_delta=on_delta,
|
|
449
|
+
**kw,
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
# Fallback: just call normal chat() and send a single delta.
|
|
453
|
+
text, usage = await self.chat(
|
|
454
|
+
messages,
|
|
455
|
+
reasoning_effort=reasoning_effort,
|
|
456
|
+
max_output_tokens=max_output_tokens,
|
|
457
|
+
output_format=output_format,
|
|
458
|
+
json_schema=json_schema,
|
|
459
|
+
schema_name=schema_name,
|
|
460
|
+
strict_schema=strict_schema,
|
|
461
|
+
validate_json=validate_json,
|
|
462
|
+
fail_on_unsupported=fail_on_unsupported,
|
|
463
|
+
**kw,
|
|
464
|
+
)
|
|
465
|
+
if on_delta is not None and text:
|
|
466
|
+
await on_delta(text)
|
|
467
|
+
|
|
468
|
+
# Postprocess (JSON modes etc.)
|
|
469
|
+
text = self._postprocess_structured_output(
|
|
470
|
+
text=text,
|
|
471
|
+
output_format=output_format,
|
|
472
|
+
json_schema=json_schema,
|
|
473
|
+
strict_schema=strict_schema,
|
|
474
|
+
validate_json=validate_json,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
latency_ms = int((time.perf_counter() - start) * 1000)
|
|
478
|
+
|
|
479
|
+
# Rate limits + metering as usual
|
|
480
|
+
self._enforce_llm_limits_for_run(usage=usage)
|
|
481
|
+
await self._record_llm_usage(model=model, usage=usage, latency_ms=latency_ms)
|
|
482
|
+
|
|
483
|
+
return text, usage
|
|
484
|
+
|
|
485
|
+
async def _chat_openai_responses_stream(
|
|
486
|
+
self,
|
|
487
|
+
messages: list[dict[str, Any]],
|
|
488
|
+
*,
|
|
489
|
+
model: str,
|
|
490
|
+
reasoning_effort: str | None,
|
|
491
|
+
max_output_tokens: int | None,
|
|
492
|
+
output_format: ChatOutputFormat,
|
|
493
|
+
json_schema: dict[str, Any] | None,
|
|
494
|
+
schema_name: str,
|
|
495
|
+
strict_schema: bool,
|
|
496
|
+
fail_on_unsupported: bool,
|
|
497
|
+
on_delta: DeltaCallback | None = None,
|
|
498
|
+
**kw: Any,
|
|
499
|
+
) -> tuple[str, dict[str, int]]:
|
|
500
|
+
"""
|
|
501
|
+
Stream text using OpenAI Responses API.
|
|
502
|
+
|
|
503
|
+
- We only support text / json_object / json_schema here.
|
|
504
|
+
- We look for `response.output_text.delta` events and call on_delta(delta).
|
|
505
|
+
- We accumulate full text and best-effort usage from the final event.
|
|
506
|
+
"""
|
|
507
|
+
await self._ensure_client()
|
|
508
|
+
assert self._client is not None
|
|
509
|
+
|
|
510
|
+
url = f"{self.base_url}/responses"
|
|
511
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
512
|
+
|
|
513
|
+
input_messages = _normalize_openai_responses_input(messages)
|
|
514
|
+
|
|
515
|
+
body: dict[str, Any] = {
|
|
516
|
+
"model": model,
|
|
517
|
+
"input": input_messages,
|
|
518
|
+
"stream": True,
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
if reasoning_effort is not None:
|
|
522
|
+
body["reasoning"] = {"effort": reasoning_effort}
|
|
523
|
+
if max_output_tokens is not None:
|
|
524
|
+
body["max_output_tokens"] = max_output_tokens
|
|
525
|
+
|
|
526
|
+
# Structured output config (same as non-streaming path)
|
|
527
|
+
if output_format == "json_object":
|
|
528
|
+
body["text"] = {"format": {"type": "json_object"}}
|
|
529
|
+
elif output_format == "json_schema":
|
|
530
|
+
if json_schema is None:
|
|
531
|
+
raise ValueError("output_format='json_schema' requires json_schema")
|
|
532
|
+
body["text"] = {
|
|
533
|
+
"format": {
|
|
534
|
+
"type": "json_schema",
|
|
535
|
+
"name": schema_name,
|
|
536
|
+
"schema": json_schema,
|
|
537
|
+
"strict": bool(strict_schema),
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
# else: default "text" format
|
|
541
|
+
|
|
542
|
+
full_chunks: list[str] = []
|
|
543
|
+
usage: dict[str, int] = {}
|
|
544
|
+
|
|
545
|
+
async def _handle_event(evt: dict[str, Any]):
|
|
546
|
+
nonlocal usage
|
|
547
|
+
|
|
548
|
+
etype = evt.get("type")
|
|
549
|
+
|
|
550
|
+
# Main text deltas
|
|
551
|
+
if etype == "response.output_text.delta":
|
|
552
|
+
delta = evt.get("delta") or ""
|
|
553
|
+
if delta:
|
|
554
|
+
full_chunks.append(delta)
|
|
555
|
+
if on_delta is not None:
|
|
556
|
+
await on_delta(delta)
|
|
557
|
+
|
|
558
|
+
# Finalization – grab usage from completed response if present
|
|
559
|
+
elif etype in ("response.completed", "response.incomplete", "response.failed"):
|
|
560
|
+
resp = evt.get("response") or {}
|
|
561
|
+
# Usage may or may not be present, keep best-effort
|
|
562
|
+
usage = resp.get("usage") or usage
|
|
563
|
+
|
|
564
|
+
# Optional: basic error surface
|
|
565
|
+
elif etype == "error":
|
|
566
|
+
# in practice `error` may be structured differently; this is just a guardrail
|
|
567
|
+
msg = evt.get("message") or "Unknown streaming error"
|
|
568
|
+
raise RuntimeError(f"OpenAI streaming error: {msg}")
|
|
569
|
+
|
|
570
|
+
async def _call():
|
|
571
|
+
async with self._client.stream(
|
|
572
|
+
"POST",
|
|
573
|
+
url,
|
|
574
|
+
headers=headers,
|
|
575
|
+
json=body,
|
|
576
|
+
) as r:
|
|
577
|
+
try:
|
|
578
|
+
r.raise_for_status()
|
|
579
|
+
except httpx.HTTPStatusError as e:
|
|
580
|
+
text = await r.aread()
|
|
581
|
+
raise RuntimeError(f"OpenAI Responses streaming error: {text!r}") from e
|
|
582
|
+
|
|
583
|
+
# SSE: each event line is "data: {...}" + blank lines between events
|
|
584
|
+
async for line in r.aiter_lines():
|
|
585
|
+
if not line:
|
|
586
|
+
continue
|
|
587
|
+
if not line.startswith("data:"):
|
|
588
|
+
continue
|
|
589
|
+
|
|
590
|
+
data_str = line[len("data:") :].strip()
|
|
591
|
+
if not data_str or data_str == "[DONE]":
|
|
592
|
+
# OpenAI ends stream with `data: [DONE]`
|
|
593
|
+
break
|
|
594
|
+
|
|
595
|
+
try:
|
|
596
|
+
evt = json.loads(data_str)
|
|
597
|
+
except Exception:
|
|
598
|
+
# best-effort: ignore malformed chunks
|
|
599
|
+
continue
|
|
600
|
+
|
|
601
|
+
await _handle_event(evt)
|
|
602
|
+
|
|
603
|
+
await self._retry.run(_call)
|
|
604
|
+
|
|
605
|
+
return "".join(full_chunks), usage
|
|
606
|
+
|
|
358
607
|
async def _chat_dispatch(
|
|
359
608
|
self,
|
|
360
609
|
messages: list[dict[str, Any]],
|
|
@@ -370,6 +619,10 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
370
619
|
fail_on_unsupported: bool,
|
|
371
620
|
**kw: Any,
|
|
372
621
|
) -> tuple[str, dict[str, int]]:
|
|
622
|
+
# Extract cross-provider extras if any
|
|
623
|
+
tools = kw.pop("tools", None)
|
|
624
|
+
tool_choice = kw.pop("tool_choice", None)
|
|
625
|
+
|
|
373
626
|
# OpenAI is now symmetric too
|
|
374
627
|
if self.provider == "openai":
|
|
375
628
|
return await self._chat_openai_responses(
|
|
@@ -381,6 +634,9 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
381
634
|
json_schema=json_schema,
|
|
382
635
|
schema_name=schema_name,
|
|
383
636
|
strict_schema=strict_schema,
|
|
637
|
+
tools=tools,
|
|
638
|
+
tool_choice=tool_choice,
|
|
639
|
+
**kw,
|
|
384
640
|
)
|
|
385
641
|
|
|
386
642
|
# Everyone else
|
|
@@ -391,6 +647,8 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
391
647
|
output_format=output_format,
|
|
392
648
|
json_schema=json_schema,
|
|
393
649
|
fail_on_unsupported=fail_on_unsupported,
|
|
650
|
+
tools=tools,
|
|
651
|
+
tool_choice=tool_choice,
|
|
394
652
|
**kw,
|
|
395
653
|
)
|
|
396
654
|
|
|
@@ -401,6 +659,8 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
401
659
|
output_format=output_format,
|
|
402
660
|
json_schema=json_schema,
|
|
403
661
|
fail_on_unsupported=fail_on_unsupported,
|
|
662
|
+
tools=tools,
|
|
663
|
+
tool_choice=tool_choice,
|
|
404
664
|
**kw,
|
|
405
665
|
)
|
|
406
666
|
|
|
@@ -410,6 +670,8 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
410
670
|
model=model,
|
|
411
671
|
output_format=output_format,
|
|
412
672
|
json_schema=json_schema,
|
|
673
|
+
fail_on_unsupported=fail_on_unsupported,
|
|
674
|
+
tools=tools,
|
|
413
675
|
**kw,
|
|
414
676
|
)
|
|
415
677
|
|
|
@@ -420,6 +682,7 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
420
682
|
output_format=output_format,
|
|
421
683
|
json_schema=json_schema,
|
|
422
684
|
fail_on_unsupported=fail_on_unsupported,
|
|
685
|
+
tools=tools,
|
|
423
686
|
**kw,
|
|
424
687
|
)
|
|
425
688
|
|
|
@@ -463,6 +726,9 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
463
726
|
json_schema: dict[str, Any] | None,
|
|
464
727
|
schema_name: str,
|
|
465
728
|
strict_schema: bool,
|
|
729
|
+
tools: list[dict[str, Any]] | None = None,
|
|
730
|
+
tool_choice: Any = None,
|
|
731
|
+
**kw: Any,
|
|
466
732
|
) -> tuple[str, dict[str, int]]:
|
|
467
733
|
await self._ensure_client()
|
|
468
734
|
assert self._client is not None
|
|
@@ -470,16 +736,16 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
470
736
|
url = f"{self.base_url}/responses"
|
|
471
737
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
472
738
|
|
|
473
|
-
# Normalize input so vision works if caller used image_url parts
|
|
474
739
|
input_messages = _normalize_openai_responses_input(messages)
|
|
475
740
|
|
|
476
741
|
body: dict[str, Any] = {"model": model, "input": input_messages}
|
|
742
|
+
|
|
477
743
|
if reasoning_effort is not None:
|
|
478
744
|
body["reasoning"] = {"effort": reasoning_effort}
|
|
479
745
|
if max_output_tokens is not None:
|
|
480
746
|
body["max_output_tokens"] = max_output_tokens
|
|
481
747
|
|
|
482
|
-
# Structured output
|
|
748
|
+
# Structured output
|
|
483
749
|
if output_format == "json_object":
|
|
484
750
|
body["text"] = {"format": {"type": "json_object"}}
|
|
485
751
|
elif output_format == "json_schema":
|
|
@@ -494,6 +760,12 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
494
760
|
}
|
|
495
761
|
}
|
|
496
762
|
|
|
763
|
+
# Tools (Responses API style)
|
|
764
|
+
if tools is not None:
|
|
765
|
+
body["tools"] = tools
|
|
766
|
+
if tool_choice is not None:
|
|
767
|
+
body["tool_choice"] = tool_choice
|
|
768
|
+
|
|
497
769
|
async def _call():
|
|
498
770
|
r = await self._client.post(url, headers=headers, json=body)
|
|
499
771
|
try:
|
|
@@ -502,12 +774,18 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
502
774
|
raise RuntimeError(f"OpenAI Responses API error: {e.response.text}") from e
|
|
503
775
|
|
|
504
776
|
data = r.json()
|
|
777
|
+
usage = data.get("usage", {}) or {}
|
|
778
|
+
|
|
779
|
+
# If caller asked for raw provider payload, just return it as a JSON string
|
|
780
|
+
if output_format == "raw":
|
|
781
|
+
txt = json.dumps(data, ensure_ascii=False)
|
|
782
|
+
return txt, usage
|
|
783
|
+
|
|
784
|
+
# Existing parsing logic for message-only flows
|
|
505
785
|
output = data.get("output")
|
|
506
786
|
txt = ""
|
|
507
787
|
|
|
508
|
-
# Your existing parsing logic, but robust for list shape
|
|
509
788
|
if isinstance(output, list) and output:
|
|
510
|
-
# concat all message outputs if multiple
|
|
511
789
|
chunks: list[str] = []
|
|
512
790
|
for item in output:
|
|
513
791
|
if isinstance(item, dict) and item.get("type") == "message":
|
|
@@ -531,7 +809,6 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
531
809
|
else:
|
|
532
810
|
txt = ""
|
|
533
811
|
|
|
534
|
-
usage = data.get("usage", {}) or {}
|
|
535
812
|
return txt, usage
|
|
536
813
|
|
|
537
814
|
return await self._retry.run(_call)
|
|
@@ -544,29 +821,10 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
544
821
|
output_format: ChatOutputFormat,
|
|
545
822
|
json_schema: dict[str, Any] | None,
|
|
546
823
|
fail_on_unsupported: bool,
|
|
824
|
+
tools: list[dict[str, Any]] | None = None,
|
|
825
|
+
tool_choice: str | dict[str, Any] | None = None,
|
|
547
826
|
**kw: Any,
|
|
548
827
|
) -> tuple[str, dict[str, int]]:
|
|
549
|
-
"""
|
|
550
|
-
Docstring for _chat_openai_like_chat_completions
|
|
551
|
-
|
|
552
|
-
:param self: Description
|
|
553
|
-
:param messages: Description
|
|
554
|
-
:type messages: list[dict[str, Any]]
|
|
555
|
-
:param model: Description
|
|
556
|
-
:type model: str
|
|
557
|
-
:param output_format: Description
|
|
558
|
-
:type output_format: ChatOutputFormat
|
|
559
|
-
:param json_schema: Description
|
|
560
|
-
:type json_schema: dict[str, Any] | None
|
|
561
|
-
:param fail_on_unsupported: Description
|
|
562
|
-
:type fail_on_unsupported: bool
|
|
563
|
-
:param kw: Description
|
|
564
|
-
:type kw: Any
|
|
565
|
-
:return: Description
|
|
566
|
-
:rtype: tuple[str, dict[str, int]]
|
|
567
|
-
|
|
568
|
-
Call OpenAI-like /chat/completions endpoint.
|
|
569
|
-
"""
|
|
570
828
|
await self._ensure_client()
|
|
571
829
|
assert self._client is not None
|
|
572
830
|
|
|
@@ -580,7 +838,6 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
580
838
|
response_format = {"type": "json_object"}
|
|
581
839
|
msg_for_provider = _ensure_system_json_directive(messages, schema=None)
|
|
582
840
|
elif output_format == "json_schema":
|
|
583
|
-
# not truly native in most openai-like providers
|
|
584
841
|
if fail_on_unsupported:
|
|
585
842
|
raise RuntimeError(f"provider {self.provider} does not support native json_schema")
|
|
586
843
|
msg_for_provider = _ensure_system_json_directive(messages, schema=json_schema)
|
|
@@ -594,6 +851,10 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
594
851
|
}
|
|
595
852
|
if response_format is not None:
|
|
596
853
|
body["response_format"] = response_format
|
|
854
|
+
if tools is not None:
|
|
855
|
+
body["tools"] = tools
|
|
856
|
+
if tool_choice is not None:
|
|
857
|
+
body["tool_choice"] = tool_choice
|
|
597
858
|
|
|
598
859
|
r = await self._client.post(
|
|
599
860
|
f"{self.base_url}/chat/completions",
|
|
@@ -606,8 +867,13 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
606
867
|
raise RuntimeError(f"OpenAI-like chat/completions error: {e.response.text}") from e
|
|
607
868
|
|
|
608
869
|
data = r.json()
|
|
609
|
-
txt, _ = _first_text(data.get("choices", [])) # you already have _first_text in file
|
|
610
870
|
usage = data.get("usage", {}) or {}
|
|
871
|
+
|
|
872
|
+
if output_format == "raw":
|
|
873
|
+
txt = json.dumps(data, ensure_ascii=False)
|
|
874
|
+
return txt, usage
|
|
875
|
+
|
|
876
|
+
txt, _ = _first_text(data.get("choices", []))
|
|
611
877
|
return txt, usage
|
|
612
878
|
|
|
613
879
|
return await self._retry.run(_call)
|
|
@@ -620,6 +886,8 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
620
886
|
output_format: ChatOutputFormat,
|
|
621
887
|
json_schema: dict[str, Any] | None,
|
|
622
888
|
fail_on_unsupported: bool,
|
|
889
|
+
tools: list[dict[str, Any]] | None = None,
|
|
890
|
+
tool_choice: str | dict[str, Any] | None = None,
|
|
623
891
|
**kw: Any,
|
|
624
892
|
) -> tuple[str, dict[str, int]]:
|
|
625
893
|
await self._ensure_client()
|
|
@@ -650,6 +918,11 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
650
918
|
)
|
|
651
919
|
payload["messages"] = _ensure_system_json_directive(messages, schema=json_schema)
|
|
652
920
|
|
|
921
|
+
if tools is not None:
|
|
922
|
+
payload["tools"] = tools
|
|
923
|
+
if tool_choice is not None:
|
|
924
|
+
payload["tool_choice"] = tool_choice
|
|
925
|
+
|
|
653
926
|
async def _call():
|
|
654
927
|
r = await self._client.post(
|
|
655
928
|
f"{self.base_url}/openai/deployments/{self.azure_deployment}/chat/completions?api-version=2024-08-01-preview",
|
|
@@ -662,8 +935,13 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
662
935
|
raise RuntimeError(f"Azure chat/completions error: {e.response.text}") from e
|
|
663
936
|
|
|
664
937
|
data = r.json()
|
|
665
|
-
txt, _ = _first_text(data.get("choices", []))
|
|
666
938
|
usage = data.get("usage", {}) or {}
|
|
939
|
+
|
|
940
|
+
if output_format == "raw":
|
|
941
|
+
txt = json.dumps(data, ensure_ascii=False)
|
|
942
|
+
return txt, usage
|
|
943
|
+
|
|
944
|
+
txt, _ = _first_text(data.get("choices", []))
|
|
667
945
|
return txt, usage
|
|
668
946
|
|
|
669
947
|
return await self._retry.run(_call)
|
|
@@ -675,11 +953,16 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
675
953
|
model: str,
|
|
676
954
|
output_format: ChatOutputFormat,
|
|
677
955
|
json_schema: dict[str, Any] | None,
|
|
956
|
+
fail_on_unsupported: bool,
|
|
957
|
+
tools: list[dict[str, Any]] | None = None,
|
|
678
958
|
**kw: Any,
|
|
679
959
|
) -> tuple[str, dict[str, int]]:
|
|
680
960
|
await self._ensure_client()
|
|
681
961
|
assert self._client is not None
|
|
682
962
|
|
|
963
|
+
if tools is not None and fail_on_unsupported:
|
|
964
|
+
raise RuntimeError("Anthropic tools/function calling not wired yet in this client")
|
|
965
|
+
|
|
683
966
|
temperature = kw.get("temperature", 0.5)
|
|
684
967
|
top_p = kw.get("top_p", 1.0)
|
|
685
968
|
|
|
@@ -746,9 +1029,14 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
746
1029
|
raise RuntimeError(f"Anthropic API error ({e.response.status_code}): {body}") from e
|
|
747
1030
|
|
|
748
1031
|
data = r.json()
|
|
1032
|
+
usage = data.get("usage", {}) or {}
|
|
1033
|
+
|
|
1034
|
+
if output_format == "raw":
|
|
1035
|
+
txt = json.dumps(data, ensure_ascii=False)
|
|
1036
|
+
return txt, usage
|
|
1037
|
+
|
|
749
1038
|
blocks = data.get("content") or []
|
|
750
1039
|
txt = "".join(b.get("text", "") for b in blocks if b.get("type") == "text")
|
|
751
|
-
usage = data.get("usage", {}) or {}
|
|
752
1040
|
return txt, usage
|
|
753
1041
|
|
|
754
1042
|
return await self._retry.run(_call)
|
|
@@ -761,6 +1049,7 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
761
1049
|
output_format: ChatOutputFormat,
|
|
762
1050
|
json_schema: dict[str, Any] | None,
|
|
763
1051
|
fail_on_unsupported: bool,
|
|
1052
|
+
tools: list[dict[str, Any]] | None = None,
|
|
764
1053
|
**kw: Any,
|
|
765
1054
|
) -> tuple[str, dict[str, int]]:
|
|
766
1055
|
await self._ensure_client()
|
|
@@ -769,6 +1058,9 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
769
1058
|
temperature = kw.get("temperature", 0.5)
|
|
770
1059
|
top_p = kw.get("top_p", 1.0)
|
|
771
1060
|
|
|
1061
|
+
if tools is not None and fail_on_unsupported:
|
|
1062
|
+
raise RuntimeError("Gemini tools/function calling not wired yet in this client")
|
|
1063
|
+
|
|
772
1064
|
# Merge system messages into preamble
|
|
773
1065
|
system_parts: list[str] = []
|
|
774
1066
|
for m in messages:
|
|
@@ -815,14 +1107,18 @@ class GenericLLMClient(LLMClientProtocol):
|
|
|
815
1107
|
) from e
|
|
816
1108
|
|
|
817
1109
|
data = r.json()
|
|
818
|
-
cand = (data.get("candidates") or [{}])[0]
|
|
819
|
-
txt = "".join(p.get("text", "") for p in (cand.get("content", {}).get("parts") or []))
|
|
820
|
-
|
|
821
1110
|
um = data.get("usageMetadata") or {}
|
|
822
1111
|
usage = {
|
|
823
1112
|
"input_tokens": int(um.get("promptTokenCount", 0) or 0),
|
|
824
1113
|
"output_tokens": int(um.get("candidatesTokenCount", 0) or 0),
|
|
825
1114
|
}
|
|
1115
|
+
|
|
1116
|
+
if output_format == "raw":
|
|
1117
|
+
txt = json.dumps(data, ensure_ascii=False)
|
|
1118
|
+
return txt, usage
|
|
1119
|
+
|
|
1120
|
+
cand = (data.get("candidates") or [{}])[0]
|
|
1121
|
+
txt = "".join(p.get("text", "") for p in (cand.get("content", {}).get("parts") or []))
|
|
826
1122
|
return txt, usage
|
|
827
1123
|
|
|
828
1124
|
return await self._retry.run(_call)
|