dlab-cli 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dlab/model_fallback.py ADDED
@@ -0,0 +1,360 @@
1
+ """
2
+ Model validation and provider fallback for agent configs.
3
+
4
+ When a decision-pack references model providers whose API keys are not
5
+ in the .env file, this module replaces those model strings with the
6
+ orchestrator's model so users only need a single API key to get started.
7
+
8
+ Two-phase design:
9
+ 1. preflight_check() — runs BEFORE session creation on source dpack files.
10
+ Catches fatal errors (orchestrator key missing, unknown models) early.
11
+ 2. process_opencode_dir() — runs DURING session setup on work-dir copies.
12
+ Applies fallback replacements for missing provider keys.
13
+ """
14
+
15
+ import difflib
16
+ import os
17
+ import re
18
+ from pathlib import Path
19
+
20
+ from dlab.create_dpack import KNOWN_PROVIDER_ENVS, get_model_list, get_provider_env_vars
21
+
22
+
23
+ # Matches provider/model-name patterns (e.g. "anthropic/claude-sonnet-4-5")
24
+ # Negative lookahead (?!/) excludes file paths like "opencode/agents/foo.md"
25
+ _MODEL_PATTERN: re.Pattern[str] = re.compile(
26
+ r"\b([a-zA-Z0-9_-]+/[a-zA-Z0-9._-]+)\b(?!/)"
27
+ )
28
+
29
+
30
+ def parse_env_file(env_file: str | None) -> dict[str, str]:
31
+ """
32
+ Parse a .env file into a key-value dict.
33
+
34
+ Parameters
35
+ ----------
36
+ env_file : str | None
37
+ Path to .env file, or None.
38
+
39
+ Returns
40
+ -------
41
+ dict[str, str]
42
+ Parsed environment variables. Empty dict if env_file is None
43
+ or file does not exist.
44
+ """
45
+ if not env_file:
46
+ return {}
47
+ path: Path = Path(env_file)
48
+ if not path.exists():
49
+ return {}
50
+
51
+ env: dict[str, str] = {}
52
+ for line in path.read_text().splitlines():
53
+ line = line.strip()
54
+ if not line or line.startswith("#"):
55
+ continue
56
+ key, _, value = line.partition("=")
57
+ value = value.strip().strip("'\"")
58
+ env[key.strip()] = value
59
+ return env
60
+
61
+
62
+ def get_available_providers(env_vars: dict[str, str]) -> set[str]:
63
+ """
64
+ Return the set of providers whose required API keys are present.
65
+
66
+ Parameters
67
+ ----------
68
+ env_vars : dict[str, str]
69
+ Parsed environment variables.
70
+
71
+ Returns
72
+ -------
73
+ set[str]
74
+ Provider names (e.g. {"anthropic", "google"}) with all required
75
+ keys present and non-empty.
76
+ """
77
+ available: set[str] = set()
78
+ for provider, required_keys in KNOWN_PROVIDER_ENVS.items():
79
+ if all(env_vars.get(k) for k in required_keys):
80
+ available.add(provider)
81
+ return available
82
+
83
+
84
+ def _strip_comments(text: str) -> str:
85
+ """Remove comment lines (# ...) from text before scanning for models."""
86
+ lines: list[str] = []
87
+ for line in text.splitlines():
88
+ stripped: str = line.lstrip()
89
+ if not stripped.startswith("#"):
90
+ lines.append(line)
91
+ return "\n".join(lines)
92
+
93
+
94
+ def find_model_strings(text: str) -> list[str]:
95
+ """
96
+ Extract all provider/model-name strings from non-comment text.
97
+
98
+ Parameters
99
+ ----------
100
+ text : str
101
+ File content to scan.
102
+
103
+ Returns
104
+ -------
105
+ list[str]
106
+ Deduplicated list of model strings found.
107
+ """
108
+ matches: list[str] = _MODEL_PATTERN.findall(_strip_comments(text))
109
+ # Only keep matches whose provider prefix is a known provider
110
+ known_prefixes: set[str] = set(KNOWN_PROVIDER_ENVS.keys())
111
+ models: list[str] = []
112
+ seen: set[str] = set()
113
+ for m in matches:
114
+ provider: str = m.split("/")[0]
115
+ if provider in known_prefixes and m not in seen:
116
+ seen.add(m)
117
+ models.append(m)
118
+ return models
119
+
120
+
121
+ def _collect_models_from_dir(directory: Path) -> list[str]:
122
+ """Scan all .yaml/.yml/.md files in a directory for model strings."""
123
+ all_models: list[str] = []
124
+ config_files: list[Path] = sorted(
125
+ list(directory.rglob("*.yaml"))
126
+ + list(directory.rglob("*.yml"))
127
+ + list(directory.rglob("*.md"))
128
+ )
129
+ for f in config_files:
130
+ all_models.extend(find_model_strings(f.read_text()))
131
+ return list(dict.fromkeys(all_models)) # deduplicate, preserve order
132
+
133
+
134
+ def _format_env_setup_hint(model: str) -> str:
135
+ """Format a hint showing which env var to set for a model's provider."""
136
+ env_vars: list[str] = get_provider_env_vars(model)
137
+ if env_vars:
138
+ var_str: str = ", ".join(env_vars)
139
+ return f"Set {var_str} in your .env file"
140
+ return "Check provider documentation for required API key"
141
+
142
+
143
+ def preflight_check(
144
+ orchestrator_model: str,
145
+ config_dir: str,
146
+ env_file: str | None,
147
+ no_sandboxing: bool = False,
148
+ ) -> tuple[list[str], list[str]]:
149
+ """
150
+ Validate models before session creation. Runs on source dpack files.
151
+
152
+ Returns errors (fatal, abort run) and warnings (informational, continue).
153
+
154
+ Parameters
155
+ ----------
156
+ orchestrator_model : str
157
+ The orchestrator's model (from --model or config default_model).
158
+ config_dir : str
159
+ Path to the decision-pack config directory.
160
+ env_file : str | None
161
+ Path to .env file.
162
+ no_sandboxing : bool
163
+ If True, also check os.environ for API keys (local mode inherits
164
+ the shell environment).
165
+
166
+ Returns
167
+ -------
168
+ tuple[list[str], list[str]]
169
+ (errors, warnings). Errors are fatal and should abort the run.
170
+ Warnings are informational (e.g. fallback will be applied).
171
+ """
172
+ errors: list[str] = []
173
+ warnings: list[str] = []
174
+
175
+ env_vars: dict[str, str] = {}
176
+ if no_sandboxing:
177
+ env_vars.update(os.environ)
178
+ env_vars.update(parse_env_file(env_file))
179
+ available: set[str] = get_available_providers(env_vars)
180
+
181
+ # Validate orchestrator model name
182
+ all_known: list[str] = get_model_list()
183
+ known: set[str] = set(all_known)
184
+ if orchestrator_model not in known:
185
+ suggestions: list[str] = sorted(difflib.get_close_matches(
186
+ orchestrator_model, all_known, n=3, cutoff=0.6,
187
+ ))
188
+ if suggestions:
189
+ alt: str = ", ".join(suggestions)
190
+ errors.append(
191
+ f"Unknown model {orchestrator_model} — did you mean: {alt}?"
192
+ )
193
+ else:
194
+ errors.append(f"Unknown model {orchestrator_model}")
195
+ return errors, warnings
196
+
197
+ # Check orchestrator model's provider key
198
+ orchestrator_provider: str = orchestrator_model.split("/")[0]
199
+ if orchestrator_provider in KNOWN_PROVIDER_ENVS and orchestrator_provider not in available:
200
+ env_hint: str = _format_env_setup_hint(orchestrator_model)
201
+ errors.append(
202
+ f"Orchestrator model {orchestrator_model} requires an API key "
203
+ f"that is not set. {env_hint}"
204
+ )
205
+ return errors, warnings
206
+
207
+ # Scan source opencode/ dir for model strings
208
+ opencode_dir: Path = Path(config_dir) / "opencode"
209
+ if not opencode_dir.exists():
210
+ return errors, warnings
211
+
212
+ all_models: list[str] = _collect_models_from_dir(opencode_dir)
213
+
214
+ # Validate agent model names exist in known list
215
+ for model in all_models:
216
+ if model not in known:
217
+ suggestions: list[str] = sorted(difflib.get_close_matches(
218
+ model, all_known, n=3, cutoff=0.6,
219
+ ))
220
+ if suggestions:
221
+ alt: str = ", ".join(suggestions)
222
+ errors.append(
223
+ f"Unknown model {model} — did you mean: {alt}?"
224
+ )
225
+ else:
226
+ errors.append(f"Unknown model {model}")
227
+
228
+ # Check which agent models will need fallback
229
+ unavailable: set[str] = set(KNOWN_PROVIDER_ENVS.keys()) - available
230
+ models_needing_fallback: list[str] = []
231
+ for model in all_models:
232
+ provider: str = model.split("/")[0]
233
+ if provider in unavailable and model != orchestrator_model:
234
+ models_needing_fallback.append(model)
235
+
236
+ if models_needing_fallback:
237
+ seen: set[str] = set()
238
+ for model in models_needing_fallback:
239
+ if model in seen:
240
+ continue
241
+ seen.add(model)
242
+ env_hint = _format_env_setup_hint(model)
243
+ warnings.append(
244
+ f"{model} -> {orchestrator_model} ({env_hint})"
245
+ )
246
+
247
+ return errors, warnings
248
+
249
+
250
+ def apply_model_fallback(
251
+ text: str,
252
+ orchestrator_model: str,
253
+ unavailable_providers: set[str],
254
+ ) -> tuple[str, list[str]]:
255
+ """
256
+ Replace model strings whose providers are unavailable.
257
+
258
+ Parameters
259
+ ----------
260
+ text : str
261
+ File content.
262
+ orchestrator_model : str
263
+ Model to substitute in place of unavailable ones.
264
+ unavailable_providers : set[str]
265
+ Provider names whose API keys are missing.
266
+
267
+ Returns
268
+ -------
269
+ tuple[str, list[str]]
270
+ (modified_text, list of replacement descriptions).
271
+ """
272
+ if not unavailable_providers:
273
+ return text, []
274
+
275
+ replacements: list[str] = []
276
+
277
+ def _replace(match: re.Match[str]) -> str:
278
+ model_str: str = match.group(1)
279
+ provider: str = model_str.split("/")[0]
280
+ if provider in unavailable_providers:
281
+ replacements.append(f"{model_str} -> {orchestrator_model}")
282
+ return orchestrator_model
283
+ return model_str
284
+
285
+ # Only replace on non-comment lines
286
+ new_lines: list[str] = []
287
+ for line in text.splitlines(keepends=True):
288
+ if line.lstrip().startswith("#"):
289
+ new_lines.append(line)
290
+ else:
291
+ new_lines.append(_MODEL_PATTERN.sub(_replace, line))
292
+
293
+ return "".join(new_lines), replacements
294
+
295
+
296
+ def process_opencode_dir(
297
+ opencode_dir: str,
298
+ orchestrator_model: str,
299
+ env_file: str | None,
300
+ no_sandboxing: bool = False,
301
+ ) -> list[str]:
302
+ """
303
+ Apply model fallback to all config files in .opencode/ (work-dir copies).
304
+
305
+ Assumes preflight_check() has already validated the orchestrator model.
306
+ Only applies replacements — no validation here.
307
+
308
+ Parameters
309
+ ----------
310
+ opencode_dir : str
311
+ Path to the .opencode/ directory in the work dir.
312
+ orchestrator_model : str
313
+ The orchestrator's model (fallback target).
314
+ env_file : str | None
315
+ Path to .env file.
316
+ no_sandboxing : bool
317
+ If True, also check os.environ for API keys.
318
+
319
+ Returns
320
+ -------
321
+ list[str]
322
+ Replacement messages (e.g. "parallel_agents/poet.yaml: google/gemini-2.0-flash -> ...").
323
+ """
324
+ opencode_path: Path = Path(opencode_dir)
325
+ if not opencode_path.exists():
326
+ return []
327
+
328
+ env_vars: dict[str, str] = {}
329
+ if no_sandboxing:
330
+ env_vars.update(os.environ)
331
+ env_vars.update(parse_env_file(env_file))
332
+ available: set[str] = get_available_providers(env_vars)
333
+
334
+ orchestrator_provider: str = orchestrator_model.split("/")[0]
335
+ if orchestrator_provider in KNOWN_PROVIDER_ENVS and orchestrator_provider not in available:
336
+ return []
337
+
338
+ unavailable: set[str] = set(KNOWN_PROVIDER_ENVS.keys()) - available
339
+ if not unavailable:
340
+ return []
341
+
342
+ messages: list[str] = []
343
+ config_files: list[Path] = sorted(
344
+ list(opencode_path.rglob("*.yaml"))
345
+ + list(opencode_path.rglob("*.yml"))
346
+ + list(opencode_path.rglob("*.md"))
347
+ )
348
+
349
+ for f in config_files:
350
+ text: str = f.read_text()
351
+ new_text, replacements = apply_model_fallback(
352
+ text, orchestrator_model, unavailable,
353
+ )
354
+ if replacements:
355
+ f.write_text(new_text)
356
+ rel: str = str(f.relative_to(opencode_path))
357
+ for r in replacements:
358
+ messages.append(f"{rel}: {r}")
359
+
360
+ return messages
dlab/parallel_tool.py ADDED
@@ -0,0 +1,18 @@
1
+ """
2
+ Template for the parallel-agents.ts tool.
3
+
4
+ This module loads the parallel-agents TypeScript source from dlab/js/
5
+ and exposes it as PARALLEL_AGENTS_SOURCE for use by session setup.
6
+
7
+ WARNING: The template contains an evil hack (git init) to work around OpenCode's config
8
+ traversal behavior. See "Git Init Hack" section in CLAUDE.md for details.
9
+ This should be replaced with a proper solution when OpenCode supports
10
+ disabling parent directory config traversal.
11
+ """
12
+
13
+ from importlib.resources import files
14
+
15
+
16
+ PARALLEL_AGENTS_SOURCE: str = (
17
+ files("dlab.js").joinpath("parallel-agents.ts").read_text()
18
+ )