owlplanner 2025.5.5__py3-none-any.whl → 2025.5.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,980 @@
1
+ """
2
+ Plotly implementation of plot backend.
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ # import plotly.io as pio
10
+
11
+ import io
12
+ from scipy import stats
13
+
14
+ from .base import PlotBackend
15
+ from .. import utils as u
16
+
17
+
18
+ class PlotlyBackend(PlotBackend):
19
+ """Plotly implementation of plot backend."""
20
+
21
+ def __init__(self):
22
+ """Initialize the plotly backend."""
23
+ # Set default template and layout
24
+ self.template = "plotly_white"
25
+ self.layout = dict(
26
+ showlegend=True,
27
+ legend=dict(
28
+ traceorder="reversed",
29
+ yanchor="top",
30
+ y=0.99,
31
+ xanchor="left",
32
+ x=0.01,
33
+ bgcolor="rgba(255, 255, 255, 0.5)"
34
+ ),
35
+ xaxis=dict(
36
+ # title="year",
37
+ showgrid=True,
38
+ griddash="dot",
39
+ gridcolor="lightgray",
40
+ zeroline=True,
41
+ zerolinecolor="gray",
42
+ zerolinewidth=1
43
+ ),
44
+ yaxis=dict(
45
+ showgrid=True,
46
+ griddash="dot",
47
+ gridcolor="lightgray",
48
+ zeroline=True,
49
+ zerolinecolor="gray",
50
+ zerolinewidth=1
51
+ )
52
+ )
53
+ # Setting to "browser" will open each graph in a separate tab.
54
+ # pio.renderers.default = "browser"
55
+
56
+ def jupyter_renderer(self, fig):
57
+ """Simple renderer for plotly in Jupyter notebook."""
58
+ return fig.show()
59
+
60
+ def plot_profile(self, year_n, xi_n, title, inames):
61
+ """Plot profile over time."""
62
+ fig = go.Figure()
63
+
64
+ # Add profile line
65
+ fig.add_trace(go.Scatter(
66
+ x=year_n,
67
+ y=xi_n,
68
+ name="profile",
69
+ line=dict(width=2)
70
+ ))
71
+
72
+ title = title.replace("\n", "<br>")
73
+ # Update layout
74
+ fig.update_layout(
75
+ title=title,
76
+ yaxis_title="ξ",
77
+ template=self.template,
78
+ showlegend=True,
79
+ legend=dict(
80
+ traceorder="reversed",
81
+ yanchor="bottom",
82
+ y=-0.5,
83
+ xanchor="center",
84
+ x=0.5,
85
+ bgcolor="rgba(0, 0, 0, 0)",
86
+ orientation="h"
87
+ ),
88
+ margin=dict(b=150)
89
+ )
90
+
91
+ return fig
92
+
93
+ def plot_gross_income(self, year_n, G_n, gamma_n, value, title, tax_brackets):
94
+ """Plot gross income over time."""
95
+ fig = go.Figure()
96
+
97
+ # Add taxable income line
98
+ if value == "nominal":
99
+ y_data = G_n / 1000
100
+ y_title = "$k (nominal)"
101
+ infladjust = gamma_n[:-1]
102
+ else:
103
+ y_data = G_n / gamma_n[:-1] / 1000
104
+ y_title = f"$k ({year_n[0]}$)"
105
+ infladjust = 1
106
+
107
+ fig.add_trace(go.Scatter(
108
+ x=year_n,
109
+ y=y_data,
110
+ name="taxable income",
111
+ line=dict(width=2)
112
+ ))
113
+
114
+ # Add tax brackets
115
+ for key, bracket_data in tax_brackets.items():
116
+ data_adj = bracket_data * infladjust / 1000
117
+ fig.add_trace(go.Scatter(
118
+ x=year_n,
119
+ y=data_adj,
120
+ name=key,
121
+ line=dict(width=1, dash="dot")
122
+ ))
123
+
124
+ title = title.replace("\n", "<br>")
125
+ # Update layout
126
+ fig.update_layout(
127
+ title=title,
128
+ yaxis_title=y_title,
129
+ template=self.template,
130
+ showlegend=True,
131
+ legend=dict(
132
+ yanchor="bottom",
133
+ y=-0.5,
134
+ xanchor="center",
135
+ x=0.5,
136
+ bgcolor="rgba(0, 0, 0, 0)",
137
+ orientation="h"
138
+ ),
139
+ margin=dict(b=150)
140
+ )
141
+
142
+ # Format y-axis as number
143
+ fig.update_yaxes(tickformat=",.0f")
144
+
145
+ return fig
146
+
147
+ def plot_net_spending(self, year_n, g_n, xi_n, xiBar_n, gamma_n, value, title, inames):
148
+ """Plot net spending over time."""
149
+ fig = go.Figure()
150
+
151
+ # Calculate data based on value type
152
+ if value == "nominal":
153
+ net_data = g_n / 1000
154
+ target_data = (g_n[0] / xi_n[0]) * xiBar_n / 1000
155
+ y_title = "$k (nominal)"
156
+ else:
157
+ net_data = g_n / gamma_n[:-1] / 1000
158
+ target_data = (g_n[0] / xi_n[0]) * xi_n / 1000
159
+ y_title = f"$k ({year_n[0]}$)"
160
+
161
+ # Add net spending line
162
+ fig.add_trace(go.Scatter(
163
+ x=year_n,
164
+ y=net_data,
165
+ name="net",
166
+ line=dict(width=2)
167
+ ))
168
+
169
+ # Add target line
170
+ fig.add_trace(go.Scatter(
171
+ x=year_n,
172
+ y=target_data,
173
+ name="target",
174
+ line=dict(width=1, dash="dot")
175
+ ))
176
+
177
+ title = title.replace("\n", "<br>")
178
+ # Update layout
179
+ fig.update_layout(
180
+ title=title,
181
+ yaxis_title=y_title,
182
+ template=self.template,
183
+ showlegend=True,
184
+ legend=dict(
185
+ yanchor="bottom",
186
+ y=-0.4,
187
+ xanchor="center",
188
+ x=0.5,
189
+ bgcolor="rgba(0, 0, 0, 0)",
190
+ orientation="h"
191
+ ),
192
+ margin=dict(b=150)
193
+ )
194
+
195
+ # Format y-axis as k
196
+ fig.update_yaxes(tickformat=",.0f")
197
+
198
+ return fig
199
+
200
+ def plot_taxes(self, year_n, T_n, M_n, gamma_n, value, title, inames):
201
+ """Plot taxes over time."""
202
+ fig = go.Figure()
203
+
204
+ # Calculate data based on value type
205
+ if value == "nominal":
206
+ income_tax_data = T_n / 1000
207
+ medicare_data = M_n / 1000
208
+ y_title = "$k (nominal)"
209
+ else:
210
+ income_tax_data = T_n / gamma_n[:-1] / 1000
211
+ medicare_data = M_n / gamma_n[:-1] / 1000
212
+ y_title = f"$k ({year_n[0]}$)"
213
+
214
+ # Add income taxes line
215
+ fig.add_trace(go.Scatter(
216
+ x=year_n,
217
+ y=income_tax_data,
218
+ name="income taxes",
219
+ line=dict(width=2)
220
+ ))
221
+
222
+ # Add Medicare line
223
+ fig.add_trace(go.Scatter(
224
+ x=year_n,
225
+ y=medicare_data,
226
+ name="Medicare",
227
+ line=dict(width=2, dash="dot")
228
+ ))
229
+
230
+ title = title.replace("\n", "<br>")
231
+ # Update layout
232
+ fig.update_layout(
233
+ title=title,
234
+ yaxis_title=y_title,
235
+ template=self.template,
236
+ showlegend=True,
237
+ legend=dict(
238
+ yanchor="bottom",
239
+ y=-0.4,
240
+ xanchor="center",
241
+ x=0.5,
242
+ bgcolor="rgba(0, 0, 0, 0)",
243
+ orientation="h"
244
+ ),
245
+ margin=dict(b=150)
246
+ )
247
+
248
+ # Format y-axis as currency
249
+ fig.update_yaxes(tickformat=",.0f")
250
+
251
+ return fig
252
+
253
+ def plot_rates(self, name, tau_kn, year_n, year_frac_left, N_k, rate_method, rate_frm=None, rate_to=None, tag=""):
254
+ """Plot rate values used over the time horizon."""
255
+ fig = go.Figure()
256
+
257
+ # Build title
258
+ title = name + "<br>Return & Inflation Rates (" + str(rate_method)
259
+ if rate_method in ["historical", "histochastic", "historical average"]:
260
+ title += f" {rate_frm}-{rate_to}"
261
+ title += ")"
262
+ if tag:
263
+ title += " - " + tag
264
+
265
+ # Define rate names and line styles
266
+ rate_names = [
267
+ "S&P500 (incl. div.)",
268
+ "Baa Corp. Bonds",
269
+ "10-y T-Notes",
270
+ "Inflation",
271
+ ]
272
+ line_styles = ["solid", "dot", "dash", "longdash"]
273
+
274
+ # Plot each rate
275
+ for k in range(N_k):
276
+ # Don't plot partial rates for current year if mid-year
277
+ if year_frac_left == 1:
278
+ data = 100 * tau_kn[k]
279
+ years = year_n
280
+ else:
281
+ data = 100 * tau_kn[k, 1:]
282
+ years = year_n[1:]
283
+
284
+ # Calculate mean and std
285
+ mean_val = np.mean(data)
286
+ std_val = np.std(data, ddof=1) # Use ddof=1 to match pandas
287
+ label = f"{rate_names[k]} <{mean_val:.1f} +/- {std_val:.1f}%>"
288
+
289
+ # Add trace
290
+ fig.add_trace(go.Scatter(
291
+ x=years,
292
+ y=data,
293
+ name=label,
294
+ line=dict(
295
+ width=2,
296
+ dash=line_styles[k % N_k]
297
+ )
298
+ ))
299
+
300
+ # Update layout
301
+ fig.update_layout(
302
+ title=title,
303
+ # xaxis_title="year",
304
+ yaxis_title="%",
305
+ template=self.template,
306
+ showlegend=True,
307
+ legend=dict(
308
+ yanchor="bottom",
309
+ y=-0.60,
310
+ xanchor="center",
311
+ x=0.5,
312
+ bgcolor="rgba(0, 0, 0, 0)",
313
+ orientation="h"
314
+ ),
315
+ margin=dict(b=150)
316
+ )
317
+
318
+ # Format y-axis as percentage
319
+ fig.update_yaxes(tickformat=".1f")
320
+
321
+ return fig
322
+
323
+ def plot_rates_distributions(self, frm, to, SP500, BondsBaa, TNotes, Inflation, FROM):
324
+ """Plot histograms of the rates distributions."""
325
+ # Create subplot figure
326
+ fig = make_subplots(
327
+ rows=1, cols=4,
328
+ subplot_titles=("S&P500", "BondsBaa", "TNotes", "Inflation"),
329
+ shared_yaxes=True
330
+ )
331
+
332
+ # Calculate number of bins
333
+ nbins = int((to - frm) / 4)
334
+
335
+ # Convert year values to indices
336
+ frm_idx = frm - FROM
337
+ to_idx = to - FROM
338
+
339
+ # Get data arrays
340
+ data = [
341
+ np.array(SP500[frm_idx:to_idx]),
342
+ np.array(BondsBaa[frm_idx:to_idx]),
343
+ np.array(TNotes[frm_idx:to_idx]),
344
+ np.array(Inflation[frm_idx:to_idx])
345
+ ]
346
+
347
+ # Add histograms
348
+ for i, dat in enumerate(data):
349
+ mean_val = np.mean(dat)
350
+ label = f"<>: {u.pc(mean_val, 2, 1)}"
351
+
352
+ fig.add_trace(
353
+ go.Histogram(
354
+ x=dat,
355
+ nbinsx=nbins,
356
+ name=label,
357
+ showlegend=False,
358
+ marker_color="orange"
359
+ ),
360
+ row=1, col=i+1
361
+ )
362
+
363
+ # Add mean annotation
364
+ fig.add_annotation(
365
+ x=0.5, y=0.95,
366
+ xref=f"x{i+1}",
367
+ yref="paper",
368
+ text=label,
369
+ showarrow=False,
370
+ font=dict(size=10),
371
+ bgcolor="rgba(255, 255, 255, 0.7)"
372
+ )
373
+
374
+ # Update layout
375
+ fig.update_layout(
376
+ title=f"Rates from {frm} to {to}",
377
+ template=self.template,
378
+ showlegend=False,
379
+ height=400,
380
+ width=1200
381
+ )
382
+
383
+ # Update axes
384
+ for i in range(4):
385
+ fig.update_xaxes(
386
+ title_text="%",
387
+ showgrid=True,
388
+ gridcolor="lightgray",
389
+ zeroline=True,
390
+ zerolinecolor="gray",
391
+ zerolinewidth=1,
392
+ row=1, col=i+1
393
+ )
394
+ fig.update_yaxes(
395
+ showgrid=True,
396
+ gridcolor="lightgray",
397
+ zeroline=True,
398
+ zerolinecolor="gray",
399
+ zerolinewidth=1,
400
+ row=1, col=i+1
401
+ )
402
+
403
+ return fig
404
+
405
+ def plot_rates_correlations(self, pname, tau_kn, N_n, rate_method, rate_frm=None, rate_to=None,
406
+ tag="", share_range=False):
407
+ """Plot correlations between various rates."""
408
+ # Create DataFrame with rate data
409
+ rate_names = [
410
+ "S&P500 (incl. div.)",
411
+ "Baa Corp. Bonds",
412
+ "10-y T-Notes",
413
+ "Inflation",
414
+ ]
415
+
416
+ df = pd.DataFrame()
417
+ for k, name in enumerate(rate_names):
418
+ df[name] = 100 * tau_kn[k] # Convert to percentage
419
+
420
+ # Create subplot figure
421
+ n_vars = len(rate_names)
422
+ fig = make_subplots(
423
+ rows=n_vars, cols=n_vars,
424
+ # subplot_titles=rate_names, # Only use rate names for first row
425
+ shared_xaxes=True, # Share x-axes
426
+ vertical_spacing=0.05,
427
+ horizontal_spacing=0.05
428
+ )
429
+
430
+ # Set range if requested
431
+ if share_range:
432
+ minval = df.min().min() - 5
433
+ maxval = df.max().max() + 5
434
+ else:
435
+ minval = maxval = None
436
+
437
+ # Add plots
438
+ for i in range(n_vars):
439
+ for j in range(n_vars):
440
+ if i == j:
441
+ # Diagonal: histogram
442
+ fig.add_trace(
443
+ go.Histogram(
444
+ x=df[rate_names[i]],
445
+ marker_color="orange",
446
+ showlegend=False
447
+ ),
448
+ row=i+1, col=j+1
449
+ )
450
+ # Set y-axis for histogram to be independent and start from 0
451
+ fig.update_yaxes(
452
+ showticklabels=True,
453
+ row=i+1,
454
+ col=j+1,
455
+ range=[0, None], # Start from 0, let max be automatic
456
+ autorange=True, # Allow automatic scaling
457
+ matches=None, # Don't share with other plots
458
+ scaleanchor=None, # Don't link to any other axis
459
+ constrain=None # Don't constrain the range
460
+ )
461
+ elif i < j:
462
+ # Upper triangle: scatter plot
463
+ fig.add_trace(
464
+ go.Scatter(
465
+ x=df[rate_names[j]],
466
+ y=df[rate_names[i]],
467
+ mode="markers",
468
+ marker=dict(
469
+ size=6,
470
+ opacity=0.5
471
+ ),
472
+ showlegend=False
473
+ ),
474
+ row=i+1, col=j+1
475
+ )
476
+ # Set range for scatter plot if requested
477
+ if share_range and minval is not None and maxval is not None:
478
+ fig.update_yaxes(range=[minval, maxval], row=i+1, col=j+1)
479
+ else:
480
+ # Lower triangle: KDE
481
+ x = df[rate_names[j]]
482
+ y = df[rate_names[i]]
483
+ kde = stats.gaussian_kde(np.vstack([x, y]))
484
+ x_grid = np.linspace(x.min(), x.max(), 50)
485
+ y_grid = np.linspace(y.min(), y.max(), 50)
486
+ X, Y = np.meshgrid(x_grid, y_grid)
487
+ Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
488
+
489
+ fig.add_trace(
490
+ go.Contour(
491
+ x=x_grid,
492
+ y=y_grid,
493
+ z=Z,
494
+ showscale=False,
495
+ colorscale="Viridis",
496
+ showlegend=False
497
+ ),
498
+ row=i+1, col=j+1
499
+ )
500
+ # Set range for KDE plot if requested
501
+ if share_range and minval is not None and maxval is not None:
502
+ fig.update_yaxes(range=[minval, maxval], row=i+1, col=j+1)
503
+
504
+ # Update layout
505
+ title = pname + "<br>"
506
+ title += f"Rates Correlations (N={N_n}) {rate_method}"
507
+ if rate_method in ["historical", "histochastic"]:
508
+ title += f" ({rate_frm}-{rate_to})"
509
+ if tag:
510
+ title += " - " + tag
511
+
512
+ fig.update_layout(
513
+ title=title,
514
+ template=self.template,
515
+ height=800,
516
+ width=800,
517
+ showlegend=False
518
+ )
519
+
520
+ # Update axes
521
+ for i in range(n_vars):
522
+ for j in range(n_vars):
523
+ # Add zero lines
524
+ fig.add_shape(
525
+ type="line",
526
+ x0=0, y0=0, x1=0, y1=1,
527
+ xref=f"x{j+1}", yref=f"y{i+1}",
528
+ line=dict(color="gray", width=1, dash="dot")
529
+ )
530
+ if i != j:
531
+ fig.add_shape(
532
+ type="line",
533
+ x0=0, y0=0, x1=1, y1=0,
534
+ xref=f"x{j+1}", yref=f"y{i+1}",
535
+ line=dict(color="gray", width=1, dash="dot")
536
+ )
537
+
538
+ # Update axis labels
539
+ if i == n_vars-1: # Bottom row
540
+ fig.update_xaxes(title_text=rate_names[j], row=i+1, col=j+1)
541
+ if j == 0: # Left column
542
+ fig.update_yaxes(title_text=rate_names[i], row=i+1, col=j+1)
543
+
544
+ return fig
545
+
546
+ def plot_histogram_results(self, objective, df, N, year_n, n_d=None, N_i=1, phi_j=None):
547
+ """Show a histogram of values from historical data or Monte Carlo simulations."""
548
+ description = io.StringIO()
549
+
550
+ # Calculate success rate and create title
551
+ pSuccess = u.pc(len(df) / N)
552
+ print(f"Success rate: {pSuccess} on {N} samples.", file=description)
553
+ title = f"N = {N}, P = {pSuccess}"
554
+
555
+ # Calculate statistics
556
+ means = df.mean(axis=0, numeric_only=True)
557
+ medians = df.median(axis=0, numeric_only=True)
558
+
559
+ # Handle mid-year cases
560
+ my = 2 * [year_n[-1]]
561
+ if N_i == 2 and n_d is not None and n_d < len(year_n):
562
+ my[0] = year_n[n_d - 1]
563
+
564
+ # Handle partial bequest cases
565
+ if (phi_j is not None and np.all((1 - phi_j) < 0.01)) or medians.iloc[0] < 1:
566
+ if medians.iloc[0] < 1:
567
+ print(f"Optimized solutions all have null partial bequest in year {my[0]}.",
568
+ file=description)
569
+ df.drop("partial", axis=1, inplace=True)
570
+ means = df.mean(axis=0, numeric_only=True)
571
+ medians = df.median(axis=0, numeric_only=True)
572
+
573
+ # Convert to thousands
574
+ df /= 1000
575
+
576
+ if len(df) > 0:
577
+ thisyear = year_n[0]
578
+
579
+ if objective == "maxBequest":
580
+ # Single figure with both partial and final bequests
581
+ fig = go.Figure()
582
+
583
+ # Add histograms for each column
584
+ for i, col in enumerate(df.columns):
585
+ dmedian = u.d(medians.iloc[i], latex=False)
586
+ dmean = u.d(means.iloc[i], latex=False)
587
+ label = f"{my[i]}: M: {dmedian}, <x>: {dmean}"
588
+
589
+ # Add histogram
590
+ fig.add_trace(go.Histogram(
591
+ x=df[col],
592
+ name=label,
593
+ opacity=0.7,
594
+ marker_color="orange"
595
+ ))
596
+
597
+ # Update layout
598
+ fig.update_layout(
599
+ title=objective,
600
+ xaxis_title=f"{thisyear} $k",
601
+ yaxis_title="Count",
602
+ template=self.template,
603
+ barmode="overlay",
604
+ showlegend=True,
605
+ legend=dict(
606
+ yanchor="bottom",
607
+ y=-0.50,
608
+ xanchor="center",
609
+ x=0.5,
610
+ bgcolor="rgba(0, 0, 0, 0)"
611
+ )
612
+ )
613
+
614
+ leads = [f"partial {my[0]}", f" final {my[1]}"]
615
+
616
+ elif len(means) == 2:
617
+ # Two separate histograms
618
+ fig = make_subplots(
619
+ rows=1, cols=2,
620
+ subplot_titles=[f"partial {my[0]}", objective],
621
+ horizontal_spacing=0.1
622
+ )
623
+
624
+ cols = ["partial", objective]
625
+ leads = [f"partial {my[0]}", objective]
626
+
627
+ for i, col in enumerate(cols):
628
+ dmedian = u.d(medians.iloc[i], latex=False)
629
+ dmean = u.d(means.iloc[i], latex=False)
630
+ label = f"M: {dmedian}, <x>: {dmean}"
631
+
632
+ # Add histogram
633
+ fig.add_trace(
634
+ go.Histogram(
635
+ x=df[col],
636
+ name=label,
637
+ marker_color="orange",
638
+ showlegend=False
639
+ ),
640
+ row=1, col=i+1
641
+ )
642
+
643
+ # Add statistics annotation
644
+ fig.add_annotation(
645
+ x=0.01, y=0.99,
646
+ xref=f"x{i+1}",
647
+ yref="paper",
648
+ text=label,
649
+ showarrow=False,
650
+ font=dict(size=10),
651
+ bgcolor="rgba(0, 0, 0, 0)"
652
+ )
653
+
654
+ # Update layout
655
+ fig.update_layout(
656
+ title=title,
657
+ template=self.template,
658
+ height=400,
659
+ width=800
660
+ )
661
+
662
+ # Update y-axis labels
663
+ fig.update_yaxes(title_text="Count", row=1, col=1)
664
+ fig.update_yaxes(title_text="Count", row=1, col=2)
665
+
666
+ else:
667
+ # Single histogram for net spending
668
+ fig = go.Figure()
669
+
670
+ dmedian = u.d(medians.iloc[0], latex=False)
671
+ dmean = u.d(means.iloc[0], latex=False)
672
+ label = f"M: {dmedian}, <x>: {dmean}"
673
+
674
+ # Add histogram
675
+ fig.add_trace(go.Histogram(
676
+ x=df[objective],
677
+ name=label,
678
+ marker_color="orange"
679
+ ))
680
+
681
+ # Update layout
682
+ fig.update_layout(
683
+ title=objective,
684
+ xaxis_title=f"{thisyear} $k",
685
+ yaxis_title="Count",
686
+ template=self.template,
687
+ showlegend=True,
688
+ legend=dict(
689
+ yanchor="bottom",
690
+ y=-0.50,
691
+ xanchor="center",
692
+ x=0.5,
693
+ bgcolor="rgba(0, 0, 0, 0)"
694
+ )
695
+ )
696
+
697
+ leads = [objective]
698
+
699
+ # Add statistics to description
700
+ for q in range(len(means)):
701
+ print(f"{leads[q]:>12}: Median ({thisyear} $): {u.d(medians.iloc[q])}", file=description)
702
+ print(f"{leads[q]:>12}: Mean ({thisyear} $): {u.d(means.iloc[q])}", file=description)
703
+ mmin = 1000 * df.iloc[:, q].min()
704
+ mmax = 1000 * df.iloc[:, q].max()
705
+ print(f"{leads[q]:>12}: Range: {u.d(mmin)} - {u.d(mmax)}", file=description)
706
+ nzeros = len(df.iloc[:, q][df.iloc[:, q] < 0.001])
707
+ print(f"{leads[q]:>12}: N zero solns: {nzeros}", file=description)
708
+
709
+ return fig, description
710
+
711
+ return None, description
712
+
713
+ def plot_asset_distribution(self, year_n, inames, b_ijkn, gamma_n, value, name, tag):
714
+ """Plot asset distribution over time."""
715
+ # Set up value formatting
716
+ if value == "nominal":
717
+ yformat = "$k (nominal)"
718
+ infladjust = 1
719
+ else:
720
+ yformat = f"$k ({year_n[0]}$)"
721
+ infladjust = gamma_n
722
+
723
+ # Prepare years array
724
+ years_n = np.array(year_n)
725
+ years_n = np.append(years_n, [years_n[-1] + 1])
726
+
727
+ # Define account and asset type mappings
728
+ jDic = {"taxable": 0, "tax-deferred": 1, "tax-free": 2}
729
+ kDic = {"stocks": 0, "C bonds": 1, "T notes": 2, "common": 3}
730
+
731
+ figures = []
732
+ for jkey in jDic:
733
+ # Create figure for this account type
734
+ fig = go.Figure()
735
+
736
+ # Prepare data for stacking
737
+ stack_data = []
738
+ stack_names = []
739
+ for kkey in kDic:
740
+ namek = f"{kkey} / {jkey}"
741
+ stack_names.append(namek)
742
+
743
+ # Calculate values for each individual
744
+ values = np.zeros((len(inames), len(years_n)))
745
+ for i in range(len(inames)):
746
+ values[i] = b_ijkn[i][jDic[jkey]][kDic[kkey]] / infladjust
747
+
748
+ # Add each individual's data as a separate series
749
+ for i in range(len(inames)):
750
+ if np.sum(values[i]) > 1.0: # Only show non-zero series
751
+ stack_data.append((values[i], f"{namek} {inames[i]}"))
752
+
753
+ # Add stacked area traces
754
+ for data, dname in stack_data:
755
+ fig.add_trace(go.Scatter(
756
+ x=years_n,
757
+ y=data/1000,
758
+ name=dname,
759
+ stackgroup="one",
760
+ fill="tonexty",
761
+ opacity=0.6
762
+ ))
763
+
764
+ # Update layout
765
+ title = f"{name}<br>Assets Distribution - {jkey}"
766
+ if tag:
767
+ title += f" - {tag}"
768
+
769
+ fig.update_layout(
770
+ title=title,
771
+ # xaxis_title="year",
772
+ yaxis_title=yformat,
773
+ template=self.template,
774
+ showlegend=True,
775
+ legend=dict(
776
+ # traceorder="reversed",
777
+ yanchor="bottom",
778
+ y=-0.65,
779
+ xanchor="center",
780
+ x=0.5,
781
+ bgcolor="rgba(0, 0, 0, 0)",
782
+ orientation="h"
783
+ ),
784
+ margin=dict(b=150)
785
+ )
786
+
787
+ # Format y-axis as k
788
+ fig.update_yaxes(tickformat=",.0f")
789
+
790
+ figures.append(fig)
791
+
792
+ return figures
793
+
794
+ def plot_allocations(self, year_n, inames, alpha_ijkn, ARCoord, title):
795
+ """Plot allocations over time."""
796
+ # Determine account types based on coordination
797
+ if ARCoord == "spouses":
798
+ acList = [ARCoord]
799
+ elif ARCoord == "individual":
800
+ acList = [ARCoord]
801
+ elif ARCoord == "account":
802
+ acList = ["taxable", "tax-deferred", "tax-free"]
803
+ else:
804
+ raise ValueError(f"Unknown coordination {ARCoord}.")
805
+
806
+ # Define asset type mapping
807
+ assetDic = {"stocks": 0, "C bonds": 1, "T notes": 2, "common": 3}
808
+
809
+ title = title.replace("\n", "<br>")
810
+ figures = []
811
+ icount = len(inames)
812
+ for i in range(icount):
813
+ for acType in acList:
814
+ # Create figure for this account type
815
+ fig = go.Figure()
816
+
817
+ # Prepare data for stacking
818
+ stack_data = []
819
+ stack_names = []
820
+ for key in assetDic:
821
+ aname = f"{key} / {acType}"
822
+ stack_names.append(aname)
823
+
824
+ # Get allocation data
825
+ data = 100*alpha_ijkn[i, acList.index(acType), assetDic[key], :len(year_n)]
826
+ stack_data.append(data)
827
+
828
+ # Add stacked area traces
829
+ for data, name in zip(stack_data, stack_names):
830
+ fig.add_trace(go.Scatter(
831
+ x=year_n,
832
+ y=data,
833
+ name=name,
834
+ stackgroup="one",
835
+ fill="tonexty",
836
+ opacity=0.6
837
+ ))
838
+
839
+ # Update layout
840
+ plot_title = f"{title} - {acType}"
841
+ fig.update_layout(
842
+ title=plot_title,
843
+ # xaxis_title="year",
844
+ yaxis_title="%",
845
+ template=self.template,
846
+ showlegend=True,
847
+ legend=dict(
848
+ traceorder="reversed",
849
+ yanchor="bottom",
850
+ y=-0.5,
851
+ xanchor="center",
852
+ x=0.5,
853
+ bgcolor="rgba(0, 0, 0, 0)",
854
+ orientation="h"
855
+ ),
856
+ margin=dict(b=150)
857
+ )
858
+
859
+ # Format y-axis as percentage
860
+ fig.update_yaxes(tickformat=".0f")
861
+
862
+ figures.append(fig)
863
+
864
+ return figures
865
+
866
+ def plot_accounts(self, year_n, savings_in, gamma_n, value, title, inames):
867
+ """Plot accounts over time."""
868
+ # Create figure
869
+ fig = go.Figure()
870
+
871
+ # Prepare years array
872
+ year_n_full = np.append(year_n, [year_n[-1] + 1])
873
+
874
+ # Set up value formatting
875
+ if value == "nominal":
876
+ yformat = "$k (nominal)"
877
+ savings = savings_in
878
+ else:
879
+ yformat = f"$k ({year_n[0]}$)"
880
+ savings = {k: v / gamma_n for k, v in savings_in.items()}
881
+
882
+ # Filter out zero series and create individual series names
883
+ nonzero_series = {}
884
+ for sname in savings:
885
+ for i in range(len(inames)):
886
+ data = savings[sname][i] / 1000
887
+ if np.sum(data) > 1.0e-3: # Only show non-zero series
888
+ nonzero_series[f"{sname} {inames[i]}"] = data
889
+
890
+ # Add stacked area traces for each account type
891
+ for account_name, data in nonzero_series.items():
892
+ fig.add_trace(go.Scatter(
893
+ x=year_n_full,
894
+ y=data,
895
+ name=account_name,
896
+ stackgroup="one",
897
+ fill="tonexty",
898
+ opacity=0.6
899
+ ))
900
+
901
+ title = title.replace("\n", "<br>")
902
+ # Update layout
903
+ fig.update_layout(
904
+ title=title,
905
+ # xaxis_title="year",
906
+ yaxis_title=yformat,
907
+ template=self.template,
908
+ showlegend=True,
909
+ legend=dict(
910
+ traceorder="reversed",
911
+ yanchor="bottom",
912
+ y=-0.5,
913
+ xanchor="center",
914
+ x=0.5,
915
+ bgcolor="rgba(0, 0, 0, 0)",
916
+ orientation="h"
917
+ ),
918
+ margin=dict(b=150)
919
+ )
920
+
921
+ # Format y-axis as currency
922
+ fig.update_yaxes(tickformat=",.0f")
923
+
924
+ return fig
925
+
926
+ def plot_sources(self, year_n, sources_in, gamma_n, value, title, inames):
927
+ """Plot sources over time."""
928
+ # Create figure
929
+ fig = go.Figure()
930
+
931
+ # Set up value formatting
932
+ if value == "nominal":
933
+ yformat = "$k (nominal)"
934
+ sources = sources_in
935
+ else:
936
+ yformat = f"$k ({year_n[0]}$)"
937
+ sources = {k: v / gamma_n[:-1] for k, v in sources_in.items()}
938
+
939
+ # Filter out zero series and create individual series names
940
+ nonzero_series = {}
941
+ for sname in sources:
942
+ for i in range(len(inames)):
943
+ data = sources[sname][i] / 1000
944
+ if np.sum(data) > 1.0e-3: # Only show non-zero series
945
+ nonzero_series[f"{sname} {inames[i]}"] = data
946
+
947
+ # Add stacked area traces for each source type
948
+ for source_name, data in nonzero_series.items():
949
+ fig.add_trace(go.Scatter(
950
+ x=year_n,
951
+ y=data,
952
+ name=source_name,
953
+ stackgroup="one",
954
+ fill="tonexty",
955
+ opacity=0.6
956
+ ))
957
+
958
+ title = title.replace("\n", "<br>")
959
+ # Update layout
960
+ fig.update_layout(
961
+ title=title,
962
+ yaxis_title=yformat,
963
+ template=self.template,
964
+ showlegend=True,
965
+ legend_traceorder="reversed",
966
+ legend=dict(
967
+ yanchor="bottom",
968
+ y=-0.75,
969
+ xanchor="center",
970
+ x=0.5,
971
+ bgcolor="rgba(0, 0, 0, 0)",
972
+ orientation="h",
973
+ ),
974
+ margin=dict(b=150)
975
+ )
976
+
977
+ # Format y-axis as k
978
+ fig.update_yaxes(tickformat=",.0f")
979
+
980
+ return fig