validmind 1.7.0__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. validmind/__init__.py +8 -1
  2. validmind/{client.pyx → client.py} +48 -41
  3. validmind/data_validation/{threshold_tests.pyx → threshold_tests.py} +1 -2
  4. validmind/datasets/__init__.py +0 -0
  5. validmind/datasets/classification/{customer_churn.pyx → customer_churn.py} +1 -1
  6. validmind/datasets/classification/datasets/bank_customer_churn.csv +8001 -0
  7. validmind/datasets/classification/datasets/taiwan_credit.csv +30001 -0
  8. validmind/datasets/classification/{taiwan_credit.pyx → taiwan_credit.py} +1 -1
  9. validmind/datasets/regression/__init__.py +55 -1
  10. validmind/datasets/regression/datasets/fred_loan_rates.csv +3552 -0
  11. validmind/datasets/regression/datasets/fred_loan_rates_test_1.csv +126 -0
  12. validmind/datasets/regression/datasets/fred_loan_rates_test_2.csv +126 -0
  13. validmind/datasets/regression/datasets/fred_loan_rates_test_3.csv +126 -0
  14. validmind/datasets/regression/datasets/fred_loan_rates_test_4.csv +126 -0
  15. validmind/datasets/regression/datasets/fred_loan_rates_test_5.csv +126 -0
  16. validmind/datasets/regression/datasets/lending_club_loan_rates.csv +138 -0
  17. validmind/datasets/regression/fred.py +132 -0
  18. validmind/datasets/regression/lending_club.py +70 -0
  19. validmind/datasets/regression/models/fred_loan_rates_model_1.pkl +0 -0
  20. validmind/datasets/regression/models/fred_loan_rates_model_2.pkl +0 -0
  21. validmind/datasets/regression/models/fred_loan_rates_model_3.pkl +0 -0
  22. validmind/datasets/regression/models/fred_loan_rates_model_4.pkl +0 -0
  23. validmind/datasets/regression/models/fred_loan_rates_model_5.pkl +0 -0
  24. validmind/model_validation/sklearn/{threshold_tests.pyx → threshold_tests.py} +9 -9
  25. validmind/model_validation/statsmodels/{metrics.pyx → metrics.py} +123 -138
  26. validmind/test_plans/__init__.py +0 -4
  27. validmind/test_plans/{binary_classifier.pyx → binary_classifier.py} +0 -15
  28. validmind/test_plans/{statsmodels_timeseries.pyx → statsmodels_timeseries.py} +2 -2
  29. validmind/test_plans/{tabular_datasets.pyx → tabular_datasets.py} +0 -13
  30. validmind/test_plans/{time_series.pyx → time_series.py} +3 -3
  31. validmind/test_suites/__init__.py +73 -0
  32. validmind/test_suites/test_suites.py +48 -0
  33. validmind/vm_models/__init__.py +2 -0
  34. validmind/vm_models/{dataset.pyx → dataset.py} +17 -8
  35. validmind/vm_models/test_suite.py +57 -0
  36. {validmind-1.7.0.dist-info → validmind-1.8.1.dist-info}/METADATA +1 -3
  37. validmind-1.8.1.dist-info/RECORD +63 -0
  38. validmind/api_client.c +0 -9481
  39. validmind/api_client.cpython-310-x86_64-linux-gnu.so +0 -0
  40. validmind/client.c +0 -7198
  41. validmind/client.cpython-310-x86_64-linux-gnu.so +0 -0
  42. validmind/datasets/regression/fred.pyx +0 -7
  43. validmind/datasets/regression/lending_club.pyx +0 -7
  44. validmind/model_utils.c +0 -9281
  45. validmind/model_utils.cpython-310-x86_64-linux-gnu.so +0 -0
  46. validmind/utils.c +0 -10284
  47. validmind/utils.cpython-310-x86_64-linux-gnu.so +0 -0
  48. validmind-1.7.0.dist-info/RECORD +0 -53
  49. /validmind/{api_client.pyx → api_client.py} +0 -0
  50. /validmind/data_validation/{metrics.pyx → metrics.py} +0 -0
  51. /validmind/{model_utils.pyx → model_utils.py} +0 -0
  52. /validmind/model_validation/{model_metadata.pyx → model_metadata.py} +0 -0
  53. /validmind/model_validation/sklearn/{metrics.pyx → metrics.py} +0 -0
  54. /validmind/model_validation/statsmodels/{threshold_tests.pyx → threshold_tests.py} +0 -0
  55. /validmind/model_validation/{utils.pyx → utils.py} +0 -0
  56. /validmind/{utils.pyx → utils.py} +0 -0
  57. /validmind/vm_models/{dataset_utils.pyx → dataset_utils.py} +0 -0
  58. /validmind/vm_models/{figure.pyx → figure.py} +0 -0
  59. /validmind/vm_models/{metric.pyx → metric.py} +0 -0
  60. /validmind/vm_models/{metric_result.pyx → metric_result.py} +0 -0
  61. /validmind/vm_models/{model.pyx → model.py} +0 -0
  62. /validmind/vm_models/{plot_utils.pyx → plot_utils.py} +0 -0
  63. /validmind/vm_models/{result_summary.pyx → result_summary.py} +0 -0
  64. /validmind/vm_models/{test_context.pyx → test_context.py} +0 -0
  65. /validmind/vm_models/{test_plan.pyx → test_plan.py} +0 -0
  66. /validmind/vm_models/{test_plan_result.pyx → test_plan_result.py} +0 -0
  67. /validmind/vm_models/{test_result.pyx → test_result.py} +0 -0
  68. /validmind/vm_models/{threshold_test.pyx → threshold_test.py} +0 -0
  69. {validmind-1.7.0.dist-info → validmind-1.8.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,138 @@
1
+ DATE,loan_rate_A,loan_rate_B,loan_rate_C,loan_rate_D
2
+ 2007-08-01,7.7666666666666700,9.497692307692310,10.9475,12.267
3
+ 2007-09-01,7.841428571428570,9.276666666666670,10.829166666666700,12.436666666666700
4
+ 2007-10-01,7.83,9.433333333333330,10.825925925925900,12.737368421052600
5
+ 2007-11-01,7.779090909090910,9.467777777777780,10.967037037037000,12.609444444444400
6
+ 2007-12-01,7.695833333333330,9.3875,10.805,12.47888888888890
7
+ 2008-01-01,7.961333333333330,9.693125,11.288055555555600,13.008200000000000
8
+ 2008-02-01,8.130333333333330,10.035730337078700,11.525636363636400,13.235925925925900
9
+ 2008-03-01,8.126285714285710,10.05709677419360,11.60367816091950,13.145975609756100
10
+ 2008-04-01,8.092083333333330,10.16122807017540,11.61359375,13.258260869565200
11
+ 2008-05-01,8.151428571428570,10.077916666666700,11.634814814814800,12.974
12
+ 2008-06-01,8.2055,10.068181818181800,11.649117647058800,13.290416666666700
13
+ 2008-07-01,8.220434782608700,10.158148148148100,11.78128205128210,13.3132
14
+ 2008-08-01,8.287,10.443636363636400,11.86875,13.460833333333300
15
+ 2008-09-01,8.08,10.433333333333300,11.803333333333300,13.486666666666700
16
+ 2008-10-01,8.612857142857140,10.757333333333300,12.13025641025640,13.681764705882400
17
+ 2008-11-01,8.558787878787880,10.911111111111100,12.349166666666700,13.83516129032260
18
+ 2008-12-01,9.199166666666670,11.432987012987000,12.764520547945200,14.423714285714300
19
+ 2009-01-01,8.956065573770490,11.618840579710100,13.162727272727300,14.634807692307700
20
+ 2009-02-01,8.898615384615390,12.01972972972970,13.168222222222200,14.635967741935500
21
+ 2009-03-01,8.821111111111110,12.035806451612900,13.182000000000000,14.603529411764700
22
+ 2009-04-01,8.685268817204300,12.012857142857100,13.133461538461500,14.720821917808200
23
+ 2009-05-01,8.944868421052630,11.619561403508800,13.155242718446600,14.541750000000000
24
+ 2009-06-01,8.933404255319150,11.6745,13.11355140186920,14.507916666666700
25
+ 2009-07-01,8.965098039215690,11.591623931623900,13.06843137254900,14.669272727272700
26
+ 2009-08-01,8.472935779816510,11.800933333333300,13.355959595959600,15.127391304347800
27
+ 2009-09-01,8.443076923076920,11.84740506329110,13.55984,15.162857142857100
28
+ 2009-10-01,8.436535433070870,11.86804347826090,13.482191780821900,15.3096875
29
+ 2009-11-01,8.363513513513510,11.807222222222200,13.499345238095200,15.217190082644600
30
+ 2009-12-01,8.409424460431660,11.87390134529150,13.519559748427700,15.220645161290300
31
+ 2010-01-01,8.144044117647060,11.607019230769200,13.482960526315800,15.173942307692300
32
+ 2010-02-01,7.517205882352940,10.687883817427400,13.35064705882350,15.231326530612200
33
+ 2010-03-01,7.457604790419160,10.6644,13.337486910994800,15.195703703703700
34
+ 2010-04-01,7.498095238095240,10.649088050314500,13.424000000000000,15.150000000000000
35
+ 2010-05-01,7.4599090909090900,10.723706293706300,13.427931034482800,15.29125
36
+ 2010-06-01,7.52636752136752,11.225825242718400,13.756233766233800,15.530058823529400
37
+ 2010-07-01,7.45978102189781,11.22090909090910,13.86310606060610,15.569890710382500
38
+ 2010-08-01,7.452052631578950,11.248456973293800,13.835163636363600,15.517061855670100
39
+ 2010-09-01,7.498099173553720,11.170552325581400,13.797773279352200,15.597903225806500
40
+ 2010-10-01,6.891980519480520,10.575714285714300,13.492008368200800,15.294080459770100
41
+ 2010-11-01,6.3666201117318400,9.736618497109830,12.948222222222200,14.853270440251600
42
+ 2010-12-01,6.319193954659950,9.722658959537570,12.96254752851710,14.805481927710800
43
+ 2011-01-01,6.582382133995040,10.1541475826972,13.06860655737710,15.108478260869600
44
+ 2011-02-01,6.80948717948718,10.467368421052600,13.297196261682200,15.395933333333300
45
+ 2011-03-01,6.834241573033710,10.474313253012000,13.317777777777800,15.337074468085100
46
+ 2011-04-01,6.847228260869570,10.450381526104400,13.228860294117600,15.29950226244340
47
+ 2011-05-01,7.165149253731340,11.099935760171300,13.817382352941200,16.48267716535430
48
+ 2011-06-01,7.018872651356990,11.108265682656800,13.876910112359600,16.51079831932770
49
+ 2011-07-01,7.083340563991320,11.067542955326500,13.886373937677100,16.48506172839510
50
+ 2011-08-01,7.0728650646950100,11.086491228070200,13.819787234042600,16.492789855072500
51
+ 2011-09-01,7.084527131782950,11.254157814871000,14.151633802816900,16.680548523206800
52
+ 2011-10-01,7.351312410841660,11.598674121405800,14.406036036036000,17.25779661016950
53
+ 2011-11-01,7.4149581239531,11.555904628331000,14.424929906542100,17.232988505747100
54
+ 2011-12-01,7.589288888888890,11.589472222222200,14.50002105263160,17.223344155844200
55
+ 2012-01-01,7.525056980056980,11.635782556750300,14.443313953488400,17.08759375
56
+ 2012-02-01,7.610733695652170,11.597185929648200,14.58754756871040,17.420031746031700
57
+ 2012-03-01,7.579030100334450,11.886394763343400,14.80486220472440,18.253422818791900
58
+ 2012-04-01,7.5876,11.925473321858900,14.861176470588200,18.16556786703600
59
+ 2012-05-01,7.476845102505700,12.016376554174100,14.810244252873600,18.233905472636800
60
+ 2012-06-01,7.5403636363636400,12.15405750798720,14.900022573363400,18.17427304964540
61
+ 2012-07-01,7.6588717339667500,12.227395115842200,15.313709677419400,18.569943181818200
62
+ 2012-08-01,7.574986474301170,12.34835970464140,15.457079288025900,18.649875862069000
63
+ 2012-09-01,7.596716681376880,12.335429844098000,15.458238267148000,18.6446126340882
64
+ 2012-10-01,7.683201877934270,12.275302245250400,15.453846710050600,18.670659574468100
65
+ 2012-11-01,7.61423155737705,12.337397568662800,15.451,18.65142300194930
66
+ 2012-12-01,7.822746823069400,12.1065440464666,15.75559186189890,18.592207478890200
67
+ 2013-01-01,7.752121546961330,12.137295401403000,15.77893272635310,18.554923913043500
68
+ 2013-02-01,7.698670360110800,12.109676284306800,15.784549549549500,18.576124721603600
69
+ 2013-03-01,7.720320610687020,12.052285819793200,15.80080592105260,18.641330998248700
70
+ 2013-04-01,7.946095025983670,12.05531282405810,15.81847364568080,18.610517928286900
71
+ 2013-05-01,7.573445527015060,12.067179799209000,15.781325174825200,18.6713679245283
72
+ 2013-06-01,7.520518828451880,11.944434831147100,15.796356209150300,18.66170493685420
73
+ 2013-07-01,7.605866466616650,11.357245404868400,15.23495371752760,18.689437963944900
74
+ 2013-08-01,7.606510948905110,11.37418853255590,15.196370835609000,18.714716883116900
75
+ 2013-09-01,7.695094055680960,11.49977115613830,15.379319657231900,18.779754016064300
76
+ 2013-10-01,7.8066564417177900,11.753876332622600,15.673991344073100,18.88925198690980
77
+ 2013-11-01,7.837439117199390,11.892656587473000,15.553595594020500,18.89677792612370
78
+ 2013-12-01,7.888719008264460,11.915244541484700,15.146307539188900,18.39415185783520
79
+ 2014-01-01,7.884319131161240,11.903197773972600,15.017733510402800,18.154655041698300
80
+ 2014-02-01,7.84595130748422,11.806563223714700,14.829327827483400,18.013829250720500
81
+ 2014-03-01,7.881642651296830,11.75884096438840,14.674028332003200,17.929929601072700
82
+ 2014-04-01,7.886479457218240,11.759009712435700,14.685187592867800,17.975035360678900
83
+ 2014-05-01,7.529020283199390,11.264750000000000,14.191762042738100,17.158339091150700
84
+ 2014-06-01,7.351805118994160,11.008507223114000,13.938348755912000,16.76545196134170
85
+ 2014-07-01,7.351720283533260,10.991443699731900,13.962106592877800,16.78448344701380
86
+ 2014-08-01,7.305786460087570,10.92469630557300,13.931857451403900,16.821789838337200
87
+ 2014-09-01,7.1805243268776600,10.92935610522180,13.932417061611400,16.81521568627450
88
+ 2014-10-01,7.313578708946770,10.99294086810650,13.902686392661200,16.741794034723300
89
+ 2014-11-01,7.341760975609760,10.609953738248000,13.599219115824300,16.518561545801500
90
+ 2014-12-01,7.2349226160758900,10.451943852167700,13.558856439127400,16.457166246851400
91
+ 2015-01-01,7.284862446138550,10.523150939274800,13.580668403647700,16.49141867927720
92
+ 2015-02-01,7.126818078064370,10.230664872566100,13.3875424992544,16.665386680988200
93
+ 2015-03-01,7.053031877213700,10.09241054613940,13.314671177266600,16.752682622787300
94
+ 2015-04-01,7.028140703517590,10.05959205467750,13.339224530168200,16.770982191041600
95
+ 2015-05-01,6.942790697674420,10.075693915181300,13.309596621846900,16.8081057178116
96
+ 2015-06-01,6.924908338261380,10.016840254294600,13.311009585460000,16.775532012897300
97
+ 2015-07-01,6.846263779016990,9.973752109959010,13.315842231914600,16.76153384747220
98
+ 2015-08-01,6.895340965654560,9.938369829683700,13.294216701173200,16.7434092634776
99
+ 2015-09-01,6.883837661010020,9.956684323369910,13.28493008739080,16.776011560693600
100
+ 2015-10-01,6.857467413674820,9.9480314478252,13.308447903156100,16.75174123337360
101
+ 2015-11-01,6.818325145839730,10.012865733184400,13.114165689017200,16.679007398273700
102
+ 2015-12-01,6.813621571972180,9.920643664396380,13.113938948995400,16.720045469308200
103
+ 2016-01-01,6.854723867456380,9.953370355808930,13.313586302637700,16.981329494896600
104
+ 2016-02-01,6.844791755508170,9.962073382887870,13.51001321003960,17.7306925087108
105
+ 2016-03-01,6.51158106060606,9.981660457385230,13.518506973417600,17.82086815920400
106
+ 2016-04-01,6.459687188434700,10.000419839473100,13.54169860126460,17.77998596913210
107
+ 2016-05-01,6.905898270251760,9.948223350253810,13.501587610160700,17.805547243003600
108
+ 2016-06-01,7.12222120518688,10.17323210606470,13.780876188418300,18.284878048780500
109
+ 2016-07-01,7.409693532818530,10.48121790772040,14.012340662161000,18.673088637251100
110
+ 2016-08-01,7.415675891431610,10.430686851479000,14.018968108344300,18.725051626729000
111
+ 2016-09-01,7.309269304403320,10.420682593856700,13.99552615321250,18.80164383561640
112
+ 2016-10-01,7.155384800384800,10.461156528477300,14.06901384239570,18.808550010269000
113
+ 2016-11-01,7.026321948889420,10.626164935530100,14.117570806100200,18.7734665882814
114
+ 2016-12-01,6.989425942156000,10.574707292338900,14.143388450148100,18.831436605317000
115
+ 2017-01-01,6.922970447631470,10.593491575712600,14.140238904713800,18.479066305818700
116
+ 2017-02-01,6.9446164739202300,10.640543691703500,14.111862055097500,18.06909398034400
117
+ 2017-03-01,6.893785237086530,10.600121526226400,14.140748368044800,18.110767457939600
118
+ 2017-04-01,6.89110229645094,10.622533764284900,14.136295552367300,18.084069239500600
119
+ 2017-05-01,6.936799809795530,10.557189495458200,14.272662721893500,18.665688052068800
120
+ 2017-06-01,7.015996470069130,10.566750889988700,14.37513654096230,19.046772499403200
121
+ 2017-07-01,7.0019612448620100,10.566760062456600,14.345889606150700,19.107410634495100
122
+ 2017-08-01,6.985308885383810,10.576519034772200,14.36054723695700,19.087508
123
+ 2017-09-01,7.08636862745098,10.51879065667700,14.307342195158200,19.013268023587600
124
+ 2017-10-01,7.113857827476040,10.522347111638200,14.222900293532800,18.95540403155130
125
+ 2017-11-01,6.925136533803130,10.513319831983200,14.159044066620400,18.923875141517100
126
+ 2017-12-01,6.8353083235638900,10.52431454716500,14.15896581775930,18.952280408231600
127
+ 2018-01-01,6.77865500934785,10.516662401754700,14.194534883720900,18.959176493330800
128
+ 2018-02-01,6.770555486361940,10.53045831202050,14.188890264490700,19.05002990175140
129
+ 2018-03-01,6.715374727894680,10.540809179770500,14.160395894428200,19.323733236680000
130
+ 2018-04-01,6.704338970023060,10.562063789868700,14.168836298629400,19.367843168957200
131
+ 2018-05-01,6.768701702826520,10.630962232202000,14.457164179104500,19.322808271248300
132
+ 2018-06-01,6.812492979436540,10.706152809367400,14.628350747613000,19.307152249134900
133
+ 2018-07-01,7.1668788793826500,11.116104527240200,15.085822900484300,19.769436012142500
134
+ 2018-08-01,7.218996976568410,11.161285669241700,15.142618464766600,19.85760335716510
135
+ 2018-09-01,7.2012805625971100,11.191917953668000,15.139769364664900,19.74845895522390
136
+ 2018-10-01,7.228498120038990,11.208417521704800,15.129104518736200,19.79216299416720
137
+ 2018-11-01,7.536896956485290,11.390483083511800,15.12686942339060,19.632696590118300
138
+ 2018-12-01,7.715209008701590,11.45963094824590,15.107476048212600,19.558346273291900
@@ -0,0 +1,132 @@
1
+ import os
2
+ import pickle
3
+ import pandas as pd
4
+
5
+ current_path = os.path.dirname(os.path.abspath(__file__))
6
+ dataset_path = os.path.join(current_path, "datasets")
7
+ models_path = os.path.join(current_path, "models")
8
+
9
+
10
+ target_column = "MORTGAGE30US"
11
+ feature_columns = ["FEDFUNDS", "GS10", "UNRATE"]
12
+ frequency = "MS"
13
+ split_option = "train_test"
14
+ transform_func = "diff"
15
+
16
+
17
+ def load_data():
18
+ data_file = os.path.join(dataset_path, "fred_loan_rates.csv")
19
+ df = pd.read_csv(data_file, parse_dates=["DATE"], index_col="DATE")
20
+ df = df[[target_column] + feature_columns]
21
+ return df
22
+
23
+
24
+ def preprocess(df, split_option="train_test_val", train_size=0.6, test_size=0.2):
25
+ """
26
+ Split a time series DataFrame into train, validation, and test sets.
27
+
28
+ Parameters:
29
+ df (pandas.DataFrame): The time series DataFrame to be split.
30
+ split_option (str): The split option to choose from: 'train_test_val' (default) or 'train_test'.
31
+ train_size (float): The proportion of the dataset to include in the training set. Default is 0.6.
32
+ test_size (float): The proportion of the dataset to include in the test set. Default is 0.2.
33
+
34
+ Returns:
35
+ train_df (pandas.DataFrame): The training set.
36
+ validation_df (pandas.DataFrame): The validation set (only returned if split_option is 'train_test_val').
37
+ test_df (pandas.DataFrame): The test set.
38
+
39
+ """
40
+ # Sort the DataFrame by the time column (assuming the time column is the index)
41
+ df = df.sort_index()
42
+
43
+ if split_option == "train_test_val":
44
+ # Split the DataFrame into train, validation, and test sets
45
+ train_size = int(len(df) * train_size)
46
+ val_size = int(len(df) * test_size)
47
+
48
+ train_df = df.iloc[:train_size]
49
+ validation_df = df.iloc[train_size : train_size + val_size]
50
+ test_df = df.iloc[train_size + val_size :]
51
+
52
+ return train_df, validation_df, test_df
53
+
54
+ elif split_option == "train_test":
55
+ # Split the DataFrame into train and test sets
56
+ train_size = int(len(df) * train_size)
57
+
58
+ train_df = df.iloc[:train_size]
59
+ test_df = df.iloc[train_size:]
60
+
61
+ return train_df, test_df
62
+
63
+ else:
64
+ raise ValueError(
65
+ "Invalid split_option. Must be 'train_test_val' or 'train_test'."
66
+ )
67
+
68
+
69
+ def transform(df, transform_func="diff"):
70
+ if transform_func == "diff":
71
+ df = df.diff().dropna()
72
+ return df
73
+
74
+
75
+ def load_model(model_name):
76
+ model_file = model_name + ".pkl"
77
+ model_path = os.path.join(models_path, model_file)
78
+
79
+ if os.path.isfile(model_path):
80
+ with open(model_path, "rb") as f:
81
+ model = pickle.load(f)
82
+
83
+ train_df = load_train_dataset(model_path)
84
+ test_df = load_test_dataset(model_name)
85
+
86
+ return model, train_df, test_df
87
+ else:
88
+ print(f"No model file found with the name: {model_name}")
89
+ return None, None, None
90
+
91
+
92
+ def load_train_dataset(model_path):
93
+ with open(model_path, "rb") as f:
94
+ model_fit = pickle.load(f)
95
+
96
+ # Extract the endogenous (target) variable from the model
97
+ train_df = pd.Series(model_fit.model.endog, index=model_fit.model.data.row_labels)
98
+ train_df = train_df.to_frame()
99
+ target_var_name = model_fit.model.endog_names
100
+ train_df.columns = [target_var_name]
101
+
102
+ # Extract the exogenous (explanatory) variables from the model
103
+ exog_df = pd.DataFrame(
104
+ model_fit.model.exog,
105
+ index=model_fit.model.data.row_labels,
106
+ columns=model_fit.model.exog_names,
107
+ )
108
+
109
+ # Concatenate the endogenous (target) and exogenous (explanatory) variables
110
+ train_df = pd.concat([train_df, exog_df], axis=1)
111
+
112
+ return train_df
113
+
114
+
115
+ def load_test_dataset(model_name):
116
+ if model_name == "fred_loan_rates_model_1":
117
+ filename = "fred_loan_rates_test_1.csv"
118
+ elif model_name == "fred_loan_rates_model_2":
119
+ filename = "fred_loan_rates_test_2.csv"
120
+ elif model_name == "fred_loan_rates_model_3":
121
+ filename = "fred_loan_rates_test_3.csv"
122
+ elif model_name == "fred_loan_rates_model_4":
123
+ filename = "fred_loan_rates_test_4.csv"
124
+ elif model_name == "fred_loan_rates_model_5":
125
+ filename = "fred_loan_rates_test_5.csv"
126
+ else:
127
+ return None
128
+
129
+ data_file = os.path.join(dataset_path, filename)
130
+ df = pd.read_csv(data_file, parse_dates=["DATE"], index_col="DATE")
131
+ df = df.diff().dropna()
132
+ return df
@@ -0,0 +1,70 @@
1
+ import os
2
+
3
+ import pandas as pd
4
+
5
+ current_path = os.path.dirname(os.path.abspath(__file__))
6
+ dataset_path = os.path.join(
7
+ current_path, "..", "..", "..", "notebooks", "datasets", "time_series"
8
+ )
9
+
10
+ target_column = ["loan_rate_A"]
11
+ feature_columns = ["loan_rate_B", "loan_rate_C", "loan_rate_D"]
12
+ frequency = "MS"
13
+ split_option = "train_test"
14
+
15
+
16
+ def load_data():
17
+ data_file = os.path.join(dataset_path, "lending_club_loan_rates.csv")
18
+ df = pd.read_csv(data_file, parse_dates=["DATE"], index_col="DATE")
19
+ return df
20
+
21
+
22
+ def preprocess(df, split_option="train_test_val", train_size=0.6, test_size=0.2):
23
+ """
24
+ Split a time series DataFrame into train, validation, and test sets.
25
+
26
+ Parameters:
27
+ df (pandas.DataFrame): The time series DataFrame to be split.
28
+ split_option (str): The split option to choose from: 'train_test_val' (default) or 'train_test'.
29
+ train_size (float): The proportion of the dataset to include in the training set. Default is 0.6.
30
+ test_size (float): The proportion of the dataset to include in the test set. Default is 0.2.
31
+
32
+ Returns:
33
+ train_df (pandas.DataFrame): The training set.
34
+ validation_df (pandas.DataFrame): The validation set (only returned if split_option is 'train_test_val').
35
+ test_df (pandas.DataFrame): The test set.
36
+
37
+ """
38
+ # Sort the DataFrame by the time column (assuming the time column is the index)
39
+ df = df.sort_index()
40
+
41
+ if split_option == "train_test_val":
42
+ # Split the DataFrame into train, validation, and test sets
43
+ train_size = int(len(df) * train_size)
44
+ val_size = int(len(df) * test_size)
45
+
46
+ train_df = df.iloc[:train_size]
47
+ validation_df = df.iloc[train_size : train_size + val_size]
48
+ test_df = df.iloc[train_size + val_size :]
49
+
50
+ return train_df, validation_df, test_df
51
+
52
+ elif split_option == "train_test":
53
+ # Split the DataFrame into train and test sets
54
+ train_size = int(len(df) * train_size)
55
+
56
+ train_df = df.iloc[:train_size]
57
+ test_df = df.iloc[train_size:]
58
+
59
+ return train_df, test_df
60
+
61
+ else:
62
+ raise ValueError(
63
+ "Invalid split_option. Must be 'train_test_val' or 'train_test'."
64
+ )
65
+
66
+
67
+ def transform(df, transform_func="diff"):
68
+ if transform_func == "diff":
69
+ df = df.diff().dropna()
70
+ return df
@@ -203,12 +203,12 @@ class OverfitDiagnosis(ThresholdTest):
203
203
  prediction_column = f"{target_column}_pred"
204
204
 
205
205
  # Add prediction column in the training dataset
206
- train_df = self.model.train_ds.df.copy(deep=True)
206
+ train_df = self.model.train_ds.df.copy()
207
207
  train_class_pred = self.model.class_predictions(self.model.y_train_predict)
208
208
  train_df[prediction_column] = train_class_pred
209
209
 
210
210
  # Add prediction column in the test dataset
211
- test_df = self.model.test_ds.df.copy(deep=True)
211
+ test_df = self.model.test_ds.df.copy()
212
212
  test_class_pred = self.model.class_predictions(self.model.y_test_predict)
213
213
  test_df[prediction_column] = test_class_pred
214
214
 
@@ -293,7 +293,7 @@ class OverfitDiagnosis(ThresholdTest):
293
293
 
294
294
  results_train = pd.DataFrame(results_train)
295
295
  results_test = pd.DataFrame(results_test)
296
- results = results_train.copy(deep=True)
296
+ results = results_train.copy()
297
297
  results.rename(
298
298
  columns={"shape": "training records", "accuracy": "training accuracy"},
299
299
  inplace=True,
@@ -458,11 +458,11 @@ class WeakspotsDiagnosis(ThresholdTest):
458
458
  target_column = self.model.train_ds.target_column
459
459
  prediction_column = f"{target_column}_pred"
460
460
 
461
- train_df = self.model.train_ds.df.copy(deep=True)
461
+ train_df = self.model.train_ds.df.copy()
462
462
  train_class_pred = self.model.class_predictions(self.model.y_train_predict)
463
463
  train_df[prediction_column] = train_class_pred
464
464
 
465
- test_df = self.model.test_ds.df.copy(deep=True)
465
+ test_df = self.model.test_ds.df.copy()
466
466
  test_class_pred = self.model.class_predictions(self.model.y_test_predict)
467
467
  test_df[prediction_column] = test_class_pred
468
468
 
@@ -704,10 +704,10 @@ class RobustnessDiagnosis(ThresholdTest):
704
704
  if self.model.train_ds.target_column in features_list:
705
705
  features_list.remove(self.model.train_ds.target_column)
706
706
 
707
- train_df = self.model.train_ds.x.copy(deep=True)
707
+ train_df = self.model.train_ds.x.copy()
708
708
  train_y_true = self.model.train_ds.y
709
709
 
710
- test_df = self.model.test_ds.x.copy(deep=True)
710
+ test_df = self.model.test_ds.x.copy()
711
711
  test_y_true = self.model.test_ds.y
712
712
 
713
713
  test_results = []
@@ -719,8 +719,8 @@ class RobustnessDiagnosis(ThresholdTest):
719
719
 
720
720
  # Iterate scaling factor for the standard deviation list
721
721
  for x_std_dev in x_std_dev_list:
722
- temp_train_df = train_df.copy(deep=True)
723
- temp_test_df = test_df.copy(deep=True)
722
+ temp_train_df = train_df.copy()
723
+ temp_test_df = test_df.copy()
724
724
 
725
725
  # Add noise to numeric features columns provided by user
726
726
  for feature in features_list: