model-forge-llm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modelforge/cli.py ADDED
@@ -0,0 +1,720 @@
1
+ # Standard library imports
2
+ import json
3
+ import random
4
+ import time
5
+ from typing import Any
6
+
7
+ # Third-party imports
8
+ import click
9
+ from langchain_core.messages import BaseMessage
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+
13
+ # Local imports
14
+ from . import auth, config
15
+ from .logging_config import get_logger
16
+ from .modelsdev import ModelsDevClient
17
+ from .registry import ModelForgeRegistry
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ def _handle_authentication(
23
+ provider: str, provider_data: dict[str, Any], api_key: str | None, dev_auth: bool
24
+ ) -> None:
25
+ """Handle authentication for provider configuration."""
26
+ if api_key:
27
+ auth_strategy = auth.ApiKeyAuth(provider)
28
+ # Store the provided API key using the new config-based approach
29
+ auth_strategy._save_auth_data({"api_key": api_key})
30
+ click.echo(f"API key stored for provider '{provider}'.")
31
+ elif dev_auth:
32
+ click.echo("Starting device authentication flow...")
33
+ try:
34
+ strategy = auth.get_auth_strategy(provider, provider_data)
35
+ credentials = strategy.authenticate()
36
+ if credentials:
37
+ click.echo(f"Authentication successful for provider '{provider}'.")
38
+ else:
39
+ click.echo("Authentication failed.", err=True)
40
+ return
41
+ except Exception as e:
42
+ logger.exception("Device authentication failed")
43
+ click.echo(f"Device authentication failed: {e}", err=True)
44
+ return
45
+
46
+
47
+ @click.group()
48
+ def cli() -> None:
49
+ """ModelForge CLI for managing LLM configurations."""
50
+
51
+
52
+ @cli.group()
53
+ def config_group() -> None:
54
+ """Configuration management commands."""
55
+
56
+
57
+ @config_group.command(name="show")
58
+ def show_config() -> None:
59
+ """Shows the current configuration."""
60
+ try:
61
+ current_config, config_path = config.get_config()
62
+ scope = "local" if config_path == config.LOCAL_CONFIG_FILE else "global"
63
+ click.echo(f"Active ModelForge Config ({scope}): {config_path}")
64
+
65
+ if not current_config or not current_config.get("providers"):
66
+ click.echo(
67
+ "Configuration is empty. Add models using 'modelforge config add'."
68
+ )
69
+ return
70
+
71
+ click.echo(json.dumps(current_config, indent=4))
72
+ except Exception as e:
73
+ logger.exception("Failed to show configuration")
74
+ click.echo(f"Error: Failed to show configuration: {e}", err=True)
75
+
76
+
77
+ @config_group.command(name="migrate")
78
+ def migrate_config() -> None:
79
+ """Migrates configuration from old location to new global location."""
80
+ config.migrate_old_config()
81
+
82
+
83
+ @config_group.command(name="add")
84
+ @click.option(
85
+ "--provider",
86
+ required=True,
87
+ help=(
88
+ "The name of the provider (e.g., 'openai', 'ollama', "
89
+ "'github_copilot', 'google')."
90
+ ),
91
+ )
92
+ @click.option(
93
+ "--model",
94
+ required=True,
95
+ help="A local, memorable name for the model (e.g., 'copilot-chat').",
96
+ )
97
+ @click.option(
98
+ "--api-model-name",
99
+ help="The actual model name the API expects (e.g., 'claude-3.7-sonnet-thought').",
100
+ )
101
+ @click.option("--api-key", help="The API key for the provider, if applicable.")
102
+ @click.option(
103
+ "--dev-auth", is_flag=True, help="Use device authentication flow, if applicable."
104
+ )
105
+ @click.option(
106
+ "--local",
107
+ is_flag=True,
108
+ help="Save to local project config (./.model-forge/config.json).",
109
+ )
110
+ def add_model(
111
+ provider: str,
112
+ model: str,
113
+ api_model_name: str | None = None,
114
+ api_key: str | None = None,
115
+ dev_auth: bool = False,
116
+ local: bool = False,
117
+ ) -> None:
118
+ """Add a new model configuration."""
119
+ try:
120
+ # Load existing configuration
121
+ target_config_path = config.get_config_path(local=local)
122
+ current_config, _ = config.get_config_from_path(target_config_path)
123
+
124
+ # Ensure providers section exists
125
+ if "providers" not in current_config:
126
+ current_config["providers"] = {}
127
+
128
+ # Initialize provider if it doesn't exist
129
+ if provider not in current_config["providers"]:
130
+ # Provider configuration defaults
131
+ provider_defaults = {
132
+ "openai": {
133
+ "llm_type": "openai_compatible",
134
+ "base_url": "https://api.openai.com/v1",
135
+ "auth_strategy": "api_key",
136
+ },
137
+ "openrouter": {
138
+ "llm_type": "openai-compatible",
139
+ "base_url": "https://openrouter.ai/api/v1",
140
+ "auth_strategy": "api_key",
141
+ },
142
+ "google": {
143
+ "llm_type": "google_genai",
144
+ "auth_strategy": "api_key",
145
+ },
146
+ "github_copilot": {
147
+ "llm_type": "github_copilot",
148
+ "base_url": "https://api.githubcopilot.com",
149
+ "auth_strategy": "device_flow",
150
+ "auth_details": {
151
+ "client_id": "01ab8ac9400c4e429b23",
152
+ "device_code_url": "https://github.com/login/device/code",
153
+ "token_url": "https://github.com/login/oauth/access_token",
154
+ "scope": "read:user",
155
+ },
156
+ },
157
+ "ollama": {"llm_type": "ollama", "base_url": "http://localhost:11434"},
158
+ }
159
+
160
+ current_config["providers"][provider] = provider_defaults.get(provider, {})
161
+ current_config["providers"][provider]["models"] = {}
162
+
163
+ # Add the model
164
+ if "models" not in current_config["providers"][provider]:
165
+ current_config["providers"][provider]["models"] = {}
166
+
167
+ model_config = {}
168
+ if api_model_name:
169
+ model_config["api_model_name"] = api_model_name
170
+
171
+ current_config["providers"][provider]["models"][model] = model_config
172
+
173
+ # Save the configuration
174
+ config.save_config(current_config, local=local)
175
+
176
+ # Handle authentication
177
+ _handle_authentication(
178
+ provider, current_config["providers"][provider], api_key, dev_auth
179
+ )
180
+
181
+ # Success message
182
+ scope_msg = "local" if local else "global"
183
+ click.echo(
184
+ f"Successfully configured model '{model}' for provider '{provider}' "
185
+ f"in the {scope_msg} config."
186
+ )
187
+ click.echo("Run 'modelforge config show' to see the updated configuration.")
188
+
189
+ except Exception as e:
190
+ logger.exception("Failed to add model configuration")
191
+ click.echo(f"Error: {e}", err=True)
192
+
193
+
194
+ @config_group.command(name="use")
195
+ @click.option(
196
+ "--provider", "provider_name", required=True, help="The name of the provider."
197
+ )
198
+ @click.option(
199
+ "--model", "model_alias", required=True, help="The alias of the model to use."
200
+ )
201
+ @click.option(
202
+ "--local", is_flag=True, help="Set the current model in the local project config."
203
+ )
204
+ def use_model(provider_name: str, model_alias: str, local: bool) -> None:
205
+ """Set the current model to use."""
206
+ success = config.set_current_model(provider_name, model_alias, local=local)
207
+ if not success:
208
+ raise click.ClickException("Model not found")
209
+
210
+
211
+ @config_group.command(name="remove")
212
+ @click.option("--provider", required=True, help="The name of the provider.")
213
+ @click.option("--model", required=True, help="The alias of the model to remove.")
214
+ @click.option(
215
+ "--keep-credentials",
216
+ is_flag=True,
217
+ help="Keep stored credentials (don't remove from config).",
218
+ )
219
+ @click.option("--local", is_flag=True, help="Remove from the local project config.")
220
+ def remove_model(
221
+ provider: str, model: str | None, keep_credentials: bool, local: bool
222
+ ) -> None:
223
+ """Removes a model configuration and optionally its stored credentials."""
224
+ target_config_path = config.get_config_path(local=local)
225
+ current_config, _ = config.get_config_from_path(target_config_path)
226
+
227
+ providers = current_config.get("providers", {})
228
+
229
+ if provider not in providers:
230
+ click.echo(f"Error: Provider '{provider}' not found.")
231
+ return
232
+
233
+ provider_data = providers[provider]
234
+ models = provider_data.get("models", {})
235
+
236
+ if model not in models:
237
+ click.echo(f"Error: Model '{model}' not found for provider '{provider}'.")
238
+ return
239
+
240
+ # Remove the model from configuration
241
+ del models[model]
242
+
243
+ # If no models left for this provider, remove the entire provider
244
+ if not models:
245
+ del providers[provider]
246
+ click.echo(f"Removed provider '{provider}' (no models remaining).")
247
+ else:
248
+ click.echo(f"Removed model '{model}' from provider '{provider}'.")
249
+
250
+ # Check if this was the currently selected model
251
+ current_model = current_config.get("current_model", {})
252
+ if (
253
+ current_model.get("provider") == provider
254
+ and current_model.get("model") == model
255
+ ):
256
+ current_config["current_model"] = {}
257
+ click.echo("Cleared current model selection (removed model was selected).")
258
+
259
+ # Save the updated configuration
260
+ config.save_config(current_config, local=local)
261
+
262
+ # Remove stored credentials unless explicitly kept
263
+ if not keep_credentials:
264
+ try:
265
+ # Clear auth data from config file
266
+ auth_strategy = auth.get_auth_strategy(
267
+ provider, current_config["providers"][provider]
268
+ )
269
+ auth_strategy.clear_credentials()
270
+ click.echo(f"Removed stored credentials for {provider}")
271
+ except Exception as e:
272
+ click.echo(f"Warning: Could not remove credentials from config: {e}")
273
+ else:
274
+ click.echo("Kept stored credentials (--keep-credentials flag used).")
275
+
276
+
277
+ @cli.command(name="test")
278
+ @click.option("--prompt", required=True, help="The prompt to send to the model.")
279
+ @click.option("--verbose", is_flag=True, help="Enable verbose debug output.")
280
+ def test_model(prompt: str, verbose: bool) -> None:
281
+ """Tests the currently selected model with a prompt."""
282
+ try:
283
+ current_model = config.get_current_model()
284
+ if not current_model:
285
+ logger.error("No model selected for testing")
286
+ click.echo(
287
+ "Error: No model selected. Use 'modelforge config use'.", err=True
288
+ )
289
+ return
290
+
291
+ provider_name = current_model.get("provider")
292
+ model_alias = current_model.get("model")
293
+
294
+ logger.info("Testing model %s/%s with prompt", provider_name, model_alias)
295
+ click.echo(
296
+ f"Sending prompt to the selected model [{provider_name}/{model_alias}]..."
297
+ )
298
+
299
+ # Step 1: Instantiate the registry and get the model
300
+ registry = ModelForgeRegistry(verbose=verbose)
301
+ llm = registry.get_llm() # Gets the currently selected model
302
+
303
+ if not llm:
304
+ logger.error(
305
+ "Failed to instantiate language model for %s/%s",
306
+ provider_name,
307
+ model_alias,
308
+ )
309
+ click.echo(
310
+ "Failed to instantiate the language model. Check logs for details.",
311
+ err=True,
312
+ )
313
+ return
314
+
315
+ # Step 2: Create the prompt and chain
316
+ prompt_template = ChatPromptTemplate.from_messages([("human", "{input}")])
317
+ chain = prompt_template | llm | StrOutputParser()
318
+
319
+ # Step 3: Run the chain with smart retry if the provider is GitHub Copilot
320
+ if provider_name == "github_copilot":
321
+ response = _invoke_with_smart_retry(chain, {"input": prompt}, verbose)
322
+ else:
323
+ response = chain.invoke({"input": prompt})
324
+
325
+ click.echo(response)
326
+
327
+ except Exception as e:
328
+ logger.exception("Error occurred while running model test")
329
+ click.echo(f"\nAn error occurred while running the model: {e}", err=True)
330
+
331
+
332
+ @cli.group(name="auth")
333
+ def auth_group() -> None:
334
+ """Authentication management commands."""
335
+
336
+
337
+ @auth_group.command(name="login")
338
+ @click.option("--provider", required=True, help="The provider to authenticate with")
339
+ @click.option("--api-key", help="API key for authentication (skips interactive prompt)")
340
+ @click.option(
341
+ "--force", is_flag=True, help="Force re-authentication even if credentials exist"
342
+ )
343
+ def auth_login(provider: str, api_key: str | None, force: bool) -> None:
344
+ """Authenticate with a provider using API key or device flow."""
345
+ try:
346
+ current_config, _ = config.get_config()
347
+ providers = current_config.get("providers", {})
348
+
349
+ if provider not in providers:
350
+ click.echo(f"❌ Provider '{provider}' not found in configuration")
351
+ click.echo("Use 'modelforge models list --provider models.dev'")
352
+ click.echo("to discover providers")
353
+ return
354
+
355
+ provider_data = providers[provider]
356
+ auth_strategy = auth.get_auth_strategy(provider, provider_data)
357
+
358
+ # Check for existing credentials
359
+ if not force and auth_strategy.get_credentials():
360
+ click.echo(f"✅ Already authenticated with '{provider}'")
361
+ click.echo("Use --force to re-authenticate")
362
+ return
363
+
364
+ # Handle different auth strategies
365
+ if isinstance(auth_strategy, auth.ApiKeyAuth):
366
+ if api_key:
367
+ auth_strategy.store_api_key(api_key)
368
+ click.echo(f"✅ API key stored for provider '{provider}'")
369
+ else:
370
+ credentials = auth_strategy.authenticate()
371
+ if credentials:
372
+ click.echo(f"✅ Successfully authenticated with '{provider}'")
373
+ else:
374
+ click.echo(f"❌ Authentication failed for '{provider}'")
375
+ elif isinstance(auth_strategy, auth.DeviceFlowAuth):
376
+ click.echo(f"🔐 Starting device flow authentication for '{provider}'...")
377
+ credentials = auth_strategy.authenticate()
378
+ if credentials:
379
+ click.echo(f"✅ Successfully authenticated with '{provider}'")
380
+ else:
381
+ click.echo(f"❌ Authentication failed for '{provider}'")
382
+ else:
383
+ click.echo(f"✅ Provider '{provider}' doesn't require authentication")
384
+
385
+ except Exception as e:
386
+ logger.exception("Authentication failed")
387
+ click.echo(f"❌ Authentication failed: {e}", err=True)
388
+
389
+
390
+ @auth_group.command(name="logout")
391
+ @click.option("--provider", required=False, help="The provider to log out from")
392
+ @click.option("--all-providers", is_flag=True, help="Log out from all providers")
393
+ def auth_logout(provider: str | None, all_providers: bool) -> None:
394
+ """Clear stored credentials for a provider."""
395
+ if all_providers:
396
+ current_config, _ = config.get_config()
397
+ providers = current_config.get("providers", {})
398
+
399
+ for provider_name in providers:
400
+ try:
401
+ provider_data = providers[provider_name]
402
+ auth_strategy = auth.get_auth_strategy(provider_name, provider_data)
403
+ auth_strategy.clear_credentials()
404
+ click.echo(f"✅ Cleared credentials for '{provider_name}'")
405
+ except Exception as e:
406
+ click.echo(f"⚠️ Failed to clear credentials for '{provider_name}': {e}")
407
+ elif provider:
408
+ current_config, _ = config.get_config()
409
+ providers = current_config.get("providers", {})
410
+
411
+ if provider not in providers:
412
+ click.echo(f"❌ Provider '{provider}' not found")
413
+ return
414
+
415
+ provider_data = providers[provider]
416
+ auth_strategy = auth.get_auth_strategy(provider, provider_data)
417
+ auth_strategy.clear_credentials()
418
+ click.echo(f"✅ Cleared credentials for '{provider}'")
419
+ else:
420
+ raise click.BadParameter("Either --provider or --all-providers is required")
421
+
422
+
423
+ @auth_group.command(name="status")
424
+ @click.option("--provider", help="Check status for specific provider")
425
+ @click.option("--verbose", is_flag=True, help="Show detailed token information")
426
+ def auth_status(provider: str | None, verbose: bool) -> None:
427
+ """Check authentication status for providers."""
428
+ try:
429
+ current_config, _ = config.get_config()
430
+ providers = current_config.get("providers", {})
431
+
432
+ if provider:
433
+ # Check specific provider
434
+ if provider not in providers:
435
+ click.echo(f"❌ Provider '{provider}' not found in configuration")
436
+ return
437
+ _check_provider_status(provider, providers[provider], verbose)
438
+ else:
439
+ # Check all providers
440
+ click.echo("🔍 Authentication Status for All Providers:\n")
441
+ for provider_name, provider_data in providers.items():
442
+ _check_provider_status(provider_name, provider_data, verbose)
443
+ click.echo() # Empty line between providers
444
+
445
+ except Exception as e:
446
+ logger.exception("Failed to check authentication status")
447
+ click.echo(f"❌ Error checking status: {e}", err=True)
448
+
449
+
450
+ @cli.group()
451
+ def models() -> None:
452
+ """Model discovery and management commands."""
453
+
454
+
455
+ @models.command(name="list")
456
+ @click.option("--provider", help="Filter by specific provider")
457
+ @click.option("--refresh", is_flag=True, help="Force refresh from models.dev")
458
+ @click.option(
459
+ "--format", "output_format", type=click.Choice(["table", "json"]), default="table"
460
+ )
461
+ def list_models(provider: str | None, refresh: bool, output_format: str) -> None:
462
+ """List available models from models.dev."""
463
+ try:
464
+ client = ModelsDevClient()
465
+ models = client.get_models(provider=provider, force_refresh=refresh)
466
+
467
+ if not models:
468
+ click.echo("No models found")
469
+ return
470
+
471
+ if output_format == "json":
472
+ click.echo(json.dumps(models, indent=2))
473
+ else:
474
+ # Table format
475
+ click.echo(f"\n📋 Found {len(models)} models:\n")
476
+
477
+ # Group by provider
478
+ provider_models: dict[str, list[dict[str, Any]]] = {}
479
+ for model in models:
480
+ prov = model.get("provider", "unknown")
481
+ if prov not in provider_models:
482
+ provider_models[prov] = []
483
+ provider_models[prov].append(model)
484
+
485
+ for prov, prov_models in provider_models.items():
486
+ click.echo(f"🤖 {prov.upper()}:")
487
+ for model in prov_models:
488
+ name = model.get("name", "unknown")
489
+ display_name = model.get("display_name", name)
490
+ description = model.get("description", "")
491
+ if len(description) > 50:
492
+ description = description[:47] + "..."
493
+
494
+ click.echo(f" • {display_name} - {description}")
495
+ click.echo()
496
+
497
+ except Exception as e:
498
+ logger.exception("Failed to list models")
499
+ click.echo(f"❌ Error listing models: {e}", err=True)
500
+
501
+
502
+ @models.command(name="search")
503
+ @click.argument("query")
504
+ @click.option("--provider", help="Filter by specific provider")
505
+ @click.option("--capability", multiple=True, help="Filter by capabilities")
506
+ @click.option("--max-price", type=float, help="Maximum price per 1K tokens")
507
+ @click.option("--refresh", is_flag=True, help="Force refresh from models.dev")
508
+ def search_models(
509
+ query: str,
510
+ provider: str | None,
511
+ capability: tuple[str, ...],
512
+ max_price: float | None,
513
+ refresh: bool,
514
+ ) -> None:
515
+ """Search models by name, description, or capabilities."""
516
+ try:
517
+ client = ModelsDevClient()
518
+ models = client.search_models(
519
+ query=query,
520
+ provider=provider,
521
+ capabilities=list(capability) if capability else None,
522
+ max_price=max_price,
523
+ force_refresh=refresh,
524
+ )
525
+
526
+ if not models:
527
+ click.echo("No models found matching criteria")
528
+ return
529
+
530
+ click.echo(f"\n🔍 Found {len(models)} matching models:\n")
531
+ for model in models:
532
+ name = model.get("name", "unknown")
533
+ display_name = model.get("display_name", name)
534
+ provider_name = model.get("provider", "unknown")
535
+ description = model.get("description", "")
536
+
537
+ click.echo(f"🤖 {display_name} ({provider_name})")
538
+ click.echo(f" {description}")
539
+
540
+ if model.get("capabilities"):
541
+ caps = ", ".join(model["capabilities"])
542
+ click.echo(f" Capabilities: {caps}")
543
+
544
+ if model.get("pricing"):
545
+ pricing = model["pricing"]
546
+ if "input_per_1k_tokens" in pricing:
547
+ click.echo(f" Price: ${pricing['input_per_1k_tokens']}/1K tokens")
548
+
549
+ click.echo()
550
+
551
+ except Exception as e:
552
+ logger.exception("Failed to search models")
553
+ click.echo(f"❌ Error searching models: {e}", err=True)
554
+
555
+
556
+ @models.command(name="info")
557
+ @click.option("--provider", required=True, help="The provider name")
558
+ @click.option("--model", required=True, help="The model name")
559
+ @click.option("--refresh", is_flag=True, help="Force refresh from models.dev")
560
+ def model_info(provider: str, model: str, refresh: bool) -> None:
561
+ """Get detailed information about a specific model."""
562
+ try:
563
+ client = ModelsDevClient()
564
+ info = client.get_model_info(
565
+ provider=provider, model=model, force_refresh=refresh
566
+ )
567
+
568
+ click.echo(json.dumps(info, indent=2))
569
+
570
+ except Exception as e:
571
+ logger.exception("Failed to get model info")
572
+ click.echo(f"❌ Error getting model info: {e}", err=True)
573
+
574
+
575
+ @cli.command(name="status")
576
+ @click.option("--provider", help="Check status for specific provider")
577
+ @click.option("--verbose", is_flag=True, help="Show detailed token information")
578
+ def status(provider: str | None, verbose: bool) -> None:
579
+ """Check authentication status for providers (deprecated, use 'auth status')."""
580
+ click.echo("⚠️ This command is deprecated. Use 'modelforge auth status' instead.")
581
+ auth_status(provider, verbose)
582
+
583
+
584
+ def _check_provider_status(
585
+ provider_name: str, provider_data: dict[str, Any], verbose: bool
586
+ ) -> None:
587
+ """Check status for a specific provider."""
588
+ auth_strategy_name = provider_data.get("auth_strategy", "unknown")
589
+
590
+ click.echo(f"📋 Provider: {provider_name}")
591
+ click.echo(f" Auth Strategy: {auth_strategy_name}")
592
+
593
+ if auth_strategy_name == "local":
594
+ click.echo(" Status: ✅ Local provider (no authentication needed)")
595
+ return
596
+
597
+ try:
598
+ auth_strategy = auth.get_auth_strategy(provider_name, provider_data)
599
+ credentials = auth_strategy.get_credentials()
600
+
601
+ if credentials:
602
+ click.echo(" Status: ✅ Valid credentials found")
603
+
604
+ # Show detailed token info for device flow
605
+ if auth_strategy_name == "device_flow" and hasattr(
606
+ auth_strategy, "get_token_info"
607
+ ):
608
+ token_info = auth_strategy.get_token_info()
609
+ if token_info and verbose:
610
+ click.echo(" Token Details:")
611
+ if "time_remaining" in token_info:
612
+ click.echo(
613
+ f" Time Remaining: {token_info['time_remaining']}"
614
+ )
615
+ if "expiry_time" in token_info:
616
+ click.echo(f" Expires At: {token_info['expiry_time']}")
617
+ if "scope" in token_info:
618
+ click.echo(f" Scope: {token_info['scope']}")
619
+ elif token_info and not verbose:
620
+ if "time_remaining" in token_info:
621
+ remaining = str(token_info["time_remaining"]).split(".")[
622
+ 0
623
+ ] # Remove microseconds
624
+ click.echo(f" Time Remaining: {remaining}")
625
+ else:
626
+ click.echo(" Status: ❌ No valid credentials found")
627
+ click.echo(f" Action: Run authentication for {provider_name}")
628
+
629
+ except Exception as e:
630
+ click.echo(f" Status: ❌ Error checking credentials: {e}")
631
+
632
+
633
+ def _invoke_with_smart_retry(
634
+ chain: Any, # noqa: ANN401
635
+ input_data: dict[str, Any],
636
+ verbose: bool = False,
637
+ max_retries: int = 3,
638
+ ) -> str:
639
+ """
640
+ Invokes a LangChain model with smart retry logic for GitHub Copilot rate limits.
641
+
642
+ Args:
643
+ chain: The LangChain model to invoke
644
+ input_data: The input data to pass to the model
645
+ verbose: Whether to show verbose output
646
+ max_retries: Maximum number of retry attempts
647
+
648
+ Returns:
649
+ The model response
650
+
651
+ Raises:
652
+ ProviderError: If max retries are reached for rate limiting
653
+ Exception: For non-rate-limit errors
654
+ """
655
+ last_exception = None
656
+
657
+ for attempt in range(max_retries):
658
+ try:
659
+ if attempt > 0:
660
+ logger.info(
661
+ "Retry attempt %d/%d for GitHub Copilot", attempt + 1, max_retries
662
+ )
663
+ if verbose:
664
+ click.echo(
665
+ f"🔄 Retry attempt {attempt + 1}/{max_retries} "
666
+ f"for GitHub Copilot..."
667
+ )
668
+
669
+ result = chain.invoke(input_data)
670
+ if isinstance(result, BaseMessage):
671
+ return str(result.content)
672
+ return str(result)
673
+
674
+ except Exception as e:
675
+ last_exception = e
676
+ error_msg = str(e).lower()
677
+
678
+ # Check if this is a rate limiting error that we should retry
679
+ if any(
680
+ phrase in error_msg
681
+ for phrase in ["forbidden", "rate limit", "too many requests"]
682
+ ):
683
+ if attempt < max_retries - 1: # Don't sleep on the last attempt
684
+ # Exponential backoff with jitter: 1s, 2s, 4s + random(0-1)
685
+ # Note: Using random for non-cryptographic backoff delay
686
+ delay = (2**attempt) + random.uniform(0, 1) # noqa: S311
687
+
688
+ logger.warning(
689
+ "Rate limited by GitHub Copilot. Waiting %.1fs before retry",
690
+ delay,
691
+ )
692
+ if verbose:
693
+ click.echo(
694
+ f"⏳ Rate limited by GitHub Copilot. "
695
+ f"Waiting {delay:.1f}s before retry..."
696
+ )
697
+
698
+ time.sleep(delay)
699
+ continue
700
+ logger.exception(
701
+ "Max retries (%d) reached for GitHub Copilot rate limiting",
702
+ max_retries,
703
+ )
704
+ if verbose:
705
+ click.echo(
706
+ f"❌ Max retries ({max_retries}) reached "
707
+ f"for GitHub Copilot rate limiting"
708
+ )
709
+ else:
710
+ # Non-rate-limit error, don't retry
711
+ logger.exception("Non-rate-limit error in GitHub Copilot call")
712
+ raise
713
+
714
+ # If we get here, all retries failed
715
+ logger.error("All retry attempts failed for GitHub Copilot")
716
+ raise last_exception or RuntimeError("All retry attempts failed")
717
+
718
+
719
+ if __name__ == "__main__":
720
+ cli()