janito 2.5.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. janito/agent/setup_agent.py +231 -223
  2. janito/agent/templates/profiles/system_prompt_template_software_developer.txt.j2 +39 -0
  3. janito/cli/chat_mode/bindings.py +1 -26
  4. janito/cli/chat_mode/session.py +282 -294
  5. janito/cli/chat_mode/session_profile_select.py +125 -55
  6. janito/cli/chat_mode/shell/commands/tools.py +51 -48
  7. janito/cli/chat_mode/toolbar.py +42 -68
  8. janito/cli/cli_commands/list_tools.py +41 -56
  9. janito/cli/cli_commands/show_system_prompt.py +70 -49
  10. janito/cli/core/runner.py +6 -1
  11. janito/cli/core/setters.py +43 -34
  12. janito/cli/main_cli.py +25 -1
  13. janito/cli/prompt_core.py +76 -69
  14. janito/cli/rich_terminal_reporter.py +22 -1
  15. janito/cli/single_shot_mode/handler.py +95 -94
  16. janito/drivers/driver_registry.py +27 -29
  17. janito/drivers/openai/driver.py +436 -494
  18. janito/llm/agent.py +54 -68
  19. janito/provider_registry.py +178 -178
  20. janito/providers/anthropic/model_info.py +41 -22
  21. janito/providers/anthropic/provider.py +80 -67
  22. janito/providers/provider_static_info.py +18 -17
  23. janito/tools/adapters/local/__init__.py +66 -65
  24. janito/tools/adapters/local/adapter.py +79 -18
  25. janito/tools/adapters/local/create_directory.py +9 -9
  26. janito/tools/adapters/local/create_file.py +12 -12
  27. janito/tools/adapters/local/delete_text_in_file.py +16 -16
  28. janito/tools/adapters/local/find_files.py +2 -2
  29. janito/tools/adapters/local/get_file_outline/core.py +5 -5
  30. janito/tools/adapters/local/get_file_outline/search_outline.py +4 -4
  31. janito/tools/adapters/local/open_html_in_browser.py +15 -15
  32. janito/tools/adapters/local/python_file_run.py +4 -4
  33. janito/tools/adapters/local/read_files.py +40 -0
  34. janito/tools/adapters/local/remove_directory.py +5 -5
  35. janito/tools/adapters/local/remove_file.py +4 -4
  36. janito/tools/adapters/local/replace_text_in_file.py +21 -21
  37. janito/tools/adapters/local/run_bash_command.py +1 -1
  38. janito/tools/adapters/local/search_text/pattern_utils.py +2 -2
  39. janito/tools/adapters/local/search_text/traverse_directory.py +10 -10
  40. janito/tools/adapters/local/validate_file_syntax/core.py +7 -7
  41. janito/tools/adapters/local/validate_file_syntax/css_validator.py +2 -2
  42. janito/tools/adapters/local/validate_file_syntax/html_validator.py +7 -7
  43. janito/tools/adapters/local/validate_file_syntax/js_validator.py +2 -2
  44. janito/tools/adapters/local/validate_file_syntax/json_validator.py +2 -2
  45. janito/tools/adapters/local/validate_file_syntax/markdown_validator.py +2 -2
  46. janito/tools/adapters/local/validate_file_syntax/ps1_validator.py +2 -2
  47. janito/tools/adapters/local/validate_file_syntax/python_validator.py +2 -2
  48. janito/tools/adapters/local/validate_file_syntax/xml_validator.py +2 -2
  49. janito/tools/adapters/local/validate_file_syntax/yaml_validator.py +2 -2
  50. janito/tools/adapters/local/view_file.py +12 -12
  51. janito/tools/path_security.py +204 -0
  52. janito/tools/tool_use_tracker.py +12 -12
  53. janito/tools/tools_adapter.py +66 -34
  54. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/METADATA +412 -412
  55. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/RECORD +59 -58
  56. janito/drivers/anthropic/driver.py +0 -113
  57. janito/tools/adapters/local/get_file_outline/python_outline_v2.py +0 -156
  58. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/WHEEL +0 -0
  59. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/entry_points.txt +0 -0
  60. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/licenses/LICENSE +0 -0
  61. {janito-2.5.1.dist-info → janito-2.6.0.dist-info}/top_level.txt +0 -0
