ommlds 0.0.0.dev480__py3-none-any.whl → 0.0.0.dev503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +100 -33
- ommlds/README.md +11 -0
- ommlds/__about__.py +9 -6
- ommlds/backends/anthropic/protocol/__init__.py +13 -1
- ommlds/backends/anthropic/protocol/_dataclasses.py +1625 -0
- ommlds/backends/anthropic/protocol/sse/events.py +2 -0
- ommlds/backends/cerebras/__init__.py +7 -0
- ommlds/backends/cerebras/_dataclasses.py +4254 -0
- ommlds/backends/cerebras/_marshal.py +24 -0
- ommlds/backends/cerebras/protocol.py +312 -0
- ommlds/backends/google/protocol/__init__.py +13 -0
- ommlds/backends/google/protocol/_dataclasses.py +5997 -0
- ommlds/backends/groq/__init__.py +7 -0
- ommlds/backends/groq/_dataclasses.py +3901 -0
- ommlds/backends/groq/clients.py +9 -0
- ommlds/backends/llamacpp/logging.py +4 -1
- ommlds/backends/mlx/caching.py +7 -3
- ommlds/backends/mlx/cli.py +10 -7
- ommlds/backends/mlx/generation.py +18 -16
- ommlds/backends/mlx/limits.py +10 -6
- ommlds/backends/mlx/loading.py +7 -4
- ommlds/backends/ollama/__init__.py +7 -0
- ommlds/backends/ollama/_dataclasses.py +3488 -0
- ommlds/backends/ollama/protocol.py +3 -0
- ommlds/backends/openai/protocol/__init__.py +15 -1
- ommlds/backends/openai/protocol/_dataclasses.py +7708 -0
- ommlds/backends/tavily/__init__.py +7 -0
- ommlds/backends/tavily/_dataclasses.py +1734 -0
- ommlds/backends/transformers/__init__.py +14 -0
- ommlds/cli/__init__.py +7 -0
- ommlds/cli/_dataclasses.py +3515 -0
- ommlds/cli/backends/catalog.py +0 -5
- ommlds/cli/backends/inject.py +70 -7
- ommlds/cli/backends/meta.py +82 -0
- ommlds/cli/content/messages.py +1 -1
- ommlds/cli/inject.py +11 -3
- ommlds/cli/main.py +137 -68
- ommlds/cli/rendering/types.py +6 -0
- ommlds/cli/secrets.py +2 -1
- ommlds/cli/sessions/base.py +1 -10
- ommlds/cli/sessions/chat/configs.py +9 -17
- ommlds/cli/sessions/chat/{chat → drivers}/ai/configs.py +3 -1
- ommlds/cli/sessions/chat/drivers/ai/events.py +57 -0
- ommlds/cli/sessions/chat/{chat → drivers}/ai/inject.py +10 -3
- ommlds/cli/sessions/chat/{chat → drivers}/ai/rendering.py +1 -1
- ommlds/cli/sessions/chat/{chat → drivers}/ai/services.py +1 -1
- ommlds/cli/sessions/chat/{chat → drivers}/ai/tools.py +4 -8
- ommlds/cli/sessions/chat/{chat → drivers}/ai/types.py +9 -0
- ommlds/cli/sessions/chat/drivers/configs.py +25 -0
- ommlds/cli/sessions/chat/drivers/events/inject.py +27 -0
- ommlds/cli/sessions/chat/drivers/events/injection.py +14 -0
- ommlds/cli/sessions/chat/drivers/events/manager.py +16 -0
- ommlds/cli/sessions/chat/drivers/events/types.py +38 -0
- ommlds/cli/sessions/chat/drivers/impl.py +50 -0
- ommlds/cli/sessions/chat/drivers/inject.py +70 -0
- ommlds/cli/sessions/chat/{chat → drivers}/state/configs.py +2 -0
- ommlds/cli/sessions/chat/drivers/state/ids.py +25 -0
- ommlds/cli/sessions/chat/drivers/state/inject.py +83 -0
- ommlds/cli/sessions/chat/{chat → drivers}/state/inmemory.py +0 -4
- ommlds/cli/sessions/chat/{chat → drivers}/state/storage.py +17 -10
- ommlds/cli/sessions/chat/{chat → drivers}/state/types.py +10 -5
- ommlds/cli/sessions/chat/{tools → drivers/tools}/configs.py +2 -2
- ommlds/cli/sessions/chat/drivers/tools/confirmation.py +44 -0
- ommlds/cli/sessions/chat/drivers/tools/errorhandling.py +39 -0
- ommlds/cli/sessions/chat/{tools → drivers/tools}/execution.py +3 -4
- ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/inject.py +3 -3
- ommlds/cli/sessions/chat/{tools → drivers/tools}/inject.py +7 -12
- ommlds/cli/sessions/chat/{tools → drivers/tools}/injection.py +5 -5
- ommlds/cli/sessions/chat/{tools → drivers/tools}/rendering.py +3 -3
- ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/inject.py +3 -3
- ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/tools.py +1 -1
- ommlds/cli/sessions/chat/drivers/types.py +31 -0
- ommlds/cli/sessions/chat/{chat → drivers}/user/configs.py +0 -3
- ommlds/cli/sessions/chat/drivers/user/inject.py +41 -0
- ommlds/cli/sessions/chat/facades/__init__.py +0 -0
- ommlds/cli/sessions/chat/facades/commands/__init__.py +0 -0
- ommlds/cli/sessions/chat/facades/commands/base.py +83 -0
- ommlds/cli/sessions/chat/facades/commands/configs.py +9 -0
- ommlds/cli/sessions/chat/facades/commands/inject.py +41 -0
- ommlds/cli/sessions/chat/facades/commands/injection.py +15 -0
- ommlds/cli/sessions/chat/facades/commands/manager.py +59 -0
- ommlds/cli/sessions/chat/facades/commands/simple.py +34 -0
- ommlds/cli/sessions/chat/facades/commands/types.py +13 -0
- ommlds/cli/sessions/chat/facades/configs.py +11 -0
- ommlds/cli/sessions/chat/facades/facade.py +26 -0
- ommlds/cli/sessions/chat/facades/inject.py +35 -0
- ommlds/cli/sessions/chat/facades/ui.py +34 -0
- ommlds/cli/sessions/chat/inject.py +8 -31
- ommlds/cli/sessions/chat/interfaces/__init__.py +0 -0
- ommlds/cli/sessions/chat/interfaces/bare/__init__.py +0 -0
- ommlds/cli/sessions/chat/interfaces/bare/configs.py +15 -0
- ommlds/cli/sessions/chat/interfaces/bare/inject.py +69 -0
- ommlds/cli/sessions/chat/interfaces/bare/interactive.py +49 -0
- ommlds/cli/sessions/chat/interfaces/bare/oneshot.py +21 -0
- ommlds/cli/sessions/chat/{tools/confirmation.py → interfaces/bare/tools.py} +3 -22
- ommlds/cli/sessions/chat/interfaces/base.py +13 -0
- ommlds/cli/sessions/chat/interfaces/configs.py +11 -0
- ommlds/cli/sessions/chat/interfaces/inject.py +29 -0
- ommlds/cli/sessions/chat/interfaces/textual/__init__.py +0 -0
- ommlds/cli/sessions/chat/interfaces/textual/app.py +310 -0
- ommlds/cli/sessions/chat/interfaces/textual/configs.py +11 -0
- ommlds/cli/sessions/chat/interfaces/textual/facades.py +19 -0
- ommlds/cli/sessions/chat/interfaces/textual/inject.py +97 -0
- ommlds/cli/sessions/chat/interfaces/textual/interface.py +24 -0
- ommlds/cli/sessions/chat/interfaces/textual/styles/__init__.py +29 -0
- ommlds/cli/sessions/chat/interfaces/textual/styles/input.tcss +53 -0
- ommlds/cli/sessions/chat/interfaces/textual/styles/markdown.tcss +7 -0
- ommlds/cli/sessions/chat/interfaces/textual/styles/messages.tcss +157 -0
- ommlds/cli/sessions/chat/interfaces/textual/tools.py +38 -0
- ommlds/cli/sessions/chat/interfaces/textual/widgets/__init__.py +0 -0
- ommlds/cli/sessions/chat/interfaces/textual/widgets/input.py +36 -0
- ommlds/cli/sessions/chat/interfaces/textual/widgets/messages.py +197 -0
- ommlds/cli/sessions/chat/session.py +8 -13
- ommlds/cli/sessions/completion/configs.py +3 -4
- ommlds/cli/sessions/completion/inject.py +1 -2
- ommlds/cli/sessions/completion/session.py +4 -8
- ommlds/cli/sessions/configs.py +10 -0
- ommlds/cli/sessions/embedding/configs.py +3 -4
- ommlds/cli/sessions/embedding/inject.py +1 -2
- ommlds/cli/sessions/embedding/session.py +4 -8
- ommlds/cli/sessions/inject.py +15 -15
- ommlds/cli/state/storage.py +7 -1
- ommlds/minichain/__init__.py +161 -38
- ommlds/minichain/_dataclasses.py +20452 -0
- ommlds/minichain/_typedvalues.py +11 -4
- ommlds/minichain/backends/impls/anthropic/names.py +3 -3
- ommlds/minichain/backends/impls/anthropic/protocol.py +2 -2
- ommlds/minichain/backends/impls/anthropic/stream.py +1 -1
- ommlds/minichain/backends/impls/cerebras/__init__.py +0 -0
- ommlds/minichain/backends/impls/cerebras/chat.py +80 -0
- ommlds/minichain/backends/impls/cerebras/names.py +45 -0
- ommlds/minichain/backends/impls/cerebras/protocol.py +143 -0
- ommlds/minichain/backends/impls/cerebras/stream.py +125 -0
- ommlds/minichain/backends/impls/duckduckgo/search.py +5 -1
- ommlds/minichain/backends/impls/google/names.py +6 -0
- ommlds/minichain/backends/impls/google/stream.py +1 -1
- ommlds/minichain/backends/impls/google/tools.py +2 -2
- ommlds/minichain/backends/impls/groq/chat.py +2 -0
- ommlds/minichain/backends/impls/groq/protocol.py +2 -2
- ommlds/minichain/backends/impls/groq/stream.py +3 -1
- ommlds/minichain/backends/impls/huggingface/repos.py +1 -5
- ommlds/minichain/backends/impls/llamacpp/chat.py +6 -3
- ommlds/minichain/backends/impls/llamacpp/completion.py +7 -3
- ommlds/minichain/backends/impls/llamacpp/stream.py +6 -3
- ommlds/minichain/backends/impls/mlx/chat.py +6 -3
- ommlds/minichain/backends/impls/ollama/chat.py +51 -57
- ommlds/minichain/backends/impls/ollama/protocol.py +144 -0
- ommlds/minichain/backends/impls/openai/format.py +4 -3
- ommlds/minichain/backends/impls/openai/names.py +3 -1
- ommlds/minichain/backends/impls/openai/stream.py +33 -1
- ommlds/minichain/backends/impls/sentencepiece/tokens.py +9 -6
- ommlds/minichain/backends/impls/tinygrad/chat.py +7 -4
- ommlds/minichain/backends/impls/tokenizers/tokens.py +9 -6
- ommlds/minichain/backends/impls/transformers/sentence.py +5 -2
- ommlds/minichain/backends/impls/transformers/tokens.py +9 -6
- ommlds/minichain/backends/impls/transformers/transformers.py +10 -8
- ommlds/minichain/backends/strings/resolving.py +1 -1
- ommlds/minichain/chat/content.py +42 -0
- ommlds/minichain/chat/messages.py +43 -39
- ommlds/minichain/chat/stream/joining.py +36 -12
- ommlds/minichain/chat/stream/types.py +1 -1
- ommlds/minichain/chat/templating.py +3 -3
- ommlds/minichain/content/__init__.py +19 -3
- ommlds/minichain/content/_marshal.py +181 -55
- ommlds/minichain/content/code.py +26 -0
- ommlds/minichain/content/composite.py +28 -0
- ommlds/minichain/content/content.py +27 -0
- ommlds/minichain/content/dynamic.py +12 -0
- ommlds/minichain/content/emphasis.py +27 -0
- ommlds/minichain/content/images.py +2 -2
- ommlds/minichain/content/json.py +2 -2
- ommlds/minichain/content/link.py +13 -0
- ommlds/minichain/content/markdown.py +12 -0
- ommlds/minichain/content/metadata.py +10 -0
- ommlds/minichain/content/namespaces.py +8 -0
- ommlds/minichain/content/placeholders.py +10 -9
- ommlds/minichain/content/quote.py +26 -0
- ommlds/minichain/content/raw.py +49 -0
- ommlds/minichain/content/recursive.py +12 -0
- ommlds/minichain/content/section.py +26 -0
- ommlds/minichain/content/sequence.py +17 -3
- ommlds/minichain/content/standard.py +32 -0
- ommlds/minichain/content/tag.py +28 -0
- ommlds/minichain/content/templates.py +13 -0
- ommlds/minichain/content/text.py +2 -2
- ommlds/minichain/content/transform/__init__.py +0 -0
- ommlds/minichain/content/transform/json.py +55 -0
- ommlds/minichain/content/transform/markdown.py +8 -0
- ommlds/minichain/content/transform/materialize.py +51 -0
- ommlds/minichain/content/transform/metadata.py +16 -0
- ommlds/minichain/content/{prepare.py → transform/prepare.py} +10 -15
- ommlds/minichain/content/transform/recursive.py +97 -0
- ommlds/minichain/content/transform/standard.py +43 -0
- ommlds/minichain/content/{transforms → transform}/stringify.py +1 -7
- ommlds/minichain/content/transform/strings.py +33 -0
- ommlds/minichain/content/transform/templates.py +25 -0
- ommlds/minichain/content/visitors.py +231 -0
- ommlds/minichain/lib/fs/tools/read.py +1 -1
- ommlds/minichain/lib/fs/tools/recursivels/rendering.py +1 -1
- ommlds/minichain/lib/fs/tools/recursivels/running.py +1 -1
- ommlds/minichain/lib/todo/tools/write.py +2 -1
- ommlds/minichain/lib/todo/types.py +1 -1
- ommlds/minichain/metadata.py +56 -2
- ommlds/minichain/resources.py +22 -1
- ommlds/minichain/services/README.md +154 -0
- ommlds/minichain/services/__init__.py +6 -2
- ommlds/minichain/services/_marshal.py +46 -10
- ommlds/minichain/services/_origclasses.py +11 -0
- ommlds/minichain/services/_typedvalues.py +8 -3
- ommlds/minichain/services/requests.py +73 -3
- ommlds/minichain/services/responses.py +73 -3
- ommlds/minichain/services/services.py +9 -0
- ommlds/minichain/stream/services.py +24 -1
- ommlds/minichain/text/applypatch.py +2 -1
- ommlds/minichain/text/toolparsing/llamacpp/types.py +1 -1
- ommlds/minichain/tokens/specials.py +1 -1
- ommlds/minichain/tools/execution/catalog.py +1 -1
- ommlds/minichain/tools/execution/errorhandling.py +36 -0
- ommlds/minichain/tools/execution/errors.py +2 -2
- ommlds/minichain/tools/execution/executors.py +1 -1
- ommlds/minichain/tools/fns.py +1 -1
- ommlds/minichain/tools/jsonschema.py +2 -2
- ommlds/minichain/tools/reflect.py +6 -6
- ommlds/minichain/tools/types.py +12 -15
- ommlds/minichain/vectors/_marshal.py +1 -1
- ommlds/minichain/vectors/embeddings.py +1 -1
- ommlds/minichain/wrappers/__init__.py +7 -0
- ommlds/minichain/wrappers/firstinwins.py +144 -0
- ommlds/minichain/wrappers/instrument.py +146 -0
- ommlds/minichain/wrappers/retry.py +168 -0
- ommlds/minichain/wrappers/services.py +98 -0
- ommlds/minichain/wrappers/stream.py +57 -0
- ommlds/nanochat/rustbpe/README.md +9 -0
- ommlds/nanochat/tokenizers.py +40 -6
- ommlds/specs/mcp/clients.py +146 -0
- ommlds/specs/mcp/protocol.py +123 -18
- ommlds/tools/git.py +82 -65
- {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/METADATA +13 -11
- ommlds-0.0.0.dev503.dist-info/RECORD +520 -0
- ommlds/cli/sessions/chat/chat/state/inject.py +0 -36
- ommlds/cli/sessions/chat/chat/user/inject.py +0 -62
- ommlds/cli/sessions/chat/chat/user/interactive.py +0 -31
- ommlds/cli/sessions/chat/chat/user/oneshot.py +0 -25
- ommlds/cli/sessions/chat/chat/user/types.py +0 -15
- ommlds/cli/sessions/chat/driver.py +0 -43
- ommlds/minichain/content/materialize.py +0 -196
- ommlds/minichain/content/simple.py +0 -47
- ommlds/minichain/content/transforms/base.py +0 -46
- ommlds/minichain/content/transforms/interleave.py +0 -70
- ommlds/minichain/content/transforms/squeeze.py +0 -72
- ommlds/minichain/content/transforms/strings.py +0 -24
- ommlds/minichain/content/types.py +0 -43
- ommlds/minichain/stream/wrap.py +0 -62
- ommlds-0.0.0.dev480.dist-info/RECORD +0 -427
- /ommlds/cli/sessions/chat/{chat → drivers}/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{chat → drivers}/ai/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{chat → drivers}/ai/injection.py +0 -0
- /ommlds/cli/sessions/chat/{chat/state → drivers/events}/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{chat/user → drivers/phases}/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{phases → drivers/phases}/inject.py +0 -0
- /ommlds/cli/sessions/chat/{phases → drivers/phases}/injection.py +0 -0
- /ommlds/cli/sessions/chat/{phases → drivers/phases}/manager.py +0 -0
- /ommlds/cli/sessions/chat/{phases → drivers/phases}/types.py +0 -0
- /ommlds/cli/sessions/chat/{phases → drivers/state}/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/fs/configs.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/todo/configs.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/__init__.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/configs.py +0 -0
- /ommlds/cli/sessions/chat/{tools → drivers/tools}/weather/inject.py +0 -0
- /ommlds/{minichain/content/transforms → cli/sessions/chat/drivers/user}/__init__.py +0 -0
- {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev480.dist-info → ommlds-0.0.0.dev503.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
+
NOTE: This can't be cleaned up too much - the callback can't be a closure to hide its guts because it needs to be
|
|
3
|
+
picklable for multiprocessing.
|
|
4
|
+
|
|
2
5
|
FIXME:
|
|
3
6
|
- it outputs newline-terminated so buffer and chop on newlines - DelimitingBuffer again
|
|
4
7
|
"""
|
|
@@ -27,4 +30,4 @@ def llama_log_callback(
|
|
|
27
30
|
|
|
28
31
|
@lang.cached_function
|
|
29
32
|
def install_logging_hook() -> None:
|
|
30
|
-
llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0))
|
|
33
|
+
llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0)) # noqa
|
ommlds/backends/mlx/caching.py
CHANGED
|
@@ -17,7 +17,11 @@
|
|
|
17
17
|
# https://github.com/ml-explore/mlx-lm/blob/ce2358d297af245b002e690623f00195b6507da0/mlx_lm/generate.py
|
|
18
18
|
import typing as ta
|
|
19
19
|
|
|
20
|
-
import
|
|
20
|
+
from omlish import lang
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
with lang.auto_proxy_import(globals()):
|
|
24
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
##
|
|
@@ -32,13 +36,13 @@ def maybe_quantize_kv_cache(
|
|
|
32
36
|
) -> None:
|
|
33
37
|
if not (
|
|
34
38
|
kv_bits is not None and
|
|
35
|
-
not isinstance(prompt_cache[0],
|
|
39
|
+
not isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache) and
|
|
36
40
|
prompt_cache[0].offset > quantized_kv_start
|
|
37
41
|
):
|
|
38
42
|
return
|
|
39
43
|
|
|
40
44
|
for i in range(len(prompt_cache)):
|
|
41
|
-
if isinstance(prompt_cache[i],
|
|
45
|
+
if isinstance(prompt_cache[i], mlx_lm_models_cache.KVCache):
|
|
42
46
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
|
43
47
|
bits=kv_bits,
|
|
44
48
|
group_size=kv_group_size,
|
ommlds/backends/mlx/cli.py
CHANGED
|
@@ -20,16 +20,19 @@ import json
|
|
|
20
20
|
import sys
|
|
21
21
|
import typing as ta
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
import mlx_lm.models.cache
|
|
25
|
-
import mlx_lm.sample_utils
|
|
26
|
-
import mlx_lm.utils
|
|
23
|
+
from omlish import lang
|
|
27
24
|
|
|
28
25
|
from .generation import GenerationParams
|
|
29
26
|
from .generation import generate
|
|
30
27
|
from .loading import load_model
|
|
31
28
|
|
|
32
29
|
|
|
30
|
+
with lang.auto_proxy_import(globals()):
|
|
31
|
+
import mlx.core as mx
|
|
32
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
33
|
+
import mlx_lm.sample_utils as mlx_lm_sample_utils
|
|
34
|
+
|
|
35
|
+
|
|
33
36
|
##
|
|
34
37
|
|
|
35
38
|
|
|
@@ -214,11 +217,11 @@ def _main() -> None:
|
|
|
214
217
|
# Load the prompt cache and metadata if a cache file is provided
|
|
215
218
|
using_cache = args.prompt_cache_file is not None
|
|
216
219
|
if using_cache:
|
|
217
|
-
prompt_cache, metadata =
|
|
220
|
+
prompt_cache, metadata = mlx_lm_models_cache.load_prompt_cache(
|
|
218
221
|
args.prompt_cache_file,
|
|
219
222
|
return_metadata=True,
|
|
220
223
|
)
|
|
221
|
-
if isinstance(prompt_cache[0],
|
|
224
|
+
if isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache):
|
|
222
225
|
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
|
223
226
|
raise ValueError('--kv-bits does not match the kv cache loaded from --prompt-cache-file.')
|
|
224
227
|
if args.kv_group_size != prompt_cache[0].group_size:
|
|
@@ -293,7 +296,7 @@ def _main() -> None:
|
|
|
293
296
|
else:
|
|
294
297
|
prompt = tokenizer.encode(prompt)
|
|
295
298
|
|
|
296
|
-
sampler =
|
|
299
|
+
sampler = mlx_lm_sample_utils.make_sampler(
|
|
297
300
|
args.temp,
|
|
298
301
|
args.top_p,
|
|
299
302
|
args.min_p,
|
|
@@ -21,10 +21,6 @@ import io
|
|
|
21
21
|
import sys
|
|
22
22
|
import typing as ta
|
|
23
23
|
|
|
24
|
-
import mlx.core as mx
|
|
25
|
-
import mlx_lm.models.cache
|
|
26
|
-
from mlx import nn
|
|
27
|
-
|
|
28
24
|
from omlish import check
|
|
29
25
|
from omlish import lang
|
|
30
26
|
|
|
@@ -33,6 +29,12 @@ from .limits import wired_limit_context
|
|
|
33
29
|
from .tokenization import Tokenization
|
|
34
30
|
|
|
35
31
|
|
|
32
|
+
with lang.auto_proxy_import(globals()):
|
|
33
|
+
import mlx.core as mx
|
|
34
|
+
import mlx.nn as mlx_nn
|
|
35
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
36
|
+
|
|
37
|
+
|
|
36
38
|
##
|
|
37
39
|
|
|
38
40
|
|
|
@@ -47,9 +49,9 @@ def _generation_stream():
|
|
|
47
49
|
class LogitProcessor(ta.Protocol):
|
|
48
50
|
def __call__(
|
|
49
51
|
self,
|
|
50
|
-
tokens: mx.array,
|
|
51
|
-
logits: mx.array,
|
|
52
|
-
) -> mx.array:
|
|
52
|
+
tokens: 'mx.array',
|
|
53
|
+
logits: 'mx.array',
|
|
54
|
+
) -> 'mx.array':
|
|
53
55
|
...
|
|
54
56
|
|
|
55
57
|
|
|
@@ -99,12 +101,12 @@ class GenerationParams:
|
|
|
99
101
|
|
|
100
102
|
class _GenerationStep(ta.NamedTuple):
|
|
101
103
|
token: int
|
|
102
|
-
logprobs: mx.array
|
|
104
|
+
logprobs: 'mx.array'
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
def _generate_step(
|
|
106
|
-
prompt: mx.array,
|
|
107
|
-
model:
|
|
108
|
+
prompt: 'mx.array',
|
|
109
|
+
model: 'mlx_nn.Module',
|
|
108
110
|
params: GenerationParams = GenerationParams(),
|
|
109
111
|
) -> ta.Generator[_GenerationStep]:
|
|
110
112
|
y = prompt
|
|
@@ -113,7 +115,7 @@ def _generate_step(
|
|
|
113
115
|
# Create the Kv cache for generation
|
|
114
116
|
prompt_cache = params.prompt_cache
|
|
115
117
|
if prompt_cache is None:
|
|
116
|
-
prompt_cache =
|
|
118
|
+
prompt_cache = mlx_lm_models_cache.make_prompt_cache(
|
|
117
119
|
model,
|
|
118
120
|
max_kv_size=params.max_kv_size,
|
|
119
121
|
)
|
|
@@ -221,7 +223,7 @@ class GenerationOutput:
|
|
|
221
223
|
token: int
|
|
222
224
|
|
|
223
225
|
# A vector of log probabilities.
|
|
224
|
-
logprobs: mx.array
|
|
226
|
+
logprobs: 'mx.array'
|
|
225
227
|
|
|
226
228
|
# The number of tokens in the prompt.
|
|
227
229
|
prompt_tokens: int
|
|
@@ -234,9 +236,9 @@ class GenerationOutput:
|
|
|
234
236
|
|
|
235
237
|
|
|
236
238
|
def stream_generate(
|
|
237
|
-
model:
|
|
239
|
+
model: 'mlx_nn.Module',
|
|
238
240
|
tokenization: Tokenization,
|
|
239
|
-
prompt: str
|
|
241
|
+
prompt: ta.Union[str, 'mx.array'],
|
|
240
242
|
params: GenerationParams = GenerationParams(),
|
|
241
243
|
) -> ta.Generator[GenerationOutput]:
|
|
242
244
|
if not isinstance(prompt, mx.array):
|
|
@@ -308,9 +310,9 @@ def stream_generate(
|
|
|
308
310
|
|
|
309
311
|
|
|
310
312
|
def generate(
|
|
311
|
-
model:
|
|
313
|
+
model: 'mlx_nn.Module',
|
|
312
314
|
tokenization: Tokenization,
|
|
313
|
-
prompt: str
|
|
315
|
+
prompt: ta.Union[str, 'mx.array'],
|
|
314
316
|
params: GenerationParams = GenerationParams(),
|
|
315
317
|
*,
|
|
316
318
|
verbose: bool = False,
|
ommlds/backends/mlx/limits.py
CHANGED
|
@@ -19,9 +19,13 @@ import contextlib
|
|
|
19
19
|
import sys
|
|
20
20
|
import typing as ta
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
from omlish import lang
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
with lang.auto_proxy_import(globals()):
|
|
26
|
+
import mlx.core as mx
|
|
27
|
+
import mlx.nn as mlx_nn
|
|
28
|
+
import mlx.utils as mlx_utils
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
##
|
|
@@ -29,8 +33,8 @@ from mlx import nn
|
|
|
29
33
|
|
|
30
34
|
@contextlib.contextmanager
|
|
31
35
|
def wired_limit_context(
|
|
32
|
-
model:
|
|
33
|
-
streams: ta.Iterable[mx.Stream] | None = None,
|
|
36
|
+
model: 'mlx_nn.Module',
|
|
37
|
+
streams: ta.Iterable['mx.Stream'] | None = None,
|
|
34
38
|
) -> ta.Generator[None]:
|
|
35
39
|
"""
|
|
36
40
|
A context manager to temporarily change the wired limit.
|
|
@@ -43,7 +47,7 @@ def wired_limit_context(
|
|
|
43
47
|
yield
|
|
44
48
|
return
|
|
45
49
|
|
|
46
|
-
model_bytes =
|
|
50
|
+
model_bytes = mlx_utils.tree_reduce(
|
|
47
51
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc,
|
|
48
52
|
model,
|
|
49
53
|
0,
|
ommlds/backends/mlx/loading.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
|
+
# ruff: noqa: TC002
|
|
1
2
|
import dataclasses as dc
|
|
2
3
|
import pathlib
|
|
3
4
|
import typing as ta
|
|
4
5
|
|
|
5
|
-
import mlx_lm.utils
|
|
6
|
-
from mlx import nn
|
|
7
|
-
|
|
8
6
|
from omlish import check
|
|
9
7
|
from omlish import lang
|
|
10
8
|
|
|
@@ -12,6 +10,11 @@ from .tokenization import Tokenization
|
|
|
12
10
|
from .tokenization import load_tokenization
|
|
13
11
|
|
|
14
12
|
|
|
13
|
+
with lang.auto_proxy_import(globals()):
|
|
14
|
+
import mlx.nn as mlx_nn
|
|
15
|
+
import mlx_lm.utils
|
|
16
|
+
|
|
17
|
+
|
|
15
18
|
##
|
|
16
19
|
|
|
17
20
|
|
|
@@ -76,7 +79,7 @@ def get_model_path(
|
|
|
76
79
|
class LoadedModel:
|
|
77
80
|
path: pathlib.Path
|
|
78
81
|
|
|
79
|
-
model:
|
|
82
|
+
model: 'mlx_nn.Module'
|
|
80
83
|
config: dict
|
|
81
84
|
|
|
82
85
|
#
|