aru-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aru/providers.py ADDED
@@ -0,0 +1,433 @@
1
+ """Multi-provider LLM abstraction for aru.
2
+
3
+ Supports provider/model format (e.g., "anthropic/claude-sonnet-4-5", "ollama/llama3.1").
4
+ Maps provider names to Agno model classes and handles provider-specific configuration.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Built-in provider definitions
16
+ # ---------------------------------------------------------------------------
17
+
18
+ @dataclass
19
+ class ProviderConfig:
20
+ """Configuration for an LLM provider."""
21
+ name: str
22
+ api_key_env: str | None = None
23
+ base_url: str | None = None
24
+ default_model: str | None = None
25
+ models: dict[str, dict[str, Any]] = field(default_factory=dict)
26
+ options: dict[str, Any] = field(default_factory=dict)
27
+
28
+
29
+ # Built-in providers with sensible defaults
30
+ BUILTIN_PROVIDERS: dict[str, ProviderConfig] = {
31
+ "anthropic": ProviderConfig(
32
+ name="Anthropic",
33
+ api_key_env="ANTHROPIC_API_KEY",
34
+ default_model="claude-sonnet-4-5-20250929",
35
+ models={
36
+ "claude-sonnet-4-5": {"id": "claude-sonnet-4-5-20250929", "max_tokens": 16384},
37
+ "claude-sonnet-4-6": {"id": "claude-sonnet-4-6", "max_tokens": 64000},
38
+ "claude-opus-4": {"id": "claude-opus-4-20250514", "max_tokens": 32000},
39
+ "claude-opus-4-6": {"id": "claude-opus-4-6", "max_tokens": 64000},
40
+ "claude-haiku-3-5": {"id": "claude-haiku-3-5-20241022", "max_tokens": 8192},
41
+ "claude-haiku-4-5": {"id": "claude-haiku-4-5-20251001", "max_tokens": 8192},
42
+ # Full IDs also work as-is
43
+ "claude-sonnet-4-5-20250929": {"id": "claude-sonnet-4-5-20250929", "max_tokens": 16384},
44
+ "claude-opus-4-20250514": {"id": "claude-opus-4-20250514", "max_tokens": 32000},
45
+ "claude-haiku-3-5-20241022": {"id": "claude-haiku-3-5-20241022", "max_tokens": 8192},
46
+ "claude-haiku-4-5-20251001": {"id": "claude-haiku-4-5-20251001", "max_tokens": 8192},
47
+ },
48
+ ),
49
+ "openai": ProviderConfig(
50
+ name="OpenAI",
51
+ api_key_env="OPENAI_API_KEY",
52
+ default_model="gpt-4o",
53
+ models={
54
+ "gpt-4o": {"id": "gpt-4o", "max_tokens": 4096},
55
+ "gpt-4o-mini": {"id": "gpt-4o-mini", "max_tokens": 4096},
56
+ "gpt-4.1": {"id": "gpt-4.1", "max_tokens": 4096},
57
+ "gpt-4.1-mini": {"id": "gpt-4.1-mini", "max_tokens": 4096},
58
+ "gpt-4.1-nano": {"id": "gpt-4.1-nano", "max_tokens": 4096},
59
+ "o3-mini": {"id": "o3-mini", "max_tokens": 4096},
60
+ },
61
+ ),
62
+ "ollama": ProviderConfig(
63
+ name="Ollama",
64
+ base_url="http://localhost:11434",
65
+ default_model="llama3.1",
66
+ models={}, # Ollama models are dynamic - any installed model works
67
+ ),
68
+ "groq": ProviderConfig(
69
+ name="Groq",
70
+ api_key_env="GROQ_API_KEY",
71
+ default_model="llama-3.3-70b-versatile",
72
+ models={
73
+ "llama-3.3-70b-versatile": {"id": "llama-3.3-70b-versatile", "max_tokens": 4096},
74
+ "llama-3.1-8b-instant": {"id": "llama-3.1-8b-instant", "max_tokens": 4096},
75
+ "mixtral-8x7b-32768": {"id": "mixtral-8x7b-32768", "max_tokens": 4096},
76
+ },
77
+ ),
78
+ "openrouter": ProviderConfig(
79
+ name="OpenRouter",
80
+ api_key_env="OPENROUTER_API_KEY",
81
+ default_model="anthropic/claude-sonnet-4-5",
82
+ models={}, # OpenRouter supports hundreds of models dynamically
83
+ ),
84
+ "deepseek": ProviderConfig(
85
+ name="DeepSeek",
86
+ api_key_env="DEEPSEEK_API_KEY",
87
+ default_model="deepseek-chat",
88
+ models={
89
+ "deepseek-chat": {"id": "deepseek-chat", "max_tokens": 8192},
90
+ "deepseek-chat-v3-0324": {"id": "deepseek-chat-v3-0324", "max_tokens": 16384},
91
+ "deepseek-reasoner": {"id": "deepseek-reasoner", "max_tokens": 16384},
92
+ },
93
+ ),
94
+ }
95
+
96
+ # Common short names (map to anthropic/ provider)
97
+ MODEL_ALIASES: dict[str, str] = {
98
+ "sonnet": "anthropic/claude-sonnet-4-5",
99
+ "opus": "anthropic/claude-opus-4",
100
+ "haiku": "anthropic/claude-haiku-3-5",
101
+ }
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Provider registry (built-ins + user overrides from aru.json)
106
+ # ---------------------------------------------------------------------------
107
+
108
+ _providers: dict[str, ProviderConfig] = {}
109
+
110
+
111
+ def _init_providers():
112
+ """Initialize provider registry with built-in defaults."""
113
+ global _providers
114
+ _providers = {k: v for k, v in BUILTIN_PROVIDERS.items()}
115
+
116
+
117
+ _init_providers()
118
+
119
+
120
+ def register_provider(key: str, config: ProviderConfig):
121
+ """Register or override a provider configuration."""
122
+ _providers[key] = config
123
+
124
+
125
+ def get_provider(key: str) -> ProviderConfig | None:
126
+ """Get provider config by key."""
127
+ return _providers.get(key)
128
+
129
+
130
+ def list_providers() -> dict[str, ProviderConfig]:
131
+ """Return all registered providers."""
132
+ return dict(_providers)
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Load user provider overrides from config
137
+ # ---------------------------------------------------------------------------
138
+
139
+ def load_providers_from_config(config_data: dict[str, Any]):
140
+ """Merge user-defined providers from aru.json into the registry.
141
+
142
+ Expected format in aru.json:
143
+ {
144
+ "providers": {
145
+ "ollama": {
146
+ "base_url": "http://localhost:11434",
147
+ "models": {
148
+ "deepseek-coder-v2": {"id": "deepseek-coder-v2:latest"}
149
+ }
150
+ },
151
+ "my-custom": {
152
+ "type": "openai",
153
+ "name": "My Custom Provider",
154
+ "api_key_env": "MY_API_KEY",
155
+ "base_url": "https://my-api.example.com/v1",
156
+ "models": {
157
+ "my-model": {"id": "my-model-v1"}
158
+ }
159
+ }
160
+ },
161
+ "models": {
162
+ "default": "anthropic/claude-sonnet-4-5",
163
+ "small": "anthropic/claude-haiku-4-5"
164
+ }
165
+ }
166
+ """
167
+ providers_data = config_data.get("providers", {})
168
+ for key, pdata in providers_data.items():
169
+ if not isinstance(pdata, dict):
170
+ continue
171
+
172
+ # If this extends a built-in, start from that base
173
+ existing = _providers.get(key)
174
+ if existing:
175
+ # Merge: user config overrides built-in fields
176
+ if "name" in pdata:
177
+ existing.name = pdata["name"]
178
+ if "api_key_env" in pdata:
179
+ existing.api_key_env = pdata["api_key_env"]
180
+ if "base_url" in pdata:
181
+ existing.base_url = pdata["base_url"]
182
+ if "default_model" in pdata:
183
+ existing.default_model = pdata["default_model"]
184
+ if "models" in pdata:
185
+ existing.models.update(pdata["models"])
186
+ if "options" in pdata:
187
+ existing.options.update(pdata["options"])
188
+ else:
189
+ # New provider - "type" field tells us which Agno class to use
190
+ _providers[key] = ProviderConfig(
191
+ name=pdata.get("name", key),
192
+ api_key_env=pdata.get("api_key_env"),
193
+ base_url=pdata.get("base_url"),
194
+ default_model=pdata.get("default_model"),
195
+ models=pdata.get("models", {}),
196
+ options=pdata.get("options", {}),
197
+ )
198
+ # Store the type hint for model creation
199
+ if "type" in pdata:
200
+ _providers[key].options["_provider_type"] = pdata["type"]
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Model resolution
205
+ # ---------------------------------------------------------------------------
206
+
207
+ def resolve_model_ref(model_ref: str) -> tuple[str, str]:
208
+ """Resolve a model reference to (provider_key, model_id).
209
+
210
+ Accepts:
211
+ - "anthropic/claude-sonnet-4-5" → ("anthropic", "claude-sonnet-4-5")
212
+ - "ollama/llama3.1" → ("ollama", "llama3.1")
213
+ - "sonnet" → ("anthropic", "claude-sonnet-4-5") (legacy alias)
214
+ - "anthropic" → ("anthropic", <default_model>)
215
+ """
216
+ # Check legacy aliases first
217
+ if model_ref in MODEL_ALIASES:
218
+ model_ref = MODEL_ALIASES[model_ref]
219
+
220
+ if "/" in model_ref:
221
+ provider_key, model_name = model_ref.split("/", 1)
222
+ else:
223
+ # Could be a provider name (use its default) or unknown
224
+ if model_ref in _providers:
225
+ provider_key = model_ref
226
+ provider = _providers[provider_key]
227
+ model_name = provider.default_model or ""
228
+ else:
229
+ # Assume anthropic for backward compatibility
230
+ provider_key = "anthropic"
231
+ model_name = model_ref
232
+
233
+ return provider_key, model_name
234
+
235
+
236
+ def _get_actual_model_id(provider: ProviderConfig, model_name: str) -> str:
237
+ """Get the actual model ID to send to the API.
238
+
239
+ If the model name is in the provider's model registry, use its 'id' field.
240
+ Otherwise, pass the model name through as-is (supports dynamic models like Ollama).
241
+ """
242
+ if model_name in provider.models:
243
+ return provider.models[model_name].get("id", model_name)
244
+ return model_name
245
+
246
+
247
+ def _get_max_tokens(provider: ProviderConfig, model_name: str, default: int = 4096) -> int:
248
+ """Get max_tokens for a model, falling back to default."""
249
+ if model_name in provider.models:
250
+ return provider.models[model_name].get("max_tokens", default)
251
+ return default
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # Model creation — the core function
256
+ # ---------------------------------------------------------------------------
257
+
258
+ def create_model(
259
+ model_ref: str,
260
+ max_tokens: int | None = None,
261
+ cache_system_prompt: bool = True,
262
+ **kwargs,
263
+ ):
264
+ """Create an Agno model instance from a provider/model reference.
265
+
266
+ Args:
267
+ model_ref: Provider/model string (e.g., "anthropic/claude-sonnet-4-5", "ollama/llama3.1")
268
+ max_tokens: Override max tokens (uses provider default if None)
269
+ cache_system_prompt: Whether to cache system prompt (Anthropic-specific)
270
+ **kwargs: Extra provider-specific parameters
271
+
272
+ Returns:
273
+ An Agno model instance ready for use with Agent()
274
+
275
+ Raises:
276
+ ValueError: If provider is unknown or required dependencies are missing.
277
+ """
278
+ provider_key, model_name = resolve_model_ref(model_ref)
279
+ provider = _providers.get(provider_key)
280
+
281
+ if provider is None:
282
+ available = ", ".join(sorted(_providers.keys()))
283
+ raise ValueError(f"Unknown provider '{provider_key}'. Available: {available}")
284
+
285
+ model_id = _get_actual_model_id(provider, model_name)
286
+ effective_max_tokens = max_tokens or _get_max_tokens(provider, model_name, 4096)
287
+
288
+ # Determine the actual provider type (for custom providers with "type" field)
289
+ provider_type = provider.options.get("_provider_type", provider_key)
290
+
291
+ return _create_provider_model(
292
+ provider_type=provider_type,
293
+ provider=provider,
294
+ model_id=model_id,
295
+ max_tokens=effective_max_tokens,
296
+ cache_system_prompt=cache_system_prompt,
297
+ **kwargs,
298
+ )
299
+
300
+
301
+ def _create_provider_model(
302
+ provider_type: str,
303
+ provider: ProviderConfig,
304
+ model_id: str,
305
+ max_tokens: int,
306
+ cache_system_prompt: bool,
307
+ **kwargs,
308
+ ):
309
+ """Instantiate the correct Agno model class based on provider type."""
310
+
311
+ if provider_type == "anthropic":
312
+ from agno.models.anthropic import Claude
313
+ api_key = _resolve_api_key(provider)
314
+ params = {"id": model_id, "max_tokens": max_tokens}
315
+ if cache_system_prompt:
316
+ params["cache_system_prompt"] = True
317
+ if api_key:
318
+ params["api_key"] = api_key
319
+ params.update(kwargs)
320
+ return Claude(**params)
321
+
322
+ elif provider_type == "openai":
323
+ from agno.models.openai import OpenAIChat
324
+ api_key = _resolve_api_key(provider)
325
+ params = {"id": model_id, "max_tokens": max_tokens}
326
+ if api_key:
327
+ params["api_key"] = api_key
328
+ if provider.base_url:
329
+ params["base_url"] = provider.base_url
330
+ params.update(kwargs)
331
+ return OpenAIChat(**params)
332
+
333
+ elif provider_type == "ollama":
334
+ from agno.models.ollama import Ollama
335
+ params = {"id": model_id}
336
+ host = provider.base_url or "http://localhost:11434"
337
+ params["host"] = host
338
+ # Ollama uses 'options' dict for num_ctx, temperature, etc.
339
+ if provider.options:
340
+ ollama_opts = {k: v for k, v in provider.options.items() if not k.startswith("_")}
341
+ if ollama_opts:
342
+ params["options"] = ollama_opts
343
+ params.update(kwargs)
344
+ return Ollama(**params)
345
+
346
+ elif provider_type == "groq":
347
+ from agno.models.groq import Groq
348
+ api_key = _resolve_api_key(provider)
349
+ params = {"id": model_id, "max_tokens": max_tokens}
350
+ if api_key:
351
+ params["api_key"] = api_key
352
+ params.update(kwargs)
353
+ return Groq(**params)
354
+
355
+ elif provider_type == "openrouter":
356
+ from agno.models.openrouter import OpenRouter
357
+ api_key = _resolve_api_key(provider)
358
+ params = {"id": model_id, "max_tokens": max_tokens}
359
+ if api_key:
360
+ params["api_key"] = api_key
361
+ params.update(kwargs)
362
+ return OpenRouter(**params)
363
+
364
+ elif provider_type == "deepseek":
365
+ from agno.models.deepseek import DeepSeek
366
+ api_key = _resolve_api_key(provider)
367
+ params = {"id": model_id, "max_tokens": max_tokens}
368
+ if api_key:
369
+ params["api_key"] = api_key
370
+ params.update(kwargs)
371
+ return DeepSeek(**params)
372
+
373
+ else:
374
+ # Fallback: try OpenAI-compatible (works for many providers)
375
+ from agno.models.openai import OpenAIChat
376
+ api_key = _resolve_api_key(provider)
377
+ params = {"id": model_id, "max_tokens": max_tokens}
378
+ if api_key:
379
+ params["api_key"] = api_key
380
+ if provider.base_url:
381
+ params["base_url"] = provider.base_url
382
+ params.update(kwargs)
383
+ return OpenAIChat(**params)
384
+
385
+
386
+ def _resolve_api_key(provider: ProviderConfig) -> str | None:
387
+ """Resolve API key from environment variable."""
388
+ if provider.api_key_env:
389
+ return os.environ.get(provider.api_key_env)
390
+ return None
391
+
392
+
393
+ # ---------------------------------------------------------------------------
394
+ # Convenience: list available models for display
395
+ # ---------------------------------------------------------------------------
396
+
397
+ def get_available_models() -> dict[str, str]:
398
+ """Return a flat dict of model_ref → display_name for all registered providers.
399
+
400
+ Includes legacy aliases.
401
+ """
402
+ models: dict[str, str] = {}
403
+
404
+ # Legacy aliases
405
+ for alias, ref in MODEL_ALIASES.items():
406
+ provider_key, model_name = resolve_model_ref(ref)
407
+ provider = _providers.get(provider_key)
408
+ if provider:
409
+ actual_id = _get_actual_model_id(provider, model_name)
410
+ models[alias] = f"{provider.name}/{actual_id}"
411
+
412
+ # All provider models
413
+ for pkey, provider in _providers.items():
414
+ if provider.models:
415
+ for mname in provider.models:
416
+ ref = f"{pkey}/{mname}"
417
+ if ref not in models:
418
+ models[ref] = f"{provider.name}/{mname}"
419
+ if provider.default_model:
420
+ ref = f"{pkey}/{provider.default_model}"
421
+ if ref not in models:
422
+ models[ref] = f"{provider.name}/{provider.default_model}"
423
+
424
+ return models
425
+
426
+
427
+ def get_model_display(model_ref: str) -> str:
428
+ """Get a human-readable display string for a model reference."""
429
+ provider_key, model_name = resolve_model_ref(model_ref)
430
+ provider = _providers.get(provider_key)
431
+ if provider:
432
+ return f"{provider.name}/{model_name}"
433
+ return model_ref
aru/tools/__init__.py ADDED
File without changes