jrpybestpracccc 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. jrpybestpracccc/__init__.py +3 -0
  2. jrpybestpracccc/__version__.py +1 -0
  3. jrpybestpracccc/extra_scripts/example_projects/bad_project/.gitignore +3 -0
  4. jrpybestpracccc/extra_scripts/example_projects/bad_project/data/.gitignore +0 -0
  5. jrpybestpracccc/extra_scripts/example_projects/bad_project/experiment.ipynb +193 -0
  6. jrpybestpracccc/extra_scripts/example_projects/bad_project/setup.cfg +7 -0
  7. jrpybestpracccc/extra_scripts/example_projects/good_project/.gitignore +3 -0
  8. jrpybestpracccc/extra_scripts/example_projects/good_project/data/.gitignore +0 -0
  9. jrpybestpracccc/extra_scripts/example_projects/good_project/experiment.ipynb +1024 -0
  10. jrpybestpracccc/extra_scripts/example_projects/good_project/setup.cfg +7 -0
  11. jrpybestpracccc/extra_scripts/example_projects/good_project/src/__init__.py +0 -0
  12. jrpybestpracccc/extra_scripts/example_projects/good_project/src/config/filenames.py +3 -0
  13. jrpybestpracccc/extra_scripts/example_projects/good_project/src/config/theme.py +147 -0
  14. jrpybestpracccc/extra_scripts/example_projects/good_project/src/data/__init__.py +2 -0
  15. jrpybestpracccc/extra_scripts/example_projects/good_project/src/data/preprocessing.py +341 -0
  16. jrpybestpracccc/extra_scripts/example_projects/good_project/src/data/storage.py +99 -0
  17. jrpybestpracccc/extra_scripts/example_projects/good_project/src/modelling/__init__.py +4 -0
  18. jrpybestpracccc/extra_scripts/example_projects/good_project/src/modelling/ingredients.py +58 -0
  19. jrpybestpracccc/extra_scripts/example_projects/good_project/src/modelling/predictions.py +18 -0
  20. jrpybestpracccc/extra_scripts/example_projects/good_project/src/modelling/scoring.py +18 -0
  21. jrpybestpracccc/extra_scripts/example_projects/good_project/src/modelling/training.py +42 -0
  22. jrpybestpracccc-0.1.1.dist-info/METADATA +14 -0
  23. jrpybestpracccc-0.1.1.dist-info/RECORD +24 -0
  24. jrpybestpracccc-0.1.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,7 @@
