annie-sdk 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ .venv/
2
+ __pycache__/
3
+ *.egg-info/
4
+ dist/
5
+ build/
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: annie-sdk
3
+ Version: 0.2.0
4
+ Summary: Python SDK for querying databases using natural language via the Annie API
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: httpx>=0.28.0
7
+ Requires-Dist: pydantic>=2.0.0
8
+ Provides-Extra: all
9
+ Requires-Dist: mysql-connector-python>=9.0.0; extra == 'all'
10
+ Requires-Dist: psycopg2-binary>=2.9.0; extra == 'all'
11
+ Provides-Extra: dev
12
+ Requires-Dist: pytest>=8.0.0; extra == 'dev'
13
+ Requires-Dist: ruff>=0.8.0; extra == 'dev'
14
+ Provides-Extra: mysql
15
+ Requires-Dist: mysql-connector-python>=9.0.0; extra == 'mysql'
16
+ Provides-Extra: postgres
17
+ Requires-Dist: psycopg2-binary>=2.9.0; extra == 'postgres'
@@ -0,0 +1,142 @@
1
+ """
2
+ Annie SDK
3
+
4
+ A Python SDK for querying databases using natural language.
5
+ The SDK connects to your database through connectors, sends your
6
+ query and schema context to the Annie API, and executes the
7
+ resulting SQL locally against your database.
8
+
9
+ Your data never leaves your infrastructure — only schema metadata
10
+ is sent to the API.
11
+
12
+ Quick start:
13
+ ```python
14
+ from annie_sdk import Agent, PostgresConnector, ConnectorTable, ConnectorTableColumn
15
+
16
+ connector = PostgresConnector(
17
+ connection_string="postgresql://user:pass@localhost:5432/mydb",
18
+ tables=[
19
+ ConnectorTable(
20
+ name="orders",
21
+ description="Customer orders",
22
+ columns=[
23
+ ConnectorTableColumn(name="id", type="integer"),
24
+ ConnectorTableColumn(name="amount", type="decimal"),
25
+ ConnectorTableColumn(name="status", type="string"),
26
+ ],
27
+ ),
28
+ ],
29
+ )
30
+
31
+ with Agent(connector=connector, api_key="your-api-key") as agent:
32
+ response = agent.run("Show me top 10 orders by amount")
33
+ print(response.sql)
34
+ print(response.data)
35
+ ```
36
+ """
37
+
38
+ # Agent
39
+ from .agent import Agent
40
+
41
+ # Connectors
42
+ from .connectors.base import (
43
+ BaseConnector,
44
+ ConnectorRelationship,
45
+ ConnectorTable,
46
+ ConnectorTableColumn,
47
+ )
48
+ from .connectors.mock import MockConnector
49
+ from .connectors.mysql import MySQLConnector
50
+ from .connectors.postgres import PostgresConnector
51
+
52
+ # Exceptions
53
+ from .exceptions import (
54
+ AnnieError,
55
+ APIError,
56
+ AuthenticationError,
57
+ ConnectionError,
58
+ RateLimitError,
59
+ TimeoutError,
60
+ ValidationError,
61
+ )
62
+
63
+ # Models (DSL)
64
+ from .models import (
65
+ AggregationFunction,
66
+ AggregationSpec,
67
+ BySpec,
68
+ DataModel,
69
+ DateBinning,
70
+ DateBinningInterval,
71
+ FilterCondition,
72
+ FilterStep,
73
+ LimitStep,
74
+ NumberBinning,
75
+ NumberBinningStrategy,
76
+ Operator,
77
+ OrderByColumn,
78
+ OrderByStep,
79
+ SelectStep,
80
+ SortDirection,
81
+ SummarizeStep,
82
+ VisualizationType,
83
+ )
84
+
85
+ # Response
86
+ from .response import (
87
+ AgentResponse,
88
+ ChartContent,
89
+ ChartFormat,
90
+ DataContent,
91
+ ErrorContent,
92
+ ResponseType,
93
+ TextContent,
94
+ )
95
+
96
+ __all__ = [
97
+ # Agent
98
+ "Agent",
99
+ # Connectors
100
+ "BaseConnector",
101
+ "ConnectorRelationship",
102
+ "ConnectorTable",
103
+ "ConnectorTableColumn",
104
+ "MockConnector",
105
+ "MySQLConnector",
106
+ "PostgresConnector",
107
+ # Exceptions
108
+ "AnnieError",
109
+ "APIError",
110
+ "AuthenticationError",
111
+ "ConnectionError",
112
+ "RateLimitError",
113
+ "TimeoutError",
114
+ "ValidationError",
115
+ # Models
116
+ "AggregationFunction",
117
+ "AggregationSpec",
118
+ "BySpec",
119
+ "DataModel",
120
+ "DateBinning",
121
+ "DateBinningInterval",
122
+ "FilterCondition",
123
+ "FilterStep",
124
+ "LimitStep",
125
+ "NumberBinning",
126
+ "NumberBinningStrategy",
127
+ "Operator",
128
+ "OrderByColumn",
129
+ "OrderByStep",
130
+ "SelectStep",
131
+ "SortDirection",
132
+ "SummarizeStep",
133
+ "VisualizationType",
134
+ # Response
135
+ "AgentResponse",
136
+ "ChartContent",
137
+ "ChartFormat",
138
+ "DataContent",
139
+ "ErrorContent",
140
+ "ResponseType",
141
+ "TextContent",
142
+ ]
@@ -0,0 +1,410 @@
1
+ """
2
+ Annie SDK Agent
3
+
4
+ The main entry point for the Annie SDK. The Agent handles communication
5
+ with the Annie API and local SQL execution via connectors.
6
+
7
+ Flow:
8
+ 1. User calls agent.run("natural language query")
9
+ 2. Agent sends NL query + schema context to Annie API
10
+ 3. Annie API returns DSL + SQL in the connector's dialect
11
+ 4. Agent executes SQL locally via the connector
12
+ 5. Results are formatted as AgentResponse
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import os
19
+ import random
20
+ from decimal import Decimal
21
+ from typing import Any
22
+
23
+ import httpx
24
+
25
+ from .connectors.base import BaseConnector
26
+ from .exceptions import APIError, AuthenticationError, RateLimitError, TimeoutError, ValidationError
27
+ from .response import AgentResponse, ChartFormat
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ DEFAULT_API_URL = "https://api.pandas-ai.com"
32
+ DEFAULT_TIMEOUT = 60
33
+
34
+ DEFAULT_DENIAL_MESSAGES = [
35
+ "I can only help with questions about your data. Try asking something like 'Show revenue by month'.",
36
+ "That doesn't seem to be a data question. Try asking about your data, like 'Top 10 customers by sales'.",
37
+ "I'm a data assistant. Ask me about your data — for example, 'What's the average order value?'",
38
+ ]
39
+
40
+
41
+ class Agent:
42
+ """Annie SDK Agent.
43
+
44
+ Handles natural language queries against your database using the
45
+ Annie API for NL-to-SQL conversion and local execution for data privacy.
46
+
47
+ Args:
48
+ connector: Database connector instance
49
+ api_key: Annie API key (or set ANNIE_API_KEY env var)
50
+ api_url: Annie API URL (or set ANNIE_API_URL env var)
51
+ denial_messages: Custom messages for non-data queries
52
+
53
+ Example:
54
+ ```python
55
+ from sdk import Agent, PostgresConnector, ConnectorTable
56
+
57
+ connector = PostgresConnector(
58
+ connection_string="postgresql://user:pass@localhost/db",
59
+ tables=[ConnectorTable(name="orders", description="Customer orders")],
60
+ )
61
+
62
+ with Agent(connector=connector, api_key="your-key") as agent:
63
+ response = agent.run("Show me top 10 orders by amount")
64
+ print(response.data)
65
+ ```
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ connector: BaseConnector,
71
+ api_key: str | None = None,
72
+ api_url: str | None = None,
73
+ denial_messages: list[str] | None = None,
74
+ ):
75
+ # Resolve API key
76
+ self._api_key = api_key or os.environ.get("ANNIE_API_KEY")
77
+ if not self._api_key:
78
+ raise AuthenticationError(
79
+ "API key is required. Pass api_key parameter or set ANNIE_API_KEY environment variable."
80
+ )
81
+
82
+ # Resolve API URL
83
+ self._api_url = (
84
+ api_url or os.environ.get("ANNIE_API_URL") or DEFAULT_API_URL
85
+ ).rstrip("/")
86
+
87
+ self._connector = connector
88
+ self._denial_messages = denial_messages or DEFAULT_DENIAL_MESSAGES
89
+
90
+ # Validate connector has tables
91
+ if not connector.tables:
92
+ raise ValidationError(
93
+ "Connector must have at least one table configured",
94
+ field="connector.tables",
95
+ )
96
+
97
+ # Initialize HTTP client
98
+ self._client = httpx.Client(
99
+ base_url=self._api_url,
100
+ headers={
101
+ "Authorization": f"Bearer {self._api_key}",
102
+ "Content-Type": "application/json",
103
+ },
104
+ timeout=DEFAULT_TIMEOUT,
105
+ )
106
+
107
+ def run(self, query: str, *, explain: bool = False) -> AgentResponse:
108
+ """Run a natural language query against your database.
109
+
110
+ Args:
111
+ query: Natural language query (e.g., "Show me top 10 customers by revenue")
112
+ explain: If True, also generate AI insights about the results
113
+
114
+ Returns:
115
+ AgentResponse with data, SQL, DSL, and optional insights
116
+ """
117
+ if not query or not query.strip():
118
+ return AgentResponse.from_error("Query cannot be empty")
119
+
120
+ # 1. Build schema context from connector
121
+ schema_context = self._connector.get_schema_context()
122
+
123
+ # 2. Call Annie API: NL → DSL + SQL
124
+ api_response = self._call_query_api(query, schema_context)
125
+
126
+ # 3. Handle non-data responses (guardrail denials)
127
+ if api_response.get("type") == "text":
128
+ message = api_response.get("message") or random.choice(self._denial_messages)
129
+ response = AgentResponse(
130
+ success=True,
131
+ visualization="table",
132
+ )
133
+ response.add_text(message)
134
+ return response
135
+
136
+ # 4. Handle API errors
137
+ if not api_response.get("success", True):
138
+ error_msg = api_response.get("message", "API returned an error")
139
+ return AgentResponse.from_error(error_msg)
140
+
141
+ sql = api_response.get("sql")
142
+ dsl = api_response.get("dsl")
143
+ visualization = api_response.get("visualization", "table")
144
+
145
+ if not sql:
146
+ return AgentResponse.from_error("API did not return SQL")
147
+
148
+ # 5. Execute SQL locally via connector
149
+ try:
150
+ if not self._connector.is_connected():
151
+ self._connector.connect()
152
+
153
+ results = self._connector.execute(sql)
154
+ except Exception as e:
155
+ response = AgentResponse(
156
+ sql=sql,
157
+ dsl=dsl,
158
+ visualization=visualization,
159
+ success=False,
160
+ )
161
+ response.add_error(f"Query execution failed: {e}")
162
+ return response
163
+
164
+ # 6. Format results as AgentResponse
165
+ response = AgentResponse(
166
+ sql=sql,
167
+ dsl=dsl,
168
+ visualization=visualization,
169
+ success=True,
170
+ )
171
+
172
+ if results:
173
+ columns = list(results[0].keys())
174
+ rows = [list(row.values()) for row in results]
175
+ response.add_data(columns=columns, rows=rows)
176
+ else:
177
+ response.add_data(columns=[], rows=[])
178
+
179
+ # 6b. Generate chart from results
180
+ if results and visualization not in ("table", "kpi"):
181
+ chart_config = self._generate_chart(results, dsl or {}, visualization)
182
+ if chart_config:
183
+ response.add_chart(
184
+ content=chart_config["content"],
185
+ format=ChartFormat(chart_config.get("format", "chartjs")),
186
+ title=chart_config.get("title"),
187
+ )
188
+
189
+ # 7. Optionally get AI explanation
190
+ if explain and results:
191
+ try:
192
+ explanation = self._call_explain_api(
193
+ query=query,
194
+ results=results[:50],
195
+ sql=sql,
196
+ dsl=dsl,
197
+ schema=schema_context,
198
+ )
199
+ if explanation:
200
+ response.add_text(explanation)
201
+ except Exception as e:
202
+ logger.warning(f"Failed to get explanation: {e}")
203
+ # Don't fail the whole response if explain fails
204
+
205
+ return response
206
+
207
+ def close(self) -> None:
208
+ """Close the agent and its connector."""
209
+ try:
210
+ self._connector.disconnect()
211
+ except Exception:
212
+ pass
213
+ try:
214
+ self._client.close()
215
+ except Exception:
216
+ pass
217
+
218
+ # =========================================================================
219
+ # Context manager
220
+ # =========================================================================
221
+
222
+ def __enter__(self) -> Agent:
223
+ return self
224
+
225
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
226
+ self.close()
227
+
228
+ # =========================================================================
229
+ # Private methods
230
+ # =========================================================================
231
+
232
+ def _call_query_api(
233
+ self,
234
+ query: str,
235
+ schema: dict[str, Any],
236
+ ) -> dict[str, Any]:
237
+ """Call the Annie /sdk/query endpoint."""
238
+ payload = {
239
+ "query": query,
240
+ "schema": schema,
241
+ "dialect": self._connector.dialect,
242
+ }
243
+
244
+ try:
245
+ response = self._client.post("/sdk/query", json=payload)
246
+ except httpx.TimeoutException as e:
247
+ raise TimeoutError(f"API request timed out: {e}")
248
+ except httpx.ConnectError as e:
249
+ raise APIError(f"Failed to connect to Annie API at {self._api_url}: {e}")
250
+
251
+ if response.status_code == 401:
252
+ raise AuthenticationError("Invalid API key")
253
+ if response.status_code == 429:
254
+ retry_after = response.headers.get("retry-after")
255
+ raise RateLimitError(
256
+ "Rate limit exceeded",
257
+ retry_after=int(retry_after) if retry_after else None,
258
+ )
259
+ if response.status_code >= 400:
260
+ raise APIError(
261
+ f"Annie API error: {response.status_code}",
262
+ status_code=response.status_code,
263
+ )
264
+
265
+ try:
266
+ return response.json()
267
+ except Exception as e:
268
+ body_preview = response.text[:200] if response.text else "(empty)"
269
+ raise APIError(
270
+ f"Failed to parse API response (status {response.status_code}): {e}. "
271
+ f"Response: {body_preview}"
272
+ )
273
+
274
+ def _call_explain_api(
275
+ self,
276
+ query: str,
277
+ results: list[dict[str, Any]],
278
+ sql: str | None,
279
+ dsl: dict[str, Any] | None,
280
+ schema: dict[str, Any] | None,
281
+ ) -> str | None:
282
+ """Call the Annie /sdk/explain endpoint."""
283
+ payload = {
284
+ "query": query,
285
+ "results": results,
286
+ "sql": sql,
287
+ "dsl": dsl,
288
+ "schema": schema,
289
+ }
290
+
291
+ try:
292
+ response = self._client.post("/sdk/explain", json=payload)
293
+
294
+ if response.status_code != 200:
295
+ logger.warning(f"Explain API returned {response.status_code}")
296
+ return None
297
+
298
+ data = response.json()
299
+ return data.get("explanation") if data.get("success") else None
300
+
301
+ except Exception as e:
302
+ logger.warning(f"Explain API call failed: {e}")
303
+ return None
304
+
305
+ def _generate_chart(
306
+ self,
307
+ results: list[dict[str, Any]],
308
+ dsl: dict[str, Any],
309
+ visualization: str,
310
+ ) -> dict[str, Any] | None:
311
+ """Generate a ChartJS configuration from query results.
312
+
313
+ Creates a ChartJS-compatible configuration based on the data structure
314
+ and the visualization type returned by the API.
315
+
316
+ Args:
317
+ results: Query results as list of dicts
318
+ dsl: The DSL used for the query
319
+ visualization: Visualization type from the API (bar, line, pie, etc.)
320
+
321
+ Returns:
322
+ Chart configuration dict or None if not applicable
323
+ """
324
+ if not results or len(results) < 2:
325
+ return None
326
+
327
+ # Find label column (first string-typed) and value columns (numeric).
328
+ # Scan up to 5 rows to handle cases where the first row has NULLs.
329
+ columns = list(results[0].keys())
330
+ label_col = None
331
+ value_cols = []
332
+
333
+ for col in columns:
334
+ # Find the first non-None value in up to 5 rows
335
+ sample_val = None
336
+ for row in results[:5]:
337
+ v = row.get(col)
338
+ if v is not None:
339
+ sample_val = v
340
+ break
341
+
342
+ if sample_val is not None and isinstance(sample_val, str) and label_col is None:
343
+ label_col = col
344
+ elif isinstance(sample_val, (int, float, Decimal)):
345
+ value_cols.append(col)
346
+
347
+ if not label_col or not value_cols:
348
+ return None
349
+
350
+ # Extract data
351
+ labels = [str(row.get(label_col) or "") for row in results]
352
+
353
+ colors = [
354
+ "rgba(54, 162, 235, 0.8)",
355
+ "rgba(255, 99, 132, 0.8)",
356
+ "rgba(255, 206, 86, 0.8)",
357
+ "rgba(75, 192, 192, 0.8)",
358
+ "rgba(153, 102, 255, 0.8)",
359
+ "rgba(255, 159, 64, 0.8)",
360
+ "rgba(199, 199, 199, 0.8)",
361
+ "rgba(83, 102, 255, 0.8)",
362
+ ]
363
+
364
+ chart_type = visualization if visualization in ("bar", "line", "pie", "doughnut") else "bar"
365
+
366
+ datasets = []
367
+ for i, val_col in enumerate(value_cols[:2]): # Limit to 2 datasets
368
+ values = [float(row.get(val_col) or 0) for row in results]
369
+ dataset: dict[str, Any] = {
370
+ "label": val_col.replace("_", " ").title(),
371
+ "data": values,
372
+ "borderWidth": 1,
373
+ }
374
+
375
+ if chart_type in ("pie", "doughnut"):
376
+ # Pie/doughnut: different color per slice
377
+ dataset["backgroundColor"] = colors[: len(labels)]
378
+ elif len(value_cols) == 1:
379
+ # Single dataset: different color per bar
380
+ dataset["backgroundColor"] = colors[: len(labels)]
381
+ else:
382
+ # Multiple datasets: same color per dataset
383
+ dataset["backgroundColor"] = colors[i % len(colors)]
384
+ dataset["borderColor"] = colors[i % len(colors)].replace("0.8", "1")
385
+
386
+ datasets.append(dataset)
387
+
388
+ title = (
389
+ f"{value_cols[0].replace('_', ' ').title()} by "
390
+ f"{label_col.replace('_', ' ').title()}"
391
+ )
392
+
393
+ return {
394
+ "format": "chartjs",
395
+ "title": title,
396
+ "content": {
397
+ "type": chart_type,
398
+ "data": {
399
+ "labels": labels,
400
+ "datasets": datasets,
401
+ },
402
+ "options": {
403
+ "responsive": True,
404
+ "plugins": {
405
+ "legend": {"position": "top"},
406
+ "title": {"display": True, "text": title},
407
+ },
408
+ },
409
+ },
410
+ }
@@ -0,0 +1,16 @@
1
+ """Annie SDK Connectors."""
2
+
3
+ from .base import BaseConnector, ConnectorRelationship, ConnectorTable, ConnectorTableColumn
4
+ from .mock import MockConnector
5
+ from .mysql import MySQLConnector
6
+ from .postgres import PostgresConnector
7
+
8
+ __all__ = [
9
+ "BaseConnector",
10
+ "ConnectorRelationship",
11
+ "ConnectorTable",
12
+ "ConnectorTableColumn",
13
+ "MockConnector",
14
+ "MySQLConnector",
15
+ "PostgresConnector",
16
+ ]