ommlds 0.0.0.dev514__py3-none-any.whl → 0.0.0.dev516__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ommlds/__about__.py CHANGED
@@ -22,7 +22,7 @@ class Project(ProjectBase):
22
22
  'llama-cpp-python ~= 0.3',
23
23
 
24
24
  'mlx ~= 0.30; sys_platform == "darwin"',
25
- 'mlx-lm ~= 0.29; sys_platform == "darwin"',
25
+ 'mlx-lm ~= 0.30; sys_platform == "darwin"',
26
26
 
27
27
  'sentencepiece ~= 0.2',
28
28
 
@@ -35,12 +35,12 @@ class Project(ProjectBase):
35
35
 
36
36
  'torch ~= 2.10',
37
37
 
38
- 'transformers ~= 4.57',
38
+ 'transformers ~= 5.0',
39
39
  'sentence-transformers ~= 5.2',
40
40
  ],
41
41
 
42
42
  'huggingface': [
43
- 'huggingface-hub ~= 0.36',
43
+ 'huggingface-hub ~= 1.3',
44
44
  'datasets ~= 4.5',
45
45
  ],
46
46
 
@@ -238,7 +238,7 @@ class GenerationOutput:
238
238
  def stream_generate(
239
239
  model: 'mlx_nn.Module',
240
240
  tokenization: Tokenization,
241
- prompt: ta.Union[str, 'mx.array'],
241
+ prompt: ta.Union[str, ta.Sequence[int], 'mx.array'],
242
242
  params: GenerationParams = GenerationParams(),
243
243
  ) -> ta.Generator[GenerationOutput]:
244
244
  if not isinstance(prompt, mx.array):
