ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,860 @@
1
+ """
2
+ Business Intelligence & Analytics Tools
3
+
4
+ Advanced business analytics tools for cohort analysis, RFM segmentation,
5
+ causal inference, and automated insight generation.
6
+ """
7
+
8
+ import polars as pl
9
+ import numpy as np
10
+ import pandas as pd
11
+ from typing import Dict, Any, List, Optional, Tuple
12
+ from datetime import datetime, timedelta
13
+ import json
14
+
15
+ # Statistical packages
16
+ try:
17
+ from scipy import stats
18
+ from scipy.stats import chi2_contingency, ttest_ind, f_oneway
19
+ except ImportError:
20
+ pass
21
+
22
+ try:
23
+ from statsmodels.tsa.stattools import grangercausalitytests
24
+ from statsmodels.stats.proportion import proportions_ztest
25
+ STATSMODELS_AVAILABLE = True
26
+ except ImportError:
27
+ STATSMODELS_AVAILABLE = False
28
+
29
+ # Causal inference (optional)
30
+ try:
31
+ from econml.dml import CausalForestDML
32
+ from econml.dr import DRLearner
33
+ ECONML_AVAILABLE = True
34
+ except ImportError:
35
+ ECONML_AVAILABLE = False
36
+
37
+ # Customer analytics (optional)
38
+ try:
39
+ from lifetimes import BetaGeoFitter, GammaGammaFitter
40
+ from lifetimes.utils import summary_data_from_transaction_data
41
+ LIFETIMES_AVAILABLE = True
42
+ except ImportError:
43
+ LIFETIMES_AVAILABLE = False
44
+
45
+ # For Groq API calls
46
+ import os
47
+ from groq import Groq
48
+
49
+
50
+ def perform_cohort_analysis(
51
+ data: pl.DataFrame,
52
+ customer_id_column: str,
53
+ date_column: str,
54
+ value_column: Optional[str] = None,
55
+ cohort_period: str = "monthly",
56
+ metric: str = "retention"
57
+ ) -> Dict[str, Any]:
58
+ """
59
+ Perform cohort analysis for customer retention, CLV, and churn analysis.
60
+
61
+ Args:
62
+ data: Input DataFrame with transaction/event data
63
+ customer_id_column: Column containing customer IDs
64
+ date_column: Column containing dates
65
+ value_column: Column containing transaction values (optional, for revenue cohorts)
66
+ cohort_period: Period for cohorts ('daily', 'weekly', 'monthly', 'quarterly')
67
+ metric: Metric to analyze ('retention', 'revenue', 'frequency', 'churn')
68
+
69
+ Returns:
70
+ Dictionary containing cohort analysis results, retention curves, and insights
71
+ """
72
+ print(f"🔍 Performing cohort analysis ({metric})...")
73
+
74
+ # Validate input
75
+ required_cols = [customer_id_column, date_column]
76
+ if metric == "revenue" and value_column:
77
+ required_cols.append(value_column)
78
+
79
+ for col in required_cols:
80
+ if col not in data.columns:
81
+ raise ValueError(f"Column '{col}' not found in DataFrame")
82
+
83
+ # Convert to pandas for easier date manipulation
84
+ df = data.to_pandas()
85
+
86
+ # Parse dates
87
+ df[date_column] = pd.to_datetime(df[date_column])
88
+
89
+ # Create cohort based on first purchase date
90
+ df['cohort'] = df.groupby(customer_id_column)[date_column].transform('min')
91
+
92
+ # Extract period from dates
93
+ period_map = {
94
+ 'daily': 'D',
95
+ 'weekly': 'W',
96
+ 'monthly': 'M',
97
+ 'quarterly': 'Q'
98
+ }
99
+
100
+ if cohort_period not in period_map:
101
+ raise ValueError(f"Unknown cohort_period '{cohort_period}'. Use: {list(period_map.keys())}")
102
+
103
+ period_format = {
104
+ 'daily': '%Y-%m-%d',
105
+ 'weekly': '%Y-W%U',
106
+ 'monthly': '%Y-%m',
107
+ 'quarterly': '%Y-Q%q'
108
+ }
109
+
110
+ df['cohort_period'] = df['cohort'].dt.to_period(period_map[cohort_period])
111
+ df['transaction_period'] = df[date_column].dt.to_period(period_map[cohort_period])
112
+
113
+ # Calculate period number (periods since cohort start)
114
+ df['period_number'] = (df['transaction_period'] - df['cohort_period']).apply(lambda x: x.n)
115
+
116
+ result = {
117
+ "metric": metric,
118
+ "cohort_period": cohort_period,
119
+ "total_customers": df[customer_id_column].nunique(),
120
+ "cohorts": []
121
+ }
122
+
123
+ try:
124
+ if metric == "retention":
125
+ # Retention analysis
126
+ cohort_data = df.groupby(['cohort_period', 'period_number']).agg({
127
+ customer_id_column: 'nunique'
128
+ }).reset_index()
129
+
130
+ cohort_data.columns = ['cohort_period', 'period_number', 'customers']
131
+
132
+ # Get cohort sizes (period 0)
133
+ cohort_sizes = cohort_data[cohort_data['period_number'] == 0].set_index('cohort_period')['customers']
134
+
135
+ # Calculate retention rates
136
+ cohort_data['cohort_size'] = cohort_data['cohort_period'].map(cohort_sizes)
137
+ cohort_data['retention_rate'] = cohort_data['customers'] / cohort_data['cohort_size']
138
+
139
+ # Pivot for cohort matrix
140
+ cohort_matrix = cohort_data.pivot(
141
+ index='cohort_period',
142
+ columns='period_number',
143
+ values='retention_rate'
144
+ )
145
+
146
+ result["cohort_matrix"] = cohort_matrix.to_dict()
147
+ result["avg_retention_by_period"] = cohort_matrix.mean().to_dict()
148
+
149
+ # Calculate churn (1 - retention)
150
+ result["avg_churn_by_period"] = (1 - cohort_matrix.mean()).to_dict()
151
+
152
+ # Retention curve (average across all cohorts)
153
+ retention_curve = cohort_matrix.mean().to_list()
154
+ result["retention_curve"] = retention_curve
155
+
156
+ elif metric == "revenue" and value_column:
157
+ # Revenue cohort analysis
158
+ cohort_data = df.groupby(['cohort_period', 'period_number']).agg({
159
+ value_column: 'sum',
160
+ customer_id_column: 'nunique'
161
+ }).reset_index()
162
+
163
+ cohort_data.columns = ['cohort_period', 'period_number', 'revenue', 'customers']
164
+
165
+ # Revenue per customer
166
+ cohort_data['revenue_per_customer'] = cohort_data['revenue'] / cohort_data['customers']
167
+
168
+ # Pivot for cohort matrix
169
+ cohort_matrix = cohort_data.pivot(
170
+ index='cohort_period',
171
+ columns='period_number',
172
+ values='revenue_per_customer'
173
+ )
174
+
175
+ result["cohort_matrix"] = cohort_matrix.to_dict()
176
+ result["avg_revenue_by_period"] = cohort_matrix.mean().to_dict()
177
+
178
+ # Cumulative revenue
179
+ cumulative_revenue = cohort_matrix.fillna(0).cumsum(axis=1)
180
+ result["cumulative_revenue"] = cumulative_revenue.mean().to_dict()
181
+
182
+ # Lifetime value estimate (sum of all periods)
183
+ result["estimated_ltv"] = float(cohort_matrix.sum(axis=1).mean())
184
+
185
+ elif metric == "frequency":
186
+ # Frequency analysis (purchases per period)
187
+ cohort_data = df.groupby(['cohort_period', 'period_number', customer_id_column]).size().reset_index(name='transactions')
188
+
189
+ cohort_summary = cohort_data.groupby(['cohort_period', 'period_number']).agg({
190
+ 'transactions': 'mean',
191
+ customer_id_column: 'count'
192
+ }).reset_index()
193
+
194
+ cohort_summary.columns = ['cohort_period', 'period_number', 'avg_transactions', 'active_customers']
195
+
196
+ # Pivot
197
+ cohort_matrix = cohort_summary.pivot(
198
+ index='cohort_period',
199
+ columns='period_number',
200
+ values='avg_transactions'
201
+ )
202
+
203
+ result["cohort_matrix"] = cohort_matrix.to_dict()
204
+ result["avg_frequency_by_period"] = cohort_matrix.mean().to_dict()
205
+
206
+ # Cohort-level statistics
207
+ cohort_stats = []
208
+ for cohort in df['cohort_period'].unique():
209
+ cohort_df = df[df['cohort_period'] == cohort]
210
+
211
+ stats_dict = {
212
+ "cohort": str(cohort),
213
+ "size": int(cohort_df[customer_id_column].nunique()),
214
+ "total_transactions": int(len(cohort_df)),
215
+ "avg_transactions_per_customer": float(len(cohort_df) / cohort_df[customer_id_column].nunique())
216
+ }
217
+
218
+ if value_column:
219
+ stats_dict["total_revenue"] = float(cohort_df[value_column].sum())
220
+ stats_dict["avg_revenue_per_customer"] = float(cohort_df[value_column].sum() / cohort_df[customer_id_column].nunique())
221
+
222
+ cohort_stats.append(stats_dict)
223
+
224
+ result["cohort_statistics"] = cohort_stats
225
+
226
+ # Calculate key insights
227
+ result["insights"] = _generate_cohort_insights(result, metric)
228
+
229
+ print(f"✅ Cohort analysis complete!")
230
+ print(f" Total customers: {result['total_customers']}")
231
+ print(f" Cohorts analyzed: {len(cohort_stats)}")
232
+
233
+ return result
234
+
235
+ except Exception as e:
236
+ print(f"❌ Error during cohort analysis: {str(e)}")
237
+ raise
238
+
239
+
240
+ def _generate_cohort_insights(result: Dict[str, Any], metric: str) -> List[str]:
241
+ """Generate insights from cohort analysis."""
242
+ insights = []
243
+
244
+ if metric == "retention" and "retention_curve" in result:
245
+ retention = result["retention_curve"]
246
+ if len(retention) > 1:
247
+ initial_drop = (retention[0] - retention[1]) * 100
248
+ insights.append(f"Initial retention drop: {initial_drop:.1f}% in first period")
249
+
250
+ if len(retention) > 3:
251
+ month_3_retention = retention[3] * 100
252
+ insights.append(f"3-period retention: {month_3_retention:.1f}%")
253
+
254
+ if metric == "revenue" and "estimated_ltv" in result:
255
+ ltv = result["estimated_ltv"]
256
+ insights.append(f"Estimated customer lifetime value: ${ltv:.2f}")
257
+
258
+ return insights
259
+
260
+
261
+ def perform_rfm_analysis(
262
+ data: pl.DataFrame,
263
+ customer_id_column: str,
264
+ date_column: str,
265
+ value_column: str,
266
+ reference_date: Optional[str] = None,
267
+ rfm_bins: int = 5
268
+ ) -> Dict[str, Any]:
269
+ """
270
+ Perform RFM (Recency, Frequency, Monetary) analysis for customer segmentation.
271
+
272
+ Args:
273
+ data: Input DataFrame with transaction data
274
+ customer_id_column: Column containing customer IDs
275
+ date_column: Column containing transaction dates
276
+ value_column: Column containing transaction values
277
+ reference_date: Reference date for recency calculation (default: max date in data)
278
+ rfm_bins: Number of bins for RFM scoring (typically 3, 4, or 5)
279
+
280
+ Returns:
281
+ Dictionary containing RFM scores, segments, and customer profiles
282
+ """
283
+ print(f"🔍 Performing RFM analysis...")
284
+
285
+ # Validate input
286
+ required_cols = [customer_id_column, date_column, value_column]
287
+ for col in required_cols:
288
+ if col not in data.columns:
289
+ raise ValueError(f"Column '{col}' not found in DataFrame")
290
+
291
+ # Convert to pandas
292
+ df = data.to_pandas()
293
+ df[date_column] = pd.to_datetime(df[date_column])
294
+
295
+ # Set reference date
296
+ if reference_date:
297
+ ref_date = pd.to_datetime(reference_date)
298
+ else:
299
+ ref_date = df[date_column].max()
300
+
301
+ print(f" Reference date: {ref_date.strftime('%Y-%m-%d')}")
302
+
303
+ # Calculate RFM metrics
304
+ rfm = df.groupby(customer_id_column).agg({
305
+ date_column: lambda x: (ref_date - x.max()).days, # Recency
306
+ customer_id_column: 'count', # Frequency
307
+ value_column: 'sum' # Monetary
308
+ })
309
+
310
+ rfm.columns = ['recency', 'frequency', 'monetary']
311
+
312
+ # RFM Scoring (1-5, where 5 is best)
313
+ # Note: For recency, lower is better, so we reverse the scoring
314
+ rfm['r_score'] = pd.qcut(rfm['recency'], rfm_bins, labels=range(rfm_bins, 0, -1), duplicates='drop')
315
+ rfm['f_score'] = pd.qcut(rfm['frequency'].rank(method='first'), rfm_bins, labels=range(1, rfm_bins+1), duplicates='drop')
316
+ rfm['m_score'] = pd.qcut(rfm['monetary'].rank(method='first'), rfm_bins, labels=range(1, rfm_bins+1), duplicates='drop')
317
+
318
+ # Convert to int
319
+ rfm['r_score'] = rfm['r_score'].astype(int)
320
+ rfm['f_score'] = rfm['f_score'].astype(int)
321
+ rfm['m_score'] = rfm['m_score'].astype(int)
322
+
323
+ # RFM Score (concatenated)
324
+ rfm['rfm_score'] = rfm['r_score'].astype(str) + rfm['f_score'].astype(str) + rfm['m_score'].astype(str)
325
+
326
+ # RFM Total Score (sum)
327
+ rfm['rfm_total'] = rfm['r_score'] + rfm['f_score'] + rfm['m_score']
328
+
329
+ # Segment customers based on RFM scores
330
+ def segment_customer(row):
331
+ r, f, m = row['r_score'], row['f_score'], row['m_score']
332
+
333
+ if r >= 4 and f >= 4 and m >= 4:
334
+ return "Champions"
335
+ elif r >= 4 and f >= 3:
336
+ return "Loyal Customers"
337
+ elif r >= 4 and f < 3:
338
+ return "Potential Loyalists"
339
+ elif r >= 3 and f >= 3 and m >= 3:
340
+ return "Recent Customers"
341
+ elif r >= 3 and m >= 4:
342
+ return "Big Spenders"
343
+ elif r < 3 and f >= 4:
344
+ return "At Risk"
345
+ elif r < 3 and f < 3 and m >= 4:
346
+ return "Can't Lose Them"
347
+ elif r < 2:
348
+ return "Lost"
349
+ else:
350
+ return "Needs Attention"
351
+
352
+ rfm['segment'] = rfm.apply(segment_customer, axis=1)
353
+
354
+ # Results
355
+ result = {
356
+ "total_customers": len(rfm),
357
+ "reference_date": ref_date.strftime('%Y-%m-%d'),
358
+ "rfm_bins": rfm_bins,
359
+ "rfm_data": rfm.reset_index().to_dict('records'),
360
+ "segment_summary": {},
361
+ "rfm_statistics": {}
362
+ }
363
+
364
+ # Segment summary
365
+ segment_stats = rfm.groupby('segment').agg({
366
+ 'recency': ['mean', 'median'],
367
+ 'frequency': ['mean', 'median'],
368
+ 'monetary': ['mean', 'median', 'sum'],
369
+ customer_id_column: 'count'
370
+ }).round(2)
371
+
372
+ for segment in rfm['segment'].unique():
373
+ segment_data = rfm[rfm['segment'] == segment]
374
+ result["segment_summary"][segment] = {
375
+ "count": int(len(segment_data)),
376
+ "percentage": float(len(segment_data) / len(rfm) * 100),
377
+ "avg_recency": float(segment_data['recency'].mean()),
378
+ "avg_frequency": float(segment_data['frequency'].mean()),
379
+ "avg_monetary": float(segment_data['monetary'].mean()),
380
+ "total_revenue": float(segment_data['monetary'].sum())
381
+ }
382
+
383
+ # Overall RFM statistics
384
+ result["rfm_statistics"] = {
385
+ "recency": {
386
+ "mean": float(rfm['recency'].mean()),
387
+ "median": float(rfm['recency'].median()),
388
+ "min": int(rfm['recency'].min()),
389
+ "max": int(rfm['recency'].max())
390
+ },
391
+ "frequency": {
392
+ "mean": float(rfm['frequency'].mean()),
393
+ "median": float(rfm['frequency'].median()),
394
+ "min": int(rfm['frequency'].min()),
395
+ "max": int(rfm['frequency'].max())
396
+ },
397
+ "monetary": {
398
+ "mean": float(rfm['monetary'].mean()),
399
+ "median": float(rfm['monetary'].median()),
400
+ "min": float(rfm['monetary'].min()),
401
+ "max": float(rfm['monetary'].max()),
402
+ "total": float(rfm['monetary'].sum())
403
+ }
404
+ }
405
+
406
+ # Top customers by RFM score
407
+ result["top_customers"] = rfm.nlargest(20, 'rfm_total').reset_index().to_dict('records')
408
+
409
+ # Actionable insights
410
+ result["recommendations"] = _generate_rfm_recommendations(result)
411
+
412
+ print(f"✅ RFM analysis complete!")
413
+ print(f" Total customers: {result['total_customers']}")
414
+ print(f" Segments: {len(result['segment_summary'])}")
415
+ print(f" Top segment: {max(result['segment_summary'].items(), key=lambda x: x[1]['count'])[0]}")
416
+
417
+ return result
418
+
419
+
420
+ def _generate_rfm_recommendations(result: Dict[str, Any]) -> Dict[str, List[str]]:
421
+ """Generate actionable recommendations based on RFM segments."""
422
+
423
+ recommendations = {}
424
+
425
+ segment_actions = {
426
+ "Champions": [
427
+ "Reward with exclusive perks and early access to new products",
428
+ "Request reviews and referrals",
429
+ "Engage for product development feedback"
430
+ ],
431
+ "Loyal Customers": [
432
+ "Upsell higher value products",
433
+ "Offer loyalty rewards",
434
+ "Encourage referrals with incentives"
435
+ ],
436
+ "Potential Loyalists": [
437
+ "Recommend related products",
438
+ "Offer membership or loyalty program",
439
+ "Engage with personalized communication"
440
+ ],
441
+ "Recent Customers": [
442
+ "Provide onboarding support",
443
+ "Build relationships with targeted content",
444
+ "Offer starter discounts for repeat purchases"
445
+ ],
446
+ "Big Spenders": [
447
+ "Target with premium products",
448
+ "Increase engagement frequency",
449
+ "Offer VIP treatment"
450
+ ],
451
+ "At Risk": [
452
+ "Send win-back campaigns",
453
+ "Offer special discounts or incentives",
454
+ "Gather feedback on their experience"
455
+ ],
456
+ "Can't Lose Them": [
457
+ "Aggressive win-back campaigns",
458
+ "Personalized outreach",
459
+ "Offer significant incentives"
460
+ ],
461
+ "Lost": [
462
+ "Run re-engagement campaigns",
463
+ "Survey for feedback",
464
+ "Consider removing from active campaigns"
465
+ ],
466
+ "Needs Attention": [
467
+ "Offer limited-time promotions",
468
+ "Share valuable content",
469
+ "Re-engage with surveys"
470
+ ]
471
+ }
472
+
473
+ for segment, actions in segment_actions.items():
474
+ if segment in result["segment_summary"]:
475
+ recommendations[segment] = actions
476
+
477
+ return recommendations
478
+
479
+
480
+ def detect_causal_relationships(
481
+ data: pl.DataFrame,
482
+ treatment_column: str,
483
+ outcome_column: str,
484
+ covariates: Optional[List[str]] = None,
485
+ method: str = "granger",
486
+ max_lag: int = 5,
487
+ confidence_level: float = 0.95
488
+ ) -> Dict[str, Any]:
489
+ """
490
+ Detect causal relationships using Granger causality, propensity matching, or uplift modeling.
491
+
492
+ Args:
493
+ data: Input DataFrame
494
+ treatment_column: Column indicating treatment/intervention
495
+ outcome_column: Column indicating outcome variable
496
+ covariates: List of covariate columns for adjustment
497
+ method: Method for causal inference ('granger', 'propensity', 'uplift')
498
+ max_lag: Maximum lag for Granger causality test
499
+ confidence_level: Confidence level for statistical tests
500
+
501
+ Returns:
502
+ Dictionary containing causal inference results and effect estimates
503
+ """
504
+ print(f"🔍 Detecting causal relationships using {method} method...")
505
+
506
+ # Validate input
507
+ required_cols = [treatment_column, outcome_column]
508
+ if covariates:
509
+ required_cols.extend(covariates)
510
+
511
+ for col in required_cols:
512
+ if col not in data.columns:
513
+ raise ValueError(f"Column '{col}' not found in DataFrame")
514
+
515
+ result = {
516
+ "method": method,
517
+ "treatment": treatment_column,
518
+ "outcome": outcome_column,
519
+ "covariates": covariates or [],
520
+ "causal_effect": None,
521
+ "statistical_significance": None
522
+ }
523
+
524
+ try:
525
+ if method == "granger" and STATSMODELS_AVAILABLE:
526
+ # Granger causality test for time series
527
+ print(f" Testing Granger causality with max lag = {max_lag}...")
528
+
529
+ # Convert to pandas
530
+ df = data.select([treatment_column, outcome_column]).to_pandas()
531
+
532
+ # Ensure numeric
533
+ df = df.apply(pd.to_numeric, errors='coerce').dropna()
534
+
535
+ # Test both directions
536
+ test_result = grangercausalitytests(
537
+ df[[outcome_column, treatment_column]],
538
+ max_lag,
539
+ verbose=False
540
+ )
541
+
542
+ # Extract p-values for each lag
543
+ granger_results = []
544
+ for lag in range(1, max_lag + 1):
545
+ ssr_ftest = test_result[lag][0]['ssr_ftest']
546
+ granger_results.append({
547
+ "lag": lag,
548
+ "f_statistic": float(ssr_ftest[0]),
549
+ "p_value": float(ssr_ftest[1]),
550
+ "significant": ssr_ftest[1] < (1 - confidence_level)
551
+ })
552
+
553
+ result["granger_causality"] = granger_results
554
+ result["causal_effect"] = any(r["significant"] for r in granger_results)
555
+ result["statistical_significance"] = min(r["p_value"] for r in granger_results)
556
+
557
+ elif method == "propensity":
558
+ # Propensity score matching
559
+ print(" Performing propensity score matching...")
560
+
561
+ df = data.to_pandas()
562
+
563
+ # Ensure treatment is binary
564
+ treatment = df[treatment_column]
565
+ if treatment.nunique() > 2:
566
+ raise ValueError(f"Treatment column must be binary for propensity matching")
567
+
568
+ outcome = df[outcome_column]
569
+
570
+ # Simple comparison without covariates
571
+ if not covariates:
572
+ treated = outcome[treatment == 1]
573
+ control = outcome[treatment == 0]
574
+
575
+ # Calculate average treatment effect
576
+ ate = treated.mean() - control.mean()
577
+
578
+ # T-test for significance
579
+ t_stat, p_value = ttest_ind(treated, control)
580
+
581
+ result["average_treatment_effect"] = float(ate)
582
+ result["t_statistic"] = float(t_stat)
583
+ result["p_value"] = float(p_value)
584
+ result["statistical_significance"] = float(p_value)
585
+ result["causal_effect"] = float(ate)
586
+ result["confidence_interval"] = [
587
+ float(ate - 1.96 * np.sqrt(treated.var()/len(treated) + control.var()/len(control))),
588
+ float(ate + 1.96 * np.sqrt(treated.var()/len(treated) + control.var()/len(control)))
589
+ ]
590
+ else:
591
+ # With covariates (simplified - use logistic regression for propensity)
592
+ from sklearn.linear_model import LogisticRegression
593
+ from sklearn.neighbors import NearestNeighbors
594
+
595
+ X = df[covariates].apply(pd.to_numeric, errors='coerce').fillna(0)
596
+
597
+ # Estimate propensity scores
598
+ ps_model = LogisticRegression(max_iter=1000)
599
+ ps_model.fit(X, treatment)
600
+ propensity_scores = ps_model.predict_proba(X)[:, 1]
601
+
602
+ df['propensity_score'] = propensity_scores
603
+
604
+ # Matching (1:1 nearest neighbor)
605
+ treated_df = df[treatment == 1]
606
+ control_df = df[treatment == 0]
607
+
608
+ # Simple matching on propensity scores
609
+ nn = NearestNeighbors(n_neighbors=1)
610
+ nn.fit(control_df[['propensity_score']])
611
+
612
+ distances, indices = nn.kneighbors(treated_df[['propensity_score']])
613
+ matched_control = control_df.iloc[indices.flatten()]
614
+
615
+ # Calculate ATE on matched sample
616
+ ate = treated_df[outcome_column].mean() - matched_control[outcome_column].mean()
617
+
618
+ result["average_treatment_effect"] = float(ate)
619
+ result["n_matched_pairs"] = len(treated_df)
620
+ result["causal_effect"] = float(ate)
621
+
622
+ elif method == "uplift":
623
+ # Uplift modeling (treatment effect heterogeneity)
624
+ print(" Calculating uplift/treatment effect...")
625
+
626
+ df = data.to_pandas()
627
+
628
+ treatment = df[treatment_column]
629
+ outcome = df[outcome_column]
630
+
631
+ # Calculate uplift by treatment group
632
+ treated_outcome = outcome[treatment == 1].mean()
633
+ control_outcome = outcome[treatment == 0].mean()
634
+
635
+ uplift = treated_outcome - control_outcome
636
+
637
+ # Statistical significance
638
+ t_stat, p_value = ttest_ind(
639
+ outcome[treatment == 1],
640
+ outcome[treatment == 0]
641
+ )
642
+
643
+ result["uplift"] = float(uplift)
644
+ result["treated_mean"] = float(treated_outcome)
645
+ result["control_mean"] = float(control_outcome)
646
+ result["relative_uplift"] = float(uplift / control_outcome * 100) if control_outcome != 0 else 0
647
+ result["t_statistic"] = float(t_stat)
648
+ result["p_value"] = float(p_value)
649
+ result["statistical_significance"] = float(p_value)
650
+ result["causal_effect"] = float(uplift)
651
+
652
+ elif method == "dowhy":
653
+ # DoWhy causal inference - formal causal graph approach
654
+ try:
655
+ import dowhy
656
+ from dowhy import CausalModel
657
+ except ImportError:
658
+ raise ValueError("dowhy not installed. Install with: pip install dowhy>=0.11")
659
+
660
+ print(" Building DoWhy causal model...")
661
+
662
+ df = data.to_pandas()
663
+
664
+ # Build causal model
665
+ # Construct a simple causal graph: covariates -> treatment -> outcome
666
+ if covariates:
667
+ graph_dot = f'digraph {{ {treatment_column} -> {outcome_column};'
668
+ for cov in covariates:
669
+ graph_dot += f' {cov} -> {treatment_column}; {cov} -> {outcome_column};'
670
+ graph_dot += ' }'
671
+ else:
672
+ graph_dot = f'digraph {{ {treatment_column} -> {outcome_column}; }}'
673
+
674
+ model = CausalModel(
675
+ data=df,
676
+ treatment=treatment_column,
677
+ outcome=outcome_column,
678
+ common_causes=covariates,
679
+ graph=graph_dot
680
+ )
681
+
682
+ # Identify causal effect
683
+ identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
684
+
685
+ # Estimate using linear regression (lightweight)
686
+ estimate = model.estimate_effect(
687
+ identified_estimand,
688
+ method_name="backdoor.linear_regression"
689
+ )
690
+
691
+ # Refutation test (placebo treatment)
692
+ try:
693
+ refutation = model.refute_estimate(
694
+ identified_estimand,
695
+ estimate,
696
+ method_name="placebo_treatment_refuter",
697
+ placebo_type="permute",
698
+ num_simulations=20
699
+ )
700
+ refutation_result = {
701
+ "new_effect": float(refutation.new_effect) if hasattr(refutation, 'new_effect') else None,
702
+ "p_value": float(refutation.refutation_result.get('p_value', 1.0)) if hasattr(refutation, 'refutation_result') and isinstance(refutation.refutation_result, dict) else None
703
+ }
704
+ except Exception:
705
+ refutation_result = {"note": "Refutation test could not be completed"}
706
+
707
+ result["causal_effect"] = float(estimate.value)
708
+ result["estimand"] = str(identified_estimand)
709
+ result["estimation_method"] = "backdoor.linear_regression"
710
+ result["refutation"] = refutation_result
711
+ result["statistical_significance"] = None # DoWhy uses refutation instead
712
+
713
+ else:
714
+ raise ValueError(f"Unknown method '{method}'. Use 'granger', 'propensity', 'uplift', or 'dowhy'")
715
+
716
+ print(f"✅ Causal analysis complete!")
717
+ if result.get("causal_effect") is not None:
718
+ print(f" Estimated causal effect: {result['causal_effect']:.4f}")
719
+
720
+ return result
721
+
722
+ except Exception as e:
723
+ print(f"❌ Error during causal analysis: {str(e)}")
724
+ raise
725
+
726
+
727
+ def generate_business_insights(
728
+ data: pl.DataFrame,
729
+ analysis_type: str,
730
+ analysis_results: Dict[str, Any],
731
+ additional_context: Optional[str] = None,
732
+ groq_api_key: Optional[str] = None
733
+ ) -> Dict[str, Any]:
734
+ """
735
+ Generate natural language business insights using Groq LLM.
736
+
737
+ Args:
738
+ data: Input DataFrame (for context)
739
+ analysis_type: Type of analysis ('rfm', 'cohort', 'causal', 'general')
740
+ analysis_results: Results from previous analysis (dict)
741
+ additional_context: Additional business context
742
+ groq_api_key: Groq API key (if not in environment)
743
+
744
+ Returns:
745
+ Dictionary containing natural language insights and recommendations
746
+ """
747
+ print(f"🔍 Generating business insights for {analysis_type} analysis...")
748
+
749
+ # Get API key
750
+ api_key = groq_api_key or os.getenv("GROQ_API_KEY")
751
+ if not api_key:
752
+ raise ValueError("Groq API key not found. Set GROQ_API_KEY environment variable or pass groq_api_key parameter")
753
+
754
+ client = Groq(api_key=api_key)
755
+
756
+ # Prepare data summary
757
+ data_summary = {
758
+ "shape": data.shape,
759
+ "columns": data.columns,
760
+ "dtypes": {col: str(dtype) for col, dtype in zip(data.columns, data.dtypes)},
761
+ "sample_stats": {}
762
+ }
763
+
764
+ # Add numeric column stats
765
+ for col in data.columns:
766
+ if data[col].dtype in [pl.Int32, pl.Int64, pl.Float32, pl.Float64]:
767
+ data_summary["sample_stats"][col] = {
768
+ "mean": float(data[col].mean()),
769
+ "median": float(data[col].median()),
770
+ "std": float(data[col].std()),
771
+ "min": float(data[col].min()),
772
+ "max": float(data[col].max())
773
+ }
774
+
775
+ # Create prompt based on analysis type
776
+ prompt = f"""You are a senior business analyst. Analyze the following data and provide actionable business insights.
777
+
778
+ Analysis Type: {analysis_type.upper()}
779
+
780
+ Data Summary:
781
+ {json.dumps(data_summary, indent=2)}
782
+
783
+ Analysis Results:
784
+ {json.dumps(analysis_results, indent=2)}
785
+
786
+ Additional Context:
787
+ {additional_context or 'None provided'}
788
+
789
+ Please provide:
790
+ 1. Key findings (3-5 bullet points)
791
+ 2. Business implications
792
+ 3. Actionable recommendations (3-5 specific actions)
793
+ 4. Risk factors or caveats
794
+ 5. Suggested next steps
795
+
796
+ Format your response as a structured business report."""
797
+
798
+ try:
799
+ # Call Groq API
800
+ response = client.chat.completions.create(
801
+ model="llama-3.3-70b-versatile",
802
+ messages=[
803
+ {
804
+ "role": "system",
805
+ "content": "You are a senior business analyst specializing in data-driven insights and strategic recommendations. Provide clear, actionable insights based on data analysis."
806
+ },
807
+ {
808
+ "role": "user",
809
+ "content": prompt
810
+ }
811
+ ],
812
+ temperature=0.3,
813
+ max_tokens=2000
814
+ )
815
+
816
+ insights_text = response.choices[0].message.content
817
+
818
+ # Parse insights (simple structure)
819
+ result = {
820
+ "analysis_type": analysis_type,
821
+ "insights_summary": insights_text,
822
+ "generated_at": datetime.now().isoformat(),
823
+ "model": "llama-3.3-70b-versatile",
824
+ "data_context": data_summary
825
+ }
826
+
827
+ # Try to extract structured sections
828
+ sections = {}
829
+ current_section = None
830
+
831
+ for line in insights_text.split('\n'):
832
+ line = line.strip()
833
+ if line.startswith('1.') or line.lower().startswith('key findings'):
834
+ current_section = 'key_findings'
835
+ sections[current_section] = []
836
+ elif line.startswith('2.') or line.lower().startswith('business implications'):
837
+ current_section = 'implications'
838
+ sections[current_section] = []
839
+ elif line.startswith('3.') or line.lower().startswith('actionable recommendations'):
840
+ current_section = 'recommendations'
841
+ sections[current_section] = []
842
+ elif line.startswith('4.') or line.lower().startswith('risk'):
843
+ current_section = 'risks'
844
+ sections[current_section] = []
845
+ elif line.startswith('5.') or line.lower().startswith('next steps'):
846
+ current_section = 'next_steps'
847
+ sections[current_section] = []
848
+ elif current_section and line:
849
+ sections[current_section].append(line)
850
+
851
+ result["structured_insights"] = sections
852
+
853
+ print(f"✅ Business insights generated!")
854
+ print(f" Sections: {', '.join(sections.keys())}")
855
+
856
+ return result
857
+
858
+ except Exception as e:
859
+ print(f"❌ Error generating insights: {str(e)}")
860
+ raise