discovery-engine-api 0.1.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- discovery/__init__.py +34 -0
- discovery/client.py +857 -0
- discovery/types.py +256 -0
- discovery_engine_api-0.1.52.dist-info/METADATA +354 -0
- discovery_engine_api-0.1.52.dist-info/RECORD +6 -0
- discovery_engine_api-0.1.52.dist-info/WHEEL +4 -0
discovery/client.py
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
1
|
+
"""Discovery Engine Python SDK."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import pandas as pd
|
|
14
|
+
except ImportError:
|
|
15
|
+
pd = None
|
|
16
|
+
|
|
17
|
+
from discovery.types import (
|
|
18
|
+
Column,
|
|
19
|
+
CorrelationEntry,
|
|
20
|
+
DataInsights,
|
|
21
|
+
EngineResult,
|
|
22
|
+
FeatureImportance,
|
|
23
|
+
FeatureImportanceScore,
|
|
24
|
+
FileInfo,
|
|
25
|
+
Pattern,
|
|
26
|
+
PatternGroup,
|
|
27
|
+
RunStatus,
|
|
28
|
+
Summary,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Engine:
|
|
33
|
+
"""Engine for the Discovery Engine API."""
|
|
34
|
+
|
|
35
|
+
# Production API URL (can be overridden via DISCOVERY_API_URL env var for testing)
|
|
36
|
+
# This points to the Modal-deployed FastAPI API
|
|
37
|
+
_DEFAULT_BASE_URL = "https://leap-labs-production--discovery-api.modal.run"
|
|
38
|
+
|
|
39
|
+
# Dashboard URL for web UI and /api/* endpoints
|
|
40
|
+
_DEFAULT_DASHBOARD_URL = "https://disco.leap-labs.com"
|
|
41
|
+
|
|
42
|
+
def __init__(self, api_key: str):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the Discovery Engine.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
api_key: Your API key
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
print("Initializing Discovery Engine...")
|
|
51
|
+
self.api_key = api_key
|
|
52
|
+
# Use DISCOVERY_API_URL env var if set (for testing/custom deployments),
|
|
53
|
+
# otherwise use the production default
|
|
54
|
+
self.base_url = os.getenv("DISCOVERY_API_URL", self._DEFAULT_BASE_URL).rstrip("/")
|
|
55
|
+
# Dashboard URL for /api/* endpoints and web UI links
|
|
56
|
+
self.dashboard_url = os.getenv(
|
|
57
|
+
"DISCOVERY_DASHBOARD_URL", self._DEFAULT_DASHBOARD_URL
|
|
58
|
+
).rstrip("/")
|
|
59
|
+
self._organization_id: Optional[str] = None
|
|
60
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
61
|
+
self._dashboard_client: Optional[httpx.AsyncClient] = None
|
|
62
|
+
self._org_fetched = False
|
|
63
|
+
|
|
64
|
+
async def _ensure_organization_id(self) -> str:
|
|
65
|
+
"""
|
|
66
|
+
Ensure we have an organization ID, fetching from API if needed.
|
|
67
|
+
|
|
68
|
+
The organization ID is required for API requests to identify which
|
|
69
|
+
organization the user belongs to (multi-tenancy support).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Organization ID string
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If no organization is found or API request fails
|
|
76
|
+
"""
|
|
77
|
+
if self._organization_id:
|
|
78
|
+
return self._organization_id
|
|
79
|
+
|
|
80
|
+
if not self._org_fetched:
|
|
81
|
+
# Fetch user's organizations and use the first one
|
|
82
|
+
try:
|
|
83
|
+
orgs = await self.get_organizations()
|
|
84
|
+
if orgs:
|
|
85
|
+
self._organization_id = orgs[0]["id"]
|
|
86
|
+
except ValueError as e:
|
|
87
|
+
# Re-raise with more context
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Failed to fetch organization: {e}. "
|
|
90
|
+
"Please ensure your API key is valid and you belong to an organization."
|
|
91
|
+
) from e
|
|
92
|
+
self._org_fetched = True
|
|
93
|
+
|
|
94
|
+
if not self._organization_id:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"No organization found for your account. "
|
|
97
|
+
"Please contact support if this issue persists."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return self._organization_id
|
|
101
|
+
|
|
102
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
103
|
+
"""Get or create the HTTP client."""
|
|
104
|
+
if self._client is None:
|
|
105
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
106
|
+
self._client = httpx.AsyncClient(
|
|
107
|
+
base_url=self.base_url,
|
|
108
|
+
headers=headers,
|
|
109
|
+
timeout=60.0,
|
|
110
|
+
)
|
|
111
|
+
return self._client
|
|
112
|
+
|
|
113
|
+
async def _get_client_with_org(self) -> httpx.AsyncClient:
|
|
114
|
+
"""
|
|
115
|
+
Get HTTP client with organization header set.
|
|
116
|
+
|
|
117
|
+
The organization ID is required for API requests to identify which
|
|
118
|
+
organization the user belongs to (multi-tenancy support).
|
|
119
|
+
"""
|
|
120
|
+
client = await self._get_client()
|
|
121
|
+
|
|
122
|
+
# Ensure we have an organization ID
|
|
123
|
+
org_id = await self._ensure_organization_id()
|
|
124
|
+
|
|
125
|
+
# Set the organization header
|
|
126
|
+
client.headers["X-Organization-ID"] = org_id
|
|
127
|
+
|
|
128
|
+
return client
|
|
129
|
+
|
|
130
|
+
async def _get_dashboard_client(self) -> httpx.AsyncClient:
|
|
131
|
+
"""Get or create the HTTP client for dashboard API calls."""
|
|
132
|
+
if self._dashboard_client is None:
|
|
133
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
134
|
+
self._dashboard_client = httpx.AsyncClient(
|
|
135
|
+
base_url=self.dashboard_url,
|
|
136
|
+
headers=headers,
|
|
137
|
+
timeout=60.0,
|
|
138
|
+
)
|
|
139
|
+
return self._dashboard_client
|
|
140
|
+
|
|
141
|
+
async def close(self):
|
|
142
|
+
"""Close the HTTP clients."""
|
|
143
|
+
if self._client:
|
|
144
|
+
await self._client.aclose()
|
|
145
|
+
self._client = None
|
|
146
|
+
if self._dashboard_client:
|
|
147
|
+
await self._dashboard_client.aclose()
|
|
148
|
+
self._dashboard_client = None
|
|
149
|
+
|
|
150
|
+
async def __aenter__(self):
|
|
151
|
+
"""Async context manager entry."""
|
|
152
|
+
return self
|
|
153
|
+
|
|
154
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
155
|
+
"""Async context manager exit."""
|
|
156
|
+
await self.close()
|
|
157
|
+
|
|
158
|
+
async def get_organizations(self) -> List[Dict[str, Any]]:
|
|
159
|
+
"""
|
|
160
|
+
Get the organizations you belong to.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
List of organizations with id, name, and slug
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
ValueError: If the API request fails
|
|
167
|
+
"""
|
|
168
|
+
client = await self._get_client()
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
response = await client.get("/v1/me/organizations")
|
|
172
|
+
response.raise_for_status()
|
|
173
|
+
return response.json()
|
|
174
|
+
except httpx.HTTPStatusError as e:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Failed to fetch organizations: {e.response.status_code} {e.response.text}"
|
|
177
|
+
) from e
|
|
178
|
+
except httpx.RequestError as e:
|
|
179
|
+
raise ValueError(f"Failed to connect to API: {str(e)}") from e
|
|
180
|
+
|
|
181
|
+
async def upload_file(
|
|
182
|
+
self, file: Union[str, Path, "pd.DataFrame"], filename: Optional[str] = None
|
|
183
|
+
) -> FileInfo:
|
|
184
|
+
"""
|
|
185
|
+
Upload a file to the API.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
file: File path, Path object, or pandas DataFrame
|
|
189
|
+
filename: Optional filename (for DataFrame uploads)
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
FileInfo with file_path, file_hash, file_size, mime_type
|
|
193
|
+
"""
|
|
194
|
+
client = await self._get_client_with_org()
|
|
195
|
+
|
|
196
|
+
if pd is not None and isinstance(file, pd.DataFrame):
|
|
197
|
+
# Convert DataFrame to CSV in memory
|
|
198
|
+
import io
|
|
199
|
+
|
|
200
|
+
buffer = io.BytesIO()
|
|
201
|
+
file.to_csv(buffer, index=False)
|
|
202
|
+
buffer.seek(0)
|
|
203
|
+
file_content = buffer.getvalue()
|
|
204
|
+
filename = filename or "dataset.csv"
|
|
205
|
+
mime_type = "text/csv"
|
|
206
|
+
else:
|
|
207
|
+
# Read file from disk
|
|
208
|
+
file_path = Path(file)
|
|
209
|
+
if not file_path.exists():
|
|
210
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
211
|
+
file_content = file_path.read_bytes()
|
|
212
|
+
filename = filename or file_path.name
|
|
213
|
+
mime_type = (
|
|
214
|
+
"text/csv" if file_path.suffix == ".csv" else "application/vnd.apache.parquet"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Upload file
|
|
218
|
+
files = {"file": (filename, file_content, mime_type)}
|
|
219
|
+
response = await client.post("/v1/upload", files=files)
|
|
220
|
+
response.raise_for_status()
|
|
221
|
+
|
|
222
|
+
data = response.json()
|
|
223
|
+
return FileInfo(
|
|
224
|
+
file_path=data["file_path"],
|
|
225
|
+
file_hash=data["file_hash"],
|
|
226
|
+
file_size=data["file_size"],
|
|
227
|
+
mime_type=data["mime_type"],
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
async def create_dataset(
|
|
231
|
+
self,
|
|
232
|
+
title: Optional[str] = None,
|
|
233
|
+
description: Optional[str] = None,
|
|
234
|
+
total_rows: int = 0,
|
|
235
|
+
dataset_size_mb: Optional[float] = None,
|
|
236
|
+
author: Optional[str] = None,
|
|
237
|
+
source_url: Optional[str] = None,
|
|
238
|
+
) -> Dict[str, Any]:
|
|
239
|
+
"""
|
|
240
|
+
Create a dataset record.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
title: Dataset title
|
|
244
|
+
description: Dataset description
|
|
245
|
+
total_rows: Number of rows in the dataset
|
|
246
|
+
dataset_size_mb: Dataset size in MB
|
|
247
|
+
author: Optional author attribution
|
|
248
|
+
source_url: Optional source URL
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Dataset record with ID
|
|
252
|
+
"""
|
|
253
|
+
client = await self._get_client_with_org()
|
|
254
|
+
|
|
255
|
+
response = await client.post(
|
|
256
|
+
"/v1/run-datasets",
|
|
257
|
+
json={
|
|
258
|
+
"title": title,
|
|
259
|
+
"description": description,
|
|
260
|
+
"total_rows": total_rows,
|
|
261
|
+
"dataset_size_mb": dataset_size_mb,
|
|
262
|
+
"author": author,
|
|
263
|
+
"source_url": source_url,
|
|
264
|
+
},
|
|
265
|
+
)
|
|
266
|
+
response.raise_for_status()
|
|
267
|
+
return response.json()
|
|
268
|
+
|
|
269
|
+
async def create_file_record(self, dataset_id: str, file_info: FileInfo) -> Dict[str, Any]:
|
|
270
|
+
"""
|
|
271
|
+
Create a file record for a dataset.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
dataset_id: Dataset ID
|
|
275
|
+
file_info: FileInfo from upload_file()
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
File record with ID
|
|
279
|
+
"""
|
|
280
|
+
client = await self._get_client_with_org()
|
|
281
|
+
|
|
282
|
+
response = await client.post(
|
|
283
|
+
f"/v1/run-datasets/{dataset_id}/files",
|
|
284
|
+
json={
|
|
285
|
+
"mime_type": file_info.mime_type,
|
|
286
|
+
"file_path": file_info.file_path,
|
|
287
|
+
"file_hash": file_info.file_hash,
|
|
288
|
+
"file_size": file_info.file_size,
|
|
289
|
+
},
|
|
290
|
+
)
|
|
291
|
+
response.raise_for_status()
|
|
292
|
+
return response.json()
|
|
293
|
+
|
|
294
|
+
async def create_columns(
|
|
295
|
+
self, dataset_id: str, columns: List[Dict[str, Any]]
|
|
296
|
+
) -> List[Dict[str, Any]]:
|
|
297
|
+
"""
|
|
298
|
+
Create column records for a dataset.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
dataset_id: Dataset ID
|
|
302
|
+
columns: List of column definitions with full metadata
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
List of column records with IDs
|
|
306
|
+
"""
|
|
307
|
+
client = await self._get_client_with_org()
|
|
308
|
+
|
|
309
|
+
response = await client.post(
|
|
310
|
+
f"/v1/run-datasets/{dataset_id}/columns",
|
|
311
|
+
json=columns,
|
|
312
|
+
)
|
|
313
|
+
response.raise_for_status()
|
|
314
|
+
return response.json()
|
|
315
|
+
|
|
316
|
+
async def create_run(
|
|
317
|
+
self,
|
|
318
|
+
dataset_id: str,
|
|
319
|
+
target_column_id: str,
|
|
320
|
+
task: str = "regression",
|
|
321
|
+
depth_iterations: int = 1,
|
|
322
|
+
visibility: str = "public",
|
|
323
|
+
timeseries_groups: Optional[List[Dict[str, Any]]] = None,
|
|
324
|
+
target_column_override: Optional[str] = None,
|
|
325
|
+
auto_report_use_llm_evals: bool = True,
|
|
326
|
+
author: Optional[str] = None,
|
|
327
|
+
source_url: Optional[str] = None,
|
|
328
|
+
) -> Dict[str, Any]:
|
|
329
|
+
"""
|
|
330
|
+
Create a run and enqueue it for processing.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
dataset_id: Dataset ID
|
|
334
|
+
target_column_id: Target column ID
|
|
335
|
+
task: Task type (regression, binary_classification, multiclass_classification)
|
|
336
|
+
depth_iterations: Number of iterative feature removal cycles (1 = fastest)
|
|
337
|
+
visibility: Dataset visibility ("public" or "private")
|
|
338
|
+
timeseries_groups: Optional list of timeseries column groups
|
|
339
|
+
target_column_override: Optional override for target column name
|
|
340
|
+
auto_report_use_llm_evals: Use LLM evaluations
|
|
341
|
+
author: Optional dataset author
|
|
342
|
+
source_url: Optional source URL
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Run record with ID and job information
|
|
346
|
+
"""
|
|
347
|
+
client = await self._get_client_with_org()
|
|
348
|
+
|
|
349
|
+
payload = {
|
|
350
|
+
"run_target_column_id": target_column_id,
|
|
351
|
+
"task": task,
|
|
352
|
+
"depth_iterations": depth_iterations,
|
|
353
|
+
"visibility": visibility,
|
|
354
|
+
"auto_report_use_llm_evals": auto_report_use_llm_evals,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
if timeseries_groups:
|
|
358
|
+
payload["timeseries_groups"] = timeseries_groups
|
|
359
|
+
if target_column_override:
|
|
360
|
+
payload["target_column_override"] = target_column_override
|
|
361
|
+
if author:
|
|
362
|
+
payload["author"] = author
|
|
363
|
+
if source_url:
|
|
364
|
+
payload["source_url"] = source_url
|
|
365
|
+
|
|
366
|
+
response = await client.post(
|
|
367
|
+
f"/v1/run-datasets/{dataset_id}/runs",
|
|
368
|
+
json=payload,
|
|
369
|
+
)
|
|
370
|
+
response.raise_for_status()
|
|
371
|
+
return response.json()
|
|
372
|
+
|
|
373
|
+
async def get_results(self, run_id: str) -> EngineResult:
|
|
374
|
+
"""
|
|
375
|
+
Get complete analysis results for a run.
|
|
376
|
+
|
|
377
|
+
This returns all data that the Discovery dashboard displays:
|
|
378
|
+
- LLM-generated summary with key insights
|
|
379
|
+
- All discovered patterns with conditions, citations, and explanations
|
|
380
|
+
- Column/feature information with statistics and importance scores
|
|
381
|
+
- Correlation matrix
|
|
382
|
+
- Global feature importance
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
run_id: The run ID
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
EngineResult with complete analysis data
|
|
389
|
+
"""
|
|
390
|
+
# Use dashboard client for /api/* endpoints (hosted on Next.js dashboard, not Modal API)
|
|
391
|
+
dashboard_client = await self._get_dashboard_client()
|
|
392
|
+
|
|
393
|
+
# Call dashboard API for results
|
|
394
|
+
response = await dashboard_client.get(f"/api/runs/{run_id}/results")
|
|
395
|
+
response.raise_for_status()
|
|
396
|
+
|
|
397
|
+
data = response.json()
|
|
398
|
+
return self._parse_analysis_result(data)
|
|
399
|
+
|
|
400
|
+
async def get_run_status(self, run_id: str) -> RunStatus:
|
|
401
|
+
"""
|
|
402
|
+
Get the status of a run.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
run_id: Run ID
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
RunStatus with current status information
|
|
409
|
+
"""
|
|
410
|
+
client = await self._get_client_with_org()
|
|
411
|
+
|
|
412
|
+
response = await client.get(f"/v1/runs/{run_id}/results")
|
|
413
|
+
response.raise_for_status()
|
|
414
|
+
|
|
415
|
+
data = response.json()
|
|
416
|
+
return RunStatus(
|
|
417
|
+
run_id=data["run_id"],
|
|
418
|
+
status=data["status"],
|
|
419
|
+
job_id=data.get("job_id"),
|
|
420
|
+
job_status=data.get("job_status"),
|
|
421
|
+
error_message=data.get("error_message"),
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
async def wait_for_completion(
|
|
425
|
+
self,
|
|
426
|
+
run_id: str,
|
|
427
|
+
poll_interval: float = 5.0,
|
|
428
|
+
timeout: Optional[float] = None,
|
|
429
|
+
) -> EngineResult:
|
|
430
|
+
"""
|
|
431
|
+
Wait for a run to complete and return the results.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
run_id: Run ID
|
|
435
|
+
poll_interval: Seconds between status checks (default: 5)
|
|
436
|
+
timeout: Maximum seconds to wait (None = no timeout)
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
EngineResult with complete analysis data
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
TimeoutError: If the run doesn't complete within the timeout
|
|
443
|
+
RuntimeError: If the run fails
|
|
444
|
+
"""
|
|
445
|
+
start_time = time.time()
|
|
446
|
+
last_status = None
|
|
447
|
+
poll_count = 0
|
|
448
|
+
|
|
449
|
+
print(f"⏳ Waiting for run {run_id} to complete...")
|
|
450
|
+
|
|
451
|
+
while True:
|
|
452
|
+
result = await self.get_results(run_id)
|
|
453
|
+
elapsed = time.time() - start_time
|
|
454
|
+
poll_count += 1
|
|
455
|
+
|
|
456
|
+
# Log status changes or every 3rd poll (every ~15 seconds)
|
|
457
|
+
if result.status != last_status or poll_count % 3 == 0:
|
|
458
|
+
status_msg = f"Status: {result.status}"
|
|
459
|
+
if result.job_status:
|
|
460
|
+
status_msg += f" (job: {result.job_status})"
|
|
461
|
+
if elapsed > 0:
|
|
462
|
+
status_msg += f" | Elapsed: {elapsed:.1f}s"
|
|
463
|
+
print(f" {status_msg}")
|
|
464
|
+
|
|
465
|
+
last_status = result.status
|
|
466
|
+
|
|
467
|
+
if result.status == "completed":
|
|
468
|
+
print(f"✓ Run completed in {elapsed:.1f}s")
|
|
469
|
+
return result
|
|
470
|
+
elif result.status == "failed":
|
|
471
|
+
error_msg = result.error_message or "Unknown error"
|
|
472
|
+
print(f"✗ Run failed: {error_msg}")
|
|
473
|
+
raise RuntimeError(f"Run {run_id} failed: {error_msg}")
|
|
474
|
+
|
|
475
|
+
if timeout and elapsed > timeout:
|
|
476
|
+
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds")
|
|
477
|
+
|
|
478
|
+
await asyncio.sleep(poll_interval)
|
|
479
|
+
|
|
480
|
+
async def run_async(
|
|
481
|
+
self,
|
|
482
|
+
file: Union[str, Path, "pd.DataFrame"],
|
|
483
|
+
target_column: str,
|
|
484
|
+
depth_iterations: int = 1,
|
|
485
|
+
title: Optional[str] = None,
|
|
486
|
+
description: Optional[str] = None,
|
|
487
|
+
column_descriptions: Optional[Dict[str, str]] = None,
|
|
488
|
+
excluded_columns: Optional[List[str]] = None,
|
|
489
|
+
task: Optional[str] = None,
|
|
490
|
+
visibility: str = "public",
|
|
491
|
+
timeseries_groups: Optional[List[Dict[str, Any]]] = None,
|
|
492
|
+
target_column_override: Optional[str] = None,
|
|
493
|
+
auto_report_use_llm_evals: bool = True,
|
|
494
|
+
author: Optional[str] = None,
|
|
495
|
+
source_url: Optional[str] = None,
|
|
496
|
+
wait: bool = False,
|
|
497
|
+
wait_timeout: Optional[float] = None,
|
|
498
|
+
**kwargs,
|
|
499
|
+
) -> EngineResult:
|
|
500
|
+
"""
|
|
501
|
+
Run analysis on a dataset (async).
|
|
502
|
+
|
|
503
|
+
This method calls the dashboard API which handles the entire workflow:
|
|
504
|
+
file upload, dataset creation, column inference, run creation, and credit deduction.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
file: File path, Path object, or pandas DataFrame
|
|
508
|
+
target_column: Name of the target column
|
|
509
|
+
depth_iterations: Number of iterative feature removal cycles (1 = fastest)
|
|
510
|
+
title: Optional dataset title
|
|
511
|
+
description: Optional dataset description
|
|
512
|
+
column_descriptions: Optional dict mapping column names to descriptions
|
|
513
|
+
excluded_columns: Optional list of column names to exclude from analysis
|
|
514
|
+
task: Task type (regression, binary, multiclass) - auto-detected if None
|
|
515
|
+
visibility: Dataset visibility ("public" or "private", default: "public")
|
|
516
|
+
timeseries_groups: Optional list of timeseries column groups
|
|
517
|
+
target_column_override: Optional override for target column name
|
|
518
|
+
auto_report_use_llm_evals: Use LLM evaluations (default: True)
|
|
519
|
+
author: Optional dataset author
|
|
520
|
+
source_url: Optional source URL
|
|
521
|
+
wait: If True, wait for analysis to complete and return full results
|
|
522
|
+
wait_timeout: Maximum seconds to wait for completion (only if wait=True)
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
EngineResult with run_id and (if wait=True) complete results
|
|
526
|
+
"""
|
|
527
|
+
# Prepare file for upload
|
|
528
|
+
if pd is not None and isinstance(file, pd.DataFrame):
|
|
529
|
+
# Convert DataFrame to CSV in memory
|
|
530
|
+
import io
|
|
531
|
+
|
|
532
|
+
print(f"📊 Preparing DataFrame ({len(file)} rows, {len(file.columns)} columns)...")
|
|
533
|
+
buffer = io.BytesIO()
|
|
534
|
+
file.to_csv(buffer, index=False)
|
|
535
|
+
buffer.seek(0)
|
|
536
|
+
file_content = buffer.getvalue()
|
|
537
|
+
filename = (title + ".csv") if title else "dataset.csv"
|
|
538
|
+
mime_type = "text/csv"
|
|
539
|
+
file_size_mb = len(file_content) / (1024 * 1024)
|
|
540
|
+
print(f" File size: {file_size_mb:.2f} MB")
|
|
541
|
+
else:
|
|
542
|
+
# Read file from disk
|
|
543
|
+
file_path = Path(file)
|
|
544
|
+
if not file_path.exists():
|
|
545
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
546
|
+
print(f"📁 Reading file: {file_path.name}...")
|
|
547
|
+
file_content = file_path.read_bytes()
|
|
548
|
+
filename = file_path.name
|
|
549
|
+
mime_type = (
|
|
550
|
+
"text/csv" if file_path.suffix == ".csv" else "application/vnd.apache.parquet"
|
|
551
|
+
)
|
|
552
|
+
file_size_mb = len(file_content) / (1024 * 1024)
|
|
553
|
+
print(f" File size: {file_size_mb:.2f} MB")
|
|
554
|
+
|
|
555
|
+
# Prepare multipart form data
|
|
556
|
+
files = {"file": (filename, file_content, mime_type)}
|
|
557
|
+
data: Dict[str, Any] = {
|
|
558
|
+
"target_column": target_column,
|
|
559
|
+
"depth_iterations": str(depth_iterations),
|
|
560
|
+
"visibility": visibility,
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
if description:
|
|
564
|
+
data["description"] = description
|
|
565
|
+
if author:
|
|
566
|
+
data["author"] = author
|
|
567
|
+
if source_url:
|
|
568
|
+
data["source_url"] = source_url
|
|
569
|
+
if column_descriptions:
|
|
570
|
+
data["column_descriptions"] = json.dumps(column_descriptions)
|
|
571
|
+
if excluded_columns:
|
|
572
|
+
data["excluded_columns"] = json.dumps(excluded_columns)
|
|
573
|
+
if timeseries_groups:
|
|
574
|
+
data["timeseries_groups"] = json.dumps(timeseries_groups)
|
|
575
|
+
|
|
576
|
+
# Call dashboard API to create report
|
|
577
|
+
print(
|
|
578
|
+
f"🚀 Uploading file and creating run (depth: {depth_iterations}, target: {target_column})..."
|
|
579
|
+
)
|
|
580
|
+
# Use dashboard client for /api/* endpoints (hosted on Next.js dashboard, not Modal API)
|
|
581
|
+
dashboard_client = await self._get_dashboard_client()
|
|
582
|
+
# httpx automatically handles multipart/form-data when both files and data are provided
|
|
583
|
+
response = await dashboard_client.post("/api/reports/create", files=files, data=data)
|
|
584
|
+
response.raise_for_status()
|
|
585
|
+
|
|
586
|
+
result_data = response.json()
|
|
587
|
+
|
|
588
|
+
# Check if duplicate
|
|
589
|
+
if result_data.get("duplicate"):
|
|
590
|
+
# For duplicates, get the run_id and fetch results
|
|
591
|
+
report_id = result_data.get("report_id")
|
|
592
|
+
run_id = result_data.get("run_id")
|
|
593
|
+
|
|
594
|
+
if not report_id or not run_id:
|
|
595
|
+
raise ValueError("Duplicate report found but missing report_id or run_id")
|
|
596
|
+
|
|
597
|
+
print(f"ℹ️ Duplicate report found (run_id: {run_id})")
|
|
598
|
+
|
|
599
|
+
# Construct dashboard URL for the processing page
|
|
600
|
+
progress_url = f"{self.dashboard_url}/reports/new/{run_id}/processing"
|
|
601
|
+
print(f"🔗 View progress: {progress_url}")
|
|
602
|
+
|
|
603
|
+
# If wait is True, fetch the full results for the existing report
|
|
604
|
+
if wait:
|
|
605
|
+
return await self.get_results(run_id)
|
|
606
|
+
|
|
607
|
+
# Otherwise return a minimal result with the run_id
|
|
608
|
+
return EngineResult(
|
|
609
|
+
run_id=run_id,
|
|
610
|
+
status="completed",
|
|
611
|
+
report_id=report_id,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
run_id = result_data["run_id"]
|
|
615
|
+
print(f"✓ Run created: {run_id}")
|
|
616
|
+
|
|
617
|
+
# Construct dashboard URL for the processing page
|
|
618
|
+
progress_url = f"{self.dashboard_url}/reports/new/{run_id}/processing"
|
|
619
|
+
print(f"🔗 View progress: {progress_url}")
|
|
620
|
+
|
|
621
|
+
if wait:
|
|
622
|
+
# Wait for completion and return full results
|
|
623
|
+
return await self.wait_for_completion(run_id, timeout=wait_timeout)
|
|
624
|
+
|
|
625
|
+
# Return minimal result with pending status
|
|
626
|
+
return EngineResult(
|
|
627
|
+
run_id=run_id,
|
|
628
|
+
status="pending",
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
def run(
|
|
632
|
+
self,
|
|
633
|
+
file: Union[str, Path, "pd.DataFrame"],
|
|
634
|
+
target_column: str,
|
|
635
|
+
depth_iterations: int = 1,
|
|
636
|
+
title: Optional[str] = None,
|
|
637
|
+
description: Optional[str] = None,
|
|
638
|
+
column_descriptions: Optional[Dict[str, str]] = None,
|
|
639
|
+
excluded_columns: Optional[List[str]] = None,
|
|
640
|
+
task: Optional[str] = None,
|
|
641
|
+
visibility: str = "public",
|
|
642
|
+
timeseries_groups: Optional[List[Dict[str, Any]]] = None,
|
|
643
|
+
target_column_override: Optional[str] = None,
|
|
644
|
+
auto_report_use_llm_evals: bool = True,
|
|
645
|
+
author: Optional[str] = None,
|
|
646
|
+
source_url: Optional[str] = None,
|
|
647
|
+
wait: bool = False,
|
|
648
|
+
wait_timeout: Optional[float] = None,
|
|
649
|
+
**kwargs,
|
|
650
|
+
) -> EngineResult:
|
|
651
|
+
"""
|
|
652
|
+
Run analysis on a dataset (synchronous wrapper).
|
|
653
|
+
|
|
654
|
+
This is a synchronous wrapper around run_async().
|
|
655
|
+
|
|
656
|
+
Args:
|
|
657
|
+
file: File path, Path object, or pandas DataFrame
|
|
658
|
+
target_column: Name of the target column
|
|
659
|
+
depth_iterations: Number of iterative feature removal cycles (1 = fastest)
|
|
660
|
+
title: Optional dataset title
|
|
661
|
+
description: Optional dataset description
|
|
662
|
+
column_descriptions: Optional dict mapping column names to descriptions
|
|
663
|
+
excluded_columns: Optional list of column names to exclude from analysis
|
|
664
|
+
task: Task type (regression, binary_classification, multiclass_classification) - auto-detected if None
|
|
665
|
+
visibility: Dataset visibility ("public" or "private", default: "public")
|
|
666
|
+
timeseries_groups: Optional list of timeseries column groups
|
|
667
|
+
target_column_override: Optional override for target column name
|
|
668
|
+
auto_report_use_llm_evals: Use LLM evaluations (default: True)
|
|
669
|
+
author: Optional dataset author
|
|
670
|
+
source_url: Optional source URL
|
|
671
|
+
wait: If True, wait for analysis to complete and return full results
|
|
672
|
+
wait_timeout: Maximum seconds to wait for completion (only if wait=True)
|
|
673
|
+
**kwargs: Additional arguments passed to run_async()
|
|
674
|
+
|
|
675
|
+
Returns:
|
|
676
|
+
EngineResult with run_id and (if wait=True) complete results
|
|
677
|
+
"""
|
|
678
|
+
coro = self.run_async(
|
|
679
|
+
file,
|
|
680
|
+
target_column,
|
|
681
|
+
depth_iterations,
|
|
682
|
+
title=title,
|
|
683
|
+
description=description,
|
|
684
|
+
column_descriptions=column_descriptions,
|
|
685
|
+
excluded_columns=excluded_columns,
|
|
686
|
+
task=task,
|
|
687
|
+
visibility=visibility,
|
|
688
|
+
timeseries_groups=timeseries_groups,
|
|
689
|
+
target_column_override=target_column_override,
|
|
690
|
+
auto_report_use_llm_evals=auto_report_use_llm_evals,
|
|
691
|
+
author=author,
|
|
692
|
+
source_url=source_url,
|
|
693
|
+
wait=wait,
|
|
694
|
+
wait_timeout=wait_timeout,
|
|
695
|
+
**kwargs,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Try to run the coroutine
|
|
699
|
+
# If we're in a Jupyter notebook with a running event loop, asyncio.run() will fail
|
|
700
|
+
try:
|
|
701
|
+
return asyncio.run(coro)
|
|
702
|
+
except RuntimeError as e:
|
|
703
|
+
# Check if the error is about a running event loop
|
|
704
|
+
if "cannot be called from a running event loop" in str(e).lower():
|
|
705
|
+
# We're in a Jupyter/IPython environment with a running event loop
|
|
706
|
+
# Try to use nest_asyncio if available
|
|
707
|
+
try:
|
|
708
|
+
import nest_asyncio
|
|
709
|
+
|
|
710
|
+
# Apply nest_asyncio (it's safe to call multiple times)
|
|
711
|
+
nest_asyncio.apply()
|
|
712
|
+
# Now we can use asyncio.run() even with a running loop
|
|
713
|
+
return asyncio.run(coro)
|
|
714
|
+
except ImportError:
|
|
715
|
+
raise RuntimeError(
|
|
716
|
+
"Cannot use engine.run() in a Jupyter notebook or environment with a running event loop. "
|
|
717
|
+
"Please use 'await engine.run_async(...)' instead, or install nest_asyncio "
|
|
718
|
+
"(pip install nest-asyncio) to enable nested event loops."
|
|
719
|
+
) from e
|
|
720
|
+
# Re-raise if it's a different RuntimeError
|
|
721
|
+
raise
|
|
722
|
+
|
|
723
|
+
def _parse_analysis_result(self, data: Dict[str, Any]) -> EngineResult:
|
|
724
|
+
"""Parse API response into EngineResult dataclass."""
|
|
725
|
+
# Parse summary
|
|
726
|
+
summary = None
|
|
727
|
+
if data.get("summary"):
|
|
728
|
+
summary = self._parse_summary(data["summary"])
|
|
729
|
+
|
|
730
|
+
# Parse patterns
|
|
731
|
+
patterns = []
|
|
732
|
+
for p in data.get("patterns", []):
|
|
733
|
+
patterns.append(
|
|
734
|
+
Pattern(
|
|
735
|
+
id=p["id"],
|
|
736
|
+
task=p.get("task", "regression"),
|
|
737
|
+
target_column=p.get("target_column", ""),
|
|
738
|
+
direction=p.get("direction", "max"),
|
|
739
|
+
p_value=p.get("p_value", 0),
|
|
740
|
+
conditions=p.get("conditions", []),
|
|
741
|
+
lift_value=p.get("lift_value", 0),
|
|
742
|
+
support_count=p.get("support_count", 0),
|
|
743
|
+
support_percentage=p.get("support_percentage", 0),
|
|
744
|
+
pattern_type=p.get("pattern_type", "validated"),
|
|
745
|
+
novelty_type=p.get("novelty_type", "confirmatory"),
|
|
746
|
+
target_score=p.get("target_score", 0),
|
|
747
|
+
target_class=p.get("target_class"),
|
|
748
|
+
target_mean=p.get("target_mean"),
|
|
749
|
+
target_std=p.get("target_std"),
|
|
750
|
+
description=p.get("description", ""),
|
|
751
|
+
novelty_explanation=p.get("novelty_explanation", ""),
|
|
752
|
+
citations=p.get("citations", []),
|
|
753
|
+
)
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Parse columns
|
|
757
|
+
columns = []
|
|
758
|
+
for c in data.get("columns", []):
|
|
759
|
+
columns.append(
|
|
760
|
+
Column(
|
|
761
|
+
id=c["id"],
|
|
762
|
+
name=c["name"],
|
|
763
|
+
display_name=c.get("display_name", c["name"]),
|
|
764
|
+
type=c.get("type", "continuous"),
|
|
765
|
+
data_type=c.get("data_type", "float"),
|
|
766
|
+
enabled=c.get("enabled", True),
|
|
767
|
+
description=c.get("description"),
|
|
768
|
+
mean=c.get("mean"),
|
|
769
|
+
median=c.get("median"),
|
|
770
|
+
std=c.get("std"),
|
|
771
|
+
min=c.get("min"),
|
|
772
|
+
max=c.get("max"),
|
|
773
|
+
iqr_min=c.get("iqr_min"),
|
|
774
|
+
iqr_max=c.get("iqr_max"),
|
|
775
|
+
mode=c.get("mode"),
|
|
776
|
+
approx_unique=c.get("approx_unique"),
|
|
777
|
+
null_percentage=c.get("null_percentage"),
|
|
778
|
+
feature_importance_score=c.get("feature_importance_score"),
|
|
779
|
+
)
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Parse correlation matrix
|
|
783
|
+
correlation_matrix = []
|
|
784
|
+
for entry in data.get("correlation_matrix", []):
|
|
785
|
+
correlation_matrix.append(
|
|
786
|
+
CorrelationEntry(
|
|
787
|
+
feature_x=entry["feature_x"],
|
|
788
|
+
feature_y=entry["feature_y"],
|
|
789
|
+
value=entry["value"],
|
|
790
|
+
)
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Parse feature importance
|
|
794
|
+
feature_importance = None
|
|
795
|
+
if data.get("feature_importance"):
|
|
796
|
+
fi = data["feature_importance"]
|
|
797
|
+
scores = [
|
|
798
|
+
FeatureImportanceScore(feature=s["feature"], score=s["score"])
|
|
799
|
+
for s in fi.get("scores", [])
|
|
800
|
+
]
|
|
801
|
+
feature_importance = FeatureImportance(
|
|
802
|
+
kind=fi.get("kind", "global"),
|
|
803
|
+
baseline=fi.get("baseline", 0),
|
|
804
|
+
scores=scores,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
return EngineResult(
|
|
808
|
+
run_id=data["run_id"],
|
|
809
|
+
report_id=data.get("report_id"),
|
|
810
|
+
status=data.get("status", "unknown"),
|
|
811
|
+
dataset_title=data.get("dataset_title"),
|
|
812
|
+
dataset_description=data.get("dataset_description"),
|
|
813
|
+
total_rows=data.get("total_rows"),
|
|
814
|
+
target_column=data.get("target_column"),
|
|
815
|
+
task=data.get("task"),
|
|
816
|
+
summary=summary,
|
|
817
|
+
patterns=patterns,
|
|
818
|
+
columns=columns,
|
|
819
|
+
correlation_matrix=correlation_matrix,
|
|
820
|
+
feature_importance=feature_importance,
|
|
821
|
+
job_id=data.get("job_id"),
|
|
822
|
+
job_status=data.get("job_status"),
|
|
823
|
+
error_message=data.get("error_message"),
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
def _parse_summary(self, data: Dict[str, Any]) -> Summary:
|
|
827
|
+
"""Parse summary data into Summary dataclass."""
|
|
828
|
+
# Parse data insights
|
|
829
|
+
data_insights = None
|
|
830
|
+
if data.get("data_insights"):
|
|
831
|
+
di = data["data_insights"]
|
|
832
|
+
data_insights = DataInsights(
|
|
833
|
+
important_features=di.get("important_features", []),
|
|
834
|
+
important_features_explanation=di.get("important_features_explanation", ""),
|
|
835
|
+
strong_correlations=di.get("strong_correlations", []),
|
|
836
|
+
strong_correlations_explanation=di.get("strong_correlations_explanation", ""),
|
|
837
|
+
notable_relationships=di.get("notable_relationships", []),
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
return Summary(
|
|
841
|
+
overview=data.get("overview", ""),
|
|
842
|
+
key_insights=data.get("key_insights", []),
|
|
843
|
+
novel_patterns=PatternGroup(
|
|
844
|
+
pattern_ids=data.get("novel_patterns", {}).get("pattern_ids", []),
|
|
845
|
+
explanation=data.get("novel_patterns", {}).get("explanation", ""),
|
|
846
|
+
),
|
|
847
|
+
surprising_findings=PatternGroup(
|
|
848
|
+
pattern_ids=data.get("surprising_findings", {}).get("pattern_ids", []),
|
|
849
|
+
explanation=data.get("surprising_findings", {}).get("explanation", ""),
|
|
850
|
+
),
|
|
851
|
+
statistically_significant=PatternGroup(
|
|
852
|
+
pattern_ids=data.get("statistically_significant", {}).get("pattern_ids", []),
|
|
853
|
+
explanation=data.get("statistically_significant", {}).get("explanation", ""),
|
|
854
|
+
),
|
|
855
|
+
data_insights=data_insights,
|
|
856
|
+
selected_pattern_id=data.get("selected_pattern_id"),
|
|
857
|
+
)
|