licos-dev-sdk 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/PKG-INFO +1 -1
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/pyproject.toml +1 -1
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/__init__.py +4 -2
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/model.py +126 -113
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/tests/test_model.py +90 -74
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/.gitignore +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/_utils.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/archive.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/chart.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/data.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/diagram.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/document.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/image.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/presentation.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/spreadsheet.py +0 -0
- {licos_dev_sdk-0.2.1 → licos_dev_sdk-0.2.2}/src/licos_dev_sdk/web.py +0 -0
|
@@ -43,12 +43,14 @@ def __getattr__(name: str):
|
|
|
43
43
|
"ConfigurationError": ("model", "ConfigurationError"),
|
|
44
44
|
"LLMClient": ("model", "LLMClient"),
|
|
45
45
|
"VisionClient": ("model", "VisionClient"),
|
|
46
|
+
"VisionUnderstandingClient": ("model", "VisionUnderstandingClient"),
|
|
46
47
|
"ImageGenerationClient": ("model", "ImageGenerationClient"),
|
|
47
48
|
"VideoGenerationClient": ("model", "VideoGenerationClient"),
|
|
48
49
|
"SpeechRecognitionClient": ("model", "SpeechRecognitionClient"),
|
|
49
50
|
"ASRClient": ("model", "ASRClient"),
|
|
50
51
|
"fetch_model_catalogs": ("model", "fetch_model_catalogs"),
|
|
51
52
|
"resolve_llm_endpoint": ("model", "resolve_llm_endpoint"),
|
|
53
|
+
"resolve_vision_endpoint": ("model", "resolve_vision_endpoint"),
|
|
52
54
|
"resolve_image_generation_endpoint": ("model", "resolve_image_generation_endpoint"),
|
|
53
55
|
"resolve_video_generation_endpoint": ("model", "resolve_video_generation_endpoint"),
|
|
54
56
|
"resolve_speech_recognition_endpoint": ("model", "resolve_speech_recognition_endpoint"),
|
|
@@ -78,9 +80,9 @@ __all__ = [
|
|
|
78
80
|
"create_pptx",
|
|
79
81
|
"ModelRuntime", "ModelEndpoint", "ModelResult",
|
|
80
82
|
"ApiError", "ConfigurationError",
|
|
81
|
-
"LLMClient", "VisionClient", "ImageGenerationClient", "VideoGenerationClient",
|
|
83
|
+
"LLMClient", "VisionClient", "VisionUnderstandingClient", "ImageGenerationClient", "VideoGenerationClient",
|
|
82
84
|
"SpeechRecognitionClient", "ASRClient",
|
|
83
|
-
"fetch_model_catalogs", "resolve_llm_endpoint",
|
|
85
|
+
"fetch_model_catalogs", "resolve_llm_endpoint", "resolve_vision_endpoint",
|
|
84
86
|
"resolve_image_generation_endpoint", "resolve_video_generation_endpoint",
|
|
85
87
|
"resolve_speech_recognition_endpoint",
|
|
86
88
|
"invoke_llm", "generate_image", "generate_video", "recognize_speech", "understand_image",
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
|
-
from dataclasses import dataclass, field, replace
|
|
5
|
+
from dataclasses import dataclass, field, replace
|
|
6
6
|
from typing import Any, Iterator, Sequence
|
|
7
7
|
from urllib import error as urlerror
|
|
8
8
|
from urllib import parse, request
|
|
@@ -11,11 +11,11 @@ from licos_platform_sdk._runtime import (
|
|
|
11
11
|
ApiError,
|
|
12
12
|
ConfigurationError,
|
|
13
13
|
env,
|
|
14
|
-
normalize_base_url,
|
|
15
|
-
platform_base_url,
|
|
16
|
-
resolve_user_token,
|
|
17
|
-
should_refresh_user_token,
|
|
18
|
-
)
|
|
14
|
+
normalize_base_url,
|
|
15
|
+
platform_base_url,
|
|
16
|
+
resolve_user_token,
|
|
17
|
+
should_refresh_user_token,
|
|
18
|
+
)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
MODEL_CATALOG_PATH = "/api/v1/llm-gateway/ai/model-catalog"
|
|
@@ -112,6 +112,16 @@ def resolve_llm_endpoint(
|
|
|
112
112
|
return _resolve_endpoint(runtime, "chat", model_group=model_group)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
+
def resolve_vision_endpoint(
|
|
116
|
+
*,
|
|
117
|
+
base_url: str | None = None,
|
|
118
|
+
user_token: str | None = None,
|
|
119
|
+
user_id: str | None = None,
|
|
120
|
+
) -> ModelEndpoint:
|
|
121
|
+
runtime = _model_runtime(base_url=base_url, user_token=user_token, user_id=user_id)
|
|
122
|
+
return _resolve_endpoint(runtime, "chat", model_group="vision")
|
|
123
|
+
|
|
124
|
+
|
|
115
125
|
def resolve_image_generation_endpoint(
|
|
116
126
|
*,
|
|
117
127
|
base_url: str | None = None,
|
|
@@ -266,6 +276,9 @@ class VisionClient:
|
|
|
266
276
|
)
|
|
267
277
|
|
|
268
278
|
|
|
279
|
+
VisionUnderstandingClient = VisionClient
|
|
280
|
+
|
|
281
|
+
|
|
269
282
|
class ImageGenerationClient:
|
|
270
283
|
def __init__(
|
|
271
284
|
self,
|
|
@@ -442,7 +455,7 @@ def clear_model_catalog_cache_for_tests() -> None:
|
|
|
442
455
|
_CATALOG_CACHE.clear()
|
|
443
456
|
|
|
444
457
|
|
|
445
|
-
def _model_runtime(
|
|
458
|
+
def _model_runtime(
|
|
446
459
|
*,
|
|
447
460
|
base_url: str | None = None,
|
|
448
461
|
user_token: str | None = None,
|
|
@@ -450,13 +463,13 @@ def _model_runtime(
|
|
|
450
463
|
) -> ModelRuntime:
|
|
451
464
|
resolved_base_url = normalize_base_url(base_url) if base_url else platform_base_url()
|
|
452
465
|
owner_user_id = user_id or env("LICOS_USER_ID") or env("AGENT_USER_ID")
|
|
453
|
-
token = (user_token or "").strip() or resolve_user_token(resolved_base_url, owner_user_id)
|
|
454
|
-
return ModelRuntime(base_url=resolved_base_url, token=token, user_id=owner_user_id)
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
def _refresh_model_runtime(runtime: ModelRuntime) -> ModelRuntime:
|
|
458
|
-
token = resolve_user_token(runtime.base_url, runtime.user_id, force_refresh=True)
|
|
459
|
-
return replace(runtime, token=token)
|
|
466
|
+
token = (user_token or "").strip() or resolve_user_token(resolved_base_url, owner_user_id)
|
|
467
|
+
return ModelRuntime(base_url=resolved_base_url, token=token, user_id=owner_user_id)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _refresh_model_runtime(runtime: ModelRuntime) -> ModelRuntime:
|
|
471
|
+
token = resolve_user_token(runtime.base_url, runtime.user_id, force_refresh=True)
|
|
472
|
+
return replace(runtime, token=token)
|
|
460
473
|
|
|
461
474
|
|
|
462
475
|
def _fetch_model_catalogs(runtime: ModelRuntime, *, refresh: bool = False) -> list[dict[str, Any]]:
|
|
@@ -466,17 +479,17 @@ def _fetch_model_catalogs(runtime: ModelRuntime, *, refresh: bool = False) -> li
|
|
|
466
479
|
if cached and not refresh and time.time() - cached[0] <= ttl:
|
|
467
480
|
return cached[1]
|
|
468
481
|
|
|
469
|
-
try:
|
|
470
|
-
payload = _request_json(
|
|
471
|
-
"GET",
|
|
472
|
-
f"{runtime.base_url}{MODEL_CATALOG_PATH}",
|
|
473
|
-
token=runtime.token,
|
|
474
|
-
timeout=30,
|
|
475
|
-
)
|
|
476
|
-
except ApiError as exc:
|
|
477
|
-
if not refresh and should_refresh_user_token(exc):
|
|
478
|
-
return _fetch_model_catalogs(_refresh_model_runtime(runtime), refresh=True)
|
|
479
|
-
raise
|
|
482
|
+
try:
|
|
483
|
+
payload = _request_json(
|
|
484
|
+
"GET",
|
|
485
|
+
f"{runtime.base_url}{MODEL_CATALOG_PATH}",
|
|
486
|
+
token=runtime.token,
|
|
487
|
+
timeout=30,
|
|
488
|
+
)
|
|
489
|
+
except ApiError as exc:
|
|
490
|
+
if not refresh and should_refresh_user_token(exc):
|
|
491
|
+
return _fetch_model_catalogs(_refresh_model_runtime(runtime), refresh=True)
|
|
492
|
+
raise
|
|
480
493
|
catalogs = _catalogs_from_payload(payload)
|
|
481
494
|
if not catalogs:
|
|
482
495
|
raise ApiError("model catalog has no provider entries", details=payload)
|
|
@@ -596,35 +609,35 @@ def _first_string(value: Any) -> str | None:
|
|
|
596
609
|
return None
|
|
597
610
|
|
|
598
611
|
|
|
599
|
-
def _post_model_json(
|
|
612
|
+
def _post_model_json(
|
|
600
613
|
endpoint: ModelEndpoint,
|
|
601
614
|
runtime: ModelRuntime,
|
|
602
615
|
body: dict[str, Any],
|
|
603
616
|
*,
|
|
604
617
|
timeout: int | None = None,
|
|
605
|
-
) -> Any:
|
|
606
|
-
request_timeout = timeout or _request_timeout()
|
|
607
|
-
try:
|
|
608
|
-
return _request_json(
|
|
609
|
-
"POST",
|
|
610
|
-
endpoint.base_url,
|
|
611
|
-
token=runtime.token,
|
|
612
|
-
body=body,
|
|
613
|
-
headers=endpoint.required_headers,
|
|
614
|
-
timeout=request_timeout,
|
|
615
|
-
)
|
|
616
|
-
except ApiError as exc:
|
|
617
|
-
if should_refresh_user_token(exc):
|
|
618
|
-
refreshed = _refresh_model_runtime(runtime)
|
|
619
|
-
return _request_json(
|
|
620
|
-
"POST",
|
|
621
|
-
endpoint.base_url,
|
|
622
|
-
token=refreshed.token,
|
|
623
|
-
body=body,
|
|
624
|
-
headers=endpoint.required_headers,
|
|
625
|
-
timeout=request_timeout,
|
|
626
|
-
)
|
|
627
|
-
raise
|
|
618
|
+
) -> Any:
|
|
619
|
+
request_timeout = timeout or _request_timeout()
|
|
620
|
+
try:
|
|
621
|
+
return _request_json(
|
|
622
|
+
"POST",
|
|
623
|
+
endpoint.base_url,
|
|
624
|
+
token=runtime.token,
|
|
625
|
+
body=body,
|
|
626
|
+
headers=endpoint.required_headers,
|
|
627
|
+
timeout=request_timeout,
|
|
628
|
+
)
|
|
629
|
+
except ApiError as exc:
|
|
630
|
+
if should_refresh_user_token(exc):
|
|
631
|
+
refreshed = _refresh_model_runtime(runtime)
|
|
632
|
+
return _request_json(
|
|
633
|
+
"POST",
|
|
634
|
+
endpoint.base_url,
|
|
635
|
+
token=refreshed.token,
|
|
636
|
+
body=body,
|
|
637
|
+
headers=endpoint.required_headers,
|
|
638
|
+
timeout=request_timeout,
|
|
639
|
+
)
|
|
640
|
+
raise
|
|
628
641
|
|
|
629
642
|
|
|
630
643
|
def _submit_model_task(
|
|
@@ -657,46 +670,46 @@ def _submit_model_task(
|
|
|
657
670
|
)
|
|
658
671
|
|
|
659
672
|
|
|
660
|
-
def _stream_model_json(
|
|
673
|
+
def _stream_model_json(
|
|
661
674
|
endpoint: ModelEndpoint,
|
|
662
675
|
runtime: ModelRuntime,
|
|
663
676
|
body: dict[str, Any],
|
|
664
677
|
*,
|
|
665
678
|
timeout: int | None = None,
|
|
666
|
-
) -> Iterator[str]:
|
|
667
|
-
active_runtime = runtime
|
|
668
|
-
for attempt in range(2):
|
|
669
|
-
req = _json_request(
|
|
670
|
-
"POST",
|
|
671
|
-
endpoint.base_url,
|
|
672
|
-
token=active_runtime.token,
|
|
673
|
-
body=body,
|
|
674
|
-
headers=endpoint.required_headers,
|
|
675
|
-
)
|
|
676
|
-
try:
|
|
677
|
-
with request.urlopen(req, timeout=timeout or _request_timeout()) as response:
|
|
678
|
-
for raw_line in response:
|
|
679
|
-
line = raw_line.decode("utf-8", errors="replace").strip()
|
|
680
|
-
if not line:
|
|
681
|
-
continue
|
|
682
|
-
if line.startswith("data:"):
|
|
683
|
-
data = line[len("data:") :].strip()
|
|
684
|
-
if data == "[DONE]":
|
|
685
|
-
break
|
|
686
|
-
content = _extract_stream_content(data)
|
|
687
|
-
if content:
|
|
688
|
-
yield content
|
|
689
|
-
continue
|
|
690
|
-
yield line
|
|
691
|
-
return
|
|
692
|
-
except urlerror.HTTPError as exc:
|
|
693
|
-
error = _api_error_from_http(exc)
|
|
694
|
-
if attempt == 0 and should_refresh_user_token(error):
|
|
695
|
-
active_runtime = _refresh_model_runtime(active_runtime)
|
|
696
|
-
continue
|
|
697
|
-
raise error from exc
|
|
698
|
-
except urlerror.URLError as exc:
|
|
699
|
-
raise ApiError(f"model stream request failed: {exc}") from exc
|
|
679
|
+
) -> Iterator[str]:
|
|
680
|
+
active_runtime = runtime
|
|
681
|
+
for attempt in range(2):
|
|
682
|
+
req = _json_request(
|
|
683
|
+
"POST",
|
|
684
|
+
endpoint.base_url,
|
|
685
|
+
token=active_runtime.token,
|
|
686
|
+
body=body,
|
|
687
|
+
headers=endpoint.required_headers,
|
|
688
|
+
)
|
|
689
|
+
try:
|
|
690
|
+
with request.urlopen(req, timeout=timeout or _request_timeout()) as response:
|
|
691
|
+
for raw_line in response:
|
|
692
|
+
line = raw_line.decode("utf-8", errors="replace").strip()
|
|
693
|
+
if not line:
|
|
694
|
+
continue
|
|
695
|
+
if line.startswith("data:"):
|
|
696
|
+
data = line[len("data:") :].strip()
|
|
697
|
+
if data == "[DONE]":
|
|
698
|
+
break
|
|
699
|
+
content = _extract_stream_content(data)
|
|
700
|
+
if content:
|
|
701
|
+
yield content
|
|
702
|
+
continue
|
|
703
|
+
yield line
|
|
704
|
+
return
|
|
705
|
+
except urlerror.HTTPError as exc:
|
|
706
|
+
error = _api_error_from_http(exc)
|
|
707
|
+
if attempt == 0 and should_refresh_user_token(error):
|
|
708
|
+
active_runtime = _refresh_model_runtime(active_runtime)
|
|
709
|
+
continue
|
|
710
|
+
raise error from exc
|
|
711
|
+
except urlerror.URLError as exc:
|
|
712
|
+
raise ApiError(f"model stream request failed: {exc}") from exc
|
|
700
713
|
|
|
701
714
|
|
|
702
715
|
def _request_json(
|
|
@@ -769,28 +782,28 @@ def _await_async_model_result(
|
|
|
769
782
|
raise ApiError("async model response is missing task_id", details=submit_response)
|
|
770
783
|
query_url = _build_task_query_url(endpoint, task_id)
|
|
771
784
|
deadline = time.time() + (max_wait_seconds or _async_timeout())
|
|
772
|
-
last_status = _task_status(submit_response) or "UNKNOWN"
|
|
773
|
-
while time.time() < deadline:
|
|
774
|
-
try:
|
|
775
|
-
response = _request_json(
|
|
776
|
-
"GET",
|
|
777
|
-
query_url,
|
|
778
|
-
token=runtime.token,
|
|
779
|
-
headers=endpoint.required_headers,
|
|
780
|
-
timeout=timeout or _request_timeout(),
|
|
781
|
-
)
|
|
782
|
-
except ApiError as exc:
|
|
783
|
-
if should_refresh_user_token(exc):
|
|
784
|
-
refreshed = _refresh_model_runtime(runtime)
|
|
785
|
-
response = _request_json(
|
|
786
|
-
"GET",
|
|
787
|
-
query_url,
|
|
788
|
-
token=refreshed.token,
|
|
789
|
-
headers=endpoint.required_headers,
|
|
790
|
-
timeout=timeout or _request_timeout(),
|
|
791
|
-
)
|
|
792
|
-
else:
|
|
793
|
-
raise
|
|
785
|
+
last_status = _task_status(submit_response) or "UNKNOWN"
|
|
786
|
+
while time.time() < deadline:
|
|
787
|
+
try:
|
|
788
|
+
response = _request_json(
|
|
789
|
+
"GET",
|
|
790
|
+
query_url,
|
|
791
|
+
token=runtime.token,
|
|
792
|
+
headers=endpoint.required_headers,
|
|
793
|
+
timeout=timeout or _request_timeout(),
|
|
794
|
+
)
|
|
795
|
+
except ApiError as exc:
|
|
796
|
+
if should_refresh_user_token(exc):
|
|
797
|
+
refreshed = _refresh_model_runtime(runtime)
|
|
798
|
+
response = _request_json(
|
|
799
|
+
"GET",
|
|
800
|
+
query_url,
|
|
801
|
+
token=refreshed.token,
|
|
802
|
+
headers=endpoint.required_headers,
|
|
803
|
+
timeout=timeout or _request_timeout(),
|
|
804
|
+
)
|
|
805
|
+
else:
|
|
806
|
+
raise
|
|
794
807
|
last_status = _task_status(response) or last_status
|
|
795
808
|
if _is_successful_task_response(response):
|
|
796
809
|
return response
|
|
@@ -942,13 +955,13 @@ def _normalize_messages(messages: Sequence[Any] | str) -> list[dict[str, Any]]:
|
|
|
942
955
|
return result
|
|
943
956
|
|
|
944
957
|
|
|
945
|
-
def _selected_model(model: str | None, default: str) -> str:
|
|
946
|
-
if not isinstance(model, str):
|
|
947
|
-
return default
|
|
948
|
-
selected = model.strip()
|
|
949
|
-
if not selected or selected.lower() == "auto":
|
|
950
|
-
return default
|
|
951
|
-
return selected
|
|
958
|
+
def _selected_model(model: str | None, default: str) -> str:
|
|
959
|
+
if not isinstance(model, str):
|
|
960
|
+
return default
|
|
961
|
+
selected = model.strip()
|
|
962
|
+
if not selected or selected.lower() == "auto":
|
|
963
|
+
return default
|
|
964
|
+
return selected
|
|
952
965
|
|
|
953
966
|
|
|
954
967
|
def _image_count(count: int | None, parameters: dict[str, Any]) -> int:
|
|
@@ -4,10 +4,10 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
6
|
import unittest
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
9
|
-
from unittest import mock
|
|
10
|
-
from urllib import error as urlerror
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
from unittest import mock
|
|
10
|
+
from urllib import error as urlerror
|
|
11
11
|
|
|
12
12
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
|
13
13
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "licos-platform-sdk" / "src"))
|
|
@@ -16,7 +16,7 @@ from licos_dev_sdk import model
|
|
|
16
16
|
from licos_platform_sdk import _runtime
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class _FakeResponse:
|
|
19
|
+
class _FakeResponse:
|
|
20
20
|
status = 200
|
|
21
21
|
|
|
22
22
|
def __init__(self, payload: dict[str, Any]) -> None:
|
|
@@ -28,19 +28,19 @@ class _FakeResponse:
|
|
|
28
28
|
def __exit__(self, *_args: Any) -> None:
|
|
29
29
|
return None
|
|
30
30
|
|
|
31
|
-
def read(self) -> bytes:
|
|
32
|
-
return json.dumps(self._payload).encode("utf-8")
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class _FakeErrorBody:
|
|
36
|
-
def __init__(self, payload: dict[str, Any]) -> None:
|
|
37
|
-
self._payload = payload
|
|
38
|
-
|
|
39
|
-
def read(self) -> bytes:
|
|
40
|
-
return json.dumps(self._payload).encode("utf-8")
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _catalog_payload() -> dict[str, Any]:
|
|
31
|
+
def read(self) -> bytes:
|
|
32
|
+
return json.dumps(self._payload).encode("utf-8")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _FakeErrorBody:
|
|
36
|
+
def __init__(self, payload: dict[str, Any]) -> None:
|
|
37
|
+
self._payload = payload
|
|
38
|
+
|
|
39
|
+
def read(self) -> bytes:
|
|
40
|
+
return json.dumps(self._payload).encode("utf-8")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _catalog_payload() -> dict[str, Any]:
|
|
44
44
|
return {
|
|
45
45
|
"code": 0,
|
|
46
46
|
"success": True,
|
|
@@ -111,62 +111,62 @@ class ModelSdkTests(unittest.TestCase):
|
|
|
111
111
|
self.assertEqual(result.text, "hello")
|
|
112
112
|
self.assertEqual(captured["exchange_headers"]["Authorization"], "Bearer ai-agent-token")
|
|
113
113
|
self.assertEqual(captured["exchange_body"], {"userId": "user-1"})
|
|
114
|
-
self.assertEqual(captured["catalog_headers"]["Authorization"], "Bearer user-token")
|
|
115
|
-
self.assertEqual(captured["chat_headers"]["Authorization"], "Bearer user-token")
|
|
116
|
-
self.assertEqual(captured["chat_body"]["model"], "chat-text")
|
|
117
|
-
|
|
118
|
-
def test_llm_explicit_model_overrides_catalog_default(self) -> None:
|
|
119
|
-
captured: dict[str, Any] = {}
|
|
120
|
-
|
|
121
|
-
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
122
|
-
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
123
|
-
return _FakeResponse({"code": 0, "success": True, "data": {"accessToken": "user-token"}})
|
|
124
|
-
if req.full_url == "http://platform.example/api/v1/llm-gateway/ai/model-catalog":
|
|
125
|
-
return _FakeResponse(_catalog_payload())
|
|
126
|
-
if req.full_url == "http://gateway.example/v1/chat/completions":
|
|
127
|
-
captured["chat_body"] = json.loads(req.data.decode("utf-8"))
|
|
128
|
-
return _FakeResponse({"choices": [{"message": {"content": "hello"}}]})
|
|
129
|
-
raise AssertionError(req.full_url)
|
|
130
|
-
|
|
131
|
-
with mock.patch.object(model.request, "urlopen", fake_urlopen):
|
|
132
|
-
result = model.LLMClient().invoke("Say hello", model="custom-chat-model")
|
|
133
|
-
|
|
134
|
-
self.assertEqual(result.text, "hello")
|
|
135
|
-
self.assertEqual(captured["chat_body"]["model"], "custom-chat-model")
|
|
136
|
-
|
|
137
|
-
def test_llm_invoke_refreshes_user_token_once_after_unauthorized(self) -> None:
|
|
138
|
-
tokens = iter(["old-token", "new-token"])
|
|
139
|
-
catalog_tokens: list[str] = []
|
|
140
|
-
chat_tokens: list[str] = []
|
|
141
|
-
|
|
142
|
-
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
143
|
-
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
144
|
-
return _FakeResponse({"code": 0, "success": True, "data": {"accessToken": next(tokens)}})
|
|
145
|
-
if req.full_url == "http://platform.example/api/v1/llm-gateway/ai/model-catalog":
|
|
146
|
-
catalog_tokens.append(dict(req.header_items())["Authorization"])
|
|
147
|
-
return _FakeResponse(_catalog_payload())
|
|
148
|
-
if req.full_url == "http://gateway.example/v1/chat/completions":
|
|
149
|
-
chat_tokens.append(dict(req.header_items())["Authorization"])
|
|
150
|
-
if len(chat_tokens) == 1:
|
|
151
|
-
raise urlerror.HTTPError(
|
|
152
|
-
req.full_url,
|
|
153
|
-
401,
|
|
154
|
-
"Unauthorized",
|
|
155
|
-
hdrs=None,
|
|
156
|
-
fp=_FakeErrorBody({"code": 10002, "message": "token invalid or expired", "success": False}),
|
|
157
|
-
)
|
|
158
|
-
return _FakeResponse({"choices": [{"message": {"content": "hello"}}]})
|
|
159
|
-
raise AssertionError(req.full_url)
|
|
160
|
-
|
|
161
|
-
with mock.patch.object(model.request, "urlopen", fake_urlopen):
|
|
162
|
-
result = model.LLMClient().invoke("Say hello", model="auto")
|
|
163
|
-
|
|
164
|
-
self.assertEqual(result.text, "hello")
|
|
165
|
-
self.assertEqual(catalog_tokens, ["Bearer old-token"])
|
|
166
|
-
self.assertEqual(chat_tokens, ["Bearer old-token", "Bearer new-token"])
|
|
167
|
-
|
|
168
|
-
def test_image_generation_defaults_to_one_image(self) -> None:
|
|
169
|
-
captured: dict[str, Any] = {}
|
|
114
|
+
self.assertEqual(captured["catalog_headers"]["Authorization"], "Bearer user-token")
|
|
115
|
+
self.assertEqual(captured["chat_headers"]["Authorization"], "Bearer user-token")
|
|
116
|
+
self.assertEqual(captured["chat_body"]["model"], "chat-text")
|
|
117
|
+
|
|
118
|
+
def test_llm_explicit_model_overrides_catalog_default(self) -> None:
|
|
119
|
+
captured: dict[str, Any] = {}
|
|
120
|
+
|
|
121
|
+
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
122
|
+
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
123
|
+
return _FakeResponse({"code": 0, "success": True, "data": {"accessToken": "user-token"}})
|
|
124
|
+
if req.full_url == "http://platform.example/api/v1/llm-gateway/ai/model-catalog":
|
|
125
|
+
return _FakeResponse(_catalog_payload())
|
|
126
|
+
if req.full_url == "http://gateway.example/v1/chat/completions":
|
|
127
|
+
captured["chat_body"] = json.loads(req.data.decode("utf-8"))
|
|
128
|
+
return _FakeResponse({"choices": [{"message": {"content": "hello"}}]})
|
|
129
|
+
raise AssertionError(req.full_url)
|
|
130
|
+
|
|
131
|
+
with mock.patch.object(model.request, "urlopen", fake_urlopen):
|
|
132
|
+
result = model.LLMClient().invoke("Say hello", model="custom-chat-model")
|
|
133
|
+
|
|
134
|
+
self.assertEqual(result.text, "hello")
|
|
135
|
+
self.assertEqual(captured["chat_body"]["model"], "custom-chat-model")
|
|
136
|
+
|
|
137
|
+
def test_llm_invoke_refreshes_user_token_once_after_unauthorized(self) -> None:
|
|
138
|
+
tokens = iter(["old-token", "new-token"])
|
|
139
|
+
catalog_tokens: list[str] = []
|
|
140
|
+
chat_tokens: list[str] = []
|
|
141
|
+
|
|
142
|
+
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
143
|
+
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
144
|
+
return _FakeResponse({"code": 0, "success": True, "data": {"accessToken": next(tokens)}})
|
|
145
|
+
if req.full_url == "http://platform.example/api/v1/llm-gateway/ai/model-catalog":
|
|
146
|
+
catalog_tokens.append(dict(req.header_items())["Authorization"])
|
|
147
|
+
return _FakeResponse(_catalog_payload())
|
|
148
|
+
if req.full_url == "http://gateway.example/v1/chat/completions":
|
|
149
|
+
chat_tokens.append(dict(req.header_items())["Authorization"])
|
|
150
|
+
if len(chat_tokens) == 1:
|
|
151
|
+
raise urlerror.HTTPError(
|
|
152
|
+
req.full_url,
|
|
153
|
+
401,
|
|
154
|
+
"Unauthorized",
|
|
155
|
+
hdrs=None,
|
|
156
|
+
fp=_FakeErrorBody({"code": 10002, "message": "token invalid or expired", "success": False}),
|
|
157
|
+
)
|
|
158
|
+
return _FakeResponse({"choices": [{"message": {"content": "hello"}}]})
|
|
159
|
+
raise AssertionError(req.full_url)
|
|
160
|
+
|
|
161
|
+
with mock.patch.object(model.request, "urlopen", fake_urlopen):
|
|
162
|
+
result = model.LLMClient().invoke("Say hello", model="auto")
|
|
163
|
+
|
|
164
|
+
self.assertEqual(result.text, "hello")
|
|
165
|
+
self.assertEqual(catalog_tokens, ["Bearer old-token"])
|
|
166
|
+
self.assertEqual(chat_tokens, ["Bearer old-token", "Bearer new-token"])
|
|
167
|
+
|
|
168
|
+
def test_image_generation_defaults_to_one_image(self) -> None:
|
|
169
|
+
captured: dict[str, Any] = {}
|
|
170
170
|
|
|
171
171
|
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
172
172
|
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
@@ -205,6 +205,22 @@ class ModelSdkTests(unittest.TestCase):
|
|
|
205
205
|
self.assertEqual(captured["body"]["messages"][0]["content"][1]["image_url"]["url"], "https://cdn.example/a.png")
|
|
206
206
|
self.assertEqual(result.text, "a blue sky")
|
|
207
207
|
|
|
208
|
+
def test_resolve_vision_endpoint_uses_vision_model_group(self) -> None:
|
|
209
|
+
def fake_urlopen(req: Any, timeout: int = 0) -> _FakeResponse:
|
|
210
|
+
if req.full_url == "http://platform.example/api/v1/internal/auth/ai-user-token":
|
|
211
|
+
return _FakeResponse({"code": 0, "success": True, "data": {"accessToken": "user-token"}})
|
|
212
|
+
if req.full_url == "http://platform.example/api/v1/llm-gateway/ai/model-catalog":
|
|
213
|
+
return _FakeResponse(_catalog_payload())
|
|
214
|
+
raise AssertionError(req.full_url)
|
|
215
|
+
|
|
216
|
+
with mock.patch.object(model.request, "urlopen", fake_urlopen):
|
|
217
|
+
endpoint = model.resolve_vision_endpoint()
|
|
218
|
+
|
|
219
|
+
self.assertEqual(endpoint.capability, "chat")
|
|
220
|
+
self.assertEqual(endpoint.model, "chat-vision")
|
|
221
|
+
self.assertEqual(endpoint.base_url, "http://gateway.example/v1/chat/completions")
|
|
222
|
+
self.assertIs(model.VisionUnderstandingClient, model.VisionClient)
|
|
223
|
+
|
|
208
224
|
|
|
209
225
|
if __name__ == "__main__":
|
|
210
226
|
unittest.main()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|