discovery-engine-api 0.1.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
discovery/client.py ADDED
@@ -0,0 +1,857 @@
1
+ """Discovery Engine Python SDK."""
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import httpx
11
+
12
+ try:
13
+ import pandas as pd
14
+ except ImportError:
15
+ pd = None
16
+
17
+ from discovery.types import (
18
+ Column,
19
+ CorrelationEntry,
20
+ DataInsights,
21
+ EngineResult,
22
+ FeatureImportance,
23
+ FeatureImportanceScore,
24
+ FileInfo,
25
+ Pattern,
26
+ PatternGroup,
27
+ RunStatus,
28
+ Summary,
29
+ )
30
+
31
+
32
+ class Engine:
33
+ """Engine for the Discovery Engine API."""
34
+
35
+ # Production API URL (can be overridden via DISCOVERY_API_URL env var for testing)
36
+ # This points to the Modal-deployed FastAPI API
37
+ _DEFAULT_BASE_URL = "https://leap-labs-production--discovery-api.modal.run"
38
+
39
+ # Dashboard URL for web UI and /api/* endpoints
40
+ _DEFAULT_DASHBOARD_URL = "https://disco.leap-labs.com"
41
+
42
+ def __init__(self, api_key: str):
43
+ """
44
+ Initialize the Discovery Engine.
45
+
46
+ Args:
47
+ api_key: Your API key
48
+ """
49
+
50
+ print("Initializing Discovery Engine...")
51
+ self.api_key = api_key
52
+ # Use DISCOVERY_API_URL env var if set (for testing/custom deployments),
53
+ # otherwise use the production default
54
+ self.base_url = os.getenv("DISCOVERY_API_URL", self._DEFAULT_BASE_URL).rstrip("/")
55
+ # Dashboard URL for /api/* endpoints and web UI links
56
+ self.dashboard_url = os.getenv(
57
+ "DISCOVERY_DASHBOARD_URL", self._DEFAULT_DASHBOARD_URL
58
+ ).rstrip("/")
59
+ self._organization_id: Optional[str] = None
60
+ self._client: Optional[httpx.AsyncClient] = None
61
+ self._dashboard_client: Optional[httpx.AsyncClient] = None
62
+ self._org_fetched = False
63
+
64
+ async def _ensure_organization_id(self) -> str:
65
+ """
66
+ Ensure we have an organization ID, fetching from API if needed.
67
+
68
+ The organization ID is required for API requests to identify which
69
+ organization the user belongs to (multi-tenancy support).
70
+
71
+ Returns:
72
+ Organization ID string
73
+
74
+ Raises:
75
+ ValueError: If no organization is found or API request fails
76
+ """
77
+ if self._organization_id:
78
+ return self._organization_id
79
+
80
+ if not self._org_fetched:
81
+ # Fetch user's organizations and use the first one
82
+ try:
83
+ orgs = await self.get_organizations()
84
+ if orgs:
85
+ self._organization_id = orgs[0]["id"]
86
+ except ValueError as e:
87
+ # Re-raise with more context
88
+ raise ValueError(
89
+ f"Failed to fetch organization: {e}. "
90
+ "Please ensure your API key is valid and you belong to an organization."
91
+ ) from e
92
+ self._org_fetched = True
93
+
94
+ if not self._organization_id:
95
+ raise ValueError(
96
+ "No organization found for your account. "
97
+ "Please contact support if this issue persists."
98
+ )
99
+
100
+ return self._organization_id
101
+
102
+ async def _get_client(self) -> httpx.AsyncClient:
103
+ """Get or create the HTTP client."""
104
+ if self._client is None:
105
+ headers = {"Authorization": f"Bearer {self.api_key}"}
106
+ self._client = httpx.AsyncClient(
107
+ base_url=self.base_url,
108
+ headers=headers,
109
+ timeout=60.0,
110
+ )
111
+ return self._client
112
+
113
+ async def _get_client_with_org(self) -> httpx.AsyncClient:
114
+ """
115
+ Get HTTP client with organization header set.
116
+
117
+ The organization ID is required for API requests to identify which
118
+ organization the user belongs to (multi-tenancy support).
119
+ """
120
+ client = await self._get_client()
121
+
122
+ # Ensure we have an organization ID
123
+ org_id = await self._ensure_organization_id()
124
+
125
+ # Set the organization header
126
+ client.headers["X-Organization-ID"] = org_id
127
+
128
+ return client
129
+
130
+ async def _get_dashboard_client(self) -> httpx.AsyncClient:
131
+ """Get or create the HTTP client for dashboard API calls."""
132
+ if self._dashboard_client is None:
133
+ headers = {"Authorization": f"Bearer {self.api_key}"}
134
+ self._dashboard_client = httpx.AsyncClient(
135
+ base_url=self.dashboard_url,
136
+ headers=headers,
137
+ timeout=60.0,
138
+ )
139
+ return self._dashboard_client
140
+
141
+ async def close(self):
142
+ """Close the HTTP clients."""
143
+ if self._client:
144
+ await self._client.aclose()
145
+ self._client = None
146
+ if self._dashboard_client:
147
+ await self._dashboard_client.aclose()
148
+ self._dashboard_client = None
149
+
150
+ async def __aenter__(self):
151
+ """Async context manager entry."""
152
+ return self
153
+
154
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
155
+ """Async context manager exit."""
156
+ await self.close()
157
+
158
+ async def get_organizations(self) -> List[Dict[str, Any]]:
159
+ """
160
+ Get the organizations you belong to.
161
+
162
+ Returns:
163
+ List of organizations with id, name, and slug
164
+
165
+ Raises:
166
+ ValueError: If the API request fails
167
+ """
168
+ client = await self._get_client()
169
+
170
+ try:
171
+ response = await client.get("/v1/me/organizations")
172
+ response.raise_for_status()
173
+ return response.json()
174
+ except httpx.HTTPStatusError as e:
175
+ raise ValueError(
176
+ f"Failed to fetch organizations: {e.response.status_code} {e.response.text}"
177
+ ) from e
178
+ except httpx.RequestError as e:
179
+ raise ValueError(f"Failed to connect to API: {str(e)}") from e
180
+
181
+ async def upload_file(
182
+ self, file: Union[str, Path, "pd.DataFrame"], filename: Optional[str] = None
183
+ ) -> FileInfo:
184
+ """
185
+ Upload a file to the API.
186
+
187
+ Args:
188
+ file: File path, Path object, or pandas DataFrame
189
+ filename: Optional filename (for DataFrame uploads)
190
+
191
+ Returns:
192
+ FileInfo with file_path, file_hash, file_size, mime_type
193
+ """
194
+ client = await self._get_client_with_org()
195
+
196
+ if pd is not None and isinstance(file, pd.DataFrame):
197
+ # Convert DataFrame to CSV in memory
198
+ import io
199
+
200
+ buffer = io.BytesIO()
201
+ file.to_csv(buffer, index=False)
202
+ buffer.seek(0)
203
+ file_content = buffer.getvalue()
204
+ filename = filename or "dataset.csv"
205
+ mime_type = "text/csv"
206
+ else:
207
+ # Read file from disk
208
+ file_path = Path(file)
209
+ if not file_path.exists():
210
+ raise FileNotFoundError(f"File not found: {file_path}")
211
+ file_content = file_path.read_bytes()
212
+ filename = filename or file_path.name
213
+ mime_type = (
214
+ "text/csv" if file_path.suffix == ".csv" else "application/vnd.apache.parquet"
215
+ )
216
+
217
+ # Upload file
218
+ files = {"file": (filename, file_content, mime_type)}
219
+ response = await client.post("/v1/upload", files=files)
220
+ response.raise_for_status()
221
+
222
+ data = response.json()
223
+ return FileInfo(
224
+ file_path=data["file_path"],
225
+ file_hash=data["file_hash"],
226
+ file_size=data["file_size"],
227
+ mime_type=data["mime_type"],
228
+ )
229
+
230
+ async def create_dataset(
231
+ self,
232
+ title: Optional[str] = None,
233
+ description: Optional[str] = None,
234
+ total_rows: int = 0,
235
+ dataset_size_mb: Optional[float] = None,
236
+ author: Optional[str] = None,
237
+ source_url: Optional[str] = None,
238
+ ) -> Dict[str, Any]:
239
+ """
240
+ Create a dataset record.
241
+
242
+ Args:
243
+ title: Dataset title
244
+ description: Dataset description
245
+ total_rows: Number of rows in the dataset
246
+ dataset_size_mb: Dataset size in MB
247
+ author: Optional author attribution
248
+ source_url: Optional source URL
249
+
250
+ Returns:
251
+ Dataset record with ID
252
+ """
253
+ client = await self._get_client_with_org()
254
+
255
+ response = await client.post(
256
+ "/v1/run-datasets",
257
+ json={
258
+ "title": title,
259
+ "description": description,
260
+ "total_rows": total_rows,
261
+ "dataset_size_mb": dataset_size_mb,
262
+ "author": author,
263
+ "source_url": source_url,
264
+ },
265
+ )
266
+ response.raise_for_status()
267
+ return response.json()
268
+
269
+ async def create_file_record(self, dataset_id: str, file_info: FileInfo) -> Dict[str, Any]:
270
+ """
271
+ Create a file record for a dataset.
272
+
273
+ Args:
274
+ dataset_id: Dataset ID
275
+ file_info: FileInfo from upload_file()
276
+
277
+ Returns:
278
+ File record with ID
279
+ """
280
+ client = await self._get_client_with_org()
281
+
282
+ response = await client.post(
283
+ f"/v1/run-datasets/{dataset_id}/files",
284
+ json={
285
+ "mime_type": file_info.mime_type,
286
+ "file_path": file_info.file_path,
287
+ "file_hash": file_info.file_hash,
288
+ "file_size": file_info.file_size,
289
+ },
290
+ )
291
+ response.raise_for_status()
292
+ return response.json()
293
+
294
+ async def create_columns(
295
+ self, dataset_id: str, columns: List[Dict[str, Any]]
296
+ ) -> List[Dict[str, Any]]:
297
+ """
298
+ Create column records for a dataset.
299
+
300
+ Args:
301
+ dataset_id: Dataset ID
302
+ columns: List of column definitions with full metadata
303
+
304
+ Returns:
305
+ List of column records with IDs
306
+ """
307
+ client = await self._get_client_with_org()
308
+
309
+ response = await client.post(
310
+ f"/v1/run-datasets/{dataset_id}/columns",
311
+ json=columns,
312
+ )
313
+ response.raise_for_status()
314
+ return response.json()
315
+
316
+ async def create_run(
317
+ self,
318
+ dataset_id: str,
319
+ target_column_id: str,
320
+ task: str = "regression",
321
+ depth_iterations: int = 1,
322
+ visibility: str = "public",
323
+ timeseries_groups: Optional[List[Dict[str, Any]]] = None,
324
+ target_column_override: Optional[str] = None,
325
+ auto_report_use_llm_evals: bool = True,
326
+ author: Optional[str] = None,
327
+ source_url: Optional[str] = None,
328
+ ) -> Dict[str, Any]:
329
+ """
330
+ Create a run and enqueue it for processing.
331
+
332
+ Args:
333
+ dataset_id: Dataset ID
334
+ target_column_id: Target column ID
335
+ task: Task type (regression, binary_classification, multiclass_classification)
336
+ depth_iterations: Number of iterative feature removal cycles (1 = fastest)
337
+ visibility: Dataset visibility ("public" or "private")
338
+ timeseries_groups: Optional list of timeseries column groups
339
+ target_column_override: Optional override for target column name
340
+ auto_report_use_llm_evals: Use LLM evaluations
341
+ author: Optional dataset author
342
+ source_url: Optional source URL
343
+
344
+ Returns:
345
+ Run record with ID and job information
346
+ """
347
+ client = await self._get_client_with_org()
348
+
349
+ payload = {
350
+ "run_target_column_id": target_column_id,
351
+ "task": task,
352
+ "depth_iterations": depth_iterations,
353
+ "visibility": visibility,
354
+ "auto_report_use_llm_evals": auto_report_use_llm_evals,
355
+ }
356
+
357
+ if timeseries_groups:
358
+ payload["timeseries_groups"] = timeseries_groups
359
+ if target_column_override:
360
+ payload["target_column_override"] = target_column_override
361
+ if author:
362
+ payload["author"] = author
363
+ if source_url:
364
+ payload["source_url"] = source_url
365
+
366
+ response = await client.post(
367
+ f"/v1/run-datasets/{dataset_id}/runs",
368
+ json=payload,
369
+ )
370
+ response.raise_for_status()
371
+ return response.json()
372
+
373
+ async def get_results(self, run_id: str) -> EngineResult:
374
+ """
375
+ Get complete analysis results for a run.
376
+
377
+ This returns all data that the Discovery dashboard displays:
378
+ - LLM-generated summary with key insights
379
+ - All discovered patterns with conditions, citations, and explanations
380
+ - Column/feature information with statistics and importance scores
381
+ - Correlation matrix
382
+ - Global feature importance
383
+
384
+ Args:
385
+ run_id: The run ID
386
+
387
+ Returns:
388
+ EngineResult with complete analysis data
389
+ """
390
+ # Use dashboard client for /api/* endpoints (hosted on Next.js dashboard, not Modal API)
391
+ dashboard_client = await self._get_dashboard_client()
392
+
393
+ # Call dashboard API for results
394
+ response = await dashboard_client.get(f"/api/runs/{run_id}/results")
395
+ response.raise_for_status()
396
+
397
+ data = response.json()
398
+ return self._parse_analysis_result(data)
399
+
400
+ async def get_run_status(self, run_id: str) -> RunStatus:
401
+ """
402
+ Get the status of a run.
403
+
404
+ Args:
405
+ run_id: Run ID
406
+
407
+ Returns:
408
+ RunStatus with current status information
409
+ """
410
+ client = await self._get_client_with_org()
411
+
412
+ response = await client.get(f"/v1/runs/{run_id}/results")
413
+ response.raise_for_status()
414
+
415
+ data = response.json()
416
+ return RunStatus(
417
+ run_id=data["run_id"],
418
+ status=data["status"],
419
+ job_id=data.get("job_id"),
420
+ job_status=data.get("job_status"),
421
+ error_message=data.get("error_message"),
422
+ )
423
+
424
+ async def wait_for_completion(
425
+ self,
426
+ run_id: str,
427
+ poll_interval: float = 5.0,
428
+ timeout: Optional[float] = None,
429
+ ) -> EngineResult:
430
+ """
431
+ Wait for a run to complete and return the results.
432
+
433
+ Args:
434
+ run_id: Run ID
435
+ poll_interval: Seconds between status checks (default: 5)
436
+ timeout: Maximum seconds to wait (None = no timeout)
437
+
438
+ Returns:
439
+ EngineResult with complete analysis data
440
+
441
+ Raises:
442
+ TimeoutError: If the run doesn't complete within the timeout
443
+ RuntimeError: If the run fails
444
+ """
445
+ start_time = time.time()
446
+ last_status = None
447
+ poll_count = 0
448
+
449
+ print(f"⏳ Waiting for run {run_id} to complete...")
450
+
451
+ while True:
452
+ result = await self.get_results(run_id)
453
+ elapsed = time.time() - start_time
454
+ poll_count += 1
455
+
456
+ # Log status changes or every 3rd poll (every ~15 seconds)
457
+ if result.status != last_status or poll_count % 3 == 0:
458
+ status_msg = f"Status: {result.status}"
459
+ if result.job_status:
460
+ status_msg += f" (job: {result.job_status})"
461
+ if elapsed > 0:
462
+ status_msg += f" | Elapsed: {elapsed:.1f}s"
463
+ print(f" {status_msg}")
464
+
465
+ last_status = result.status
466
+
467
+ if result.status == "completed":
468
+ print(f"✓ Run completed in {elapsed:.1f}s")
469
+ return result
470
+ elif result.status == "failed":
471
+ error_msg = result.error_message or "Unknown error"
472
+ print(f"✗ Run failed: {error_msg}")
473
+ raise RuntimeError(f"Run {run_id} failed: {error_msg}")
474
+
475
+ if timeout and elapsed > timeout:
476
+ raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds")
477
+
478
+ await asyncio.sleep(poll_interval)
479
+
480
+ async def run_async(
481
+ self,
482
+ file: Union[str, Path, "pd.DataFrame"],
483
+ target_column: str,
484
+ depth_iterations: int = 1,
485
+ title: Optional[str] = None,
486
+ description: Optional[str] = None,
487
+ column_descriptions: Optional[Dict[str, str]] = None,
488
+ excluded_columns: Optional[List[str]] = None,
489
+ task: Optional[str] = None,
490
+ visibility: str = "public",
491
+ timeseries_groups: Optional[List[Dict[str, Any]]] = None,
492
+ target_column_override: Optional[str] = None,
493
+ auto_report_use_llm_evals: bool = True,
494
+ author: Optional[str] = None,
495
+ source_url: Optional[str] = None,
496
+ wait: bool = False,
497
+ wait_timeout: Optional[float] = None,
498
+ **kwargs,
499
+ ) -> EngineResult:
500
+ """
501
+ Run analysis on a dataset (async).
502
+
503
+ This method calls the dashboard API which handles the entire workflow:
504
+ file upload, dataset creation, column inference, run creation, and credit deduction.
505
+
506
+ Args:
507
+ file: File path, Path object, or pandas DataFrame
508
+ target_column: Name of the target column
509
+ depth_iterations: Number of iterative feature removal cycles (1 = fastest)
510
+ title: Optional dataset title
511
+ description: Optional dataset description
512
+ column_descriptions: Optional dict mapping column names to descriptions
513
+ excluded_columns: Optional list of column names to exclude from analysis
514
+ task: Task type (regression, binary, multiclass) - auto-detected if None
515
+ visibility: Dataset visibility ("public" or "private", default: "public")
516
+ timeseries_groups: Optional list of timeseries column groups
517
+ target_column_override: Optional override for target column name
518
+ auto_report_use_llm_evals: Use LLM evaluations (default: True)
519
+ author: Optional dataset author
520
+ source_url: Optional source URL
521
+ wait: If True, wait for analysis to complete and return full results
522
+ wait_timeout: Maximum seconds to wait for completion (only if wait=True)
523
+
524
+ Returns:
525
+ EngineResult with run_id and (if wait=True) complete results
526
+ """
527
+ # Prepare file for upload
528
+ if pd is not None and isinstance(file, pd.DataFrame):
529
+ # Convert DataFrame to CSV in memory
530
+ import io
531
+
532
+ print(f"📊 Preparing DataFrame ({len(file)} rows, {len(file.columns)} columns)...")
533
+ buffer = io.BytesIO()
534
+ file.to_csv(buffer, index=False)
535
+ buffer.seek(0)
536
+ file_content = buffer.getvalue()
537
+ filename = (title + ".csv") if title else "dataset.csv"
538
+ mime_type = "text/csv"
539
+ file_size_mb = len(file_content) / (1024 * 1024)
540
+ print(f" File size: {file_size_mb:.2f} MB")
541
+ else:
542
+ # Read file from disk
543
+ file_path = Path(file)
544
+ if not file_path.exists():
545
+ raise FileNotFoundError(f"File not found: {file_path}")
546
+ print(f"📁 Reading file: {file_path.name}...")
547
+ file_content = file_path.read_bytes()
548
+ filename = file_path.name
549
+ mime_type = (
550
+ "text/csv" if file_path.suffix == ".csv" else "application/vnd.apache.parquet"
551
+ )
552
+ file_size_mb = len(file_content) / (1024 * 1024)
553
+ print(f" File size: {file_size_mb:.2f} MB")
554
+
555
+ # Prepare multipart form data
556
+ files = {"file": (filename, file_content, mime_type)}
557
+ data: Dict[str, Any] = {
558
+ "target_column": target_column,
559
+ "depth_iterations": str(depth_iterations),
560
+ "visibility": visibility,
561
+ }
562
+
563
+ if description:
564
+ data["description"] = description
565
+ if author:
566
+ data["author"] = author
567
+ if source_url:
568
+ data["source_url"] = source_url
569
+ if column_descriptions:
570
+ data["column_descriptions"] = json.dumps(column_descriptions)
571
+ if excluded_columns:
572
+ data["excluded_columns"] = json.dumps(excluded_columns)
573
+ if timeseries_groups:
574
+ data["timeseries_groups"] = json.dumps(timeseries_groups)
575
+
576
+ # Call dashboard API to create report
577
+ print(
578
+ f"🚀 Uploading file and creating run (depth: {depth_iterations}, target: {target_column})..."
579
+ )
580
+ # Use dashboard client for /api/* endpoints (hosted on Next.js dashboard, not Modal API)
581
+ dashboard_client = await self._get_dashboard_client()
582
+ # httpx automatically handles multipart/form-data when both files and data are provided
583
+ response = await dashboard_client.post("/api/reports/create", files=files, data=data)
584
+ response.raise_for_status()
585
+
586
+ result_data = response.json()
587
+
588
+ # Check if duplicate
589
+ if result_data.get("duplicate"):
590
+ # For duplicates, get the run_id and fetch results
591
+ report_id = result_data.get("report_id")
592
+ run_id = result_data.get("run_id")
593
+
594
+ if not report_id or not run_id:
595
+ raise ValueError("Duplicate report found but missing report_id or run_id")
596
+
597
+ print(f"ℹ️ Duplicate report found (run_id: {run_id})")
598
+
599
+ # Construct dashboard URL for the processing page
600
+ progress_url = f"{self.dashboard_url}/reports/new/{run_id}/processing"
601
+ print(f"🔗 View progress: {progress_url}")
602
+
603
+ # If wait is True, fetch the full results for the existing report
604
+ if wait:
605
+ return await self.get_results(run_id)
606
+
607
+ # Otherwise return a minimal result with the run_id
608
+ return EngineResult(
609
+ run_id=run_id,
610
+ status="completed",
611
+ report_id=report_id,
612
+ )
613
+
614
+ run_id = result_data["run_id"]
615
+ print(f"✓ Run created: {run_id}")
616
+
617
+ # Construct dashboard URL for the processing page
618
+ progress_url = f"{self.dashboard_url}/reports/new/{run_id}/processing"
619
+ print(f"🔗 View progress: {progress_url}")
620
+
621
+ if wait:
622
+ # Wait for completion and return full results
623
+ return await self.wait_for_completion(run_id, timeout=wait_timeout)
624
+
625
+ # Return minimal result with pending status
626
+ return EngineResult(
627
+ run_id=run_id,
628
+ status="pending",
629
+ )
630
+
631
+ def run(
632
+ self,
633
+ file: Union[str, Path, "pd.DataFrame"],
634
+ target_column: str,
635
+ depth_iterations: int = 1,
636
+ title: Optional[str] = None,
637
+ description: Optional[str] = None,
638
+ column_descriptions: Optional[Dict[str, str]] = None,
639
+ excluded_columns: Optional[List[str]] = None,
640
+ task: Optional[str] = None,
641
+ visibility: str = "public",
642
+ timeseries_groups: Optional[List[Dict[str, Any]]] = None,
643
+ target_column_override: Optional[str] = None,
644
+ auto_report_use_llm_evals: bool = True,
645
+ author: Optional[str] = None,
646
+ source_url: Optional[str] = None,
647
+ wait: bool = False,
648
+ wait_timeout: Optional[float] = None,
649
+ **kwargs,
650
+ ) -> EngineResult:
651
+ """
652
+ Run analysis on a dataset (synchronous wrapper).
653
+
654
+ This is a synchronous wrapper around run_async().
655
+
656
+ Args:
657
+ file: File path, Path object, or pandas DataFrame
658
+ target_column: Name of the target column
659
+ depth_iterations: Number of iterative feature removal cycles (1 = fastest)
660
+ title: Optional dataset title
661
+ description: Optional dataset description
662
+ column_descriptions: Optional dict mapping column names to descriptions
663
+ excluded_columns: Optional list of column names to exclude from analysis
664
+ task: Task type (regression, binary_classification, multiclass_classification) - auto-detected if None
665
+ visibility: Dataset visibility ("public" or "private", default: "public")
666
+ timeseries_groups: Optional list of timeseries column groups
667
+ target_column_override: Optional override for target column name
668
+ auto_report_use_llm_evals: Use LLM evaluations (default: True)
669
+ author: Optional dataset author
670
+ source_url: Optional source URL
671
+ wait: If True, wait for analysis to complete and return full results
672
+ wait_timeout: Maximum seconds to wait for completion (only if wait=True)
673
+ **kwargs: Additional arguments passed to run_async()
674
+
675
+ Returns:
676
+ EngineResult with run_id and (if wait=True) complete results
677
+ """
678
+ coro = self.run_async(
679
+ file,
680
+ target_column,
681
+ depth_iterations,
682
+ title=title,
683
+ description=description,
684
+ column_descriptions=column_descriptions,
685
+ excluded_columns=excluded_columns,
686
+ task=task,
687
+ visibility=visibility,
688
+ timeseries_groups=timeseries_groups,
689
+ target_column_override=target_column_override,
690
+ auto_report_use_llm_evals=auto_report_use_llm_evals,
691
+ author=author,
692
+ source_url=source_url,
693
+ wait=wait,
694
+ wait_timeout=wait_timeout,
695
+ **kwargs,
696
+ )
697
+
698
+ # Try to run the coroutine
699
+ # If we're in a Jupyter notebook with a running event loop, asyncio.run() will fail
700
+ try:
701
+ return asyncio.run(coro)
702
+ except RuntimeError as e:
703
+ # Check if the error is about a running event loop
704
+ if "cannot be called from a running event loop" in str(e).lower():
705
+ # We're in a Jupyter/IPython environment with a running event loop
706
+ # Try to use nest_asyncio if available
707
+ try:
708
+ import nest_asyncio
709
+
710
+ # Apply nest_asyncio (it's safe to call multiple times)
711
+ nest_asyncio.apply()
712
+ # Now we can use asyncio.run() even with a running loop
713
+ return asyncio.run(coro)
714
+ except ImportError:
715
+ raise RuntimeError(
716
+ "Cannot use engine.run() in a Jupyter notebook or environment with a running event loop. "
717
+ "Please use 'await engine.run_async(...)' instead, or install nest_asyncio "
718
+ "(pip install nest-asyncio) to enable nested event loops."
719
+ ) from e
720
+ # Re-raise if it's a different RuntimeError
721
+ raise
722
+
723
+ def _parse_analysis_result(self, data: Dict[str, Any]) -> EngineResult:
724
+ """Parse API response into EngineResult dataclass."""
725
+ # Parse summary
726
+ summary = None
727
+ if data.get("summary"):
728
+ summary = self._parse_summary(data["summary"])
729
+
730
+ # Parse patterns
731
+ patterns = []
732
+ for p in data.get("patterns", []):
733
+ patterns.append(
734
+ Pattern(
735
+ id=p["id"],
736
+ task=p.get("task", "regression"),
737
+ target_column=p.get("target_column", ""),
738
+ direction=p.get("direction", "max"),
739
+ p_value=p.get("p_value", 0),
740
+ conditions=p.get("conditions", []),
741
+ lift_value=p.get("lift_value", 0),
742
+ support_count=p.get("support_count", 0),
743
+ support_percentage=p.get("support_percentage", 0),
744
+ pattern_type=p.get("pattern_type", "validated"),
745
+ novelty_type=p.get("novelty_type", "confirmatory"),
746
+ target_score=p.get("target_score", 0),
747
+ target_class=p.get("target_class"),
748
+ target_mean=p.get("target_mean"),
749
+ target_std=p.get("target_std"),
750
+ description=p.get("description", ""),
751
+ novelty_explanation=p.get("novelty_explanation", ""),
752
+ citations=p.get("citations", []),
753
+ )
754
+ )
755
+
756
+ # Parse columns
757
+ columns = []
758
+ for c in data.get("columns", []):
759
+ columns.append(
760
+ Column(
761
+ id=c["id"],
762
+ name=c["name"],
763
+ display_name=c.get("display_name", c["name"]),
764
+ type=c.get("type", "continuous"),
765
+ data_type=c.get("data_type", "float"),
766
+ enabled=c.get("enabled", True),
767
+ description=c.get("description"),
768
+ mean=c.get("mean"),
769
+ median=c.get("median"),
770
+ std=c.get("std"),
771
+ min=c.get("min"),
772
+ max=c.get("max"),
773
+ iqr_min=c.get("iqr_min"),
774
+ iqr_max=c.get("iqr_max"),
775
+ mode=c.get("mode"),
776
+ approx_unique=c.get("approx_unique"),
777
+ null_percentage=c.get("null_percentage"),
778
+ feature_importance_score=c.get("feature_importance_score"),
779
+ )
780
+ )
781
+
782
+ # Parse correlation matrix
783
+ correlation_matrix = []
784
+ for entry in data.get("correlation_matrix", []):
785
+ correlation_matrix.append(
786
+ CorrelationEntry(
787
+ feature_x=entry["feature_x"],
788
+ feature_y=entry["feature_y"],
789
+ value=entry["value"],
790
+ )
791
+ )
792
+
793
+ # Parse feature importance
794
+ feature_importance = None
795
+ if data.get("feature_importance"):
796
+ fi = data["feature_importance"]
797
+ scores = [
798
+ FeatureImportanceScore(feature=s["feature"], score=s["score"])
799
+ for s in fi.get("scores", [])
800
+ ]
801
+ feature_importance = FeatureImportance(
802
+ kind=fi.get("kind", "global"),
803
+ baseline=fi.get("baseline", 0),
804
+ scores=scores,
805
+ )
806
+
807
+ return EngineResult(
808
+ run_id=data["run_id"],
809
+ report_id=data.get("report_id"),
810
+ status=data.get("status", "unknown"),
811
+ dataset_title=data.get("dataset_title"),
812
+ dataset_description=data.get("dataset_description"),
813
+ total_rows=data.get("total_rows"),
814
+ target_column=data.get("target_column"),
815
+ task=data.get("task"),
816
+ summary=summary,
817
+ patterns=patterns,
818
+ columns=columns,
819
+ correlation_matrix=correlation_matrix,
820
+ feature_importance=feature_importance,
821
+ job_id=data.get("job_id"),
822
+ job_status=data.get("job_status"),
823
+ error_message=data.get("error_message"),
824
+ )
825
+
826
+ def _parse_summary(self, data: Dict[str, Any]) -> Summary:
827
+ """Parse summary data into Summary dataclass."""
828
+ # Parse data insights
829
+ data_insights = None
830
+ if data.get("data_insights"):
831
+ di = data["data_insights"]
832
+ data_insights = DataInsights(
833
+ important_features=di.get("important_features", []),
834
+ important_features_explanation=di.get("important_features_explanation", ""),
835
+ strong_correlations=di.get("strong_correlations", []),
836
+ strong_correlations_explanation=di.get("strong_correlations_explanation", ""),
837
+ notable_relationships=di.get("notable_relationships", []),
838
+ )
839
+
840
+ return Summary(
841
+ overview=data.get("overview", ""),
842
+ key_insights=data.get("key_insights", []),
843
+ novel_patterns=PatternGroup(
844
+ pattern_ids=data.get("novel_patterns", {}).get("pattern_ids", []),
845
+ explanation=data.get("novel_patterns", {}).get("explanation", ""),
846
+ ),
847
+ surprising_findings=PatternGroup(
848
+ pattern_ids=data.get("surprising_findings", {}).get("pattern_ids", []),
849
+ explanation=data.get("surprising_findings", {}).get("explanation", ""),
850
+ ),
851
+ statistically_significant=PatternGroup(
852
+ pattern_ids=data.get("statistically_significant", {}).get("pattern_ids", []),
853
+ explanation=data.get("statistically_significant", {}).get("explanation", ""),
854
+ ),
855
+ data_insights=data_insights,
856
+ selected_pattern_id=data.get("selected_pattern_id"),
857
+ )