causaliq-knowledge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ """
2
+ causaliq-knowledge: LLM and human knowledge for causal discovery.
3
+ """
4
+
5
+ from causaliq_knowledge.base import KnowledgeProvider
6
+ from causaliq_knowledge.models import EdgeDirection, EdgeKnowledge
7
+
8
+ __version__ = "0.1.0"
9
+ __author__ = "CausalIQ"
10
+ __email__ = "info@causaliq.com"
11
+
12
+ # Package metadata
13
+ __title__ = "causaliq-knowledge"
14
+ __description__ = "LLM and human knowledge for causal discovery"
15
+
16
+ __url__ = "https://github.com/causaliq/causaliq-knowledge"
17
+ __license__ = "MIT"
18
+
19
+ # Version tuple for programmatic access
20
+ VERSION = tuple(map(int, __version__.split(".")))
21
+
22
+ __all__ = [
23
+ "__version__",
24
+ "__author__",
25
+ "__email__",
26
+ "VERSION",
27
+ # Core models
28
+ "EdgeKnowledge",
29
+ "EdgeDirection",
30
+ # Abstract interface
31
+ "KnowledgeProvider",
32
+ # Note: Import LLMKnowledge from causaliq_knowledge.llm
33
+ ]
@@ -0,0 +1,85 @@
1
+ """Abstract base class for knowledge providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional
5
+
6
+ from causaliq_knowledge.models import EdgeKnowledge
7
+
8
+
9
+ class KnowledgeProvider(ABC):
10
+ """Abstract interface for all knowledge sources.
11
+
12
+ This is the base class that all knowledge providers must implement.
13
+ Knowledge providers can be LLM-based, rule-based, human-input based,
14
+ or any other source of causal knowledge.
15
+
16
+ The primary method is `query_edge()` which asks about the causal
17
+ relationship between two variables.
18
+
19
+ Example:
20
+ >>> class MyKnowledgeProvider(KnowledgeProvider):
21
+ ... def query_edge(self, node_a, node_b, context=None):
22
+ ... # Implementation here
23
+ ... return EdgeKnowledge(exists=True, confidence=0.8, ...)
24
+ ...
25
+ >>> provider = MyKnowledgeProvider()
26
+ >>> result = provider.query_edge("smoking", "cancer")
27
+ """
28
+
29
+ @abstractmethod
30
+ def query_edge(
31
+ self,
32
+ node_a: str,
33
+ node_b: str,
34
+ context: Optional[dict] = None,
35
+ ) -> EdgeKnowledge:
36
+ """Query whether a causal edge exists between two nodes.
37
+
38
+ Args:
39
+ node_a: Name of the first variable.
40
+ node_b: Name of the second variable.
41
+ context: Optional context dictionary that may include:
42
+ - domain: The domain (e.g., "medicine", "economics")
43
+ - descriptions: Dict mapping variable names to descriptions
44
+ - additional_info: Any other relevant context
45
+
46
+ Returns:
47
+ EdgeKnowledge with:
48
+ - exists: True, False, or None (uncertain)
49
+ - direction: "a_to_b", "b_to_a", "undirected", or None
50
+ - confidence: 0.0 to 1.0
51
+ - reasoning: Human-readable explanation
52
+ - model: Source identifier (optional)
53
+
54
+ Raises:
55
+ NotImplementedError: If not implemented by subclass.
56
+ """
57
+ pass
58
+
59
+ def query_edges(
60
+ self,
61
+ edges: list[tuple[str, str]],
62
+ context: Optional[dict] = None,
63
+ ) -> list[EdgeKnowledge]:
64
+ """Query multiple edges at once.
65
+
66
+ Default implementation calls query_edge for each pair.
67
+ Subclasses may override for batch optimization.
68
+
69
+ Args:
70
+ edges: List of (node_a, node_b) tuples to query.
71
+ context: Optional context dictionary (shared across all queries).
72
+
73
+ Returns:
74
+ List of EdgeKnowledge results, one per edge pair.
75
+ """
76
+ return [self.query_edge(a, b, context) for a, b in edges]
77
+
78
+ @property
79
+ def name(self) -> str:
80
+ """Return the name of this knowledge provider.
81
+
82
+ Returns:
83
+ Class name by default. Subclasses may override.
84
+ """
85
+ return self.__class__.__name__
@@ -0,0 +1,207 @@
1
+ """Command-line interface for causaliq-knowledge."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from typing import Optional
8
+
9
+ import click
10
+
11
+ from causaliq_knowledge import __version__
12
+
13
+
14
+ @click.group()
15
+ @click.version_option(version=__version__)
16
+ def cli() -> None:
17
+ """CausalIQ Knowledge - LLM knowledge for causal discovery.
18
+
19
+ Query LLMs about causal relationships between variables.
20
+ """
21
+ pass
22
+
23
+
24
+ @cli.command("query")
25
+ @click.argument("node_a")
26
+ @click.argument("node_b")
27
+ @click.option(
28
+ "--model",
29
+ "-m",
30
+ multiple=True,
31
+ default=["groq/llama-3.1-8b-instant"],
32
+ help="LLM model(s) to query. Can be specified multiple times.",
33
+ )
34
+ @click.option(
35
+ "--domain",
36
+ "-d",
37
+ default=None,
38
+ help="Domain context (e.g., 'medicine', 'economics').",
39
+ )
40
+ @click.option(
41
+ "--strategy",
42
+ "-s",
43
+ type=click.Choice(["weighted_vote", "highest_confidence"]),
44
+ default="weighted_vote",
45
+ help="Consensus strategy for multi-model queries.",
46
+ )
47
+ @click.option(
48
+ "--json",
49
+ "output_json",
50
+ is_flag=True,
51
+ help="Output result as JSON.",
52
+ )
53
+ @click.option(
54
+ "--temperature",
55
+ "-t",
56
+ type=float,
57
+ default=0.1,
58
+ help="LLM temperature (0.0-1.0).",
59
+ )
60
+ def query_edge(
61
+ node_a: str,
62
+ node_b: str,
63
+ model: tuple[str, ...],
64
+ domain: Optional[str],
65
+ strategy: str,
66
+ output_json: bool,
67
+ temperature: float,
68
+ ) -> None:
69
+ """Query LLMs about a causal relationship between two variables.
70
+
71
+ NODE_A and NODE_B are the variable names to query about.
72
+
73
+ Examples:
74
+
75
+ cqknow query smoking lung_cancer
76
+
77
+ cqknow query smoking lung_cancer --domain medicine
78
+
79
+ cqknow query X Y --model groq/llama-3.1-8b-instant \
80
+ --model gemini/gemini-2.5-flash
81
+ """
82
+ # Import here to avoid slow startup for --help
83
+ from causaliq_knowledge.llm import LLMKnowledge
84
+
85
+ # Build context
86
+ context = None
87
+ if domain:
88
+ context = {"domain": domain}
89
+
90
+ # Create provider
91
+ try:
92
+ provider = LLMKnowledge(
93
+ models=list(model),
94
+ consensus_strategy=strategy,
95
+ temperature=temperature,
96
+ )
97
+ except Exception as e:
98
+ click.echo(f"Error creating provider: {e}", err=True)
99
+ sys.exit(1)
100
+
101
+ # Query
102
+ click.echo(
103
+ f"Querying {len(model)} model(s) about: {node_a} -> {node_b}",
104
+ err=True,
105
+ )
106
+
107
+ try:
108
+ result = provider.query_edge(node_a, node_b, context=context)
109
+ except Exception as e:
110
+ click.echo(f"Error querying LLM: {e}", err=True)
111
+ sys.exit(1)
112
+
113
+ # Output
114
+ if output_json:
115
+ output = {
116
+ "node_a": node_a,
117
+ "node_b": node_b,
118
+ "exists": result.exists,
119
+ "direction": result.direction.value if result.direction else None,
120
+ "confidence": result.confidence,
121
+ "reasoning": result.reasoning,
122
+ "model": result.model,
123
+ }
124
+ click.echo(json.dumps(output, indent=2))
125
+ else:
126
+ # Human-readable output
127
+ exists_map = {True: "Yes", False: "No", None: "Uncertain"}
128
+ exists_str = exists_map[result.exists]
129
+ direction_str = result.direction.value if result.direction else "N/A"
130
+
131
+ click.echo(f"\n{'='*60}")
132
+ click.echo(f"Query: Does '{node_a}' cause '{node_b}'?")
133
+ click.echo("=" * 60)
134
+ click.echo(f"Exists: {exists_str}")
135
+ click.echo(f"Direction: {direction_str}")
136
+ click.echo(f"Confidence: {result.confidence:.2f}")
137
+ click.echo(f"Model(s): {result.model or 'unknown'}")
138
+ click.echo(f"{'='*60}")
139
+ click.echo(f"Reasoning: {result.reasoning}")
140
+ click.echo()
141
+
142
+ # Show stats
143
+ stats = provider.get_stats()
144
+ if stats["total_cost"] > 0:
145
+ click.echo(
146
+ f"Cost: ${stats['total_cost']:.6f} "
147
+ f"({stats['total_calls']} call(s))",
148
+ err=True,
149
+ )
150
+
151
+
152
+ @cli.command("models")
153
+ def list_models() -> None:
154
+ """List supported LLM models.
155
+
156
+ These are model identifiers that work with our direct API clients.
157
+ Only models with direct API support are listed.
158
+ """
159
+ models = [
160
+ (
161
+ "Groq (Fast, Free Tier Available)",
162
+ [
163
+ "groq/llama-3.1-8b-instant",
164
+ "groq/llama-3.1-70b-versatile",
165
+ "groq/llama-3.2-1b-preview",
166
+ "groq/llama-3.2-3b-preview",
167
+ "groq/mixtral-8x7b-32768",
168
+ "groq/gemma-7b-it",
169
+ "groq/gemma2-9b-it",
170
+ ],
171
+ ),
172
+ (
173
+ "Google Gemini (Free Tier Available)",
174
+ [
175
+ "gemini/gemini-2.5-flash",
176
+ "gemini/gemini-1.5-pro",
177
+ "gemini/gemini-1.5-flash",
178
+ "gemini/gemini-1.5-flash-8b",
179
+ ],
180
+ ),
181
+ ]
182
+
183
+ click.echo("\nSupported LLM Models (Direct API Access):\n")
184
+ for provider, model_list in models:
185
+ click.echo(f" {provider}:")
186
+ for m in model_list:
187
+ click.echo(f" - {m}")
188
+ click.echo()
189
+ click.echo("Required API Keys:")
190
+ click.echo(
191
+ " GROQ_API_KEY - Get free API key at https://console.groq.com"
192
+ )
193
+ click.echo(
194
+ " GEMINI_API_KEY - Get free API key at https://aistudio.google.com"
195
+ )
196
+ click.echo()
197
+ click.echo("Default model: groq/llama-3.1-8b-instant")
198
+ click.echo()
199
+
200
+
201
+ def main() -> None:
202
+ """Entry point for the CLI."""
203
+ cli()
204
+
205
+
206
+ if __name__ == "__main__": # pragma: no cover
207
+ main()
@@ -0,0 +1,34 @@
1
+ """LLM integration module for causaliq-knowledge."""
2
+
3
+ from causaliq_knowledge.llm.gemini_client import (
4
+ GeminiClient,
5
+ GeminiConfig,
6
+ GeminiResponse,
7
+ )
8
+ from causaliq_knowledge.llm.groq_client import (
9
+ GroqClient,
10
+ GroqConfig,
11
+ GroqResponse,
12
+ )
13
+ from causaliq_knowledge.llm.prompts import EdgeQueryPrompt, parse_edge_response
14
+ from causaliq_knowledge.llm.provider import (
15
+ CONSENSUS_STRATEGIES,
16
+ LLMKnowledge,
17
+ highest_confidence,
18
+ weighted_vote,
19
+ )
20
+
21
+ __all__ = [
22
+ "CONSENSUS_STRATEGIES",
23
+ "EdgeQueryPrompt",
24
+ "GeminiClient",
25
+ "GeminiConfig",
26
+ "GeminiResponse",
27
+ "GroqClient",
28
+ "GroqConfig",
29
+ "GroqResponse",
30
+ "LLMKnowledge",
31
+ "highest_confidence",
32
+ "parse_edge_response",
33
+ "weighted_vote",
34
+ ]
@@ -0,0 +1,203 @@
1
+ """Direct Google Gemini API client - clean and reliable."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import httpx
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class GeminiConfig:
16
+ """Configuration for Gemini API client."""
17
+
18
+ model: str = "gemini-2.5-flash"
19
+ temperature: float = 0.1
20
+ max_tokens: int = 500
21
+ timeout: float = 30.0
22
+ api_key: Optional[str] = None
23
+
24
+ def __post_init__(self) -> None:
25
+ """Set API key from environment if not provided."""
26
+ if self.api_key is None:
27
+ self.api_key = os.getenv("GEMINI_API_KEY")
28
+ if not self.api_key:
29
+ raise ValueError("GEMINI_API_KEY environment variable is required")
30
+
31
+
32
+ @dataclass
33
+ class GeminiResponse:
34
+ """Response from Gemini API."""
35
+
36
+ content: str
37
+ model: str
38
+ input_tokens: int = 0
39
+ output_tokens: int = 0
40
+ cost: float = 0.0 # Gemini free tier
41
+ raw_response: Optional[Dict] = None
42
+
43
+ def parse_json(self) -> Optional[Dict[str, Any]]:
44
+ """Parse content as JSON, handling common formatting issues."""
45
+ try:
46
+ # Clean up potential markdown code blocks
47
+ text = self.content.strip()
48
+ if text.startswith("```json"):
49
+ text = text[7:]
50
+ elif text.startswith("```"):
51
+ text = text[3:]
52
+ if text.endswith("```"):
53
+ text = text[:-3]
54
+
55
+ return json.loads(text.strip()) # type: ignore[no-any-return]
56
+ except json.JSONDecodeError:
57
+ return None
58
+
59
+
60
+ class GeminiClient:
61
+ """Direct Gemini API client."""
62
+
63
+ BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
64
+
65
+ def __init__(self, config: Optional[GeminiConfig] = None):
66
+ """Initialize Gemini client."""
67
+ self.config = config or GeminiConfig()
68
+ self._total_calls = 0
69
+
70
+ def completion(
71
+ self, messages: List[Dict[str, str]], **kwargs: Any
72
+ ) -> GeminiResponse:
73
+ """Make a chat completion request to Gemini."""
74
+
75
+ # Convert OpenAI-style messages to Gemini format
76
+ contents = []
77
+ system_instruction = None
78
+
79
+ for msg in messages:
80
+ if msg["role"] == "system":
81
+ # Gemini handles system prompts differently
82
+ system_instruction = {"parts": [{"text": msg["content"]}]}
83
+ elif msg["role"] == "user":
84
+ contents.append(
85
+ {"role": "user", "parts": [{"text": msg["content"]}]}
86
+ )
87
+ elif msg["role"] == "assistant":
88
+ contents.append(
89
+ {"role": "model", "parts": [{"text": msg["content"]}]}
90
+ )
91
+
92
+ # Build request payload in Gemini's format
93
+ payload = {
94
+ "contents": contents,
95
+ "generationConfig": {
96
+ "temperature": kwargs.get(
97
+ "temperature", self.config.temperature
98
+ ),
99
+ "maxOutputTokens": kwargs.get(
100
+ "max_tokens", self.config.max_tokens
101
+ ),
102
+ "responseMimeType": "text/plain",
103
+ },
104
+ }
105
+
106
+ # Add system instruction if present
107
+ if system_instruction:
108
+ payload["systemInstruction"] = system_instruction
109
+
110
+ # API endpoint with model and key
111
+ url = f"{self.BASE_URL}/{self.config.model}:generateContent"
112
+ params = {"key": self.config.api_key}
113
+
114
+ headers = {"Content-Type": "application/json"}
115
+
116
+ logger.debug(f"Calling Gemini API with model: {self.config.model}")
117
+
118
+ try:
119
+ with httpx.Client(timeout=self.config.timeout) as client:
120
+ response = client.post(
121
+ url, json=payload, headers=headers, params=params
122
+ )
123
+ response.raise_for_status()
124
+
125
+ data = response.json()
126
+
127
+ # Handle Gemini API errors
128
+ if "error" in data:
129
+ error_msg = data["error"].get("message", "Unknown error")
130
+ raise ValueError(f"Gemini API error: {error_msg}")
131
+
132
+ # Extract response data from Gemini format
133
+ candidates = data.get("candidates", [])
134
+ if not candidates:
135
+ raise ValueError("No candidates returned by Gemini API")
136
+
137
+ candidate = candidates[0]
138
+ if candidate.get("finishReason") == "SAFETY":
139
+ raise ValueError(
140
+ "Content was blocked by Gemini safety filters"
141
+ )
142
+
143
+ # Extract text content
144
+ parts = candidate.get("content", {}).get("parts", [])
145
+ content = ""
146
+ for part in parts:
147
+ if "text" in part:
148
+ content += part["text"]
149
+
150
+ # Extract usage info
151
+ usage = data.get("usageMetadata", {})
152
+ input_tokens = usage.get("promptTokenCount", 0)
153
+ output_tokens = usage.get("candidatesTokenCount", 0)
154
+
155
+ self._total_calls += 1
156
+
157
+ logger.debug(
158
+ f"Gemini response: {input_tokens} in, {output_tokens} out"
159
+ )
160
+
161
+ return GeminiResponse(
162
+ content=content,
163
+ model=self.config.model,
164
+ input_tokens=input_tokens,
165
+ output_tokens=output_tokens,
166
+ cost=0.0, # Free tier
167
+ raw_response=data,
168
+ )
169
+
170
+ except httpx.HTTPStatusError as e:
171
+ try:
172
+ error_data = e.response.json()
173
+ error_msg = error_data.get("error", {}).get(
174
+ "message", e.response.text
175
+ )
176
+ except Exception:
177
+ error_msg = e.response.text
178
+
179
+ logger.error(
180
+ f"Gemini API HTTP error: {e.response.status_code} - "
181
+ f"{error_msg}"
182
+ )
183
+ raise ValueError(
184
+ f"Gemini API error: {e.response.status_code} - {error_msg}"
185
+ )
186
+ except httpx.TimeoutException:
187
+ raise ValueError("Gemini API request timed out")
188
+ except Exception as e:
189
+ logger.error(f"Gemini API unexpected error: {e}")
190
+ raise ValueError(f"Gemini API error: {str(e)}")
191
+
192
+ def complete_json(
193
+ self, messages: List[Dict[str, str]], **kwargs: Any
194
+ ) -> tuple[Optional[Dict[str, Any]], GeminiResponse]:
195
+ """Make a completion request and parse response as JSON."""
196
+ response = self.completion(messages, **kwargs)
197
+ parsed = response.parse_json()
198
+ return parsed, response
199
+
200
+ @property
201
+ def call_count(self) -> int:
202
+ """Number of API calls made."""
203
+ return self._total_calls