causaliq-knowledge 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  """Direct Groq API client - clean and reliable."""
2
2
 
3
- import json
4
3
  import logging
5
4
  import os
6
5
  from dataclasses import dataclass
@@ -8,12 +7,28 @@ from typing import Any, Dict, List, Optional
8
7
 
9
8
  import httpx
10
9
 
10
+ from causaliq_knowledge.llm.base_client import (
11
+ BaseLLMClient,
12
+ LLMConfig,
13
+ LLMResponse,
14
+ )
15
+
11
16
  logger = logging.getLogger(__name__)
12
17
 
13
18
 
14
19
  @dataclass
15
- class GroqConfig:
16
- """Configuration for Groq API client."""
20
+ class GroqConfig(LLMConfig):
21
+ """Configuration for Groq API client.
22
+
23
+ Extends LLMConfig with Groq-specific defaults.
24
+
25
+ Attributes:
26
+ model: Groq model identifier (default: llama-3.1-8b-instant).
27
+ temperature: Sampling temperature (default: 0.1).
28
+ max_tokens: Maximum response tokens (default: 500).
29
+ timeout: Request timeout in seconds (default: 30.0).
30
+ api_key: Groq API key (falls back to GROQ_API_KEY env var).
31
+ """
17
32
 
18
33
  model: str = "llama-3.1-8b-instant"
19
34
  temperature: float = 0.1
@@ -29,50 +44,52 @@ class GroqConfig:
29
44
  raise ValueError("GROQ_API_KEY environment variable is required")
30
45
 
31
46
 
32
- @dataclass
33
- class GroqResponse:
34
- """Response from Groq API."""
35
-
36
- content: str
37
- model: str
38
- input_tokens: int = 0
39
- output_tokens: int = 0
40
- cost: float = 0.0 # Groq free tier
41
- raw_response: Optional[Dict] = None
42
-
43
- def parse_json(self) -> Optional[Dict[str, Any]]:
44
- """Parse content as JSON, handling common formatting issues."""
45
- try:
46
- # Clean up potential markdown code blocks
47
- text = self.content.strip()
48
- if text.startswith("```json"):
49
- text = text[7:]
50
- elif text.startswith("```"):
51
- text = text[3:]
52
- if text.endswith("```"):
53
- text = text[:-3]
54
-
55
- return json.loads(text.strip()) # type: ignore[no-any-return]
56
- except json.JSONDecodeError as e:
57
- logger.warning(f"Failed to parse JSON response: {e}")
58
- return None
47
+ class GroqClient(BaseLLMClient):
48
+ """Direct Groq API client.
59
49
 
50
+ Implements the BaseLLMClient interface for Groq's API.
51
+ Uses httpx for HTTP requests.
60
52
 
61
- class GroqClient:
62
- """Direct Groq API client."""
53
+ Example:
54
+ >>> config = GroqConfig(model="llama-3.1-8b-instant")
55
+ >>> client = GroqClient(config)
56
+ >>> msgs = [{"role": "user", "content": "Hello"}]
57
+ >>> response = client.completion(msgs)
58
+ >>> print(response.content)
59
+ """
63
60
 
64
61
  BASE_URL = "https://api.groq.com/openai/v1"
65
62
 
66
- def __init__(self, config: Optional[GroqConfig] = None):
67
- """Initialize Groq client."""
63
+ def __init__(self, config: Optional[GroqConfig] = None) -> None:
64
+ """Initialize Groq client.
65
+
66
+ Args:
67
+ config: Groq configuration. If None, uses defaults with
68
+ API key from GROQ_API_KEY environment variable.
69
+ """
68
70
  self.config = config or GroqConfig()
69
71
  self._total_calls = 0
70
72
 
73
+ @property
74
+ def provider_name(self) -> str:
75
+ """Return the provider name."""
76
+ return "groq"
77
+
71
78
  def completion(
72
79
  self, messages: List[Dict[str, str]], **kwargs: Any
73
- ) -> GroqResponse:
74
- """Make a chat completion request to Groq."""
80
+ ) -> LLMResponse:
81
+ """Make a chat completion request to Groq.
82
+
83
+ Args:
84
+ messages: List of message dicts with "role" and "content" keys.
85
+ **kwargs: Override config options (temperature, max_tokens).
75
86
 
87
+ Returns:
88
+ LLMResponse with the generated content and metadata.
89
+
90
+ Raises:
91
+ ValueError: If the API request fails.
92
+ """
76
93
  # Build request payload
