ommlds 0.0.0.dev467__py3-none-any.whl → 0.0.0.dev469__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ommlds might be problematic. Click here for more details.
- ommlds/.omlish-manifests.json +109 -2
- ommlds/__about__.py +2 -2
- ommlds/_hacks/__init__.py +4 -0
- ommlds/_hacks/funcs.py +110 -0
- ommlds/_hacks/names.py +158 -0
- ommlds/_hacks/params.py +73 -0
- ommlds/_hacks/patches.py +0 -3
- ommlds/backends/ollama/__init__.py +0 -0
- ommlds/backends/ollama/protocol.py +170 -0
- ommlds/backends/transformers/__init__.py +0 -0
- ommlds/backends/transformers/filecache.py +109 -0
- ommlds/backends/transformers/streamers.py +73 -0
- ommlds/cli/main.py +11 -5
- ommlds/cli/sessions/chat/backends/catalog.py +1 -1
- ommlds/cli/sessions/completion/session.py +1 -1
- ommlds/cli/sessions/embedding/session.py +1 -1
- ommlds/minichain/__init__.py +5 -0
- ommlds/minichain/backends/catalogs/base.py +14 -1
- ommlds/minichain/backends/catalogs/simple.py +2 -2
- ommlds/minichain/backends/catalogs/strings.py +9 -7
- ommlds/minichain/backends/impls/anthropic/stream.py +1 -2
- ommlds/minichain/backends/impls/google/stream.py +1 -2
- ommlds/minichain/backends/impls/llamacpp/chat.py +9 -0
- ommlds/minichain/backends/impls/llamacpp/stream.py +26 -10
- ommlds/minichain/backends/impls/ollama/__init__.py +0 -0
- ommlds/minichain/backends/impls/ollama/chat.py +199 -0
- ommlds/minichain/backends/impls/openai/stream.py +1 -2
- ommlds/minichain/backends/impls/transformers/transformers.py +134 -17
- ommlds/minichain/registries/globals.py +18 -4
- ommlds/minichain/standard.py +7 -0
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/METADATA +7 -7
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/RECORD +36 -26
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import dataclasses as dc
|
|
3
|
+
import os
|
|
4
|
+
import threading
|
|
5
|
+
import typing as ta
|
|
6
|
+
|
|
7
|
+
import transformers as tfm
|
|
8
|
+
|
|
9
|
+
from omlish import lang
|
|
10
|
+
|
|
11
|
+
from ..._hacks.funcs import create_detour
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
##
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
|
18
|
+
class _FileCachePatchContext:
|
|
19
|
+
local_first: bool = False
|
|
20
|
+
local_config_present_is_authoritative: bool = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_FILE_CACHE_PATCH_CONTEXT_TLS = threading.local()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_file_cache_patch_context() -> _FileCachePatchContext:
|
|
27
|
+
try:
|
|
28
|
+
return _FILE_CACHE_PATCH_CONTEXT_TLS.context
|
|
29
|
+
except AttributeError:
|
|
30
|
+
ctx = _FILE_CACHE_PATCH_CONTEXT_TLS.context = _FileCachePatchContext()
|
|
31
|
+
return ctx
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_FILE_CACHE_PATCH_LOCK = threading.Lock()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@lang.cached_function(lock=_FILE_CACHE_PATCH_LOCK)
|
|
38
|
+
def patch_file_cache() -> None:
|
|
39
|
+
"""
|
|
40
|
+
I tried to make a `local_first_pipeline` function to be called instead of `tfm.pipeline`, I really did, but the
|
|
41
|
+
transformers code is such a disgusting rat's nest full of direct static calls to the caching code strewn about at
|
|
42
|
+
every layer with no concern whatsoever for forwarding kwargs where they need to go.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from transformers.utils.hub import cached_files
|
|
46
|
+
|
|
47
|
+
orig_cached_files: ta.Callable[..., str | None] = lang.copy_function(cached_files) # type: ignore
|
|
48
|
+
|
|
49
|
+
get_file_cache_patch_context = _get_file_cache_patch_context
|
|
50
|
+
|
|
51
|
+
def new_cached_files(
|
|
52
|
+
path_or_repo_id: str | os.PathLike,
|
|
53
|
+
filenames: list[str],
|
|
54
|
+
**kwargs: ta.Any,
|
|
55
|
+
) -> str | None:
|
|
56
|
+
ctx = get_file_cache_patch_context()
|
|
57
|
+
|
|
58
|
+
if ctx.local_first and not kwargs.get('local_files_only'):
|
|
59
|
+
try:
|
|
60
|
+
local = orig_cached_files(
|
|
61
|
+
path_or_repo_id,
|
|
62
|
+
filenames,
|
|
63
|
+
**{**kwargs, 'local_files_only': True},
|
|
64
|
+
)
|
|
65
|
+
except OSError as e: # noqa
|
|
66
|
+
pass
|
|
67
|
+
else:
|
|
68
|
+
return local
|
|
69
|
+
|
|
70
|
+
if ctx.local_config_present_is_authoritative:
|
|
71
|
+
try:
|
|
72
|
+
local_config = orig_cached_files(
|
|
73
|
+
path_or_repo_id,
|
|
74
|
+
[tfm.CONFIG_NAME],
|
|
75
|
+
**{**kwargs, 'local_files_only': True},
|
|
76
|
+
)
|
|
77
|
+
except OSError as e: # noqa
|
|
78
|
+
pass
|
|
79
|
+
else:
|
|
80
|
+
raise OSError(
|
|
81
|
+
f'Files {filenames!r} requested under local_first '
|
|
82
|
+
f'but local_config present at {local_config!r}, '
|
|
83
|
+
f'assuming files do not exist.',
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return orig_cached_files(path_or_repo_id, filenames, **kwargs)
|
|
87
|
+
|
|
88
|
+
cached_files.__code__ = create_detour(cached_files, new_cached_files, as_kwargs=True)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@contextlib.contextmanager
|
|
92
|
+
def file_cache_patch_context(
|
|
93
|
+
*,
|
|
94
|
+
local_first: bool = False,
|
|
95
|
+
local_config_present_is_authoritative: bool = False,
|
|
96
|
+
) -> ta.Generator[None]:
|
|
97
|
+
patch_file_cache()
|
|
98
|
+
|
|
99
|
+
new_ctx = dc.replace(
|
|
100
|
+
old_ctx := _get_file_cache_patch_context(),
|
|
101
|
+
local_first=local_first,
|
|
102
|
+
local_config_present_is_authoritative=local_config_present_is_authoritative,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
_FILE_CACHE_PATCH_CONTEXT_TLS.context = new_ctx
|
|
106
|
+
try:
|
|
107
|
+
yield
|
|
108
|
+
finally:
|
|
109
|
+
_FILE_CACHE_PATCH_CONTEXT_TLS.context = old_ctx
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing as ta
|
|
3
|
+
|
|
4
|
+
import transformers as tfm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
T = ta.TypeVar('T')
|
|
8
|
+
P = ta.ParamSpec('P')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CancellableTextStreamer(tfm.TextStreamer):
|
|
15
|
+
class Callback(ta.Protocol):
|
|
16
|
+
def __call__(self, text: str, *, stream_end: bool) -> None: ...
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
tokenizer: tfm.AutoTokenizer,
|
|
21
|
+
callback: Callback,
|
|
22
|
+
*,
|
|
23
|
+
skip_prompt: bool = False,
|
|
24
|
+
**decode_kwargs: ta.Any,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(
|
|
27
|
+
tokenizer,
|
|
28
|
+
skip_prompt=skip_prompt,
|
|
29
|
+
**decode_kwargs,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.callback = callback
|
|
33
|
+
|
|
34
|
+
_cancelled: bool = False
|
|
35
|
+
|
|
36
|
+
#
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def cancelled(self) -> bool:
|
|
40
|
+
return self._cancelled
|
|
41
|
+
|
|
42
|
+
def cancel(self) -> None:
|
|
43
|
+
self._cancelled = True
|
|
44
|
+
|
|
45
|
+
class Cancelled(BaseException): # noqa
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def ignoring_cancelled(fn: ta.Callable[P, T]) -> ta.Callable[P, T | None]:
|
|
50
|
+
@functools.wraps(fn)
|
|
51
|
+
def inner(*args, **kwargs):
|
|
52
|
+
try:
|
|
53
|
+
return fn(*args, **kwargs)
|
|
54
|
+
except CancellableTextStreamer.Cancelled:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
return inner
|
|
58
|
+
|
|
59
|
+
def _maybe_raise_cancelled(self) -> None:
|
|
60
|
+
if self._cancelled:
|
|
61
|
+
raise CancellableTextStreamer.Cancelled
|
|
62
|
+
|
|
63
|
+
#
|
|
64
|
+
|
|
65
|
+
def put(self, value: ta.Any) -> None:
|
|
66
|
+
self._maybe_raise_cancelled()
|
|
67
|
+
super().put(value)
|
|
68
|
+
self._maybe_raise_cancelled()
|
|
69
|
+
|
|
70
|
+
def on_finalized_text(self, text: str, stream_end: bool = False) -> None:
|
|
71
|
+
self._maybe_raise_cancelled()
|
|
72
|
+
self.callback(text, stream_end=stream_end)
|
|
73
|
+
self._maybe_raise_cancelled()
|
ommlds/cli/main.py
CHANGED
|
@@ -39,10 +39,6 @@ else:
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
async def _a_main(args: ta.Any = None) -> None:
|
|
42
|
-
logs.configure_standard_logging('INFO')
|
|
43
|
-
|
|
44
|
-
#
|
|
45
|
-
|
|
46
42
|
parser = argparse.ArgumentParser()
|
|
47
43
|
parser.add_argument('prompt', nargs='*')
|
|
48
44
|
|
|
@@ -64,6 +60,8 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
64
60
|
parser.add_argument('-E', '--embed', action='store_true')
|
|
65
61
|
parser.add_argument('-j', '--image', action='store_true')
|
|
66
62
|
|
|
63
|
+
parser.add_argument('-v', '--verbose', action='store_true')
|
|
64
|
+
|
|
67
65
|
parser.add_argument('--enable-fs-tools', action='store_true')
|
|
68
66
|
parser.add_argument('--enable-todo-tools', action='store_true')
|
|
69
67
|
parser.add_argument('--enable-unsafe-tools-do-not-use-lol', action='store_true')
|
|
@@ -74,6 +72,14 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
74
72
|
|
|
75
73
|
#
|
|
76
74
|
|
|
75
|
+
if args.verbose:
|
|
76
|
+
logs.configure_standard_logging('DEBUG')
|
|
77
|
+
else:
|
|
78
|
+
logs.configure_standard_logging('INFO')
|
|
79
|
+
logs.silence_noisy_loggers()
|
|
80
|
+
|
|
81
|
+
#
|
|
82
|
+
|
|
77
83
|
content: mc.Content | None
|
|
78
84
|
|
|
79
85
|
if args.image:
|
|
@@ -161,7 +167,7 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
161
167
|
args.enable_test_weather_tool or
|
|
162
168
|
args.code
|
|
163
169
|
),
|
|
164
|
-
enabled_tools={
|
|
170
|
+
enabled_tools={ # noqa
|
|
165
171
|
*(['fs'] if args.enable_fs_tools else []),
|
|
166
172
|
*(['todo'] if args.enable_todo_tools else []),
|
|
167
173
|
*(['weather'] if args.enable_test_weather_tool else []),
|
|
@@ -32,7 +32,7 @@ class _CatalogBackendProvider(BackendProvider[ServiceT], lang.Abstract):
|
|
|
32
32
|
@contextlib.asynccontextmanager
|
|
33
33
|
async def _provide_backend(self, cls: type[ServiceT]) -> ta.AsyncIterator[ServiceT]:
|
|
34
34
|
service: ServiceT
|
|
35
|
-
async with lang.async_or_sync_maybe_managing(self._catalog.
|
|
35
|
+
async with lang.async_or_sync_maybe_managing(self._catalog.new_backend(
|
|
36
36
|
cls,
|
|
37
37
|
self._name,
|
|
38
38
|
*(self._configs or []),
|
|
@@ -31,7 +31,7 @@ class CompletionSession(Session['CompletionSession.Config']):
|
|
|
31
31
|
prompt = check.isinstance(self._config.content, str)
|
|
32
32
|
|
|
33
33
|
mdl: mc.CompletionService
|
|
34
|
-
async with lang.async_maybe_managing(self._backend_catalog.
|
|
34
|
+
async with lang.async_maybe_managing(self._backend_catalog.new_backend(
|
|
35
35
|
mc.CompletionService,
|
|
36
36
|
self._config.backend or DEFAULT_COMPLETION_MODEL_BACKEND,
|
|
37
37
|
)) as mdl:
|
|
@@ -29,7 +29,7 @@ class EmbeddingSession(Session['EmbeddingSession.Config']):
|
|
|
29
29
|
|
|
30
30
|
async def run(self) -> None:
|
|
31
31
|
mdl: mc.EmbeddingService
|
|
32
|
-
async with lang.async_maybe_managing(self._backend_catalog.
|
|
32
|
+
async with lang.async_maybe_managing(self._backend_catalog.new_backend(
|
|
33
33
|
mc.EmbeddingService,
|
|
34
34
|
self._config.backend or DEFAULT_EMBEDDING_MODEL_BACKEND,
|
|
35
35
|
)) as mdl:
|
ommlds/minichain/__init__.py
CHANGED
|
@@ -322,6 +322,7 @@ with _lang.auto_proxy_init(
|
|
|
322
322
|
##
|
|
323
323
|
|
|
324
324
|
from .registries.globals import ( # noqa
|
|
325
|
+
get_registry_cls,
|
|
325
326
|
register_type,
|
|
326
327
|
registry_new,
|
|
327
328
|
registry_of,
|
|
@@ -558,6 +559,10 @@ with _lang.auto_proxy_init(
|
|
|
558
559
|
)
|
|
559
560
|
|
|
560
561
|
from .standard import ( # noqa
|
|
562
|
+
Device,
|
|
563
|
+
|
|
564
|
+
ApiUrl,
|
|
565
|
+
|
|
561
566
|
ApiKey,
|
|
562
567
|
|
|
563
568
|
DefaultOptions,
|
|
@@ -3,15 +3,28 @@ import typing as ta
|
|
|
3
3
|
|
|
4
4
|
from omlish import lang
|
|
5
5
|
|
|
6
|
+
from ...configs import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
T = ta.TypeVar('T')
|
|
10
|
+
|
|
6
11
|
|
|
7
12
|
##
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
class BackendCatalog(lang.Abstract):
|
|
16
|
+
class Backend(ta.NamedTuple):
|
|
17
|
+
factory: ta.Callable[..., ta.Any]
|
|
18
|
+
configs: ta.Sequence[Config] | None
|
|
19
|
+
|
|
11
20
|
@abc.abstractmethod
|
|
12
|
-
def get_backend(self, service_cls:
|
|
21
|
+
def get_backend(self, service_cls: type[T], name: str) -> Backend:
|
|
13
22
|
raise NotImplementedError
|
|
14
23
|
|
|
24
|
+
def new_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> ta.Any:
|
|
25
|
+
be = self.get_backend(service_cls, name)
|
|
26
|
+
return be.factory(*be.configs or [], *args, **kwargs)
|
|
27
|
+
|
|
15
28
|
# #
|
|
16
29
|
#
|
|
17
30
|
# class Bound(lang.Final, ta.Generic[T]):
|
|
@@ -40,9 +40,9 @@ class SimpleBackendCatalog(BackendCatalog):
|
|
|
40
40
|
sc_dct[e.name] = e
|
|
41
41
|
self._dct = dct
|
|
42
42
|
|
|
43
|
-
def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) ->
|
|
43
|
+
def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> BackendCatalog.Backend:
|
|
44
44
|
e = self._dct[service_cls][name]
|
|
45
|
-
return e.factory_fn
|
|
45
|
+
return BackendCatalog.Backend(e.factory_fn, None)
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
##
|
|
@@ -5,7 +5,7 @@ from omlish import check
|
|
|
5
5
|
from ...models.configs import ModelPath
|
|
6
6
|
from ...models.configs import ModelRepo
|
|
7
7
|
from ...models.repos.resolving import ModelRepoResolver
|
|
8
|
-
from ...registries.globals import
|
|
8
|
+
from ...registries.globals import get_registry_cls
|
|
9
9
|
from ..strings.parsing import parse_backend_string
|
|
10
10
|
from ..strings.resolving import BackendStringResolver
|
|
11
11
|
from ..strings.resolving import ResolveBackendStringArgs
|
|
@@ -30,14 +30,14 @@ class BackendStringBackendCatalog(BackendCatalog):
|
|
|
30
30
|
self._string_resolver = string_resolver
|
|
31
31
|
self._model_repo_resolver = model_repo_resolver
|
|
32
32
|
|
|
33
|
-
def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) ->
|
|
33
|
+
def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> BackendCatalog.Backend:
|
|
34
34
|
ps = parse_backend_string(name)
|
|
35
35
|
rs = check.not_none(self._string_resolver.resolve_backend_string(ResolveBackendStringArgs(
|
|
36
36
|
service_cls,
|
|
37
37
|
ps,
|
|
38
38
|
)))
|
|
39
39
|
|
|
40
|
-
al = list(rs.args or [])
|
|
40
|
+
al: list = list(rs.args or [])
|
|
41
41
|
|
|
42
42
|
# FIXME: lol
|
|
43
43
|
if al and isinstance(al[0], ModelRepo):
|
|
@@ -46,10 +46,12 @@ class BackendStringBackendCatalog(BackendCatalog):
|
|
|
46
46
|
mrp = check.not_none(mrr.resolve(mr))
|
|
47
47
|
al = [ModelPath(mrp.path), *al[1:]]
|
|
48
48
|
|
|
49
|
-
|
|
49
|
+
cls = get_registry_cls(
|
|
50
50
|
service_cls,
|
|
51
51
|
rs.name,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return BackendCatalog.Backend(
|
|
55
|
+
cls,
|
|
56
|
+
al,
|
|
55
57
|
)
|
|
@@ -95,8 +95,7 @@ class AnthropicChatChoicesStreamService:
|
|
|
95
95
|
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
96
96
|
sd = sse.SseDecoder()
|
|
97
97
|
while True:
|
|
98
|
-
|
|
99
|
-
b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
|
|
98
|
+
b = http_response.stream.read1(self.READ_CHUNK_SIZE)
|
|
100
99
|
for l in db.feed(b):
|
|
101
100
|
if isinstance(l, DelimitingBuffer.Incomplete):
|
|
102
101
|
# FIXME: handle
|
|
@@ -169,8 +169,7 @@ class GoogleChatChoicesStreamService:
|
|
|
169
169
|
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
|
|
170
170
|
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
171
171
|
while True:
|
|
172
|
-
|
|
173
|
-
b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
|
|
172
|
+
b = http_response.stream.read1(self.READ_CHUNK_SIZE)
|
|
174
173
|
for bl in db.feed(b):
|
|
175
174
|
if isinstance(bl, DelimitingBuffer.Incomplete):
|
|
176
175
|
# FIXME: handle
|
|
@@ -30,6 +30,15 @@ from .format import get_msg_content
|
|
|
30
30
|
##
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
34
|
+
# ['ChatChoicesService'],
|
|
35
|
+
# 'llamacpp',
|
|
36
|
+
# )
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
##
|
|
40
|
+
|
|
41
|
+
|
|
33
42
|
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
34
43
|
# name='llamacpp',
|
|
35
44
|
# type='ChatChoicesService',
|
|
@@ -29,6 +29,15 @@ from .format import get_msg_content
|
|
|
29
29
|
##
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
33
|
+
# ['ChatChoicesStreamService'],
|
|
34
|
+
# 'llamacpp',
|
|
35
|
+
# )
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
##
|
|
39
|
+
|
|
40
|
+
|
|
32
41
|
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
33
42
|
# name='llamacpp',
|
|
34
43
|
# type='ChatChoicesStreamService',
|
|
@@ -76,18 +85,25 @@ class LlamacppChatChoicesStreamService(lang.ExitStacked):
|
|
|
76
85
|
rs.enter_context(lang.defer(close_output))
|
|
77
86
|
|
|
78
87
|
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
|
|
88
|
+
last_role: ta.Any = None
|
|
89
|
+
|
|
79
90
|
for chunk in output:
|
|
80
91
|
check.state(chunk['object'] == 'chat.completion.chunk')
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
92
|
+
|
|
93
|
+
choice = check.single(chunk['choices'])
|
|
94
|
+
|
|
95
|
+
if not (delta := choice.get('delta', {})):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# FIXME: check role is assistant
|
|
99
|
+
if (role := delta.get('role')) != last_role:
|
|
100
|
+
last_role = role
|
|
101
|
+
|
|
102
|
+
# FIXME: stop reason
|
|
103
|
+
|
|
104
|
+
if (content := delta.get('content', '')):
|
|
105
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiChoiceDelta(content)])]))
|
|
106
|
+
|
|
91
107
|
return None
|
|
92
108
|
|
|
93
109
|
return await new_stream_response(rs, inner)
|
|
File without changes
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import lang
|
|
5
|
+
from omlish import marshal as msh
|
|
6
|
+
from omlish import typedvalues as tv
|
|
7
|
+
from omlish.formats import json
|
|
8
|
+
from omlish.http import all as http
|
|
9
|
+
from omlish.io.buffers import DelimitingBuffer
|
|
10
|
+
|
|
11
|
+
from .....backends.ollama import protocol as pt
|
|
12
|
+
from ....chat.choices.services import ChatChoicesOutputs
|
|
13
|
+
from ....chat.choices.services import ChatChoicesRequest
|
|
14
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
15
|
+
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
16
|
+
from ....chat.choices.types import AiChoice
|
|
17
|
+
from ....chat.messages import AiMessage
|
|
18
|
+
from ....chat.messages import AnyAiMessage
|
|
19
|
+
from ....chat.messages import Message
|
|
20
|
+
from ....chat.messages import SystemMessage
|
|
21
|
+
from ....chat.messages import UserMessage
|
|
22
|
+
from ....chat.stream.services import ChatChoicesStreamRequest
|
|
23
|
+
from ....chat.stream.services import ChatChoicesStreamResponse
|
|
24
|
+
from ....chat.stream.services import static_check_is_chat_choices_stream_service
|
|
25
|
+
from ....chat.stream.types import AiChoiceDeltas
|
|
26
|
+
from ....chat.stream.types import AiChoicesDeltas
|
|
27
|
+
from ....chat.stream.types import ContentAiChoiceDelta
|
|
28
|
+
from ....models.configs import ModelName
|
|
29
|
+
from ....resources import UseResources
|
|
30
|
+
from ....standard import ApiUrl
|
|
31
|
+
from ....stream.services import StreamResponseSink
|
|
32
|
+
from ....stream.services import new_stream_response
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
##
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
39
|
+
# [
|
|
40
|
+
# 'ChatChoicesService',
|
|
41
|
+
# 'ChatChoicesStreamService',
|
|
42
|
+
# ],
|
|
43
|
+
# 'ollama',
|
|
44
|
+
# )
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
##
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BaseOllamaChatChoicesService(lang.Abstract):
|
|
51
|
+
DEFAULT_API_URL: ta.ClassVar[ApiUrl] = ApiUrl('http://localhost:11434/api')
|
|
52
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName('llama3.2')
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*configs: ApiUrl | ModelName,
|
|
57
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__()
|
|
60
|
+
|
|
61
|
+
self._http_client = http_client
|
|
62
|
+
|
|
63
|
+
with tv.consume(*configs) as cc:
|
|
64
|
+
self._api_url = cc.pop(self.DEFAULT_API_URL)
|
|
65
|
+
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
66
|
+
|
|
67
|
+
#
|
|
68
|
+
|
|
69
|
+
ROLE_MAP: ta.ClassVar[ta.Mapping[type[Message], pt.Role]] = {
|
|
70
|
+
SystemMessage: 'system',
|
|
71
|
+
UserMessage: 'user',
|
|
72
|
+
AiMessage: 'assistant',
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _get_message_content(cls, m: Message) -> str | None:
|
|
77
|
+
if isinstance(m, (AiMessage, UserMessage, SystemMessage)):
|
|
78
|
+
return check.isinstance(m.c, str)
|
|
79
|
+
else:
|
|
80
|
+
raise TypeError(m)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _build_request_messages(cls, mc_msgs: ta.Iterable[Message]) -> ta.Sequence[pt.Message]:
|
|
84
|
+
messages: list[pt.Message] = []
|
|
85
|
+
for m in mc_msgs:
|
|
86
|
+
messages.append(pt.Message(
|
|
87
|
+
role=cls.ROLE_MAP[type(m)],
|
|
88
|
+
content=cls._get_message_content(m),
|
|
89
|
+
))
|
|
90
|
+
return messages
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
##
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
97
|
+
# name='ollama',
|
|
98
|
+
# type='ChatChoicesService',
|
|
99
|
+
# )
|
|
100
|
+
@static_check_is_chat_choices_service
|
|
101
|
+
class OllamaChatChoicesService(BaseOllamaChatChoicesService):
|
|
102
|
+
async def invoke(
|
|
103
|
+
self,
|
|
104
|
+
request: ChatChoicesRequest,
|
|
105
|
+
) -> ChatChoicesResponse:
|
|
106
|
+
messages = self._build_request_messages(request.v)
|
|
107
|
+
|
|
108
|
+
a_req = pt.ChatRequest(
|
|
109
|
+
model=self._model_name.v,
|
|
110
|
+
messages=messages,
|
|
111
|
+
# tools=tools or None,
|
|
112
|
+
stream=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
raw_request = msh.marshal(a_req)
|
|
116
|
+
|
|
117
|
+
async with http.manage_async_client(self._http_client) as http_client:
|
|
118
|
+
raw_response = await http_client.request(http.HttpRequest(
|
|
119
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
120
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
121
|
+
))
|
|
122
|
+
|
|
123
|
+
json_response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|
|
124
|
+
|
|
125
|
+
resp = msh.unmarshal(json_response, pt.ChatResponse)
|
|
126
|
+
|
|
127
|
+
out: list[AnyAiMessage] = []
|
|
128
|
+
if resp.message.role == 'assistant':
|
|
129
|
+
out.append(AiMessage(
|
|
130
|
+
check.not_none(resp.message.content),
|
|
131
|
+
))
|
|
132
|
+
else:
|
|
133
|
+
raise TypeError(resp.message.role)
|
|
134
|
+
|
|
135
|
+
return ChatChoicesResponse([
|
|
136
|
+
AiChoice(out),
|
|
137
|
+
])
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
##
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
144
|
+
# name='ollama',
|
|
145
|
+
# type='ChatChoicesStreamService',
|
|
146
|
+
# )
|
|
147
|
+
@static_check_is_chat_choices_stream_service
|
|
148
|
+
class OllamaChatChoicesStreamService(BaseOllamaChatChoicesService):
|
|
149
|
+
READ_CHUNK_SIZE = 64 * 1024
|
|
150
|
+
|
|
151
|
+
async def invoke(
|
|
152
|
+
self,
|
|
153
|
+
request: ChatChoicesStreamRequest,
|
|
154
|
+
) -> ChatChoicesStreamResponse:
|
|
155
|
+
messages = self._build_request_messages(request.v)
|
|
156
|
+
|
|
157
|
+
a_req = pt.ChatRequest(
|
|
158
|
+
model=self._model_name.v,
|
|
159
|
+
messages=messages,
|
|
160
|
+
# tools=tools or None,
|
|
161
|
+
stream=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
raw_request = msh.marshal(a_req)
|
|
165
|
+
|
|
166
|
+
http_request = http.HttpRequest(
|
|
167
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
168
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
async with UseResources.or_new(request.options) as rs:
|
|
172
|
+
http_client = await rs.enter_async_context(http.manage_async_client(self._http_client))
|
|
173
|
+
http_response = await rs.enter_async_context(await http_client.stream_request(http_request))
|
|
174
|
+
|
|
175
|
+
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
|
|
176
|
+
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
177
|
+
while True:
|
|
178
|
+
b = await http_response.stream.read1(self.READ_CHUNK_SIZE)
|
|
179
|
+
for l in db.feed(b):
|
|
180
|
+
if isinstance(l, DelimitingBuffer.Incomplete):
|
|
181
|
+
# FIXME: handle
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
lj = json.loads(l.decode('utf-8'))
|
|
185
|
+
lp: pt.ChatResponse = msh.unmarshal(lj, pt.ChatResponse)
|
|
186
|
+
|
|
187
|
+
check.state(lp.message.role == 'assistant')
|
|
188
|
+
check.none(lp.message.tool_name)
|
|
189
|
+
check.state(not lp.message.tool_calls)
|
|
190
|
+
|
|
191
|
+
if (c := lp.message.content):
|
|
192
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiChoiceDelta(
|
|
193
|
+
c,
|
|
194
|
+
)])]))
|
|
195
|
+
|
|
196
|
+
if not b:
|
|
197
|
+
return []
|
|
198
|
+
|
|
199
|
+
return await new_stream_response(rs, inner)
|
|
@@ -88,8 +88,7 @@ class OpenaiChatChoicesStreamService:
|
|
|
88
88
|
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
89
89
|
sd = sse.SseDecoder()
|
|
90
90
|
while True:
|
|
91
|
-
|
|
92
|
-
b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
|
|
91
|
+
b = http_response.stream.read1(self.READ_CHUNK_SIZE)
|
|
93
92
|
for l in db.feed(b):
|
|
94
93
|
if isinstance(l, DelimitingBuffer.Incomplete):
|
|
95
94
|
# FIXME: handle
|