@@ -312,7 +312,7 @@ def stream_generate(
312
312
  def generate(
313
313
  model: 'mlx_nn.Module',
314
314
  tokenization: Tokenization,
315
- prompt: ta.Union[str, 'mx.array'],
315
+ prompt: ta.Union[str, ta.Sequence[int], 'mx.array'],
316
316
  params: GenerationParams = GenerationParams(),
317
317
  *,
318
318
  verbose: bool = False,
@@ -1,3 +1,5 @@
1
+ from omlish import check
2
+
1
3
  from ..types import Tokenizer
2
4
  from .base import BaseStreamingDetokenizer
3
5
 
@@ -36,14 +38,14 @@ class NaiveStreamingDetokenizer(BaseStreamingDetokenizer):
36
38
  self._tokens.append(token)
37
39
 
38
40
  def finalize(self) -> None:
39
- self._text += self._tokenizer.decode(self._current_tokens)
41
+ self._text += check.isinstance(self._tokenizer.decode(self._current_tokens), str)
40
42
  self._current_tokens = []
41
43
  self._current_text = ''
42
44
 
43
45
  @property
44
46
  def text(self) -> str:
45
47
  if self._current_tokens:
46
- self._current_text = self._tokenizer.decode(self._current_tokens)
48
+ self._current_text = check.isinstance(self._tokenizer.decode(self._current_tokens), str)
47
49
  if self._current_text.endswith('\ufffd') or (
48
50
  self._tokenizer.clean_up_tokenization_spaces and
49
51
  self._current_text and
@@ -44,7 +44,7 @@ class Tokenization:
44
44
  try:
45
45
  token_id = int(token)
46
46
  except ValueError:
47
- token_id = self._tokenizer.convert_tokens_to_ids(token)
47
+ token_id = check.isinstance(self._tokenizer.convert_tokens_to_ids(check.isinstance(token, str)), int)
48
48
 
49
49
  if token_id is None:
50
50
  raise ValueError(f"'{token}' is not a token for this tokenizer")
@@ -12,19 +12,24 @@ P = ta.ParamSpec('P')
12
12
 
13
13
 
14
14
  class CancellableTextStreamer(tfm.TextStreamer):
15
+ class Tokenizer(ta.Protocol):
16
+ """Transformers base class still hard deps an AutoTokenizer despite only needing this one method."""
17
+
18
+ def decode(self, tokens: ta.Sequence[int], **kwargs: ta.Any) -> str: ...
19
+
15
20
  class Callback(ta.Protocol):
16
21
  def __call__(self, text: str, *, stream_end: bool) -> None: ...
17
22
 
18
23
  def __init__(
19
24
  self,
20
- tokenizer: tfm.AutoTokenizer,
25
+ tokenizer: Tokenizer,
21
26
  callback: Callback,
22
27
  *,
23
28
  skip_prompt: bool = False,
24
29
  **decode_kwargs: ta.Any,
25
30
  ) -> None:
26
31
  super().__init__(
27
- tokenizer,
32
+ ta.cast(tfm.AutoTokenizer, tokenizer), # noqa
28
33
  skip_prompt=skip_prompt,
29
34
  **decode_kwargs,
30
35
  )
@@ -1450,6 +1450,153 @@ def _process_dataclass__d65d18393f357ae0fb02bb80268c6f1473462613():
1450
1450
  return _process_dataclass
1451
1451
 
1452
1452
 
1453
+ @_register(
1454
+ plan_repr=(
1455
+ "Plans(tup=(CopyPlan(fields=('profile', 'args')), EqPlan(fields=('profile', 'args')), FrozenPlan(fields=('profi"
1456
+ "le', 'args'), allow_dynamic_dunder_attrs=False), HashPlan(action='add', fields=('profile', 'args'), cache=Fals"
1457
+ "e), InitPlan(fields=(InitPlan.Field(name='profile', annotation=OpRef(name='init.fields.0.annotation'), default"
1458
+ "=None, default_factory=None, init=True, override=False, field_type=FieldType.INSTANCE, coerce=None, validate=N"
1459
+ "one, check_type=None), InitPlan.Field(name='args', annotation=OpRef(name='init.fields.1.annotation'), default="
1460
+ "None, default_factory=None, init=True, override=False, field_type=FieldType.INSTANCE, coerce=None, validate=No"
1461
+ "ne, check_type=None)), self_param='self', std_params=('profile', 'args'), kw_only_params=(), frozen=True, slot"
1462
+ "s=False, post_init_params=None, init_fns=(), validate_fns=()), ReprPlan(fields=(ReprPlan.Field(name='profile',"
1463
+ " kw_only=False, fn=None), ReprPlan.Field(name='args', kw_only=False, fn=None)), id=False, terse=False, default"
1464
+ "_fn=None)))"
1465
+ ),
1466
+ plan_repr_sha1='3ca2bdd121559ad92e95a0e4bdf8210452f20e5d',
1467
+ op_ref_idents=(
1468
+ '__dataclass__init__fields__0__annotation',
1469
+ '__dataclass__init__fields__1__annotation',
1470
+ ),
1471
+ cls_names=(
1472
+ ('ommlds.cli.main', 'ProfileAspect.ConfigureContext'),
1473
+ ),
1474
+ )
1475
+ def _process_dataclass__3ca2bdd121559ad92e95a0e4bdf8210452f20e5d():
1476
+ def _process_dataclass(
1477
+ *,
1478
+ __dataclass__cls,
1479
+ __dataclass__init__fields__0__annotation,
1480
+ __dataclass__init__fields__1__annotation,
1481
+ __dataclass__FieldFnValidationError, # noqa
1482
+ __dataclass__FieldTypeValidationError, # noqa
1483
+ __dataclass__FnValidationError, # noqa
1484
+ __dataclass__FrozenInstanceError=dataclasses.FrozenInstanceError, # noqa
1485
+ __dataclass__FunctionType=types.FunctionType, # noqa
1486
+ __dataclass__HAS_DEFAULT_FACTORY=dataclasses._HAS_DEFAULT_FACTORY, # noqa
1487
+ __dataclass__MISSING=dataclasses.MISSING, # noqa
1488
+ __dataclass__None=None, # noqa
1489
+ __dataclass__TypeError=TypeError, # noqa
1490
+ __dataclass___recursive_repr=reprlib.recursive_repr, # noqa
1491
+ __dataclass__isinstance=isinstance, # noqa
1492
+ __dataclass__object_setattr=object.__setattr__, # noqa
1493
+ __dataclass__property=property, # noqa
1494
+ ):
1495
+ def __copy__(self):
1496
+ if self.__class__ is not __dataclass__cls:
1497
+ raise TypeError(self)
1498
+ return __dataclass__cls( # noqa
1499
+ profile=self.profile,
1500
+ args=self.args,
1501
+ )
1502
+
1503
+ __copy__.__qualname__ = f"{__dataclass__cls.__qualname__}.__copy__"
1504
+ if '__copy__' in __dataclass__cls.__dict__:
1505
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __copy__ in class {__dataclass__cls.__name__}")
1506
+ setattr(__dataclass__cls, '__copy__', __copy__)
1507
+
1508
+ def __eq__(self, other):
1509
+ if self is other:
1510
+ return True
1511
+ if self.__class__ is not other.__class__:
1512
+ return NotImplemented
1513
+ return (
1514
+ self.profile == other.profile and
1515
+ self.args == other.args
1516
+ )
1517
+
1518
+ __eq__.__qualname__ = f"{__dataclass__cls.__qualname__}.__eq__"
1519
+ if '__eq__' in __dataclass__cls.__dict__:
1520
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __eq__ in class {__dataclass__cls.__name__}")
1521
+ setattr(__dataclass__cls, '__eq__', __eq__)
1522
+
1523
+ __dataclass___setattr_frozen_fields = {
1524
+ 'profile',
1525
+ 'args',
1526
+ }
1527
+
1528
+ def __setattr__(self, name, value):
1529
+ if (
1530
+ type(self) is __dataclass__cls
1531
+ or name in __dataclass___setattr_frozen_fields
1532
+ ):
1533
+ raise __dataclass__FrozenInstanceError(f"cannot assign to field {name!r}")
1534
+ super(__dataclass__cls, self).__setattr__(name, value)
1535
+
1536
+ __setattr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__setattr__"
1537
+ if '__setattr__' in __dataclass__cls.__dict__:
1538
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __setattr__ in class {__dataclass__cls.__name__}")
1539
+ setattr(__dataclass__cls, '__setattr__', __setattr__)
1540
+
1541
+ __dataclass___delattr_frozen_fields = {
1542
+ 'profile',
1543
+ 'args',
1544
+ }
1545
+
1546
+ def __delattr__(self, name):
1547
+ if (
1548
+ type(self) is __dataclass__cls
1549
+ or name in __dataclass___delattr_frozen_fields
1550
+ ):
1551
+ raise __dataclass__FrozenInstanceError(f"cannot delete field {name!r}")
1552
+ super(__dataclass__cls, self).__delattr__(name)
1553
+
1554
+ __delattr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__delattr__"
1555
+ if '__delattr__' in __dataclass__cls.__dict__:
1556
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __delattr__ in class {__dataclass__cls.__name__}")
1557
+ setattr(__dataclass__cls, '__delattr__', __delattr__)
1558
+
1559
+ def __hash__(self):
1560
+ return hash((
1561
+ self.profile,
1562
+ self.args,
1563
+ ))
1564
+
1565
+ __hash__.__qualname__ = f"{__dataclass__cls.__qualname__}.__hash__"
1566
+ setattr(__dataclass__cls, '__hash__', __hash__)
1567
+
1568
+ def __init__(
1569
+ self,
1570
+ profile: __dataclass__init__fields__0__annotation,
1571
+ args: __dataclass__init__fields__1__annotation,
1572
+ ) -> __dataclass__None:
1573
+ __dataclass__object_setattr(self, 'profile', profile)
1574
+ __dataclass__object_setattr(self, 'args', args)
1575
+
1576
+ __init__.__qualname__ = f"{__dataclass__cls.__qualname__}.__init__"
1577
+ if '__init__' in __dataclass__cls.__dict__:
1578
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __init__ in class {__dataclass__cls.__name__}")
1579
+ setattr(__dataclass__cls, '__init__', __init__)
1580
+
1581
+ @__dataclass___recursive_repr()
1582
+ def __repr__(self):
1583
+ parts = []
1584
+ parts.append(f"profile={self.profile!r}")
1585
+ parts.append(f"args={self.args!r}")
1586
+ return (
1587
+ f"{self.__class__.__qualname__}("
1588
+ f"{', '.join(parts)}"
1589
+ f")"
1590
+ )
1591
+
1592
+ __repr__.__qualname__ = f"{__dataclass__cls.__qualname__}.__repr__"
1593
+ if '__repr__' in __dataclass__cls.__dict__:
1594
+ raise __dataclass__TypeError(f"Cannot overwrite attribute __repr__ in class {__dataclass__cls.__name__}")
1595
+ setattr(__dataclass__cls, '__repr__', __repr__)
1596
+
1597
+ return _process_dataclass
1598
+
1599
+
1453
1600
  @_register(
1454
1601
  plan_repr=(
1455
1602
  "Plans(tup=(CopyPlan(fields=('markdown',)), EqPlan(fields=('markdown',)), FrozenPlan(fields=('markdown',), allo"
ommlds/cli/inject.py CHANGED
@@ -17,6 +17,7 @@ with lang.auto_proxy_import(globals()):
17
17
  def bind_main(
18
18
  *,
19
19
  session_cfg: SessionConfig,
20
+ profile_name: str | None = None,
20
21
  ) -> inj.Elements:
21
22
  els: list[inj.Elemental] = []
22
23
 
@@ -30,7 +31,10 @@ def bind_main(
30
31
  #
31
32
 
32
33
  els.extend([
33
- _sessions.bind_sessions(session_cfg),
34
+ _sessions.bind_sessions(
35
+ session_cfg,
36
+ profile_name=profile_name,
37
+ ),
34
38
 
35
39
  _state.bind_state(),
36
40
  ])
ommlds/cli/main.py CHANGED
@@ -2,30 +2,20 @@
2
2
  TODO:
3
3
  - bootstrap lol
4
4
  """
5
- import abc
6
5
  import functools
7
- import sys
8
6
  import typing as ta
9
7
 
10
8
  import anyio
11
9
 
12
- from omlish import check
13
- from omlish import dataclasses as dc
14
10
  from omlish import inject as inj
15
- from omlish import lang
16
11
  from omlish.argparse import all as ap
17
12
  from omlish.logs import all as logs
18
13
 
19
14
  from .inject import bind_main
15
+ from .profiles import PROFILE_TYPES
20
16
  from .secrets import install_env_secrets
21
17
  from .sessions.base import Session
22
- from .sessions.chat.configs import ChatConfig
23
- from .sessions.chat.interfaces.bare.configs import BareInterfaceConfig
24
- from .sessions.chat.interfaces.configs import InterfaceConfig
25
- from .sessions.chat.interfaces.textual.configs import TextualInterfaceConfig
26
- from .sessions.completion.configs import CompletionConfig
27
18
  from .sessions.configs import SessionConfig
28
- from .sessions.embedding.configs import EmbeddingConfig
29
19
 
30
20
 
31
21
  ##
@@ -48,337 +38,14 @@ def _process_main_extra_args(args: ap.Namespace) -> None:
48
38
  ##
49
39
 
50
40
 
51
- class Profile(lang.Abstract):
52
- @abc.abstractmethod
53
- def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
54
- raise NotImplementedError
55
-
56
-
57
- ##
58
-
59
-
60
- # class ChatAspect(lang.Abstract):
61
- # def get_parser_args(self) -> ta.Sequence[ap.Arg]: ...
62
- # def set_args(self, args: ap.Namespace) -> None: ...
63
- # def configure(self, cfg: ChatConfig) -> ChatConfig: ...
64
-
65
-
66
- class ChatProfile(Profile):
67
- _args: ap.Namespace
68
-
69
- #
70
-
71
- BACKEND_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
72
- ap.arg('-b', '--backend', group='backend'),
73
- ]
74
-
75
- def configure_backend(self, cfg: ChatConfig) -> ChatConfig:
76
- return dc.replace(
77
- cfg,
78
- driver=dc.replace(
79
- cfg.driver,
80
- backend=dc.replace(
81
- cfg.driver.backend,
82
- backend=self._args.backend,
83
- ),
84
- ),
85
- )
86
-
87
- #
88
-
89
- INTERFACE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
90
- ap.arg('-i', '--interactive', action='store_true', group='interface'),
91
- ap.arg('-T', '--textual', action='store_true', group='interface'),
92
- ap.arg('-e', '--editor', action='store_true', group='interface'),
93
- ]
94
-
95
- def configure_interface(self, cfg: ChatConfig) -> ChatConfig:
96
- if self._args.editor:
97
- check.arg(not self._args.interactive)
98
- check.arg(not self._args.message)
99
- raise NotImplementedError
100
-
101
- if self._args.textual:
102
- check.isinstance(cfg.interface, BareInterfaceConfig)
103
- cfg = dc.replace(
104
- cfg,
105
- interface=TextualInterfaceConfig(**{
106
- f.name: getattr(cfg.interface, f.name)
107
- for f in dc.fields(InterfaceConfig)
108
- }),
109
- )
110
-
111
- else:
112
- cfg = dc.replace(
113
- cfg,
114
- driver=dc.replace(
115
- cfg.driver,
116
- ai=dc.replace(
117
- cfg.driver.ai,
118
- verbose=True,
119
- ),
120
- ),
121
- interface=dc.replace(
122
- check.isinstance(cfg.interface, BareInterfaceConfig),
123
- interactive=self._args.interactive,
124
- ),
125
- )
126
-
127
- return cfg
128
-
129
- #
130
-
131
- INPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
132
- ap.arg('message', nargs='*', group='input'),
133
- ]
134
-
135
- def configure_input(self, cfg: ChatConfig) -> ChatConfig:
136
- if self._args.interactive or self._args.textual:
137
- check.arg(not self._args.message)
138
-
139
- elif self._args.message:
140
- ps: list[str] = []
141
-
142
- for a in self._args.message:
143
- if a == '-':
144
- ps.append(sys.stdin.read())
145
-
146
- elif a.startswith('@'):
147
- with open(a[1:]) as f:
148
- ps.append(f.read())
149
-
150
- else:
151
- ps.append(a)
152
-
153
- c = ' '.join(ps)
154
-
155
- cfg = dc.replace(
156
- cfg,
157
- driver=dc.replace(
158
- cfg.driver,
159
- user=dc.replace(
160
- cfg.driver.user,
161
- initial_user_content=c,
162
- ),
163
- ),
164
- )
165
-
166
- else:
167
- raise ValueError('Must specify input')
168
-
169
- return cfg
170
-
171
- #
172
-
173
- STATE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
174
- ap.arg('-n', '--new', action='store_true', group='state'),
175
- ap.arg('--ephemeral', action='store_true', group='state'),
176
- ]
177
-
178
- def configure_state(self, cfg: ChatConfig) -> ChatConfig:
179
- return dc.replace(
180
- cfg,
181
- driver=dc.replace(
182
- cfg.driver,
183
- state=dc.replace(
184
- cfg.driver.state,
185
- state='ephemeral' if self._args.ephemeral else 'new' if self._args.new else 'continue',
186
- ),
187
- ),
188
- )
189
-
190
- #
191
-
192
- OUTPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
193
- ap.arg('-s', '--stream', action='store_true', group='output'),
194
- ap.arg('-M', '--markdown', action='store_true', group='output'),
195
- ]
196
-
197
- def configure_output(self, cfg: ChatConfig) -> ChatConfig:
198
- return dc.replace(
199
- cfg,
200
- driver=dc.replace(
201
- cfg.driver,
202
- ai=dc.replace(
203
- cfg.driver.ai,
204
- stream=bool(self._args.stream),
205
- ),
206
- ),
207
- rendering=dc.replace(
208
- cfg.rendering,
209
- markdown=bool(self._args.markdown),
210
- ),
211
- )
212
-
213
- #
214
-
215
- TOOLS_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
216
- ap.arg('--enable-fs-tools', action='store_true', group='tools'),
217
- ap.arg('--enable-todo-tools', action='store_true', group='tools'),
218
- # ap.arg('--enable-unsafe-tools-do-not-use-lol', action='store_true', group='tools'),
219
- ap.arg('--enable-test-weather-tool', action='store_true', group='tools'),
220
- ]
221
-
222
- def configure_tools(self, cfg: ChatConfig) -> ChatConfig:
223
- if not (
224
- self._args.enable_fs_tools or
225
- self._args.enable_todo_tools or
226
- # self._args.enable_unsafe_tools_do_not_use_lol or
227
- self._args.enable_test_weather_tool or
228
- self._args.code
229
- ):
230
- return cfg
231
-
232
- return dc.replace(
233
- cfg,
234
- driver=dc.replace(
235
- cfg.driver,
236
- ai=dc.replace(
237
- cfg.driver.ai,
238
- enable_tools=True,
239
- ),
240
- tools=dc.replace(
241
- cfg.driver.tools,
242
- enabled_tools={ # noqa
243
- *(cfg.driver.tools.enabled_tools or []),
244
- *(['fs'] if self._args.enable_fs_tools else []),
245
- *(['todo'] if self._args.enable_todo_tools else []),
246
- *(['weather'] if self._args.enable_test_weather_tool else []),
247
- },
248
- ),
249
- ),
250
- interface=dc.replace(
251
- cfg.interface,
252
- enable_tools=True,
253
- ),
254
- )
255
-
256
- #
257
-
258
- CODE_CONFIG: ta.ClassVar[ta.Sequence[ap.Arg]] = [
259
- ap.arg('-c', '--code', action='store_true', group='code'),
260
- ]
261
-
262
- def configure_code(self, cfg: ChatConfig) -> ChatConfig:
263
- if not self._args.code:
264
- return cfg
265
-
266
- cfg = dc.replace(
267
- cfg,
268
- driver=dc.replace(
269
- cfg.driver,
270
- ai=dc.replace(
271
- cfg.driver.ai,
272
- enable_tools=True,
273
- ),
274
- ),
275
- )
276
-
277
- if self._args.new or self._args.ephemeral:
278
- from ..minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
279
- system_content = CODE_AGENT_SYSTEM_PROMPT
280
-
281
- cfg = dc.replace(
282
- cfg,
283
- driver=dc.replace(
284
- cfg.driver,
285
- user=dc.replace(
286
- cfg.driver.user,
287
- initial_system_content=system_content,
288
- ),
289
- ),
290
- )
291
-
292
- return cfg
293
-
294
- #
295
-
296
- def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
297
- parser = ap.ArgumentParser()
298
-
299
- for grp_name, grp_args in [
300
- ('backend', self.BACKEND_ARGS),
301
- ('interface', self.INTERFACE_ARGS),
302
- ('input', self.INPUT_ARGS),
303
- ('state', self.STATE_ARGS),
304
- ('output', self.OUTPUT_ARGS),
305
- ('tools', self.TOOLS_ARGS),
306
- ('code', self.CODE_CONFIG),
307
- ]:
308
- grp = parser.add_argument_group(grp_name)
309
- for a in grp_args:
310
- grp.add_argument(*a.args, **a.kwargs)
311
-
312
- self._args = parser.parse_args(argv)
313
-
314
- cfg = ChatConfig()
315
- cfg = self.configure_backend(cfg)
316
- cfg = self.configure_interface(cfg)
317
- cfg = self.configure_input(cfg)
318
- cfg = self.configure_state(cfg)
319
- cfg = self.configure_output(cfg)
320
- cfg = self.configure_tools(cfg)
321
- cfg = self.configure_code(cfg)
322
-
323
- return cfg
324
-
325
-
326
- ##
327
-
328
-
329
- class CompletionProfile(Profile):
330
- def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
331
- parser = ap.ArgumentParser()
332
- parser.add_argument('prompt', nargs='*')
333
- parser.add_argument('-b', '--backend', default='openai')
334
- args = parser.parse_args(argv)
335
-
336
- content = ' '.join(args.prompt)
337
-
338
- cfg = CompletionConfig(
339
- content=check.non_empty_str(content),
340
- backend=args.backend,
341
- )
342
-
343
- return cfg
344
-
345
-
346
- ##
347
-
348
-
349
- class EmbedProfile(Profile):
350
- def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
351
- parser = ap.ArgumentParser()
352
- parser.add_argument('prompt', nargs='*')
353
- parser.add_argument('-b', '--backend', default='openai')
354
- args = parser.parse_args(argv)
355
-
356
- content = ' '.join(args.prompt)
357
-
358
- cfg = EmbeddingConfig(
359
- content=check.non_empty_str(content),
360
- backend=args.backend,
361
- )
362
-
363
- return cfg
364
-
365
-
366
- ##
367
-
368
-
369
- PROFILE_TYPES: ta.Mapping[str, type[Profile]] = {
370
- 'chat': ChatProfile,
371
- 'complete': CompletionProfile,
372
- 'embed': EmbedProfile,
373
- }
374
-
375
-
376
- ##
377
-
378
-
379
- async def _run_session_cfg(session_cfg: SessionConfig) -> None:
41
+ async def _run_session_cfg(
42
+ session_cfg: SessionConfig,
43
+ *,
44
+ profile_name: str | None = None,
45
+ ) -> None:
380
46
  async with inj.create_async_managed_injector(bind_main(
381
47
  session_cfg=session_cfg,
48
+ profile_name=profile_name,
382
49
  )) as injector:
383
50
  await (await injector[Session]).run()
384
51
 
@@ -388,12 +55,13 @@ async def _run_session_cfg(session_cfg: SessionConfig) -> None:
388
55
 
389
56
  MAIN_PROFILE_ARGS: ta.Sequence[ap.Arg] = [
390
57
  ap.arg('-p', '--profile', default='chat'),
58
+ ap.arg('-h', '--help', action='store_true'),
391
59
  ap.arg('args', nargs=ap.REMAINDER),
392
60
  ]
393
61
 
394
62
 
395
63
  async def _a_main(argv: ta.Any = None) -> None:
396
- parser = ap.ArgumentParser()
64
+ parser = ap.ArgumentParser(add_help=False)
397
65
 
398
66
  for a in [*MAIN_PROFILE_ARGS, *MAIN_EXTRA_ARGS]:
399
67
  parser.add_argument(*a.args, **a.kwargs)
@@ -407,9 +75,16 @@ async def _a_main(argv: ta.Any = None) -> None:
407
75
  profile_cls = PROFILE_TYPES[args.profile]
408
76
  profile = profile_cls()
409
77
 
410
- session_cfg = profile.configure([*unk_args, *args.args])
78
+ session_cfg = profile.configure([
79
+ *unk_args,
80
+ *(['--help'] if args.help else []),
81
+ *args.args,
82
+ ])
411
83
 
412
- await _run_session_cfg(session_cfg)
84
+ await _run_session_cfg(
85
+ session_cfg,
86
+ profile_name=args.profile,
87
+ )
413
88
 
414
89
 
415
90
  def _main(args: ta.Any = None) -> None: