imsciences 0.9.6.9__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
imsciences/vis.py CHANGED
@@ -1,179 +1,719 @@
1
+ import inspect
2
+
3
+ import numpy as np
1
4
  import pandas as pd
2
5
  import plotly.express as px
3
- import plotly.graph_objs as go
6
+ import plotly.graph_objects as go
7
+
8
+
9
+ class datavis:
10
+ def __init__(self):
11
+ """Initialize DataVis with default theme settings."""
12
+ self.themes = {
13
+ "default": {
14
+ "template": "plotly_white",
15
+ "colorscale": "viridis",
16
+ "line_color": "#1f77b4",
17
+ "background_color": "white",
18
+ "grid_color": "lightgray",
19
+ "text_color": "black",
20
+ "font_family": "Raleway, sans-serif",
21
+ "font_size": 12,
22
+ },
23
+ "dark": {
24
+ "template": "plotly_dark",
25
+ "colorscale": "plasma",
26
+ "line_color": "#f07b16",
27
+ "background_color": "#2f3136",
28
+ "grid_color": "#1cd416",
29
+ "text_color": "white",
30
+ "font_family": "Raleway, sans-serif",
31
+ "font_size": 12,
32
+ },
33
+ "business": {
34
+ "template": "plotly_white",
35
+ "colorscale": "blues",
36
+ "line_color": "#0e4272",
37
+ "background_color": "white",
38
+ "grid_color": "#e6e6e6",
39
+ "text_color": "#921919",
40
+ "font_family": "Raleway, sans-serif",
41
+ "font_size": 11,
42
+ },
43
+ "scientific": {
44
+ "template": "plotly_dark",
45
+ "colorscale": "rdylbu",
46
+ "line_color": "#d62728",
47
+ "background_color": "white",
48
+ "grid_color": "#f0f0f0",
49
+ "text_color": "black",
50
+ "font_family": "Raleway, sans-serif",
51
+ "font_size": 10,
52
+ },
53
+ }
54
+ self.current_theme = "dark"
55
+
56
+ def help(self, method=None, *, show_examples=True):
57
+ """
58
+ Enhanced help system with detailed information about all methods.
59
+
60
+ Parameters
61
+ ----------
62
+ method : str, optional
63
+ Specific method to get help for. If None, shows overview of all methods.
64
+ show_examples : bool, default True
65
+ Whether to show usage examples.
66
+
67
+ Usage:
68
+ ------
69
+ vis.help() # Show all methods
70
+ vis.help('plot_one') # Show help for specific method
71
+ vis.help('plot_chart', show_examples=False) # Show help without examples
72
+
73
+ """
74
+ if method:
75
+ self._show_method_help(method, show_examples=show_examples)
76
+ else:
77
+ self._show_overview_help(show_examples=show_examples)
78
+
79
+ def _show_overview_help(self, *, show_examples=True):
80
+ """Display overview of all available methods."""
81
+ print("=" * 80)
82
+ print("DataVis Class - Comprehensive Data Visualization Tool")
83
+ print("=" * 80)
84
+ print(f"Current Theme: {self.current_theme}")
85
+ print(f"Available Themes: {', '.join(self.themes.keys())}")
86
+ print("\n📊 AVAILABLE METHODS:\n")
87
+
88
+ methods_info = [
89
+ {
90
+ "name": "plot_one",
91
+ "description": "Plot a single time series from a DataFrame",
92
+ "params": "df, column, date_column",
93
+ "use_case": "Single metric tracking over time",
94
+ },
95
+ {
96
+ "name": "plot_two",
97
+ "description": "Compare two metrics from different DataFrames",
98
+ "params": "data_config, same_axis=True",
99
+ "use_case": "Comparative analysis of two time series",
100
+ },
101
+ {
102
+ "name": "plot_chart",
103
+ "description": "Create various chart types (line, bar, scatter, etc.)",
104
+ "params": "data_config",
105
+ "use_case": "Flexible charting with multiple chart types",
106
+ },
107
+ {
108
+ "name": "plot_correlation",
109
+ "description": "Generate correlation heatmaps",
110
+ "params": 'df, columns=None, method="pearson"',
111
+ "use_case": "Analyze relationships between variables",
112
+ },
113
+ {
114
+ "name": "plot_sankey",
115
+ "description": "Create Sankey diagrams for flow visualization",
116
+ "params": "df, source_col, target_col, value_col, title=None, color_mapping=None",
117
+ "use_case": "Visualize flow/process data",
118
+ },
119
+ {
120
+ "name": "set_theme",
121
+ "description": "Set global theme for all charts",
122
+ "params": "theme_name",
123
+ "use_case": "Consistent styling across visualizations",
124
+ },
125
+ {
126
+ "name": "help",
127
+ "description": "Get detailed help for methods",
128
+ "params": "method=None, show_examples=True",
129
+ "use_case": "Learn how to use the visualization tools",
130
+ },
131
+ ]
132
+
133
+ for i, method in enumerate(methods_info, 1):
134
+ print(f"{i}. {method['name']}")
135
+ print(f" 📝 Description: {method['description']}")
136
+ print(f" ⚙️ Parameters: {method['params']}")
137
+ print(f" 🎯 Use Case: {method['use_case']}")
138
+ print()
139
+
140
+ if show_examples:
141
+ print("💡 QUICK START EXAMPLES:")
142
+ print(" vis.help('plot_one') # Get detailed help for plot_one")
143
+ print(" vis.set_theme('dark') # Switch to dark theme")
144
+ print(" vis.plot_one(df, 'sales', 'date') # Plot sales over time")
145
+ print(" vis.plot_correlation(df) # Create correlation heatmap")
146
+ print()
147
+
148
+ print("🔧 For detailed help on any method, use: vis.help('method_name')")
149
+ print("=" * 80)
150
+
151
+ def _show_method_help(self, method_name, *, show_examples=True):
152
+ """Display detailed help for a specific method."""
153
+ if not hasattr(self, method_name):
154
+ print(f"❌ Method '{method_name}' not found!")
155
+ print(
156
+ f"Available methods: {[m for m in dir(self) if not m.startswith('_') and callable(getattr(self, m))]}"
157
+ )
158
+ return
159
+
160
+ method = getattr(self, method_name)
161
+
162
+ print("=" * 60)
163
+ print(f"📊 DETAILED HELP: {method_name}")
164
+ print("=" * 60)
165
+
166
+ # Get docstring
167
+ doc = inspect.getdoc(method)
168
+ if doc:
169
+ print(f"📝 {doc}")
170
+ print()
171
+
172
+ # Get method signature
173
+ sig = inspect.signature(method)
174
+ print(f"🔧 Signature: {method_name}{sig}")
175
+ print()
176
+
177
+ # Method-specific examples
178
+ if show_examples:
179
+ examples = self._get_method_examples(method_name)
180
+ if examples:
181
+ print("💡 EXAMPLES:")
182
+ for example in examples:
183
+ print(f" {example}")
184
+ print()
185
+
186
+ print("=" * 60)
187
+
188
+ def _get_method_examples(self, method_name):
189
+ """Get examples for specific methods."""
190
+ examples = {
191
+ "plot_one": [
192
+ "vis.plot_one(df, 'sales', 'date') # Plot sales over time",
193
+ "vis.plot_one(stock_df, 'price', 'timestamp') # Stock price chart",
194
+ ],
195
+ "plot_two": [
196
+ "config = {'df1': df1, 'col1': 'sales', 'df2': df2, 'col2': 'revenue', 'date_column': 'date'}",
197
+ "vis.plot_two(config, same_axis=True) # Same y-axis",
198
+ "vis.plot_two(config, same_axis=False) # Separate y-axes",
199
+ ],
200
+ "plot_chart": [
201
+ "config = {'df': df, 'date_col': 'date', 'value_cols': ['sales'], 'chart_type': 'line'}",
202
+ "vis.plot_chart(config) # Line chart",
203
+ "config['chart_type'] = 'bar' # Change to bar chart",
204
+ ],
205
+ "plot_correlation": [
206
+ "vis.plot_correlation(df) # All numeric columns",
207
+ "vis.plot_correlation(df, columns=['sales', 'profit', 'cost']) # Specific columns",
208
+ "vis.plot_correlation(df, method='spearman') # Spearman correlation",
209
+ ],
210
+ "plot_sankey": [
211
+ "# Basic multi-layer Sankey",
212
+ "vis.plot_sankey(df, 'Source', 'Target', 'Value')",
213
+ "",
214
+ "# Sankey with custom colors and title",
215
+ "color_map = {",
216
+ " 'Brand Media': 'rgba(246, 107, 109, 0.6)',",
217
+ " 'TV': 'rgba(246, 107, 109, 0.6)',",
218
+ " 'default': 'rgba(175, 175, 175, 0.6)'",
219
+ "}",
220
+ "vis.plot_sankey(df, 'Source', 'Target', 'Value', title='Brand Media Effects', color_mapping=color_map)",
221
+ ],
222
+ "set_theme": [
223
+ "vis.set_theme('dark') # Switch to dark theme",
224
+ "vis.set_theme('business') # Professional business theme",
225
+ "vis.set_theme('scientific') # Scientific publication theme",
226
+ ],
227
+ }
228
+ return examples.get(method_name, [])
229
+
230
+ def set_theme(self, theme_name):
231
+ """
232
+ Set the global theme for all charts.
233
+
234
+ Parameters
235
+ ----------
236
+ theme_name : str
237
+ Theme name. Available options: 'default', 'dark', 'business', 'scientific'
238
+
239
+ Returns
240
+ -------
241
+ None
242
+
243
+ Examples
244
+ --------
245
+ vis.set_theme('dark') # Dark theme with plasma colors
246
+ vis.set_theme('business') # Professional business theme
247
+ vis.set_theme('scientific') # Scientific publication theme
248
+
249
+ """
250
+ if theme_name not in self.themes:
251
+ available_themes = ", ".join(self.themes.keys())
252
+ error_msg = (
253
+ f"Theme '{theme_name}' not found. Available themes: {available_themes}"
254
+ )
255
+ raise ValueError(error_msg)
256
+
257
+ self.current_theme = theme_name
258
+ print(f"✅ Theme set to: {theme_name}")
259
+
260
+ def _apply_theme(self, fig):
261
+ """Apply current theme to a figure."""
262
+ theme = self.themes[self.current_theme]
263
+
264
+ fig.update_layout(
265
+ template=theme["template"],
266
+ plot_bgcolor=theme["background_color"],
267
+ font={
268
+ "family": theme["font_family"],
269
+ "size": theme["font_size"],
270
+ "color": theme["text_color"],
271
+ },
272
+ xaxis={
273
+ "showline": True,
274
+ "linecolor": theme["text_color"],
275
+ "gridcolor": theme["grid_color"],
276
+ },
277
+ yaxis={
278
+ "showline": True,
279
+ "linecolor": theme["text_color"],
280
+ "gridcolor": theme["grid_color"],
281
+ "rangemode": "tozero",
282
+ },
283
+ )
284
+
285
+ return fig
286
+
287
+ def plot_correlation(self, df, columns=None, method="pearson", title=None):
288
+ """
289
+ Create a correlation heatmap for numeric columns in a DataFrame.
290
+
291
+ Parameters
292
+ ----------
293
+ df : pandas.DataFrame
294
+ Input DataFrame with numeric columns
295
+ columns : list, optional
296
+ Specific columns to include in correlation. If None, uses all numeric columns
297
+ method : str, default 'pearson'
298
+ Correlation method: 'pearson', 'kendall', 'spearman'
299
+ title : str, optional
300
+ Custom title for the heatmap
301
+
302
+ Returns
303
+ -------
304
+ plotly.graph_objects.Figure
305
+ The correlation heatmap figure
306
+
307
+ Example:
308
+ --------
309
+ # Basic correlation heatmap
310
+ fig = vis.plot_correlation(df)
311
+
312
+ # Specific columns with Spearman correlation
313
+ fig = vis.plot_correlation(df, columns=['sales', 'profit', 'cost'], method='spearman')
314
+
315
+ """
316
+ # Select numeric columns
317
+ if columns is None:
318
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
319
+ else:
320
+ numeric_cols = [
321
+ col
322
+ for col in columns
323
+ if col in df.columns and df[col].dtype in ["int64", "float64"]
324
+ ]
325
+
326
+ # Minimum columns required for correlation
327
+ min_correlation_cols = 2
328
+
329
+ if len(numeric_cols) < min_correlation_cols:
330
+ error_msg = "Need at least 2 numeric columns for correlation analysis"
331
+ raise ValueError(error_msg)
332
+
333
+ # Calculate correlation matrix
334
+ corr_matrix = df[numeric_cols].corr(method=method)
335
+
336
+ # Create heatmap
337
+ fig = px.imshow(
338
+ corr_matrix,
339
+ text_auto=True,
340
+ aspect="auto",
341
+ color_continuous_scale=self.themes[self.current_theme]["colorscale"],
342
+ title=title or f"{method.title()} Correlation Matrix",
343
+ )
344
+
345
+ # Apply theme
346
+ fig = self._apply_theme(fig)
347
+
348
+ # Update text color for better readability
349
+ fig.update_traces(
350
+ textfont={"color": self.themes[self.current_theme]["text_color"]}
351
+ )
352
+
353
+ return fig
4
354
 