77
94
  payload = {
78
95
  "model": self.config.model,
@@ -112,7 +129,7 @@ class GroqClient:
112
129
  f"Groq response: {input_tokens} in, {output_tokens} out"
113
130
  )
114
131
 
115
- return GroqResponse(
132
+ return LLMResponse(
116
133
  content=content,
117
134
  model=data.get("model", self.config.model),
118
135
  input_tokens=input_tokens,
@@ -136,13 +153,71 @@ class GroqClient:
136
153
 
137
154
  def complete_json(
138
155
  self, messages: List[Dict[str, str]], **kwargs: Any
139
- ) -> tuple[Optional[Dict[str, Any]], GroqResponse]:
140
- """Make a completion request and parse response as JSON."""
156
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
157
+ """Make a completion request and parse response as JSON.
158
+
159
+ Args:
160
+ messages: List of message dicts with "role" and "content" keys.
161
+ **kwargs: Override config options passed to completion().
162
+
163
+ Returns:
164
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
165
+ """
141
166
  response = self.completion(messages, **kwargs)
142
167
  parsed = response.parse_json()
143
168
  return parsed, response
144
169
 
145
170
  @property
146
171
  def call_count(self) -> int:
147
- """Number of API calls made."""
172
+ """Return the number of API calls made."""
148
173
  return self._total_calls
174
+
175
+ def is_available(self) -> bool:
176
+ """Check if Groq API is available.
177
+
178
+ Returns:
179
+ True if GROQ_API_KEY is configured.
180
+ """
181
+ return bool(self.config.api_key)
182
+
183
+ def list_models(self) -> List[str]:
184
+ """List available models from Groq API.
185
+
186
+ Queries the Groq API to get models accessible with the current
187
+ API key. Filters to only include text generation models.
188
+
189
+ Returns:
190
+ List of model identifiers (e.g., ['llama-3.1-8b-instant', ...]).
191
+
192
+ Raises:
193
+ ValueError: If the API request fails.
194
+ """
195
+ try:
196
+ with httpx.Client(timeout=self.config.timeout) as client:
197
+ response = client.get(
198
+ f"{self.BASE_URL}/models",
199
+ headers={"Authorization": f"Bearer {self.config.api_key}"},
200
+ )
201
+ response.raise_for_status()
202
+ data = response.json()
203
+
204
+ # Filter and sort models
205
+ models = []
206
+ for model in data.get("data", []):
207
+ model_id = model.get("id", "")
208
+ # Skip whisper (audio), guard, and safeguard models
209
+ if any(
210
+ x in model_id.lower()
211
+ for x in ["whisper", "guard", "embed"]
212
+ ):
213
+ continue
214
+ models.append(model_id)
215
+
216
+ return sorted(models)
217
+
218
+ except httpx.HTTPStatusError as e:
219
+ raise ValueError(
220
+ f"Groq API error: {e.response.status_code} - {e.response.text}"
221
+ )
222
+ except Exception as e:
223
+ raise ValueError(f"Failed to list Groq models: {e}")
@@ -0,0 +1,122 @@
1
+ """Direct Mistral AI API client - OpenAI-compatible API."""
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ from causaliq_knowledge.llm.openai_compat_client import (
9
+ OpenAICompatClient,
10
+ OpenAICompatConfig,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class MistralConfig(OpenAICompatConfig):
18
+ """Configuration for Mistral AI API client.
19
+
20
+ Extends OpenAICompatConfig with Mistral-specific defaults.
21
+
22
+ Attributes:
23
+ model: Mistral model identifier (default: mistral-small-latest).
24
+ temperature: Sampling temperature (default: 0.1).
25
+ max_tokens: Maximum response tokens (default: 500).
26
+ timeout: Request timeout in seconds (default: 30.0).
27
+ api_key: Mistral API key (falls back to MISTRAL_API_KEY env var).
28
+ """
29
+
30
+ model: str = "mistral-small-latest"
31
+ temperature: float = 0.1
32
+ max_tokens: int = 500
33
+ timeout: float = 30.0
34
+ api_key: Optional[str] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ """Set API key from environment if not provided."""
38
+ if self.api_key is None:
39
+ self.api_key = os.getenv("MISTRAL_API_KEY")
40
+ if not self.api_key:
41
+ raise ValueError(
42
+ "MISTRAL_API_KEY environment variable is required"
43
+ )
44
+
45
+
46
+ class MistralClient(OpenAICompatClient):
47
+ """Direct Mistral AI API client.
48
+
49
+ Mistral AI is a French company providing high-quality LLMs with an
50
+ OpenAI-compatible API.
51
+
52
+ Available models:
53
+ - mistral-small-latest: Fast, cost-effective
54
+ - mistral-medium-latest: Balanced performance
55
+ - mistral-large-latest: Most capable
56
+ - codestral-latest: Optimized for code
57
+
58
+ Example:
59
+ >>> config = MistralConfig(model="mistral-small-latest")
60
+ >>> client = MistralClient(config)
61
+ >>> msgs = [{"role": "user", "content": "Hello"}]
62
+ >>> response = client.completion(msgs)
63
+ >>> print(response.content)
64
+ """
65
+
66
+ BASE_URL = "https://api.mistral.ai/v1"
67
+ PROVIDER_NAME = "mistral"
68
+ ENV_VAR = "MISTRAL_API_KEY"
69
+
70
+ def __init__(self, config: Optional[MistralConfig] = None) -> None:
71
+ """Initialize Mistral client.
72
+
73
+ Args:
74
+ config: Mistral configuration. If None, uses defaults with
75
+ API key from MISTRAL_API_KEY environment variable.
76
+ """
77
+ super().__init__(config)
78
+
79
+ def _default_config(self) -> MistralConfig:
80
+ """Return default Mistral configuration."""
81
+ return MistralConfig()
82
+
83
+ def _get_pricing(self) -> Dict[str, Dict[str, float]]:
84
+ """Return Mistral pricing per 1M tokens.
85
+
86
+ Returns:
87
+ Dict mapping model prefixes to input/output costs.
88
+ """
89
+ # Mistral pricing as of Jan 2025
90
+ return {
91
+ "mistral-large": {"input": 2.00, "output": 6.00},
92
+ "mistral-medium": {"input": 2.70, "output": 8.10},
93
+ "mistral-small": {"input": 0.20, "output": 0.60},
94
+ "codestral": {"input": 0.20, "output": 0.60},
95
+ "open-mistral-nemo": {"input": 0.15, "output": 0.15},
96
+ "open-mixtral-8x22b": {"input": 2.00, "output": 6.00},
97
+ "open-mixtral-8x7b": {"input": 0.70, "output": 0.70},
98
+ "ministral-3b": {"input": 0.04, "output": 0.04},
99
+ "ministral-8b": {"input": 0.10, "output": 0.10},
100
+ }
101
+
102
+ def _filter_models(self, models: List[str]) -> List[str]:
103
+ """Filter to Mistral chat models only.
104
+
105
+ Args:
106
+ models: List of all model IDs from API.
107
+
108
+ Returns:
109
+ Filtered list of Mistral models.
110
+ """
111
+ filtered = []
112
+ for model_id in models:
113
+ # Include mistral and codestral models
114
+ if any(
115
+ prefix in model_id.lower()
116
+ for prefix in ["mistral", "codestral", "ministral", "mixtral"]
117
+ ):
118
+ # Exclude embedding models
119
+ if "embed" in model_id.lower():
120
+ continue
121
+ filtered.append(model_id)
122
+ return filtered
@@ -0,0 +1,240 @@
1
+ """Local Ollama API client for running Llama models locally."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import httpx
8
+
9
+ from causaliq_knowledge.llm.base_client import (
10
+ BaseLLMClient,
11
+ LLMConfig,
12
+ LLMResponse,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class OllamaConfig(LLMConfig):
20
+ """Configuration for Ollama API client.
21
+
22
+ Extends LLMConfig with Ollama-specific defaults.
23
+
24
+ Attributes:
25
+ model: Ollama model identifier (default: llama3.2:1b).
26
+ temperature: Sampling temperature (default: 0.1).
27
+ max_tokens: Maximum response tokens (default: 500).
28
+ timeout: Request timeout in seconds (default: 120.0, local).
29
+ api_key: Not used for Ollama (local server).
30
+ base_url: Ollama server URL (default: http://localhost:11434).
31
+ """
32
+
33
+ model: str = "llama3.2:1b"
34
+ temperature: float = 0.1
35
+ max_tokens: int = 500
36
+ timeout: float = 120.0 # Local inference can be slow
37
+ api_key: Optional[str] = None # Not needed for local Ollama
38
+ base_url: str = "http://localhost:11434"
39
+
40
+
41
+ class OllamaClient(BaseLLMClient):
42
+ """Local Ollama API client.
43
+
44
+ Implements the BaseLLMClient interface for locally running Ollama server.
45
+ Uses httpx for HTTP requests to the local Ollama API.
46
+
47
+ Ollama provides an OpenAI-compatible API for running open-source models
48
+ like Llama locally without requiring API keys or internet access.
49
+
50
+ Example:
51
+ >>> config = OllamaConfig(model="llama3.2:1b")
52
+ >>> client = OllamaClient(config)
53
+ >>> msgs = [{"role": "user", "content": "Hello"}]
54
+ >>> response = client.completion(msgs)
55
+ >>> print(response.content)
56
+ """
57
+
58
+ def __init__(self, config: Optional[OllamaConfig] = None) -> None:
59
+ """Initialize Ollama client.
60
+
61
+ Args:
62
+ config: Ollama configuration. If None, uses defaults connecting
63
+ to localhost:11434 with llama3.2:1b model.
64
+ """
65
+ self.config = config or OllamaConfig()
66
+ self._total_calls = 0
67
+
68
+ @property
69
+ def provider_name(self) -> str:
70
+ """Return the provider name."""
71
+ return "ollama"
72
+
73
+ def completion(
74
+ self, messages: List[Dict[str, str]], **kwargs: Any
75
+ ) -> LLMResponse:
76
+ """Make a chat completion request to Ollama.
77
+
78
+ Args:
79
+ messages: List of message dicts with "role" and "content" keys.
80
+ **kwargs: Override config options (temperature, max_tokens).
81
+
82
+ Returns:
83
+ LLMResponse with the generated content and metadata.
84
+
85
+ Raises:
86
+ ValueError: If the API request fails or Ollama is not running.
87
+ """
88
+ # Build request payload (Ollama uses similar format to OpenAI)
89
+ payload: Dict[str, Any] = {
90
+ "model": self.config.model,
91
+ "messages": messages,
92
+ "stream": False,
93
+ "options": {
94
+ "temperature": kwargs.get(
95
+ "temperature", self.config.temperature
96
+ ),
97
+ "num_predict": kwargs.get(
98
+ "max_tokens", self.config.max_tokens
99
+ ),
100
+ },
101
+ }
102
+
103
+ url = f"{self.config.base_url}/api/chat"
104
+ headers = {"Content-Type": "application/json"}
105
+
106
+ logger.debug(f"Calling Ollama API with model: {self.config.model}")
107
+
108
+ try:
109
+ with httpx.Client(timeout=self.config.timeout) as client:
110
+ response = client.post(url, json=payload, headers=headers)
111
+ response.raise_for_status()
112
+
113
+ data = response.json()
114
+
115
+ # Extract response content
116
+ content = data.get("message", {}).get("content", "")
117
+
118
+ # Extract token counts (Ollama provides these)
119
+ input_tokens = data.get("prompt_eval_count", 0)
120
+ output_tokens = data.get("eval_count", 0)
121
+
122
+ self._total_calls += 1
123
+
124
+ logger.debug(
125
+ f"Ollama response: {input_tokens} in, {output_tokens} out"
126
+ )
127
+
128
+ return LLMResponse(
129
+ content=content,
130
+ model=data.get("model", self.config.model),
131
+ input_tokens=input_tokens,
132
+ output_tokens=output_tokens,
133
+ cost=0.0, # Local inference is free
134
+ raw_response=data,
135
+ )
136
+
137
+ except httpx.ConnectError:
138
+ raise ValueError(
139
+ "Could not connect to Ollama. "
140
+ "Make sure Ollama is running (run 'ollama serve' or start "
141
+ "the Ollama app). "
142
+ f"Tried to connect to: {self.config.base_url}"
143
+ )
144
+ except httpx.HTTPStatusError as e:
145
+ if e.response.status_code == 404:
146
+ raise ValueError(
147
+ f"Model '{self.config.model}' not found. "
148
+ f"Run 'ollama pull {self.config.model}' to download it."
149
+ )
150
+ logger.error(
151
+ f"Ollama API error: {e.response.status_code} - "
152
+ f"{e.response.text}"
153
+ )
154
+ raise ValueError(
155
+ f"Ollama API error: {e.response.status_code} - "
156
+ f"{e.response.text}"
157
+ )
158
+ except httpx.TimeoutException:
159
+ raise ValueError(
160
+ "Ollama API request timed out. Local inference can be slow - "
161
+ "try increasing the timeout in OllamaConfig."
162
+ )
163
+ except Exception as e:
164
+ logger.error(f"Ollama API unexpected error: {e}")
165
+ raise ValueError(f"Ollama API error: {str(e)}")
166
+
167
+ def complete_json(
168
+ self, messages: List[Dict[str, str]], **kwargs: Any
169
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
170
+ """Make a completion request and parse response as JSON.
171
+
172
+ Args:
173
+ messages: List of message dicts with "role" and "content" keys.
174
+ **kwargs: Override config options passed to completion().
175
+
176
+ Returns:
177
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
178
+ """
179
+ response = self.completion(messages, **kwargs)
180
+ parsed = response.parse_json()
181
+ return parsed, response
182
+
183
+ @property
184
+ def call_count(self) -> int:
185
+ """Return the number of API calls made."""
186
+ return self._total_calls
187
+
188
+ def is_available(self) -> bool:
189
+ """Check if Ollama server is running and model is available.
190
+
191
+ Returns:
192
+ True if Ollama is running and the configured model exists.
193
+ """
194
+ try:
195
+ with httpx.Client(timeout=5.0) as client:
196
+ # Check if server is running
197
+ response = client.get(f"{self.config.base_url}/api/tags")
198
+ if response.status_code != 200:
199
+ return False
200
+
201
+ # Check if model is available
202
+ data = response.json()
203
+ models = [m.get("name", "") for m in data.get("models", [])]
204
+ # Ollama model names can have :latest suffix
205
+ model_name = self.config.model
206
+ return any(
207
+ m == model_name or m.startswith(f"{model_name}:")
208
+ for m in models
209
+ )
210
+ except Exception:
211
+ return False
212
+
213
+ def list_models(self) -> List[str]:
214
+ """List installed models from Ollama.
215
+
216
+ Queries the local Ollama server to get installed models.
217
+ Unlike cloud providers, this returns only models the user
218
+ has explicitly pulled/installed.
219
+
220
+ Returns:
221
+ List of model identifiers (e.g., ['llama3.2:1b', ...]).
222
+
223
+ Raises:
224
+ ValueError: If Ollama server is not running.
225
+ """
226
+ try:
227
+ with httpx.Client(timeout=5.0) as client:
228
+ response = client.get(f"{self.config.base_url}/api/tags")
229
+ response.raise_for_status()
230
+ data = response.json()
231
+
232
+ models = [m.get("name", "") for m in data.get("models", [])]
233
+ return sorted(models)
234
+
235
+ except httpx.ConnectError:
236
+ raise ValueError(
237
+ "Ollama server not running. Start with: ollama serve"
238
+ )
239
+ except Exception as e:
240
+ raise ValueError(f"Failed to list Ollama models: {e}")
@@ -0,0 +1,115 @@
1
+ """Direct OpenAI API client - clean and reliable."""
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ from causaliq_knowledge.llm.openai_compat_client import (
9
+ OpenAICompatClient,
10
+ OpenAICompatConfig,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class OpenAIConfig(OpenAICompatConfig):
18
+ """Configuration for OpenAI API client.
19
+
20
+ Extends OpenAICompatConfig with OpenAI-specific defaults.
21
+
22
+ Attributes:
23
+ model: OpenAI model identifier (default: gpt-4o-mini).
24
+ temperature: Sampling temperature (default: 0.1).
25
+ max_tokens: Maximum response tokens (default: 500).
26
+ timeout: Request timeout in seconds (default: 30.0).
27
+ api_key: OpenAI API key (falls back to OPENAI_API_KEY env var).
28
+ """
29
+
30
+ model: str = "gpt-4o-mini"
31
+ temperature: float = 0.1
32
+ max_tokens: int = 500
33
+ timeout: float = 30.0
34
+ api_key: Optional[str] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ """Set API key from environment if not provided."""
38
+ if self.api_key is None:
39
+ self.api_key = os.getenv("OPENAI_API_KEY")
40
+ if not self.api_key:
41
+ raise ValueError("OPENAI_API_KEY environment variable is required")
42
+
43
+
44
+ class OpenAIClient(OpenAICompatClient):
45
+ """Direct OpenAI API client.
46
+
47
+ Implements the BaseLLMClient interface for OpenAI's API.
48
+ Uses httpx for HTTP requests.
49
+
50
+ Example:
51
+ >>> config = OpenAIConfig(model="gpt-4o-mini")
52
+ >>> client = OpenAIClient(config)
53
+ >>> msgs = [{"role": "user", "content": "Hello"}]
54
+ >>> response = client.completion(msgs)
55
+ >>> print(response.content)
56
+ """
57
+
58
+ BASE_URL = "https://api.openai.com/v1"
59
+ PROVIDER_NAME = "openai"
60
+ ENV_VAR = "OPENAI_API_KEY"
61
+
62
+ def __init__(self, config: Optional[OpenAIConfig] = None) -> None:
63
+ """Initialize OpenAI client.
64
+
65
+ Args:
66
+ config: OpenAI configuration. If None, uses defaults with
67
+ API key from OPENAI_API_KEY environment variable.
68
+ """
69
+ super().__init__(config)
70
+
71
+ def _default_config(self) -> OpenAIConfig:
72
+ """Return default OpenAI configuration."""
73
+ return OpenAIConfig()
74
+
75
+ def _get_pricing(self) -> Dict[str, Dict[str, float]]:
76
+ """Return OpenAI pricing per 1M tokens.
77
+
78
+ Returns:
79
+ Dict mapping model prefixes to input/output costs.
80
+ """
81
+ # Order matters - more specific prefixes must come first
82
+ return {
83
+ "gpt-4o-mini": {"input": 0.15, "output": 0.60},
84
+ "gpt-4o": {"input": 2.50, "output": 10.00},
85
+ "gpt-4-turbo": {"input": 10.00, "output": 30.00},
86
+ "gpt-4": {"input": 30.00, "output": 60.00},
87
+ "gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
88
+ "o1-mini": {"input": 3.00, "output": 12.00},
89
+ "o1": {"input": 15.00, "output": 60.00},
90
+ }
91
+
92
+ def _filter_models(self, models: List[str]) -> List[str]:
93
+ """Filter to OpenAI chat models only.
94
+
95
+ Args:
96
+ models: List of all model IDs from API.
97
+
98
+ Returns:
99
+ Filtered list of GPT and o1/o3 models.
100
+ """
101
+ filtered = []
102
+ for model_id in models:
103
+ # Include GPT and o1/o3 models
104
+ if any(
105
+ prefix in model_id
106
+ for prefix in ["gpt-4", "gpt-3.5", "o1", "o3"]
107
+ ):
108
+ # Exclude instruct variants and specific exclusions
109
+ if any(
110
+ x in model_id.lower()
111
+ for x in ["instruct", "vision", "audio", "realtime"]
112
+ ):
113
+ continue
114
+ filtered.append(model_id)
115
+ return filtered