ommlds 0.0.0.dev467__py3-none-any.whl → 0.0.0.dev469__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ommlds might be problematic. Click here for more details.

Files changed (36) hide show
  1. ommlds/.omlish-manifests.json +109 -2
  2. ommlds/__about__.py +2 -2
  3. ommlds/_hacks/__init__.py +4 -0
  4. ommlds/_hacks/funcs.py +110 -0
  5. ommlds/_hacks/names.py +158 -0
  6. ommlds/_hacks/params.py +73 -0
  7. ommlds/_hacks/patches.py +0 -3
  8. ommlds/backends/ollama/__init__.py +0 -0
  9. ommlds/backends/ollama/protocol.py +170 -0
  10. ommlds/backends/transformers/__init__.py +0 -0
  11. ommlds/backends/transformers/filecache.py +109 -0
  12. ommlds/backends/transformers/streamers.py +73 -0
  13. ommlds/cli/main.py +11 -5
  14. ommlds/cli/sessions/chat/backends/catalog.py +1 -1
  15. ommlds/cli/sessions/completion/session.py +1 -1
  16. ommlds/cli/sessions/embedding/session.py +1 -1
  17. ommlds/minichain/__init__.py +5 -0
  18. ommlds/minichain/backends/catalogs/base.py +14 -1
  19. ommlds/minichain/backends/catalogs/simple.py +2 -2
  20. ommlds/minichain/backends/catalogs/strings.py +9 -7
  21. ommlds/minichain/backends/impls/anthropic/stream.py +1 -2
  22. ommlds/minichain/backends/impls/google/stream.py +1 -2
  23. ommlds/minichain/backends/impls/llamacpp/chat.py +9 -0
  24. ommlds/minichain/backends/impls/llamacpp/stream.py +26 -10
  25. ommlds/minichain/backends/impls/ollama/__init__.py +0 -0
  26. ommlds/minichain/backends/impls/ollama/chat.py +199 -0
  27. ommlds/minichain/backends/impls/openai/stream.py +1 -2
  28. ommlds/minichain/backends/impls/transformers/transformers.py +134 -17
  29. ommlds/minichain/registries/globals.py +18 -4
  30. ommlds/minichain/standard.py +7 -0
  31. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/METADATA +7 -7
  32. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/RECORD +36 -26
  33. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/WHEEL +0 -0
  34. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/entry_points.txt +0 -0
  35. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/licenses/LICENSE +0 -0
  36. {ommlds-0.0.0.dev467.dist-info → ommlds-0.0.0.dev469.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ import contextlib
2
+ import dataclasses as dc
3
+ import os
4
+ import threading
5
+ import typing as ta
6
+
7
+ import transformers as tfm
8
+
9
+ from omlish import lang
10
+
11
+ from ..._hacks.funcs import create_detour
12
+
13
+
14
+ ##
15
+
16
+
17
+ @dc.dataclass(frozen=True, kw_only=True)
18
+ class _FileCachePatchContext:
19
+ local_first: bool = False
20
+ local_config_present_is_authoritative: bool = False
21
+
22
+
23
+ _FILE_CACHE_PATCH_CONTEXT_TLS = threading.local()
24
+
25
+
26
+ def _get_file_cache_patch_context() -> _FileCachePatchContext:
27
+ try:
28
+ return _FILE_CACHE_PATCH_CONTEXT_TLS.context
29
+ except AttributeError:
30
+ ctx = _FILE_CACHE_PATCH_CONTEXT_TLS.context = _FileCachePatchContext()
31
+ return ctx
32
+
33
+
34
+ _FILE_CACHE_PATCH_LOCK = threading.Lock()
35
+
36
+
37
+ @lang.cached_function(lock=_FILE_CACHE_PATCH_LOCK)
38
+ def patch_file_cache() -> None:
39
+ """
40
+ I tried to make a `local_first_pipeline` function to be called instead of `tfm.pipeline`, I really did, but the
41
+ transformers code is such a disgusting rat's nest full of direct static calls to the caching code strewn about at
42
+ every layer with no concern whatsoever for forwarding kwargs where they need to go.
43
+ """
44
+
45
+ from transformers.utils.hub import cached_files
46
+
47
+ orig_cached_files: ta.Callable[..., str | None] = lang.copy_function(cached_files) # type: ignore
48
+
49
+ get_file_cache_patch_context = _get_file_cache_patch_context
50
+
51
+ def new_cached_files(
52
+ path_or_repo_id: str | os.PathLike,
53
+ filenames: list[str],
54
+ **kwargs: ta.Any,
55
+ ) -> str | None:
56
+ ctx = get_file_cache_patch_context()
57
+
58
+ if ctx.local_first and not kwargs.get('local_files_only'):
59
+ try:
60
+ local = orig_cached_files(
61
+ path_or_repo_id,
62
+ filenames,
63
+ **{**kwargs, 'local_files_only': True},
64
+ )
65
+ except OSError as e: # noqa
66
+ pass
67
+ else:
68
+ return local
69
+
70
+ if ctx.local_config_present_is_authoritative:
71
+ try:
72
+ local_config = orig_cached_files(
73
+ path_or_repo_id,
74
+ [tfm.CONFIG_NAME],
75
+ **{**kwargs, 'local_files_only': True},
76
+ )
77
+ except OSError as e: # noqa
78
+ pass
79
+ else:
80
+ raise OSError(
81
+ f'Files {filenames!r} requested under local_first '
82
+ f'but local_config present at {local_config!r}, '
83
+ f'assuming files do not exist.',
84
+ )
85
+
86
+ return orig_cached_files(path_or_repo_id, filenames, **kwargs)
87
+
88
+ cached_files.__code__ = create_detour(cached_files, new_cached_files, as_kwargs=True)
89
+
90
+
91
+ @contextlib.contextmanager
92
+ def file_cache_patch_context(
93
+ *,
94
+ local_first: bool = False,
95
+ local_config_present_is_authoritative: bool = False,
96
+ ) -> ta.Generator[None]:
97
+ patch_file_cache()
98
+
99
+ new_ctx = dc.replace(
100
+ old_ctx := _get_file_cache_patch_context(),
101
+ local_first=local_first,
102
+ local_config_present_is_authoritative=local_config_present_is_authoritative,
103
+ )
104
+
105
+ _FILE_CACHE_PATCH_CONTEXT_TLS.context = new_ctx
106
+ try:
107
+ yield
108
+ finally:
109
+ _FILE_CACHE_PATCH_CONTEXT_TLS.context = old_ctx
@@ -0,0 +1,73 @@
1
+ import functools
2
+ import typing as ta
3
+
4
+ import transformers as tfm
5
+
6
+
7
+ T = ta.TypeVar('T')
8
+ P = ta.ParamSpec('P')
9
+
10
+
11
+ ##
12
+
13
+
14
+ class CancellableTextStreamer(tfm.TextStreamer):
15
+ class Callback(ta.Protocol):
16
+ def __call__(self, text: str, *, stream_end: bool) -> None: ...
17
+
18
+ def __init__(
19
+ self,
20
+ tokenizer: tfm.AutoTokenizer,
21
+ callback: Callback,
22
+ *,
23
+ skip_prompt: bool = False,
24
+ **decode_kwargs: ta.Any,
25
+ ) -> None:
26
+ super().__init__(
27
+ tokenizer,
28
+ skip_prompt=skip_prompt,
29
+ **decode_kwargs,
30
+ )
31
+
32
+ self.callback = callback
33
+
34
+ _cancelled: bool = False
35
+
36
+ #
37
+
38
+ @property
39
+ def cancelled(self) -> bool:
40
+ return self._cancelled
41
+
42
+ def cancel(self) -> None:
43
+ self._cancelled = True
44
+
45
+ class Cancelled(BaseException): # noqa
46
+ pass
47
+
48
+ @staticmethod
49
+ def ignoring_cancelled(fn: ta.Callable[P, T]) -> ta.Callable[P, T | None]:
50
+ @functools.wraps(fn)
51
+ def inner(*args, **kwargs):
52
+ try:
53
+ return fn(*args, **kwargs)
54
+ except CancellableTextStreamer.Cancelled:
55
+ pass
56
+
57
+ return inner
58
+
59
+ def _maybe_raise_cancelled(self) -> None:
60
+ if self._cancelled:
61
+ raise CancellableTextStreamer.Cancelled
62
+
63
+ #
64
+
65
+ def put(self, value: ta.Any) -> None:
66
+ self._maybe_raise_cancelled()
67
+ super().put(value)
68
+ self._maybe_raise_cancelled()
69
+
70
+ def on_finalized_text(self, text: str, stream_end: bool = False) -> None:
71
+ self._maybe_raise_cancelled()
72
+ self.callback(text, stream_end=stream_end)
73
+ self._maybe_raise_cancelled()
ommlds/cli/main.py CHANGED
@@ -39,10 +39,6 @@ else:
39
39
 
40
40
 
41
41
  async def _a_main(args: ta.Any = None) -> None:
42
- logs.configure_standard_logging('INFO')
43
-
44
- #
45
-
46
42
  parser = argparse.ArgumentParser()
47
43
  parser.add_argument('prompt', nargs='*')
48
44
 
@@ -64,6 +60,8 @@ async def _a_main(args: ta.Any = None) -> None:
64
60
  parser.add_argument('-E', '--embed', action='store_true')
65
61
  parser.add_argument('-j', '--image', action='store_true')
66
62
 
63
+ parser.add_argument('-v', '--verbose', action='store_true')
64
+
67
65
  parser.add_argument('--enable-fs-tools', action='store_true')
68
66
  parser.add_argument('--enable-todo-tools', action='store_true')
69
67
  parser.add_argument('--enable-unsafe-tools-do-not-use-lol', action='store_true')
@@ -74,6 +72,14 @@ async def _a_main(args: ta.Any = None) -> None:
74
72
 
75
73
  #
76
74
 
75
+ if args.verbose:
76
+ logs.configure_standard_logging('DEBUG')
77
+ else:
78
+ logs.configure_standard_logging('INFO')
79
+ logs.silence_noisy_loggers()
80
+
81
+ #
82
+
77
83
  content: mc.Content | None
78
84
 
79
85
  if args.image:
@@ -161,7 +167,7 @@ async def _a_main(args: ta.Any = None) -> None:
161
167
  args.enable_test_weather_tool or
162
168
  args.code
163
169
  ),