janito/llm/agent.py CHANGED
@@ -299,38 +299,13 @@ class LLMAgent:
299
299
  role: str = "user",
300
300
  config=None,
301
301
  ):
302
- if (
303
- hasattr(self, "driver")
304
- and self.driver
305
- and hasattr(self.driver, "clear_output_queue")
306
- ):
307
- self.driver.clear_output_queue()
308
- # Drain input queue before sending new messages
309
- if (
310
- hasattr(self, "driver")
311
- and self.driver
312
- and hasattr(self.driver, "clear_input_queue")
313
- ):
314
- self.driver.clear_input_queue()
315
- """
316
- Main agent conversation loop supporting function/tool calls and conversation history extension, now as a blocking event-driven loop with event publishing.
317
-
318
- Args:
319
- prompt: The user prompt as a string (optional if messages is provided).
320
- messages: A list of message dicts (optional if prompt is provided).
321
- role: The role for the prompt (default: 'user').
322
- config: Optional driver config (defaults to provider config).
323
-
324
- Returns:
325
- The final ResponseReceived event (or error event) when the conversation is complete.
326
- """
302
+ self._clear_driver_queues()
327
303
  self._validate_and_update_history(prompt, messages, role)
328
304
  self._ensure_system_prompt()
329
305
  if config is None:
330
306
  config = self.llm_provider.driver_config
331
307
  loop_count = 1
332
308
  import threading
333
-
334
309
  cancel_event = threading.Event()
335
310
  while True:
336
311
  self._print_verbose_chat_loop(loop_count)
@@ -339,27 +314,32 @@ class LLMAgent:
339
314
  try:
340
315
  result, added_tool_results = self._process_next_response()
341
316
  except KeyboardInterrupt:
342
- # Propagate the interrupt to the caller, but signal the driver to cancel first
343
317
  cancel_event.set()
344
318
  raise
345
319
  if getattr(self, "verbose_agent", False):
346
- print(
347
- f"[agent] [DEBUG] Returned from _process_next_response: result={result}, added_tool_results={added_tool_results}"
348
- )
349
- if result is None:
350
- if getattr(self, "verbose_agent", False):
351
- print(
352
- f"[agent] [INFO] Exiting chat loop: _process_next_response returned None result (likely timeout or error). Returning (None, False)."
353
- )
354
- return None, False
355
- if not added_tool_results:
356
- if getattr(self, "verbose_agent", False):
357
- print(
358
- f"[agent] [INFO] Exiting chat loop: _process_next_response returned added_tool_results=False (final response or no more tool calls). Returning result: {result}"
359
- )
320
+ print(f"[agent] [DEBUG] Returned from _process_next_response: result={result}, added_tool_results={added_tool_results}")
321
+ if self._should_exit_chat_loop(result, added_tool_results):
360
322
  return result
361
323
  loop_count += 1
362
324
 
325
+ def _clear_driver_queues(self):
326
+ if hasattr(self, "driver") and self.driver:
327
+ if hasattr(self.driver, "clear_output_queue"):
328
+ self.driver.clear_output_queue()
329
+ if hasattr(self.driver, "clear_input_queue"):
330
+ self.driver.clear_input_queue()
331
+
332
+ def _should_exit_chat_loop(self, result, added_tool_results):
333
+ if result is None:
334
+ if getattr(self, "verbose_agent", False):
335
+ print("[agent] [INFO] Exiting chat loop: _process_next_response returned None result (likely timeout or error). Returning (None, False).")
336
+ return True
337
+ if not added_tool_results:
338
+ if getattr(self, "verbose_agent", False):
339
+ print(f"[agent] [INFO] Exiting chat loop: _process_next_response returned added_tool_results=False (final response or no more tool calls). Returning result: {result}")
340
+ return True
341
+ return False
342
+
363
343
  def _print_verbose_chat_loop(self, loop_count):
364
344
  if getattr(self, "verbose_agent", False):
