ommlds 0.0.0.dev514__py3-none-any.whl → 0.0.0.dev516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/__about__.py +3 -3
- ommlds/backends/mlx/generation.py +2 -2
- ommlds/backends/mlx/tokenization/detokenization/naive.py +4 -2
- ommlds/backends/mlx/tokenization/tokenization.py +1 -1
- ommlds/backends/transformers/streamers.py +7 -2
- ommlds/cli/_dataclasses.py +147 -0
- ommlds/cli/inject.py +5 -1
- ommlds/cli/main.py +18 -343
- ommlds/cli/profiles.py +449 -0
- ommlds/cli/sessions/chat/interfaces/textual/app.py +8 -2
- ommlds/cli/sessions/chat/interfaces/textual/inputhistory.py +67 -30
- ommlds/cli/sessions/inject.py +11 -1
- ommlds/cli/sessions/types.py +7 -0
- ommlds/minichain/backends/impls/transformers/tokens.py +4 -2
- ommlds/minichain/backends/impls/transformers/transformers.py +2 -2
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/METADATA +10 -10
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/RECORD +21 -19
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev514.dist-info → ommlds-0.0.0.dev516.dist-info}/top_level.txt +0 -0
ommlds/cli/profiles.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import sys
|
|
3
|
+
import typing as ta
|
|
4
|
+
|
|
5
|
+
from omlish import check
|
|
6
|
+
from omlish import dataclasses as dc
|
|
7
|
+
from omlish import lang
|
|
8
|
+
from omlish.argparse import all as ap
|
|
9
|
+
|
|
10
|
+
from .sessions.chat.configs import ChatConfig
|
|
11
|
+
from .sessions.chat.interfaces.bare.configs import BareInterfaceConfig
|
|
12
|
+
from .sessions.chat.interfaces.configs import InterfaceConfig
|
|
13
|
+
from .sessions.chat.interfaces.textual.configs import TextualInterfaceConfig
|
|
14
|
+
from .sessions.completion.configs import CompletionConfig
|
|
15
|
+
from .sessions.configs import SessionConfig
|
|
16
|
+
from .sessions.embedding.configs import EmbeddingConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
SessionConfigT = ta.TypeVar('SessionConfigT', bound=SessionConfig)
|
|
20
|
+
SessionConfigU = ta.TypeVar('SessionConfigU', bound=SessionConfig)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Profile(lang.Abstract, ta.Generic[SessionConfigT]):
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def configure(self, argv: ta.Sequence[str]) -> SessionConfigT:
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
##
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProfileAspect(lang.Abstract, ta.Generic[SessionConfigT]):
|
|
36
|
+
@property
|
|
37
|
+
def name(self) -> str:
|
|
38
|
+
return lang.camel_to_snake(type(self).__name__).lower()
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def default_parser_arg_group(self) -> str | None:
|
|
42
|
+
return self.name
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def parser_args(self) -> ta.Sequence[ap.Arg]:
|
|
46
|
+
return []
|
|
47
|
+
|
|
48
|
+
@ta.final
|
|
49
|
+
@dc.dataclass(frozen=True)
|
|
50
|
+
class ConfigureContext(ta.Generic[SessionConfigU]):
|
|
51
|
+
profile: 'Profile[SessionConfigU]'
|
|
52
|
+
args: ap.Namespace
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def configure(self, ctx: ConfigureContext[SessionConfigT], cfg: SessionConfigT) -> SessionConfigT:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AspectProfile(Profile[SessionConfigT], lang.Abstract):
|
|
60
|
+
@abc.abstractmethod
|
|
61
|
+
def _build_aspects(self) -> ta.Sequence[ProfileAspect[SessionConfigT]]:
|
|
62
|
+
return []
|
|
63
|
+
|
|
64
|
+
__aspects: ta.Sequence[ProfileAspect[SessionConfigT]]
|
|
65
|
+
|
|
66
|
+
@ta.final
|
|
67
|
+
@property
|
|
68
|
+
def aspects(self) -> ta.Sequence[ProfileAspect[SessionConfigT]]:
|
|
69
|
+
try:
|
|
70
|
+
return self.__aspects
|
|
71
|
+
except AttributeError:
|
|
72
|
+
pass
|
|
73
|
+
self.__aspects = aspects = tuple(self._build_aspects())
|
|
74
|
+
return aspects
|
|
75
|
+
|
|
76
|
+
#
|
|
77
|
+
|
|
78
|
+
@abc.abstractmethod
|
|
79
|
+
def initial_config(self) -> SessionConfigT:
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
#
|
|
83
|
+
|
|
84
|
+
def configure(self, argv: ta.Sequence[str]) -> SessionConfigT:
|
|
85
|
+
parser = ap.ArgumentParser()
|
|
86
|
+
|
|
87
|
+
pa_grps: dict[str, ta.Any] = {}
|
|
88
|
+
for a in self.aspects:
|
|
89
|
+
for pa in a.parser_args:
|
|
90
|
+
if (pa_gn := lang.opt_coalesce(pa.group, a.default_parser_arg_group)) is not None:
|
|
91
|
+
check.non_empty_str(pa_gn)
|
|
92
|
+
try:
|
|
93
|
+
pa_grp = pa_grps[pa_gn]
|
|
94
|
+
except KeyError:
|
|
95
|
+
pa_grps[pa_gn] = pa_grp = parser.add_argument_group(pa_gn)
|
|
96
|
+
pa_grp.add_argument(*pa.args, **pa.kwargs)
|
|
97
|
+
else:
|
|
98
|
+
parser.add_argument(*pa.args, **pa.kwargs)
|
|
99
|
+
|
|
100
|
+
args = parser.parse_args(argv)
|
|
101
|
+
|
|
102
|
+
cfg_ctx = ProfileAspect.ConfigureContext(
|
|
103
|
+
self,
|
|
104
|
+
args,
|
|
105
|
+
)
|
|
106
|
+
cfg = self.initial_config()
|
|
107
|
+
for a in self.aspects:
|
|
108
|
+
cfg = a.configure(cfg_ctx, cfg)
|
|
109
|
+
|
|
110
|
+
return cfg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
##
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ChatProfile(AspectProfile[ChatConfig]):
|
|
117
|
+
class Backend(ProfileAspect[ChatConfig]):
|
|
118
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
119
|
+
ap.arg('-b', '--backend'),
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
123
|
+
return dc.replace(
|
|
124
|
+
cfg,
|
|
125
|
+
driver=dc.replace(
|
|
126
|
+
cfg.driver,
|
|
127
|
+
backend=dc.replace(
|
|
128
|
+
cfg.driver.backend,
|
|
129
|
+
backend=ctx.args.backend,
|
|
130
|
+
),
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
#
|
|
135
|
+
|
|
136
|
+
class Interface(ProfileAspect[ChatConfig]):
|
|
137
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
138
|
+
ap.arg('-i', '--interactive', action='store_true'),
|
|
139
|
+
ap.arg('-T', '--textual', action='store_true'),
|
|
140
|
+
ap.arg('-e', '--editor', action='store_true'),
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
144
|
+
if ctx.args.editor:
|
|
145
|
+
check.arg(not ctx.args.interactive)
|
|
146
|
+
check.arg(not ctx.args.message)
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
if ctx.args.textual:
|
|
150
|
+
check.isinstance(cfg.interface, BareInterfaceConfig)
|
|
151
|
+
cfg = dc.replace(
|
|
152
|
+
cfg,
|
|
153
|
+
interface=TextualInterfaceConfig(**{
|
|
154
|
+
f.name: getattr(cfg.interface, f.name)
|
|
155
|
+
for f in dc.fields(InterfaceConfig)
|
|
156
|
+
}),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
else:
|
|
160
|
+
cfg = dc.replace(
|
|
161
|
+
cfg,
|
|
162
|
+
driver=dc.replace(
|
|
163
|
+
cfg.driver,
|
|
164
|
+
ai=dc.replace(
|
|
165
|
+
cfg.driver.ai,
|
|
166
|
+
verbose=True,
|
|
167
|
+
),
|
|
168
|
+
),
|
|
169
|
+
interface=dc.replace(
|
|
170
|
+
check.isinstance(cfg.interface, BareInterfaceConfig),
|
|
171
|
+
interactive=ctx.args.interactive,
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return cfg
|
|
176
|
+
|
|
177
|
+
#
|
|
178
|
+
|
|
179
|
+
class Input(ProfileAspect[ChatConfig]):
|
|
180
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
181
|
+
ap.arg('message', nargs='*'),
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
185
|
+
if ctx.args.interactive or ctx.args.textual:
|
|
186
|
+
check.arg(not ctx.args.message)
|
|
187
|
+
|
|
188
|
+
elif ctx.args.message:
|
|
189
|
+
ps: list[str] = []
|
|
190
|
+
|
|
191
|
+
for a in ctx.args.message:
|
|
192
|
+
if a == '-':
|
|
193
|
+
ps.append(sys.stdin.read())
|
|
194
|
+
|
|
195
|
+
elif a.startswith('@'):
|
|
196
|
+
with open(a[1:]) as f:
|
|
197
|
+
ps.append(f.read())
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
ps.append(a)
|
|
201
|
+
|
|
202
|
+
c = ' '.join(ps)
|
|
203
|
+
|
|
204
|
+
cfg = dc.replace(
|
|
205
|
+
cfg,
|
|
206
|
+
driver=dc.replace(
|
|
207
|
+
cfg.driver,
|
|
208
|
+
user=dc.replace(
|
|
209
|
+
cfg.driver.user,
|
|
210
|
+
initial_user_content=c,
|
|
211
|
+
),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError('Must specify input')
|
|
217
|
+
|
|
218
|
+
return cfg
|
|
219
|
+
|
|
220
|
+
#
|
|
221
|
+
|
|
222
|
+
class State(ProfileAspect[ChatConfig]):
|
|
223
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
224
|
+
ap.arg('-n', '--new', action='store_true'),
|
|
225
|
+
ap.arg('--ephemeral', action='store_true'),
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
229
|
+
return dc.replace(
|
|
230
|
+
cfg,
|
|
231
|
+
driver=dc.replace(
|
|
232
|
+
cfg.driver,
|
|
233
|
+
state=dc.replace(
|
|
234
|
+
cfg.driver.state,
|
|
235
|
+
state='ephemeral' if ctx.args.ephemeral else 'new' if ctx.args.new else 'continue',
|
|
236
|
+
),
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
#
|
|
241
|
+
|
|
242
|
+
class Output(ProfileAspect[ChatConfig]):
|
|
243
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
244
|
+
ap.arg('-s', '--stream', action='store_true'),
|
|
245
|
+
ap.arg('-M', '--markdown', action='store_true'),
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
249
|
+
return dc.replace(
|
|
250
|
+
cfg,
|
|
251
|
+
driver=dc.replace(
|
|
252
|
+
cfg.driver,
|
|
253
|
+
ai=dc.replace(
|
|
254
|
+
cfg.driver.ai,
|
|
255
|
+
stream=bool(ctx.args.stream),
|
|
256
|
+
),
|
|
257
|
+
),
|
|
258
|
+
rendering=dc.replace(
|
|
259
|
+
cfg.rendering,
|
|
260
|
+
markdown=bool(ctx.args.markdown),
|
|
261
|
+
),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
#
|
|
265
|
+
|
|
266
|
+
class Tools(ProfileAspect[ChatConfig]):
|
|
267
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
268
|
+
ap.arg('--enable-fs-tools', action='store_true'),
|
|
269
|
+
ap.arg('--enable-todo-tools', action='store_true'),
|
|
270
|
+
# ap.arg('--enable-unsafe-tools-do-not-use-lol', action='store_true'),
|
|
271
|
+
ap.arg('--enable-test-weather-tool', action='store_true'),
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
def configure_with_tools(
|
|
275
|
+
self,
|
|
276
|
+
ctx: ProfileAspect.ConfigureContext[ChatConfig],
|
|
277
|
+
cfg: ChatConfig,
|
|
278
|
+
enabled_tools: ta.Iterable[str],
|
|
279
|
+
) -> ChatConfig:
|
|
280
|
+
check.not_isinstance(enabled_tools, str)
|
|
281
|
+
|
|
282
|
+
return dc.replace(
|
|
283
|
+
cfg,
|
|
284
|
+
driver=dc.replace(
|
|
285
|
+
cfg.driver,
|
|
286
|
+
ai=dc.replace(
|
|
287
|
+
cfg.driver.ai,
|
|
288
|
+
enable_tools=True,
|
|
289
|
+
),
|
|
290
|
+
tools=dc.replace(
|
|
291
|
+
cfg.driver.tools,
|
|
292
|
+
enabled_tools={ # noqa
|
|
293
|
+
*(cfg.driver.tools.enabled_tools or []),
|
|
294
|
+
*enabled_tools,
|
|
295
|
+
},
|
|
296
|
+
),
|
|
297
|
+
),
|
|
298
|
+
interface=dc.replace(
|
|
299
|
+
cfg.interface,
|
|
300
|
+
enable_tools=True,
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
305
|
+
if not (
|
|
306
|
+
ctx.args.enable_fs_tools or
|
|
307
|
+
ctx.args.enable_todo_tools or
|
|
308
|
+
# ctx.args.enable_unsafe_tools_do_not_use_lol or
|
|
309
|
+
ctx.args.enable_test_weather_tool
|
|
310
|
+
):
|
|
311
|
+
return cfg
|
|
312
|
+
|
|
313
|
+
return self.configure_with_tools(ctx, cfg, {
|
|
314
|
+
*(['fs'] if ctx.args.enable_fs_tools else []),
|
|
315
|
+
*(['todo'] if ctx.args.enable_todo_tools else []),
|
|
316
|
+
*(['weather'] if ctx.args.enable_test_weather_tool else []),
|
|
317
|
+
})
|
|
318
|
+
|
|
319
|
+
#
|
|
320
|
+
|
|
321
|
+
class Code(ProfileAspect[ChatConfig]):
|
|
322
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
323
|
+
ap.arg('-c', '--code', action='store_true'),
|
|
324
|
+
]
|
|
325
|
+
|
|
326
|
+
def configure_for_code(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
327
|
+
cfg = dc.replace(
|
|
328
|
+
cfg,
|
|
329
|
+
driver=dc.replace(
|
|
330
|
+
cfg.driver,
|
|
331
|
+
ai=dc.replace(
|
|
332
|
+
cfg.driver.ai,
|
|
333
|
+
enable_tools=True,
|
|
334
|
+
),
|
|
335
|
+
),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if ctx.args.new or ctx.args.ephemeral:
|
|
339
|
+
from ..minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
|
|
340
|
+
system_content = CODE_AGENT_SYSTEM_PROMPT
|
|
341
|
+
|
|
342
|
+
cfg = dc.replace(
|
|
343
|
+
cfg,
|
|
344
|
+
driver=dc.replace(
|
|
345
|
+
cfg.driver,
|
|
346
|
+
user=dc.replace(
|
|
347
|
+
cfg.driver.user,
|
|
348
|
+
initial_system_content=system_content,
|
|
349
|
+
),
|
|
350
|
+
),
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
return cfg
|
|
354
|
+
|
|
355
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
356
|
+
if not ctx.args.code:
|
|
357
|
+
return cfg
|
|
358
|
+
|
|
359
|
+
return self.configure_for_code(ctx, cfg)
|
|
360
|
+
|
|
361
|
+
#
|
|
362
|
+
|
|
363
|
+
def _build_aspects(self) -> ta.Sequence[ProfileAspect[ChatConfig]]:
|
|
364
|
+
return [
|
|
365
|
+
*super()._build_aspects(),
|
|
366
|
+
self.Backend(),
|
|
367
|
+
self.Interface(),
|
|
368
|
+
self.Input(),
|
|
369
|
+
self.State(),
|
|
370
|
+
self.Output(),
|
|
371
|
+
self.Tools(),
|
|
372
|
+
self.Code(),
|
|
373
|
+
]
|
|
374
|
+
|
|
375
|
+
def initial_config(self) -> ChatConfig:
|
|
376
|
+
return ChatConfig()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
#
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class CodeProfile(ChatProfile):
|
|
383
|
+
class Tools(ChatProfile.Tools):
|
|
384
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = []
|
|
385
|
+
|
|
386
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
387
|
+
return self.configure_with_tools(ctx, cfg, {
|
|
388
|
+
'fs',
|
|
389
|
+
'todo',
|
|
390
|
+
})
|
|
391
|
+
|
|
392
|
+
class Code(ChatProfile.Code):
|
|
393
|
+
parser_args: ta.ClassVar[ta.Sequence[ap.Arg]] = []
|
|
394
|
+
|
|
395
|
+
def configure(self, ctx: ProfileAspect.ConfigureContext[ChatConfig], cfg: ChatConfig) -> ChatConfig:
|
|
396
|
+
return self.configure_for_code(ctx, cfg)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
##
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class CompletionProfile(Profile):
|
|
403
|
+
def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
|
|
404
|
+
parser = ap.ArgumentParser()
|
|
405
|
+
parser.add_argument('prompt', nargs='*')
|
|
406
|
+
parser.add_argument('-b', '--backend', default='openai')
|
|
407
|
+
args = parser.parse_args(argv)
|
|
408
|
+
|
|
409
|
+
content = ' '.join(args.prompt)
|
|
410
|
+
|
|
411
|
+
cfg = CompletionConfig(
|
|
412
|
+
content=check.non_empty_str(content),
|
|
413
|
+
backend=args.backend,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
return cfg
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
##
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class EmbedProfile(Profile):
|
|
423
|
+
def configure(self, argv: ta.Sequence[str]) -> SessionConfig:
|
|
424
|
+
parser = ap.ArgumentParser()
|
|
425
|
+
parser.add_argument('prompt', nargs='*')
|
|
426
|
+
parser.add_argument('-b', '--backend', default='openai')
|
|
427
|
+
args = parser.parse_args(argv)
|
|
428
|
+
|
|
429
|
+
content = ' '.join(args.prompt)
|
|
430
|
+
|
|
431
|
+
cfg = EmbeddingConfig(
|
|
432
|
+
content=check.non_empty_str(content),
|
|
433
|
+
backend=args.backend,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
return cfg
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
##
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
PROFILE_TYPES: ta.Mapping[str, type[Profile]] = {
|
|
443
|
+
'chat': ChatProfile,
|
|
444
|
+
'code': CodeProfile,
|
|
445
|
+
|
|
446
|
+
'complete': CompletionProfile,
|
|
447
|
+
|
|
448
|
+
'embed': EmbedProfile,
|
|
449
|
+
}
|
|
@@ -16,6 +16,7 @@ from omlish.logs import all as logs
|
|
|
16
16
|
|
|
17
17
|
from ...... import minichain as mc
|
|
18
18
|
from .....backends.types import BackendName
|
|
19
|
+
from ....types import SessionProfileName
|
|
19
20
|
from ...drivers.events.types import AiDeltaChatEvent
|
|
20
21
|
from ...drivers.events.types import AiMessagesChatEvent
|
|
21
22
|
from ...drivers.types import ChatDriver
|
|
@@ -100,6 +101,7 @@ class ChatApp(
|
|
|
100
101
|
backend_name: BackendName | None = None,
|
|
101
102
|
devtools_setup: tx.DevtoolsSetup | None = None,
|
|
102
103
|
input_history_manager: InputHistoryManager,
|
|
104
|
+
session_profile_name: SessionProfileName | None = None,
|
|
103
105
|
) -> None:
|
|
104
106
|
super().__init__()
|
|
105
107
|
|
|
@@ -111,6 +113,7 @@ class ChatApp(
|
|
|
111
113
|
self._chat_event_queue = chat_event_queue
|
|
112
114
|
self._backend_name = backend_name
|
|
113
115
|
self._input_history_manager = input_history_manager
|
|
116
|
+
self._session_profile_name = session_profile_name
|
|
114
117
|
|
|
115
118
|
self._chat_action_queue: asyncio.Queue[ta.Any] = asyncio.Queue()
|
|
116
119
|
|
|
@@ -328,6 +331,7 @@ class ChatApp(
|
|
|
328
331
|
|
|
329
332
|
await self._mount_messages(
|
|
330
333
|
WelcomeMessage('\n'.join([
|
|
334
|
+
*([f'Profile: {self._session_profile_name}'] if self._session_profile_name is not None else []),
|
|
331
335
|
f'Backend: {self._backend_name or "?"}',
|
|
332
336
|
f'Dir: {os.getcwd()}',
|
|
333
337
|
])),
|
|
@@ -335,7 +339,7 @@ class ChatApp(
|
|
|
335
339
|
|
|
336
340
|
async def on_unmount(self) -> None:
|
|
337
341
|
if (cat := self._chat_action_queue_task) is not None:
|
|
338
|
-
await self.
|
|
342
|
+
await self._chat_action_queue.put(None)
|
|
339
343
|
await cat
|
|
340
344
|
|
|
341
345
|
await self._chat_driver.stop()
|
|
@@ -359,7 +363,7 @@ class ChatApp(
|
|
|
359
363
|
),
|
|
360
364
|
)
|
|
361
365
|
|
|
362
|
-
self._input_history_manager.add(event.text)
|
|
366
|
+
await self._input_history_manager.add(event.text)
|
|
363
367
|
|
|
364
368
|
await self._chat_action_queue.put(ChatApp.UserInput(event.text))
|
|
365
369
|
|
|
@@ -371,12 +375,14 @@ class ChatApp(
|
|
|
371
375
|
|
|
372
376
|
@tx.on(InputTextArea.HistoryPrevious)
|
|
373
377
|
async def on_input_text_area_history_previous(self, event: InputTextArea.HistoryPrevious) -> None:
|
|
378
|
+
await self._input_history_manager.load_if_necessary()
|
|
374
379
|
if (entry := self._input_history_manager.get_previous(event.text)) is not None:
|
|
375
380
|
self._get_input_text_area().text = entry
|
|
376
381
|
self._move_input_cursor_to_end()
|
|
377
382
|
|
|
378
383
|
@tx.on(InputTextArea.HistoryNext)
|
|
379
384
|
async def on_input_text_area_history_next(self, event: InputTextArea.HistoryNext) -> None:
|
|
385
|
+
await self._input_history_manager.load_if_necessary()
|
|
380
386
|
if (entry := self._input_history_manager.get_next(event.text)) is not None:
|
|
381
387
|
ita = self._get_input_text_area()
|
|
382
388
|
ita.text = entry
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
import json
|
|
3
2
|
import os
|
|
4
3
|
import typing as ta
|
|
5
4
|
|
|
6
5
|
from omlish import lang
|
|
6
|
+
from omlish.formats import json
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
##
|
|
@@ -11,11 +11,11 @@ from omlish import lang
|
|
|
11
11
|
|
|
12
12
|
class InputHistoryStorage(lang.Abstract):
|
|
13
13
|
@abc.abstractmethod
|
|
14
|
-
def load(self) -> list[str]:
|
|
14
|
+
def load(self) -> ta.Awaitable[list[str]]:
|
|
15
15
|
raise NotImplementedError
|
|
16
16
|
|
|
17
17
|
@abc.abstractmethod
|
|
18
|
-
def save(self, entries: ta.Sequence[str]) -> None:
|
|
18
|
+
def save(self, entries: ta.Sequence[str]) -> ta.Awaitable[None]:
|
|
19
19
|
raise NotImplementedError
|
|
20
20
|
|
|
21
21
|
|
|
@@ -25,10 +25,10 @@ class InMemoryInputHistoryStorage(InputHistoryStorage):
|
|
|
25
25
|
|
|
26
26
|
self._entries: list[str] = []
|
|
27
27
|
|
|
28
|
-
def load(self) -> list[str]:
|
|
28
|
+
async def load(self) -> list[str]:
|
|
29
29
|
return list(self._entries)
|
|
30
30
|
|
|
31
|
-
def save(self, entries: ta.Sequence[str]) -> None:
|
|
31
|
+
async def save(self, entries: ta.Sequence[str]) -> None:
|
|
32
32
|
self._entries = list(entries)
|
|
33
33
|
|
|
34
34
|
|
|
@@ -38,26 +38,31 @@ class FileInputHistoryStorage(InputHistoryStorage):
|
|
|
38
38
|
|
|
39
39
|
self._path = path
|
|
40
40
|
|
|
41
|
-
def load(self) -> list[str]:
|
|
41
|
+
async def load(self) -> list[str]:
|
|
42
42
|
if not os.path.exists(self._path):
|
|
43
43
|
return []
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
|
-
with open(self._path) as f:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
return data
|
|
50
|
-
return []
|
|
51
|
-
except (json.JSONDecodeError, OSError):
|
|
46
|
+
with open(self._path) as f: # noqa
|
|
47
|
+
content = f.read()
|
|
48
|
+
except OSError:
|
|
52
49
|
return []
|
|
53
50
|
|
|
54
|
-
|
|
51
|
+
data = json.loads(content)
|
|
52
|
+
|
|
53
|
+
if isinstance(data, list) and all(isinstance(e, str) for e in data):
|
|
54
|
+
return data
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
async def save(self, entries: ta.Sequence[str]) -> None:
|
|
58
|
+
content = json.dumps_pretty(list(entries))
|
|
59
|
+
dir_path = os.path.dirname(self._path)
|
|
60
|
+
|
|
55
61
|
try:
|
|
56
|
-
dir_path = os.path.dirname(self._path)
|
|
57
62
|
if dir_path:
|
|
58
63
|
os.makedirs(dir_path, exist_ok=True)
|
|
59
|
-
with open(self._path, 'w') as f:
|
|
60
|
-
|
|
64
|
+
with open(self._path, 'w') as f: # noqa
|
|
65
|
+
f.write(content)
|
|
61
66
|
except OSError:
|
|
62
67
|
pass
|
|
63
68
|
|
|
@@ -88,16 +93,34 @@ class InputHistoryManager:
|
|
|
88
93
|
self._storage = storage
|
|
89
94
|
self._max_entries = max_entries
|
|
90
95
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
96
|
+
#
|
|
97
|
+
|
|
98
|
+
_entries: list[str]
|
|
99
|
+
_position: int = 0
|
|
100
|
+
|
|
101
|
+
async def load_if_necessary(self) -> None:
|
|
102
|
+
try:
|
|
103
|
+
self._entries # noqa
|
|
104
|
+
except AttributeError:
|
|
105
|
+
pass
|
|
106
|
+
else:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
self._entries = await self._storage.load()
|
|
110
|
+
self._position = len(self._entries)
|
|
94
111
|
|
|
95
|
-
|
|
112
|
+
#
|
|
113
|
+
|
|
114
|
+
_current_draft: str = ''
|
|
115
|
+
|
|
116
|
+
async def add(self, text: str) -> None:
|
|
96
117
|
"""Add a new history entry and reset position."""
|
|
97
118
|
|
|
98
119
|
if not text.strip():
|
|
99
120
|
return
|
|
100
121
|
|
|
122
|
+
await self.load_if_necessary()
|
|
123
|
+
|
|
101
124
|
# Don't add duplicate consecutive entries
|
|
102
125
|
if self._entries and self._entries[-1] == text:
|
|
103
126
|
self.reset_position()
|
|
@@ -109,7 +132,7 @@ class InputHistoryManager:
|
|
|
109
132
|
if len(self._entries) > self._max_entries:
|
|
110
133
|
self._entries = self._entries[-self._max_entries:]
|
|
111
134
|
|
|
112
|
-
self._storage.save(self._entries)
|
|
135
|
+
await self._storage.save(self._entries)
|
|
113
136
|
self.reset_position()
|
|
114
137
|
|
|
115
138
|
def get_previous(self, text: str | None = None) -> str | None:
|
|
@@ -123,20 +146,24 @@ class InputHistoryManager:
|
|
|
123
146
|
The previous history entry, or None if at the beginning
|
|
124
147
|
"""
|
|
125
148
|
|
|
126
|
-
|
|
149
|
+
try:
|
|
150
|
+
entries = self._entries
|
|
151
|
+
except AttributeError:
|
|
152
|
+
return None
|
|
153
|
+
if entries:
|
|
127
154
|
return None
|
|
128
155
|
|
|
129
156
|
# Save current draft if we're at the end
|
|
130
|
-
if self._position == len(
|
|
157
|
+
if self._position == len(entries) and text is not None:
|
|
131
158
|
self._current_draft = text
|
|
132
159
|
|
|
133
160
|
# Move to previous entry
|
|
134
161
|
if self._position > 0:
|
|
135
162
|
self._position -= 1
|
|
136
|
-
return
|
|
163
|
+
return entries[self._position]
|
|
137
164
|
|
|
138
165
|
# Already at oldest entry
|
|
139
|
-
return
|
|
166
|
+
return entries[0] if entries else None
|
|
140
167
|
|
|
141
168
|
def get_next(self, text: str | None = None) -> str | None:
|
|
142
169
|
"""
|
|
@@ -149,20 +176,24 @@ class InputHistoryManager:
|
|
|
149
176
|
The next history entry, the saved draft if moving past the end, or None
|
|
150
177
|
"""
|
|
151
178
|
|
|
152
|
-
|
|
179
|
+
try:
|
|
180
|
+
entries = self._entries
|
|
181
|
+
except AttributeError:
|
|
182
|
+
return None
|
|
183
|
+
if entries:
|
|
153
184
|
return None
|
|
154
185
|
|
|
155
186
|
# Move to next entry
|
|
156
|
-
if self._position < len(
|
|
187
|
+
if self._position < len(entries):
|
|
157
188
|
self._position += 1
|
|
158
189
|
|
|
159
190
|
# If we moved past the end, return the draft
|
|
160
|
-
if self._position == len(
|
|
191
|
+
if self._position == len(entries):
|
|
161
192
|
draft = self._current_draft
|
|
162
193
|
self._current_draft = ''
|
|
163
194
|
return draft
|
|
164
195
|
|
|
165
|
-
return
|
|
196
|
+
return entries[self._position]
|
|
166
197
|
|
|
167
198
|
# Already at newest position
|
|
168
199
|
return None
|
|
@@ -170,5 +201,11 @@ class InputHistoryManager:
|
|
|
170
201
|
def reset_position(self) -> None:
|
|
171
202
|
"""Reset history position to the end (no history item selected)."""
|
|
172
203
|
|
|
173
|
-
|
|
204
|
+
try:
|
|
205
|
+
entries = self._entries
|
|
206
|
+
except AttributeError:
|
|
207
|
+
self._position = 0
|
|
208
|
+
else:
|
|
209
|
+
self._position = len(entries)
|
|
210
|
+
|
|
174
211
|
self._current_draft = ''
|