ommlds 0.0.0.dev426__py3-none-any.whl → 0.0.0.dev485__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +336 -39
- ommlds/__about__.py +16 -10
- ommlds/_hacks/__init__.py +4 -0
- ommlds/_hacks/funcs.py +110 -0
- ommlds/_hacks/names.py +158 -0
- ommlds/_hacks/params.py +73 -0
- ommlds/_hacks/patches.py +0 -3
- ommlds/backends/anthropic/protocol/__init__.py +13 -1
- ommlds/backends/anthropic/protocol/_dataclasses.py +1625 -0
- ommlds/backends/anthropic/protocol/_marshal.py +2 -2
- ommlds/backends/anthropic/protocol/sse/_marshal.py +1 -1
- ommlds/backends/anthropic/protocol/sse/assemble.py +23 -7
- ommlds/backends/anthropic/protocol/sse/events.py +13 -0
- ommlds/backends/anthropic/protocol/types.py +40 -8
- ommlds/backends/google/protocol/__init__.py +16 -0
- ommlds/backends/google/protocol/_dataclasses.py +5997 -0
- ommlds/backends/google/protocol/_marshal.py +16 -0
- ommlds/backends/google/protocol/types.py +626 -0
- ommlds/backends/groq/__init__.py +7 -0
- ommlds/backends/groq/_dataclasses.py +3901 -0
- ommlds/backends/groq/_marshal.py +23 -0
- ommlds/backends/groq/protocol.py +249 -0
- ommlds/backends/llamacpp/logging.py +4 -1
- ommlds/backends/mlx/caching.py +7 -3
- ommlds/backends/mlx/cli.py +10 -7
- ommlds/backends/mlx/generation.py +19 -17
- ommlds/backends/mlx/limits.py +10 -6
- ommlds/backends/mlx/loading.py +65 -5
- ommlds/backends/ollama/__init__.py +7 -0
- ommlds/backends/ollama/_dataclasses.py +3458 -0
- ommlds/backends/ollama/protocol.py +170 -0
- ommlds/backends/openai/protocol/__init__.py +24 -29
- ommlds/backends/openai/protocol/_common.py +18 -0
- ommlds/backends/openai/protocol/_dataclasses.py +7708 -0
- ommlds/backends/openai/protocol/_marshal.py +27 -0
- ommlds/backends/openai/protocol/chatcompletion/chunk.py +58 -31
- ommlds/backends/openai/protocol/chatcompletion/contentpart.py +49 -44
- ommlds/backends/openai/protocol/chatcompletion/message.py +55 -43
- ommlds/backends/openai/protocol/chatcompletion/request.py +114 -66
- ommlds/backends/openai/protocol/chatcompletion/response.py +71 -45
- ommlds/backends/openai/protocol/chatcompletion/responseformat.py +27 -20
- ommlds/backends/openai/protocol/chatcompletion/tokenlogprob.py +16 -7
- ommlds/backends/openai/protocol/completionusage.py +24 -15
- ommlds/backends/tavily/__init__.py +7 -0
- ommlds/backends/tavily/_dataclasses.py +1734 -0
- ommlds/backends/tavily/protocol.py +301 -0
- ommlds/backends/tinygrad/models/llama3/__init__.py +22 -14
- ommlds/backends/transformers/__init__.py +14 -0
- ommlds/backends/transformers/filecache.py +109 -0
- ommlds/backends/transformers/streamers.py +73 -0
- ommlds/cli/__init__.py +7 -0
- ommlds/cli/_dataclasses.py +2562 -0
- ommlds/cli/asyncs.py +30 -0
- ommlds/cli/backends/catalog.py +93 -0
- ommlds/cli/backends/configs.py +9 -0
- ommlds/cli/backends/inject.py +31 -36
- ommlds/cli/backends/injection.py +16 -0
- ommlds/cli/backends/types.py +46 -0
- ommlds/cli/content/messages.py +34 -0
- ommlds/cli/content/strings.py +42 -0
- ommlds/cli/inject.py +17 -32
- ommlds/cli/inputs/__init__.py +0 -0
- ommlds/cli/inputs/asyncs.py +32 -0
- ommlds/cli/inputs/sync.py +75 -0
- ommlds/cli/main.py +270 -110
- ommlds/cli/rendering/__init__.py +0 -0
- ommlds/cli/rendering/configs.py +9 -0
- ommlds/cli/rendering/inject.py +31 -0
- ommlds/cli/rendering/markdown.py +52 -0
- ommlds/cli/rendering/raw.py +73 -0
- ommlds/cli/rendering/types.py +21 -0
- ommlds/cli/secrets.py +21 -0
- ommlds/cli/sessions/base.py +1 -1
- ommlds/cli/sessions/chat/chat/__init__.py +0 -0
- ommlds/cli/sessions/chat/chat/ai/__init__.py +0 -0
- ommlds/cli/sessions/chat/chat/ai/configs.py +11 -0
- ommlds/cli/sessions/chat/chat/ai/inject.py +74 -0
- ommlds/cli/sessions/chat/chat/ai/injection.py +14 -0
- ommlds/cli/sessions/chat/chat/ai/rendering.py +70 -0
- ommlds/cli/sessions/chat/chat/ai/services.py +79 -0
- ommlds/cli/sessions/chat/chat/ai/tools.py +44 -0
- ommlds/cli/sessions/chat/chat/ai/types.py +28 -0
- ommlds/cli/sessions/chat/chat/state/__init__.py +0 -0
- ommlds/cli/sessions/chat/chat/state/configs.py +11 -0
- ommlds/cli/sessions/chat/chat/state/inject.py +36 -0
- ommlds/cli/sessions/chat/chat/state/inmemory.py +33 -0
- ommlds/cli/sessions/chat/chat/state/storage.py +52 -0
- ommlds/cli/sessions/chat/chat/state/types.py +38 -0
- ommlds/cli/sessions/chat/chat/user/__init__.py +0 -0
- ommlds/cli/sessions/chat/chat/user/configs.py +17 -0
- ommlds/cli/sessions/chat/chat/user/inject.py +62 -0
- ommlds/cli/sessions/chat/chat/user/interactive.py +31 -0
- ommlds/cli/sessions/chat/chat/user/oneshot.py +25 -0
- ommlds/cli/sessions/chat/chat/user/types.py +15 -0
- ommlds/cli/sessions/chat/configs.py +27 -0
- ommlds/cli/sessions/chat/driver.py +43 -0
- ommlds/cli/sessions/chat/inject.py +33 -65
- ommlds/cli/sessions/chat/phases/__init__.py +0 -0
- ommlds/cli/sessions/chat/phases/inject.py +27 -0
- ommlds/cli/sessions/chat/phases/injection.py +14 -0
- ommlds/cli/sessions/chat/phases/manager.py +29 -0
- ommlds/cli/sessions/chat/phases/types.py +29 -0
- ommlds/cli/sessions/chat/session.py +27 -0
- ommlds/cli/sessions/chat/tools/__init__.py +0 -0
- ommlds/cli/sessions/chat/tools/configs.py +22 -0
- ommlds/cli/sessions/chat/tools/confirmation.py +46 -0
- ommlds/cli/sessions/chat/tools/execution.py +66 -0
- ommlds/cli/sessions/chat/tools/fs/__init__.py +0 -0
- ommlds/cli/sessions/chat/tools/fs/configs.py +12 -0
- ommlds/cli/sessions/chat/tools/fs/inject.py +35 -0
- ommlds/cli/sessions/chat/tools/inject.py +88 -0
- ommlds/cli/sessions/chat/tools/injection.py +44 -0
- ommlds/cli/sessions/chat/tools/rendering.py +58 -0
- ommlds/cli/sessions/chat/tools/todo/__init__.py +0 -0
- ommlds/cli/sessions/chat/tools/todo/configs.py +12 -0
- ommlds/cli/sessions/chat/tools/todo/inject.py +31 -0
- ommlds/cli/sessions/chat/tools/weather/__init__.py +0 -0
- ommlds/cli/sessions/chat/tools/weather/configs.py +12 -0
- ommlds/cli/sessions/chat/tools/weather/inject.py +22 -0
- ommlds/cli/{tools/weather.py → sessions/chat/tools/weather/tools.py} +1 -1
- ommlds/cli/sessions/completion/configs.py +21 -0
- ommlds/cli/sessions/completion/inject.py +42 -0
- ommlds/cli/sessions/completion/session.py +35 -0
- ommlds/cli/sessions/embedding/configs.py +21 -0
- ommlds/cli/sessions/embedding/inject.py +42 -0
- ommlds/cli/sessions/embedding/session.py +33 -0
- ommlds/cli/sessions/inject.py +28 -11
- ommlds/cli/state/__init__.py +0 -0
- ommlds/cli/state/inject.py +28 -0
- ommlds/cli/{state.py → state/storage.py} +41 -24
- ommlds/minichain/__init__.py +84 -24
- ommlds/minichain/_dataclasses.py +15401 -0
- ommlds/minichain/_marshal.py +49 -9
- ommlds/minichain/_typedvalues.py +2 -4
- ommlds/minichain/backends/catalogs/base.py +20 -1
- ommlds/minichain/backends/catalogs/simple.py +2 -2
- ommlds/minichain/backends/catalogs/strings.py +10 -8
- ommlds/minichain/backends/impls/anthropic/chat.py +65 -27
- ommlds/minichain/backends/impls/anthropic/names.py +10 -8
- ommlds/minichain/backends/impls/anthropic/protocol.py +109 -0
- ommlds/minichain/backends/impls/anthropic/stream.py +111 -43
- ommlds/minichain/backends/impls/duckduckgo/search.py +6 -2
- ommlds/minichain/backends/impls/dummy/__init__.py +0 -0
- ommlds/minichain/backends/impls/dummy/chat.py +69 -0
- ommlds/minichain/backends/impls/google/chat.py +114 -22
- ommlds/minichain/backends/impls/google/search.py +7 -2
- ommlds/minichain/backends/impls/google/stream.py +219 -0
- ommlds/minichain/backends/impls/google/tools.py +149 -0
- ommlds/minichain/backends/impls/groq/__init__.py +0 -0
- ommlds/minichain/backends/impls/groq/chat.py +75 -0
- ommlds/minichain/backends/impls/groq/names.py +48 -0
- ommlds/minichain/backends/impls/groq/protocol.py +143 -0
- ommlds/minichain/backends/impls/groq/stream.py +125 -0
- ommlds/minichain/backends/impls/huggingface/repos.py +1 -5
- ommlds/minichain/backends/impls/llamacpp/chat.py +40 -22
- ommlds/minichain/backends/impls/llamacpp/completion.py +9 -5
- ommlds/minichain/backends/impls/llamacpp/format.py +4 -2
- ommlds/minichain/backends/impls/llamacpp/stream.py +43 -23
- ommlds/minichain/backends/impls/mistral.py +20 -5
- ommlds/minichain/backends/impls/mlx/chat.py +101 -24
- ommlds/minichain/backends/impls/ollama/__init__.py +0 -0
- ommlds/minichain/backends/impls/ollama/chat.py +199 -0
- ommlds/minichain/backends/impls/openai/chat.py +18 -8
- ommlds/minichain/backends/impls/openai/completion.py +10 -3
- ommlds/minichain/backends/impls/openai/embedding.py +10 -3
- ommlds/minichain/backends/impls/openai/format.py +131 -106
- ommlds/minichain/backends/impls/openai/names.py +31 -5
- ommlds/minichain/backends/impls/openai/stream.py +43 -25
- ommlds/minichain/backends/impls/sentencepiece/tokens.py +9 -6
- ommlds/minichain/backends/impls/tavily.py +66 -0
- ommlds/minichain/backends/impls/tinygrad/chat.py +30 -20
- ommlds/minichain/backends/impls/tokenizers/tokens.py +9 -6
- ommlds/minichain/backends/impls/transformers/sentence.py +6 -3
- ommlds/minichain/backends/impls/transformers/tokens.py +10 -7
- ommlds/minichain/backends/impls/transformers/transformers.py +160 -37
- ommlds/minichain/backends/strings/parsing.py +1 -1
- ommlds/minichain/backends/strings/resolving.py +4 -1
- ommlds/minichain/chat/_marshal.py +16 -9
- ommlds/minichain/chat/choices/adapters.py +4 -4
- ommlds/minichain/chat/choices/services.py +1 -1
- ommlds/minichain/chat/choices/stream/__init__.py +0 -0
- ommlds/minichain/chat/choices/stream/adapters.py +35 -0
- ommlds/minichain/chat/choices/stream/joining.py +31 -0
- ommlds/minichain/chat/choices/stream/services.py +45 -0
- ommlds/minichain/chat/choices/stream/types.py +43 -0
- ommlds/minichain/chat/choices/types.py +2 -2
- ommlds/minichain/chat/history.py +3 -3
- ommlds/minichain/chat/messages.py +55 -19
- ommlds/minichain/chat/services.py +3 -3
- ommlds/minichain/chat/stream/_marshal.py +16 -0
- ommlds/minichain/chat/stream/joining.py +85 -0
- ommlds/minichain/chat/stream/services.py +15 -21
- ommlds/minichain/chat/stream/types.py +32 -19
- ommlds/minichain/chat/tools/execution.py +8 -7
- ommlds/minichain/chat/tools/ids.py +9 -15
- ommlds/minichain/chat/tools/parsing.py +17 -26
- ommlds/minichain/chat/transforms/base.py +29 -38
- ommlds/minichain/chat/transforms/metadata.py +30 -4
- ommlds/minichain/chat/transforms/services.py +9 -11
- ommlds/minichain/content/_marshal.py +44 -20
- ommlds/minichain/content/json.py +13 -0
- ommlds/minichain/content/materialize.py +14 -21
- ommlds/minichain/content/prepare.py +4 -0
- ommlds/minichain/content/transforms/interleave.py +1 -1
- ommlds/minichain/content/transforms/squeeze.py +1 -1
- ommlds/minichain/content/transforms/stringify.py +1 -1
- ommlds/minichain/json.py +20 -0
- ommlds/minichain/lib/code/__init__.py +0 -0
- ommlds/minichain/lib/code/prompts.py +6 -0
- ommlds/minichain/lib/fs/binfiles.py +108 -0
- ommlds/minichain/lib/fs/context.py +126 -0
- ommlds/minichain/lib/fs/errors.py +101 -0
- ommlds/minichain/lib/fs/suggestions.py +36 -0
- ommlds/minichain/lib/fs/tools/__init__.py +0 -0
- ommlds/minichain/lib/fs/tools/edit.py +104 -0
- ommlds/minichain/lib/fs/tools/ls.py +38 -0
- ommlds/minichain/lib/fs/tools/read.py +115 -0
- ommlds/minichain/lib/fs/tools/recursivels/__init__.py +0 -0
- ommlds/minichain/lib/fs/tools/recursivels/execution.py +40 -0
- ommlds/minichain/lib/todo/__init__.py +0 -0
- ommlds/minichain/lib/todo/context.py +54 -0
- ommlds/minichain/lib/todo/tools/__init__.py +0 -0
- ommlds/minichain/lib/todo/tools/read.py +44 -0
- ommlds/minichain/lib/todo/tools/write.py +335 -0
- ommlds/minichain/lib/todo/types.py +60 -0
- ommlds/minichain/llms/_marshal.py +25 -17
- ommlds/minichain/llms/types.py +4 -0
- ommlds/minichain/registries/globals.py +18 -4
- ommlds/minichain/resources.py +68 -45
- ommlds/minichain/search.py +1 -1
- ommlds/minichain/services/_marshal.py +46 -39
- ommlds/minichain/services/facades.py +3 -3
- ommlds/minichain/services/services.py +1 -1
- ommlds/minichain/standard.py +8 -0
- ommlds/minichain/stream/services.py +152 -38
- ommlds/minichain/stream/wrap.py +22 -24
- ommlds/minichain/text/toolparsing/llamacpp/hermes2.py +3 -2
- ommlds/minichain/text/toolparsing/llamacpp/llama31.py +3 -2
- ommlds/minichain/text/toolparsing/llamacpp/utils.py +3 -2
- ommlds/minichain/tools/_marshal.py +1 -1
- ommlds/minichain/tools/execution/catalog.py +2 -1
- ommlds/minichain/tools/execution/context.py +34 -14
- ommlds/minichain/tools/execution/errors.py +15 -0
- ommlds/minichain/tools/execution/executors.py +8 -3
- ommlds/minichain/tools/execution/reflect.py +40 -5
- ommlds/minichain/tools/fns.py +46 -9
- ommlds/minichain/tools/jsonschema.py +14 -5
- ommlds/minichain/tools/reflect.py +54 -18
- ommlds/minichain/tools/types.py +33 -1
- ommlds/minichain/utils.py +27 -0
- ommlds/minichain/vectors/_marshal.py +11 -10
- ommlds/minichain/vectors/types.py +1 -1
- ommlds/nanochat/LICENSE +21 -0
- ommlds/nanochat/__init__.py +0 -0
- ommlds/nanochat/rustbpe/LICENSE +21 -0
- ommlds/nanochat/tokenizers.py +406 -0
- ommlds/server/cli.py +1 -2
- ommlds/server/server.py +5 -5
- ommlds/server/service.py +1 -1
- ommlds/specs/__init__.py +0 -0
- ommlds/specs/mcp/__init__.py +0 -0
- ommlds/specs/mcp/_marshal.py +23 -0
- ommlds/specs/mcp/clients.py +146 -0
- ommlds/specs/mcp/protocol.py +371 -0
- ommlds/tools/git.py +35 -12
- ommlds/tools/ocr.py +8 -9
- ommlds/wiki/analyze.py +6 -7
- ommlds/wiki/text/mfh.py +1 -5
- ommlds/wiki/text/wtp.py +1 -3
- ommlds/wiki/utils/xml.py +5 -5
- {ommlds-0.0.0.dev426.dist-info → ommlds-0.0.0.dev485.dist-info}/METADATA +24 -21
- ommlds-0.0.0.dev485.dist-info/RECORD +436 -0
- ommlds/cli/backends/standard.py +0 -20
- ommlds/cli/sessions/chat/base.py +0 -42
- ommlds/cli/sessions/chat/interactive.py +0 -73
- ommlds/cli/sessions/chat/printing.py +0 -96
- ommlds/cli/sessions/chat/prompt.py +0 -143
- ommlds/cli/sessions/chat/state.py +0 -109
- ommlds/cli/sessions/chat/tools.py +0 -91
- ommlds/cli/sessions/completion/completion.py +0 -44
- ommlds/cli/sessions/embedding/embedding.py +0 -42
- ommlds/cli/tools/config.py +0 -13
- ommlds/cli/tools/inject.py +0 -64
- ommlds/minichain/chat/stream/adapters.py +0 -69
- ommlds/minichain/lib/fs/ls/execution.py +0 -32
- ommlds-0.0.0.dev426.dist-info/RECORD +0 -303
- /ommlds/{cli/tools → backends/google}/__init__.py +0 -0
- /ommlds/{huggingface.py → backends/huggingface.py} +0 -0
- /ommlds/{minichain/lib/fs/ls → cli/content}/__init__.py +0 -0
- /ommlds/minichain/lib/fs/{ls → tools/recursivels}/rendering.py +0 -0
- /ommlds/minichain/lib/fs/{ls → tools/recursivels}/running.py +0 -0
- {ommlds-0.0.0.dev426.dist-info → ommlds-0.0.0.dev485.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev426.dist-info → ommlds-0.0.0.dev485.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev426.dist-info → ommlds-0.0.0.dev485.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev426.dist-info → ommlds-0.0.0.dev485.dist-info}/top_level.txt +0 -0
|
@@ -1,32 +1,46 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import typing as ta
|
|
2
3
|
|
|
3
4
|
from omlish import check
|
|
4
5
|
from omlish import lang
|
|
5
6
|
from omlish import typedvalues as tv
|
|
6
7
|
|
|
7
|
-
from
|
|
8
|
+
from ....chat.choices.services import ChatChoicesOutputs
|
|
8
9
|
from ....chat.choices.services import ChatChoicesRequest
|
|
9
10
|
from ....chat.choices.services import ChatChoicesResponse
|
|
10
11
|
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
12
|
+
from ....chat.choices.stream.services import ChatChoicesStreamRequest
|
|
13
|
+
from ....chat.choices.stream.services import ChatChoicesStreamResponse
|
|
14
|
+
from ....chat.choices.stream.services import static_check_is_chat_choices_stream_service
|
|
15
|
+
from ....chat.choices.stream.types import AiChoiceDeltas
|
|
16
|
+
from ....chat.choices.stream.types import AiChoicesDeltas
|
|
11
17
|
from ....chat.choices.types import AiChoice
|
|
12
18
|
from ....chat.choices.types import ChatChoicesOptions
|
|
13
19
|
from ....chat.messages import AiMessage
|
|
14
20
|
from ....chat.messages import Message
|
|
15
21
|
from ....chat.messages import SystemMessage
|
|
16
22
|
from ....chat.messages import UserMessage
|
|
23
|
+
from ....chat.stream.types import ContentAiDelta
|
|
17
24
|
from ....configs import Config
|
|
18
25
|
from ....llms.types import MaxTokens
|
|
19
26
|
from ....models.configs import ModelPath
|
|
20
27
|
from ....models.configs import ModelRepo
|
|
21
28
|
from ....models.configs import ModelSpecifier
|
|
29
|
+
from ....resources import UseResources
|
|
22
30
|
from ....standard import DefaultOptions
|
|
31
|
+
from ....stream.services import StreamResponseSink
|
|
32
|
+
from ....stream.services import new_stream_response
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
with lang.auto_proxy_import(globals()):
|
|
36
|
+
from .....backends import mlx as mlxu
|
|
23
37
|
|
|
24
38
|
|
|
25
39
|
##
|
|
26
40
|
|
|
27
41
|
|
|
28
42
|
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
29
|
-
# ['ChatChoicesService'],
|
|
43
|
+
# ['ChatChoicesService', 'ChatChoicesStreamService'],
|
|
30
44
|
# 'mlx',
|
|
31
45
|
# )
|
|
32
46
|
|
|
@@ -34,12 +48,7 @@ from ....standard import DefaultOptions
|
|
|
34
48
|
##
|
|
35
49
|
|
|
36
50
|
|
|
37
|
-
|
|
38
|
-
# name='mlx',
|
|
39
|
-
# type='ChatChoicesService',
|
|
40
|
-
# )
|
|
41
|
-
@static_check_is_chat_choices_service
|
|
42
|
-
class MlxChatChoicesService(lang.ExitStacked):
|
|
51
|
+
class BaseMlxChatChoicesService(lang.ExitStacked):
|
|
43
52
|
DEFAULT_MODEL: ta.ClassVar[ModelSpecifier] = (
|
|
44
53
|
# 'mlx-community/DeepSeek-Coder-V2-Lite-Instruct-8bit'
|
|
45
54
|
# 'mlx-community/Llama-3.3-70B-Instruct-4bit'
|
|
@@ -52,8 +61,8 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
52
61
|
# 'mlx-community/Qwen2.5-0.5B-4bit'
|
|
53
62
|
# 'mlx-community/Qwen2.5-32B-Instruct-8bit'
|
|
54
63
|
# 'mlx-community/Qwen2.5-Coder-32B-Instruct-8bit'
|
|
55
|
-
# 'mlx-community/mamba-2.8b-hf-f16'
|
|
56
64
|
# 'mlx-community/Qwen3-30B-A3B-6bit'
|
|
65
|
+
# 'mlx-community/mamba-2.8b-hf-f16'
|
|
57
66
|
)
|
|
58
67
|
|
|
59
68
|
def __init__(self, *configs: Config) -> None:
|
|
@@ -70,17 +79,14 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
70
79
|
}
|
|
71
80
|
|
|
72
81
|
def _get_msg_content(self, m: Message) -> str | None:
|
|
73
|
-
if isinstance(m, AiMessage):
|
|
74
|
-
return check.isinstance(m.c, str)
|
|
75
|
-
|
|
76
|
-
elif isinstance(m, (SystemMessage, UserMessage)):
|
|
82
|
+
if isinstance(m, (AiMessage, SystemMessage, UserMessage)):
|
|
77
83
|
return check.isinstance(m.c, str)
|
|
78
84
|
|
|
79
85
|
else:
|
|
80
86
|
raise TypeError(m)
|
|
81
87
|
|
|
82
88
|
@lang.cached_function(transient=True)
|
|
83
|
-
def _load_model(self) -> mlxu.LoadedModel:
|
|
89
|
+
def _load_model(self) -> 'mlxu.LoadedModel':
|
|
84
90
|
# FIXME: walk state, find all mx.arrays, dealloc/set to empty
|
|
85
91
|
check.not_none(self._exit_stack)
|
|
86
92
|
|
|
@@ -96,10 +102,9 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
96
102
|
max_tokens=MaxTokens,
|
|
97
103
|
)
|
|
98
104
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
tokenizer = loaded_model.tokenization.tokenizer
|
|
105
|
+
@lang.cached_function(transient=True)
|
|
106
|
+
def _get_tokenizer(self) -> 'mlxu.tokenization.Tokenizer':
|
|
107
|
+
tokenizer = self._load_model().tokenization.tokenizer
|
|
103
108
|
|
|
104
109
|
if not (
|
|
105
110
|
hasattr(tokenizer, 'apply_chat_template') and
|
|
@@ -107,26 +112,44 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
107
112
|
):
|
|
108
113
|
raise RuntimeError(tokenizer)
|
|
109
114
|
|
|
110
|
-
|
|
115
|
+
return tokenizer
|
|
116
|
+
|
|
117
|
+
def _build_prompt(self, messages: ta.Sequence[Message]) -> str:
|
|
118
|
+
return check.isinstance(self._get_tokenizer().apply_chat_template(
|
|
111
119
|
[ # type: ignore[arg-type]
|
|
112
120
|
dict(
|
|
113
121
|
role=self.ROLES_MAP[type(m)],
|
|
114
122
|
content=self._get_msg_content(m),
|
|
115
123
|
)
|
|
116
|
-
for m in
|
|
124
|
+
for m in messages
|
|
117
125
|
],
|
|
118
126
|
tokenize=False,
|
|
119
127
|
add_generation_prompt=True,
|
|
120
|
-
)
|
|
128
|
+
), str)
|
|
129
|
+
|
|
130
|
+
def _build_kwargs(self, oc: tv.TypedValuesConsumer) -> dict[str, ta.Any]:
|
|
131
|
+
kwargs: dict[str, ta.Any] = {}
|
|
132
|
+
kwargs.update(oc.pop_scalar_kwargs(**self._OPTION_KWARG_NAMES_MAP))
|
|
133
|
+
return kwargs
|
|
121
134
|
|
|
122
|
-
|
|
135
|
+
|
|
136
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
137
|
+
# name='mlx',
|
|
138
|
+
# type='ChatChoicesService',
|
|
139
|
+
# )
|
|
140
|
+
@static_check_is_chat_choices_service
|
|
141
|
+
class MlxChatChoicesService(BaseMlxChatChoicesService):
|
|
142
|
+
async def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
143
|
+
loaded_model = self._load_model()
|
|
144
|
+
|
|
145
|
+
prompt = self._build_prompt(request.v)
|
|
123
146
|
|
|
124
147
|
with tv.consume(
|
|
125
148
|
*self._default_options,
|
|
126
149
|
*request.options,
|
|
127
150
|
override=True,
|
|
128
151
|
) as oc:
|
|
129
|
-
kwargs.
|
|
152
|
+
kwargs = self._build_kwargs(oc)
|
|
130
153
|
|
|
131
154
|
response = mlxu.generate(
|
|
132
155
|
loaded_model.model,
|
|
@@ -137,5 +160,59 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
137
160
|
)
|
|
138
161
|
|
|
139
162
|
return ChatChoicesResponse([
|
|
140
|
-
AiChoice(AiMessage(response)) # noqa
|
|
163
|
+
AiChoice([AiMessage(response)]) # noqa
|
|
141
164
|
])
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
168
|
+
# name='mlx',
|
|
169
|
+
# type='ChatChoicesStreamService',
|
|
170
|
+
# )
|
|
171
|
+
@static_check_is_chat_choices_stream_service
|
|
172
|
+
class MlxChatChoicesStreamService(BaseMlxChatChoicesService):
|
|
173
|
+
def __init__(self, *configs: Config) -> None:
|
|
174
|
+
super().__init__()
|
|
175
|
+
|
|
176
|
+
with tv.consume(*configs) as cc:
|
|
177
|
+
self._model = cc.pop(MlxChatChoicesService.DEFAULT_MODEL)
|
|
178
|
+
self._default_options: tv.TypedValues = DefaultOptions.pop(cc)
|
|
179
|
+
|
|
180
|
+
READ_CHUNK_SIZE = 64 * 1024
|
|
181
|
+
|
|
182
|
+
async def invoke(
|
|
183
|
+
self,
|
|
184
|
+
request: ChatChoicesStreamRequest,
|
|
185
|
+
*,
|
|
186
|
+
max_tokens: int = 4096, # FIXME: ChatOption
|
|
187
|
+
) -> ChatChoicesStreamResponse:
|
|
188
|
+
loaded_model = self._load_model()
|
|
189
|
+
|
|
190
|
+
prompt = self._build_prompt(request.v)
|
|
191
|
+
|
|
192
|
+
with tv.consume(
|
|
193
|
+
*self._default_options,
|
|
194
|
+
*request.options,
|
|
195
|
+
override=True,
|
|
196
|
+
) as oc:
|
|
197
|
+
oc.pop(UseResources, None)
|
|
198
|
+
kwargs = self._build_kwargs(oc)
|
|
199
|
+
|
|
200
|
+
async with UseResources.or_new(request.options) as rs:
|
|
201
|
+
gen: ta.Iterator[mlxu.GenerationOutput] = rs.enter_context(contextlib.closing(mlxu.stream_generate(
|
|
202
|
+
loaded_model.model,
|
|
203
|
+
loaded_model.tokenization,
|
|
204
|
+
check.isinstance(prompt, str),
|
|
205
|
+
mlxu.GenerationParams(**kwargs),
|
|
206
|
+
# verbose=True,
|
|
207
|
+
)))
|
|
208
|
+
|
|
209
|
+
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs]:
|
|
210
|
+
for go in gen:
|
|
211
|
+
if go.text:
|
|
212
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([
|
|
213
|
+
ContentAiDelta(go.text),
|
|
214
|
+
])]))
|
|
215
|
+
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
return await new_stream_response(rs, inner)
|
|
File without changes
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import lang
|
|
5
|
+
from omlish import marshal as msh
|
|
6
|
+
from omlish import typedvalues as tv
|
|
7
|
+
from omlish.formats import json
|
|
8
|
+
from omlish.http import all as http
|
|
9
|
+
from omlish.io.buffers import DelimitingBuffer
|
|
10
|
+
|
|
11
|
+
from .....backends.ollama import protocol as pt
|
|
12
|
+
from ....chat.choices.services import ChatChoicesOutputs
|
|
13
|
+
from ....chat.choices.services import ChatChoicesRequest
|
|
14
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
15
|
+
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
16
|
+
from ....chat.choices.stream.services import ChatChoicesStreamRequest
|
|
17
|
+
from ....chat.choices.stream.services import ChatChoicesStreamResponse
|
|
18
|
+
from ....chat.choices.stream.services import static_check_is_chat_choices_stream_service
|
|
19
|
+
from ....chat.choices.stream.types import AiChoiceDeltas
|
|
20
|
+
from ....chat.choices.stream.types import AiChoicesDeltas
|
|
21
|
+
from ....chat.choices.types import AiChoice
|
|
22
|
+
from ....chat.messages import AiMessage
|
|
23
|
+
from ....chat.messages import AnyAiMessage
|
|
24
|
+
from ....chat.messages import Message
|
|
25
|
+
from ....chat.messages import SystemMessage
|
|
26
|
+
from ....chat.messages import UserMessage
|
|
27
|
+
from ....chat.stream.types import ContentAiDelta
|
|
28
|
+
from ....models.configs import ModelName
|
|
29
|
+
from ....resources import UseResources
|
|
30
|
+
from ....standard import ApiUrl
|
|
31
|
+
from ....stream.services import StreamResponseSink
|
|
32
|
+
from ....stream.services import new_stream_response
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
##
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
39
|
+
# [
|
|
40
|
+
# 'ChatChoicesService',
|
|
41
|
+
# 'ChatChoicesStreamService',
|
|
42
|
+
# ],
|
|
43
|
+
# 'ollama',
|
|
44
|
+
# )
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
##
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BaseOllamaChatChoicesService(lang.Abstract):
|
|
51
|
+
DEFAULT_API_URL: ta.ClassVar[ApiUrl] = ApiUrl('http://localhost:11434/api')
|
|
52
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName('llama3.2')
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*configs: ApiUrl | ModelName,
|
|
57
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__()
|
|
60
|
+
|
|
61
|
+
self._http_client = http_client
|
|
62
|
+
|
|
63
|
+
with tv.consume(*configs) as cc:
|
|
64
|
+
self._api_url = cc.pop(self.DEFAULT_API_URL)
|
|
65
|
+
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
66
|
+
|
|
67
|
+
#
|
|
68
|
+
|
|
69
|
+
ROLE_MAP: ta.ClassVar[ta.Mapping[type[Message], pt.Role]] = { # noqa
|
|
70
|
+
SystemMessage: 'system',
|
|
71
|
+
UserMessage: 'user',
|
|
72
|
+
AiMessage: 'assistant',
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _get_message_content(cls, m: Message) -> str | None:
|
|
77
|
+
if isinstance(m, (AiMessage, UserMessage, SystemMessage)):
|
|
78
|
+
return check.isinstance(m.c, str)
|
|
79
|
+
else:
|
|
80
|
+
raise TypeError(m)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _build_request_messages(cls, mc_msgs: ta.Iterable[Message]) -> ta.Sequence[pt.Message]:
|
|
84
|
+
messages: list[pt.Message] = []
|
|
85
|
+
for m in mc_msgs:
|
|
86
|
+
messages.append(pt.Message(
|
|
87
|
+
role=cls.ROLE_MAP[type(m)],
|
|
88
|
+
content=cls._get_message_content(m),
|
|
89
|
+
))
|
|
90
|
+
return messages
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
##
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
97
|
+
# name='ollama',
|
|
98
|
+
# type='ChatChoicesService',
|
|
99
|
+
# )
|
|
100
|
+
@static_check_is_chat_choices_service
|
|
101
|
+
class OllamaChatChoicesService(BaseOllamaChatChoicesService):
|
|
102
|
+
async def invoke(
|
|
103
|
+
self,
|
|
104
|
+
request: ChatChoicesRequest,
|
|
105
|
+
) -> ChatChoicesResponse:
|
|
106
|
+
messages = self._build_request_messages(request.v)
|
|
107
|
+
|
|
108
|
+
a_req = pt.ChatRequest(
|
|
109
|
+
model=self._model_name.v,
|
|
110
|
+
messages=messages,
|
|
111
|
+
# tools=tools or None,
|
|
112
|
+
stream=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
raw_request = msh.marshal(a_req)
|
|
116
|
+
|
|
117
|
+
async with http.manage_async_client(self._http_client) as http_client:
|
|
118
|
+
raw_response = await http_client.request(http.HttpRequest(
|
|
119
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
120
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
121
|
+
))
|
|
122
|
+
|
|
123
|
+
json_response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|
|
124
|
+
|
|
125
|
+
resp = msh.unmarshal(json_response, pt.ChatResponse)
|
|
126
|
+
|
|
127
|
+
out: list[AnyAiMessage] = []
|
|
128
|
+
if resp.message.role == 'assistant':
|
|
129
|
+
out.append(AiMessage(
|
|
130
|
+
check.not_none(resp.message.content),
|
|
131
|
+
))
|
|
132
|
+
else:
|
|
133
|
+
raise TypeError(resp.message.role)
|
|
134
|
+
|
|
135
|
+
return ChatChoicesResponse([
|
|
136
|
+
AiChoice(out),
|
|
137
|
+
])
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
##
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
144
|
+
# name='ollama',
|
|
145
|
+
# type='ChatChoicesStreamService',
|
|
146
|
+
# )
|
|
147
|
+
@static_check_is_chat_choices_stream_service
|
|
148
|
+
class OllamaChatChoicesStreamService(BaseOllamaChatChoicesService):
|
|
149
|
+
READ_CHUNK_SIZE: ta.ClassVar[int] = -1
|
|
150
|
+
|
|
151
|
+
async def invoke(
|
|
152
|
+
self,
|
|
153
|
+
request: ChatChoicesStreamRequest,
|
|
154
|
+
) -> ChatChoicesStreamResponse:
|
|
155
|
+
messages = self._build_request_messages(request.v)
|
|
156
|
+
|
|
157
|
+
a_req = pt.ChatRequest(
|
|
158
|
+
model=self._model_name.v,
|
|
159
|
+
messages=messages,
|
|
160
|
+
# tools=tools or None,
|
|
161
|
+
stream=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
raw_request = msh.marshal(a_req)
|
|
165
|
+
|
|
166
|
+
http_request = http.HttpRequest(
|
|
167
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
168
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
async with UseResources.or_new(request.options) as rs:
|
|
172
|
+
http_client = await rs.enter_async_context(http.manage_async_client(self._http_client))
|
|
173
|
+
http_response = await rs.enter_async_context(await http_client.stream_request(http_request))
|
|
174
|
+
|
|
175
|
+
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
|
|
176
|
+
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
177
|
+
while True:
|
|
178
|
+
b = await http_response.stream.read1(self.READ_CHUNK_SIZE)
|
|
179
|
+
for l in db.feed(b):
|
|
180
|
+
if isinstance(l, DelimitingBuffer.Incomplete):
|
|
181
|
+
# FIXME: handle
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
lj = json.loads(l.decode('utf-8'))
|
|
185
|
+
lp: pt.ChatResponse = msh.unmarshal(lj, pt.ChatResponse)
|
|
186
|
+
|
|
187
|
+
check.state(lp.message.role == 'assistant')
|
|
188
|
+
check.none(lp.message.tool_name)
|
|
189
|
+
check.state(not lp.message.tool_calls)
|
|
190
|
+
|
|
191
|
+
if (c := lp.message.content):
|
|
192
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiDelta(
|
|
193
|
+
c,
|
|
194
|
+
)])]))
|
|
195
|
+
|
|
196
|
+
if not b:
|
|
197
|
+
return []
|
|
198
|
+
|
|
199
|
+
return await new_stream_response(rs, inner)
|
|
@@ -14,10 +14,12 @@ TODO:
|
|
|
14
14
|
import typing as ta
|
|
15
15
|
|
|
16
16
|
from omlish import check
|
|
17
|
+
from omlish import marshal as msh
|
|
17
18
|
from omlish import typedvalues as tv
|
|
18
19
|
from omlish.formats import json
|
|
19
20
|
from omlish.http import all as http
|
|
20
21
|
|
|
22
|
+
from .....backends.openai import protocol as pt
|
|
21
23
|
from ....chat.choices.services import ChatChoicesRequest
|
|
22
24
|
from ....chat.choices.services import ChatChoicesResponse
|
|
23
25
|
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
@@ -25,7 +27,8 @@ from ....models.configs import ModelName
|
|
|
25
27
|
from ....standard import ApiKey
|
|
26
28
|
from ....standard import DefaultOptions
|
|
27
29
|
from .format import OpenaiChatRequestHandler
|
|
28
|
-
from .
|
|
30
|
+
from .format import build_mc_choices_response
|
|
31
|
+
from .names import CHAT_MODEL_NAMES
|
|
29
32
|
|
|
30
33
|
|
|
31
34
|
##
|
|
@@ -37,17 +40,23 @@ from .names import MODEL_NAMES
|
|
|
37
40
|
# )
|
|
38
41
|
@static_check_is_chat_choices_service
|
|
39
42
|
class OpenaiChatChoicesService:
|
|
40
|
-
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName(check.not_none(
|
|
43
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName(check.not_none(CHAT_MODEL_NAMES.default))
|
|
41
44
|
|
|
42
|
-
def __init__(
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
*configs: ApiKey | ModelName | DefaultOptions,
|
|
48
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
49
|
+
) -> None:
|
|
43
50
|
super().__init__()
|
|
44
51
|
|
|
52
|
+
self._http_client = http_client
|
|
53
|
+
|
|
45
54
|
with tv.consume(*configs) as cc:
|
|
46
55
|
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
47
56
|
self._api_key = ApiKey.pop_secret(cc, env='OPENAI_API_KEY')
|
|
48
57
|
self._default_options: tv.TypedValues = DefaultOptions.pop(cc)
|
|
49
58
|
|
|
50
|
-
def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
59
|
+
async def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
51
60
|
# check.isinstance(request, ChatRequest)
|
|
52
61
|
|
|
53
62
|
rh = OpenaiChatRequestHandler(
|
|
@@ -57,23 +66,24 @@ class OpenaiChatChoicesService:
|
|
|
57
66
|
*request.options,
|
|
58
67
|
override=True,
|
|
59
68
|
),
|
|
60
|
-
model=
|
|
69
|
+
model=CHAT_MODEL_NAMES.resolve(self._model_name.v),
|
|
61
70
|
mandatory_kwargs=dict(
|
|
62
71
|
stream=False,
|
|
63
72
|
),
|
|
64
73
|
)
|
|
65
74
|
|
|
66
|
-
raw_request = rh.
|
|
75
|
+
raw_request = msh.marshal(rh.oai_request())
|
|
67
76
|
|
|
68
|
-
http_response = http.
|
|
77
|
+
http_response = await http.async_request(
|
|
69
78
|
'https://api.openai.com/v1/chat/completions',
|
|
70
79
|
headers={
|
|
71
80
|
http.consts.HEADER_CONTENT_TYPE: http.consts.CONTENT_TYPE_JSON,
|
|
72
81
|
http.consts.HEADER_AUTH: http.consts.format_bearer_auth_header(check.not_none(self._api_key).reveal()),
|
|
73
82
|
},
|
|
74
83
|
data=json.dumps(raw_request).encode('utf-8'),
|
|
84
|
+
client=self._http_client,
|
|
75
85
|
)
|
|
76
86
|
|
|
77
87
|
raw_response = json.loads(check.not_none(http_response.data).decode('utf-8'))
|
|
78
88
|
|
|
79
|
-
return
|
|
89
|
+
return build_mc_choices_response(msh.unmarshal(raw_response, pt.ChatCompletionResponse))
|
|
@@ -23,13 +23,19 @@ from ....standard import ApiKey
|
|
|
23
23
|
class OpenaiCompletionService:
|
|
24
24
|
DEFAULT_MODEL_NAME: ta.ClassVar[str] = 'gpt-3.5-turbo-instruct'
|
|
25
25
|
|
|
26
|
-
def __init__(
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
*configs: Config,
|
|
29
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
30
|
+
) -> None:
|
|
27
31
|
super().__init__()
|
|
28
32
|
|
|
33
|
+
self._http_client = http_client
|
|
34
|
+
|
|
29
35
|
with tv.consume(*configs) as cc:
|
|
30
36
|
self._api_key = ApiKey.pop_secret(cc, env='OPENAI_API_KEY')
|
|
31
37
|
|
|
32
|
-
def invoke(self, t: CompletionRequest) -> CompletionResponse:
|
|
38
|
+
async def invoke(self, t: CompletionRequest) -> CompletionResponse:
|
|
33
39
|
raw_request = dict(
|
|
34
40
|
model=self.DEFAULT_MODEL_NAME,
|
|
35
41
|
prompt=t.v,
|
|
@@ -41,13 +47,14 @@ class OpenaiCompletionService:
|
|
|
41
47
|
stream=False,
|
|
42
48
|
)
|
|
43
49
|
|
|
44
|
-
raw_response = http.
|
|
50
|
+
raw_response = await http.async_request(
|
|
45
51
|
'https://api.openai.com/v1/completions',
|
|
46
52
|
headers={
|
|
47
53
|
http.consts.HEADER_CONTENT_TYPE: http.consts.CONTENT_TYPE_JSON,
|
|
48
54
|
http.consts.HEADER_AUTH: http.consts.format_bearer_auth_header(check.not_none(self._api_key).reveal()),
|
|
49
55
|
},
|
|
50
56
|
data=json.dumps(raw_request).encode('utf-8'),
|
|
57
|
+
client=self._http_client,
|
|
51
58
|
)
|
|
52
59
|
|
|
53
60
|
response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|
|
@@ -22,25 +22,32 @@ from ....vectors.types import Vector
|
|
|
22
22
|
class OpenaiEmbeddingService:
|
|
23
23
|
model = 'text-embedding-3-small'
|
|
24
24
|
|
|
25
|
-
def __init__(
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
*configs: Config,
|
|
28
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
29
|
+
) -> None:
|
|
26
30
|
super().__init__()
|
|
27
31
|
|
|
32
|
+
self._http_client = http_client
|
|
33
|
+
|
|
28
34
|
with tv.consume(*configs) as cc:
|
|
29
35
|
self._api_key = ApiKey.pop_secret(cc, env='OPENAI_API_KEY')
|
|
30
36
|
|
|
31
|
-
def invoke(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
37
|
+
async def invoke(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
32
38
|
raw_request = dict(
|
|
33
39
|
model=self.model,
|
|
34
40
|
input=check.isinstance(request.v, str),
|
|
35
41
|
)
|
|
36
42
|
|
|
37
|
-
raw_response = http.
|
|
43
|
+
raw_response = await http.async_request(
|
|
38
44
|
'https://api.openai.com/v1/embeddings',
|
|
39
45
|
headers={
|
|
40
46
|
http.consts.HEADER_CONTENT_TYPE: http.consts.CONTENT_TYPE_JSON,
|
|
41
47
|
http.consts.HEADER_AUTH: http.consts.format_bearer_auth_header(check.not_none(self._api_key).reveal()),
|
|
42
48
|
},
|
|
43
49
|
data=json.dumps(raw_request).encode('utf-8'),
|
|
50
|
+
client=self._http_client,
|
|
44
51
|
)
|
|
45
52
|
|
|
46
53
|
response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|