5
- class datavis:
6
-
7
- def help(self):
355
+ def plot_sankey(self, df, source_col, target_col, value_col, **kwargs):
8
356
  """
9
- Displays a help menu listing all the available functions with their descriptions, usage, and examples.
10
- """
11
- print("1. plot_one")
12
- print(" - Description: Plots a specified column from a DataFrame with white background and black axes.")
13
- print(" - Usage: plot_one(df1, col1, date_column)")
14
- print(" - Example: plot_one(df, 'sales', 'date')\n")
15
-
16
- print("2. plot_two")
17
- print(" - Description: Plots specified columns from two DataFrames, optionally on the same or separate y-axes.")
18
- print(" - Usage: plot_two(df1, col1, df2, col2, date_column, same_axis=True)")
19
- print(" - Example: plot_two(df1, 'sales_vol', df2, 'sales_revenue', 'date', same_axis=False)\n")
20
-
21
- print("3. plot_chart")
22
- print(" - Description: Plots various chart types using Plotly, including line, bar, scatter, area, pie, etc.")
23
- print(" - Usage: plot_chart(df, date_col, value_cols, chart_type='line', title='Chart', x_title='Date', y_title='Values')")
24
- print(" - Example: plot_chart(df, 'date', ['sales', 'revenue'], chart_type='line', title='Sales and Revenue')\n")
25
-
357
+ Create a multi-layer Sankey diagram from a single DataFrame.
358
+
359
+ Parameters
360
+ ----------
361
+ df : pandas.DataFrame
362
+ Input DataFrame with source, target, and value columns
363
+ source_col : str
364
+ Column name for source nodes
365
+ target_col : str
366
+ Column name for target nodes
367
+ value_col : str
368
+ Column name for flow values (must be numeric)
369
+ title : str, optional
370
+ Custom title for the diagram
371
+ color_mapping : dict, optional
372
+ Dictionary mapping source/target names to colors
373
+ Format: {'Brand Media': 'rgba(246, 107, 109, 0.6)', 'default': 'rgba(175, 175, 175, 0.6)'}
374
+ Pass as keyword argument: color_mapping={...}
375
+ **kwargs : dict
376
+ Additional keyword arguments including title and color_mapping
377
+
378
+ Returns
379
+ -------
380
+ plotly.graph_objects.Figure
381
+ The multi-layer Sankey diagram figure
382
+
383
+ DataFrame Format Requirements:
384
+ -----------------------------
385
+ Single DataFrame with all flow data:
386
+ | Source | Target | Value |
387
+ |-------------|-----------|-------|
388
+ | Brand Media | TV | 100 |
389
+ | Brand Media | Radio | 50 |
390
+ | TV | BU_North | 60 |
391
+ | TV | BU_South | 40 |
392
+ | Radio | BU_North | 30 |
393
+
394
+ Example:
395
+ --------
396
+ # Basic multi-layer Sankey
397
+ fig = vis.plot_sankey(df, 'Source', 'Target', 'Value')
398
+
399
+ # Sankey with custom colors
400
+ color_map = {
401
+ 'Brand Media': 'rgba(246, 107, 109, 0.6)',
402
+ 'TV': 'rgba(246, 107, 109, 0.6)',
403
+ 'Radio': 'rgba(246, 107, 109, 0.6)',
404
+ 'default': 'rgba(175, 175, 175, 0.6)'
405
+ }
406
+ fig = vis.plot_sankey(df, 'Source', 'Target', 'Value',
407
+ title='Brand Media Effects', color_mapping=color_map)
408
+
409
+ """
410
+ # Extract keyword arguments
411
+ title = kwargs.get("title")
412
+ color_mapping = kwargs.get("color_mapping")
413
+
414
+ # Validate required columns
415
+ required_cols = [source_col, target_col, value_col]
416
+ missing_cols = [col for col in required_cols if col not in df.columns]
417
+ if missing_cols:
418
+ error_msg = f"Missing columns: {missing_cols}"
419
+ raise ValueError(error_msg)
420
+
421
+ # Ensure value column is numeric
422
+ if not pd.api.types.is_numeric_dtype(df[value_col]):
423
+ error_msg = f"Value column '{value_col}' must be numeric"
424
+ raise ValueError(error_msg)
425
+
426
+ # Create working copy and remove any rows with missing values or zero values
427
+ work_df = df[required_cols].dropna()
428
+ work_df = work_df[work_df[value_col] != 0] # Remove zero values
429
+
430
+ if work_df.empty:
431
+ error_msg = "No valid data rows found after removing missing/zero values"
432
+ raise ValueError(error_msg)
433
+
434
+ # Get all unique nodes
435
+ all_sources = set(work_df[source_col].unique())
436
+ all_targets = set(work_df[target_col].unique())
437
+
438
+ # Create layers for proper node positioning
439
+ # Layer 1: sources that don't appear as targets (starting nodes)
440
+ layer_1_nodes = list(all_sources - all_targets)
441
+ # Final layer: targets that don't appear as sources (ending nodes)
442
+ final_layer_nodes = list(all_targets - all_sources)
443
+ # Intermediate: nodes that are both source and target
444
+ intermediate_nodes = list(all_sources & all_targets)
445
+
446
+ # Create ordered node list for proper left-to-right flow
447
+ all_nodes = layer_1_nodes + intermediate_nodes + final_layer_nodes
448
+ node_dict = {node: i for i, node in enumerate(all_nodes)}
449
+
450
+ # Create source, target, and value lists for Sankey
451
+ source_indices = [node_dict[source] for source in work_df[source_col]]
452
+ target_indices = [node_dict[target] for target in work_df[target_col]]
453
+ values = work_df[value_col].tolist()
454
+
455
+ # Apply color mapping if provided
456
+ if color_mapping:
457
+ link_colors = []
458
+ for _, row in work_df.iterrows():
459
+ source_name = row[source_col]
460
+ target_name = row[target_col]
461
+ if source_name in color_mapping:
462
+ link_colors.append(color_mapping[source_name])
463
+ elif target_name in color_mapping:
464
+ link_colors.append(color_mapping[target_name])
465
+ else:
466
+ link_colors.append(
467
+ color_mapping.get("default", "rgba(175, 175, 175, 0.6)")
468
+ )
469
+ else:
470
+ # Use theme-based default colors
471
+ link_colors = (
472
+ "rgba(255,255,255,0.3)"
473
+ if self.current_theme == "dark"
474
+ else "rgba(0,0,0,0.3)"
475
+ )
476
+
477
+ # Create Sankey diagram
478
+ fig = go.Figure(
479
+ data=[
480
+ go.Sankey(
481
+ node={
482
+ "pad": 15,
483
+ "thickness": 20,
484
+ "line": {"color": "black", "width": 0.5},
485
+ "label": all_nodes,
486
+ "color": self.themes[self.current_theme]["line_color"],
487
+ },
488
+ link={
489
+ "source": source_indices,
490
+ "target": target_indices,
491
+ "value": values,
492
+ "color": link_colors,
493
+ },
494
+ )
495
+ ]
496
+ )
497
+
498
+ fig.update_layout(
499
+ title_text=title or f"Sankey Diagram - {source_col} to {target_col}",
500
+ font_size=self.themes[self.current_theme]["font_size"],
501
+ font_color=self.themes[self.current_theme]["text_color"],
502
+ paper_bgcolor=self.themes[self.current_theme]["background_color"],
503
+ )
504
+
505
+ return fig
506
+
26
507
  def plot_one(self, df1, col1, date_column):
27
508
  """
28
- Plots specified column from a DataFrame with white background and black axes,
29
- using a specified date column as the X-axis.
509
+ Plots specified column from a DataFrame with themed styling.
510
+
511
+ Uses a specified date column as the X-axis.
512
+
513
+ Parameters
514
+ ----------
515
+ df1 : pandas.DataFrame
516
+ Input DataFrame
517
+ col1 : str
518
+ Column name from the DataFrame to plot
519
+ date_column : str
520
+ The name of the date column to use for the X-axis
521
+
522
+ Returns
523
+ -------
524
+ plotly.graph_objects.Figure
525
+ The line plot figure
30
526
 
31
- :param df1: DataFrame
32
- :param col1: Column name from the DataFrame
33
- :param date_column: The name of the date column to use for the X-axis
34
527
  """
35
528
  # Check if columns exist in the DataFrame
36
529
  if col1 not in df1.columns or date_column not in df1.columns:
37
- raise ValueError("Column not found in DataFrame")
530
+ error_msg = "Column not found in DataFrame"
531
+ raise ValueError(error_msg)
38
532
 
39
533
  # Check if the date column is in datetime format, if not convert it
40
534
  if not pd.api.types.is_datetime64_any_dtype(df1[date_column]):
41
535
  try:
42
- # Convert with dayfirst=True to interpret dates correctly
43
536
  df1[date_column] = pd.to_datetime(df1[date_column], dayfirst=True)
44
- except Exception as e:
45
- raise ValueError(f"Error converting {date_column} to datetime: {e}")
537
+ except (ValueError, TypeError) as e:
538
+ error_msg = f"Error converting {date_column} to datetime: {e}"
539
+ raise ValueError(error_msg) from e
46
540
 
47
541
  # Plotting using Plotly Express
48
542
  fig = px.line(df1, x=date_column, y=col1)
49
543
 
50
- # Update layout for white background and black axes lines, and setting y-axis to start at 0
51
- fig.update_layout(
52
- plot_bgcolor='white',
53
- xaxis=dict(
54
- showline=True,
55
- linecolor='black'
56
- ),
57
- yaxis=dict(
58
- showline=True,
59
- linecolor='black',
60
- rangemode='tozero' # Setting Y-axis to start at 0 if suitable
61
- )
62
- )
544
+ # Apply theme
545
+ fig = self._apply_theme(fig)
63
546
 
64
547
  return fig
65
548
 
66
- def plot_two(self, df1, col1, df2, col2, date_column, same_axis=True):
549
+ def plot_two(self, data_config, *, same_axis=True):
67
550
  """
68
- Plots specified columns from two different DataFrames with both different and the same lengths,
69
- using a specified date column as the X-axis, and charting on either the same or separate y-axes.
70
-
71
- :param df1: First DataFrame
72
- :param col1: Column name from the first DataFrame
73
- :param df2: Second DataFrame
74
- :param col2: Column name from the second DataFrame
75
- :param date_column: The name of the date column to use for the X-axis
76
- :param same_axis: If True, plot both traces on the same y-axis; otherwise, use separate y-axes.
77
- :return: Plotly figure
551
+ Plots specified columns from two different DataFrames with themed styling.
552
+
553
+ Parameters
554
+ ----------
555
+ data_config : dict
556
+ Dictionary with keys: 'df1', 'col1', 'df2', 'col2', 'date_column'
557
+ same_axis : bool, default True
558
+ If True, plot both traces on the same y-axis; otherwise, use separate y-axes
559
+
560
+ Returns
561
+ -------
562
+ plotly.graph_objects.Figure
563
+ The comparison plot figure
564
+
78
565
  """
566
+ # Extract parameters from config
567
+ df1 = data_config["df1"]
568
+ col1 = data_config["col1"]
569
+ df2 = data_config["df2"]
570
+ col2 = data_config["col2"]
571
+ date_column = data_config["date_column"]
572
+
79
573
  # Validate inputs
80
574
  if col1 not in df1.columns or date_column not in df1.columns:
81
- raise ValueError(f"Column {col1} or {date_column} not found in the first DataFrame.")
575
+ error_msg = (
576
+ f"Column {col1} or {date_column} not found in the first DataFrame."
577
+ )
578
+ raise ValueError(error_msg)
82
579
  if col2 not in df2.columns or date_column not in df2.columns:
83
- raise ValueError(f"Column {col2} or {date_column} not found in the second DataFrame.")
580
+ error_msg = (
581
+ f"Column {col2} or {date_column} not found in the second DataFrame."
582
+ )
583
+ raise ValueError(error_msg)
84
584
 
85
585
  # Ensure date columns are in datetime format
86
- df1[date_column] = pd.to_datetime(df1[date_column], errors='coerce')
87
- df2[date_column] = pd.to_datetime(df2[date_column], errors='coerce')
586
+ df1[date_column] = pd.to_datetime(df1[date_column], errors="coerce")
587
+ df2[date_column] = pd.to_datetime(df2[date_column], errors="coerce")
88
588
 
89
589
  # Drop rows with invalid dates
90
590
  df1 = df1.dropna(subset=[date_column])
91
591
  df2 = df2.dropna(subset=[date_column])
92
592
 
93
- # Create traces for the first and second DataFrames
94
- trace1 = go.Scatter(x=df1[date_column], y=df1[col1], mode='lines', name=col1, yaxis='y1')
593
+ # Create traces
594
+ trace1 = go.Scatter(
595
+ x=df1[date_column],
596
+ y=df1[col1],
597
+ mode="lines",
598
+ name=col1,
599
+ yaxis="y1",
600
+ line={"color": self.themes[self.current_theme]["line_color"]},
601
+ )
95
602
 
96
603
  if same_axis:
97
- trace2 = go.Scatter(x=df2[date_column], y=df2[col2], mode='lines', name=col2, yaxis='y1')
604
+ trace2 = go.Scatter(
605
+ x=df2[date_column],
606
+ y=df2[col2],
607
+ mode="lines",
608
+ name=col2,
609
+ yaxis="y1",
610
+ )
98
611
  else:
99
- trace2 = go.Scatter(x=df2[date_column], y=df2[col2], mode='lines', name=col2, yaxis='y2')
100
-
101
- # Define layout for the plot
102
- layout = go.Layout(
103
- title="Comparison Plot",
104
- xaxis=dict(title=date_column, showline=True, linecolor='black'),
105
- yaxis=dict(
106
- title=col1 if same_axis else f"{col1} (y1)",
107
- showline=True,
108
- linecolor='black',
109
- rangemode='tozero'
110
- ),
111
- yaxis2=dict(
112
- title=f"{col2} (y2)" if not same_axis else "",
113
- overlaying='y',
114
- side='right',
115
- showline=True,
116
- linecolor='black',
117
- rangemode='tozero'
118
- ),
119
- showlegend=True,
120
- plot_bgcolor='white' # Set the plot background color to white
121
- )
612
+ trace2 = go.Scatter(
613
+ x=df2[date_column],
614
+ y=df2[col2],
615
+ mode="lines",
616
+ name=col2,
617
+ yaxis="y2",
618
+ )
619
+
620
+ # Create figure
621
+ fig = go.Figure(data=[trace1, trace2])
122
622
 
123
- # Create the figure with the defined layout and traces
124
- fig = go.Figure(data=[trace1, trace2], layout=layout)
623
+ # Apply theme
624
+ fig = self._apply_theme(fig)
625
+
626
+ # Update layout for dual axis if needed
627
+ if not same_axis:
628
+ fig.update_layout(
629
+ yaxis2={
630
+ "title": f"{col2} (y2)",
631
+ "overlaying": "y",
632
+ "side": "right",
633
+ "showline": True,
634
+ "linecolor": self.themes[self.current_theme]["text_color"],
635
+ "rangemode": "tozero",
636
+ }
637
+ )
638
+
639
+ fig.update_layout(title="Comparison Plot", showlegend=True)
125
640
 
126
641
  return fig
127
642
 
128
- def plot_chart(self, df, date_col, value_cols, chart_type='line', title='Chart', x_title='Date', y_title='Values', **kwargs):
643
+ def plot_chart(self, data_config):
129
644
  """
130
- Plot various types of charts using Plotly.
131
-
132
- Args:
133
- df (pandas.DataFrame): DataFrame containing the data.
134
- date_col (str): The name of the column with date information.
135
- value_cols (list): List of columns to plot.
136
- chart_type (str): Type of chart to plot ('line', 'bar', 'scatter', etc.).
137
- title (str): Title of the chart.
138
- x_title (str): Title of the x-axis.
139
- y_title (str): Title of the y-axis.
140
- **kwargs: Additional keyword arguments for customization.
141
-
142
- Returns:
143
- plotly.graph_objects.Figure: The Plotly figure object.
645
+ Plot various types of charts using Plotly with themed styling.
646
+
647
+ Parameters
648
+ ----------
649
+ data_config : dict
650
+ Configuration dictionary with keys:
651
+ - df: DataFrame containing the data
652
+ - date_col: The name of the column with date information
653
+ - value_cols: List of columns to plot
654
+ - chart_type: Type of chart ('line', 'bar', 'scatter', etc.)
655
+ - title: Title of the chart
656
+ - x_title: Title of the x-axis
657
+ - y_title: Title of the y-axis
658
+ - kwargs: Additional keyword arguments
659
+
660
+ Returns
661
+ -------
662
+ plotly.graph_objects.Figure
663
+ The chart figure
664
+
144
665
  """
145
- import pandas as pd
146
- import plotly.graph_objects as go
666
+ # Extract parameters with defaults
667
+ dataframe = data_config["df"]
668
+ date_col = data_config["date_col"]
669
+ value_cols = data_config["value_cols"]
670
+ chart_type = data_config.get("chart_type", "line")
671
+ title = data_config.get("title", "Chart")
672
+ x_title = data_config.get("x_title", "Date")
673
+ y_title = data_config.get("y_title", "Values")
674
+ kwargs = data_config.get("kwargs", {})
147
675
 
148
676
  # Ensure the date column is in datetime format
149
- df[date_col] = pd.to_datetime(df[date_col])
677
+ dataframe[date_col] = pd.to_datetime(dataframe[date_col])
150
678
 
151
679
  # Validate input columns
152
- value_cols = [col for col in value_cols if col in df.columns and col != date_col]
680
+ value_cols = [
681
+ col for col in value_cols if col in dataframe.columns and col != date_col
682
+ ]
153
683
  if not value_cols:
154
- raise ValueError("No valid columns provided for plotting.")
684
+ error_msg = "No valid columns provided for plotting."
685
+ raise ValueError(error_msg)
155
686
 
156
687
  # Initialize the figure
157
688
  fig = go.Figure()
158
689
 
159
690
  # Define a mapping for chart types to corresponding Plotly trace types
160
691
  chart_trace_map = {
161
- 'line': lambda col: go.Scatter(x=df[date_col], y=df[col], mode='lines', name=col, **kwargs),
162
- 'bar': lambda col: go.Bar(x=df[date_col], y=df[col], name=col, **kwargs),
163
- 'scatter': lambda col: go.Scatter(x=df[date_col], y=df[col], mode='markers', name=col, **kwargs),
164
- 'area': lambda col: go.Scatter(x=df[date_col], y=df[col], mode='lines', fill='tozeroy', name=col, **kwargs),
165
- 'pie': lambda col: go.Pie(labels=df[date_col], values=df[col], name=col, **kwargs),
166
- 'box': lambda col: go.Box(y=df[col], name=col, **kwargs),
167
- 'bubble': lambda _: go.Scatter(
168
- x=df[value_cols[0]], y=df[value_cols[1]], mode='markers',
169
- marker=dict(size=df[value_cols[2]]), name='Bubble Chart', **kwargs
692
+ "line": lambda col: go.Scatter(
693
+ x=dataframe[date_col],
694
+ y=dataframe[col],
695
+ mode="lines",
696
+ name=col,
697
+ **kwargs,
698
+ ),
699
+ "bar": lambda col: go.Bar(
700
+ x=dataframe[date_col], y=dataframe[col], name=col, **kwargs
701
+ ),
702
+ "scatter": lambda col: go.Scatter(
703
+ x=dataframe[date_col],
704
+ y=dataframe[col],
705
+ mode="markers",
706
+ name=col,
707
+ **kwargs,
708
+ ),
709
+ "area": lambda col: go.Scatter(
710
+ x=dataframe[date_col],
711
+ y=dataframe[col],
712
+ mode="lines",
713
+ fill="tozeroy",
714
+ name=col,
715
+ **kwargs,
170
716
  ),
171
- 'funnel': lambda col: go.Funnel(y=df[date_col], x=df[col], **kwargs),
172
- 'waterfall': lambda col: go.Waterfall(x=df[date_col], y=df[col], measure=df[value_cols[1]], **kwargs),
173
- 'scatter3d': lambda _: go.Scatter3d(
174
- x=df[value_cols[0]], y=df[value_cols[1]], z=df[value_cols[2]],
175
- mode='markers', **kwargs
176
- )
177
717
  }
178
718
 
179
719
  # Generate traces for the selected chart type
@@ -182,15 +722,18 @@ class datavis:
182
722
  trace = chart_trace_map[chart_type](col)
183
723
  fig.add_trace(trace)
184
724
  else:
185
- raise ValueError(f"Unsupported chart type: {chart_type}")
725
+ error_msg = f"Unsupported chart type: {chart_type}"
726
+ raise ValueError(error_msg)
727
+
728
+ # Apply theme
729
+ fig = self._apply_theme(fig)
186
730
 
187
- # Update the layout of the figure
731
+ # Update the layout
188
732
  fig.update_layout(
189
733
  title=title,
190
734
  xaxis_title=x_title,
191
735
  yaxis_title=y_title,
192
- legend_title='Series',
193
- template='plotly_dark'
736
+ legend_title="Series",
194
737
  )
195
738
 
196
- return fig
739
+ return fig