ommlds 0.0.0.dev479__py3-none-any.whl → 0.0.0.dev481__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +40 -23
- ommlds/__about__.py +1 -1
- ommlds/backends/llamacpp/logging.py +4 -1
- ommlds/backends/mlx/caching.py +7 -3
- ommlds/backends/mlx/cli.py +10 -7
- ommlds/backends/mlx/generation.py +18 -16
- ommlds/backends/mlx/limits.py +10 -6
- ommlds/backends/mlx/loading.py +7 -4
- ommlds/backends/tavily/__init__.py +0 -0
- ommlds/backends/tavily/protocol.py +301 -0
- ommlds/backends/transformers/__init__.py +14 -0
- ommlds/minichain/__init__.py +1 -0
- ommlds/minichain/_dataclasses.py +46282 -0
- ommlds/minichain/backends/impls/anthropic/chat.py +23 -4
- ommlds/minichain/backends/impls/duckduckgo/search.py +5 -1
- ommlds/minichain/backends/impls/huggingface/repos.py +1 -5
- ommlds/minichain/backends/impls/llamacpp/chat.py +6 -3
- ommlds/minichain/backends/impls/llamacpp/completion.py +7 -3
- ommlds/minichain/backends/impls/llamacpp/stream.py +6 -3
- ommlds/minichain/backends/impls/mlx/chat.py +6 -3
- ommlds/minichain/backends/impls/openai/format.py +2 -0
- ommlds/minichain/backends/impls/openai/names.py +3 -1
- ommlds/minichain/backends/impls/sentencepiece/tokens.py +9 -6
- ommlds/minichain/backends/impls/tavily.py +66 -0
- ommlds/minichain/backends/impls/tinygrad/chat.py +7 -4
- ommlds/minichain/backends/impls/tokenizers/tokens.py +9 -6
- ommlds/minichain/backends/impls/transformers/sentence.py +5 -2
- ommlds/minichain/backends/impls/transformers/tokens.py +9 -6
- ommlds/minichain/backends/impls/transformers/transformers.py +10 -8
- ommlds/minichain/llms/types.py +4 -0
- ommlds/minichain/search.py +1 -1
- ommlds/minichain/standard.py +1 -0
- ommlds/specs/__init__.py +0 -0
- ommlds/specs/mcp/__init__.py +0 -0
- ommlds/specs/mcp/_marshal.py +23 -0
- ommlds/specs/mcp/clients.py +146 -0
- ommlds/specs/mcp/protocol.py +371 -0
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/METADATA +5 -5
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/RECORD +43 -34
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev479.dist-info → ommlds-0.0.0.dev481.dist-info}/top_level.txt +0 -0
ommlds/.omlish-manifests.json
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
"module": ".minichain.backends.impls.anthropic.chat",
|
|
19
19
|
"attr": null,
|
|
20
20
|
"file": "ommlds/minichain/backends/impls/anthropic/chat.py",
|
|
21
|
-
"line":
|
|
21
|
+
"line": 42,
|
|
22
22
|
"value": {
|
|
23
23
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
24
24
|
"module": "ommlds.minichain.backends.impls.anthropic.chat",
|
|
@@ -78,7 +78,7 @@
|
|
|
78
78
|
"module": ".minichain.backends.impls.duckduckgo.search",
|
|
79
79
|
"attr": null,
|
|
80
80
|
"file": "ommlds/minichain/backends/impls/duckduckgo/search.py",
|
|
81
|
-
"line":
|
|
81
|
+
"line": 17,
|
|
82
82
|
"value": {
|
|
83
83
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
84
84
|
"module": "ommlds.minichain.backends.impls.duckduckgo.search",
|
|
@@ -252,7 +252,7 @@
|
|
|
252
252
|
"module": ".minichain.backends.impls.huggingface.repos",
|
|
253
253
|
"attr": null,
|
|
254
254
|
"file": "ommlds/minichain/backends/impls/huggingface/repos.py",
|
|
255
|
-
"line":
|
|
255
|
+
"line": 20,
|
|
256
256
|
"value": {
|
|
257
257
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
258
258
|
"module": "ommlds.minichain.backends.impls.huggingface.repos",
|
|
@@ -269,7 +269,7 @@
|
|
|
269
269
|
"module": ".minichain.backends.impls.llamacpp.chat",
|
|
270
270
|
"attr": null,
|
|
271
271
|
"file": "ommlds/minichain/backends/impls/llamacpp/chat.py",
|
|
272
|
-
"line":
|
|
272
|
+
"line": 36,
|
|
273
273
|
"value": {
|
|
274
274
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
275
275
|
"service_cls_names": [
|
|
@@ -284,7 +284,7 @@
|
|
|
284
284
|
"module": ".minichain.backends.impls.llamacpp.chat",
|
|
285
285
|
"attr": null,
|
|
286
286
|
"file": "ommlds/minichain/backends/impls/llamacpp/chat.py",
|
|
287
|
-
"line":
|
|
287
|
+
"line": 45,
|
|
288
288
|
"value": {
|
|
289
289
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
290
290
|
"module": "ommlds.minichain.backends.impls.llamacpp.chat",
|
|
@@ -299,7 +299,7 @@
|
|
|
299
299
|
"module": ".minichain.backends.impls.llamacpp.completion",
|
|
300
300
|
"attr": null,
|
|
301
301
|
"file": "ommlds/minichain/backends/impls/llamacpp/completion.py",
|
|
302
|
-
"line":
|
|
302
|
+
"line": 28,
|
|
303
303
|
"value": {
|
|
304
304
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
305
305
|
"module": "ommlds.minichain.backends.impls.llamacpp.completion",
|
|
@@ -314,7 +314,7 @@
|
|
|
314
314
|
"module": ".minichain.backends.impls.llamacpp.stream",
|
|
315
315
|
"attr": null,
|
|
316
316
|
"file": "ommlds/minichain/backends/impls/llamacpp/stream.py",
|
|
317
|
-
"line":
|
|
317
|
+
"line": 35,
|
|
318
318
|
"value": {
|
|
319
319
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
320
320
|
"service_cls_names": [
|
|
@@ -329,7 +329,7 @@
|
|
|
329
329
|
"module": ".minichain.backends.impls.llamacpp.stream",
|
|
330
330
|
"attr": null,
|
|
331
331
|
"file": "ommlds/minichain/backends/impls/llamacpp/stream.py",
|
|
332
|
-
"line":
|
|
332
|
+
"line": 44,
|
|
333
333
|
"value": {
|
|
334
334
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
335
335
|
"module": "ommlds.minichain.backends.impls.llamacpp.stream",
|
|
@@ -359,7 +359,7 @@
|
|
|
359
359
|
"module": ".minichain.backends.impls.mlx.chat",
|
|
360
360
|
"attr": null,
|
|
361
361
|
"file": "ommlds/minichain/backends/impls/mlx/chat.py",
|
|
362
|
-
"line":
|
|
362
|
+
"line": 42,
|
|
363
363
|
"value": {
|
|
364
364
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
365
365
|
"service_cls_names": [
|
|
@@ -375,7 +375,7 @@
|
|
|
375
375
|
"module": ".minichain.backends.impls.mlx.chat",
|
|
376
376
|
"attr": null,
|
|
377
377
|
"file": "ommlds/minichain/backends/impls/mlx/chat.py",
|
|
378
|
-
"line":
|
|
378
|
+
"line": 136,
|
|
379
379
|
"value": {
|
|
380
380
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
381
381
|
"module": "ommlds.minichain.backends.impls.mlx.chat",
|
|
@@ -390,7 +390,7 @@
|
|
|
390
390
|
"module": ".minichain.backends.impls.mlx.chat",
|
|
391
391
|
"attr": null,
|
|
392
392
|
"file": "ommlds/minichain/backends/impls/mlx/chat.py",
|
|
393
|
-
"line":
|
|
393
|
+
"line": 167,
|
|
394
394
|
"value": {
|
|
395
395
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
396
396
|
"module": "ommlds.minichain.backends.impls.mlx.chat",
|
|
@@ -496,7 +496,7 @@
|
|
|
496
496
|
"module": ".minichain.backends.impls.openai.names",
|
|
497
497
|
"attr": "_CHAT_BACKEND_STRINGS_MANIFEST",
|
|
498
498
|
"file": "ommlds/minichain/backends/impls/openai/names.py",
|
|
499
|
-
"line":
|
|
499
|
+
"line": 65,
|
|
500
500
|
"value": {
|
|
501
501
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
502
502
|
"service_cls_names": [
|
|
@@ -520,6 +520,7 @@
|
|
|
520
520
|
"gpt-5-chat-latest": null,
|
|
521
521
|
"gpt-5-mini": null,
|
|
522
522
|
"gpt-5-nano": null,
|
|
523
|
+
"gpt-5.1": null,
|
|
523
524
|
"gpt3.5-turbo": "gpt-3.5-turbo",
|
|
524
525
|
"gpt3.5-turbo-instruct": "gpt-3.5-turbo-instruct",
|
|
525
526
|
"gpt4": "gpt-4",
|
|
@@ -533,7 +534,8 @@
|
|
|
533
534
|
"gpt5-chat-latest": "gpt-5-chat-latest",
|
|
534
535
|
"gpt5-mini": "gpt-5-mini",
|
|
535
536
|
"gpt5-nano": "gpt-5-nano",
|
|
536
|
-
"
|
|
537
|
+
"gpt5.1": "gpt-5.1",
|
|
538
|
+
"gpt": "gpt-5.1",
|
|
537
539
|
"gpt-mini": "gpt-5-mini",
|
|
538
540
|
"o3": null,
|
|
539
541
|
"o3-mini": null,
|
|
@@ -548,7 +550,7 @@
|
|
|
548
550
|
"module": ".minichain.backends.impls.openai.names",
|
|
549
551
|
"attr": "_COMPLETION_BACKEND_STRINGS_MANIFEST",
|
|
550
552
|
"file": "ommlds/minichain/backends/impls/openai/names.py",
|
|
551
|
-
"line":
|
|
553
|
+
"line": 79,
|
|
552
554
|
"value": {
|
|
553
555
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
554
556
|
"service_cls_names": [
|
|
@@ -563,7 +565,7 @@
|
|
|
563
565
|
"module": ".minichain.backends.impls.openai.names",
|
|
564
566
|
"attr": "_EMBEDDING_BACKEND_STRINGS_MANIFEST",
|
|
565
567
|
"file": "ommlds/minichain/backends/impls/openai/names.py",
|
|
566
|
-
"line":
|
|
568
|
+
"line": 91,
|
|
567
569
|
"value": {
|
|
568
570
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
569
571
|
"service_cls_names": [
|
|
@@ -589,11 +591,26 @@
|
|
|
589
591
|
}
|
|
590
592
|
}
|
|
591
593
|
},
|
|
594
|
+
{
|
|
595
|
+
"module": ".minichain.backends.impls.tavily",
|
|
596
|
+
"attr": null,
|
|
597
|
+
"file": "ommlds/minichain/backends/impls/tavily.py",
|
|
598
|
+
"line": 19,
|
|
599
|
+
"value": {
|
|
600
|
+
"!.minichain.registries.manifests.RegistryManifest": {
|
|
601
|
+
"module": "ommlds.minichain.backends.impls.tavily",
|
|
602
|
+
"attr": "TavilySearchService",
|
|
603
|
+
"name": "tavily",
|
|
604
|
+
"aliases": null,
|
|
605
|
+
"type": "SearchService"
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
},
|
|
592
609
|
{
|
|
593
610
|
"module": ".minichain.backends.impls.tinygrad.chat",
|
|
594
611
|
"attr": null,
|
|
595
612
|
"file": "ommlds/minichain/backends/impls/tinygrad/chat.py",
|
|
596
|
-
"line":
|
|
613
|
+
"line": 118,
|
|
597
614
|
"value": {
|
|
598
615
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
599
616
|
"module": "ommlds.minichain.backends.impls.tinygrad.chat",
|
|
@@ -608,7 +625,7 @@
|
|
|
608
625
|
"module": ".minichain.backends.impls.tinygrad.chat",
|
|
609
626
|
"attr": null,
|
|
610
627
|
"file": "ommlds/minichain/backends/impls/tinygrad/chat.py",
|
|
611
|
-
"line":
|
|
628
|
+
"line": 138,
|
|
612
629
|
"value": {
|
|
613
630
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
614
631
|
"module": "ommlds.minichain.backends.impls.tinygrad.chat",
|
|
@@ -623,7 +640,7 @@
|
|
|
623
640
|
"module": ".minichain.backends.impls.tinygrad.chat",
|
|
624
641
|
"attr": null,
|
|
625
642
|
"file": "ommlds/minichain/backends/impls/tinygrad/chat.py",
|
|
626
|
-
"line":
|
|
643
|
+
"line": 169,
|
|
627
644
|
"value": {
|
|
628
645
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
629
646
|
"service_cls_names": [
|
|
@@ -639,7 +656,7 @@
|
|
|
639
656
|
"module": ".minichain.backends.impls.transformers.sentence",
|
|
640
657
|
"attr": null,
|
|
641
658
|
"file": "ommlds/minichain/backends/impls/transformers/sentence.py",
|
|
642
|
-
"line":
|
|
659
|
+
"line": 22,
|
|
643
660
|
"value": {
|
|
644
661
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
645
662
|
"module": "ommlds.minichain.backends.impls.transformers.sentence",
|
|
@@ -656,7 +673,7 @@
|
|
|
656
673
|
"module": ".minichain.backends.impls.transformers.transformers",
|
|
657
674
|
"attr": null,
|
|
658
675
|
"file": "ommlds/minichain/backends/impls/transformers/transformers.py",
|
|
659
|
-
"line":
|
|
676
|
+
"line": 52,
|
|
660
677
|
"value": {
|
|
661
678
|
"!.minichain.backends.strings.manifests.BackendStringsManifest": {
|
|
662
679
|
"service_cls_names": [
|
|
@@ -672,7 +689,7 @@
|
|
|
672
689
|
"module": ".minichain.backends.impls.transformers.transformers",
|
|
673
690
|
"attr": null,
|
|
674
691
|
"file": "ommlds/minichain/backends/impls/transformers/transformers.py",
|
|
675
|
-
"line":
|
|
692
|
+
"line": 68,
|
|
676
693
|
"value": {
|
|
677
694
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
678
695
|
"module": "ommlds.minichain.backends.impls.transformers.transformers",
|
|
@@ -689,7 +706,7 @@
|
|
|
689
706
|
"module": ".minichain.backends.impls.transformers.transformers",
|
|
690
707
|
"attr": null,
|
|
691
708
|
"file": "ommlds/minichain/backends/impls/transformers/transformers.py",
|
|
692
|
-
"line":
|
|
709
|
+
"line": 199,
|
|
693
710
|
"value": {
|
|
694
711
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
695
712
|
"module": "ommlds.minichain.backends.impls.transformers.transformers",
|
|
@@ -706,7 +723,7 @@
|
|
|
706
723
|
"module": ".minichain.backends.impls.transformers.transformers",
|
|
707
724
|
"attr": null,
|
|
708
725
|
"file": "ommlds/minichain/backends/impls/transformers/transformers.py",
|
|
709
|
-
"line":
|
|
726
|
+
"line": 229,
|
|
710
727
|
"value": {
|
|
711
728
|
"!.minichain.registries.manifests.RegistryManifest": {
|
|
712
729
|
"module": "ommlds.minichain.backends.impls.transformers.transformers",
|
ommlds/__about__.py
CHANGED
|
@@ -18,7 +18,7 @@ class Project(ProjectBase):
|
|
|
18
18
|
|
|
19
19
|
'llama-cpp-python ~= 0.3',
|
|
20
20
|
|
|
21
|
-
'mlx ~= 0.
|
|
21
|
+
'mlx ~= 0.30; sys_platform == "darwin"',
|
|
22
22
|
'mlx-lm ~= 0.28; sys_platform == "darwin"',
|
|
23
23
|
|
|
24
24
|
# 'sentencepiece ~= 0.2', # FIXME: https://github.com/google/sentencepiece/issues/1121
|
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
+
NOTE: This can't be cleaned up too much - the callback can't be a closure to hide its guts because it needs to be
|
|
3
|
+
picklable for multiprocessing.
|
|
4
|
+
|
|
2
5
|
FIXME:
|
|
3
6
|
- it outputs newline-terminated so buffer and chop on newlines - DelimitingBuffer again
|
|
4
7
|
"""
|
|
@@ -27,4 +30,4 @@ def llama_log_callback(
|
|
|
27
30
|
|
|
28
31
|
@lang.cached_function
|
|
29
32
|
def install_logging_hook() -> None:
|
|
30
|
-
llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0))
|
|
33
|
+
llama_cpp.llama_log_set(llama_log_callback, ct.c_void_p(0)) # noqa
|
ommlds/backends/mlx/caching.py
CHANGED
|
@@ -17,7 +17,11 @@
|
|
|
17
17
|
# https://github.com/ml-explore/mlx-lm/blob/ce2358d297af245b002e690623f00195b6507da0/mlx_lm/generate.py
|
|
18
18
|
import typing as ta
|
|
19
19
|
|
|
20
|
-
import
|
|
20
|
+
from omlish import lang
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
with lang.auto_proxy_import(globals()):
|
|
24
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
##
|
|
@@ -32,13 +36,13 @@ def maybe_quantize_kv_cache(
|
|
|
32
36
|
) -> None:
|
|
33
37
|
if not (
|
|
34
38
|
kv_bits is not None and
|
|
35
|
-
not isinstance(prompt_cache[0],
|
|
39
|
+
not isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache) and
|
|
36
40
|
prompt_cache[0].offset > quantized_kv_start
|
|
37
41
|
):
|
|
38
42
|
return
|
|
39
43
|
|
|
40
44
|
for i in range(len(prompt_cache)):
|
|
41
|
-
if isinstance(prompt_cache[i],
|
|
45
|
+
if isinstance(prompt_cache[i], mlx_lm_models_cache.KVCache):
|
|
42
46
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
|
43
47
|
bits=kv_bits,
|
|
44
48
|
group_size=kv_group_size,
|
ommlds/backends/mlx/cli.py
CHANGED
|
@@ -20,16 +20,19 @@ import json
|
|
|
20
20
|
import sys
|
|
21
21
|
import typing as ta
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
import mlx_lm.models.cache
|
|
25
|
-
import mlx_lm.sample_utils
|
|
26
|
-
import mlx_lm.utils
|
|
23
|
+
from omlish import lang
|
|
27
24
|
|
|
28
25
|
from .generation import GenerationParams
|
|
29
26
|
from .generation import generate
|
|
30
27
|
from .loading import load_model
|
|
31
28
|
|
|
32
29
|
|
|
30
|
+
with lang.auto_proxy_import(globals()):
|
|
31
|
+
import mlx.core as mx
|
|
32
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
33
|
+
import mlx_lm.sample_utils as mlx_lm_sample_utils
|
|
34
|
+
|
|
35
|
+
|
|
33
36
|
##
|
|
34
37
|
|
|
35
38
|
|
|
@@ -214,11 +217,11 @@ def _main() -> None:
|
|
|
214
217
|
# Load the prompt cache and metadata if a cache file is provided
|
|
215
218
|
using_cache = args.prompt_cache_file is not None
|
|
216
219
|
if using_cache:
|
|
217
|
-
prompt_cache, metadata =
|
|
220
|
+
prompt_cache, metadata = mlx_lm_models_cache.load_prompt_cache(
|
|
218
221
|
args.prompt_cache_file,
|
|
219
222
|
return_metadata=True,
|
|
220
223
|
)
|
|
221
|
-
if isinstance(prompt_cache[0],
|
|
224
|
+
if isinstance(prompt_cache[0], mlx_lm_models_cache.QuantizedKVCache):
|
|
222
225
|
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
|
223
226
|
raise ValueError('--kv-bits does not match the kv cache loaded from --prompt-cache-file.')
|
|
224
227
|
if args.kv_group_size != prompt_cache[0].group_size:
|
|
@@ -293,7 +296,7 @@ def _main() -> None:
|
|
|
293
296
|
else:
|
|
294
297
|
prompt = tokenizer.encode(prompt)
|
|
295
298
|
|
|
296
|
-
sampler =
|
|
299
|
+
sampler = mlx_lm_sample_utils.make_sampler(
|
|
297
300
|
args.temp,
|
|
298
301
|
args.top_p,
|
|
299
302
|
args.min_p,
|
|
@@ -21,10 +21,6 @@ import io
|
|
|
21
21
|
import sys
|
|
22
22
|
import typing as ta
|
|
23
23
|
|
|
24
|
-
import mlx.core as mx
|
|
25
|
-
import mlx_lm.models.cache
|
|
26
|
-
from mlx import nn
|
|
27
|
-
|
|
28
24
|
from omlish import check
|
|
29
25
|
from omlish import lang
|
|
30
26
|
|
|
@@ -33,6 +29,12 @@ from .limits import wired_limit_context
|
|
|
33
29
|
from .tokenization import Tokenization
|
|
34
30
|
|
|
35
31
|
|
|
32
|
+
with lang.auto_proxy_import(globals()):
|
|
33
|
+
import mlx.core as mx
|
|
34
|
+
import mlx.nn as mlx_nn
|
|
35
|
+
import mlx_lm.models.cache as mlx_lm_models_cache
|
|
36
|
+
|
|
37
|
+
|
|
36
38
|
##
|
|
37
39
|
|
|
38
40
|
|
|
@@ -47,9 +49,9 @@ def _generation_stream():
|
|
|
47
49
|
class LogitProcessor(ta.Protocol):
|
|
48
50
|
def __call__(
|
|
49
51
|
self,
|
|
50
|
-
tokens: mx.array,
|
|
51
|
-
logits: mx.array,
|
|
52
|
-
) -> mx.array:
|
|
52
|
+
tokens: 'mx.array',
|
|
53
|
+
logits: 'mx.array',
|
|
54
|
+
) -> 'mx.array':
|
|
53
55
|
...
|
|
54
56
|
|
|
55
57
|
|
|
@@ -99,12 +101,12 @@ class GenerationParams:
|
|
|
99
101
|
|
|
100
102
|
class _GenerationStep(ta.NamedTuple):
|
|
101
103
|
token: int
|
|
102
|
-
logprobs: mx.array
|
|
104
|
+
logprobs: 'mx.array'
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
def _generate_step(
|
|
106
|
-
prompt: mx.array,
|
|
107
|
-
model:
|
|
108
|
+
prompt: 'mx.array',
|
|
109
|
+
model: 'mlx_nn.Module',
|
|
108
110
|
params: GenerationParams = GenerationParams(),
|
|
109
111
|
) -> ta.Generator[_GenerationStep]:
|
|
110
112
|
y = prompt
|
|
@@ -113,7 +115,7 @@ def _generate_step(
|
|
|
113
115
|
# Create the Kv cache for generation
|
|
114
116
|
prompt_cache = params.prompt_cache
|
|
115
117
|
if prompt_cache is None:
|
|
116
|
-
prompt_cache =
|
|
118
|
+
prompt_cache = mlx_lm_models_cache.make_prompt_cache(
|
|
117
119
|
model,
|
|
118
120
|
max_kv_size=params.max_kv_size,
|
|
119
121
|
)
|
|
@@ -221,7 +223,7 @@ class GenerationOutput:
|
|
|
221
223
|
token: int
|
|
222
224
|
|
|
223
225
|
# A vector of log probabilities.
|
|
224
|
-
logprobs: mx.array
|
|
226
|
+
logprobs: 'mx.array'
|
|
225
227
|
|
|
226
228
|
# The number of tokens in the prompt.
|
|
227
229
|
prompt_tokens: int
|
|
@@ -234,9 +236,9 @@ class GenerationOutput:
|
|
|
234
236
|
|
|
235
237
|
|
|
236
238
|
def stream_generate(
|
|
237
|
-
model:
|
|
239
|
+
model: 'mlx_nn.Module',
|
|
238
240
|
tokenization: Tokenization,
|
|
239
|
-
prompt: str
|
|
241
|
+
prompt: ta.Union[str, 'mx.array'],
|
|
240
242
|
params: GenerationParams = GenerationParams(),
|
|
241
243
|
) -> ta.Generator[GenerationOutput]:
|
|
242
244
|
if not isinstance(prompt, mx.array):
|
|
@@ -308,9 +310,9 @@ def stream_generate(
|
|
|
308
310
|
|
|
309
311
|
|
|
310
312
|
def generate(
|
|
311
|
-
model:
|
|
313
|
+
model: 'mlx_nn.Module',
|
|
312
314
|
tokenization: Tokenization,
|
|
313
|
-
prompt: str
|
|
315
|
+
prompt: ta.Union[str, 'mx.array'],
|
|
314
316
|
params: GenerationParams = GenerationParams(),
|
|
315
317
|
*,
|
|
316
318
|
verbose: bool = False,
|
ommlds/backends/mlx/limits.py
CHANGED
|
@@ -19,9 +19,13 @@ import contextlib
|
|
|
19
19
|
import sys
|
|
20
20
|
import typing as ta
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
from omlish import lang
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
with lang.auto_proxy_import(globals()):
|
|
26
|
+
import mlx.core as mx
|
|
27
|
+
import mlx.nn as mlx_nn
|
|
28
|
+
import mlx.utils as mlx_utils
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
##
|
|
@@ -29,8 +33,8 @@ from mlx import nn
|
|
|
29
33
|
|
|
30
34
|
@contextlib.contextmanager
|
|
31
35
|
def wired_limit_context(
|
|
32
|
-
model:
|
|
33
|
-
streams: ta.Iterable[mx.Stream] | None = None,
|
|
36
|
+
model: 'mlx_nn.Module',
|
|
37
|
+
streams: ta.Iterable['mx.Stream'] | None = None,
|
|
34
38
|
) -> ta.Generator[None]:
|
|
35
39
|
"""
|
|
36
40
|
A context manager to temporarily change the wired limit.
|
|
@@ -43,7 +47,7 @@ def wired_limit_context(
|
|
|
43
47
|
yield
|
|
44
48
|
return
|
|
45
49
|
|
|
46
|
-
model_bytes =
|
|
50
|
+
model_bytes = mlx_utils.tree_reduce(
|
|
47
51
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc,
|
|
48
52
|
model,
|
|
49
53
|
0,
|
ommlds/backends/mlx/loading.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
|
+
# ruff: noqa: TC002
|
|
1
2
|
import dataclasses as dc
|
|
2
3
|
import pathlib
|
|
3
4
|
import typing as ta
|
|
4
5
|
|
|
5
|
-
import mlx_lm.utils
|
|
6
|
-
from mlx import nn
|
|
7
|
-
|
|
8
6
|
from omlish import check
|
|
9
7
|
from omlish import lang
|
|
10
8
|
|
|
@@ -12,6 +10,11 @@ from .tokenization import Tokenization
|
|
|
12
10
|
from .tokenization import load_tokenization
|
|
13
11
|
|
|
14
12
|
|
|
13
|
+
with lang.auto_proxy_import(globals()):
|
|
14
|
+
import mlx.nn as mlx_nn
|
|
15
|
+
import mlx_lm.utils
|
|
16
|
+
|
|
17
|
+
|
|
15
18
|
##
|
|
16
19
|
|
|
17
20
|
|
|
@@ -76,7 +79,7 @@ def get_model_path(
|
|
|
76
79
|
class LoadedModel:
|
|
77
80
|
path: pathlib.Path
|
|
78
81
|
|
|
79
|
-
model:
|
|
82
|
+
model: 'mlx_nn.Module'
|
|
80
83
|
config: dict
|
|
81
84
|
|
|
82
85
|
#
|
|
File without changes
|