llms-py 2.0.35__py3-none-any.whl → 3.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llms/main.py CHANGED
@@ -9,6 +9,8 @@
9
9
  import argparse
10
10
  import asyncio
11
11
  import base64
12
+ from datetime import datetime
13
+ import hashlib
12
14
  import json
13
15
  import mimetypes
14
16
  import os
@@ -34,11 +36,12 @@ try:
34
36
  except ImportError:
35
37
  HAS_PIL = False
36
38
 
37
- VERSION = "2.0.35"
39
+ VERSION = "3.0.0b1"
38
40
  _ROOT = None
39
41
  g_config_path = None
40
42
  g_ui_path = None
41
43
  g_config = None
44
+ g_providers = None
42
45
  g_handlers = {}
43
46
  g_verbose = False
44
47
  g_logprefix = ""
@@ -310,7 +313,7 @@ def convert_image_if_needed(image_bytes, mimetype="image/png"):
310
313
  return image_bytes, mimetype
311
314
 
312
315
 
313
- async def process_chat(chat):
316
+ async def process_chat(chat, provider_id=None):
314
317
  if not chat:
315
318
  raise Exception("No chat provided")
316
319
  if "stream" not in chat:
@@ -331,6 +334,8 @@ async def process_chat(chat):
331
334
  image_url = item["image_url"]
332
335
  if "url" in image_url:
333
336
  url = image_url["url"]
337
+ if url.startswith("/~cache/"):
338
+ url = get_cache_path(url[8:])
334
339
  if is_url(url):
335
340
  _log(f"Downloading image: {url}")
336
341
  async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
@@ -377,6 +382,8 @@ async def process_chat(chat):
377
382
  input_audio = item["input_audio"]
378
383
  if "data" in input_audio:
379
384
  url = input_audio["data"]
385
+ if url.startswith("/~cache/"):
386
+ url = get_cache_path(url[8:])
380
387
  mimetype = get_file_mime_type(get_filename(url))
381
388
  if is_url(url):
382
389
  _log(f"Downloading audio: {url}")
@@ -388,6 +395,8 @@ async def process_chat(chat):
388
395
  mimetype = response.headers["Content-Type"]
389
396
  # convert to base64
390
397
  input_audio["data"] = base64.b64encode(content).decode("utf-8")
398
+ if provider_id == "alibaba":
399
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
391
400
  input_audio["format"] = mimetype.rsplit("/", 1)[1]
392
401
  elif is_file_path(url):
393
402
  _log(f"Reading audio: {url}")
@@ -395,6 +404,8 @@ async def process_chat(chat):
395
404
  content = f.read()
396
405
  # convert to base64
397
406
  input_audio["data"] = base64.b64encode(content).decode("utf-8")
407
+ if provider_id == "alibaba":
408
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
398
409
  input_audio["format"] = mimetype.rsplit("/", 1)[1]
399
410
  elif is_base_64(url):
400
411
  pass # use base64 data as-is
@@ -404,6 +415,8 @@ async def process_chat(chat):
404
415
  file = item["file"]
405
416
  if "file_data" in file:
406
417
  url = file["file_data"]
418
+ if url.startswith("/~cache/"):
419
+ url = get_cache_path(url[8:])
407
420
  mimetype = get_file_mime_type(get_filename(url))
408
421
  if is_url(url):
409
422
  _log(f"Downloading file: {url}")
@@ -449,24 +462,26 @@ async def response_json(response):
449
462
  return body
450
463
 
451
464
 
452
- class OpenAiProvider:
453
- def __init__(self, base_url, api_key=None, models=None, **kwargs):
454
- if models is None:
455
- models = {}
456
- self.base_url = base_url.strip("/")
457
- self.api_key = api_key
458
- self.models = models
465
+ class OpenAiCompatible:
466
+ sdk = "@ai-sdk/openai-compatible"
459
467
 
460
- # check if base_url ends with /v{\d} to handle providers with different versions (e.g. z.ai uses /v4)
461
- last_segment = base_url.rsplit("/", 1)[1]
462
- if last_segment.startswith("v") and last_segment[1:].isdigit():
463
- self.chat_url = f"{base_url}/chat/completions"
464
- else:
465
- self.chat_url = f"{base_url}/v1/chat/completions"
468
+ def __init__(self, **kwargs):
469
+ required_args = ["id", "api"]
470
+ for arg in required_args:
471
+ if arg not in kwargs:
472
+ raise ValueError(f"Missing required argument: {arg}")
473
+
474
+ self.id = kwargs.get("id")
475
+ self.api = kwargs.get("api").strip("/")
476
+ self.api_key = kwargs.get("api_key")
477
+ self.name = kwargs.get("name", self.id.replace("-", " ").title().replace(" ", ""))
478
+ self.set_models(**kwargs)
479
+
480
+ self.chat_url = f"{self.api}/chat/completions"
466
481
 
467
482
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
468
- if api_key is not None:
469
- self.headers["Authorization"] = f"Bearer {api_key}"
483
+ if self.api_key is not None:
484
+ self.headers["Authorization"] = f"Bearer {self.api_key}"
470
485
 
471
486
  self.frequency_penalty = float(kwargs["frequency_penalty"]) if "frequency_penalty" in kwargs else None
472
487
  self.max_completion_tokens = int(kwargs["max_completion_tokens"]) if "max_completion_tokens" in kwargs else None
@@ -486,36 +501,141 @@ class OpenAiProvider:
486
501
  self.verbosity = kwargs.get("verbosity")
487
502
  self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
488
503
  self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
489
- self.pricing = kwargs.get("pricing")
490
- self.default_pricing = kwargs.get("default_pricing")
491
504
  self.check = kwargs.get("check")
492
505
 
493
- @classmethod
494
- def test(cls, base_url=None, api_key=None, models=None, **kwargs):
495
- if models is None:
496
- models = {}
497
- return base_url and api_key and len(models) > 0
506
+ def set_models(self, **kwargs):
507
+ models = kwargs.get("models", {})
508
+ self.map_models = kwargs.get("map_models", {})
509
+ # if 'map_models' is provided, only include models in `map_models[model_id] = provider_model_id`
510
+ if self.map_models:
511
+ self.models = {}
512
+ for provider_model_id in self.map_models.values():
513
+ if provider_model_id in models:
514
+ self.models[provider_model_id] = models[provider_model_id]
515
+ else:
516
+ self.models = models
517
+
518
+ include_models = kwargs.get("include_models") # string regex pattern
519
+ # only include models that match the regex pattern
520
+ if include_models:
521
+ _log(f"Filtering {len(self.models)} models, only including models that match regex: {include_models}")
522
+ self.models = {k: v for k, v in self.models.items() if re.search(include_models, k)}
523
+
524
+ exclude_models = kwargs.get("exclude_models") # string regex pattern
525
+ # exclude models that match the regex pattern
526
+ if exclude_models:
527
+ _log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
528
+ self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
529
+
530
+ def test(self, **kwargs):
531
+ ret = self.api and self.api_key and (len(self.models) > 0)
532
+ if not ret:
533
+ _log(f"Provider {self.name} Missing: {self.api}, {self.api_key}, {len(self.models)}")
534
+ return ret
498
535
 
499
536
  async def load(self):
