ommlds 0.0.0.dev468__py3-none-any.whl → 0.0.0.dev470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ommlds might be problematic. Click here for more details.

Files changed (36) hide show
  1. ommlds/.omlish-manifests.json +7 -7
  2. ommlds/_hacks/__init__.py +4 -0
  3. ommlds/_hacks/funcs.py +110 -0
  4. ommlds/_hacks/names.py +158 -0
  5. ommlds/_hacks/params.py +73 -0
  6. ommlds/_hacks/patches.py +0 -3
  7. ommlds/backends/transformers/filecache.py +109 -0
  8. ommlds/cli/main.py +11 -5
  9. ommlds/cli/sessions/chat/backends/catalog.py +13 -8
  10. ommlds/cli/sessions/chat/backends/inject.py +15 -0
  11. ommlds/cli/sessions/completion/session.py +1 -1
  12. ommlds/cli/sessions/embedding/session.py +1 -1
  13. ommlds/minichain/__init__.py +1 -0
  14. ommlds/minichain/backends/catalogs/base.py +20 -1
  15. ommlds/minichain/backends/catalogs/simple.py +2 -2
  16. ommlds/minichain/backends/catalogs/strings.py +9 -7
  17. ommlds/minichain/backends/impls/anthropic/chat.py +5 -1
  18. ommlds/minichain/backends/impls/anthropic/stream.py +10 -5
  19. ommlds/minichain/backends/impls/google/chat.py +9 -2
  20. ommlds/minichain/backends/impls/google/search.py +6 -1
  21. ommlds/minichain/backends/impls/google/stream.py +10 -5
  22. ommlds/minichain/backends/impls/mistral.py +9 -2
  23. ommlds/minichain/backends/impls/ollama/chat.py +12 -9
  24. ommlds/minichain/backends/impls/openai/chat.py +9 -2
  25. ommlds/minichain/backends/impls/openai/completion.py +9 -2
  26. ommlds/minichain/backends/impls/openai/embedding.py +9 -2
  27. ommlds/minichain/backends/impls/openai/stream.py +10 -5
  28. ommlds/minichain/backends/impls/transformers/transformers.py +64 -26
  29. ommlds/minichain/registries/globals.py +18 -4
  30. ommlds/tools/git.py +4 -1
  31. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/METADATA +3 -3
  32. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/RECORD +36 -32
  33. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/WHEEL +0 -0
  34. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/entry_points.txt +0 -0
  35. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/licenses/LICENSE +0 -0
  36. {ommlds-0.0.0.dev468.dist-info → ommlds-0.0.0.dev470.dist-info}/top_level.txt +0 -0
@@ -321,7 +321,7 @@
321
321
  "module": ".minichain.backends.impls.ollama.chat",
322
322
  "attr": null,
323
323
  "file": "ommlds/minichain/backends/impls/ollama/chat.py",
324
- "line": 93,
324
+ "line": 96,
325
325
  "value": {
326
326
  "!.minichain.registries.manifests.RegistryManifest": {
327
327
  "module": "ommlds.minichain.backends.impls.ollama.chat",
@@ -336,7 +336,7 @@
336
336
  "module": ".minichain.backends.impls.ollama.chat",
337
337
  "attr": null,
338
338
  "file": "ommlds/minichain/backends/impls/ollama/chat.py",
339
- "line": 139,
339
+ "line": 143,
340
340
  "value": {
341
341
  "!.minichain.registries.manifests.RegistryManifest": {
342
342
  "module": "ommlds.minichain.backends.impls.ollama.chat",
@@ -526,7 +526,7 @@
526
526
  "module": ".minichain.backends.impls.transformers.transformers",
527
527
  "attr": null,
528
528
  "file": "ommlds/minichain/backends/impls/transformers/transformers.py",
529
- "line": 46,
529
+ "line": 50,
530
530
  "value": {
531
531
  "!.minichain.backends.strings.manifests.BackendStringsManifest": {
532
532
  "service_cls_names": [
@@ -542,7 +542,7 @@
542
542
  "module": ".minichain.backends.impls.transformers.transformers",
543
543
  "attr": null,
544
544
  "file": "ommlds/minichain/backends/impls/transformers/transformers.py",
545
- "line": 62,
545
+ "line": 66,
546
546
  "value": {
547
547
  "!.minichain.registries.manifests.RegistryManifest": {
548
548
  "module": "ommlds.minichain.backends.impls.transformers.transformers",
@@ -559,7 +559,7 @@
559
559
  "module": ".minichain.backends.impls.transformers.transformers",
560
560
  "attr": null,
561
561
  "file": "ommlds/minichain/backends/impls/transformers/transformers.py",
562
- "line": 189,
562
+ "line": 197,
563
563
  "value": {
564
564
  "!.minichain.registries.manifests.RegistryManifest": {
565
565
  "module": "ommlds.minichain.backends.impls.transformers.transformers",
@@ -576,7 +576,7 @@
576
576
  "module": ".minichain.backends.impls.transformers.transformers",
577
577
  "attr": null,
578
578
  "file": "ommlds/minichain/backends/impls/transformers/transformers.py",
579
- "line": 219,
579
+ "line": 227,
580
580
  "value": {
581
581
  "!.minichain.registries.manifests.RegistryManifest": {
582
582
  "module": "ommlds.minichain.backends.impls.transformers.transformers",
@@ -714,7 +714,7 @@
714
714
  "module": ".tools.git",
715
715
  "attr": null,
716
716
  "file": "ommlds/tools/git.py",
717
- "line": 186,
717
+ "line": 189,
718
718
  "value": {
719
719
  "!omdev.tools.git.messages.GitMessageGeneratorManifest": {
720
720
  "module": "ommlds.tools.git",
ommlds/_hacks/__init__.py CHANGED
@@ -0,0 +1,4 @@
1
+ """
2
+ Currently, and ideally, the only thing this codebase will ever interact with requiring these kinds of awful hacks is ML
3
+ code - thus these 'helpers' are kept here, not in the core library.
4
+ """
ommlds/_hacks/funcs.py ADDED
@@ -0,0 +1,110 @@
1
+ import dataclasses as dc
2
+ import linecache
3
+ import textwrap
4
+ import threading
5
+ import types
6
+ import typing as ta
7
+ import uuid
8
+ import warnings
9
+
10
+ from omlish import check
11
+ from omlish import lang
12
+
13
+ from .names import NamespaceBuilder
14
+ from .params import render_param_spec_call
15
+ from .params import render_param_spec_def
16
+
17
+
18
+ ##
19
+
20
+
21
+ @dc.dataclass()
22
+ class _ReservedFilenameEntry:
23
+ unique_id: str
24
+ seq: int = 0
25
+
26
+
27
+ _RESERVED_FILENAME_UUID_TLS = threading.local()
28
+
29
+
30
+ def reserve_linecache_filename(prefix: str) -> str:
31
+ try:
32
+ e = _RESERVED_FILENAME_UUID_TLS.unique_id
33
+ except AttributeError:
34
+ e = _RESERVED_FILENAME_UUID_TLS.unique_id = _ReservedFilenameEntry(str(uuid.uuid4()))
35
+
36
+ while True:
37
+ unique_filename = f'<generated:{prefix}:{e.seq}>'
38
+ cache_line = (1, None, (e.unique_id,), unique_filename)
39
+ e.seq += 1
40
+ if linecache.cache.setdefault(unique_filename, cache_line) == cache_line: # type: ignore
41
+ return unique_filename
42
+
43
+
44
+ ##
45
+
46
+
47
+ def create_function(
48
+ name: str,
49
+ params: lang.CanParamSpec,
50
+ body: str,
51
+ *,
52
+ globals: ta.Mapping[str, ta.Any] | None = None, # noqa
53
+ locals: ta.Mapping[str, ta.Any] | None = None, # noqa
54
+ indent: str = ' ',
55
+ ) -> types.FunctionType:
56
+ params = lang.ParamSpec.of(params)
57
+ check.isinstance(body, str)
58
+ locals = dict(locals or {}) # noqa
59
+
60
+ nsb = NamespaceBuilder(reserved_names=set(locals) | set(globals or []))
61
+ sig = render_param_spec_def(params, nsb)
62
+ for k, v in nsb.items():
63
+ check.not_in(k, locals)
64
+ locals[k] = v
65
+
66
+ body_txt = '\n'.join([
67
+ f'def {name}{sig}:',
68
+ textwrap.indent(textwrap.dedent(body.strip()), indent),
69
+ ])
70
+
71
+ exec_txt = '\n'.join([
72
+ f'def __create_fn__({", ".join(locals.keys())}):',
73
+ textwrap.indent(body_txt, indent),
74
+ f'{indent}return {name}',
75
+ ])
76
+
77
+ ns: dict = {}
78
+ filename = reserve_linecache_filename(name)
79
+ bytecode = compile(exec_txt, filename, 'exec')
80
+ eval(bytecode, globals or {}, ns) # type: ignore # noqa
81
+
82
+ fn = ns['__create_fn__'](**locals)
83
+ fn.__source__ = body_txt
84
+ linecache.cache[filename] = (len(exec_txt), None, exec_txt.splitlines(True), filename)
85
+ return fn
86
+
87
+
88
+ ##
89
+
90
+
91
+ def create_detour(
92
+ params: lang.CanParamSpec,
93
+ target: ta.Callable,
94
+ *,
95
+ as_kwargs: bool = False,
96
+ ) -> types.CodeType:
97
+ params = lang.ParamSpec.of(params)
98
+ check.callable(target)
99
+
100
+ with warnings.catch_warnings():
101
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
102
+
103
+ gfn = create_function(
104
+ '_',
105
+ params,
106
+ f'return 1{render_param_spec_call(params, as_kwargs=as_kwargs)}',
107
+ )
108
+
109
+ check.state(gfn.__code__.co_consts[:2] == (None, 1))
110
+ return gfn.__code__.replace(co_consts=(None, target, *gfn.__code__.co_consts[2:]))
ommlds/_hacks/names.py ADDED
@@ -0,0 +1,158 @@
1
+ """
2
+ TODO:
3
+ - releaseable names
4
+ """
5
+ import string
6
+ import typing as ta
7
+
8
+ from omlish import check
9
+ from omlish import collections as col
10
+
11
+
12
+ ##
13
+
14
+
15
+ class NameGenerator(ta.Protocol):
16
+ def __call__(self, prefix: str = '') -> str: ...
17
+
18
+
19
+ #
20
+
21
+
22
+ class NameGeneratorImpl:
23
+ DEFAULT_PREFIX: ta.ClassVar[str] = '_'
24
+
25
+ def __init__(
26
+ self,
27
+ *,
28
+ reserved_names: ta.Iterable[str] | None = None,
29
+ global_prefix: str | None = None,
30
+ use_global_prefix_if_present: bool = False,
31
+ add_global_prefix_before_number: bool = False,
32
+ ) -> None:
33
+ super().__init__()
34
+
35
+ check.arg(not isinstance(reserved_names, str))
36
+ self._reserved_names = {check.isinstance(n, str) for n in (reserved_names or [])}
37
+ self._global_prefix = global_prefix if global_prefix is not None else self.DEFAULT_PREFIX
38
+ self._use_global_prefix_if_present = bool(use_global_prefix_if_present)
39
+ self._add_global_prefix_before_number = bool(add_global_prefix_before_number)
40
+
41
+ self._name_counts: dict[str, int] = {}
42
+
43
+ def __call__(self, prefix: str = '') -> str:
44
+ if self._use_global_prefix_if_present and prefix.startswith(self._global_prefix):
45
+ base_name = prefix
46
+ else:
47
+ base_name = self._global_prefix + prefix
48
+
49
+ base_count = -1
50
+ if base_name[-1] in string.digits:
51
+ i = len(base_name) - 2
52
+ while i >= 0 and base_name[i] in string.digits:
53
+ i -= 1
54
+ i += 1
55
+ base_count = int(base_name[i:])
56
+ base_name = base_name[:i]
57
+
58
+ if self._add_global_prefix_before_number:
59
+ if not (self._use_global_prefix_if_present and base_name.endswith(self._global_prefix)):
60
+ base_name += self._global_prefix
61
+
62
+ if base_count >= 0:
63
+ count = self._name_counts.setdefault(base_name, 0)
64
+ if base_count > count:
65
+ self._name_counts[base_name] = base_count
66
+
67
+ while True:
68
+ count = self._name_counts.get(base_name, 0)
69
+ self._name_counts[base_name] = count + 1
70
+ name = base_name + str(count)
71
+ if name not in self._reserved_names:
72
+ return name
73
+
74
+
75
+ name_generator = NameGeneratorImpl
76
+
77
+
78
+ ##
79
+
80
+
81
+ class NamespaceBuilder(ta.Mapping[str, ta.Any]):
82
+ def __init__(
83
+ self,
84
+ *,
85
+ reserved_names: ta.Iterable[str] | None = None,
86
+ name_generator: NameGenerator | None = None, # noqa
87
+ ) -> None:
88
+ super().__init__()
89
+
90
+ self._reserved_names = {
91
+ check.isinstance(n, str)
92
+ for n in (check.not_isinstance(reserved_names, str) or [])
93
+ }
94
+ self._name_generator = check.callable(name_generator) if name_generator is not None else \
95
+ NameGeneratorImpl(reserved_names=self._reserved_names, use_global_prefix_if_present=True)
96
+
97
+ self._dct: ta.MutableMapping[str, ta.Any] = {}
98
+ self._dedupe_dct: ta.MutableMapping[ta.Any, str] = col.IdentityKeyDict()
99
+
100
+ @property
101
+ def reserved_names(self) -> ta.AbstractSet[str]:
102
+ return self._reserved_names
103
+
104
+ @property
105
+ def name_generator(self) -> NameGenerator:
106
+ return self._name_generator
107
+
108
+ def __getitem__(self, k: str) -> ta.Any:
109
+ return self._dct[k]
110
+
111
+ def __len__(self):
112
+ return len(self._dct)
113
+
114
+ def __iter__(self) -> ta.Iterator[str]:
115
+ return iter(self._dct)
116
+
117
+ def items(self) -> ta.ItemsView[str, ta.Any]:
118
+ return self._dct.items()
119
+
120
+ def put(
121
+ self,
122
+ value: ta.Any,
123
+ name: str | None = None,
124
+ *,
125
+ exact: bool = False,
126
+ dedupe: bool = False,
127
+ ) -> str:
128
+ check.arg(not (name is None and exact))
129
+ if name is not None:
130
+ check.isinstance(name, str)
131
+
132
+ if dedupe:
133
+ try:
134
+ return self._dedupe_dct[value]
135
+ except KeyError:
136
+ pass
137
+
138
+ if name is not None:
139
+ if name not in self._reserved_names:
140
+ try:
141
+ existing = self._dct[name]
142
+ except KeyError:
143
+ self._dct[name] = value
144
+ if dedupe:
145
+ self._dedupe_dct[value] = name
146
+ return name
147
+ else:
148
+ if existing is value:
149
+ return name
150
+ if exact:
151
+ raise KeyError(name)
152
+
153
+ gen_name = self._name_generator(name or '')
154
+ check.not_in(gen_name, self._dct)
155
+ self._dct[gen_name] = value
156
+ if dedupe:
157
+ self._dedupe_dct[value] = gen_name
158
+ return gen_name
@@ -0,0 +1,73 @@
1
+ import typing as ta
2
+
3
+ from omlish import lang
4
+
5
+ from .names import NamespaceBuilder
6
+
7
+
8
+ ##
9
+
10
+
11
+ def render_param_spec_call(
12
+ params: lang.ParamSpec,
13
+ *,
14
+ as_kwargs: bool = False,
15
+ ) -> str:
16
+ src = ['(']
17
+
18
+ for i, p in enumerate(params):
19
+ if isinstance(p, lang.ParamSeparator):
20
+ continue
21
+
22
+ if i:
23
+ src.append(', ')
24
+
25
+ if as_kwargs:
26
+ if isinstance(p, lang.Param):
27
+ src.append(f'{p.name}={p.name}')
28
+ else:
29
+ raise TypeError(p)
30
+
31
+ else:
32
+ if isinstance(p, lang.ArgsParam):
33
+ src.append(f'*{p.name}')
34
+ elif isinstance(p, lang.KwargsParam):
35
+ src.append(f'**{p.name}')
36
+ elif isinstance(p, lang.PosOnlyParam):
37
+ src.append(p.name)
38
+ elif isinstance(p, lang.KwOnlyParam):
39
+ src.append(f'{p.name}={p.name}')
40
+ elif isinstance(p, lang.ValParam):
41
+ src.append(p.name)
42
+ else:
43
+ raise TypeError(p)
44
+
45
+ src.append(')')
46
+
47
+ return ''.join(src)
48
+
49
+
50
+ def render_param_spec_def(
51
+ params: lang.ParamSpec,
52
+ nsb: NamespaceBuilder,
53
+ *,
54
+ return_ann: lang.Maybe[ta.Any] = lang.empty(),
55
+ ) -> str:
56
+ src = ['(']
57
+
58
+ for i, p in enumerate(params):
59
+ if i:
60
+ src.append(', ')
61
+
62
+ src.append(lang.param_render(
63
+ p,
64
+ render_annotation=lambda ann: nsb.put(ann, f'ann_{p.name}'), # noqa
65
+ render_default=lambda dfl: nsb.put(dfl, f'dfl_{p.name}'), # noqa
66
+ ))
67
+
68
+ src.append(')')
69
+
70
+ if return_ann.present:
71
+ src.append(f' -> {nsb.put(return_ann.must(), "ann_return")}')
72
+
73
+ return ''.join(src)
ommlds/_hacks/patches.py CHANGED
@@ -1,7 +1,4 @@
1
1
  """
2
- Currently, and ideally, the only thing this codebase will ever interact with requiring these kinds of awful hacks is ML
3
- code - thus these 'helpers' are kept here, not in the core library.
4
-
5
2
  TODO:
6
3
  - patch lock
7
4
  - thread / context local gating
@@ -0,0 +1,109 @@
1
+ import contextlib
2
+ import dataclasses as dc
3
+ import os
4
+ import threading
5
+ import typing as ta
6
+
7
+ import transformers as tfm
8
+
9
+ from omlish import lang
10
+
11
+ from ..._hacks.funcs import create_detour
12
+
13
+
14
+ ##
15
+
16
+
17
+ @dc.dataclass(frozen=True, kw_only=True)
18
+ class _FileCachePatchContext:
19
+ local_first: bool = False
20
+ local_config_present_is_authoritative: bool = False
21
+
22
+
23
+ _FILE_CACHE_PATCH_CONTEXT_TLS = threading.local()
24
+
25
+
26
+ def _get_file_cache_patch_context() -> _FileCachePatchContext:
27
+ try:
28
+ return _FILE_CACHE_PATCH_CONTEXT_TLS.context
29
+ except AttributeError:
30
+ ctx = _FILE_CACHE_PATCH_CONTEXT_TLS.context = _FileCachePatchContext()
31
+ return ctx
32
+
33
+
34
+ _FILE_CACHE_PATCH_LOCK = threading.Lock()
35
+
36
+
37
+ @lang.cached_function(lock=_FILE_CACHE_PATCH_LOCK)
38
+ def patch_file_cache() -> None:
39
+ """
40
+ I tried to make a `local_first_pipeline` function to be called instead of `tfm.pipeline`, I really did, but the
41
+ transformers code is such a disgusting rat's nest full of direct static calls to the caching code strewn about at
42
+ every layer with no concern whatsoever for forwarding kwargs where they need to go.
43
+ """
44
+
45
+ from transformers.utils.hub import cached_files
46
+
47
+ orig_cached_files: ta.Callable[..., str | None] = lang.copy_function(cached_files) # type: ignore
48
+
49
+ get_file_cache_patch_context = _get_file_cache_patch_context
50
+
51
+ def new_cached_files(
52
+ path_or_repo_id: str | os.PathLike,
53
+ filenames: list[str],
54
+ **kwargs: ta.Any,
55
+ ) -> str | None:
56
+ ctx = get_file_cache_patch_context()
57
+
58
+ if ctx.local_first and not kwargs.get('local_files_only'):
59
+ try:
60
+ local = orig_cached_files(
61
+ path_or_repo_id,
62
+ filenames,
63
+ **{**kwargs, 'local_files_only': True},
64
+ )
65
+ except OSError as e: # noqa
66
+ pass
67
+ else:
68
+ return local
69
+
70
+ if ctx.local_config_present_is_authoritative:
71
+ try:
72
+ local_config = orig_cached_files(
73
+ path_or_repo_id,
74
+ [tfm.CONFIG_NAME],
75
+ **{**kwargs, 'local_files_only': True},
76
+ )
77
+ except OSError as e: # noqa
78
+ pass
79
+ else:
80
+ raise OSError(
81
+ f'Files {filenames!r} requested under local_first '
82
+ f'but local_config present at {local_config!r}, '
83
+ f'assuming files do not exist.',
84
+ )
85
+
86
+ return orig_cached_files(path_or_repo_id, filenames, **kwargs)
87
+
88
+ cached_files.__code__ = create_detour(cached_files, new_cached_files, as_kwargs=True)
89
+
90
+
91
+ @contextlib.contextmanager
92
+ def file_cache_patch_context(
93
+ *,
94
+ local_first: bool = False,
95
+ local_config_present_is_authoritative: bool = False,
96
+ ) -> ta.Generator[None]:
97
+ patch_file_cache()
98
+
99
+ new_ctx = dc.replace(
100
+ old_ctx := _get_file_cache_patch_context(),
101
+ local_first=local_first,
102
+ local_config_present_is_authoritative=local_config_present_is_authoritative,
103
+ )
104
+
105
+ _FILE_CACHE_PATCH_CONTEXT_TLS.context = new_ctx
106
+ try:
107
+ yield
108
+ finally:
109
+ _FILE_CACHE_PATCH_CONTEXT_TLS.context = old_ctx
ommlds/cli/main.py CHANGED
@@ -39,10 +39,6 @@ else:
39
39
 
40
40
 
41
41
  async def _a_main(args: ta.Any = None) -> None:
42
- logs.configure_standard_logging('INFO')
43
-
44
- #
45
-
46
42
  parser = argparse.ArgumentParser()
47
43
  parser.add_argument('prompt', nargs='*')
48
44
 
@@ -64,6 +60,8 @@ async def _a_main(args: ta.Any = None) -> None:
64
60
  parser.add_argument('-E', '--embed', action='store_true')
65
61
  parser.add_argument('-j', '--image', action='store_true')
66
62
 
63
+ parser.add_argument('-v', '--verbose', action='store_true')
64
+
67
65
  parser.add_argument('--enable-fs-tools', action='store_true')
68
66
  parser.add_argument('--enable-todo-tools', action='store_true')
69
67
  parser.add_argument('--enable-unsafe-tools-do-not-use-lol', action='store_true')
@@ -74,6 +72,14 @@ async def _a_main(args: ta.Any = None) -> None:
74
72
 
75
73
  #
76
74
 
75
+ if args.verbose:
76
+ logs.configure_standard_logging('DEBUG')
77
+ else:
78
+ logs.configure_standard_logging('INFO')
79
+ logs.silence_noisy_loggers()
80
+
81
+ #
82
+
77
83
  content: mc.Content | None
78
84
 
79
85
  if args.image:
@@ -161,7 +167,7 @@ async def _a_main(args: ta.Any = None) -> None:
161
167
  args.enable_test_weather_tool or
162
168
  args.code
163
169
  ),
164
- enabled_tools={
170
+ enabled_tools={ # noqa
165
171
  *(['fs'] if args.enable_fs_tools else []),
166
172
  *(['todo'] if args.enable_todo_tools else []),
167
173
  *(['weather'] if args.enable_test_weather_tool else []),
@@ -15,33 +15,38 @@ from .types import ServiceT
15
15
  ##
16
16
 
17
17
 
18
- class _CatalogBackendProvider(BackendProvider[ServiceT], lang.Abstract):
18
+ class CatalogBackendProvider(BackendProvider[ServiceT], lang.Abstract):
19
+ class Instantiator(lang.Func2['mc.BackendCatalog.Backend', BackendConfigs | None, ta.Awaitable[ta.Any]]):
20
+ pass
21
+
19
22
  def __init__(
20
23
  self,
21
24
  *,
22
25
  name: BackendName,
23
26
  catalog: 'mc.BackendCatalog',
24
27
  configs: BackendConfigs | None = None,
28
+ instantiator: Instantiator | None = None,
25
29
  ) -> None:
26
30
  super().__init__()
27
31
 
28
32
  self._name = name
29
33
  self._catalog = catalog
30
34
  self._configs = configs
35
+ if instantiator is None:
36
+ instantiator = CatalogBackendProvider.Instantiator(lang.as_async(lambda be, cfgs: be.factory(*cfgs or [])))
37
+ self._instantiator = instantiator
31
38
 
32
39
  @contextlib.asynccontextmanager
33
40
  async def _provide_backend(self, cls: type[ServiceT]) -> ta.AsyncIterator[ServiceT]:
41
+ be = self._catalog.get_backend(cls, self._name)
42
+
34
43
  service: ServiceT
35
- async with lang.async_or_sync_maybe_managing(self._catalog.get_backend(
36
- cls,
37
- self._name,
38
- *(self._configs or []),
39
- )) as service:
44
+ async with lang.async_or_sync_maybe_managing(await self._instantiator(be, self._configs)) as service:
40
45
  yield service
41
46
 
42
47
 
43
48
  class CatalogChatChoicesServiceBackendProvider(
44
- _CatalogBackendProvider['mc.ChatChoicesService'],
49
+ CatalogBackendProvider['mc.ChatChoicesService'],
45
50
  ChatChoicesServiceBackendProvider,
46
51
  ):
47
52
  def provide_backend(self) -> ta.AsyncContextManager['mc.ChatChoicesService']:
@@ -49,7 +54,7 @@ class CatalogChatChoicesServiceBackendProvider(
49
54
 
50
55
 
51
56
  class CatalogChatChoicesStreamServiceBackendProvider(
52
- _CatalogBackendProvider['mc.ChatChoicesStreamService'],
57
+ CatalogBackendProvider['mc.ChatChoicesStreamService'],
53
58
  ChatChoicesStreamServiceBackendProvider,
54
59
  ):
55
60
  def provide_backend(self) -> ta.AsyncContextManager['mc.ChatChoicesStreamService']:
@@ -1,6 +1,9 @@
1
+ import typing as ta
2
+
1
3
  from omlish import inject as inj
2
4
  from omlish import lang
3
5
 
6
+ from ..... import minichain as mc
4
7
  from .injection import backend_configs
5
8
 
6
9
 
@@ -34,4 +37,16 @@ def bind_backends(
34
37
 
35
38
  #
36
39
 
40
+ async def catalog_backend_instantiator_provider(injector: inj.AsyncInjector) -> _catalog.CatalogBackendProvider.Instantiator: # noqa
41
+ async def inner(be: 'mc.BackendCatalog.Backend', cfgs: _types.BackendConfigs | None) -> ta.Any:
42
+ kwt = inj.build_kwargs_target(be.factory, non_strict=True)
43
+ kw = await injector.provide_kwargs(kwt)
44
+ return be.factory(*cfgs or [], **kw)
45
+
46
+ return _catalog.CatalogBackendProvider.Instantiator(inner)
47
+
48
+ els.append(inj.bind(_catalog.CatalogBackendProvider.Instantiator, to_async_fn=catalog_backend_instantiator_provider)) # noqa
49
+
50
+ #
51
+
37
52
  return inj.as_elements(*els)
@@ -31,7 +31,7 @@ class CompletionSession(Session['CompletionSession.Config']):
31
31
  prompt = check.isinstance(self._config.content, str)
32
32
 
33
33
  mdl: mc.CompletionService
34
- async with lang.async_maybe_managing(self._backend_catalog.get_backend(
34
+ async with lang.async_maybe_managing(self._backend_catalog.new_backend(
35
35
  mc.CompletionService,
36
36
  self._config.backend or DEFAULT_COMPLETION_MODEL_BACKEND,
37
37
  )) as mdl:
@@ -29,7 +29,7 @@ class EmbeddingSession(Session['EmbeddingSession.Config']):
29
29
 
30
30
  async def run(self) -> None:
31
31
  mdl: mc.EmbeddingService
32
- async with lang.async_maybe_managing(self._backend_catalog.get_backend(
32
+ async with lang.async_maybe_managing(self._backend_catalog.new_backend(
33
33
  mc.EmbeddingService,
34
34
  self._config.backend or DEFAULT_EMBEDDING_MODEL_BACKEND,
35
35
  )) as mdl:
@@ -322,6 +322,7 @@ with _lang.auto_proxy_init(
322
322
  ##
323
323
 
324
324
  from .registries.globals import ( # noqa
325
+ get_registry_cls,
325
326
  register_type,
326
327
  registry_new,
327
328
  registry_of,