weco 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
weco/api.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import sys
2
- from typing import Dict, Any, Optional, Union, Tuple, List
2
+ from typing import Dict, Any, Optional, Union, Tuple
3
3
  import requests
4
4
  from rich.console import Console
5
5
 
6
6
  from weco import __pkg_version__, __base_url__
7
- from .utils import truncate_output, determine_model_for_onboarding
7
+ from .utils import truncate_output
8
8
 
9
9
 
10
10
  def handle_api_error(e: requests.exceptions.HTTPError, console: Console) -> None:
@@ -109,31 +109,31 @@ def start_optimization_run(
109
109
  log_dir: str = ".runs",
110
110
  auth_headers: dict = {},
111
111
  timeout: Union[int, Tuple[int, int]] = (10, 3650),
112
+ api_keys: Optional[Dict[str, str]] = None,
112
113
  ) -> Optional[Dict[str, Any]]:
113
114
  """Start the optimization run."""
114
115
  with console.status("[bold green]Starting Optimization..."):
115
116
  try:
116
- response = requests.post(
117
- f"{__base_url__}/runs/",
118
- json={
119
- "source_code": source_code,
120
- "source_path": source_path,
121
- "additional_instructions": additional_instructions,
122
- "objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize},
123
- "optimizer": {
124
- "steps": steps,
125
- "code_generator": code_generator_config,
126
- "evaluator": evaluator_config,
127
- "search_policy": search_policy_config,
128
- },
129
- "eval_timeout": eval_timeout,
130
- "save_logs": save_logs,
131
- "log_dir": log_dir,
132
- "metadata": {"client_name": "cli", "client_version": __pkg_version__},
117
+ request_json = {
118
+ "source_code": source_code,
119
+ "source_path": source_path,
120
+ "additional_instructions": additional_instructions,
121
+ "objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize},
122
+ "optimizer": {
123
+ "steps": steps,
124
+ "code_generator": code_generator_config,
125
+ "evaluator": evaluator_config,
126
+ "search_policy": search_policy_config,
133
127
  },
134
- headers=auth_headers,
135
- timeout=timeout,
136
- )
128
+ "eval_timeout": eval_timeout,
129
+ "save_logs": save_logs,
130
+ "log_dir": log_dir,
131
+ "metadata": {"client_name": "cli", "client_version": __pkg_version__},
132
+ }
133
+ if api_keys:
134
+ request_json["api_keys"] = api_keys
135
+
136
+ response = requests.post(f"{__base_url__}/runs/", json=request_json, headers=auth_headers, timeout=timeout)
137
137
  response.raise_for_status()
138
138
  result = response.json()
139
139
  # Handle None values for code and plan fields
@@ -156,11 +156,10 @@ def resume_optimization_run(
156
156
  """Request the backend to resume an interrupted run."""
157
157
  with console.status("[bold green]Resuming run..."):
158
158
  try:
159
+ request_json = {"metadata": {"client_name": "cli", "client_version": __pkg_version__}}
160
+
159
161
  response = requests.post(
160
- f"{__base_url__}/runs/{run_id}/resume",
161
- json={"metadata": {"client_name": "cli", "client_version": __pkg_version__}},
162
- headers=auth_headers,
163
- timeout=timeout,
162
+ f"{__base_url__}/runs/{run_id}/resume", json=request_json, headers=auth_headers, timeout=timeout
164
163
  )
165
164
  response.raise_for_status()
166
165
  result = response.json()
@@ -180,17 +179,19 @@ def evaluate_feedback_then_suggest_next_solution(
180
179
  execution_output: str,
181
180
  auth_headers: dict = {},
182
181
  timeout: Union[int, Tuple[int, int]] = (10, 3650),
182
+ api_keys: Optional[Dict[str, str]] = None,
183
183
  ) -> Dict[str, Any]:
184
184
  """Evaluate the feedback and suggest the next solution."""
185
185
  try:
186
186
  # Truncate the execution output before sending to backend
187
187
  truncated_output = truncate_output(execution_output)
188
188
 
189
+ request_json = {"execution_output": truncated_output, "metadata": {}}
190
+ if api_keys:
191
+ request_json["api_keys"] = api_keys
192
+
189
193
  response = requests.post(
190
- f"{__base_url__}/runs/{run_id}/suggest",
191
- json={"execution_output": truncated_output, "metadata": {}},
192
- headers=auth_headers,
193
- timeout=timeout,
194
+ f"{__base_url__}/runs/{run_id}/suggest", json=request_json, headers=auth_headers, timeout=timeout
194
195
  )
195
196
  response.raise_for_status()
196
197
  result = response.json()
@@ -314,145 +315,3 @@ def report_termination(
314
315
  except Exception as e:
315
316
  print(f"Warning: Failed to report termination to backend for run {run_id}: {e}", file=sys.stderr)
316
317
  return False
317
-
318
-
319
- def get_optimization_suggestions_from_codebase(
320
- console: Console,
321
- gitingest_summary: str,
322
- gitingest_tree: str,
323
- gitingest_content_str: str,
324
- auth_headers: dict = {},
325
- timeout: Union[int, Tuple[int, int]] = (10, 3650),
326
- ) -> Optional[List[Dict[str, Any]]]:
327
- """Analyze codebase and get optimization suggestions using the model-agnostic backend API."""
328
- try:
329
- model = determine_model_for_onboarding()
330
- response = requests.post(
331
- f"{__base_url__}/onboard/analyze-codebase",
332
- json={
333
- "gitingest_summary": gitingest_summary,
334
- "gitingest_tree": gitingest_tree,
335
- "gitingest_content": gitingest_content_str,
336
- "model": model,
337
- "metadata": {},
338
- },
339
- headers=auth_headers,
340
- timeout=timeout,
341
- )
342
- response.raise_for_status()
343
- result = response.json()
344
- return [option for option in result.get("options", [])]
345
-
346
- except requests.exceptions.HTTPError as e:
347
- handle_api_error(e, console)
348
- return None
349
- except Exception as e:
350
- console.print(f"[bold red]Error: {e}[/]")
351
- return None
352
-
353
-
354
- def generate_evaluation_script_and_metrics(
355
- console: Console,
356
- target_file: str,
357
- description: str,
358
- gitingest_content_str: str,
359
- auth_headers: dict = {},
360
- timeout: Union[int, Tuple[int, int]] = (10, 3650),
361
- ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
362
- """Generate evaluation script and determine metrics using the model-agnostic backend API."""
363
- try:
364
- model = determine_model_for_onboarding()
365
- response = requests.post(
366
- f"{__base_url__}/onboard/generate-script",
367
- json={
368
- "target_file": target_file,
369
- "description": description,
370
- "gitingest_content": gitingest_content_str,
371
- "model": model,
372
- "metadata": {},
373
- },
374
- headers=auth_headers,
375
- timeout=timeout,
376
- )
377
- response.raise_for_status()
378
- result = response.json()
379
- return result.get("script_content"), result.get("metric_name"), result.get("goal"), result.get("reasoning")
380
- except requests.exceptions.HTTPError as e:
381
- handle_api_error(e, console)
382
- return None, None, None, None
383
- except Exception as e:
384
- console.print(f"[bold red]Error: {e}[/]")
385
- return None, None, None, None
386
-
387
-
388
- def analyze_evaluation_environment(
389
- console: Console,
390
- target_file: str,
391
- description: str,
392
- gitingest_summary: str,
393
- gitingest_tree: str,
394
- gitingest_content_str: str,
395
- auth_headers: dict = {},
396
- timeout: Union[int, Tuple[int, int]] = (10, 3650),
397
- ) -> Optional[Dict[str, Any]]:
398
- """Analyze existing evaluation scripts and environment using the model-agnostic backend API."""
399
- try:
400
- model = determine_model_for_onboarding()
401
- response = requests.post(
402
- f"{__base_url__}/onboard/analyze-environment",
403
- json={
404
- "target_file": target_file,
405
- "description": description,
406
- "gitingest_summary": gitingest_summary,
407
- "gitingest_tree": gitingest_tree,
408
- "gitingest_content": gitingest_content_str,
409
- "model": model,
410
- "metadata": {},
411
- },
412
- headers=auth_headers,
413
- timeout=timeout,
414
- )
415
- response.raise_for_status()
416
- return response.json()
417
-
418
- except requests.exceptions.HTTPError as e:
419
- handle_api_error(e, console)
420
- return None
421
- except Exception as e:
422
- console.print(f"[bold red]Error: {e}[/]")
423
- return None
424
-
425
-
426
- def analyze_script_execution_requirements(
427
- console: Console,
428
- script_content: str,
429
- script_path: str,
430
- target_file: str,
431
- auth_headers: dict = {},
432
- timeout: Union[int, Tuple[int, int]] = (10, 3650),
433
- ) -> Optional[str]:
434
- """Analyze script to determine proper execution command using the model-agnostic backend API."""
435
- try:
436
- model = determine_model_for_onboarding()
437
- response = requests.post(
438
- f"{__base_url__}/onboard/analyze-script",
439
- json={
440
- "script_content": script_content,
441
- "script_path": script_path,
442
- "target_file": target_file,
443
- "model": model,
444
- "metadata": {},
445
- },
446
- headers=auth_headers,
447
- timeout=timeout,
448
- )
449
- response.raise_for_status()
450
- result = response.json()
451
- return result.get("command", f"python {script_path}")
452
-
453
- except requests.exceptions.HTTPError as e:
454
- handle_api_error(e, console)
455
- return f"python {script_path}"
456
- except Exception as e:
457
- console.print(f"[bold red]Error: {e}[/]")
458
- return f"python {script_path}"
weco/cli.py CHANGED
@@ -1,16 +1,47 @@
1
1
  import argparse
2
2
  import sys
3
- import pathlib
4
3
  from rich.console import Console
5
4
  from rich.traceback import install
6
5
 
7
6
  from .auth import clear_api_key
8
- from .utils import check_for_cli_updates
7
+ from .constants import DEFAULT_MODELS
8
+ from .utils import check_for_cli_updates, get_default_model, UnrecognizedAPIKeysError, DefaultModelNotFoundError
9
+
9
10
 
10
11
  install(show_locals=True)
11
12
  console = Console()
12
13
 
13
14
 
15
+ def parse_api_keys(api_key_args: list[str] | None) -> dict[str, str]:
16
+ """Parse API key arguments from CLI into a dictionary.
17
+
18
+ Args:
19
+ api_key_args: List of strings in format 'provider=key' (e.g., ['openai=sk-xxx', 'anthropic=sk-ant-yyy'])
20
+
21
+ Returns:
22
+ Dictionary mapping provider names to API keys. Returns empty dict if no keys provided.
23
+
24
+ Raises:
25
+ ValueError: If any argument is not in the correct format.
26
+ """
27
+ if not api_key_args:
28
+ return {}
29
+
30
+ api_keys = {}
31
+ for arg in api_key_args:
32
+ try:
33
+ provider, key = (s.strip() for s in arg.split("=", 1))
34
+ except Exception:
35
+ raise ValueError(f"Invalid API key format: '{arg}'. Expected format: 'provider=key'")
36
+
37
+ if not provider or not key:
38
+ raise ValueError(f"Invalid API key format: '{arg}'. Provider and key must be non-empty.")
39
+
40
+ api_keys[provider.lower()] = key
41
+
42
+ return api_keys
43
+
44
+
14
45
  # Function to define and return the run_parser (or configure it on a passed subparser object)
15
46
  # This helps keep main() cleaner and centralizes run command arg definitions.
16
47
  def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
@@ -78,6 +109,29 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
78
109
  help="Automatically apply the best solution to the source file without prompting",
79
110
  )
80
111
 
112
+ default_api_keys = " ".join([f"{provider}=xxx" for provider, _ in DEFAULT_MODELS])
113
+ supported_providers = ", ".join([provider for provider, _ in DEFAULT_MODELS])
114
+ default_models_for_providers = "\n".join([f"- {provider}: {model}" for provider, model in DEFAULT_MODELS])
115
+ run_parser.add_argument(
116
+ "--api-key",
117
+ nargs="+",
118
+ type=str,
119
+ default=None,
120
+ help=f"""Provide one or more API keys for supported LLM providers. Specify a model with the --model flag.
121
+ Weco will use the default model for the provider if no model is specified.
122
+
123
+ Use the format 'provider=KEY', separated by spaces to specify multiple keys.
124
+
125
+ Example:
126
+ --api-key {default_api_keys}
127
+
128
+ Supported provider names: {supported_providers}.
129
+
130
+ Default models for providers:
131
+ {default_models_for_providers}
132
+ """,
133
+ )
134
+
81
135
 
82
136
  def configure_credits_parser(credits_parser: argparse.ArgumentParser) -> None:
83
137
  """Configure the credits command parser and all its subcommands."""
@@ -129,24 +183,62 @@ def configure_resume_parser(resume_parser: argparse.ArgumentParser) -> None:
129
183
  help="Automatically apply the best solution to the source file without prompting",
130
184
  )
131
185
 
186
+ default_api_keys = " ".join([f"{provider}=xxx" for provider, _ in DEFAULT_MODELS])
187
+ supported_providers = ", ".join([provider for provider, _ in DEFAULT_MODELS])
188
+
189
+ resume_parser.add_argument(
190
+ "--api-key",
191
+ nargs="+",
192
+ type=str,
193
+ default=None,
194
+ help=f"""Provide one or more API keys for supported LLM providers.
195
+ Weco will use the model associated with the run you are resuming.
196
+
197
+ Use the format 'provider=KEY', separated by spaces to specify multiple keys.
198
+
199
+ Example:
200
+ --api-key {default_api_keys}
201
+
202
+ Supported provider names: {supported_providers}.
203
+ """,
204
+ )
205
+
132
206
 
133
207
  def execute_run_command(args: argparse.Namespace) -> None:
134
208
  """Execute the 'weco run' command with all its logic."""
135
209
  from .optimizer import execute_optimization
136
210
 
211
+ try:
212
+ api_keys = parse_api_keys(args.api_key)
213
+ except ValueError as e:
214
+ console.print(f"[bold red]Error parsing API keys: {e}[/]")
215
+ sys.exit(1)
216
+
217
+ model = args.model
218
+ if not model:
219
+ try:
220
+ model = get_default_model(api_keys=api_keys)
221
+ except (UnrecognizedAPIKeysError, DefaultModelNotFoundError) as e:
222
+ console.print(f"[bold red]Error: {e}[/]")
223
+ sys.exit(1)
224
+
225
+ if api_keys:
226
+ console.print(f"[bold yellow]Custom API keys provided. Using default model: {model} for the run.[/]")
227
+
137
228
  success = execute_optimization(
138
229
  source=args.source,
139
230
  eval_command=args.eval_command,
140
231
  metric=args.metric,
141
232
  goal=args.goal,
233
+ model=model,
142
234
  steps=args.steps,
143
- model=args.model,
144
235
  log_dir=args.log_dir,
145
236
  additional_instructions=args.additional_instructions,
146
237
  console=console,
147
238
  eval_timeout=args.eval_timeout,
148
239
  save_logs=args.save_logs,
149
240
  apply_change=args.apply_change,
241
+ api_keys=api_keys,
150
242
  )
151
243
  exit_code = 0 if success else 1
152
244
  sys.exit(exit_code)
@@ -156,7 +248,13 @@ def execute_resume_command(args: argparse.Namespace) -> None:
156
248
  """Execute the 'weco resume' command with all its logic."""
157
249
  from .optimizer import resume_optimization
158
250
 
159
- success = resume_optimization(run_id=args.run_id, console=console, apply_change=args.apply_change)
251
+ try:
252
+ api_keys = parse_api_keys(args.api_key)
253
+ except ValueError as e:
254
+ console.print(f"[bold red]Error parsing API keys: {e}[/]")
255
+ sys.exit(1)
256
+
257
+ success = resume_optimization(run_id=args.run_id, console=console, api_keys=api_keys, apply_change=args.apply_change)
160
258
  sys.exit(0 if success else 1)
161
259
 
162
260
 
@@ -169,22 +267,13 @@ def main() -> None:
169
267
  formatter_class=argparse.RawDescriptionHelpFormatter,
170
268
  )
171
269
 
172
- # Add global model argument
173
- parser.add_argument(
174
- "-M",
175
- "--model",
176
- type=str,
177
- default=None,
178
- help="Model to use for optimization. Defaults to `o4-mini`. See full list at docs.weco.ai/cli/supported-models",
179
- )
180
-
181
270
  subparsers = parser.add_subparsers(
182
271
  dest="command", help="Available commands"
183
272
  ) # Removed required=True for now to handle chatbot case easily
184
273
 
185
274
  # --- Run Command Parser Setup ---
186
275
  run_parser = subparsers.add_parser(
187
- "run", help="Run code optimization", formatter_class=argparse.RawDescriptionHelpFormatter, allow_abbrev=False
276
+ "run", help="Run code optimization", formatter_class=argparse.RawTextHelpFormatter, allow_abbrev=False
188
277
  )
189
278
  configure_run_parser(run_parser) # Use the helper to add arguments
190
279
 
@@ -199,86 +288,11 @@ def main() -> None:
199
288
  resume_parser = subparsers.add_parser(
200
289
  "resume",
201
290
  help="Resume an interrupted optimization run",
202
- formatter_class=argparse.RawDescriptionHelpFormatter,
291
+ formatter_class=argparse.RawTextHelpFormatter,
203
292
  allow_abbrev=False,
204
293
  )
205
294
  configure_resume_parser(resume_parser)
206
295
 
207
- # Check if we should run the chatbot
208
- # This logic needs to be robust. If 'run' or 'logout' is present, or -h/--help, don't run chatbot.
209
- # Otherwise, if it's just 'weco' or 'weco <path>' (with optional --model), run chatbot.
210
-
211
- def should_run_chatbot(args_list):
212
- """Determine if we should run chatbot by filtering out model arguments."""
213
- filtered = []
214
- i = 0
215
- while i < len(args_list):
216
- if args_list[i] in ["-M", "--model"]:
217
- # Skip the model argument and its value (if it exists)
218
- i += 1 # Skip the model flag
219
- if i < len(args_list): # Skip the model value if it exists
220
- i += 1
221
- elif args_list[i].startswith("--model="):
222
- i += 1 # Skip --model=value format
223
- else:
224
- filtered.append(args_list[i])
225
- i += 1
226
-
227
- # Apply existing chatbot detection logic to filtered args
228
- return len(filtered) == 0 or (len(filtered) == 1 and not filtered[0].startswith("-"))
229
-
230
- # Check for known commands by looking at the first non-option argument
231
- def get_first_non_option_arg():
232
- for arg in sys.argv[1:]:
233
- if not arg.startswith("-"):
234
- return arg
235
- return None
236
-
237
- first_non_option = get_first_non_option_arg()
238
- is_known_command = first_non_option in ["run", "logout", "credits"]
239
- is_help_command = len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"] # Check for global help
240
-
241
- should_run_chatbot_result = should_run_chatbot(sys.argv[1:])
242
- should_run_chatbot_flag = not is_known_command and not is_help_command and should_run_chatbot_result
243
-
244
- if should_run_chatbot_flag:
245
- from .chatbot import run_onboarding_chatbot # Moved import inside
246
-
247
- # Create a simple parser just for extracting the model argument
248
- model_parser = argparse.ArgumentParser(add_help=False)
249
- model_parser.add_argument("-M", "--model", type=str, default=None)
250
-
251
- # Parse args to extract model
252
- args, unknown = model_parser.parse_known_args()
253
-
254
- # Determine project path from remaining arguments
255
- filtered_args = []
256
- i = 1
257
- while i < len(sys.argv):
258
- if sys.argv[i] in ["-M", "--model"]:
259
- # Skip the model argument and its value (if it exists)
260
- i += 1 # Skip the model flag
261
- if i < len(sys.argv): # Skip the model value if it exists
262
- i += 1
263
- elif sys.argv[i].startswith("--model="):
264
- i += 1 # Skip --model=value format
265
- else:
266
- filtered_args.append(sys.argv[i])
267
- i += 1
268
-
269
- project_path = pathlib.Path(filtered_args[0]) if filtered_args else pathlib.Path.cwd()
270
- if not project_path.is_dir():
271
- console.print(
272
- f"[bold red]Error:[/] The path '{project_path}' is not a valid directory. Please provide a valid directory path."
273
- )
274
- sys.exit(1)
275
-
276
- # Pass the run_parser and model to the chatbot
277
- run_onboarding_chatbot(project_path, console, run_parser, model=args.model)
278
- sys.exit(0)
279
-
280
- # If not running chatbot, proceed with normal arg parsing
281
- # If we reached here, a command (run, logout) or help is expected.
282
296
  args = parser.parse_args()
283
297
 
284
298
  if args.command == "logout":
weco/constants.py CHANGED
@@ -7,8 +7,8 @@ Constants for the Weco CLI package.
7
7
  TRUNCATION_THRESHOLD = 51000 # Maximum length before truncation
8
8
  TRUNCATION_KEEP_LENGTH = 25000 # Characters to keep from beginning and end
9
9
 
10
- # Default model configuration
11
- DEFAULT_MODEL = "o4-mini"
12
-
13
10
  # Supported file extensions for additional instructions
14
11
  SUPPORTED_FILE_EXTENSIONS = [".md", ".txt", ".rst"]
12
+
13
+ # Default models for each provider in order of preference
14
+ DEFAULT_MODELS = [("gemini", "gemini-3-pro-preview"), ("openai", "o4-mini"), ("vertex_ai", "claude-opus-4-5")]
weco/optimizer.py CHANGED
@@ -127,14 +127,15 @@ def execute_optimization(
127
127
  eval_command: str,
128
128
  metric: str,
129
129
  goal: str, # "maximize" or "minimize"
130
+ model: str,
130
131
  steps: int = 100,
131
- model: Optional[str] = None,
132
132
  log_dir: str = ".runs",
133
133
  additional_instructions: Optional[str] = None,
134
134
  console: Optional[Console] = None,
135
135
  eval_timeout: Optional[int] = None,
136
136
  save_logs: bool = False,
137
137
  apply_change: bool = False,
138
+ api_keys: Optional[dict[str, str]] = None,
138
139
  ) -> bool:
139
140
  """
140
141
  Execute the core optimization logic.
@@ -203,11 +204,6 @@ def execute_optimization(
203
204
  # --- Process Parameters ---
204
205
  maximize = goal.lower() in ["maximize", "max"]
205
206
 
206
- # Determine the model to use
207
- if model is None:
208
- # Default to o4-mini with credit-based billing
209
- model = "o4-mini"
210
-
211
207
  code_generator_config = {"model": model}
212
208
  evaluator_config = {"model": model, "include_analysis": True}
213
209
  search_policy_config = {
@@ -245,6 +241,7 @@ def execute_optimization(
245
241
  save_logs=save_logs,
246
242
  log_dir=log_dir,
247
243
  auth_headers=auth_headers,
244
+ api_keys=api_keys,
248
245
  )
249
246
  # Indicate the endpoint failed to return a response and the optimization was unsuccessful
250
247
  if run_response is None:
@@ -371,7 +368,12 @@ def execute_optimization(
371
368
 
372
369
  # Send feedback and get next suggestion
373
370
  eval_and_next_solution_response = evaluate_feedback_then_suggest_next_solution(
374
- console=console, step=step, run_id=run_id, execution_output=term_out, auth_headers=auth_headers
371
+ console=console,
372
+ step=step,
373
+ run_id=run_id,
374
+ execution_output=term_out,
375
+ auth_headers=auth_headers,
376
+ api_keys=api_keys,
375
377
  )
376
378
  # Save next solution (.runs/<run-id>/step_<step>.<extension>)
377
379
  write_to_path(fp=runs_dir / f"step_{step}{source_fp.suffix}", content=eval_and_next_solution_response["code"])
@@ -445,7 +447,12 @@ def execute_optimization(
445
447
  if not user_stop_requested_flag:
446
448
  # Evaluate the final solution thats been generated
447
449
  eval_and_next_solution_response = evaluate_feedback_then_suggest_next_solution(
448
- console=console, step=steps, run_id=run_id, execution_output=term_out, auth_headers=auth_headers
450
+ console=console,
451
+ step=steps,
452
+ run_id=run_id,
453
+ execution_output=term_out,
454
+ auth_headers=auth_headers,
455
+ api_keys=api_keys,
449
456
  )
450
457
  summary_panel.set_step(step=steps)
451
458
  status_response = get_optimization_run_status(
@@ -544,7 +551,9 @@ def execute_optimization(
544
551
  return optimization_completed_normally or user_stop_requested_flag
545
552
 
546
553
 
547
- def resume_optimization(run_id: str, console: Optional[Console] = None, apply_change: bool = False) -> bool:
554
+ def resume_optimization(
555
+ run_id: str, console: Optional[Console] = None, apply_change: bool = False, api_keys: Optional[dict[str, str]] = None
556
+ ) -> bool:
548
557
  """Resume an interrupted run from the most recent node and continue optimization."""
549
558
  if console is None:
550
559
  console = Console()
@@ -792,6 +801,7 @@ def resume_optimization(run_id: str, console: Optional[Console] = None, apply_ch
792
801
  run_id=resume_resp["run_id"],
793
802
  execution_output=term_out,
794
803
  auth_headers=auth_headers,
804
+ api_keys=api_keys,
795
805
  )
796
806
 
797
807
  # Save next solution to logs
@@ -863,6 +873,7 @@ def resume_optimization(run_id: str, console: Optional[Console] = None, apply_ch
863
873
  run_id=resume_resp["run_id"],
864
874
  execution_output=term_out,
865
875
  auth_headers=auth_headers,
876
+ api_keys=api_keys,
866
877
  )
867
878
  summary_panel.set_step(step=total_steps)
868
879
  status_response = get_optimization_run_status(
weco/utils.py CHANGED
@@ -9,13 +9,26 @@ from rich.panel import Panel
9
9
  import pathlib
10
10
  import requests
11
11
  from packaging.version import parse as parse_version
12
- from .constants import TRUNCATION_THRESHOLD, TRUNCATION_KEEP_LENGTH, DEFAULT_MODEL, SUPPORTED_FILE_EXTENSIONS
12
+ from .constants import TRUNCATION_THRESHOLD, TRUNCATION_KEEP_LENGTH, SUPPORTED_FILE_EXTENSIONS, DEFAULT_MODELS
13
13
 
14
14
 
15
- # Env/arg helper functions
16
- def determine_model_for_onboarding() -> str:
17
- """Determine which model to use for onboarding chatbot. Defaults to o4-mini."""
18
- return DEFAULT_MODEL
15
+ class UnrecognizedAPIKeysError(Exception):
16
+ """Exception raised when unrecognized API keys are provided."""
17
+
18
+ def __init__(self, api_keys: dict[str, str]):
19
+ self.api_keys = api_keys
20
+ providers = {provider for provider, _ in DEFAULT_MODELS}
21
+ super().__init__(
22
+ f"Unrecognized API key provider in {set(api_keys.keys())}. Supported providers: {', '.join(providers)}"
23
+ )
24
+
25
+
26
+ class DefaultModelNotFoundError(Exception):
27
+ """Exception raised when no default model is found for the API keys."""
28
+
29
+ def __init__(self, api_keys: dict[str, str]):
30
+ self.api_keys = api_keys
31
+ super().__init__(f"No default model found for any of the provided API keys: {set(api_keys.keys())}")
19
32
 
20
33
 
21
34
  def read_additional_instructions(additional_instructions: str | None) -> str | None:
@@ -218,3 +231,18 @@ def check_for_cli_updates():
218
231
  except Exception:
219
232
  # Catch any other unexpected error during the check
220
233
  pass
234
+
235
+
236
+ def get_default_model(api_keys: dict[str, str] | None = None) -> str:
237
+ """Determine the default model to use based on the API keys."""
238
+ providers = {provider for provider, _ in DEFAULT_MODELS}
239
+ if api_keys and not all(provider in providers for provider in api_keys.keys()):
240
+ raise UnrecognizedAPIKeysError(api_keys)
241
+
242
+ if api_keys:
243
+ for provider, model in DEFAULT_MODELS:
244
+ if provider in api_keys:
245
+ return model
246
+ # Should never happen, but just in case
247
+ raise DefaultModelNotFoundError(api_keys)
248
+ return DEFAULT_MODELS[0][1]