164
- enabled_tools={
170
+ enabled_tools={ # noqa
165
171
  *(['fs'] if args.enable_fs_tools else []),
166
172
  *(['todo'] if args.enable_todo_tools else []),
167
173
  *(['weather'] if args.enable_test_weather_tool else []),
@@ -32,7 +32,7 @@ class _CatalogBackendProvider(BackendProvider[ServiceT], lang.Abstract):
32
32
  @contextlib.asynccontextmanager
33
33
  async def _provide_backend(self, cls: type[ServiceT]) -> ta.AsyncIterator[ServiceT]:
34
34
  service: ServiceT
35
- async with lang.async_or_sync_maybe_managing(self._catalog.get_backend(
35
+ async with lang.async_or_sync_maybe_managing(self._catalog.new_backend(
36
36
  cls,
37
37
  self._name,
38
38
  *(self._configs or []),
@@ -31,7 +31,7 @@ class CompletionSession(Session['CompletionSession.Config']):
31
31
  prompt = check.isinstance(self._config.content, str)
32
32
 
33
33
  mdl: mc.CompletionService
34
- async with lang.async_maybe_managing(self._backend_catalog.get_backend(
34
+ async with lang.async_maybe_managing(self._backend_catalog.new_backend(
35
35
  mc.CompletionService,
36
36
  self._config.backend or DEFAULT_COMPLETION_MODEL_BACKEND,
37
37
  )) as mdl:
@@ -29,7 +29,7 @@ class EmbeddingSession(Session['EmbeddingSession.Config']):
29
29
 
30
30
  async def run(self) -> None:
31
31
  mdl: mc.EmbeddingService
32
- async with lang.async_maybe_managing(self._backend_catalog.get_backend(
32
+ async with lang.async_maybe_managing(self._backend_catalog.new_backend(
33
33
  mc.EmbeddingService,
34
34
  self._config.backend or DEFAULT_EMBEDDING_MODEL_BACKEND,
35
35
  )) as mdl:
@@ -322,6 +322,7 @@ with _lang.auto_proxy_init(
322
322
  ##
323
323
 
324
324
  from .registries.globals import ( # noqa
325
+ get_registry_cls,
325
326
  register_type,
326
327
  registry_new,
327
328
  registry_of,
@@ -558,6 +559,10 @@ with _lang.auto_proxy_init(
558
559
  )
559
560
 
560
561
  from .standard import ( # noqa
562
+ Device,
563
+
564
+ ApiUrl,
565
+
561
566
  ApiKey,
562
567
 
563
568
  DefaultOptions,
@@ -3,15 +3,28 @@ import typing as ta
3
3
 
4
4
  from omlish import lang
5
5
 
6
+ from ...configs import Config
7
+
8
+
9
+ T = ta.TypeVar('T')
10
+
6
11
 
7
12
  ##
8
13
 
9
14
 
10
15
  class BackendCatalog(lang.Abstract):
16
+ class Backend(ta.NamedTuple):
17
+ factory: ta.Callable[..., ta.Any]
18
+ configs: ta.Sequence[Config] | None
19
+
11
20
  @abc.abstractmethod
12
- def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> ta.Any:
21
+ def get_backend(self, service_cls: type[T], name: str) -> Backend:
13
22
  raise NotImplementedError
14
23
 
24
+ def new_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> ta.Any:
25
+ be = self.get_backend(service_cls, name)
26
+ return be.factory(*be.configs or [], *args, **kwargs)
27
+
15
28
  # #
16
29
  #
17
30
  # class Bound(lang.Final, ta.Generic[T]):
@@ -40,9 +40,9 @@ class SimpleBackendCatalog(BackendCatalog):
40
40
  sc_dct[e.name] = e
41
41
  self._dct = dct
42
42
 
43
- def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> ta.Any:
43
+ def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> BackendCatalog.Backend:
44
44
  e = self._dct[service_cls][name]
45
- return e.factory_fn(*args, **kwargs)
45
+ return BackendCatalog.Backend(e.factory_fn, None)
46
46
 
47
47
 
48
48
  ##
@@ -5,7 +5,7 @@ from omlish import check
5
5
  from ...models.configs import ModelPath
6
6
  from ...models.configs import ModelRepo
7
7
  from ...models.repos.resolving import ModelRepoResolver
8
- from ...registries.globals import registry_new
8
+ from ...registries.globals import get_registry_cls
9
9
  from ..strings.parsing import parse_backend_string
10
10
  from ..strings.resolving import BackendStringResolver
11
11
  from ..strings.resolving import ResolveBackendStringArgs
@@ -30,14 +30,14 @@ class BackendStringBackendCatalog(BackendCatalog):
30
30
  self._string_resolver = string_resolver
31
31
  self._model_repo_resolver = model_repo_resolver
32
32
 
33
- def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> ta.Any:
33
+ def get_backend(self, service_cls: ta.Any, name: str, *args: ta.Any, **kwargs: ta.Any) -> BackendCatalog.Backend:
34
34
  ps = parse_backend_string(name)
35
35
  rs = check.not_none(self._string_resolver.resolve_backend_string(ResolveBackendStringArgs(
36
36
  service_cls,
37
37
  ps,
38
38
  )))
39
39
 
40
- al = list(rs.args or [])
40
+ al: list = list(rs.args or [])
41
41
 
42
42
  # FIXME: lol
43
43
  if al and isinstance(al[0], ModelRepo):
@@ -46,10 +46,12 @@ class BackendStringBackendCatalog(BackendCatalog):
46
46
  mrp = check.not_none(mrr.resolve(mr))
47
47
  al = [ModelPath(mrp.path), *al[1:]]
48
48
 
49
- return registry_new(
49
+ cls = get_registry_cls(
50
50
  service_cls,
51
51
  rs.name,
52
- *al,
53
- *args,
54
- **kwargs,
52
+ )
53
+
54
+ return BackendCatalog.Backend(
55
+ cls,
56
+ al,
55
57
  )
@@ -95,8 +95,7 @@ class AnthropicChatChoicesStreamService:
95
95
  db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
96
96
  sd = sse.SseDecoder()
97
97
  while True:
98
- # FIXME: read1 not on response stream protocol
99
- b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
98
+ b = http_response.stream.read1(self.READ_CHUNK_SIZE)
100
99
  for l in db.feed(b):
101
100
  if isinstance(l, DelimitingBuffer.Incomplete):
102
101
  # FIXME: handle
@@ -169,8 +169,7 @@ class GoogleChatChoicesStreamService:
169
169
  async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
170
170
  db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
171
171
  while True:
172
- # FIXME: read1 not on response stream protocol
173
- b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
172
+ b = http_response.stream.read1(self.READ_CHUNK_SIZE)
174
173
  for bl in db.feed(b):
175
174
  if isinstance(bl, DelimitingBuffer.Incomplete):
176
175
  # FIXME: handle
@@ -30,6 +30,15 @@ from .format import get_msg_content
30
30
  ##
31
31
 
32
32
 
33
+ # @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
34
+ # ['ChatChoicesService'],
35
+ # 'llamacpp',
36
+ # )
37
+
38
+
39
+ ##
40
+
41
+
33
42
  # @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
34
43
  # name='llamacpp',
35
44
  # type='ChatChoicesService',
@@ -29,6 +29,15 @@ from .format import get_msg_content
29
29
  ##
30
30
 
31
31
 
32
+ # @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
33
+ # ['ChatChoicesStreamService'],
34
+ # 'llamacpp',
35
+ # )
36
+
37
+
38
+ ##
39
+
40
+
32
41
  # @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
33
42
  # name='llamacpp',
34
43
  # type='ChatChoicesStreamService',
@@ -76,18 +85,25 @@ class LlamacppChatChoicesStreamService(lang.ExitStacked):
76
85
  rs.enter_context(lang.defer(close_output))
77
86
 
78
87
  async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
88
+ last_role: ta.Any = None
89
+
79
90
  for chunk in output:
80
91
  check.state(chunk['object'] == 'chat.completion.chunk')
81
- l: list[AiChoiceDeltas] = []
82
- for choice in chunk['choices']:
83
- # FIXME: check role is assistant
84
- # FIXME: stop reason
85
- if not (delta := choice.get('delta', {})):
86
- continue
87
- if not (content := delta.get('content', '')):
88
- continue
89
- l.append(AiChoiceDeltas([ContentAiChoiceDelta(content)]))
90
- await sink.emit(AiChoicesDeltas(l))
92
+
93
+ choice = check.single(chunk['choices'])
94
+
95
+ if not (delta := choice.get('delta', {})):
96
+ continue
97
+
98
+ # FIXME: check role is assistant
99
+ if (role := delta.get('role')) != last_role:
100
+ last_role = role
101
+
102
+ # FIXME: stop reason
103
+
104
+ if (content := delta.get('content', '')):
105
+ await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiChoiceDelta(content)])]))
106
+
91
107
  return None
92
108
 
93
109
  return await new_stream_response(rs, inner)
File without changes
@@ -0,0 +1,199 @@
1
+ import typing as ta
2
+
3
+ from omlish import check
4
+ from omlish import lang
5
+ from omlish import marshal as msh
6
+ from omlish import typedvalues as tv
7
+ from omlish.formats import json
8
+ from omlish.http import all as http
9
+ from omlish.io.buffers import DelimitingBuffer
10
+
11
+ from .....backends.ollama import protocol as pt
12
+ from ....chat.choices.services import ChatChoicesOutputs
13
+ from ....chat.choices.services import ChatChoicesRequest
14
+ from ....chat.choices.services import ChatChoicesResponse
15
+ from ....chat.choices.services import static_check_is_chat_choices_service
16
+ from ....chat.choices.types import AiChoice
17
+ from ....chat.messages import AiMessage
18
+ from ....chat.messages import AnyAiMessage
19
+ from ....chat.messages import Message
20
+ from ....chat.messages import SystemMessage
21
+ from ....chat.messages import UserMessage
22
+ from ....chat.stream.services import ChatChoicesStreamRequest
23
+ from ....chat.stream.services import ChatChoicesStreamResponse
24
+ from ....chat.stream.services import static_check_is_chat_choices_stream_service
25
+ from ....chat.stream.types import AiChoiceDeltas
26
+ from ....chat.stream.types import AiChoicesDeltas
27
+ from ....chat.stream.types import ContentAiChoiceDelta
28
+ from ....models.configs import ModelName
29
+ from ....resources import UseResources
30
+ from ....standard import ApiUrl
31
+ from ....stream.services import StreamResponseSink
32
+ from ....stream.services import new_stream_response
33
+
34
+
35
+ ##
36
+
37
+
38
+ # @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
39
+ # [
40
+ # 'ChatChoicesService',
41
+ # 'ChatChoicesStreamService',
42
+ # ],
43
+ # 'ollama',
44
+ # )
45
+
46
+
47
+ ##
48
+
49
+
50
+ class BaseOllamaChatChoicesService(lang.Abstract):
51
+ DEFAULT_API_URL: ta.ClassVar[ApiUrl] = ApiUrl('http://localhost:11434/api')
52
+ DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName('llama3.2')
53
+
54
+ def __init__(
55
+ self,
56
+ *configs: ApiUrl | ModelName,
57
+ http_client: http.AsyncHttpClient | None = None,
58
+ ) -> None:
59
+ super().__init__()
60
+
61
+ self._http_client = http_client
62
+
63
+ with tv.consume(*configs) as cc:
64
+ self._api_url = cc.pop(self.DEFAULT_API_URL)
65
+ self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
66
+
67
+ #
68
+
69
+ ROLE_MAP: ta.ClassVar[ta.Mapping[type[Message], pt.Role]] = {
70
+ SystemMessage: 'system',
71
+ UserMessage: 'user',
72
+ AiMessage: 'assistant',
73
+ }
74
+
75
+ @classmethod
76
+ def _get_message_content(cls, m: Message) -> str | None:
77
+ if isinstance(m, (AiMessage, UserMessage, SystemMessage)):
78
+ return check.isinstance(m.c, str)
79
+ else:
80
+ raise TypeError(m)
81
+
82
+ @classmethod
83
+ def _build_request_messages(cls, mc_msgs: ta.Iterable[Message]) -> ta.Sequence[pt.Message]:
84
+ messages: list[pt.Message] = []
85
+ for m in mc_msgs:
86
+ messages.append(pt.Message(
87
+ role=cls.ROLE_MAP[type(m)],
88
+ content=cls._get_message_content(m),
89
+ ))
90
+ return messages
91
+
92
+
93
+ ##
94
+
95
+
96
+ # @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
97
+ # name='ollama',
98
+ # type='ChatChoicesService',
99
+ # )
100
+ @static_check_is_chat_choices_service
101
+ class OllamaChatChoicesService(BaseOllamaChatChoicesService):
102
+ async def invoke(
103
+ self,
104
+ request: ChatChoicesRequest,
105
+ ) -> ChatChoicesResponse:
106
+ messages = self._build_request_messages(request.v)
107
+
108
+ a_req = pt.ChatRequest(
109
+ model=self._model_name.v,
110
+ messages=messages,
111
+ # tools=tools or None,
112
+ stream=False,
113
+ )
114
+
115
+ raw_request = msh.marshal(a_req)
116
+
117
+ async with http.manage_async_client(self._http_client) as http_client:
118
+ raw_response = await http_client.request(http.HttpRequest(
119
+ self._api_url.v.removesuffix('/') + '/chat',
120
+ data=json.dumps(raw_request).encode('utf-8'),
121
+ ))
122
+
123
+ json_response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
124
+
125
+ resp = msh.unmarshal(json_response, pt.ChatResponse)
126
+
127
+ out: list[AnyAiMessage] = []
128
+ if resp.message.role == 'assistant':
129
+ out.append(AiMessage(
130
+ check.not_none(resp.message.content),
131
+ ))
132
+ else:
133
+ raise TypeError(resp.message.role)
134
+
135
+ return ChatChoicesResponse([
136
+ AiChoice(out),
137
+ ])
138
+
139
+
140
+ ##
141
+
142
+
143
+ # @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
144
+ # name='ollama',
145
+ # type='ChatChoicesStreamService',
146
+ # )
147
+ @static_check_is_chat_choices_stream_service
148
+ class OllamaChatChoicesStreamService(BaseOllamaChatChoicesService):
149
+ READ_CHUNK_SIZE = 64 * 1024
150
+
151
+ async def invoke(
152
+ self,
153
+ request: ChatChoicesStreamRequest,
154
+ ) -> ChatChoicesStreamResponse:
155
+ messages = self._build_request_messages(request.v)
156
+
157
+ a_req = pt.ChatRequest(
158
+ model=self._model_name.v,
159
+ messages=messages,
160
+ # tools=tools or None,
161
+ stream=True,
162
+ )
163
+
164
+ raw_request = msh.marshal(a_req)
165
+
166
+ http_request = http.HttpRequest(
167
+ self._api_url.v.removesuffix('/') + '/chat',
168
+ data=json.dumps(raw_request).encode('utf-8'),
169
+ )
170
+
171
+ async with UseResources.or_new(request.options) as rs:
172
+ http_client = await rs.enter_async_context(http.manage_async_client(self._http_client))
173
+ http_response = await rs.enter_async_context(await http_client.stream_request(http_request))
174
+
175
+ async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
176
+ db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
177
+ while True:
178
+ b = await http_response.stream.read1(self.READ_CHUNK_SIZE)
179
+ for l in db.feed(b):
180
+ if isinstance(l, DelimitingBuffer.Incomplete):
181
+ # FIXME: handle
182
+ return []
183
+
184
+ lj = json.loads(l.decode('utf-8'))
185
+ lp: pt.ChatResponse = msh.unmarshal(lj, pt.ChatResponse)
186
+
187
+ check.state(lp.message.role == 'assistant')
188
+ check.none(lp.message.tool_name)
189
+ check.state(not lp.message.tool_calls)
190
+
191
+ if (c := lp.message.content):
192
+ await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiChoiceDelta(
193
+ c,
194
+ )])]))
195
+
196
+ if not b:
197
+ return []
198
+
199
+ return await new_stream_response(rs, inner)
@@ -88,8 +88,7 @@ class OpenaiChatChoicesStreamService:
88
88
  db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
89
89
  sd = sse.SseDecoder()
90
90
  while True:
91
- # FIXME: read1 not on response stream protocol
92
- b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
91
+ b = http_response.stream.read1(self.READ_CHUNK_SIZE)
93
92
  for l in db.feed(b):
94
93
  if isinstance(l, DelimitingBuffer.Incomplete):
95
94
  # FIXME: handle