causaliq-knowledge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +33 -0
- causaliq_knowledge/base.py +85 -0
- causaliq_knowledge/cli.py +207 -0
- causaliq_knowledge/llm/__init__.py +34 -0
- causaliq_knowledge/llm/gemini_client.py +203 -0
- causaliq_knowledge/llm/groq_client.py +148 -0
- causaliq_knowledge/llm/prompts.py +204 -0
- causaliq_knowledge/llm/provider.py +341 -0
- causaliq_knowledge/models.py +124 -0
- causaliq_knowledge-0.1.0.dist-info/METADATA +185 -0
- causaliq_knowledge-0.1.0.dist-info/RECORD +15 -0
- causaliq_knowledge-0.1.0.dist-info/WHEEL +5 -0
- causaliq_knowledge-0.1.0.dist-info/entry_points.txt +3 -0
- causaliq_knowledge-0.1.0.dist-info/licenses/LICENSE +21 -0
- causaliq_knowledge-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
causaliq-knowledge: LLM and human knowledge for causal discovery.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from causaliq_knowledge.base import KnowledgeProvider
|
|
6
|
+
from causaliq_knowledge.models import EdgeDirection, EdgeKnowledge
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__author__ = "CausalIQ"
|
|
10
|
+
__email__ = "info@causaliq.com"
|
|
11
|
+
|
|
12
|
+
# Package metadata
|
|
13
|
+
__title__ = "causaliq-knowledge"
|
|
14
|
+
__description__ = "LLM and human knowledge for causal discovery"
|
|
15
|
+
|
|
16
|
+
__url__ = "https://github.com/causaliq/causaliq-knowledge"
|
|
17
|
+
__license__ = "MIT"
|
|
18
|
+
|
|
19
|
+
# Version tuple for programmatic access
|
|
20
|
+
VERSION = tuple(map(int, __version__.split(".")))
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"__version__",
|
|
24
|
+
"__author__",
|
|
25
|
+
"__email__",
|
|
26
|
+
"VERSION",
|
|
27
|
+
# Core models
|
|
28
|
+
"EdgeKnowledge",
|
|
29
|
+
"EdgeDirection",
|
|
30
|
+
# Abstract interface
|
|
31
|
+
"KnowledgeProvider",
|
|
32
|
+
# Note: Import LLMKnowledge from causaliq_knowledge.llm
|
|
33
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Abstract base class for knowledge providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from causaliq_knowledge.models import EdgeKnowledge
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class KnowledgeProvider(ABC):
|
|
10
|
+
"""Abstract interface for all knowledge sources.
|
|
11
|
+
|
|
12
|
+
This is the base class that all knowledge providers must implement.
|
|
13
|
+
Knowledge providers can be LLM-based, rule-based, human-input based,
|
|
14
|
+
or any other source of causal knowledge.
|
|
15
|
+
|
|
16
|
+
The primary method is `query_edge()` which asks about the causal
|
|
17
|
+
relationship between two variables.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> class MyKnowledgeProvider(KnowledgeProvider):
|
|
21
|
+
... def query_edge(self, node_a, node_b, context=None):
|
|
22
|
+
... # Implementation here
|
|
23
|
+
... return EdgeKnowledge(exists=True, confidence=0.8, ...)
|
|
24
|
+
...
|
|
25
|
+
>>> provider = MyKnowledgeProvider()
|
|
26
|
+
>>> result = provider.query_edge("smoking", "cancer")
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def query_edge(
|
|
31
|
+
self,
|
|
32
|
+
node_a: str,
|
|
33
|
+
node_b: str,
|
|
34
|
+
context: Optional[dict] = None,
|
|
35
|
+
) -> EdgeKnowledge:
|
|
36
|
+
"""Query whether a causal edge exists between two nodes.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
node_a: Name of the first variable.
|
|
40
|
+
node_b: Name of the second variable.
|
|
41
|
+
context: Optional context dictionary that may include:
|
|
42
|
+
- domain: The domain (e.g., "medicine", "economics")
|
|
43
|
+
- descriptions: Dict mapping variable names to descriptions
|
|
44
|
+
- additional_info: Any other relevant context
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
EdgeKnowledge with:
|
|
48
|
+
- exists: True, False, or None (uncertain)
|
|
49
|
+
- direction: "a_to_b", "b_to_a", "undirected", or None
|
|
50
|
+
- confidence: 0.0 to 1.0
|
|
51
|
+
- reasoning: Human-readable explanation
|
|
52
|
+
- model: Source identifier (optional)
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
NotImplementedError: If not implemented by subclass.
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def query_edges(
|
|
60
|
+
self,
|
|
61
|
+
edges: list[tuple[str, str]],
|
|
62
|
+
context: Optional[dict] = None,
|
|
63
|
+
) -> list[EdgeKnowledge]:
|
|
64
|
+
"""Query multiple edges at once.
|
|
65
|
+
|
|
66
|
+
Default implementation calls query_edge for each pair.
|
|
67
|
+
Subclasses may override for batch optimization.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
edges: List of (node_a, node_b) tuples to query.
|
|
71
|
+
context: Optional context dictionary (shared across all queries).
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
List of EdgeKnowledge results, one per edge pair.
|
|
75
|
+
"""
|
|
76
|
+
return [self.query_edge(a, b, context) for a, b in edges]
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def name(self) -> str:
|
|
80
|
+
"""Return the name of this knowledge provider.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Class name by default. Subclasses may override.
|
|
84
|
+
"""
|
|
85
|
+
return self.__class__.__name__
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Command-line interface for causaliq-knowledge."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
from causaliq_knowledge import __version__
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@click.group()
|
|
15
|
+
@click.version_option(version=__version__)
|
|
16
|
+
def cli() -> None:
|
|
17
|
+
"""CausalIQ Knowledge - LLM knowledge for causal discovery.
|
|
18
|
+
|
|
19
|
+
Query LLMs about causal relationships between variables.
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@cli.command("query")
|
|
25
|
+
@click.argument("node_a")
|
|
26
|
+
@click.argument("node_b")
|
|
27
|
+
@click.option(
|
|
28
|
+
"--model",
|
|
29
|
+
"-m",
|
|
30
|
+
multiple=True,
|
|
31
|
+
default=["groq/llama-3.1-8b-instant"],
|
|
32
|
+
help="LLM model(s) to query. Can be specified multiple times.",
|
|
33
|
+
)
|
|
34
|
+
@click.option(
|
|
35
|
+
"--domain",
|
|
36
|
+
"-d",
|
|
37
|
+
default=None,
|
|
38
|
+
help="Domain context (e.g., 'medicine', 'economics').",
|
|
39
|
+
)
|
|
40
|
+
@click.option(
|
|
41
|
+
"--strategy",
|
|
42
|
+
"-s",
|
|
43
|
+
type=click.Choice(["weighted_vote", "highest_confidence"]),
|
|
44
|
+
default="weighted_vote",
|
|
45
|
+
help="Consensus strategy for multi-model queries.",
|
|
46
|
+
)
|
|
47
|
+
@click.option(
|
|
48
|
+
"--json",
|
|
49
|
+
"output_json",
|
|
50
|
+
is_flag=True,
|
|
51
|
+
help="Output result as JSON.",
|
|
52
|
+
)
|
|
53
|
+
@click.option(
|
|
54
|
+
"--temperature",
|
|
55
|
+
"-t",
|
|
56
|
+
type=float,
|
|
57
|
+
default=0.1,
|
|
58
|
+
help="LLM temperature (0.0-1.0).",
|
|
59
|
+
)
|
|
60
|
+
def query_edge(
|
|
61
|
+
node_a: str,
|
|
62
|
+
node_b: str,
|
|
63
|
+
model: tuple[str, ...],
|
|
64
|
+
domain: Optional[str],
|
|
65
|
+
strategy: str,
|
|
66
|
+
output_json: bool,
|
|
67
|
+
temperature: float,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Query LLMs about a causal relationship between two variables.
|
|
70
|
+
|
|
71
|
+
NODE_A and NODE_B are the variable names to query about.
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
|
|
75
|
+
cqknow query smoking lung_cancer
|
|
76
|
+
|
|
77
|
+
cqknow query smoking lung_cancer --domain medicine
|
|
78
|
+
|
|
79
|
+
cqknow query X Y --model groq/llama-3.1-8b-instant \
|
|
80
|
+
--model gemini/gemini-2.5-flash
|
|
81
|
+
"""
|
|
82
|
+
# Import here to avoid slow startup for --help
|
|
83
|
+
from causaliq_knowledge.llm import LLMKnowledge
|
|
84
|
+
|
|
85
|
+
# Build context
|
|
86
|
+
context = None
|
|
87
|
+
if domain:
|
|
88
|
+
context = {"domain": domain}
|
|
89
|
+
|
|
90
|
+
# Create provider
|
|
91
|
+
try:
|
|
92
|
+
provider = LLMKnowledge(
|
|
93
|
+
models=list(model),
|
|
94
|
+
consensus_strategy=strategy,
|
|
95
|
+
temperature=temperature,
|
|
96
|
+
)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
click.echo(f"Error creating provider: {e}", err=True)
|
|
99
|
+
sys.exit(1)
|
|
100
|
+
|
|
101
|
+
# Query
|
|
102
|
+
click.echo(
|
|
103
|
+
f"Querying {len(model)} model(s) about: {node_a} -> {node_b}",
|
|
104
|
+
err=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
result = provider.query_edge(node_a, node_b, context=context)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
click.echo(f"Error querying LLM: {e}", err=True)
|
|
111
|
+
sys.exit(1)
|
|
112
|
+
|
|
113
|
+
# Output
|
|
114
|
+
if output_json:
|
|
115
|
+
output = {
|
|
116
|
+
"node_a": node_a,
|
|
117
|
+
"node_b": node_b,
|
|
118
|
+
"exists": result.exists,
|
|
119
|
+
"direction": result.direction.value if result.direction else None,
|
|
120
|
+
"confidence": result.confidence,
|
|
121
|
+
"reasoning": result.reasoning,
|
|
122
|
+
"model": result.model,
|
|
123
|
+
}
|
|
124
|
+
click.echo(json.dumps(output, indent=2))
|
|
125
|
+
else:
|
|
126
|
+
# Human-readable output
|
|
127
|
+
exists_map = {True: "Yes", False: "No", None: "Uncertain"}
|
|
128
|
+
exists_str = exists_map[result.exists]
|
|
129
|
+
direction_str = result.direction.value if result.direction else "N/A"
|
|
130
|
+
|
|
131
|
+
click.echo(f"\n{'='*60}")
|
|
132
|
+
click.echo(f"Query: Does '{node_a}' cause '{node_b}'?")
|
|
133
|
+
click.echo("=" * 60)
|
|
134
|
+
click.echo(f"Exists: {exists_str}")
|
|
135
|
+
click.echo(f"Direction: {direction_str}")
|
|
136
|
+
click.echo(f"Confidence: {result.confidence:.2f}")
|
|
137
|
+
click.echo(f"Model(s): {result.model or 'unknown'}")
|
|
138
|
+
click.echo(f"{'='*60}")
|
|
139
|
+
click.echo(f"Reasoning: {result.reasoning}")
|
|
140
|
+
click.echo()
|
|
141
|
+
|
|
142
|
+
# Show stats
|
|
143
|
+
stats = provider.get_stats()
|
|
144
|
+
if stats["total_cost"] > 0:
|
|
145
|
+
click.echo(
|
|
146
|
+
f"Cost: ${stats['total_cost']:.6f} "
|
|
147
|
+
f"({stats['total_calls']} call(s))",
|
|
148
|
+
err=True,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@cli.command("models")
|
|
153
|
+
def list_models() -> None:
|
|
154
|
+
"""List supported LLM models.
|
|
155
|
+
|
|
156
|
+
These are model identifiers that work with our direct API clients.
|
|
157
|
+
Only models with direct API support are listed.
|
|
158
|
+
"""
|
|
159
|
+
models = [
|
|
160
|
+
(
|
|
161
|
+
"Groq (Fast, Free Tier Available)",
|
|
162
|
+
[
|
|
163
|
+
"groq/llama-3.1-8b-instant",
|
|
164
|
+
"groq/llama-3.1-70b-versatile",
|
|
165
|
+
"groq/llama-3.2-1b-preview",
|
|
166
|
+
"groq/llama-3.2-3b-preview",
|
|
167
|
+
"groq/mixtral-8x7b-32768",
|
|
168
|
+
"groq/gemma-7b-it",
|
|
169
|
+
"groq/gemma2-9b-it",
|
|
170
|
+
],
|
|
171
|
+
),
|
|
172
|
+
(
|
|
173
|
+
"Google Gemini (Free Tier Available)",
|
|
174
|
+
[
|
|
175
|
+
"gemini/gemini-2.5-flash",
|
|
176
|
+
"gemini/gemini-1.5-pro",
|
|
177
|
+
"gemini/gemini-1.5-flash",
|
|
178
|
+
"gemini/gemini-1.5-flash-8b",
|
|
179
|
+
],
|
|
180
|
+
),
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
click.echo("\nSupported LLM Models (Direct API Access):\n")
|
|
184
|
+
for provider, model_list in models:
|
|
185
|
+
click.echo(f" {provider}:")
|
|
186
|
+
for m in model_list:
|
|
187
|
+
click.echo(f" - {m}")
|
|
188
|
+
click.echo()
|
|
189
|
+
click.echo("Required API Keys:")
|
|
190
|
+
click.echo(
|
|
191
|
+
" GROQ_API_KEY - Get free API key at https://console.groq.com"
|
|
192
|
+
)
|
|
193
|
+
click.echo(
|
|
194
|
+
" GEMINI_API_KEY - Get free API key at https://aistudio.google.com"
|
|
195
|
+
)
|
|
196
|
+
click.echo()
|
|
197
|
+
click.echo("Default model: groq/llama-3.1-8b-instant")
|
|
198
|
+
click.echo()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def main() -> None:
|
|
202
|
+
"""Entry point for the CLI."""
|
|
203
|
+
cli()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
if __name__ == "__main__": # pragma: no cover
|
|
207
|
+
main()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""LLM integration module for causaliq-knowledge."""
|
|
2
|
+
|
|
3
|
+
from causaliq_knowledge.llm.gemini_client import (
|
|
4
|
+
GeminiClient,
|
|
5
|
+
GeminiConfig,
|
|
6
|
+
GeminiResponse,
|
|
7
|
+
)
|
|
8
|
+
from causaliq_knowledge.llm.groq_client import (
|
|
9
|
+
GroqClient,
|
|
10
|
+
GroqConfig,
|
|
11
|
+
GroqResponse,
|
|
12
|
+
)
|
|
13
|
+
from causaliq_knowledge.llm.prompts import EdgeQueryPrompt, parse_edge_response
|
|
14
|
+
from causaliq_knowledge.llm.provider import (
|
|
15
|
+
CONSENSUS_STRATEGIES,
|
|
16
|
+
LLMKnowledge,
|
|
17
|
+
highest_confidence,
|
|
18
|
+
weighted_vote,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"CONSENSUS_STRATEGIES",
|
|
23
|
+
"EdgeQueryPrompt",
|
|
24
|
+
"GeminiClient",
|
|
25
|
+
"GeminiConfig",
|
|
26
|
+
"GeminiResponse",
|
|
27
|
+
"GroqClient",
|
|
28
|
+
"GroqConfig",
|
|
29
|
+
"GroqResponse",
|
|
30
|
+
"LLMKnowledge",
|
|
31
|
+
"highest_confidence",
|
|
32
|
+
"parse_edge_response",
|
|
33
|
+
"weighted_vote",
|
|
34
|
+
]
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Direct Google Gemini API client - clean and reliable."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class GeminiConfig:
|
|
16
|
+
"""Configuration for Gemini API client."""
|
|
17
|
+
|
|
18
|
+
model: str = "gemini-2.5-flash"
|
|
19
|
+
temperature: float = 0.1
|
|
20
|
+
max_tokens: int = 500
|
|
21
|
+
timeout: float = 30.0
|
|
22
|
+
api_key: Optional[str] = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
"""Set API key from environment if not provided."""
|
|
26
|
+
if self.api_key is None:
|
|
27
|
+
self.api_key = os.getenv("GEMINI_API_KEY")
|
|
28
|
+
if not self.api_key:
|
|
29
|
+
raise ValueError("GEMINI_API_KEY environment variable is required")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class GeminiResponse:
|
|
34
|
+
"""Response from Gemini API."""
|
|
35
|
+
|
|
36
|
+
content: str
|
|
37
|
+
model: str
|
|
38
|
+
input_tokens: int = 0
|
|
39
|
+
output_tokens: int = 0
|
|
40
|
+
cost: float = 0.0 # Gemini free tier
|
|
41
|
+
raw_response: Optional[Dict] = None
|
|
42
|
+
|
|
43
|
+
def parse_json(self) -> Optional[Dict[str, Any]]:
|
|
44
|
+
"""Parse content as JSON, handling common formatting issues."""
|
|
45
|
+
try:
|
|
46
|
+
# Clean up potential markdown code blocks
|
|
47
|
+
text = self.content.strip()
|
|
48
|
+
if text.startswith("```json"):
|
|
49
|
+
text = text[7:]
|
|
50
|
+
elif text.startswith("```"):
|
|
51
|
+
text = text[3:]
|
|
52
|
+
if text.endswith("```"):
|
|
53
|
+
text = text[:-3]
|
|
54
|
+
|
|
55
|
+
return json.loads(text.strip()) # type: ignore[no-any-return]
|
|
56
|
+
except json.JSONDecodeError:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GeminiClient:
|
|
61
|
+
"""Direct Gemini API client."""
|
|
62
|
+
|
|
63
|
+
BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
64
|
+
|
|
65
|
+
def __init__(self, config: Optional[GeminiConfig] = None):
|
|
66
|
+
"""Initialize Gemini client."""
|
|
67
|
+
self.config = config or GeminiConfig()
|
|
68
|
+
self._total_calls = 0
|
|
69
|
+
|
|
70
|
+
def completion(
|
|
71
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
72
|
+
) -> GeminiResponse:
|
|
73
|
+
"""Make a chat completion request to Gemini."""
|
|
74
|
+
|
|
75
|
+
# Convert OpenAI-style messages to Gemini format
|
|
76
|
+
contents = []
|
|
77
|
+
system_instruction = None
|
|
78
|
+
|
|
79
|
+
for msg in messages:
|
|
80
|
+
if msg["role"] == "system":
|
|
81
|
+
# Gemini handles system prompts differently
|
|
82
|
+
system_instruction = {"parts": [{"text": msg["content"]}]}
|
|
83
|
+
elif msg["role"] == "user":
|
|
84
|
+
contents.append(
|
|
85
|
+
{"role": "user", "parts": [{"text": msg["content"]}]}
|
|
86
|
+
)
|
|
87
|
+
elif msg["role"] == "assistant":
|
|
88
|
+
contents.append(
|
|
89
|
+
{"role": "model", "parts": [{"text": msg["content"]}]}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Build request payload in Gemini's format
|
|
93
|
+
payload = {
|
|
94
|
+
"contents": contents,
|
|
95
|
+
"generationConfig": {
|
|
96
|
+
"temperature": kwargs.get(
|
|
97
|
+
"temperature", self.config.temperature
|
|
98
|
+
),
|
|
99
|
+
"maxOutputTokens": kwargs.get(
|
|
100
|
+
"max_tokens", self.config.max_tokens
|
|
101
|
+
),
|
|
102
|
+
"responseMimeType": "text/plain",
|
|
103
|
+
},
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Add system instruction if present
|
|
107
|
+
if system_instruction:
|
|
108
|
+
payload["systemInstruction"] = system_instruction
|
|
109
|
+
|
|
110
|
+
# API endpoint with model and key
|
|
111
|
+
url = f"{self.BASE_URL}/{self.config.model}:generateContent"
|
|
112
|
+
params = {"key": self.config.api_key}
|
|
113
|
+
|
|
114
|
+
headers = {"Content-Type": "application/json"}
|
|
115
|
+
|
|
116
|
+
logger.debug(f"Calling Gemini API with model: {self.config.model}")
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
with httpx.Client(timeout=self.config.timeout) as client:
|
|
120
|
+
response = client.post(
|
|
121
|
+
url, json=payload, headers=headers, params=params
|
|
122
|
+
)
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
|
|
125
|
+
data = response.json()
|
|
126
|
+
|
|
127
|
+
# Handle Gemini API errors
|
|
128
|
+
if "error" in data:
|
|
129
|
+
error_msg = data["error"].get("message", "Unknown error")
|
|
130
|
+
raise ValueError(f"Gemini API error: {error_msg}")
|
|
131
|
+
|
|
132
|
+
# Extract response data from Gemini format
|
|
133
|
+
candidates = data.get("candidates", [])
|
|
134
|
+
if not candidates:
|
|
135
|
+
raise ValueError("No candidates returned by Gemini API")
|
|
136
|
+
|
|
137
|
+
candidate = candidates[0]
|
|
138
|
+
if candidate.get("finishReason") == "SAFETY":
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Content was blocked by Gemini safety filters"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Extract text content
|
|
144
|
+
parts = candidate.get("content", {}).get("parts", [])
|
|
145
|
+
content = ""
|
|
146
|
+
for part in parts:
|
|
147
|
+
if "text" in part:
|
|
148
|
+
content += part["text"]
|
|
149
|
+
|
|
150
|
+
# Extract usage info
|
|
151
|
+
usage = data.get("usageMetadata", {})
|
|
152
|
+
input_tokens = usage.get("promptTokenCount", 0)
|
|
153
|
+
output_tokens = usage.get("candidatesTokenCount", 0)
|
|
154
|
+
|
|
155
|
+
self._total_calls += 1
|
|
156
|
+
|
|
157
|
+
logger.debug(
|
|
158
|
+
f"Gemini response: {input_tokens} in, {output_tokens} out"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return GeminiResponse(
|
|
162
|
+
content=content,
|
|
163
|
+
model=self.config.model,
|
|
164
|
+
input_tokens=input_tokens,
|
|
165
|
+
output_tokens=output_tokens,
|
|
166
|
+
cost=0.0, # Free tier
|
|
167
|
+
raw_response=data,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
except httpx.HTTPStatusError as e:
|
|
171
|
+
try:
|
|
172
|
+
error_data = e.response.json()
|
|
173
|
+
error_msg = error_data.get("error", {}).get(
|
|
174
|
+
"message", e.response.text
|
|
175
|
+
)
|
|
176
|
+
except Exception:
|
|
177
|
+
error_msg = e.response.text
|
|
178
|
+
|
|
179
|
+
logger.error(
|
|
180
|
+
f"Gemini API HTTP error: {e.response.status_code} - "
|
|
181
|
+
f"{error_msg}"
|
|
182
|
+
)
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Gemini API error: {e.response.status_code} - {error_msg}"
|
|
185
|
+
)
|
|
186
|
+
except httpx.TimeoutException:
|
|
187
|
+
raise ValueError("Gemini API request timed out")
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.error(f"Gemini API unexpected error: {e}")
|
|
190
|
+
raise ValueError(f"Gemini API error: {str(e)}")
|
|
191
|
+
|
|
192
|
+
def complete_json(
|
|
193
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
194
|
+
) -> tuple[Optional[Dict[str, Any]], GeminiResponse]:
|
|
195
|
+
"""Make a completion request and parse response as JSON."""
|
|
196
|
+
response = self.completion(messages, **kwargs)
|
|
197
|
+
parsed = response.parse_json()
|
|
198
|
+
return parsed, response
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def call_count(self) -> int:
|
|
202
|
+
"""Number of API calls made."""
|
|
203
|
+
return self._total_calls
|