500
- pass
537
+ if not self.models:
538
+ await self.load_models()
501
539
 
502
- def model_pricing(self, model):
540
+ def model_cost(self, model):
503
541
  provider_model = self.provider_model(model) or model
504
- if self.pricing and provider_model in self.pricing:
505
- return self.pricing[provider_model]
506
- return self.default_pricing or None
542
+ for model_id, model_info in self.models.items():
543
+ if model_id.lower() == provider_model.lower():
544
+ return model_info.get("cost")
545
+ return None
507
546
 
508
547
  def provider_model(self, model):
509
- if model in self.models:
510
- return self.models[model]
548
+ # convert model to lowercase for case-insensitive comparison
549
+ model_lower = model.lower()
550
+
551
+ # if model is a map model id, return the provider model id
552
+ for model_id, provider_model in self.map_models.items():
553
+ if model_id.lower() == model_lower:
554
+ return provider_model
555
+
556
+ # if model is a provider model id, try again with just the model name
557
+ for provider_model in self.map_models.values():
558
+ if provider_model.lower() == model_lower:
559
+ return provider_model
560
+
561
+ # if model is a model id, try again with just the model id or name
562
+ for model_id, provider_model_info in self.models.items():
563
+ id = provider_model_info.get("id") or model_id
564
+ if model_id.lower() == model_lower or id.lower() == model_lower:
565
+ return id
566
+ name = provider_model_info.get("name")
567
+ if name and name.lower() == model_lower:
568
+ return id
569
+
570
+ # fallback to trying again with just the model short name
571
+ for model_id, provider_model_info in self.models.items():
572
+ id = provider_model_info.get("id") or model_id
573
+ if "/" in id:
574
+ model_name = id.split("/")[-1]
575
+ if model_name.lower() == model_lower:
576
+ return id
577
+
578
+ # if model is a full provider model id, try again with just the model name
579
+ if "/" in model:
580
+ last_part = model.split("/")[-1]
581
+ return self.provider_model(last_part)
511
582
  return None
512
583
 
584
+ def validate_modalities(self, chat):
585
+ model_id = chat.get("model")
586
+ if not model_id or not self.models:
587
+ return
588
+
589
+ model_info = None
590
+ # Try to find model info using provider_model logic (already resolved to ID)
591
+ if model_id in self.models:
592
+ model_info = self.models[model_id]
593
+ else:
594
+ # Fallback scan
595
+ for m_id, m_info in self.models.items():
596
+ if m_id == model_id or m_info.get("id") == model_id:
597
+ model_info = m_info
598
+ break
599
+
600
+ print(f"DEBUG: Validate modalities: model={model_id}, found_info={model_info is not None}")
601
+ if model_info:
602
+ print(f"DEBUG: Modalities: {model_info.get('modalities')}")
603
+
604
+ if not model_info:
605
+ return
606
+
607
+ modalities = model_info.get("modalities", {})
608
+ input_modalities = modalities.get("input", [])
609
+
610
+ # Check for unsupported modalities
611
+ has_audio = False
612
+ has_image = False
613
+ for message in chat.get("messages", []):
614
+ content = message.get("content")
615
+ if isinstance(content, list):
616
+ for item in content:
617
+ type_ = item.get("type")
618
+ if type_ == "input_audio" or "input_audio" in item:
619
+ has_audio = True
620
+ elif type_ == "image_url" or "image_url" in item:
621
+ has_image = True
622
+
623
+ if has_audio and "audio" not in input_modalities:
624
+ raise Exception(
625
+ f"Model '{model_id}' does not support audio input. Supported modalities: {', '.join(input_modalities)}"
626
+ )
627
+
628
+ if has_image and "image" not in input_modalities:
629
+ raise Exception(
630
+ f"Model '{model_id}' does not support image input. Supported modalities: {', '.join(input_modalities)}"
631
+ )
632
+
513
633
  def to_response(self, response, chat, started_at):
514
634
  if "metadata" not in response:
515
635
  response["metadata"] = {}
516
636
  response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
517
637
  if chat is not None and "model" in chat:
518
- pricing = self.model_pricing(chat["model"])
638
+ pricing = self.model_cost(chat["model"])
519
639
  if pricing and "input" in pricing and "output" in pricing:
520
640
  response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
521
641
  _log(json.dumps(response, indent=2))
@@ -524,6 +644,8 @@ class OpenAiProvider:
524
644
  async def chat(self, chat):
525
645
  chat["model"] = self.provider_model(chat["model"]) or chat["model"]
526
646
 
647
+ self.validate_modalities(chat)
648
+
527
649
  # with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
528
650
  # f.write(json.dumps(chat, indent=2))
529
651
 
@@ -562,7 +684,7 @@ class OpenAiProvider:
562
684
  if self.enable_thinking is not None:
563
685
  chat["enable_thinking"] = self.enable_thinking
564
686
 
565
- chat = await process_chat(chat)
687
+ chat = await process_chat(chat, provider_id=self.id)
566
688
  _log(f"POST {self.chat_url}")
567
689
  _log(chat_summary(chat))
568
690
  # remove metadata if any (conflicts with some providers, e.g. Z.ai)
@@ -576,78 +698,341 @@ class OpenAiProvider:
576
698
  return self.to_response(await response_json(response), chat, started_at)
577
699
 
578
700
 
579
- class OllamaProvider(OpenAiProvider):
580
- def __init__(self, base_url, models, all_models=False, **kwargs):
581
- super().__init__(base_url=base_url, models=models, **kwargs)
582
- self.all_models = all_models
701
+ class OpenAiProvider(OpenAiCompatible):
702
+ sdk = "@ai-sdk/openai"
703
+
704
+ def __init__(self, **kwargs):
705
+ if "api" not in kwargs:
706
+ kwargs["api"] = "https://api.openai.com/v1"
707
+ super().__init__(**kwargs)
708
+
709
+
710
+ class AnthropicProvider(OpenAiCompatible):
711
+ sdk = "@ai-sdk/anthropic"
712
+
713
+ def __init__(self, **kwargs):
714
+ if "api" not in kwargs:
715
+ kwargs["api"] = "https://api.anthropic.com/v1"
716
+ super().__init__(**kwargs)
717
+
718
+ # Anthropic uses x-api-key header instead of Authorization
719
+ if self.api_key:
720
+ self.headers = self.headers.copy()
721
+ if "Authorization" in self.headers:
722
+ del self.headers["Authorization"]
723
+ self.headers["x-api-key"] = self.api_key
724
+
725
+ if "anthropic-version" not in self.headers:
726
+ self.headers = self.headers.copy()
727
+ self.headers["anthropic-version"] = "2023-06-01"
728
+ self.chat_url = f"{self.api}/messages"
729
+
730
+ async def chat(self, chat):
731
+ chat["model"] = self.provider_model(chat["model"]) or chat["model"]
732
+
733
+ chat = await process_chat(chat, provider_id=self.id)
734
+
735
+ # Transform OpenAI format to Anthropic format
736
+ anthropic_request = {
737
+ "model": chat["model"],
738
+ "messages": [],
739
+ }
740
+
741
+ # Extract system message (Anthropic uses top-level 'system' parameter)
742
+ system_messages = []
743
+ for message in chat.get("messages", []):
744
+ if message.get("role") == "system":
745
+ content = message.get("content", "")
746
+ if isinstance(content, str):
747
+ system_messages.append(content)
748
+ elif isinstance(content, list):
749
+ for item in content:
750
+ if item.get("type") == "text":
751
+ system_messages.append(item.get("text", ""))
752
+
753
+ if system_messages:
754
+ anthropic_request["system"] = "\n".join(system_messages)
755
+
756
+ # Transform messages (exclude system messages)
757
+ for message in chat.get("messages", []):
758
+ if message.get("role") == "system":
759
+ continue
760
+
761
+ anthropic_message = {"role": message.get("role"), "content": []}
762
+
763
+ content = message.get("content", "")
764
+ if isinstance(content, str):
765
+ anthropic_message["content"] = content
766
+ elif isinstance(content, list):
767
+ for item in content:
768
+ if item.get("type") == "text":
769
+ anthropic_message["content"].append({"type": "text", "text": item.get("text", "")})
770
+ elif item.get("type") == "image_url" and "image_url" in item:
771
+ # Transform OpenAI image_url format to Anthropic format
772
+ image_url = item["image_url"].get("url", "")
773
+ if image_url.startswith("data:"):
774
+ # Extract media type and base64 data
775
+ parts = image_url.split(";base64,", 1)
776
+ if len(parts) == 2:
777
+ media_type = parts[0].replace("data:", "")
778
+ base64_data = parts[1]
779
+ anthropic_message["content"].append(
780
+ {
781
+ "type": "image",
782
+ "source": {"type": "base64", "media_type": media_type, "data": base64_data},
783
+ }
784
+ )
785
+
786
+ anthropic_request["messages"].append(anthropic_message)
787
+
788
+ # Handle max_tokens (required by Anthropic, uses max_tokens not max_completion_tokens)
789
+ if "max_completion_tokens" in chat:
790
+ anthropic_request["max_tokens"] = chat["max_completion_tokens"]
791
+ elif "max_tokens" in chat:
792
+ anthropic_request["max_tokens"] = chat["max_tokens"]
793
+ else:
794
+ # Anthropic requires max_tokens, set a default
795
+ anthropic_request["max_tokens"] = 4096
796
+
797
+ # Copy other supported parameters
798
+ if "temperature" in chat:
799
+ anthropic_request["temperature"] = chat["temperature"]
800
+ if "top_p" in chat:
801
+ anthropic_request["top_p"] = chat["top_p"]
802
+ if "top_k" in chat:
803
+ anthropic_request["top_k"] = chat["top_k"]
804
+ if "stop" in chat:
805
+ anthropic_request["stop_sequences"] = chat["stop"] if isinstance(chat["stop"], list) else [chat["stop"]]
806
+ if "stream" in chat:
807
+ anthropic_request["stream"] = chat["stream"]
808
+ if "tools" in chat:
809
+ anthropic_request["tools"] = chat["tools"]
810
+ if "tool_choice" in chat:
811
+ anthropic_request["tool_choice"] = chat["tool_choice"]
812
+
813
+ _log(f"POST {self.chat_url}")
814
+ _log(f"Anthropic Request: {json.dumps(anthropic_request, indent=2)}")
815
+
816
+ async with aiohttp.ClientSession() as session:
817
+ started_at = time.time()
818
+ async with session.post(
819
+ self.chat_url,
820
+ headers=self.headers,
821
+ data=json.dumps(anthropic_request),
822
+ timeout=aiohttp.ClientTimeout(total=120),
823
+ ) as response:
824
+ return self.to_response(await response_json(response), chat, started_at)
825
+
826
+ def to_response(self, response, chat, started_at):
827
+ """Convert Anthropic response format to OpenAI-compatible format."""
828
+ # Transform Anthropic response to OpenAI format
829
+ openai_response = {
830
+ "id": response.get("id", ""),
831
+ "object": "chat.completion",
832
+ "created": int(started_at),
833
+ "model": response.get("model", ""),
834
+ "choices": [],
835
+ "usage": {},
836
+ }
837
+
838
+ # Transform content blocks to message content
839
+ content_parts = []
840
+ thinking_parts = []
841
+
842
+ for block in response.get("content", []):
843
+ if block.get("type") == "text":
844
+ content_parts.append(block.get("text", ""))
845
+ elif block.get("type") == "thinking":
846
+ # Store thinking blocks separately (some models include reasoning)
847
+ thinking_parts.append(block.get("thinking", ""))
848
+
849
+ # Combine all text content
850
+ message_content = "\n".join(content_parts) if content_parts else ""
851
+
852
+ # Create the choice object
853
+ choice = {
854
+ "index": 0,
855
+ "message": {"role": "assistant", "content": message_content},
856
+ "finish_reason": response.get("stop_reason", "stop"),
857
+ }
858
+
859
+ # Add thinking as metadata if present
860
+ if thinking_parts:
861
+ choice["message"]["thinking"] = "\n".join(thinking_parts)
862
+
863
+ openai_response["choices"].append(choice)
864
+
865
+ # Transform usage
866
+ if "usage" in response:
867
+ usage = response["usage"]
868
+ openai_response["usage"] = {
869
+ "prompt_tokens": usage.get("input_tokens", 0),
870
+ "completion_tokens": usage.get("output_tokens", 0),
871
+ "total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
872
+ }
873
+
874
+ # Add metadata
875
+ if "metadata" not in openai_response:
876
+ openai_response["metadata"] = {}
877
+ openai_response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
878
+
879
+ if chat is not None and "model" in chat:
880
+ cost = self.model_cost(chat["model"])
881
+ if cost and "input" in cost and "output" in cost:
882
+ openai_response["metadata"]["pricing"] = f"{cost['input']}/{cost['output']}"
883
+
884
+ _log(json.dumps(openai_response, indent=2))
885
+ return openai_response
886
+
887
+
888
+ class MistralProvider(OpenAiCompatible):
889
+ sdk = "@ai-sdk/mistral"
890
+
891
+ def __init__(self, **kwargs):
892
+ if "api" not in kwargs:
893
+ kwargs["api"] = "https://api.mistral.ai/v1"
894
+ super().__init__(**kwargs)
895
+
896
+
897
+ class GroqProvider(OpenAiCompatible):
898
+ sdk = "@ai-sdk/groq"
899
+
900
+ def __init__(self, **kwargs):
901
+ if "api" not in kwargs:
902
+ kwargs["api"] = "https://api.groq.com/openai/v1"
903
+ super().__init__(**kwargs)
904
+
905
+
906
+ class XaiProvider(OpenAiCompatible):
907
+ sdk = "@ai-sdk/xai"
908
+
909
+ def __init__(self, **kwargs):
910
+ if "api" not in kwargs:
911
+ kwargs["api"] = "https://api.x.ai/v1"
912
+ super().__init__(**kwargs)
913
+
914
+
915
+ class CodestralProvider(OpenAiCompatible):
916
+ sdk = "codestral"
917
+
918
+ def __init__(self, **kwargs):
919
+ super().__init__(**kwargs)
920
+
921
+
922
+ class OllamaProvider(OpenAiCompatible):
923
+ sdk = "ollama"
924
+
925
+ def __init__(self, **kwargs):
926
+ super().__init__(**kwargs)
927
+ # Ollama's OpenAI-compatible endpoint is at /v1/chat/completions
928
+ self.chat_url = f"{self.api}/v1/chat/completions"
583
929
 
584
930
  async def load(self):
585
- if self.all_models:
586
- await self.load_models(default_models=self.models)
931
+ if not self.models:
932
+ await self.load_models()
587
933
 
588
934
  async def get_models(self):
589
935
  ret = {}
590
936
  try:
591
937
  async with aiohttp.ClientSession() as session:
592
- _log(f"GET {self.base_url}/api/tags")
938
+ _log(f"GET {self.api}/api/tags")
593
939
  async with session.get(
594
- f"{self.base_url}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
940
+ f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
595
941
  ) as response:
596
942
  data = await response_json(response)
597
943
  for model in data.get("models", []):
598
944
  name = model["model"]
599
945
  if name.endswith(":latest"):
600
946
  name = name[:-7]
601
- ret[name] = name
947
+ model_id = name.replace(":", "-")
948
+ ret[model_id] = name
602
949
  _log(f"Loaded Ollama models: {ret}")
603
950
  except Exception as e:
604
951
  _log(f"Error getting Ollama models: {e}")
605
952
  # return empty dict if ollama is not available
606
953
  return ret
607
954
 
608
- async def load_models(self, default_models):
955
+ async def load_models(self):
609
956
  """Load models if all_models was requested"""
610
- if self.all_models:
611
- self.models = await self.get_models()
612
- if default_models:
613
- self.models = {**default_models, **self.models}
614
-
615
- @classmethod
616
- def test(cls, base_url=None, models=None, all_models=False, **kwargs):
617
- if models is None:
618
- models = {}
619
- return base_url and (len(models) > 0 or all_models)
620
-
621
-
622
- class GoogleOpenAiProvider(OpenAiProvider):
623
- def __init__(self, api_key, models, **kwargs):
624
- super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
625
- self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
626
-
627
- @classmethod
628
- def test(cls, api_key=None, models=None, **kwargs):
629
- if models is None:
630
- models = {}
631
- return api_key and len(models) > 0
632
-
633
-
634
- class GoogleProvider(OpenAiProvider):
635
- def __init__(self, models, api_key, safety_settings=None, thinking_config=None, curl=False, **kwargs):
636
- super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
637
- self.safety_settings = safety_settings
638
- self.thinking_config = thinking_config
639
- self.curl = curl
957
+
958
+ # Map models to provider models {model_id:model_id}
959
+ model_map = await self.get_models()
960
+ if self.map_models:
961
+ map_model_values = set(self.map_models.values())
962
+ to = {}
963
+ for k, v in model_map.items():
964
+ if k in self.map_models:
965
+ to[k] = v
966
+ if v in map_model_values:
967
+ to[k] = v
968
+ model_map = to
969
+ else:
970
+ self.map_models = model_map
971
+ models = {}
972
+ for k, v in model_map.items():
973
+ models[k] = {
974
+ "id": k,
975
+ "name": v.replace(":", " "),
976
+ "modalities": {"input": ["text"], "output": ["text"]},
977
+ "cost": {
978
+ "input": 0,
979
+ "output": 0,
980
+ },
981
+ }
982
+ self.models = models
983
+
984
+ def test(self, **kwargs):
985
+ return True
986
+
987
+
988
+ class LMStudioProvider(OllamaProvider):
989
+ sdk = "lmstudio"
990
+
991
+ def __init__(self, **kwargs):
992
+ super().__init__(**kwargs)
993
+ self.chat_url = f"{self.api}/chat/completions"
994
+
995
+ async def get_models(self):
996
+ ret = {}
997
+ try:
998
+ async with aiohttp.ClientSession() as session:
999
+ _log(f"GET {self.api}/models")
1000
+ async with session.get(
1001
+ f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
1002
+ ) as response:
1003
+ data = await response_json(response)
1004
+ for model in data.get("data", []):
1005
+ id = model["id"]
1006
+ ret[id] = id
1007
+ _log(f"Loaded LMStudio models: {ret}")
1008
+ except Exception as e:
1009
+ _log(f"Error getting LMStudio models: {e}")
1010
+ # return empty dict if ollama is not available
1011
+ return ret
1012
+
1013
+
1014
+ # class GoogleOpenAiProvider(OpenAiCompatible):
1015
+ # sdk = "google-openai-compatible"
1016
+
1017
+ # def __init__(self, api_key, **kwargs):
1018
+ # super().__init__(api="https://generativelanguage.googleapis.com", api_key=api_key, **kwargs)
1019
+ # self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
1020
+
1021
+
1022
+ class GoogleProvider(OpenAiCompatible):
1023
+ sdk = "@ai-sdk/google"
1024
+
1025
+ def __init__(self, **kwargs):
1026
+ new_kwargs = {"api": "https://generativelanguage.googleapis.com", **kwargs}
1027
+ super().__init__(**new_kwargs)
1028
+ self.safety_settings = kwargs.get("safety_settings")
1029
+ self.thinking_config = kwargs.get("thinking_config")
1030
+ self.curl = kwargs.get("curl")
640
1031
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
641
1032
  # Google fails when using Authorization header, use query string param instead
642
1033
  if "Authorization" in self.headers:
643
1034
  del self.headers["Authorization"]
644
1035
 
645
- @classmethod
646
- def test(cls, api_key=None, models=None, **kwargs):
647
- if models is None:
648
- models = {}
649
- return api_key is not None and len(models) > 0
650
-
651
1036
  async def chat(self, chat):
652
1037
  chat["model"] = self.provider_model(chat["model"]) or chat["model"]
653
1038
 
@@ -843,6 +1228,28 @@ class GoogleProvider(OpenAiProvider):
843
1228
  return self.to_response(response, chat, started_at)
844
1229
 
845
1230
 
1231
+ ALL_PROVIDERS = [
1232
+ OpenAiCompatible,
1233
+ OpenAiProvider,
1234
+ AnthropicProvider,
1235
+ MistralProvider,
1236
+ GroqProvider,
1237
+ XaiProvider,
1238
+ CodestralProvider,
1239
+ GoogleProvider,
1240
+ OllamaProvider,
1241
+ LMStudioProvider,
1242
+ ]
1243
+
1244
+
1245
+ def get_provider_model(model_name):
1246
+ for provider in g_handlers.values():
1247
+ provider_model = provider.provider_model(model_name)
1248
+ if provider_model:
1249
+ return provider_model
1250
+ return None
1251
+
1252
+
846
1253
  def get_models():
847
1254
  ret = []
848
1255
  for provider in g_handlers.values():
@@ -856,21 +1263,32 @@ def get_models():
856
1263
  def get_active_models():
857
1264
  ret = []
858
1265
  existing_models = set()
859
- for id, provider in g_handlers.items():
860
- for model in provider.models:
861
- if model not in existing_models:
862
- existing_models.add(model)
863
- provider_model = provider.models[model]
864
- pricing = provider.model_pricing(model)
865
- ret.append({"id": model, "provider": id, "provider_model": provider_model, "pricing": pricing})
1266
+ for provider_id, provider in g_handlers.items():
1267
+ for model in provider.models.values():
1268
+ name = model.get("name")
1269
+ if not name:
1270
+ _log(f"Provider {provider_id} model {model} has no name")
1271
+ continue
1272
+ if name not in existing_models:
1273
+ existing_models.add(name)
1274
+ item = model.copy()
1275
+ item.update({"provider": provider_id})
1276
+ ret.append(item)
866
1277
  ret.sort(key=lambda x: x["id"])
867
1278
  return ret
868
1279
 
869
1280
 
1281
+ def api_providers():
1282
+ ret = []
1283
+ for id, provider in g_handlers.items():
1284
+ ret.append({"id": id, "name": provider.name, "models": provider.models})
1285
+ return ret
1286
+
1287
+
870
1288
  async def chat_completion(chat):
871
1289
  model = chat["model"]
872
1290
  # get first provider that has the model
873
- candidate_providers = [name for name, provider in g_handlers.items() if model in provider.models]
1291
+ candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
874
1292
  if len(candidate_providers) == 0:
875
1293
  raise (Exception(f"Model {model} not found"))
876
1294
 
@@ -991,10 +1409,18 @@ def config_str(key):
991
1409
  return key in g_config and g_config[key] or None
992
1410
 
993
1411
 
994
- def init_llms(config):
1412
+ def load_config(config, providers, verbose=None):
1413
+ global g_config, g_providers, g_verbose
1414
+ g_config = config
1415
+ g_providers = providers
1416
+ if verbose:
1417
+ g_verbose = verbose
1418
+
1419
+
1420
+ def init_llms(config, providers):
995
1421
  global g_config, g_handlers
996
1422
 
997
- g_config = config
1423
+ load_config(config, providers)
998
1424
  g_handlers = {}
999
1425
  # iterate over config and replace $ENV with env value
1000
1426
  for key, value in g_config.items():
@@ -1005,34 +1431,71 @@ def init_llms(config):
1005
1431
  # printdump(g_config)
1006
1432
  providers = g_config["providers"]
1007
1433
 
1008
- for name, orig in providers.items():
1434
+ for id, orig in providers.items():
1009
1435
  definition = orig.copy()
1010
- provider_type = definition["type"]
1011
1436
  if "enabled" in definition and not definition["enabled"]:
1012
1437
  continue
1013
1438
 
1014
- # Replace API keys with environment variables if they start with $
1015
- if "api_key" in definition:
1016
- value = definition["api_key"]
1017
- if isinstance(value, str) and value.startswith("$"):
1018
- definition["api_key"] = os.environ.get(value[1:], "")
1019
-
1020
- # Create a copy of definition without the 'type' key for constructor kwargs
1021
- constructor_kwargs = {k: v for k, v in definition.items() if k != "type" and k != "enabled"}
1022
- constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
1023
-
1024
- if provider_type == "OpenAiProvider" and OpenAiProvider.test(**constructor_kwargs):
1025
- g_handlers[name] = OpenAiProvider(**constructor_kwargs)
1026
- elif provider_type == "OllamaProvider" and OllamaProvider.test(**constructor_kwargs):
1027
- g_handlers[name] = OllamaProvider(**constructor_kwargs)
1028
- elif provider_type == "GoogleProvider" and GoogleProvider.test(**constructor_kwargs):
1029
- g_handlers[name] = GoogleProvider(**constructor_kwargs)
1030
- elif provider_type == "GoogleOpenAiProvider" and GoogleOpenAiProvider.test(**constructor_kwargs):
1031
- g_handlers[name] = GoogleOpenAiProvider(**constructor_kwargs)
1439
+ provider_id = definition.get("id", id)
1440
+ if "id" not in definition:
1441
+ definition["id"] = provider_id
1442
+ provider = g_providers.get(provider_id)
1443
+ constructor_kwargs = create_provider_kwargs(definition, provider)
1444
+ provider = create_provider(constructor_kwargs)
1032
1445
 
1446
+ if provider and provider.test(**constructor_kwargs):
1447
+ g_handlers[id] = provider
1033
1448
  return g_handlers
1034
1449
 
1035
1450
 
1451
+ def create_provider_kwargs(definition, provider=None):
1452
+ if provider:
1453
+ provider = provider.copy()
1454
+ provider.update(definition)
1455
+ else:
1456
+ provider = definition.copy()
1457
+
1458
+ # Replace API keys with environment variables if they start with $
1459
+ if "api_key" in provider:
1460
+ value = provider["api_key"]
1461
+ if isinstance(value, str) and value.startswith("$"):
1462
+ provider["api_key"] = os.environ.get(value[1:], "")
1463
+
1464
+ if "api_key" not in provider and "env" in provider:
1465
+ for env_var in provider["env"]:
1466
+ val = os.environ.get(env_var)
1467
+ if val:
1468
+ provider["api_key"] = val
1469
+ break
1470
+
1471
+ # Create a copy of provider
1472
+ constructor_kwargs = dict(provider.items())
1473
+ # Create a copy of all list and dict values
1474
+ for key, value in constructor_kwargs.items():
1475
+ if isinstance(value, (list, dict)):
1476
+ constructor_kwargs[key] = value.copy()
1477
+ constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
1478
+ return constructor_kwargs
1479
+
1480
+
1481
+ def create_provider(provider):
1482
+ if not isinstance(provider, dict):
1483
+ return None
1484
+ provider_label = provider.get("id", provider.get("name", "unknown"))
1485
+ npm_sdk = provider.get("npm")
1486
+ if not npm_sdk:
1487
+ _log(f"Provider {provider_label} is missing 'npm' sdk")
1488
+ return None
1489
+
1490
+ for provider_type in ALL_PROVIDERS:
1491
+ if provider_type.sdk == npm_sdk:
1492
+ kwargs = create_provider_kwargs(provider)
1493
+ return provider_type(**kwargs)
1494
+
1495
+ _log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
1496
+ return None
1497
+
1498
+
1036
1499
  async def load_llms():
1037
1500
  global g_handlers
1038
1501
  _log("Loading providers...")
@@ -1076,6 +1539,23 @@ async def save_default_config(config_path):
1076
1539
  g_config = json.loads(config_json)
1077
1540
 
1078
1541
 
1542
+ async def update_providers(home_providers_path):
1543
+ global g_providers
1544
+ text = await get_text("https://models.dev/api.json")
1545
+ all_providers = json.loads(text)
1546
+
1547
+ filtered_providers = {}
1548
+ for id, provider in all_providers.items():
1549
+ if id in g_config["providers"]:
1550
+ filtered_providers[id] = provider
1551
+
1552
+ os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
1553
+ with open(home_providers_path, "w", encoding="utf-8") as f:
1554
+ json.dump(filtered_providers, f)
1555
+
1556
+ g_providers = filtered_providers
1557
+
1558
+
1079
1559
  def provider_status():
1080
1560
  enabled = list(g_handlers.keys())
1081
1561
  disabled = [provider for provider in g_config["providers"] if provider not in enabled]
@@ -1100,6 +1580,10 @@ def home_llms_path(filename):
1100
1580
  return f"{os.environ.get('HOME')}/.llms/{filename}"
1101
1581
 
1102
1582
 
1583
+ def get_cache_path(filename):
1584
+ return home_llms_path(f"cache/{filename}")
1585
+
1586
+
1103
1587
  def get_config_path():
1104
1588
  home_config_path = home_llms_path("llms.json")
1105
1589
  check_paths = [
@@ -1137,7 +1621,7 @@ def enable_provider(provider):
1137
1621
  else:
1138
1622
  msg = f"WARNING: {provider} is not configured with an API Key"
1139
1623
  save_config(g_config)
1140
- init_llms(g_config)
1624
+ init_llms(g_config, g_providers)
1141
1625
  return provider_config, msg
1142
1626
 
1143
1627
 
@@ -1145,7 +1629,7 @@ def disable_provider(provider):
1145
1629
  provider_config = g_config["providers"][provider]
1146
1630
  provider_config["enabled"] = False
1147
1631
  save_config(g_config)
1148
- init_llms(g_config)
1632
+ init_llms(g_config, g_providers)
1149
1633
 
1150
1634
 
1151
1635
  def resolve_root():
@@ -1340,7 +1824,8 @@ async def check_models(provider_name, model_names=None):
1340
1824
  else:
1341
1825
  # Check only specified models
1342
1826
  for model_name in model_names:
1343
- if model_name in provider.models:
1827
+ provider_model = provider.provider_model(model_name)
1828
+ if provider_model:
1344
1829
  models_to_check.append(model_name)
1345
1830
  else:
1346
1831
  print(f"Model '{model_name}' not found in provider '{provider_name}'")
@@ -1355,69 +1840,76 @@ async def check_models(provider_name, model_names=None):
1355
1840
 
1356
1841
  # Test each model
1357
1842
  for model in models_to_check:
1358
- # Create a simple ping chat request
1359
- chat = (provider.check or g_config["defaults"]["check"]).copy()
1360
- chat["model"] = model
1843
+ await check_provider_model(provider, model)
1361
1844
 
1362
- started_at = time.time()
1363
- try:
1364
- # Try to get a response from the model
1365
- response = await provider.chat(chat)
1366
- duration_ms = int((time.time() - started_at) * 1000)
1845
+ print()
1367
1846
 
1368
- # Check if we got a valid response
1369
- if response and "choices" in response and len(response["choices"]) > 0:
1370
- print(f" ✓ {model:<40} ({duration_ms}ms)")
1371
- else:
1372
- print(f" ✗ {model:<40} Invalid response format")
1373
- except HTTPError as e:
1374
- duration_ms = int((time.time() - started_at) * 1000)
1375
- error_msg = f"HTTP {e.status}"
1376
- try:
1377
- # Try to parse error body for more details
1378
- error_body = json.loads(e.body) if e.body else {}
1379
- if "error" in error_body:
1380
- error = error_body["error"]
1381
- if isinstance(error, dict):
1382
- if "message" in error and isinstance(error["message"], str):
1383
- # OpenRouter
1384
- error_msg = error["message"]
1385
- if "code" in error:
1386
- error_msg = f"{error['code']} {error_msg}"
1387
- if "metadata" in error and "raw" in error["metadata"]:
1388
- error_msg += f" - {error['metadata']['raw']}"
1389
- if "provider" in error:
1390
- error_msg += f" ({error['provider']})"
1391
- elif isinstance(error, str):
1392
- error_msg = error
1393
- elif "message" in error_body:
1394
- if isinstance(error_body["message"], str):
1395
- error_msg = error_body["message"]
1396
- elif (
1397
- isinstance(error_body["message"], dict)
1398
- and "detail" in error_body["message"]
1399
- and isinstance(error_body["message"]["detail"], list)
1400
- ):
1401
- # codestral error format
1402
- error_msg = error_body["message"]["detail"][0]["msg"]
1403
- if (
1404
- "loc" in error_body["message"]["detail"][0]
1405
- and len(error_body["message"]["detail"][0]["loc"]) > 0
1406
- ):
1407
- error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
1408
- except Exception as parse_error:
1409
- _log(f"Error parsing error body: {parse_error}")
1410
- error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
1411
- print(f" ✗ {model:<40} {error_msg}")
1412
- except asyncio.TimeoutError:
1413
- duration_ms = int((time.time() - started_at) * 1000)
1414
- print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
1415
- except Exception as e:
1416
- duration_ms = int((time.time() - started_at) * 1000)
1417
- error_msg = str(e)[:100]
1418
- print(f" ✗ {model:<40} {error_msg}")
1419
1847
 
1420
- print()
1848
+ async def check_provider_model(provider, model):
1849
+ # Create a simple ping chat request
1850
+ chat = (provider.check or g_config["defaults"]["check"]).copy()
1851
+ chat["model"] = model
1852
+
1853
+ success = False
1854
+ started_at = time.time()
1855
+ try:
1856
+ # Try to get a response from the model
1857
+ response = await provider.chat(chat)
1858
+ duration_ms = int((time.time() - started_at) * 1000)
1859
+
1860
+ # Check if we got a valid response
1861
+ if response and "choices" in response and len(response["choices"]) > 0:
1862
+ success = True
1863
+ print(f" ✓ {model:<40} ({duration_ms}ms)")
1864
+ else:
1865
+ print(f" ✗ {model:<40} Invalid response format")
1866
+ except HTTPError as e:
1867
+ duration_ms = int((time.time() - started_at) * 1000)
1868
+ error_msg = f"HTTP {e.status}"
1869
+ try:
1870
+ # Try to parse error body for more details
1871
+ error_body = json.loads(e.body) if e.body else {}
1872
+ if "error" in error_body:
1873
+ error = error_body["error"]
1874
+ if isinstance(error, dict):
1875
+ if "message" in error and isinstance(error["message"], str):
1876
+ # OpenRouter
1877
+ error_msg = error["message"]
1878
+ if "code" in error:
1879
+ error_msg = f"{error['code']} {error_msg}"
1880
+ if "metadata" in error and "raw" in error["metadata"]:
1881
+ error_msg += f" - {error['metadata']['raw']}"
1882
+ if "provider" in error:
1883
+ error_msg += f" ({error['provider']})"
1884
+ elif isinstance(error, str):
1885
+ error_msg = error
1886
+ elif "message" in error_body:
1887
+ if isinstance(error_body["message"], str):
1888
+ error_msg = error_body["message"]
1889
+ elif (
1890
+ isinstance(error_body["message"], dict)
1891
+ and "detail" in error_body["message"]
1892
+ and isinstance(error_body["message"]["detail"], list)
1893
+ ):
1894
+ # codestral error format
1895
+ error_msg = error_body["message"]["detail"][0]["msg"]
1896
+ if (
1897
+ "loc" in error_body["message"]["detail"][0]
1898
+ and len(error_body["message"]["detail"][0]["loc"]) > 0
1899
+ ):
1900
+ error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
1901
+ except Exception as parse_error:
1902
+ _log(f"Error parsing error body: {parse_error}")
1903
+ error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
1904
+ print(f" ✗ {model:<40} {error_msg}")
1905
+ except asyncio.TimeoutError:
1906
+ duration_ms = int((time.time() - started_at) * 1000)
1907
+ print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
1908
+ except Exception as e:
1909
+ duration_ms = int((time.time() - started_at) * 1000)
1910
+ error_msg = str(e)[:100]
1911
+ print(f" ✗ {model:<40} {error_msg}")
1912
+ return success
1421
1913
 
1422
1914
 
1423
1915
  def text_from_resource(filename):
@@ -1453,7 +1945,8 @@ async def text_from_resource_or_url(filename):
1453
1945
  async def save_home_configs():
1454
1946
  home_config_path = home_llms_path("llms.json")
1455
1947
  home_ui_path = home_llms_path("ui.json")
1456
- if os.path.exists(home_config_path) and os.path.exists(home_ui_path):
1948
+ home_providers_path = home_llms_path("providers.json")
1949
+ if os.path.exists(home_config_path) and os.path.exists(home_ui_path) and os.path.exists(home_providers_path):
1457
1950
  return
1458
1951
 
1459
1952
  llms_home = os.path.dirname(home_config_path)
@@ -1470,14 +1963,43 @@ async def save_home_configs():
1470
1963
  with open(home_ui_path, "w", encoding="utf-8") as f:
1471
1964
  f.write(ui_json)
1472
1965
  _log(f"Created default ui config at {home_ui_path}")
1966
+
1967
+ if not os.path.exists(home_providers_path):
1968
+ providers_json = await text_from_resource_or_url("providers.json")
1969
+ with open(home_providers_path, "w", encoding="utf-8") as f:
1970
+ f.write(providers_json)
1971
+ _log(f"Created default providers config at {home_providers_path}")
1473
1972
  except Exception:
1474
1973
  print("Could not create llms.json. Create one with --init or use --config <path>")
1475
1974
  exit(1)
1476
1975
 
1477
1976
 
1977
+ def load_config_json(config_json):
1978
+ if config_json is None:
1979
+ return None
1980
+ config = json.loads(config_json)
1981
+ if not config or "version" not in config or config["version"] < 3:
1982
+ preserve_keys = ["auth", "defaults", "limits", "convert"]
1983
+ new_config = json.loads(text_from_resource("llms.json"))
1984
+ if config:
1985
+ for key in preserve_keys:
1986
+ if key in config:
1987
+ new_config[key] = config[key]
1988
+ config = new_config
1989
+ # move old config to YYYY-MM-DD.bak
1990
+ new_path = f"{g_config_path}.{datetime.now().strftime('%Y-%m-%d')}.bak"
1991
+ if os.path.exists(new_path):
1992
+ os.remove(new_path)
1993
+ os.rename(g_config_path, new_path)
1994
+ print(f"llms.json migrated. old config moved to {new_path}")
1995
+ # save new config
1996
+ save_config(g_config)
1997
+ return config
1998
+
1999
+
1478
2000
  async def reload_providers():
1479
2001
  global g_config, g_handlers
1480
- g_handlers = init_llms(g_config)
2002
+ g_handlers = init_llms(g_config, g_providers)
1481
2003
  await load_llms()
1482
2004
  _log(f"{len(g_handlers)} providers loaded")
1483
2005
  return g_handlers
@@ -1538,10 +2060,11 @@ async def watch_config_files(config_path, ui_path, interval=1):
1538
2060
 
1539
2061
 
1540
2062
  def main():
1541
- global _ROOT, g_verbose, g_default_model, g_logprefix, g_config, g_config_path, g_ui_path
2063
+ global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_ui_path
1542
2064
 
1543
2065
  parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
1544
2066
  parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
2067
+ parser.add_argument("--providers", default=None, help="Path to models.dev providers file", metavar="FILE")
1545
2068
  parser.add_argument("-m", "--model", default=None, help="Model to use")
1546
2069
 
1547
2070
  parser.add_argument("--chat", default=None, help="OpenAI Chat Completion Request to send", metavar="REQUEST")
@@ -1573,6 +2096,7 @@ def main():
1573
2096
  parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
1574
2097
 
1575
2098
  parser.add_argument("--init", action="store_true", help="Create a default llms.json")
2099
+ parser.add_argument("--update", action="store_true", help="Update local models.dev providers.json")
1576
2100
 
1577
2101
  parser.add_argument("--root", default=None, help="Change root directory for UI files", metavar="PATH")
1578
2102
  parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
@@ -1597,6 +2121,7 @@ def main():
1597
2121
 
1598
2122
  home_config_path = home_llms_path("llms.json")
1599
2123
  home_ui_path = home_llms_path("ui.json")
2124
+ home_providers_path = home_llms_path("providers.json")
1600
2125
 
1601
2126
  if cli_args.init:
1602
2127
  if os.path.exists(home_config_path):
@@ -1610,14 +2135,26 @@ def main():
1610
2135
  else:
1611
2136
  asyncio.run(save_text_url(github_url("ui.json"), home_ui_path))
1612
2137
  print(f"Created default ui config at {home_ui_path}")
2138
+
2139
+ if os.path.exists(home_providers_path):
2140
+ print(f"providers.json already exists at {home_providers_path}")
2141
+ else:
2142
+ asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
2143
+ print(f"Created default providers config at {home_providers_path}")
1613
2144
  exit(0)
1614
2145
 
2146
+ if cli_args.providers:
2147
+ if not os.path.exists(cli_args.providers):
2148
+ print(f"providers.json not found at {cli_args.providers}")
2149
+ exit(1)
2150
+ g_providers = json.loads(text_from_file(cli_args.providers))
2151
+
1615
2152
  if cli_args.config:
1616
2153
  # read contents
1617
2154
  g_config_path = cli_args.config
1618
2155
  with open(g_config_path, encoding="utf-8") as f:
1619
2156
  config_json = f.read()
1620
- g_config = json.loads(config_json)
2157
+ g_config = load_config_json(config_json)
1621
2158
 
1622
2159
  config_dir = os.path.dirname(g_config_path)
1623
2160
  # look for ui.json in same directory as config
@@ -1631,12 +2168,24 @@ def main():
1631
2168
  f.write(ui_json)
1632
2169
  _log(f"Created default ui config at {home_ui_path}")
1633
2170
  g_ui_path = home_ui_path
2171
+
2172
+ if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
2173
+ g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
2174
+
1634
2175
  else:
1635
2176
  # ensure llms.json and ui.json exist in home directory
1636
2177
  asyncio.run(save_home_configs())
1637
2178
  g_config_path = home_config_path
1638
2179
  g_ui_path = home_ui_path
1639
- g_config = json.loads(text_from_file(g_config_path))
2180
+ g_config = load_config_json(text_from_file(g_config_path))
2181
+
2182
+ if not g_providers:
2183
+ g_providers = json.loads(text_from_file(home_providers_path))
2184
+
2185
+ if cli_args.update:
2186
+ asyncio.run(update_providers(home_providers_path))
2187
+ print(f"Updated {home_providers_path}")
2188
+ exit(0)
1640
2189
 
1641
2190
  asyncio.run(reload_providers())
1642
2191
 
@@ -1654,13 +2203,35 @@ def main():
1654
2203
  if cli_args.list:
1655
2204
  # Show list of enabled providers and their models
1656
2205
  enabled = []
2206
+ provider_count = 0
2207
+ model_count = 0
2208
+
2209
+ max_model_length = 0
1657
2210
  for name, provider in g_handlers.items():
1658
2211
  if len(filter_list) > 0 and name not in filter_list:
1659
2212
  continue
2213
+ for model in provider.models:
2214
+ max_model_length = max(max_model_length, len(model))
2215
+
2216
+ for name, provider in g_handlers.items():
2217
+ if len(filter_list) > 0 and name not in filter_list:
2218
+ continue
2219
+ provider_count += 1
1660
2220
  print(f"{name}:")
1661
2221
  enabled.append(name)
1662
2222
  for model in provider.models:
1663
- print(f" {model}")
2223
+ model_count += 1
2224
+ model_cost_info = None
2225
+ if "cost" in provider.models[model]:
2226
+ model_cost = provider.models[model]["cost"]
2227
+ if "input" in model_cost and "output" in model_cost:
2228
+ if model_cost["input"] == 0 and model_cost["output"] == 0:
2229
+ model_cost_info = " 0"
2230
+ else:
2231
+ model_cost_info = f"{model_cost['input']:5} / {model_cost['output']}"
2232
+ print(f" {model:{max_model_length}} {model_cost_info or ''}")
2233
+
2234
+ print(f"\n{model_count} models available from {provider_count} providers")
1664
2235
 
1665
2236
  print_status()
1666
2237
  exit(0)
@@ -1775,6 +2346,11 @@ def main():
1775
2346
 
1776
2347
  app.router.add_get("/models", active_models_handler)
1777
2348
 
2349
+ async def active_providers_handler(request):
2350
+ return web.json_response(api_providers())
2351
+
2352
+ app.router.add_get("/providers", active_providers_handler)
2353
+
1778
2354
  async def status_handler(request):
1779
2355
  enabled, disabled = provider_status()
1780
2356
  return web.json_response(
@@ -1810,6 +2386,135 @@ def main():
1810
2386
 
1811
2387
  app.router.add_post("/providers/{provider}", provider_handler)
1812
2388
 
2389
+ async def upload_handler(request):
2390
+ # Check authentication if enabled
2391
+ is_authenticated, user_data = check_auth(request)
2392
+ if not is_authenticated:
2393
+ return web.json_response(
2394
+ {
2395
+ "error": {
2396
+ "message": "Authentication required",
2397
+ "type": "authentication_error",
2398
+ "code": "unauthorized",
2399
+ }
2400
+ },
2401
+ status=401,
2402
+ )
2403
+
2404
+ reader = await request.multipart()
2405
+
2406
+ # Read first file field
2407
+ field = await reader.next()
2408
+ while field and field.name != "file":
2409
+ field = await reader.next()
2410
+
2411
+ if not field:
2412
+ return web.json_response({"error": "No file provided"}, status=400)
2413
+
2414
+ filename = field.filename or "file"
2415
+ content = await field.read()
2416
+ mimetype = get_file_mime_type(filename)
2417
+
2418
+ # If image, resize if needed
2419
+ if mimetype.startswith("image/"):
2420
+ content, mimetype = convert_image_if_needed(content, mimetype)
2421
+
2422
+ # Calculate SHA256
2423
+ sha256_hash = hashlib.sha256(content).hexdigest()
2424
+ ext = filename.rsplit(".", 1)[1] if "." in filename else ""
2425
+ if not ext:
2426
+ ext = mimetypes.guess_extension(mimetype) or ""
2427
+ if ext.startswith("."):
2428
+ ext = ext[1:]
2429
+
2430
+ if not ext:
2431
+ ext = "bin"
2432
+
2433
+ save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
2434
+
2435
+ # Use first 2 chars for subdir to avoid too many files in one dir
2436
+ subdir = sha256_hash[:2]
2437
+ relative_path = f"{subdir}/{save_filename}"
2438
+ full_path = get_cache_path(relative_path)
2439
+
2440
+ # if file and its .info.json already exists, return it
2441
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
2442
+ if os.path.exists(full_path) and os.path.exists(info_path):
2443
+ return web.json_response(json.load(open(info_path)))
2444
+
2445
+ os.makedirs(os.path.dirname(full_path), exist_ok=True)
2446
+
2447
+ with open(full_path, "wb") as f:
2448
+ f.write(content)
2449
+
2450
+ response_data = {
2451
+ "date": int(time.time()),
2452
+ "url": f"/~cache/{relative_path}",
2453
+ "size": len(content),
2454
+ "type": mimetype,
2455
+ "name": filename,
2456
+ }
2457
+
2458
+ # If image, get dimensions
2459
+ if HAS_PIL and mimetype.startswith("image/"):
2460
+ try:
2461
+ with Image.open(BytesIO(content)) as img:
2462
+ response_data["width"] = img.width
2463
+ response_data["height"] = img.height
2464
+ except Exception:
2465
+ pass
2466
+
2467
+ # Save metadata
2468
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
2469
+ with open(info_path, "w") as f:
2470
+ json.dump(response_data, f)
2471
+
2472
+ return web.json_response(response_data)
2473
+
2474
+ app.router.add_post("/upload", upload_handler)
2475
+
2476
+ async def cache_handler(request):
2477
+ path = request.match_info["tail"]
2478
+ full_path = get_cache_path(path)
2479
+
2480
+ if "info" in request.query:
2481
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
2482
+ if not os.path.exists(info_path):
2483
+ return web.Response(text="404: Not Found", status=404)
2484
+
2485
+ # Check for directory traversal for info path
2486
+ try:
2487
+ cache_root = Path(get_cache_path(""))
2488
+ requested_path = Path(info_path).resolve()
2489
+ if not str(requested_path).startswith(str(cache_root)):
2490
+ return web.Response(text="403: Forbidden", status=403)
2491
+ except Exception:
2492
+ return web.Response(text="403: Forbidden", status=403)
2493
+
2494
+ with open(info_path, "r") as f:
2495
+ content = f.read()
2496
+ return web.Response(text=content, content_type="application/json")
2497
+
2498
+ if not os.path.exists(full_path):
2499
+ return web.Response(text="404: Not Found", status=404)
2500
+
2501
+ # Check for directory traversal
2502
+ try:
2503
+ cache_root = Path(get_cache_path(""))
2504
+ requested_path = Path(full_path).resolve()
2505
+ if not str(requested_path).startswith(str(cache_root)):
2506
+ return web.Response(text="403: Forbidden", status=403)
2507
+ except Exception:
2508
+ return web.Response(text="403: Forbidden", status=403)
2509
+
2510
+ with open(full_path, "rb") as f:
2511
+ content = f.read()
2512
+
2513
+ mimetype = get_file_mime_type(full_path)
2514
+ return web.Response(body=content, content_type=mimetype)
2515
+
2516
+ app.router.add_get("/~cache/{tail:.*}", cache_handler)
2517
+
1813
2518
  # OAuth handlers
1814
2519
  async def github_auth_handler(request):
1815
2520
  """Initiate GitHub OAuth flow"""
@@ -2149,10 +2854,9 @@ def main():
2149
2854
 
2150
2855
  if cli_args.default is not None:
2151
2856
  default_model = cli_args.default
2152
- all_models = get_models()
2153
- if default_model not in all_models:
2857
+ provider_model = get_provider_model(default_model)
2858
+ if provider_model is None:
2154
2859
  print(f"Model {default_model} not found")
2155
- print(f"Available models: {', '.join(all_models)}")
2156
2860
  exit(1)
2157
2861
  default_text = g_config["defaults"]["text"]
2158
2862
  default_text["model"] = default_model