ommlds 0.0.0.dev515__py3-none-any.whl → 0.0.0.dev516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/__about__.py +3 -3
- ommlds/backends/mlx/generation.py +2 -2
- ommlds/backends/mlx/tokenization/detokenization/naive.py +4 -2
- ommlds/backends/mlx/tokenization/tokenization.py +1 -1
- ommlds/backends/transformers/streamers.py +7 -2
- ommlds/cli/_dataclasses.py +147 -0
- ommlds/cli/inject.py +5 -1
- ommlds/cli/main.py +18 -343
- ommlds/cli/profiles.py +449 -0
- ommlds/cli/sessions/chat/interfaces/textual/app.py +4 -0
- ommlds/cli/sessions/inject.py +11 -1
- ommlds/cli/sessions/types.py +7 -0
- ommlds/minichain/backends/impls/transformers/tokens.py +4 -2
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/METADATA +10 -10
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/RECORD +19 -17
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev515.dist-info → ommlds-0.0.0.dev516.dist-info}/top_level.txt +0 -0
ommlds/__about__.py
CHANGED
|
@@ -22,7 +22,7 @@ class Project(ProjectBase):
|
|
|
22
22
|
'llama-cpp-python ~= 0.3',
|
|
23
23
|
|
|
24
24
|
'mlx ~= 0.30; sys_platform == "darwin"',
|
|
25
|
-
'mlx-lm ~= 0.
|
|
25
|
+
'mlx-lm ~= 0.30; sys_platform == "darwin"',
|
|
26
26
|
|
|
27
27
|
'sentencepiece ~= 0.2',
|
|
28
28
|
|
|
@@ -35,12 +35,12 @@ class Project(ProjectBase):
|
|
|
35
35
|
|
|
36
36
|
'torch ~= 2.10',
|
|
37
37
|
|
|
38
|
-
'transformers ~=
|
|
38
|
+
'transformers ~= 5.0',
|
|
39
39
|
'sentence-transformers ~= 5.2',
|
|
40
40
|
],
|
|
41
41
|
|
|
42
42
|
'huggingface': [
|
|
43
|
-
'huggingface-hub ~=
|
|
43
|
+
'huggingface-hub ~= 1.3',
|
|
44
44
|
'datasets ~= 4.5',
|
|
45
45
|
],
|
|
46
46
|
|
|
@@ -238,7 +238,7 @@ class GenerationOutput:
|
|
|
238
238
|
def stream_generate(
|
|
239
239
|
model: 'mlx_nn.Module',
|
|
240
240
|
tokenization: Tokenization,
|
|
241
|
-
prompt: ta.Union[str, 'mx.array'],
|
|
241
|
+
prompt: ta.Union[str, ta.Sequence[int], 'mx.array'],
|
|
242
242
|
params: GenerationParams = GenerationParams(),
|
|
243
243
|
) -> ta.Generator[GenerationOutput]:
|
|
244
244
|
if not isinstance(prompt, mx.array):
|
|
@@ -312,7 +312,7 @@ def stream_generate(
|
|
|
312
312
|
def generate(
|
|
313
313
|
model: 'mlx_nn.Module',
|
|
314
314
|
tokenization: Tokenization,
|
|
315
|
-
prompt: ta.Union[str, 'mx.array'],
|
|
315
|
+
prompt: ta.Union[str, ta.Sequence[int], 'mx.array'],
|
|
316
316
|
params: GenerationParams = GenerationParams(),
|
|
317
317
|
*,
|
|
318
318
|
verbose: bool = False,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from omlish import check
|
|
2
|
+
|
|
1
3
|
from ..types import Tokenizer
|
|
2
4
|
from .base import BaseStreamingDetokenizer
|
|
3
5
|
|
|
@@ -36,14 +38,14 @@ class NaiveStreamingDetokenizer(BaseStreamingDetokenizer):
|
|
|
36
38
|
self._tokens.append(token)
|
|
37
39
|
|
|
38
40
|
def finalize(self) -> None:
|
|
39
|
-
self._text += self._tokenizer.decode(self._current_tokens)
|
|
41
|
+
self._text += check.isinstance(self._tokenizer.decode(self._current_tokens), str)
|
|
40
42
|
self._current_tokens = []
|
|
41
43
|
self._current_text = ''
|
|
42
44
|
|
|
43
45
|
@property
|
|
44
46
|
def text(self) -> str:
|
|
45
47
|
if self._current_tokens:
|
|
46
|
-
self._current_text = self._tokenizer.decode(self._current_tokens)
|
|
48
|
+
self._current_text = check.isinstance(self._tokenizer.decode(self._current_tokens), str)
|
|
47
49
|
if self._current_text.endswith('\ufffd') or (
|
|
48
50
|
self._tokenizer.clean_up_tokenization_spaces and
|
|
49
51
|
self._current_text and
|
|
@@ -44,7 +44,7 @@ class Tokenization:
|
|
|
44
44
|
try:
|
|
45
45
|
token_id = int(token)
|
|
46
46
|
except ValueError:
|
|
47
|
-
token_id = self._tokenizer.convert_tokens_to_ids(token)
|
|
47
|
+
token_id = check.isinstance(self._tokenizer.convert_tokens_to_ids(check.isinstance(token, str)), int)
|
|
48
48
|
|
|
49
49
|
if token_id is None:
|
|
50
50
|
raise ValueError(f"'{token}' is not a token for this tokenizer")
|
|
@@ -12,19 +12,24 @@ P = ta.ParamSpec('P')
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class CancellableTextStreamer(tfm.TextStreamer):
|
|
15
|
+
class Tokenizer(ta.Protocol):
|
|
16
|
+
"""Transformers base class still hard deps an AutoTokenizer despite only needing this one method."""
|
|
17
|
+
|
|
18
|
+
def decode(self, tokens: ta.Sequence[int], **kwargs: ta.Any) -> str: ...
|
|
19
|
+
|
|
15
20
|
class Callback(ta.Protocol):
|
|
16
21
|
def __call__(self, text: str, *, stream_end: bool) -> None: ...
|
|
17
22
|
|
|
18
23
|
def __init__(
|
|
19
24
|
self,
|
|
20
|
-
tokenizer:
|
|
25
|
+
tokenizer: Tokenizer,
|
|
21
26
|
callback: Callback,
|
|
22
27
|
*,
|
|
23
28
|
skip_prompt: bool = False,
|
|
24
29
|
**decode_kwargs: ta.Any,
|
|
25
30
|
) -> None:
|
|
26
31
|
super().__init__(
|
|
27
|
-
tokenizer,
|
|
32
|
+
ta.cast(tfm.AutoTokenizer, tokenizer), # noqa
|
|
28
33
|
skip_prompt=skip_prompt,
|
|
29
34
|
**decode_kwargs,
|
|
30
35
|
)
|
ommlds/cli/_dataclasses.py
CHANGED
|
@@ -1450,6 +1450,153 @@ def _process_dataclass__d65d18393f357ae0fb02bb80268c6f1473462613():
|
|
|
1450
1450
|
return _process_dataclass
|
|
1451
1451
|
|
|
1452
1452
|
|
|
1453
|
+
@_register(
|
|
1454
|
+
plan_repr=(
|
|
1455
|
+
"Plans(tup=(CopyPlan(fields=('profile', 'args')), EqPlan(fields=('profile', 'args')), FrozenPlan(fields=('profi"
|
|
1456
|
+
"le', 'args'), allow_dynamic_dunder_attrs=False), HashPlan(action='add', fields=('profile', 'args'), cache=Fals"
|
|
1457
|
+
"e), InitPlan(fields=(InitPlan.Field(name='profile', annotation=OpRef(name='init.fields.0.annotation'), default"
|
|
1458
|
+
"=None, default_factory=None, init=True, override=False, field_type=FieldType.INSTANCE, coerce=None, validate=N"
|
|
1459
|
+
"one, check_type=None), InitPlan.Field(name='args', annotation=OpRef(name='init.fields.1.annotation'), default="
|
|
1460
|
+
"None, default_factory=None, init=True, override=False, field_type=FieldType.INSTANCE, coerce=None, validate=No"
|
|
1461
|
+
"ne, check_type=None)), self_param='self', std_params=('profile', 'args'), kw_only_params=(), frozen=True, slot"
|
|
1462
|
+
"s=False, post_init_params=None, init_fns=(), validate_fns=()), ReprPlan(fields=(ReprPlan.Field(name='profile',"
|
|
1463
|
+
" kw_only=False, fn=None), ReprPlan.Field(name='args', kw_only=False, fn=None)), id=False, terse=False, default"
|
|
1464
|
+
"_fn=None)))"
|
|
1465
|
+
),
|
|
1466
|
+
plan_repr_sha1='3ca2bdd121559ad92e95a0e4bdf8210452f20e5d',
|
|
1467
|
+
op_ref_idents=(
|
|
1468
|
+
'__dataclass__init__fields__0__annotation',
|
|
1469
|
+
'__dataclass__init__fields__1__annotation',
|
|
1470
|
+
),
|
|
1471
|
+
cls_names=(
|
|
1472
|
+
('ommlds.cli.main', 'ProfileAspect.ConfigureContext'),
|
|
1473
|
+
),
|
|
1474
|
+
)
|
|
1475
|
+
def _process_dataclass__3ca2bdd121559ad92e95a0e4bdf8210452f20e5d():
|
|
1476
|
+
def _process_dataclass(
|
|
1477
|
+
*,
|
|
1478
|
+
__dataclass__cls,
|
|
1479
|
+
__dataclass__init__fields__0__annotation,
|
|
1480
|
+
__dataclass__init__fields__1__annotation,
|
|
1481
|
+
__dataclass__FieldFnValidationError, # noqa
|
|
1482
|
+
__dataclass__FieldTypeValidationError, # noqa
|
|
1483
|
+
__dataclass__FnValidationError, # noqa
|
|
1484
|
+
__dataclass__FrozenInstanceError=dataclasses.FrozenInstanceError, # noqa
|
|
1485
|
+
__dataclass__FunctionType=types.FunctionType, # noqa
|
|
1486
|
+
__dataclass__HAS_DEFAULT_FACTORY=dataclasses._HAS_DEFAULT_FACTORY, # noqa
|
|
1487
|
+
__dataclass__MISSING=dataclasses.MISSING, # noqa
|
|
1488
|
+
__dataclass__None=None, # noqa
|
|
1489
|
+
__dataclass__TypeError=TypeError, # noqa
|
|
1490
|
+
__dataclass___recursive_repr=reprlib.recursive_repr, # noqa
|
|
1491
|
+
__dataclass__isinstance=isinstance, # noqa
|
|
1492
|
+
__dataclass__object_setattr=object.__setattr__, # noqa
|
|
1493
|
+
__dataclass__property=property, # noqa
|
|
1494
|
+
):
|
|
1495
|
+
def __copy__(self):
|
|
1496
|
+
if self.__class__ is not __dataclass__cls:
|
|
1497
|
+
raise TypeError(self)
|
|
1498
|
+
return __dataclass__cls( # noqa
|
|
1499
|
+
profile=self.profile,
|
|
1500
|
+
args=self.args,
|
|
1501
|
+
)
|
|
1502
|
+
|
|
1503
|
+
__copy__.__qualname__ = f"{__dataclass__cls.__qualname__}.__copy__"
|
|
1504
|
+
if '__copy__' in __dataclass__cls.__dict__:
|
|
1505
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __copy__ in class {__dataclass__cls.__name__}")
|
|
1506
|
+
setattr(__dataclass__cls, '__copy__', __copy__)
|
|
1507
|
+
|
|
1508
|
+
def __eq__(self, other):
|
|
1509
|
+
if self is other:
|
|
1510
|
+
return True
|
|
1511
|
+
if self.__class__ is not other.__class__:
|
|
1512
|
+
return NotImplemented
|
|
1513
|
+
return (
|
|
1514
|
+
self.profile == other.profile and
|
|
1515
|
+
self.args == other.args
|
|
1516
|
+
)
|
|
1517
|
+
|
|
1518
|
+
__eq__.__qualname__ = f"{__dataclass__cls.__qualname__}.__eq__"
|
|
1519
|
+
if '__eq__' in __dataclass__cls.__dict__:
|
|
1520
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __eq__ in class {__dataclass__cls.__name__}")
|
|
1521
|
+
setattr(__dataclass__cls, '__eq__', __eq__)
|
|
1522
|
+
|
|
1523
|
+
__dataclass___setattr_frozen_fields = {
|
|
1524
|
+
'profile',
|
|
1525
|
+
'args',
|
|
1526
|
+
}
|
|
1527
|
+
|
|
1528
|
+
def __setattr__(self, name, value):
|
|
1529
|
+
if (
|
|
1530
|
+
type(self) is __dataclass__cls
|
|
1531
|
+
or name in __dataclass___setattr_frozen_fields
|
|
1532
|
+
):
|
|
1533
|
+
raise __dataclass__FrozenInstanceError(f"cannot assign to field {name!r}")
|
|
1534
|
+
super(__dataclass__cls, self).__setattr__(name, value)
|
|
1535
|
+
|
|
1536
|
+
__setattr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__setattr__"
|
|
1537
|
+
if '__setattr__' in __dataclass__cls.__dict__:
|
|
1538
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __setattr__ in class {__dataclass__cls.__name__}")
|
|
1539
|
+
setattr(__dataclass__cls, '__setattr__', __setattr__)
|
|
1540
|
+
|
|
1541
|
+
__dataclass___delattr_frozen_fields = {
|
|
1542
|
+
'profile',
|
|
1543
|
+
'args',
|
|
1544
|
+
}
|
|
1545
|
+
|
|
1546
|
+
def __delattr__(self, name):
|
|
1547
|
+
if (
|
|
1548
|
+
type(self) is __dataclass__cls
|
|
1549
|
+
or name in __dataclass___delattr_frozen_fields
|
|
1550
|
+
):
|
|
1551
|
+
raise __dataclass__FrozenInstanceError(f"cannot delete field {name!r}")
|
|
1552
|
+
super(__dataclass__cls, self).__delattr__(name)
|
|
1553
|
+
|
|
1554
|
+
__delattr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__delattr__"
|
|
1555
|
+
if '__delattr__' in __dataclass__cls.__dict__:
|
|
1556
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __delattr__ in class {__dataclass__cls.__name__}")
|
|
1557
|
+
setattr(__dataclass__cls, '__delattr__', __delattr__)
|
|
1558
|
+
|
|
1559
|
+
def __hash__(self):
|
|
1560
|
+
return hash((
|
|
1561
|
+
self.profile,
|
|
1562
|
+
self.args,
|
|
1563
|
+
))
|
|
1564
|
+
|
|
1565
|
+
__hash__.__qualname__ = f"{__dataclass__cls.__qualname__}.__hash__"
|
|
1566
|
+
setattr(__dataclass__cls, '__hash__', __hash__)
|
|
1567
|
+
|
|
1568
|
+
def __init__(
|
|
1569
|
+
self,
|
|
1570
|
+
profile: __dataclass__init__fields__0__annotation,
|
|
1571
|
+
args: __dataclass__init__fields__1__annotation,
|
|
1572
|
+
) -> __dataclass__None:
|
|
1573
|
+
__dataclass__object_setattr(self, 'profile', profile)
|
|
1574
|
+
__dataclass__object_setattr(self, 'args', args)
|
|
1575
|
+
|
|
1576
|
+
__init__.__qualname__ = f"{__dataclass__cls.__qualname__}.__init__"
|
|
1577
|
+
if '__init__' in __dataclass__cls.__dict__:
|
|
1578
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __init__ in class {__dataclass__cls.__name__}")
|
|
1579
|
+
setattr(__dataclass__cls, '__init__', __init__)
|
|
1580
|
+
|
|
1581
|
+
@__dataclass___recursive_repr()
|
|
1582
|
+
def __repr__(self):
|
|
1583
|
+
parts = []
|
|
1584
|
+
parts.append(f"profile={self.profile!r}")
|
|
1585
|
+
parts.append(f"args={self.args!r}")
|
|
1586
|
+
return (
|
|
1587
|
+
f"{self.__class__.__qualname__}("
|
|
1588
|
+
f"{', '.join(parts)}"
|
|
1589
|
+
f")"
|
|
1590
|
+
)
|
|
1591
|
+
|
|
1592
|
+
__repr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__repr__"
|
|
1593
|
+
if '__repr__' in __dataclass__cls.__dict__:
|
|
1594
|
+
raise __dataclass__TypeError(f"Cannot overwrite attribute __repr__ in class {__dataclass__cls.__name__}")
|
|
1595
|
+
setattr(__dataclass__cls, '__repr__', __repr__)
|
|
1596
|
+
|
|
1597
|
+
return _process_dataclass
|
|
1598
|
+
|
|
1599
|
+
|
|
1453
1600
|
@_register(
|
|
1454
1601
|
plan_repr=(
|
|
1455
1602
|
"Plans(tup=(CopyPlan(fields=('markdown',)), EqPlan(fields=('markdown',)), FrozenPlan(fields=('markdown',), allo"
|
ommlds/cli/inject.py
CHANGED
|
@@ -17,6 +17,7 @@ with lang.auto_proxy_import(globals()):
|
|
|
17
17
|
def bind_main(
|
|
18
18
|
*,
|
|
19
19
|
session_cfg: SessionConfig,
|
|
20
|
+
profile_name: str | None = None,
|
|
20
21
|
) -> inj.Elements:
|
|
21
22
|
els: list[inj.Elemental] = []
|
|
22
23
|
|
|
@@ -30,7 +31,10 @@ def bind_main(
|
|
|
30
31
|
#
|
|
31
32
|
|
|
32
33
|
els.extend([
|
|
33
|
-
_sessions.bind_sessions(
|
|
34
|
+
_sessions.bind_sessions(
|
|
35
|
+
session_cfg,
|
|
36
|
+
profile_name=profile_name,
|
|
37
|
+
),
|
|
34
38
|
|
|
35
39
|
_state.bind_state(),
|
|
36
40
|
])
|
ommlds/cli/main.py
CHANGED
|
@@ -2,30 +2,20 @@
|
|
|
2
2
|
TODO:
|
|
3
3
|
- bootstrap lol
|
|
4
4
|
"""
|
|
5
|
-
import abc
|
|
6
5
|
import functools
|
|
7
|
-
import sys
|
|
8
6
|
import typing as ta
|
|
9
7
|
|
|
10
8
|
import anyio
|
|
11
9
|
|
|
12
|
-
from omlish import check
|
|
13
|
-
from omlish import dataclasses as dc
|
|
14
10
|
from omlish import inject as inj
|
|
15
|
-
from omlish import lang
|
|
16
11
|
from omlish.argparse import all as ap
|
|
17
12
|
from omlish.logs import all as logs
|
|
18
13
|
|
|
19
14
|
from .inject import bind_main
|
|
15
|
+
from .profiles import PROFILE_TYPES
|
|
20
16
|
from .secrets import install_env_secrets
|
|
21
17
|
from .sessions.base import Session
|
|
22
|
-
from .sessions.chat.configs import ChatConfig
|
|
23
|
-
from .sessions.chat.interfaces.bare.configs import BareInterfaceConfig
|
|
24
|
-
from .sessions.chat.interfaces.configs import InterfaceConfig
|
|
25
|
-
from .sessions.chat.interfaces.textual.configs import TextualInterfaceConfig
|
|
26
|
-
from .sessions.completion.configs import CompletionConfig
|
|
27
18
|
from .sessions.configs import SessionConfig
|
|
28
|
-
from .sessions.embedding.configs import EmbeddingConfig
|
|
29
19
|
|
|
30
20
|
|
|
31
21
|
##
|
|
@@ -48,337 +38,14 @@ def _process_main_extra_args(args: ap.Namespace) -> None:
|
|
|
48
38
|
##
|
|
49
39
|
|
|
50
40
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
##
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
# class ChatAspect(lang.Abstract):
|
|
61
|
-
# def get_parser_args(self) -> ta.Sequence[ap.Arg]: ...
|
|
62
|
-
# def set_args(self, args: ap.Namespace) -> None: ...
|
|
63
|
-
# def configure(self, cfg: ChatConfig) -> ChatConfig: ...
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class ChatProfile(Profile):
|
|
67
|
-
_args: ap.Namespace
|
|
68
|
-
|
|
69
|
-
#
|
|
70
|
-
|
|
71
|
-
BACKEND_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
72
|
-
ap.arg('-b', '--backend', group='backend'),
|
|
73
|
-
]
|
|
74
|
-
|
|
75
|
-
def configure_backend(self, cfg: ChatConfig) -> ChatConfig:
|
|
76
|
-
return dc.replace(
|
|
77
|
-
cfg,
|
|
78
|
-
driver=dc.replace(
|
|
79
|
-
cfg.driver,
|
|
80
|
-
backend=dc.replace(
|
|
81
|
-
cfg.driver.backend,
|
|
82
|
-
backend=self._args.backend,
|
|
83
|
-
),
|
|
84
|
-
),
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
INTERFACE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
90
|
-
ap.arg('-i', '--interactive', action='store_true', group='interface'),
|
|
91
|
-
ap.arg('-T', '--textual', action='store_true', group='interface'),
|
|
92
|
-
ap.arg('-e', '--editor', action='store_true', group='interface'),
|
|
93
|
-
]
|
|
94
|
-
|
|
95
|
-
def configure_interface(self, cfg: ChatConfig) -> ChatConfig:
|
|
96
|
-
if self._args.editor:
|
|
97
|
-
check.arg(not self._args.interactive)
|
|
98
|
-
check.arg(not self._args.message)
|
|
99
|
-
raise NotImplementedError
|
|
100
|
-
|
|
101
|
-
if self._args.textual:
|
|
102
|
-
check.isinstance(cfg.interface, BareInterfaceConfig)
|
|
103
|
-
cfg = dc.replace(
|
|
104
|
-
cfg,
|
|
105
|
-
interface=TextualInterfaceConfig(**{
|
|
106
|
-
f.name: getattr(cfg.interface, f.name)
|
|
107
|
-
for f in dc.fields(InterfaceConfig)
|
|
108
|
-
}),
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
else:
|
|
112
|
-
cfg = dc.replace(
|
|
113
|
-
cfg,
|
|
114
|
-
driver=dc.replace(
|
|
115
|
-
cfg.driver,
|
|
116
|
-
ai=dc.replace(
|
|
117
|
-
cfg.driver.ai,
|
|
118
|
-
verbose=True,
|
|
119
|
-
),
|
|
120
|
-
),
|
|
121
|
-
interface=dc.replace(
|
|
122
|
-
check.isinstance(cfg.interface, BareInterfaceConfig),
|
|
123
|
-
interactive=self._args.interactive,
|
|
124
|
-
),
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
return cfg
|
|
128
|
-
|
|
129
|
-
#
|
|
130
|
-
|
|
131
|
-
INPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
132
|
-
ap.arg('message', nargs='*', group='input'),
|
|
133
|
-
]
|
|
134
|
-
|
|
135
|
-
def configure_input(self, cfg: ChatConfig) -> ChatConfig:
|
|
136
|
-
if self._args.interactive or self._args.textual:
|
|
137
|
-
check.arg(not self._args.message)
|
|
138
|
-
|
|
139
|
-
elif self._args.message:
|
|
140
|
-
ps: list[str] = []
|
|
141
|
-
|
|
142
|
-
for a in self._args.message:
|
|
143
|
-
if a == '-':
|
|
144
|
-
ps.append(sys.stdin.read())
|
|
145
|
-
|
|
146
|
-
elif a.startswith('@'):
|
|
147
|
-
with open(a[1:]) as f:
|
|
148
|
-
ps.append(f.read())
|
|
149
|
-
|
|
150
|
-
else:
|
|
151
|
-
ps.append(a)
|
|
152
|
-
|
|
153
|
-
c = ' '.join(ps)
|
|
154
|
-
|
|
155
|
-
cfg = dc.replace(
|
|
156
|
-
cfg,
|
|
157
|
-
driver=dc.replace(
|
|
158
|
-
cfg.driver,
|
|
159
|
-
user=dc.replace(
|
|
160
|
-
cfg.driver.user,
|
|
161
|
-
initial_user_content=c,
|
|
162
|
-
),
|
|
163
|
-
),
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
else:
|
|
167
|
-
raise ValueError('Must specify input')
|
|
168
|
-
|
|
169
|
-
return cfg
|
|
170
|
-
|
|
171
|
-
#
|
|
172
|
-
|
|
173
|
-
STATE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
174
|
-
ap.arg('-n', '--new', action='store_true', group='state'),
|
|
175
|
-
ap.arg('--ephemeral', action='store_true', group='state'),
|
|
176
|
-
]
|
|
177
|
-
|
|
178
|
-
def configure_state(self, cfg: ChatConfig) -> ChatConfig:
|
|
179
|
-
return dc.replace(
|
|
180
|
-
cfg,
|
|
181
|
-
driver=dc.replace(
|
|
182
|
-
cfg.driver,
|
|
183
|
-
state=dc.replace(
|
|
184
|
-
cfg.driver.state,
|
|
185
|
-
state='ephemeral' if self._args.ephemeral else 'new' if self._args.new else 'continue',
|
|
186
|
-
),
|
|
187
|
-
),
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
#
|
|
191
|
-
|
|
192
|
-
OUTPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
193
|
-
ap.arg('-s', '--stream', action='store_true', group='output'),
|
|
194
|
-
ap.arg('-M', '--markdown', action='store_true', group='output'),
|
|
195
|
-
]
|
|
196
|
-
|
|
197
|
-
def configure_output(self, cfg: ChatConfig) -> ChatConfig:
|
|
198
|
-
return dc.replace(
|
|
199
|
-
cfg,
|
|
200
|
-
driver=dc.replace(
|
|
201
|
-
cfg.driver,
|
|
202
|
-
ai=dc.replace(
|
|
203
|
-
cfg.driver.ai,
|
|
204
|
-
stream=bool(self._args.stream),
|
|
205
|
-
),
|
|
206
|
-
),
|
|
207
|
-
rendering=dc.replace(
|
|
208
|
-
cfg.rendering,
|
|
209
|
-
markdown=bool(self._args.markdown),
|
|
210
|
-
),
|
|
211
|
-
)
|
|
212
|
-
|
|
213
|
-
#
|
|
214
|
-
|
|
215
|
-
TOOLS_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
216
|
-
ap.arg('--enable-fs-tools', action='store_true', group='tools'),
|
|
217
|
-
ap.arg('--enable-todo-tools', action='store_true', group='tools'),
|
|
218
|
-
# ap.arg('--enable-unsafe-tools-do-not-use-lol', action='store_true', group='tools'),
|
|
219
|
-
ap.arg('--enable-test-weather-tool', action='store_true', group='tools'),
|
|
220
|
-
]
|
|
221
|
-
|
|
222
|
-
def configure_tools(self, cfg: ChatConfig) -> ChatConfig:
|
|
223
|
-
if not (
|
|
224
|
-
self._args.enable_fs_tools or
|
|
225
|
-
self._args.enable_todo_tools or
|
|
226
|
-
# self._args.enable_unsafe_tools_do_not_use_lol or
|
|
227
|
-
self._args.enable_test_weather_tool or
|
|
228
|
-
self._args.code
|
|
229
|
-
):
|
|
230
|
-
return cfg
|
|
231
|
-
|
|
232
|
-
return dc.replace(
|
|
233
|
-
cfg,
|
|
234
|
-
driver=dc.replace(
|
|
235
|
-
cfg.driver,
|
|
236
|
-
ai=dc.replace(
|
|
237
|
-
cfg.driver.ai,
|
|
238
|
-
enable_tools=True,
|
|
239
|
-
),
|
|
240
|
-
tools=dc.replace(
|
|
241
|
-
cfg.driver.tools,
|
|
242
|
-
enabled_tools={ # noqa
|
|
243
|
-
*(cfg.driver.tools.enabled_tools or []),
|
|
244
|
-
*(['fs'] if self._args.enable_fs_tools else []),
|
|
245
|
-
*(['todo'] if self._args.enable_todo_tools else []),
|
|
246
|
-
*(['weather'] if self._args.enable_test_weather_tool else []),
|
|
247
|
-
},
|
|
248
|
-
),
|
|
249
|
-
),
|
|
250
|
-
interface=dc.replace(
|
|
251
|
-
cfg.interface,
|
|
252
|
-
enable_tools=True,
|
|
253
|
-
),
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
#
|
|
257
|
-
|
|
258
|
-
CODE_CONFIG: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
259
|
-
ap.arg('-c', '--code', action='store_true', group='code'),
|
|
260
|
-
]
|
|
261
|
-
|
|
262
|
-
def configure_code(self, cfg: ChatConfig) -> ChatConfig:
|
|
263
|
-
if not self._args.code:
|
|
264
|
-
return cfg
|
|
265
|
-
|
|
266
|
-
cfg = dc.replace(
|
|
267
|
-
cfg,
|
|
268
|
-
driver=dc.replace(
|
|
269
|
-
cfg.driver,
|
|
270
|
-
ai=dc.replace(
|
|
271
|
-
cfg.driver.ai,
|
|
272
|
-
enable_tools=True,
|
|
273
|
-
),
|
|
274
|
-
),
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
if self._args.new or self._args.ephemeral:
|
|
278
|
-
from ..minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
|
|
279
|
-
system_content = CODE_AGENT_SYSTEM_PROMPT
|
|
280
|
-
|
|
281
|
-
cfg = dc.replace(
|
|
282
|
-
cfg,
|
|
283
|
-
driver=dc.replace(
|
|
284
|
-
cfg.driver,
|
|
285
|
-
user=dc.replace(
|
|
286
|
-
cfg.driver.user,
|
|
287
|
-
initial_system_content=system_content,
|
|
288
|
-
),
|
|
289
|
-
),
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
return cfg
|
|
293
|
-
|
|
294
|
-
#
|
|
295
|
-
|
|
296
|
-
def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
|
|
297
|
-
parser = ap.ArgumentParser()
|
|
298
|
-
|
|
299
|
-
for grp_name, grp_args in [
|
|
300
|
-
('backend', self.BACKEND_ARGS),
|
|
301
|
-
('interface', self.INTERFACE_ARGS),
|
|
302
|
-
('input', self.INPUT_ARGS),
|
|
303
|
-
('state', self.STATE_ARGS),
|
|
304
|
-
('output', self.OUTPUT_ARGS),
|
|
305
|
-
('tools', self.TOOLS_ARGS),
|
|
306
|
-
('code', self.CODE_CONFIG),
|
|
307
|
-
]:
|
|
308
|
-
grp = parser.add_argument_group(grp_name)
|
|
309
|
-
for a in grp_args:
|
|
310
|
-
grp.add_argument(*a.args, **a.kwargs)
|
|
311
|
-
|
|
312
|
-
self._args = parser.parse_args(argv)
|
|
313
|
-
|
|
314
|
-
cfg = ChatConfig()
|
|
315
|
-
cfg = self.configure_backend(cfg)
|
|
316
|
-
cfg = self.configure_interface(cfg)
|
|
317
|
-
cfg = self.configure_input(cfg)
|
|
318
|
-
cfg = self.configure_state(cfg)
|
|
319
|
-
cfg = self.configure_output(cfg)
|
|
320
|
-
cfg = self.configure_tools(cfg)
|
|
321
|
-
cfg = self.configure_code(cfg)
|
|
322
|
-
|
|
323
|
-
return cfg
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
##
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
class CompletionProfile(Profile):
|
|
330
|
-
def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
|
|
331
|
-
parser = ap.ArgumentParser()
|
|
332
|
-
parser.add_argument('prompt', nargs='*')
|
|
333
|
-
parser.add_argument('-b', '--backend', default='openai')
|
|
334
|
-
args = parser.parse_args(argv)
|
|
335
|
-
|
|
336
|
-
content = ' '.join(args.prompt)
|
|
337
|
-
|
|
338
|
-
cfg = CompletionConfig(
|
|
339
|
-
content=check.non_empty_str(content),
|
|
340
|
-
backend=args.backend,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
return cfg
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
##
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
class EmbedProfile(Profile):
|
|
350
|
-
def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
|
|
351
|
-
parser = ap.ArgumentParser()
|
|
352
|
-
parser.add_argument('prompt', nargs='*')
|
|
353
|
-
parser.add_argument('-b', '--backend', default='openai')
|
|
354
|
-
args = parser.parse_args(argv)
|
|
355
|
-
|
|
356
|
-
content = ' '.join(args.prompt)
|
|
357
|
-
|
|
358
|
-
cfg = EmbeddingConfig(
|
|
359
|
-
content=check.non_empty_str(content),
|
|
360
|
-
backend=args.backend,
|
|
361
|
-
)
|
|
362
|
-
|
|
363
|
-
return cfg
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
##
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
PROFILE_TYPES: ta.Mapping[str, type[Profile]] = {
|
|
370
|
-
'chat': ChatProfile,
|
|
371
|
-
'complete': CompletionProfile,
|
|
372
|
-
'embed': EmbedProfile,
|
|
373
|
-
}
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
##
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
async def _run_session_cfg(session_cfg: SessionConfig) -> None:
|
|
41
|
+
async def _run_session_cfg(
|
|
42
|
+
session_cfg: SessionConfig,
|
|
43
|
+
*,
|
|
44
|
+
profile_name: str | None = None,
|
|
45
|
+
) -> None:
|
|
380
46
|
async with inj.create_async_managed_injector(bind_main(
|
|
381
47
|
session_cfg=session_cfg,
|
|
48
|
+
profile_name=profile_name,
|
|
382
49
|
)) as injector:
|
|
383
50
|
await (await injector[Session]).run()
|
|
384
51
|
|
|
@@ -388,12 +55,13 @@ async def _run_session_cfg(session_cfg: SessionConfig) -> None:
|
|
|
388
55
|
|
|
389
56
|
MAIN_PROFILE_ARGS: ta.Sequence[ap.Arg] = [
|
|
390
57
|
ap.arg('-p', '--profile', default='chat'),
|
|
58
|
+
ap.arg('-h', '--help', action='store_true'),
|
|
391
59
|
ap.arg('args', nargs=ap.REMAINDER),
|
|
392
60
|
]
|
|
393
61
|
|
|
394
62
|
|
|
395
63
|
async def _a_main(argv: ta.Any = None) -> None:
|
|
396
|
-
parser = ap.ArgumentParser()
|
|
64
|
+
parser = ap.ArgumentParser(add_help=False)
|
|
397
65
|
|
|
398
66
|
for a in [*MAIN_PROFILE_ARGS, *MAIN_EXTRA_ARGS]:
|
|
399
67
|
parser.add_argument(*a.args, **a.kwargs)
|
|
@@ -407,9 +75,16 @@ async def _a_main(argv: ta.Any = None) -> None:
|
|
|
407
75
|
profile_cls = PROFILE_TYPES[args.profile]
|
|
408
76
|
profile = profile_cls()
|
|
409
77
|
|
|
410
|
-
session_cfg = profile.configure([
|
|
78
|
+
session_cfg = profile.configure([
|
|
79
|
+
*unk_args,
|
|
80
|
+
*(['--help'] if args.help else []),
|
|
81
|
+
*args.args,
|
|
82
|
+
])
|
|
411
83
|
|
|
412
|
-
await _run_session_cfg(
|
|
84
|
+
await _run_session_cfg(
|
|
85
|
+
session_cfg,
|
|
86
|
+
profile_name=args.profile,
|
|
87
|
+
)
|
|
413
88
|
|
|
414
89
|
|
|
415
90
|
def _main(args: ta.Any = None) -> None:
|