1
+ [flake8]
2
+ max-line-length = 88
3
+ per-file-ignores = [
4
+ utils/data/__init__.py:F401,
5
+ utils/modelling/__init__.py:F401,
6
+ utils/deployment/__init__.py:F401,
7
+ ]
@@ -0,0 +1,3 @@
1
+ """Commonly used filenames and filepaths"""
2
+
3
+ DATA_PATH = "data"
@@ -0,0 +1,147 @@
1
+ import matplotlib.pyplot as plt
2
+ import plotly.graph_objects as go
3
+ import plotly.io as pio
4
+ from cycler import cycler
5
+
6
+ Colors = {
7
+ "vibrant purple": "#7041FF",
8
+ "orange": "#FFAC00",
9
+ "aubergine": "#280049",
10
+ "lavender": "#8C57CC",
11
+ "purple 4": "#AB6B99",
12
+ "purple 5": "#CA7880",
13
+ "red": "#FF2200",
14
+ "yellow": "#FFFF4B",
15
+ "green": "#A2D800",
16
+ "forest green": "#1A5F31",
17
+ "sea green": "#369992",
18
+ "sky blue": "#AEC5EB",
19
+ }
20
+
21
+ # set default colours to use
22
+ plt.rcParams["axes.prop_cycle"] = cycler(color=list(Colors.values()))
23
+
24
+ # style plots to match CCC branding
25
+ plt.rcParams["figure.figsize"] = [6, 3]
26
+ plt.rcParams["figure.constrained_layout.use"] = True
27
+ plt.rcParams["figure.dpi"] = 120
28
+ plt.rcParams["axes.grid"] = True
29
+ plt.rcParams["axes.grid.axis"] = "y"
30
+ plt.rcParams["axes.labelweight"] = "bold"
31
+ plt.rcParams["axes.titlecolor"] = Colors["vibrant purple"]
32
+ plt.rcParams["axes.titleweight"] = "bold"
33
+ plt.rcParams["axes.labelcolor"] = Colors["vibrant purple"]
34
+ plt.rcParams["axes.edgecolor"] = Colors["vibrant purple"]
35
+ plt.rcParams["axes.spines.left"] = False
36
+ plt.rcParams["axes.spines.right"] = False
37
+ plt.rcParams["axes.spines.top"] = False
38
+ plt.rcParams["grid.linewidth"] = 0.4
39
+ plt.rcParams["grid.color"] = "silver"
40
+ plt.rcParams["xtick.color"] = Colors["vibrant purple"]
41
+ plt.rcParams["ytick.color"] = Colors["vibrant purple"]
42
+ plt.rcParams["ytick.left"] = False
43
+ plt.rcParams["legend.frameon"] = False
44
+ plt.rcParams["legend.labelcolor"] = Colors["vibrant purple"]
45
+ plt.rcParams["font.family"] = "century gothic"
46
+
47
+ SCENARIO_COLORS = {
48
+ "Baseline": Colors["orange"],
49
+ "Pathway": Colors["vibrant purple"],
50
+ "Historical": Colors["aubergine"],
51
+ }
52
+
53
+ # Plotly template
54
+ ccc_template = go.layout.Template(
55
+ layout=go.Layout(
56
+ width=720,
57
+ height=472,
58
+
59
+ font=dict(
60
+ family="Century Gothic, sans-serif",
61
+ size=12,
62
+ color=Colors["vibrant purple"]
63
+ ),
64
+
65
+ title=dict(
66
+ font=dict(
67
+ family="Century Gothic, sans-serif",
68
+ size=14,
69
+ color=Colors["vibrant purple"]
70
+ ),
71
+ x=0.5,
72
+ xanchor="center"
73
+ ),
74
+
75
+ plot_bgcolor="white",
76
+ paper_bgcolor="white",
77
+
78
+ xaxis=dict(
79
+ showgrid=False,
80
+ showline=True,
81
+ linewidth=1,
82
+ linecolor=Colors["vibrant purple"],
83
+ tickcolor=Colors["vibrant purple"],
84
+ title_font=dict(
85
+ family="Century Gothic, sans-serif",
86
+ size=14,
87
+ color=Colors["vibrant purple"]
88
+ ),
89
+ tickfont=dict(
90
+ family="Century Gothic, sans-serif",
91
+ size=12,
92
+ color=Colors["vibrant purple"]
93
+ ),
94
+ mirror=False,
95
+ zeroline=False
96
+ ),
97
+
98
+ yaxis=dict(
99
+ showgrid=True,
100
+ gridwidth=0.4,
101
+ gridcolor="silver",
102
+ showline=False,
103
+ zeroline=False,
104
+ tickcolor=Colors["vibrant purple"],
105
+ ticks="",
106
+ title_font=dict(
107
+ family="Century Gothic, sans-serif",
108
+ size=14,
109
+ color=Colors["vibrant purple"]
110
+ ),
111
+ tickfont=dict(
112
+ family="Century Gothic, sans-serif",
113
+ size=12,
114
+ color=Colors["vibrant purple"]
115
+ )
116
+ ),
117
+
118
+ legend=dict(
119
+ bgcolor="rgba(0,0,0,0)",
120
+ bordercolor="rgba(0,0,0,0)",
121
+ font=dict(
122
+ family="Century Gothic, sans-serif",
123
+ color=Colors["vibrant purple"]
124
+ ),
125
+ orientation="v",
126
+ yanchor="top",
127
+ y=1,
128
+ xanchor="left",
129
+ x=0.01
130
+ ),
131
+
132
+ colorway=list(Colors.values()),
133
+
134
+ hoverlabel=dict(
135
+ bgcolor="white",
136
+ font_size=12,
137
+ font_family="Century Gothic, sans-serif",
138
+ bordercolor=Colors["vibrant purple"]
139
+ ),
140
+
141
+ margin=dict(l=60, r=20, t=60, b=50)
142
+ )
143
+ )
144
+
145
+ # Register and set as default
146
+ pio.templates["ccc"] = ccc_template
147
+ pio.templates.default = "ccc"
@@ -0,0 +1,2 @@
1
+ from .storage import load_data, save_data
2
+ from .preprocessing import calibrate_transformer, preprocess, transform_data
@@ -0,0 +1,341 @@
1
+ """Functions for dealing with data preprocessing, including cleaning, scaling and
2
+ one-hot encoding
3
+ """
4
+
5
+ import pandas as pd
6
+ from sklearn.compose import ColumnTransformer
7
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
8
+
9
+ from ..config.filenames import DATA_PATH
10
+
11
+
12
+ def clean_data(data):
13
+ """Remove missing values and clean up variable names
14
+
15
+ :parameters:
16
+ data: pd.DataFrame
17
+ Pandas DataFrame with the data to be cleaned
18
+
19
+ :returns:
20
+ pd.DataFrame object with the tidy data
21
+ """
22
+ data.dropna(inplace=True)
23
+ data = data.clean_names()
24
+ return data
25
+
26
+
27
+ def join_on_archetypes(
28
+ df: pd.DataFrame, archetype_df: pd.DataFrame
29
+ ) -> pd.DataFrame:
30
+
31
+ """
32
+ This function handles the joining of the two capex datasets with the
33
+ archetypes dataset.
34
+
35
+ Args
36
+ ------
37
+ df: the capex DataFrame to join with.
38
+ archetype_df: the archetype DataFrame.
39
+
40
+ Returns
41
+ ------
42
+
43
+ df: the new DataFrame with the archetype data joined on.
44
+
45
+ """
46
+
47
+ df = df.merge(archetype_df,
48
+ left_on="Heating system",
49
+ right_on="Heating system",
50
+ how="inner")
51
+
52
+ return df
53
+
54
+
55
+ def get_size_scaled_capex(capex_by_power_df: pd.DataFrame) -> pd.DataFrame:
56
+ """
57
+ This function handles the scaling of the capex amounts by
58
+ heating size for the capex by power data.
59
+
60
+ Args
61
+ -----
62
+ capex_by_power_df: the capex by power DataFrame.
63
+
64
+ Returns
65
+ ------
66
+ capex_by_power_df: the scaled capex DataFrame.
67
+ """
68
+
69
+ year_cols = [
70
+ col for col in capex_by_power_df.columns if col.startswith("20")
71
+ ]
72
+
73
+ capex_by_power_df[year_cols] = (
74
+ capex_by_power_df[year_cols]
75
+ .multiply(capex_by_power_df["Heating system size"], axis=0)
76
+ )
77
+
78
+ return capex_by_power_df
79
+
80
+
81
+ def concatenate_dfs(
82
+ capex_by_power_df: pd.DataFrame,
83
+ fixed_capex_df: pd.DataFrame,
84
+ ) -> pd.DataFrame:
85
+
86
+ """
87
+ Concatenates the capex by power DataFrame and the fixed capex DataFrame
88
+ into a single DataFrame.
89
+
90
+ Args
91
+ -----
92
+ capex_by_power_df: DataFrame containing capex data scaled by power.
93
+ fixed_capex_df: DataFrame containing fixed capex data.
94
+
95
+ Returns
96
+ ------
97
+ all_capex_df: The concatenated DataFrame containing all capex data.
98
+ """
99
+
100
+ all_capex_df = pd.concat([capex_by_power_df, fixed_capex_df])
101
+
102
+ return all_capex_df
103
+
104
+
105
+ def adjust_prices(all_capex_df: pd.DataFrame) -> pd.DataFrame:
106
+
107
+ """
108
+ Adjusts the capex values in the DataFrame for each year by multiplying
109
+ them with the corresponding price index.
110
+
111
+ Args
112
+ -----
113
+ all_capex_df: DataFrame containing capex data with year columns and a
114
+ 'Price index' column.
115
+
116
+ Returns
117
+ ------
118
+ all_capex_df: DataFrame with year columns adjusted by the price index.
119
+ """
120
+
121
+ year_cols = [col for col in all_capex_df.columns if col.startswith("20")]
122
+ all_capex_df[year_cols] = (
123
+ all_capex_df[year_cols]
124
+ .multiply(all_capex_df["Price index"], axis=0)
125
+ )
126
+
127
+ return all_capex_df
128
+
129
+
130
+ def get_grouped_total_capex(all_capex_df: pd.DataFrame) -> pd.DataFrame:
131
+
132
+ """
133
+ Groups the capex DataFrame by 'Constrained archetype number' and sums the
134
+ capex values for each year.
135
+
136
+ Args
137
+ -----
138
+ all_capex_df: DataFrame containing capex data with year columns and a
139
+ 'Constrained archetype number' column.
140
+
141
+ Returns
142
+ ------
143
+ all_capex_df: DataFrame with summed capex values for each 'Constrained
144
+ archetype number' and year.
145
+ """
146
+
147
+ year_cols = [col for col in all_capex_df.columns if col.startswith("20")]
148
+ agg_dict = {
149
+ col: 'first' for col in all_capex_df.columns if col not in year_cols
150
+ }
151
+ agg_dict.update({col: 'sum' for col in year_cols})
152
+
153
+ summed_capex_df = (
154
+ all_capex_df
155
+ .groupby(by="Constrained archetype number", as_index=False)
156
+ .agg(agg_dict)
157
+ )
158
+
159
+ return summed_capex_df
160
+
161
+
162
+ def drop_and_add_columns(
163
+ all_capex_df: pd.DataFrame,
164
+ columns_to_drop=["Cost type variant", "Assumptions", "Heating system"],
165
+ ) -> pd.DataFrame:
166
+
167
+ """
168
+ Drops specified columns from the DataFrame, removes duplicate rows, and
169
+ adds two new columns.
170
+
171
+ -----
172
+ all_capex_df: DataFrame from which columns will be dropped.
173
+ columns_to_drop: List of column names to be dropped from the DataFrame
174
+ (default: ["Cost type variant", "Assumptions", "Heating system"]).
175
+
176
+ Returns
177
+ ------
178
+ all_capex_df: DataFrame with specified columns dropped, new columns added,
179
+ and duplicates removed.
180
+ """
181
+
182
+ all_capex_df = all_capex_df.drop(columns=columns_to_drop)
183
+
184
+ all_capex_df["Variable unit"] = "£"
185
+ all_capex_df["Data name"] = "Baseline capex per home by renewal year"
186
+
187
+ return all_capex_df
188
+
189
+
190
+ def create_output(
191
+ all_capex_df: pd.DataFrame,
192
+ output_filename="Baseline heating system capex.csv",
193
+ output_path=DATA_PATH,
194
+ ):
195
+
196
+ """
197
+ Writes the provided DataFrame to a CSV file at the specified output path.
198
+
199
+ Args
200
+ -----
201
+ all_capex_df: DataFrame to be written to CSV.
202
+ output_filename: Name of the output CSV file (default: "Baseline heating
203
+ system capex.csv").
204
+ output_path: Path where the CSV file will be saved (default: DATA_PATH).
205
+
206
+ Returns
207
+ ------
208
+ None
209
+ """
210
+
211
+ all_capex_df.to_csv(output_path/output_filename, index=False)
212
+
213
+ return
214
+
215
+
216
+ def create_long_format(df: pd.DataFrame, index: int = -31):
217
+ """
218
+ Converts an input DataFrame into long format
219
+
220
+ parameters
221
+ ----------
222
+ df: pd.DataFrame
223
+ The data should initially have capex columns for the years
224
+ "2020" to "2050". These typically span index -31 to the end of
225
+ the DataFrame.
226
+ index: int (default -31)
227
+ Index of the "2020" capex column
228
+
229
+ returns
230
+ -------
231
+ pd.DataFrame
232
+ Long-format version of the data with a "Year" column and
233
+ "Capex" column
234
+ """
235
+ columns = list(df.columns)
236
+ return df.melt(
237
+ id_vars=columns[:index],
238
+ value_vars=columns[index:],
239
+ var_name="Year",
240
+ value_name="Capex",
241
+ )
242
+
243
+
244
+ def preprocess(
245
+ archetype_df: pd.DataFrame,
246
+ fixed_capex_df: pd.DataFrame,
247
+ capex_by_power_df: pd.DataFrame,
248
+ long_format: bool = False,
249
+ index: int = -31,
250
+ ) -> pd.DataFrame:
251
+ """
252
+ Run all of the preprocessing steps on the input data
253
+
254
+ Args
255
+ -----
256
+ archetype_df: The archetype DataFrame.
257
+ fixed_capex_df: The fixed capex DataFrame.
258
+ capex_by_power_df: The capex by power.
259
+ long_format: If True, convert the output data into long
260
+ format with "Year" and "Capex" columns (False by default)
261
+ index: The index of the "2020" column in the preprocessed
262
+ data (-31 by default)
263
+
264
+ Returns
265
+ ------
266
+ pd.DataFrame
267
+ Preprocessed DataFrame
268
+ """
269
+
270
+ fixed_capex_df = join_on_archetypes(
271
+ df=fixed_capex_df, archetype_df=archetype_df
272
+ )
273
+ capex_by_power_df = join_on_archetypes(
274
+ df=capex_by_power_df, archetype_df=archetype_df
275
+ )
276
+
277
+ capex_by_power_df = get_size_scaled_capex(
278
+ capex_by_power_df=capex_by_power_df
279
+ )
280
+
281
+ all_capex_df = concatenate_dfs(
282
+ capex_by_power_df=capex_by_power_df,
283
+ fixed_capex_df=fixed_capex_df,
284
+ )
285
+ all_capex_df = get_grouped_total_capex(all_capex_df=all_capex_df)
286
+ all_capex_df = drop_and_add_columns(all_capex_df=all_capex_df)
287
+
288
+ if long_format:
289
+ return create_long_format(
290
+ df=all_capex_df,
291
+ index=index,
292
+ )
293
+
294
+ return all_capex_df
295
+
296
+
297
+ def calibrate_transformer(
298
+ data: pd.DataFrame,
299
+ numerical_features: list,
300
+ categorical_features: list,
301
+ ):
302
+ """Calibrate a data transformer using the input data
303
+
304
+ :parameters:
305
+ data: pd.DataFrame
306
+ Input data used for calibrating the transformer
307
+ numerical_features: list
308
+ List of the numerical variables to be transformed
309
+ categorical_features: list
310
+ List of the categorical variables to be transformed
311
+
312
+ :returns:
313
+ sklearn ColumnTransformer object for applying consistent scaling
314
+ and encoding to the data
315
+ """
316
+ transformer = ColumnTransformer(
317
+ [
318
+ ("num", StandardScaler(), numerical_features),
319
+ ("cat", OneHotEncoder(drop="first"), categorical_features),
320
+ ]
321
+ )
322
+ transformer.fit(data)
323
+ return transformer
324
+
325
+
326
+ def transform_data(
327
+ data: pd.DataFrame,
328
+ transformer: ColumnTransformer,
329
+ ):
330
+ """ Scale and encode data using a pre-calibrated transformer
331
+
332
+ :parameters:
333
+ data: pd.DataFrame
334
+ Data to be preprocessed
335
+ transformer: sklearn ColumnTransformer
336
+ Pre-calibrated transformer for scaling and encoding the data
337
+
338
+ :returns:
339
+ pd.DataFrame object with transformed data
340
+ """
341
+ return transformer.transform(data)
@@ -0,0 +1,99 @@
1
+ """Functions for managing direct interactions with the raw data"""
2
+
3
+ import os
4
+ import pandas as pd
5
+
6
+ from ..config.filenames import DATA_PATH
7
+
8
+
9
+ def construct_data_filepath(filename="raw_data"):
10
+ """
11
+ Construct filepath to data CSV file
12
+
13
+ Args
14
+ ----
15
+
16
+ filename: str (default: "raw_data")
17
+ Filename of the data (without the .csv extension)
18
+
19
+ Returns
20
+ -------
21
+ string with the full filepath for loading and writing the data
22
+ """
23
+ file = f"{filename}.csv"
24
+ return os.path.join(DATA_PATH, file)
25
+
26
+
27
+ def load_csv(filename: str, header_height: int = 4):
28
+ """
29
+ Loads a CSV file from the project data folder
30
+
31
+ parameters
32
+ ----------
33
+ filename: str
34
+ Name of the CSV file to load
35
+
36
+ header_height: int
37
+ Number of rows to skip at the top of the CSV
38
+ (4 by default)
39
+
40
+ returns
41
+ -------
42
+ pd.DataFrame
43
+ """
44
+ filepath = construct_data_filepath(filename)
45
+ return pd.read_csv(filepath, skiprows=header_height)
46
+
47
+
48
+ def load_data() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
49
+
50
+ """
51
+ Reads in the three CSV files used in the workflow.
52
+
53
+ returns
54
+ -------
55
+
56
+ archetype_df: DataFrame containing the archetypes data.
57
+ fixed_capex_df: DataFrame containing the fixed capex data.
58
+ capex_by_power_df: DataFrame containing the capex by power data.
59
+
60
+ """
61
+
62
+ archetype_df = load_csv(
63
+ filename="Baseline archetype properties",
64
+ header_height=0,
65
+ )
66
+
67
+ # Skip the first 4 rows as they contain metadata or headers not relevant
68
+ # to the DataFrame
69
+ fixed_capex_df = load_csv(
70
+ filename="Input Baseline capex fixed",
71
+ header_height=4,
72
+ )
73
+ # Skip the first 4 rows as they contain metadata or headers not relevant
74
+ # to the DataFrame
75
+ capex_by_power_df = load_csv(
76
+ filename="Input Baseline capex by power",
77
+ header_height=4,
78
+ )
79
+
80
+ return archetype_df, fixed_capex_df, capex_by_power_df
81
+
82
+
83
+ def save_data(data, output_filename="raw_data.csv"):
84
+ """
85
+ Save data to a CSV file in the project data folder
86
+
87
+ parameters
88
+ ----------
89
+ data: pd.DataFrame
90
+ Pandas DataFrame containing the data to be saved
91
+ output: str
92
+ Output filename (without extension) for the CSV data
93
+
94
+ returns
95
+ -------
96
+ None
97
+ """
98
+ filepath = construct_data_filepath(output_filename)
99
+ data.to_csv(filepath, index=False)
@@ -0,0 +1,4 @@
1
+ from .ingredients import extract_target_variable, setup_model, setup_pipeline
2
+ from .predictions import model_predict
3
+ from .scoring import score_model
4
+ from .training import create_test_data, train_model
@@ -0,0 +1,58 @@
1
+ """
2
+ Functions associated with setting up the model object and the data input
3
+ variables
4
+ """
5
+
6
+ from sklearn.linear_model import LinearRegression
7
+ from sklearn.pipeline import Pipeline
8
+
9
+
10
+ def extract_target_variable(data, target, covariates):
11
+ """Extract target variable to be predicted by trained model
12
+
13
+ :parameters:
14
+ data: pd.DataFrame
15
+ Pandas DataFrame with the raw data from which to select the target variable
16
+ target: str
17
+ String with the name of the column to set as the target variable y
18
+ covariates: list
19
+ List of column names to set as the covariates x
20
+
21
+ :returns:
22
+ Two outputs as tuple (y, X):
23
+ - Target variable column
24
+ - Predictor variable columns
25
+ """
26
+ y = data[target]
27
+ x = data[covariates]
28
+ return y, x
29
+
30
+
31
+ def setup_model():
32
+ """Constucts a model object (currently LinearRegression) to use for training
33
+
34
+ :returns:
35
+ Sklearn model class
36
+ """
37
+ return LinearRegression()
38
+
39
+
40
+ def setup_pipeline(transformer, model):
41
+ """Constructs a pipeline for data transformation and modelling
42
+
43
+ :parameters:
44
+ transformer: Sklearn ColumnTransformer object
45
+ Calibrated transformer for applying scaling and one-hot encoding to data
46
+ model: Sklean model object
47
+ Model to be trained on data
48
+
49
+ :returns:
50
+ Sklearn Pipeline class including a calibrated data transformer and a model that is
51
+ ready for training
52
+ """
53
+ return Pipeline(
54
+ [
55
+ ("transform", transformer),
56
+ ("model", model),
57
+ ]
58
+ )
@@ -0,0 +1,18 @@
1
+ """
2
+ Functions involved in generating model predictions using the sklearn model
3
+ """
4
+
5
+
6
+ def model_predict(model, data):
7
+ """Get predictions from model
8
+
9
+ :parameters:
10
+ model: sklearn model or pipeline object
11
+ Trained sklearn model or pipeline ready for predictions
12
+ data: 2-D data structure
13
+ Untransformed input data for predictions
14
+
15
+ :returns:
16
+ 1-D data structure with the model predictions
17
+ """
18
+ return model.predict(data)
@@ -0,0 +1,18 @@
1
+ """Scoring functions for model testing"""
2
+
3
+ from sklearn.metrics import mean_absolute_error
4
+
5
+
6
+ def score_model(y_true, y_pred):
7
+ """Scoring function for the mean absolute error
8
+
9
+ :parameters:
10
+ y_true: 1-D data structure
11
+ True values of the target variable
12
+ y_pred: 1-D data structure
13
+ Predicted values of the target variable
14
+
15
+ :returns:
16
+ Maximum absolute error
17
+ """
18
+ return mean_absolute_error(y_true, y_pred)