365
345
  print(
@@ -436,42 +416,48 @@ class LLMAgent:
436
416
  Reset all driver config fields to the model's defaults for the current provider (overwriting any user customizations).
437
417
  """
438
418
  provider = self.llm_provider
439
- # Find model spec
440
- model_spec = None
441
- if hasattr(provider, "MODEL_SPECS"):
442
- model_spec = provider.MODEL_SPECS.get(model_name)
443
- if not model_spec:
444
- raise ValueError(f"Model '{model_name}' not found in provider MODEL_SPECS.")
445
- # Overwrite all config fields with model defaults
419
+ model_spec = self._get_model_spec(provider, model_name)
446
420
  config = getattr(provider, "driver_config", None)
447
421
  if config is None:
448
422
  return
423
+ self._apply_model_defaults_to_config(config, model_spec, model_name)
424
+ self._update_driver_model_config(model_name, config)
425
+
426
+ def _get_model_spec(self, provider, model_name):
427
+ if hasattr(provider, "MODEL_SPECS"):
428
+ model_spec = provider.MODEL_SPECS.get(model_name)
429
+ if model_spec:
430
+ return model_spec
431
+ raise ValueError(f"Model '{model_name}' not found in provider MODEL_SPECS.")
432
+
433
+ def _apply_model_defaults_to_config(self, config, model_spec, model_name):
449
434
  config.model = model_name
450
- # Standard fields, with safe conversion for int fields
451
- def safe_int(val):
452
- try:
453
- if val is None or val == "N/A":
454
- return None
455
- return int(val)
456
- except Exception:
457
- return None
458
- def safe_float(val):
459
- try:
460
- if val is None or val == "N/A":
461
- return None
462
- return float(val)
463
- except Exception:
464
- return None
465
- config.temperature = safe_float(getattr(model_spec, "default_temp", None))
466
- config.max_tokens = safe_int(getattr(model_spec, "max_response", None))
467
- config.max_completion_tokens = safe_int(getattr(model_spec, "max_cot", None))
468
- # Optionally reset other fields to None/defaults
435
+ config.temperature = self._safe_float(getattr(model_spec, "default_temp", None))
436
+ config.max_tokens = self._safe_int(getattr(model_spec, "max_response", None))
437
+ config.max_completion_tokens = self._safe_int(getattr(model_spec, "max_cot", None))
469
438
  config.top_p = None
470
439
  config.presence_penalty = None
471
440
  config.frequency_penalty = None
472
441
  config.stop = None
473
442
  config.reasoning_effort = None
474
- # Update driver if present
443
+
444
+ def _safe_int(self, val):
445
+ try:
446
+ if val is None or val == "N/A":
447
+ return None
448
+ return int(val)
449
+ except Exception:
450
+ return None
451
+
452
+ def _safe_float(self, val):
453
+ try:
454
+ if val is None or val == "N/A":
455
+ return None
456
+ return float(val)
457
+ except Exception:
458
+ return None
459
+
460
+ def _update_driver_model_config(self, model_name, config):
475
461
  if self.driver is not None:
476
462
  if hasattr(self.driver, "model_name"):
477
463
  self.driver.model_name = model_name
@@ -1,178 +1,178 @@
1
- """
2
- ProviderRegistry: Handles provider listing and selection logic for janito CLI.
3
- """
4
-
5
- from rich.table import Table
6
- from janito.cli.console import shared_console
7
- from janito.providers.registry import LLMProviderRegistry
8
- from janito.providers.provider_static_info import STATIC_PROVIDER_METADATA
9
- from janito.llm.auth import LLMAuthManager
10
- import sys
11
- from janito.exceptions import MissingProviderSelectionException
12
-
13
-
14
- class ProviderRegistry:
15
- def list_providers(self):
16
- """List all supported LLM providers as a table using rich, showing if auth is configured and supported model names."""
17
- providers = self._get_provider_names()
18
- table = self._create_table()
19
- rows = self._get_all_provider_rows(providers)
20
- self._add_rows_to_table(table, rows)
21
- self._print_table(table)
22
-
23
- def _get_provider_names(self):
24
- return list(STATIC_PROVIDER_METADATA.keys())
25
-
26
- def _create_table(self):
27
- table = Table(title="Supported LLM Providers")
28
- table.add_column("Provider", style="cyan")
29
- table.add_column("Maintainer", style="yellow", justify="center")
30
- table.add_column("Model Names", style="magenta")
31
- return table
32
-
33
- def _get_all_provider_rows(self, providers):
34
- rows = []
35
- for p in providers:
36
- info = self._get_provider_info(p)
37
- # info is (provider_name, maintainer, model_names, skip)
38
- if len(info) == 4 and info[3]:
39
- continue # skip providers flagged as not implemented
40
- rows.append(info[:3])
41
- rows.sort(key=self._maintainer_sort_key)
42
- return rows
43
-
44
- def _add_rows_to_table(self, table, rows):
45
- for idx, (p, maintainer, model_names) in enumerate(rows):
46
- table.add_row(p, maintainer, model_names)
47
- if idx != len(rows) - 1:
48
- table.add_section()
49
-
50
- def _print_table(self, table):
51
- """Print the table using rich when running in a terminal; otherwise fall back to a plain ASCII listing.
52
- This avoids UnicodeDecodeError when the parent process captures the output with a non-UTF8 encoding.
53
- """
54
- import sys
55
-
56
- if sys.stdout.isatty():
57
- # Safe to use rich's unicode output when attached to an interactive terminal.
58
- shared_console.print(table)
59
- return
60
-
61
- # Fallback: plain ASCII output (render without rich formatting)
62
- print("Supported LLM Providers")
63
- # Build header from column titles
64
- header_titles = [column.header or "" for column in table.columns]
65
- print(" | ".join(header_titles))
66
- # rich.table.Row objects in recent Rich versions don't expose a public `.cells` attribute.
67
- # Instead, cell content is stored in each column's private `_cells` list.
68
- for row_index, _ in enumerate(table.rows):
69
- cells_text = [str(column._cells[row_index]) for column in table.columns]
70
- ascii_row = " | ".join(cells_text).encode("ascii", "ignore").decode("ascii")
71
- print(ascii_row)
72
-
73
- def _get_provider_info(self, provider_name):
74
- static_info = STATIC_PROVIDER_METADATA.get(provider_name, {})
75
- maintainer_val = static_info.get("maintainer", "-")
76
- maintainer = (
77
- "[red]🚨 Needs maintainer[/red]"
78
- if maintainer_val == "Needs maintainer"
79
- else f"👤 {maintainer_val}"
80
- )
81
- model_names = "-"
82
- unavailable_reason = None
83
- skip = False
84
- try:
85
- provider_class = LLMProviderRegistry.get(provider_name)
86
- creds = LLMAuthManager().get_credentials(provider_name)
87
- provider_instance = None
88
- instantiation_failed = False
89
- try:
90
- provider_instance = provider_class()
91
- except NotImplementedError:
92
- skip = True
93
- unavailable_reason = "Not implemented"
94
- model_names = f"[red]❌ Not implemented[/red]"
95
- except Exception as e:
96
- instantiation_failed = True
97
- unavailable_reason = (
98
- f"Unavailable (import error or missing dependency): {str(e)}"
99
- )
100
- model_names = f"[red]❌ {unavailable_reason}[/red]"
101
- if not instantiation_failed and provider_instance is not None:
102
- available, unavailable_reason = self._get_availability(
103
- provider_instance
104
- )
105
- if (
106
- not available
107
- and unavailable_reason
108
- and "not implemented" in str(unavailable_reason).lower()
109
- ):
110
- skip = True
111
- if available:
112
- model_names = self._get_model_names(provider_name)
113
- else:
114
- model_names = f"[red]❌ {unavailable_reason}[/red]"
115
- except Exception as import_error:
116
- model_names = f"[red]❌ Unavailable (cannot import provider module): {str(import_error)}[/red]"
117
- return (provider_name, maintainer, model_names, skip)
118
-
119
- def _get_availability(self, provider_instance):
120
- try:
121
- available = getattr(provider_instance, "available", True)
122
- unavailable_reason = getattr(provider_instance, "unavailable_reason", None)
123
- except Exception as e:
124
- available = False
125
- unavailable_reason = f"Error reading runtime availability: {str(e)}"
126
- return available, unavailable_reason
127
-
128
- def _get_model_names(self, provider_name):
129
- provider_to_specs = {
130
- "openai": "janito.providers.openai.model_info",
131
- "azure_openai": "janito.providers.azure_openai.model_info",
132
- "google": "janito.providers.google.model_info",
133
-
134
- "deepseek": "janito.providers.deepseek.model_info",
135
- }
136
- if provider_name in provider_to_specs:
137
- try:
138
- mod = __import__(
139
- provider_to_specs[provider_name], fromlist=["MODEL_SPECS"]
140
- )
141
- return ", ".join(mod.MODEL_SPECS.keys())
142
- except Exception:
143
- return "(Error)"
144
- return "-"
145
-
146
- def _maintainer_sort_key(self, row):
147
- maint = row[1]
148
- is_needs_maint = "Needs maintainer" in maint
149
- return (is_needs_maint, row[2] != "✅ Auth")
150
-
151
- def get_provider(self, provider_name):
152
- """Return the provider class for the given provider name. Returns None if not found."""
153
- from janito.providers.registry import LLMProviderRegistry
154
-
155
- if not provider_name:
156
- print("Error: Provider name must be specified.")
157
- return None
158
- provider_class = LLMProviderRegistry.get(provider_name)
159
- if provider_class is None:
160
- available = ', '.join(LLMProviderRegistry.list_providers())
161
- print(f"Error: Provider '{provider_name}' is not recognized. Available providers: {available}.")
162
- return None
163
- return provider_class
164
-
165
- def get_instance(self, provider_name, config=None):
166
- """Return an instance of the provider for the given provider name, optionally passing a config object. Returns None if not found."""
167
- provider_class = self.get_provider(provider_name)
168
- if provider_class is None:
169
- return None
170
- if config is not None:
171
- return provider_class(config=config)
172
- return provider_class()
173
-
174
-
175
- # For backward compatibility
176
- def list_providers():
177
- """Legacy function for listing providers, now uses ProviderRegistry class."""
178
- ProviderRegistry().list_providers()
1
+ """
2
+ ProviderRegistry: Handles provider listing and selection logic for janito CLI.
3
+ """
4
+
5
+ from rich.table import Table
6
+ from janito.cli.console import shared_console
7
+ from janito.providers.registry import LLMProviderRegistry
8
+ from janito.providers.provider_static_info import STATIC_PROVIDER_METADATA
9
+ from janito.llm.auth import LLMAuthManager
10
+ import sys
11
+ from janito.exceptions import MissingProviderSelectionException
12
+
13
+
14
+ class ProviderRegistry:
15
+ def list_providers(self):
16
+ """List all supported LLM providers as a table using rich, showing if auth is configured and supported model names."""
17
+ providers = self._get_provider_names()
18
+ table = self._create_table()
19
+ rows = self._get_all_provider_rows(providers)
20
+ self._add_rows_to_table(table, rows)
21
+ self._print_table(table)
22
+
23
+ def _get_provider_names(self):
24
+ return list(STATIC_PROVIDER_METADATA.keys())
25
+
26
+ def _create_table(self):
27
+ table = Table(title="Supported LLM Providers")
28
+ table.add_column("Provider", style="cyan")
29
+ table.add_column("Maintainer", style="yellow", justify="center")
30
+ table.add_column("Model Names", style="magenta")
31
+ return table
32
+
33
+ def _get_all_provider_rows(self, providers):
34
+ rows = []
35
+ for p in providers:
36
+ info = self._get_provider_info(p)
37
+ # info is (provider_name, maintainer, model_names, skip)
38
+ if len(info) == 4 and info[3]:
39
+ continue # skip providers flagged as not implemented
40
+ rows.append(info[:3])
41
+ rows.sort(key=self._maintainer_sort_key)
42
+ return rows
43
+
44
+ def _add_rows_to_table(self, table, rows):
45
+ for idx, (p, maintainer, model_names) in enumerate(rows):
46
+ table.add_row(p, maintainer, model_names)
47
+ if idx != len(rows) - 1:
48
+ table.add_section()
49
+
50
+ def _print_table(self, table):
51
+ """Print the table using rich when running in a terminal; otherwise fall back to a plain ASCII listing.
52
+ This avoids UnicodeDecodeError when the parent process captures the output with a non-UTF8 encoding.
53
+ """
54
+ import sys
55
+
56
+ if sys.stdout.isatty():
57
+ # Safe to use rich's unicode output when attached to an interactive terminal.
58
+ shared_console.print(table)
59
+ return
60
+
61
+ # Fallback: plain ASCII output (render without rich formatting)
62
+ print("Supported LLM Providers")
63
+ # Build header from column titles
64
+ header_titles = [column.header or "" for column in table.columns]
65
+ print(" | ".join(header_titles))
66
+ # rich.table.Row objects in recent Rich versions don't expose a public `.cells` attribute.
67
+ # Instead, cell content is stored in each column's private `_cells` list.
68
+ for row_index, _ in enumerate(table.rows):
69
+ cells_text = [str(column._cells[row_index]) for column in table.columns]
70
+ ascii_row = " | ".join(cells_text).encode("ascii", "ignore").decode("ascii")
71
+ print(ascii_row)
72
+
73
+ def _get_provider_info(self, provider_name):
74
+ static_info = STATIC_PROVIDER_METADATA.get(provider_name, {})
75
+ maintainer_val = static_info.get("maintainer", "-")
76
+ maintainer = (
77
+ "[red]🚨 Needs maintainer[/red]"
78
+ if maintainer_val == "Needs maintainer"
79
+ else f"👤 {maintainer_val}"
80
+ )
81
+ model_names = "-"
82
+ unavailable_reason = None
83
+ skip = False
84
+ try:
85
+ provider_class = LLMProviderRegistry.get(provider_name)
86
+ creds = LLMAuthManager().get_credentials(provider_name)
87
+ provider_instance = None
88
+ instantiation_failed = False
89
+ try:
90
+ provider_instance = provider_class()
91
+ except NotImplementedError:
92
+ skip = True
93
+ unavailable_reason = "Not implemented"
94
+ model_names = f"[red]❌ Not implemented[/red]"
95
+ except Exception as e:
96
+ instantiation_failed = True
97
+ unavailable_reason = (
98
+ f"Unavailable (import error or missing dependency): {str(e)}"
99
+ )
100
+ model_names = f"[red]❌ {unavailable_reason}[/red]"
101
+ if not instantiation_failed and provider_instance is not None:
102
+ available, unavailable_reason = self._get_availability(
103
+ provider_instance
104
+ )
105
+ if (
106
+ not available
107
+ and unavailable_reason
108
+ and "not implemented" in str(unavailable_reason).lower()
109
+ ):
110
+ skip = True
111
+ if available:
112
+ model_names = self._get_model_names(provider_name)
113
+ else:
114
+ model_names = f"[red]❌ {unavailable_reason}[/red]"
115
+ except Exception as import_error:
116
+ model_names = f"[red]❌ Unavailable (cannot import provider module): {str(import_error)}[/red]"
117
+ return (provider_name, maintainer, model_names, skip)
118
+
119
+ def _get_availability(self, provider_instance):
120
+ try:
121
+ available = getattr(provider_instance, "available", True)
122
+ unavailable_reason = getattr(provider_instance, "unavailable_reason", None)
123
+ except Exception as e:
124
+ available = False
125
+ unavailable_reason = f"Error reading runtime availability: {str(e)}"
126
+ return available, unavailable_reason
127
+
128
+ def _get_model_names(self, provider_name):
129
+ provider_to_specs = {
130
+ "openai": "janito.providers.openai.model_info",
131
+ "azure_openai": "janito.providers.azure_openai.model_info",
132
+ "google": "janito.providers.google.model_info",
133
+ "anthropic": "janito.providers.anthropic.model_info",
134
+ "deepseek": "janito.providers.deepseek.model_info",
135
+ }
136
+ if provider_name in provider_to_specs:
137
+ try:
138
+ mod = __import__(
139
+ provider_to_specs[provider_name], fromlist=["MODEL_SPECS"]
140
+ )
141
+ return ", ".join(mod.MODEL_SPECS.keys())
142
+ except Exception:
143
+ return "(Error)"
144
+ return "-"
145
+
146
+ def _maintainer_sort_key(self, row):
147
+ maint = row[1]
148
+ is_needs_maint = "Needs maintainer" in maint
149
+ return (is_needs_maint, row[2] != "✅ Auth")
150
+
151
+ def get_provider(self, provider_name):
152
+ """Return the provider class for the given provider name. Returns None if not found."""
153
+ from janito.providers.registry import LLMProviderRegistry
154
+
155
+ if not provider_name:
156
+ print("Error: Provider name must be specified.")
157
+ return None
158
+ provider_class = LLMProviderRegistry.get(provider_name)
159
+ if provider_class is None:
160
+ available = ', '.join(LLMProviderRegistry.list_providers())
161
+ print(f"Error: Provider '{provider_name}' is not recognized. Available providers: {available}.")
162
+ return None
163
+ return provider_class
164
+
165
+ def get_instance(self, provider_name, config=None):
166
+ """Return an instance of the provider for the given provider name, optionally passing a config object. Returns None if not found."""
167
+ provider_class = self.get_provider(provider_name)
168
+ if provider_class is None:
169
+ return None
170
+ if config is not None:
171
+ return provider_class(config=config)
172
+ return provider_class()
173
+
174
+
175
+ # For backward compatibility
176
+ def list_providers():
177
+ """Legacy function for listing providers, now uses ProviderRegistry class."""
178
+ ProviderRegistry().list_providers()
@@ -1,22 +1,41 @@
1
- from janito.llm.model import LLMModelInfo
2
-
3
- MODEL_SPECS = {
4
- "claude-3-opus-20240229": LLMModelInfo(
5
- name="claude-3-opus-20240229",
6
- max_response=200000,
7
- default_temp=0.7,
8
- driver="AnthropicModelDriver",
9
- ),
10
- "claude-3-sonnet-20240229": LLMModelInfo(
11
- name="claude-3-sonnet-20240229",
12
- max_response=200000,
13
- default_temp=0.7,
14
- driver="AnthropicModelDriver",
15
- ),
16
- "claude-3-haiku-20240307": LLMModelInfo(
17
- name="claude-3-haiku-20240307",
18
- max_response=200000,
19
- default_temp=0.7,
20
- driver="AnthropicModelDriver",
21
- ),
22
- }
1
+ from janito.llm.model import LLMModelInfo
2
+
3
+ MODEL_SPECS = {
4
+ "claude-opus-4-20250514": LLMModelInfo(
5
+ name="claude-opus-4-20250514",
6
+ max_response=32000,
7
+ default_temp=0.7,
8
+ driver="OpenAIModelDriver",
9
+ ),
10
+ "claude-sonnet-4-20250514": LLMModelInfo(
11
+ name="claude-sonnet-4-20250514",
12
+ max_response=64000,
13
+ default_temp=0.7,
14
+ driver="OpenAIModelDriver",
15
+ ),
16
+ "claude-3-7-sonnet-20250219": LLMModelInfo(
17
+ name="claude-3-7-sonnet-20250219",
18
+ max_response=64000,
19
+ default_temp=0.7,
20
+ driver="OpenAIModelDriver",
21
+ ),
22
+ "claude-3-5-haiku-20241022": LLMModelInfo(
23
+ name="claude-3-5-haiku-20241022",
24
+ max_response=8192,
25
+ default_temp=0.7,
26
+ driver="OpenAIModelDriver",
27
+ ),
28
+ "claude-3-5-sonnet-20241022": LLMModelInfo(
29
+ name="claude-3-5-sonnet-20241022",
30
+ max_response=8192,
31
+ default_temp=0.7,
32
+ driver="OpenAIModelDriver",
33
+ ),
34
+ "claude-3-haiku-20240307": LLMModelInfo(
35
+ name="claude-3-haiku-20240307",
36
+ max_response=4096,
37
+ default_temp=0.7,
38
+ driver="OpenAIModelDriver",
39
+ ),
40
+ }
41
+