imsciences 0.9.7.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
imsciences/pull.py CHANGED
@@ -1,27 +1,26 @@
1
- import pandas as pd
2
- import numpy as np
1
+ import importlib
3
2
  import re
4
- from fredapi import Fred
5
3
  import time
4
+ import urllib.request
5
+ import xml.etree.ElementTree as ET
6
6
  from datetime import datetime, timedelta
7
7
  from io import StringIO
8
+
9
+ import numpy as np
10
+ import pandas as pd
8
11
  import requests
9
- import xml.etree.ElementTree as ET
10
- from bs4 import BeautifulSoup
11
12
  import yfinance as yf
12
- import holidays
13
+ from bs4 import BeautifulSoup
13
14
  from dateutil.easter import easter
14
- import urllib.request
15
+ from fredapi import Fred
15
16
  from geopy.geocoders import Nominatim
16
- import importlib
17
- import workalendar
18
17
 
19
18
  from imsciences.mmm import dataprocessing
20
19
 
21
20
  ims_proc = dataprocessing()
22
21
 
22
+
23
23
  class datapull:
24
-
25
24
  def help(self):
26
25
  print("This is the help section. The functions in the package are as follows:")
27
26
 
@@ -36,8 +35,12 @@ class datapull:
36
35
  print(" - Example: pull_boe_data('mon')")
37
36
 
38
37
  print("\n3. pull_oecd")
39
- print(" - Description: Fetch macroeconomic data from OECD for a specified country.")
40
- print(" - Usage: pull_oecd(country='GBR', week_commencing='mon', start_date: '2020-01-01')")
38
+ print(
39
+ " - Description: Fetch macroeconomic data from OECD for a specified country.",
40
+ )
41
+ print(
42
+ " - Usage: pull_oecd(country='GBR', week_commencing='mon', start_date: '2020-01-01')",
43
+ )
41
44
  print(" - Example: pull_oecd('GBR', 'mon', '2000-01-01')")
42
45
 
43
46
  print("\n4. get_google_mobility_data")
@@ -46,89 +49,136 @@ class datapull:
46
49
  print(" - Example: get_google_mobility_data('United Kingdom', 'mon')")
47
50
 
48
51
  print("\n5. pull_seasonality")
49
- print(" - Description: Generate combined dummy variables for seasonality, trends, and COVID lockdowns.")
52
+ print(
53
+ " - Description: Generate combined dummy variables for seasonality, trends, and COVID lockdowns.",
54
+ )
50
55
  print(" - Usage: pull_seasonality(week_commencing, start_date, countries)")
51
56
  print(" - Example: pull_seasonality('mon', '2020-01-01', ['US', 'GB'])")
52
57
 
53
58
  print("\n6. pull_weather")
54
- print(" - Description: Fetch and process historical weather data for the specified country.")
59
+ print(
60
+ " - Description: Fetch and process historical weather data for the specified country.",
61
+ )
55
62
  print(" - Usage: pull_weather(week_commencing, start_date, country)")
56
63
  print(" - Example: pull_weather('mon', '2020-01-01', ['GBR'])")
57
-
64
+
58
65
  print("\n7. pull_macro_ons_uk")
59
- print(" - Description: Fetch and process time series data from the Beta ONS API.")
66
+ print(
67
+ " - Description: Fetch and process time series data from the Beta ONS API.",
68
+ )
60
69
  print(" - Usage: pull_macro_ons_uk(aditional_list, week_commencing, sector)")
61
70
  print(" - Example: pull_macro_ons_uk(['HBOI'], 'mon', 'fast_food')")
62
-
71
+
63
72
  print("\n8. pull_yfinance")
64
- print(" - Description: Fetch and process time series data from the Beta ONS API.")
73
+ print(
74
+ " - Description: Fetch and process time series data from the Beta ONS API.",
75
+ )
65
76
  print(" - Usage: pull_yfinance(tickers, week_start_day)")
66
77
  print(" - Example: pull_yfinance(['^FTMC', '^IXIC'], 'mon')")
67
-
78
+
68
79
  print("\n9. pull_sports_events")
69
- print(" - Description: Pull a veriety of sports events primaraly football and rugby.")
80
+ print(
81
+ " - Description: Pull a veriety of sports events primaraly football and rugby.",
82
+ )
70
83
  print(" - Usage: pull_sports_events(start_date, week_commencing)")
71
84
  print(" - Example: pull_sports_events('2020-01-01', 'mon')")
72
85
 
73
86
  ############################################################### MACRO ##########################################################################
74
87
 
75
- def pull_fred_data(self, week_commencing: str = 'mon', series_id_list: list[str] = ["GPDIC1", "Y057RX1Q020SBEA", "GCEC1"]) -> pd.DataFrame:
76
- '''
88
+ def pull_fred_data(
89
+ self,
90
+ week_commencing: str = "mon",
91
+ series_id_list: list[str] = ["GPDIC1", "Y057RX1Q020SBEA", "GCEC1"],
92
+ ) -> pd.DataFrame:
93
+ """
77
94
  Parameters
78
95
  ----------
79
96
  week_commencing : str
80
97
  specify the day for the week commencing, the default is 'sun' (e.g., 'mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun')
81
98
 
82
99
  series_id_list : list[str]
83
- provide a list with IDs to download data series from FRED (link: https://fred.stlouisfed.org/tags/series?t=id). Default list is
100
+ provide a list with IDs to download data series from FRED (link: https://fred.stlouisfed.org/tags/series?t=id). Default list is
84
101
  ["GPDIC1", "Y057RX1Q020SBEA", "GCEC1"]
85
-
102
+
86
103
  Returns
87
- ----------
104
+ -------
88
105
  pd.DataFrame
89
106
  Return a data frame with FRED data according to the series IDs provided
90
- '''
107
+
108
+ """
91
109
  # Fred API
92
- fred = Fred(api_key='76f5f8156145fdb8fbaf66f1eb944f8a')
110
+ fred = Fred(api_key="76f5f8156145fdb8fbaf66f1eb944f8a")
93
111
 
94
112
  # Fetch the metadata for each series to get the full names
95
- series_names = {series_id: fred.get_series_info(series_id).title for series_id in series_id_list}
113
+ series_names = {
114
+ series_id: fred.get_series_info(series_id).title
115
+ for series_id in series_id_list
116
+ }
96
117
 
97
118
  # Download data from series id list
98
- fred_series = {series_id: fred.get_series(series_id) for series_id in series_id_list}
119
+ fred_series = {
120
+ series_id: fred.get_series(series_id) for series_id in series_id_list
121
+ }
99
122
 
100
123
  # Data processing
101
- date_range = {'OBS': pd.date_range("1950-01-01", datetime.today().strftime('%Y-%m-%d'), freq='d')}
124
+ date_range = {
125
+ "OBS": pd.date_range(
126
+ "1950-01-01",
127
+ datetime.today().strftime("%Y-%m-%d"),
128
+ freq="d",
129
+ ),
130
+ }
102
131
  fred_series_df = pd.DataFrame(date_range)
103
132
 
104
133
  for series_id, series_data in fred_series.items():
105
134
  series_data = series_data.reset_index()
106
- series_data.columns = ['OBS', series_names[series_id]] # Use the series name as the column header
107
- fred_series_df = pd.merge_asof(fred_series_df, series_data, on='OBS', direction='backward')
135
+ series_data.columns = [
136
+ "OBS",
137
+ series_names[series_id],
138
+ ] # Use the series name as the column header
139
+ fred_series_df = pd.merge_asof(
140
+ fred_series_df,
141
+ series_data,
142
+ on="OBS",
143
+ direction="backward",
144
+ )
108
145
 
109
146
  # Handle duplicate columns
110
147
  for col in fred_series_df.columns:
111
- if '_x' in col:
112
- base_col = col.replace('_x', '')
113
- fred_series_df[base_col] = fred_series_df[col].combine_first(fred_series_df[base_col + '_y'])
114
- fred_series_df.drop([col, base_col + '_y'], axis=1, inplace=True)
148
+ if "_x" in col:
149
+ base_col = col.replace("_x", "")
150
+ fred_series_df[base_col] = fred_series_df[col].combine_first(
151
+ fred_series_df[base_col + "_y"],
152
+ )
153
+ fred_series_df.drop([col, base_col + "_y"], axis=1, inplace=True)
115
154
 
116
155
  # Ensure sum_columns are present in the DataFrame
117
- sum_columns = [series_names[series_id] for series_id in series_id_list if series_names[series_id] in fred_series_df.columns]
156
+ sum_columns = [
157
+ series_names[series_id]
158
+ for series_id in series_id_list
159
+ if series_names[series_id] in fred_series_df.columns
160
+ ]
118
161
 
119
162
  # Aggregate results by week
120
- fred_df_final = ims_proc.aggregate_daily_to_wc_wide(df=fred_series_df,
121
- date_column="OBS",
122
- group_columns=[],
123
- sum_columns=sum_columns,
124
- wc=week_commencing,
125
- aggregation="average")
163
+ fred_df_final = ims_proc.aggregate_daily_to_wc_wide(
164
+ df=fred_series_df,
165
+ date_column="OBS",
166
+ group_columns=[],
167
+ sum_columns=sum_columns,
168
+ wc=week_commencing,
169
+ aggregation="average",
170
+ )
126
171
 
127
172
  # Remove anything after the instance of any ':' in the column names and rename, except for 'OBS'
128
- fred_df_final.columns = ['OBS' if col == 'OBS' else 'macro_' + col.lower().split(':')[0].replace(' ', '_') for col in fred_df_final.columns]
173
+ fred_df_final.columns = [
174
+ "OBS"
175
+ if col == "OBS"
176
+ else "macro_" + col.lower().split(":")[0].replace(" ", "_")
177
+ for col in fred_df_final.columns
178
+ ]
129
179
 
130
180
  return fred_df_final
131
-
181
+
132
182
  def pull_boe_data(self, week_commencing="mon", max_retries=5, delay=5):
133
183
  """
134
184
  Fetch and process Bank of England interest rate data.
@@ -144,12 +194,21 @@ class datapull:
144
194
  pd.DataFrame: A DataFrame with weekly aggregated Bank of England interest rates.
145
195
  The 'OBS' column contains the week commencing dates in 'dd/mm/yyyy' format
146
196
  and 'macro_boe_intr_rate' contains the average interest rate for the week.
197
+
147
198
  """
148
199
  # Week commencing dictionary
149
- day_dict = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
200
+ day_dict = {
201
+ "mon": 0,
202
+ "tue": 1,
203
+ "wed": 2,
204
+ "thu": 3,
205
+ "fri": 4,
206
+ "sat": 5,
207
+ "sun": 6,
208
+ }
150
209
 
151
210
  # URL of the Bank of England data page
152
- url = 'https://www.bankofengland.co.uk/boeapps/database/Bank-Rate.asp'
211
+ url = "https://www.bankofengland.co.uk/boeapps/database/Bank-Rate.asp"
153
212
 
154
213
  # Retry logic for HTTP request
155
214
  for attempt in range(max_retries):
@@ -159,7 +218,7 @@ class datapull:
159
218
  "User-Agent": (
160
219
  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) "
161
220
  "Chrome/91.0.4472.124 Safari/537.36"
162
- )
221
+ ),
163
222
  }
164
223
  response = requests.get(url, headers=headers)
165
224
  response.raise_for_status() # Raise an exception for HTTP errors
@@ -177,10 +236,15 @@ class datapull:
177
236
  # Find the table on the page
178
237
  table = soup.find("table") # Locate the first table
179
238
  table_html = str(table) # Convert table to string
180
- df = pd.read_html(StringIO(table_html))[0] # Use StringIO to wrap the table HTML
239
+ df = pd.read_html(StringIO(table_html))[
240
+ 0
241
+ ] # Use StringIO to wrap the table HTML
181
242
 
182
243
  # Rename and clean up columns
183
- df.rename(columns={"Date Changed": "OBS", "Rate": "macro_boe_intr_rate"}, inplace=True)
244
+ df.rename(
245
+ columns={"Date Changed": "OBS", "Rate": "macro_boe_intr_rate"},
246
+ inplace=True,
247
+ )
184
248
  df["OBS"] = pd.to_datetime(df["OBS"], format="%d %b %y")
185
249
  df.sort_values("OBS", inplace=True)
186
250
 
@@ -190,7 +254,7 @@ class datapull:
190
254
 
191
255
  # Adjust each date to the specified week commencing day
192
256
  df_daily["Week_Commencing"] = df_daily["OBS"].apply(
193
- lambda x: x - timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
257
+ lambda x: x - timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7),
194
258
  )
195
259
 
196
260
  # Merge and forward-fill missing rates
@@ -198,58 +262,135 @@ class datapull:
198
262
  df_daily["macro_boe_intr_rate"] = df_daily["macro_boe_intr_rate"].ffill()
199
263
 
200
264
  # Group by week commencing and calculate the average rate
201
- df_final = df_daily.groupby("Week_Commencing")["macro_boe_intr_rate"].mean().reset_index()
202
- df_final["Week_Commencing"] = df_final["Week_Commencing"].dt.strftime('%d/%m/%Y')
265
+ df_final = (
266
+ df_daily.groupby("Week_Commencing")["macro_boe_intr_rate"]
267
+ .mean()
268
+ .reset_index()
269
+ )
270
+ df_final["Week_Commencing"] = df_final["Week_Commencing"].dt.strftime(
271
+ "%d/%m/%Y",
272
+ )
203
273
  df_final.rename(columns={"Week_Commencing": "OBS"}, inplace=True)
204
274
 
205
275
  return df_final
206
-
207
- def pull_oecd(self, country: str = "GBR", week_commencing: str = "mon", start_date: str = "2020-01-01") -> pd.DataFrame:
276
+
277
+ def pull_oecd(
278
+ self,
279
+ country: str = "GBR",
280
+ week_commencing: str = "mon",
281
+ start_date: str = "2020-01-01",
282
+ ) -> pd.DataFrame:
208
283
  """
209
284
  Fetch and process time series data from the OECD API.
210
285
 
211
286
  Args:
212
287
  country (list): A string containing a 3-letter code the of country of interest (E.g: "GBR", "FRA", "USA", "DEU")
213
- week_commencing (str): The starting day of the week for aggregation.
288
+ week_commencing (str): The starting day of the week for aggregation.
214
289
  Options are "mon", "tue", "wed", "thu", "fri", "sat", "sun".
215
290
  start_date (str): Dataset start date in the format "YYYY-MM-DD"
216
291
 
217
292
  Returns:
218
- pd.DataFrame: A DataFrame with weekly aggregated OECD data. The 'OBS' column contains the week
293
+ pd.DataFrame: A DataFrame with weekly aggregated OECD data. The 'OBS' column contains the week
219
294
  commencing dates, and other columns contain the aggregated time series values.
220
- """
295
+
296
+ """
221
297
 
222
298
  def parse_quarter(date_str):
223
299
  """Parses a string in 'YYYY-Q#' format into a datetime object."""
224
- year, quarter = date_str.split('-')
300
+ year, quarter = date_str.split("-")
225
301
  quarter_number = int(quarter[1])
226
302
  month = (quarter_number - 1) * 3 + 1
227
303
  return pd.Timestamp(f"{year}-{month:02d}-01")
228
304
 
229
305
  # Generate a date range from 1950-01-01 to today
230
- date_range = pd.date_range(start=start_date, end=datetime.today(), freq='D')
306
+ date_range = pd.date_range(start=start_date, end=datetime.today(), freq="D")
231
307
 
232
308
  url_details = [
233
- ["BCICP", "SDD.STES,DSD_STES@DF_CLI,", ".....", "macro_business_confidence_index"],
234
- ["CCICP", "SDD.STES,DSD_STES@DF_CLI,", ".....", "macro_consumer_confidence_index"],
235
- ["N.CPI", "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,", "PA._T.N.GY", "macro_cpi_total"],
236
- ["N.CPI", "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,", "PA.CP041T043.N.GY", "macro_cpi_housing"],
237
- ["N.CPI", "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,", "PA.CP01.N.GY", "macro_cpi_food"],
238
- ["N.CPI", "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,", "PA.CP045_0722.N.GY", "macro_cpi_energy"],
239
- ["UNE_LF_M", "SDD.TPS,DSD_LFS@DF_IALFS_UNE_M,", "._Z.Y._T.Y_GE15.", "macro_unemployment_rate"],
240
- ["EAR", "SDD.TPS,DSD_EAR@DF_HOU_EAR,", ".Y..S1D", "macro_private_hourly_earnings"],
241
- ["RHP", "ECO.MPD,DSD_AN_HOUSE_PRICES@DF_HOUSE_PRICES,1.0", "", "macro_real_house_prices"],
242
- ["PRVM", "SDD.STES,DSD_KEI@DF_KEI,4.0", "IX.C..", "macro_manufacturing_production_volume"],
243
- ["TOVM", "SDD.STES,DSD_KEI@DF_KEI,4.0", "IX...", "macro_retail_trade_volume"],
244
- ["IRSTCI", "SDD.STES,DSD_KEI@DF_KEI,4.0", "PA...", "macro_interbank_rate"],
245
- ["IRLT", "SDD.STES,DSD_KEI@DF_KEI,4.0", "PA...", "macro_long_term_interest_rate"],
246
- ["B1GQ", "SDD.NAD,DSD_NAMAIN1@DF_QNA,1.1", "._Z....GY.T0102", "macro_gdp_growth_yoy"]
309
+ [
310
+ "BCICP",
311
+ "SDD.STES,DSD_STES@DF_CLI,",
312
+ ".....",
313
+ "macro_business_confidence_index",
314
+ ],
315
+ [
316
+ "CCICP",
317
+ "SDD.STES,DSD_STES@DF_CLI,",
318
+ ".....",
319
+ "macro_consumer_confidence_index",
320
+ ],
321
+ [
322
+ "N.CPI",
323
+ "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,",
324
+ "PA._T.N.GY",
325
+ "macro_cpi_total",
326
+ ],
327
+ [
328
+ "N.CPI",
329
+ "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,",
330
+ "PA.CP041T043.N.GY",
331
+ "macro_cpi_housing",
332
+ ],
333
+ [
334
+ "N.CPI",
335
+ "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,",
336
+ "PA.CP01.N.GY",
337
+ "macro_cpi_food",
338
+ ],
339
+ [
340
+ "N.CPI",
341
+ "SDD.TPS,DSD_PRICES@DF_PRICES_ALL,",
342
+ "PA.CP045_0722.N.GY",
343
+ "macro_cpi_energy",
344
+ ],
345
+ [
346
+ "UNE_LF_M",
347
+ "SDD.TPS,DSD_LFS@DF_IALFS_UNE_M,",
348
+ "._Z.Y._T.Y_GE15.",
349
+ "macro_unemployment_rate",
350
+ ],
351
+ [
352
+ "EAR",
353
+ "SDD.TPS,DSD_EAR@DF_HOU_EAR,",
354
+ ".Y..S1D",
355
+ "macro_private_hourly_earnings",
356
+ ],
357
+ [
358
+ "RHP",
359
+ "ECO.MPD,DSD_AN_HOUSE_PRICES@DF_HOUSE_PRICES,1.0",
360
+ "",
361
+ "macro_real_house_prices",
362
+ ],
363
+ [
364
+ "PRVM",
365
+ "SDD.STES,DSD_KEI@DF_KEI,4.0",
366
+ "IX.C..",
367
+ "macro_manufacturing_production_volume",
368
+ ],
369
+ [
370
+ "TOVM",
371
+ "SDD.STES,DSD_KEI@DF_KEI,4.0",
372
+ "IX...",
373
+ "macro_retail_trade_volume",
374
+ ],
375
+ ["IRSTCI", "SDD.STES,DSD_KEI@DF_KEI,4.0", "PA...", "macro_interbank_rate"],
376
+ [
377
+ "IRLT",
378
+ "SDD.STES,DSD_KEI@DF_KEI,4.0",
379
+ "PA...",
380
+ "macro_long_term_interest_rate",
381
+ ],
382
+ [
383
+ "B1GQ",
384
+ "SDD.NAD,DSD_NAMAIN1@DF_QNA,1.1",
385
+ "._Z....GY.T0102",
386
+ "macro_gdp_growth_yoy",
387
+ ],
247
388
  ]
248
389
 
249
390
  # Create empty final dataframe
250
391
  oecd_df_final = pd.DataFrame()
251
392
 
252
- daily_df = pd.DataFrame({'OBS': date_range})
393
+ daily_df = pd.DataFrame({"OBS": date_range})
253
394
  value_columns = []
254
395
 
255
396
  # Iterate for each variable of interest
@@ -260,8 +401,7 @@ class datapull:
260
401
  col_name = series_details[3]
261
402
 
262
403
  # check if request was successful and determine the most granular data available
263
- for freq in ['M', 'Q', 'A']:
264
-
404
+ for freq in ["M", "Q", "A"]:
265
405
  if series in ["UNE_LF_M", "EAR"]:
266
406
  data_url = f"https://sdmx.oecd.org/public/rest/data/OECD.{dataset_id}/{country}.{series}.{filter}.{freq}?startPeriod=1950-01"
267
407
  elif series in ["B1GQ"]:
@@ -274,13 +414,14 @@ class datapull:
274
414
 
275
415
  # Check if the request was successful
276
416
  if data_response.status_code != 200:
277
- print(f"Failed to fetch data for series {series} with frequency '{freq}' for {country}: {data_response.status_code} {data_response.text}")
417
+ print(
418
+ f"Failed to fetch data for series {series} with frequency '{freq}' for {country}: {data_response.status_code} {data_response.text}",
419
+ )
278
420
  url_test = False
279
421
  continue
280
- else:
281
- url_test = True
282
- break
283
-
422
+ url_test = True
423
+ break
424
+
284
425
  # get data for the next variable if url doesn't exist
285
426
  if url_test is False:
286
427
  continue
@@ -288,21 +429,24 @@ class datapull:
288
429
  root = ET.fromstring(data_response.content)
289
430
 
290
431
  # Define namespaces if necessary (the namespace is included in the tags)
291
- namespaces = {'generic': 'http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic'}
432
+ namespaces = {
433
+ "generic": "http://www.sdmx.org/resources/sdmxml/schemas/v2_1/data/generic",
434
+ }
292
435
 
293
436
  # Lists to store the data
294
437
  dates = []
295
438
  values = []
296
439
 
297
440
  # Iterate over all <Obs> elements and extract date and value
298
- for obs in root.findall('.//generic:Obs', namespaces):
299
-
441
+ for obs in root.findall(".//generic:Obs", namespaces):
300
442
  # Extracting the time period (date)
301
- time_period = obs.find('.//generic:ObsDimension', namespaces).get('value')
302
-
443
+ time_period = obs.find(".//generic:ObsDimension", namespaces).get(
444
+ "value",
445
+ )
446
+
303
447
  # Extracting the observation value
304
- value = obs.find('.//generic:ObsValue', namespaces).get('value')
305
-
448
+ value = obs.find(".//generic:ObsValue", namespaces).get("value")
449
+
306
450
  # Storing the data
307
451
  if time_period and value:
308
452
  dates.append(time_period)
@@ -312,888 +456,995 @@ class datapull:
312
456
  value_columns.append(col_name)
313
457
 
314
458
  # Creating a DataFrame
315
- data = pd.DataFrame({'OBS': dates, col_name: values})
459
+ data = pd.DataFrame({"OBS": dates, col_name: values})
316
460
 
317
461
  # Convert date strings into datetime format
318
- if freq == 'Q':
319
- data['OBS'] = data['OBS'].apply(parse_quarter)
462
+ if freq == "Q":
463
+ data["OBS"] = data["OBS"].apply(parse_quarter)
320
464
  else:
321
465
  # Display the DataFrame
322
- data['OBS'] = data['OBS'].apply(lambda x: datetime.strptime(x, '%Y-%m'))
466
+ data["OBS"] = data["OBS"].apply(lambda x: datetime.strptime(x, "%Y-%m"))
323
467
 
324
468
  # Sort data by chronological order
325
- data.sort_values(by='OBS', inplace=True)
469
+ data.sort_values(by="OBS", inplace=True)
326
470
 
327
471
  # Merge the data based on the observation date
328
- daily_df = pd.merge_asof(daily_df, data[['OBS', col_name]], on='OBS', direction='backward')
329
-
472
+ daily_df = pd.merge_asof(
473
+ daily_df,
474
+ data[["OBS", col_name]],
475
+ on="OBS",
476
+ direction="backward",
477
+ )
330
478
 
331
479
  # Ensure columns are numeric
332
480
  for col in value_columns:
333
481
  if col in daily_df.columns:
334
- daily_df[col] = pd.to_numeric(daily_df[col], errors='coerce').fillna(0)
482
+ daily_df[col] = pd.to_numeric(daily_df[col], errors="coerce").fillna(0)
335
483
  else:
336
484
  print(f"Column {col} not found in daily_df")
337
485
 
338
486
  # Aggregate results by week
339
- country_df = ims_proc.aggregate_daily_to_wc_wide(df=daily_df,
340
- date_column="OBS",
341
- group_columns=[],
342
- sum_columns=value_columns,
343
- wc=week_commencing,
344
- aggregation="average")
345
-
346
- oecd_df_final = pd.concat([oecd_df_final, country_df], axis=0, ignore_index=True)
487
+ country_df = ims_proc.aggregate_daily_to_wc_wide(
488
+ df=daily_df,
489
+ date_column="OBS",
490
+ group_columns=[],
491
+ sum_columns=value_columns,
492
+ wc=week_commencing,
493
+ aggregation="average",
494
+ )
495
+
496
+ oecd_df_final = pd.concat(
497
+ [oecd_df_final, country_df],
498
+ axis=0,
499
+ ignore_index=True,
500
+ )
347
501
 
348
502
  return oecd_df_final
349
-
350
- def get_google_mobility_data(self, country="United Kingdom", wc="mon") -> pd.DataFrame:
503
+
504
+ def get_google_mobility_data(
505
+ self,
506
+ country="United Kingdom",
507
+ wc="mon",
508
+ ) -> pd.DataFrame:
351
509
  """
352
510
  Fetch Google Mobility data for the specified country.
353
-
354
- Parameters:
511
+
512
+ Parameters
513
+ ----------
355
514
  - country (str): The name of the country for which to fetch data.
356
515
 
357
- Returns:
516
+ Returns
517
+ -------
358
518
  - pd.DataFrame: A DataFrame containing the Google Mobility data.
519
+
359
520
  """
360
521
  # URL of the Google Mobility Reports CSV file
361
522
  url = "https://www.gstatic.com/covid19/mobility/Global_Mobility_Report.csv"
362
-
523
+
363
524
  # Fetch the CSV file
364
525
  response = requests.get(url)
365
526
  if response.status_code != 200:
366
527
  raise Exception(f"Failed to fetch data: {response.status_code}")
367
-
528
+
368
529
  # Load the CSV file into a pandas DataFrame
369
530
  csv_data = StringIO(response.text)
370
531
  df = pd.read_csv(csv_data, low_memory=False)
371
-
532
+
372
533
  # Filter the DataFrame for the specified country
373
- country_df = df[df['country_region'] == country]
374
-
375
- final_covid = ims_proc.aggregate_daily_to_wc_wide(country_df, "date", [], ['retail_and_recreation_percent_change_from_baseline', 'grocery_and_pharmacy_percent_change_from_baseline',
376
- 'parks_percent_change_from_baseline', 'transit_stations_percent_change_from_baseline',
377
- 'workplaces_percent_change_from_baseline', 'residential_percent_change_from_baseline'], wc, "average")
378
-
379
- final_covid1 = ims_proc.rename_cols(final_covid, 'covid_')
534
+ country_df = df[df["country_region"] == country]
535
+
536
+ final_covid = ims_proc.aggregate_daily_to_wc_wide(
537
+ country_df,
538
+ "date",
539
+ [],
540
+ [
541
+ "retail_and_recreation_percent_change_from_baseline",
542
+ "grocery_and_pharmacy_percent_change_from_baseline",
543
+ "parks_percent_change_from_baseline",
544
+ "transit_stations_percent_change_from_baseline",
545
+ "workplaces_percent_change_from_baseline",
546
+ "residential_percent_change_from_baseline",
547
+ ],
548
+ wc,
549
+ "average",
550
+ )
551
+
552
+ final_covid1 = ims_proc.rename_cols(final_covid, "covid_")
380
553
  return final_covid1
381
-
554
+
382
555
  ############################################################### Seasonality ##########################################################################
383
556
 
384
557
  def pull_seasonality(self, week_commencing, start_date, countries):
385
- """
386
- Generates a DataFrame with weekly seasonality features.
558
+ """
559
+ Generates a DataFrame with weekly seasonality features.
387
560
 
388
- Args:
389
- week_commencing (str): The starting day of the week ('mon', 'tue', ..., 'sun').
390
- start_date (str): The start date in 'YYYY-MM-DD' format.
391
- countries (list): A list of country codes (e.g., ['GB', 'US']) for holidays.
561
+ Args:
562
+ week_commencing (str): The starting day of the week ('mon', 'tue', ..., 'sun').
563
+ start_date (str): The start date in 'YYYY-MM-DD' format.
564
+ countries (list): A list of country codes (e.g., ['GB', 'US']) for holidays.
392
565
 
393
- Returns:
394
- pd.DataFrame: A DataFrame indexed by week start date, containing various
395
- seasonal dummy variables, holidays, trend, and constant.
396
- The date column is named 'OBS'.
397
- """
398
- # ---------------------------------------------------------------------
399
- # 0. Setup: dictionary for 'week_commencing' to Python weekday() integer
400
- # ---------------------------------------------------------------------
401
- day_dict = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
402
- if week_commencing not in day_dict:
403
- raise ValueError(f"Invalid week_commencing value: {week_commencing}. Use one of {list(day_dict.keys())}")
404
-
405
- # ---------------------------------------------------------------------
406
- # 0.2 Setup: dictionary continents and countries
407
- # ---------------------------------------------------------------------
408
- COUNTRY_TO_CONTINENT = {
409
- #Europe
410
- "Austria": "europe",
411
- "Belarus": "europe",
412
- "Belgium": "europe",
413
- "Bulgaria": "europe",
414
- "Croatia": "europe",
415
- "Cyprus": "europe",
416
- "Czechia": "europe",
417
- "CzechRepublic": "europe",
418
- "Denmark": "europe",
419
- "Estonia": "europe",
420
- "EuropeanCentralBank": "europe",
421
- "Finland": "europe",
422
- "France": "europe",
423
- "FranceAlsaceMoselle": "europe",
424
- "Germany": "europe",
425
- "GermanyBaden": "europe",
426
- "GermanyBavaria": "europe",
427
- "GermanyBerlin": "europe",
428
- "GermanyBrandenburg": "europe",
429
- "GermanyBremen": "europe",
430
- "GermanyHamburg": "europe",
431
- "GermanyHesse": "europe",
432
- "GermanyLowerSaxony": "europe",
433
- "GermanyMecklenburgVorpommern": "europe",
434
- "GermanyNorthRhineWestphalia": "europe",
435
- "GermanyRhinelandPalatinate": "europe",
436
- "GermanySaarland": "europe",
437
- "GermanySaxony": "europe",
438
- "GermanySaxonyAnhalt": "europe",
439
- "GermanySchleswigHolstein": "europe",
440
- "GermanyThuringia": "europe",
441
- "Greece": "europe",
442
- "Hungary": "europe",
443
- "Iceland": "europe",
444
- "Ireland": "europe",
445
- "Italy": "europe",
446
- "Latvia": "europe",
447
- "Lithuania": "europe",
448
- "Luxembourg": "europe",
449
- "Malta": "europe",
450
- "Monaco": "europe",
451
- "Netherlands": "europe",
452
- "Norway": "europe",
453
- "Poland": "europe",
454
- "Portugal": "europe",
455
- "Romania": "europe",
456
- "Russia": "europe",
457
- "Serbia": "europe",
458
- "Slovakia": "europe",
459
- "Slovenia": "europe",
460
- "Spain": "europe",
461
- "SpainAndalusia": "europe",
462
- "SpainAragon": "europe",
463
- "SpainAsturias": "europe",
464
- "SpainBalearicIslands": "europe",
465
- "SpainBasqueCountry": "europe",
466
- "SpainCanaryIslands": "europe",
467
- "SpainCantabria": "europe",
468
- "SpainCastileAndLeon": "europe",
469
- "SpainCastillaLaMancha": "europe",
470
- "SpainCatalonia": "europe",
471
- "SpainExtremadura": "europe",
472
- "SpainGalicia": "europe",
473
- "SpainLaRioja": "europe",
474
- "SpainMadrid": "europe",
475
- "SpainMurcia": "europe",
476
- "SpainNavarre": "europe",
477
- "SpainValencia": "europe",
478
- "Sweden": "europe",
479
- "Switzerland": "europe",
480
- "Ukraine": "europe",
481
- "UnitedKingdom": "europe",
566
+ Returns:
567
+ pd.DataFrame: A DataFrame indexed by week start date, containing various
568
+ seasonal dummy variables, holidays, trend, and constant.
569
+ The date column is named 'OBS'.
482
570
 
483
- # Americas
484
- "Argentina": "america",
485
- "Barbados": "america",
486
- "Brazil": "america",
487
- "Canada": "america",
488
- "Chile": "america",
489
- "Colombia": "america",
490
- "Mexico": "america",
491
- "Panama": "america",
492
- "Paraguay": "america",
493
- "Peru": "america",
494
- "UnitedStates": "usa",
495
-
496
- #US States
497
- "Alabama": "usa.states",
498
- "Alaska": "usa.states",
499
- "Arizona": "usa.states",
500
- "Arkansas": "usa.states",
501
- "California": "usa.states",
502
- "Colorado": "usa.states",
503
- "Connecticut": "usa.states",
504
- "Delaware": "usa.states",
505
- "DistrictOfColumbia": "usa.states",
506
- "Florida": "usa.states",
507
- "Georgia": "usa.states",
508
- "Hawaii": "usa.states",
509
- "Idaho": "usa.states",
510
- "Illinois": "usa.states",
511
- "Indiana": "usa.states",
512
- "Iowa": "usa.states",
513
- "Kansas": "usa.states",
514
- "Kentucky": "usa.states",
515
- "Louisiana": "usa.states",
516
- "Maine": "usa.states",
517
- "Maryland": "usa.states",
518
- "Massachusetts": "usa.states",
519
- "Michigan": "usa.states",
520
- "Minnesota": "usa.states",
521
- "Mississippi": "usa.states",
522
- "Missouri": "usa.states",
523
- "Montana": "usa.states",
524
- "Nebraska": "usa.states",
525
- "Nevada": "usa.states",
526
- "NewHampshire": "usa.states",
527
- "NewJersey": "usa.states",
528
- "NewMexico": "usa.states",
529
- "NewYork": "usa.states",
530
- "NorthCarolina": "usa.states",
531
- "NorthDakota": "usa.states",
532
- "Ohio": "usa.states",
533
- "Oklahoma": "usa.states",
534
- "Oregon": "usa.states",
535
- "Pennsylvania": "usa.states",
536
- "RhodeIsland": "usa.states",
537
- "SouthCarolina": "usa.states",
538
- "SouthDakota": "usa.states",
539
- "Tennessee": "usa.states",
540
- "Texas": "usa.states",
541
- "Utah": "usa.states",
542
- "Vermont": "usa.states",
543
- "Virginia": "usa.states",
544
- "Washington": "usa.states",
545
- "WestVirginia": "usa.states",
546
- "Wisconsin": "usa.states",
547
- "Wyoming": "usa.states",
548
-
549
- #Oceania
550
- "Australia": "oceania",
551
- "AustraliaCapitalTerritory": "oceania",
552
- "AustraliaNewSouthWales": "oceania",
553
- "AustraliaNorthernTerritory": "oceania",
554
- "AustraliaQueensland": "oceania",
555
- "AustraliaSouthAustralia": "oceania",
556
- "AustraliaTasmania": "oceania",
557
- "AustraliaVictoria": "oceania",
558
- "AustraliaWesternAustralia": "oceania",
559
- "MarshallIslands": "oceania",
560
- "NewZealand": "oceania",
561
-
562
- #Asia
563
- "China": "asia",
564
- "HongKong": "asia",
565
- "India": "asia",
566
- "Israel": "asia",
567
- "Japan": "asia",
568
- "Kazakhstan": "asia",
569
- "Malaysia": "asia",
570
- "Qatar": "asia",
571
- "Singapore": "asia",
572
- "SouthKorea": "asia",
573
- "Taiwan": "asia",
574
- "Turkey": "asia",
575
- "Vietnam": "asia",
576
-
577
- #Africa
578
- "Algeria": "africa",
579
- "Angola": "africa",
580
- "Benin": "africa",
581
- "IvoryCoast": "africa",
582
- "Kenya": "africa",
583
- "Madagascar": "africa",
584
- "Nigeria": "africa",
585
- "SaoTomeAndPrincipe": "africa",
586
- "SouthAfrica": "africa",
587
- }
571
+ """
572
+ # ---------------------------------------------------------------------
573
+ # 0. Setup: dictionary for 'week_commencing' to Python weekday() integer
574
+ # ---------------------------------------------------------------------
575
+ day_dict = {
576
+ "mon": 0,
577
+ "tue": 1,
578
+ "wed": 2,
579
+ "thu": 3,
580
+ "fri": 4,
581
+ "sat": 5,
582
+ "sun": 6,
583
+ }
584
+ if week_commencing not in day_dict:
585
+ raise ValueError(
586
+ f"Invalid week_commencing value: {week_commencing}. Use one of {list(day_dict.keys())}",
587
+ )
588
588
 
589
- # Dictionary mapping ISO country codes to their corresponding workalendar country names
590
- holiday_country = {
591
- # Major countries with required formats
592
- "GB": "UnitedKingdom",
593
- "US": "UnitedStates",
594
- "USA": "UnitedStates", # Alternative code for US
595
- "CA": "Canada",
596
- "ZA": "SouthAfrica",
597
- "FR": "France",
598
- "DE": "Germany",
599
- "AU": "Australia",
600
- "AUS": "Australia", # Alternative code for Australia
601
-
602
- # European countries
603
- "AT": "Austria",
604
- "BY": "Belarus",
605
- "BE": "Belgium",
606
- "BG": "Bulgaria",
607
- "HR": "Croatia",
608
- "CY": "Cyprus",
609
- "CZ": "Czechia",
610
- "DK": "Denmark",
611
- "EE": "Estonia",
612
- "FI": "Finland",
613
- "GR": "Greece",
614
- "HU": "Hungary",
615
- "IS": "Iceland",
616
- "IE": "Ireland",
617
- "IT": "Italy",
618
- "LV": "Latvia",
619
- "LT": "Lithuania",
620
- "LU": "Luxembourg",
621
- "MT": "Malta",
622
- "MC": "Monaco",
623
- "NL": "Netherlands",
624
- "NO": "Norway",
625
- "PL": "Poland",
626
- "PT": "Portugal",
627
- "RO": "Romania",
628
- "RU": "Russia",
629
- "RS": "Serbia",
630
- "SK": "Slovakia",
631
- "SI": "Slovenia",
632
- "ES": "Spain",
633
- "SE": "Sweden",
634
- "CH": "Switzerland",
635
- "UA": "Ukraine",
636
-
637
- # Americas
638
- "AR": "Argentina",
639
- "BB": "Barbados",
640
- "BR": "Brazil",
641
- "CL": "Chile",
642
- "CO": "Colombia",
643
- "MX": "Mexico",
644
- "PA": "Panama",
645
- "PY": "Paraguay",
646
- "PE": "Peru",
647
-
648
- # USA States (using common abbreviations)
649
- "AL": "Alabama",
650
- "AK": "Alaska",
651
- "AZ": "Arizona",
652
- "AR": "Arkansas",
653
- "CA_US": "California",
654
- "CO_US": "Colorado",
655
- "CT": "Connecticut",
656
- "DE_US": "Delaware",
657
- "DC": "DistrictOfColumbia",
658
- "FL": "Florida",
659
- "GA": "Georgia",
660
- "HI": "Hawaii",
661
- "ID": "Idaho",
662
- "IL": "Illinois",
663
- "IN": "Indiana",
664
- "IA": "Iowa",
665
- "KS": "Kansas",
666
- "KY": "Kentucky",
667
- "LA": "Louisiana",
668
- "ME": "Maine",
669
- "MD": "Maryland",
670
- "MA": "Massachusetts",
671
- "MI": "Michigan",
672
- "MN": "Minnesota",
673
- "MS": "Mississippi",
674
- "MO": "Missouri",
675
- "MT": "Montana",
676
- "NE": "Nebraska",
677
- "NV": "Nevada",
678
- "NH": "NewHampshire",
679
- "NJ": "NewJersey",
680
- "NM": "NewMexico",
681
- "NY": "NewYork",
682
- "NC": "NorthCarolina",
683
- "ND": "NorthDakota",
684
- "OH": "Ohio",
685
- "OK": "Oklahoma",
686
- "OR": "Oregon",
687
- "PA_US": "Pennsylvania",
688
- "RI": "RhodeIsland",
689
- "SC": "SouthCarolina",
690
- "SD": "SouthDakota",
691
- "TN": "Tennessee",
692
- "TX": "Texas",
693
- "UT": "Utah",
694
- "VT": "Vermont",
695
- "VA": "Virginia",
696
- "WA": "Washington",
697
- "WV": "WestVirginia",
698
- "WI": "Wisconsin",
699
- "WY": "Wyoming",
700
-
701
- # Australia territories
702
- "ACT": "AustraliaCapitalTerritory",
703
- "NSW": "AustraliaNewSouthWales",
704
- "NT": "AustraliaNorthernTerritory",
705
- "QLD": "AustraliaQueensland",
706
- "SA": "AustraliaSouthAustralia",
707
- "TAS": "AustraliaTasmania",
708
- "VIC": "AustraliaVictoria",
709
- "WA_AU": "AustraliaWesternAustralia",
710
-
711
- # Asian countries
712
- "CN": "China",
713
- "HK": "HongKong",
714
- "IN": "India",
715
- "IL": "Israel",
716
- "JP": "Japan",
717
- "KZ": "Kazakhstan",
718
- "MY": "Malaysia",
719
- "QA": "Qatar",
720
- "SG": "Singapore",
721
- "KR": "SouthKorea",
722
- "TW": "Taiwan",
723
- "TR": "Turkey",
724
- "VN": "Vietnam",
725
-
726
- # Other Oceania countries
727
- "MH": "MarshallIslands",
728
- "NZ": "NewZealand",
729
-
730
- # African countries
731
- "DZ": "Algeria",
732
- "AO": "Angola",
733
- "BJ": "Benin",
734
- "CI": "IvoryCoast",
735
- "KE": "Kenya",
736
- "MG": "Madagascar",
737
- "NG": "Nigeria",
738
- "ST": "SaoTomeAndPrincipe"
739
- }
589
+ # ---------------------------------------------------------------------
590
+ # 0.2 Setup: dictionary continents and countries
591
+ # ---------------------------------------------------------------------
592
+ COUNTRY_TO_CONTINENT = {
593
+ # Europe
594
+ "Austria": "europe",
595
+ "Belarus": "europe",
596
+ "Belgium": "europe",
597
+ "Bulgaria": "europe",
598
+ "Croatia": "europe",
599
+ "Cyprus": "europe",
600
+ "Czechia": "europe",
601
+ "CzechRepublic": "europe",
602
+ "Denmark": "europe",
603
+ "Estonia": "europe",
604
+ "EuropeanCentralBank": "europe",
605
+ "Finland": "europe",
606
+ "France": "europe",
607
+ "FranceAlsaceMoselle": "europe",
608
+ "Germany": "europe",
609
+ "GermanyBaden": "europe",
610
+ "GermanyBavaria": "europe",
611
+ "GermanyBerlin": "europe",
612
+ "GermanyBrandenburg": "europe",
613
+ "GermanyBremen": "europe",
614
+ "GermanyHamburg": "europe",
615
+ "GermanyHesse": "europe",
616
+ "GermanyLowerSaxony": "europe",
617
+ "GermanyMecklenburgVorpommern": "europe",
618
+ "GermanyNorthRhineWestphalia": "europe",
619
+ "GermanyRhinelandPalatinate": "europe",
620
+ "GermanySaarland": "europe",
621
+ "GermanySaxony": "europe",
622
+ "GermanySaxonyAnhalt": "europe",
623
+ "GermanySchleswigHolstein": "europe",
624
+ "GermanyThuringia": "europe",
625
+ "Greece": "europe",
626
+ "Hungary": "europe",
627
+ "Iceland": "europe",
628
+ "Ireland": "europe",
629
+ "Italy": "europe",
630
+ "Latvia": "europe",
631
+ "Lithuania": "europe",
632
+ "Luxembourg": "europe",
633
+ "Malta": "europe",
634
+ "Monaco": "europe",
635
+ "Netherlands": "europe",
636
+ "Norway": "europe",
637
+ "Poland": "europe",
638
+ "Portugal": "europe",
639
+ "Romania": "europe",
640
+ "Russia": "europe",
641
+ "Serbia": "europe",
642
+ "Slovakia": "europe",
643
+ "Slovenia": "europe",
644
+ "Spain": "europe",
645
+ "SpainAndalusia": "europe",
646
+ "SpainAragon": "europe",
647
+ "SpainAsturias": "europe",
648
+ "SpainBalearicIslands": "europe",
649
+ "SpainBasqueCountry": "europe",
650
+ "SpainCanaryIslands": "europe",
651
+ "SpainCantabria": "europe",
652
+ "SpainCastileAndLeon": "europe",
653
+ "SpainCastillaLaMancha": "europe",
654
+ "SpainCatalonia": "europe",
655
+ "SpainExtremadura": "europe",
656
+ "SpainGalicia": "europe",
657
+ "SpainLaRioja": "europe",
658
+ "SpainMadrid": "europe",
659
+ "SpainMurcia": "europe",
660
+ "SpainNavarre": "europe",
661
+ "SpainValencia": "europe",
662
+ "Sweden": "europe",
663
+ "Switzerland": "europe",
664
+ "Ukraine": "europe",
665
+ "UnitedKingdom": "europe",
666
+ # Americas
667
+ "Argentina": "america",
668
+ "Barbados": "america",
669
+ "Brazil": "america",
670
+ "Canada": "america",
671
+ "Chile": "america",
672
+ "Colombia": "america",
673
+ "Mexico": "america",
674
+ "Panama": "america",
675
+ "Paraguay": "america",
676
+ "Peru": "america",
677
+ "UnitedStates": "usa",
678
+ # US States
679
+ "Alabama": "usa.states",
680
+ "Alaska": "usa.states",
681
+ "Arizona": "usa.states",
682
+ "Arkansas": "usa.states",
683
+ "California": "usa.states",
684
+ "Colorado": "usa.states",
685
+ "Connecticut": "usa.states",
686
+ "Delaware": "usa.states",
687
+ "DistrictOfColumbia": "usa.states",
688
+ "Florida": "usa.states",
689
+ "Georgia": "usa.states",
690
+ "Hawaii": "usa.states",
691
+ "Idaho": "usa.states",
692
+ "Illinois": "usa.states",
693
+ "Indiana": "usa.states",
694
+ "Iowa": "usa.states",
695
+ "Kansas": "usa.states",
696
+ "Kentucky": "usa.states",
697
+ "Louisiana": "usa.states",
698
+ "Maine": "usa.states",
699
+ "Maryland": "usa.states",
700
+ "Massachusetts": "usa.states",
701
+ "Michigan": "usa.states",
702
+ "Minnesota": "usa.states",
703
+ "Mississippi": "usa.states",
704
+ "Missouri": "usa.states",
705
+ "Montana": "usa.states",
706
+ "Nebraska": "usa.states",
707
+ "Nevada": "usa.states",
708
+ "NewHampshire": "usa.states",
709
+ "NewJersey": "usa.states",
710
+ "NewMexico": "usa.states",
711
+ "NewYork": "usa.states",
712
+ "NorthCarolina": "usa.states",
713
+ "NorthDakota": "usa.states",
714
+ "Ohio": "usa.states",
715
+ "Oklahoma": "usa.states",
716
+ "Oregon": "usa.states",
717
+ "Pennsylvania": "usa.states",
718
+ "RhodeIsland": "usa.states",
719
+ "SouthCarolina": "usa.states",
720
+ "SouthDakota": "usa.states",
721
+ "Tennessee": "usa.states",
722
+ "Texas": "usa.states",
723
+ "Utah": "usa.states",
724
+ "Vermont": "usa.states",
725
+ "Virginia": "usa.states",
726
+ "Washington": "usa.states",
727
+ "WestVirginia": "usa.states",
728
+ "Wisconsin": "usa.states",
729
+ "Wyoming": "usa.states",
730
+ # Oceania
731
+ "Australia": "oceania",
732
+ "AustraliaCapitalTerritory": "oceania",
733
+ "AustraliaNewSouthWales": "oceania",
734
+ "AustraliaNorthernTerritory": "oceania",
735
+ "AustraliaQueensland": "oceania",
736
+ "AustraliaSouthAustralia": "oceania",
737
+ "AustraliaTasmania": "oceania",
738
+ "AustraliaVictoria": "oceania",
739
+ "AustraliaWesternAustralia": "oceania",
740
+ "MarshallIslands": "oceania",
741
+ "NewZealand": "oceania",
742
+ # Asia
743
+ "China": "asia",
744
+ "HongKong": "asia",
745
+ "India": "asia",
746
+ "Israel": "asia",
747
+ "Japan": "asia",
748
+ "Kazakhstan": "asia",
749
+ "Malaysia": "asia",
750
+ "Qatar": "asia",
751
+ "Singapore": "asia",
752
+ "SouthKorea": "asia",
753
+ "Taiwan": "asia",
754
+ "Turkey": "asia",
755
+ "Vietnam": "asia",
756
+ # Africa
757
+ "Algeria": "africa",
758
+ "Angola": "africa",
759
+ "Benin": "africa",
760
+ "IvoryCoast": "africa",
761
+ "Kenya": "africa",
762
+ "Madagascar": "africa",
763
+ "Nigeria": "africa",
764
+ "SaoTomeAndPrincipe": "africa",
765
+ "SouthAfrica": "africa",
766
+ }
740
767
 
768
+ # Dictionary mapping ISO country codes to their corresponding workalendar country names
769
+ holiday_country = {
770
+ # Major countries with required formats
771
+ "GB": "UnitedKingdom",
772
+ "US": "UnitedStates",
773
+ "USA": "UnitedStates", # Alternative code for US
774
+ "CA": "Canada",
775
+ "ZA": "SouthAfrica",
776
+ "FR": "France",
777
+ "DE": "Germany",
778
+ "AU": "Australia",
779
+ "AUS": "Australia", # Alternative code for Australia
780
+ # European countries
781
+ "AT": "Austria",
782
+ "BY": "Belarus",
783
+ "BE": "Belgium",
784
+ "BG": "Bulgaria",
785
+ "HR": "Croatia",
786
+ "CY": "Cyprus",
787
+ "CZ": "Czechia",
788
+ "DK": "Denmark",
789
+ "EE": "Estonia",
790
+ "FI": "Finland",
791
+ "GR": "Greece",
792
+ "HU": "Hungary",
793
+ "IS": "Iceland",
794
+ "IE": "Ireland",
795
+ "IT": "Italy",
796
+ "LV": "Latvia",
797
+ "LT": "Lithuania",
798
+ "LU": "Luxembourg",
799
+ "MT": "Malta",
800
+ "MC": "Monaco",
801
+ "NL": "Netherlands",
802
+ "NO": "Norway",
803
+ "PL": "Poland",
804
+ "PT": "Portugal",
805
+ "RO": "Romania",
806
+ "RU": "Russia",
807
+ "RS": "Serbia",
808
+ "SK": "Slovakia",
809
+ "SI": "Slovenia",
810
+ "ES": "Spain",
811
+ "SE": "Sweden",
812
+ "CH": "Switzerland",
813
+ "UA": "Ukraine",
814
+ # Americas
815
+ "AR": "Argentina",
816
+ "BB": "Barbados",
817
+ "BR": "Brazil",
818
+ "CL": "Chile",
819
+ "CO": "Colombia",
820
+ "MX": "Mexico",
821
+ "PA": "Panama",
822
+ "PY": "Paraguay",
823
+ "PE": "Peru",
824
+ # USA States (using common abbreviations)
825
+ "AL": "Alabama",
826
+ "AK": "Alaska",
827
+ "AZ": "Arizona",
828
+ "AR": "Arkansas",
829
+ "CA_US": "California",
830
+ "CO_US": "Colorado",
831
+ "CT": "Connecticut",
832
+ "DE_US": "Delaware",
833
+ "DC": "DistrictOfColumbia",
834
+ "FL": "Florida",
835
+ "GA": "Georgia",
836
+ "HI": "Hawaii",
837
+ "ID": "Idaho",
838
+ "IL": "Illinois",
839
+ "IN": "Indiana",
840
+ "IA": "Iowa",
841
+ "KS": "Kansas",
842
+ "KY": "Kentucky",
843
+ "LA": "Louisiana",
844
+ "ME": "Maine",
845
+ "MD": "Maryland",
846
+ "MA": "Massachusetts",
847
+ "MI": "Michigan",
848
+ "MN": "Minnesota",
849
+ "MS": "Mississippi",
850
+ "MO": "Missouri",
851
+ "MT": "Montana",
852
+ "NE": "Nebraska",
853
+ "NV": "Nevada",
854
+ "NH": "NewHampshire",
855
+ "NJ": "NewJersey",
856
+ "NM": "NewMexico",
857
+ "NY": "NewYork",
858
+ "NC": "NorthCarolina",
859
+ "ND": "NorthDakota",
860
+ "OH": "Ohio",
861
+ "OK": "Oklahoma",
862
+ "OR": "Oregon",
863
+ "PA_US": "Pennsylvania",
864
+ "RI": "RhodeIsland",
865
+ "SC": "SouthCarolina",
866
+ "SD": "SouthDakota",
867
+ "TN": "Tennessee",
868
+ "TX": "Texas",
869
+ "UT": "Utah",
870
+ "VT": "Vermont",
871
+ "VA": "Virginia",
872
+ "WA": "Washington",
873
+ "WV": "WestVirginia",
874
+ "WI": "Wisconsin",
875
+ "WY": "Wyoming",
876
+ # Australia territories
877
+ "ACT": "AustraliaCapitalTerritory",
878
+ "NSW": "AustraliaNewSouthWales",
879
+ "NT": "AustraliaNorthernTerritory",
880
+ "QLD": "AustraliaQueensland",
881
+ "SA": "AustraliaSouthAustralia",
882
+ "TAS": "AustraliaTasmania",
883
+ "VIC": "AustraliaVictoria",
884
+ "WA_AU": "AustraliaWesternAustralia",
885
+ # Asian countries
886
+ "CN": "China",
887
+ "HK": "HongKong",
888
+ "IN": "India",
889
+ "IL": "Israel",
890
+ "JP": "Japan",
891
+ "KZ": "Kazakhstan",
892
+ "MY": "Malaysia",
893
+ "QA": "Qatar",
894
+ "SG": "Singapore",
895
+ "KR": "SouthKorea",
896
+ "TW": "Taiwan",
897
+ "TR": "Turkey",
898
+ "VN": "Vietnam",
899
+ # Other Oceania countries
900
+ "MH": "MarshallIslands",
901
+ "NZ": "NewZealand",
902
+ # African countries
903
+ "DZ": "Algeria",
904
+ "AO": "Angola",
905
+ "BJ": "Benin",
906
+ "CI": "IvoryCoast",
907
+ "KE": "Kenya",
908
+ "MG": "Madagascar",
909
+ "NG": "Nigeria",
910
+ "ST": "SaoTomeAndPrincipe",
911
+ }
741
912
 
742
- # ---------------------------------------------------------------------
743
- # 1. Create daily date range from start_date to today
744
- # ---------------------------------------------------------------------
745
- try:
746
- start_dt = pd.to_datetime(start_date)
747
- except ValueError:
748
- raise ValueError(f"Invalid start_date format: {start_date}. Use 'YYYY-MM-DD'")
749
-
750
- end_dt = datetime.today()
751
- # Ensure end date is not before start date
752
- if end_dt < start_dt:
753
- end_dt = start_dt + timedelta(days=1) # Or handle as error if preferred
754
-
755
- date_range = pd.date_range(
756
- start=start_dt,
757
- end=end_dt,
758
- freq="D"
913
+ # ---------------------------------------------------------------------
914
+ # 1. Create daily date range from start_date to today
915
+ # ---------------------------------------------------------------------
916
+ try:
917
+ start_dt = pd.to_datetime(start_date)
918
+ except ValueError:
919
+ raise ValueError(
920
+ f"Invalid start_date format: {start_date}. Use 'YYYY-MM-DD'",
759
921
  )
760
- df_daily = pd.DataFrame(date_range, columns=["Date"])
761
922
 
762
- # ---------------------------------------------------------------------
763
- # 1.1 Identify "week_start" for each daily row, based on week_commencing
764
- # ---------------------------------------------------------------------
765
- start_day_int = day_dict[week_commencing]
766
- df_daily['week_start'] = df_daily["Date"].apply(
767
- lambda x: x - pd.Timedelta(days=(x.weekday() - start_day_int) % 7)
768
- )
923
+ end_dt = datetime.today()
924
+ # Ensure end date is not before start date
925
+ if end_dt < start_dt:
926
+ end_dt = start_dt + timedelta(days=1) # Or handle as error if preferred
769
927
 
770
- # ---------------------------------------------------------------------
771
- # 1.2 Calculate ISO week number for each DAY (for later aggregation)
772
- # Also calculate Year for each DAY to handle year transitions correctly
773
- # ---------------------------------------------------------------------
774
- df_daily['iso_week_daily'] = df_daily['Date'].dt.isocalendar().week.astype(int)
775
- df_daily['iso_year_daily'] = df_daily['Date'].dt.isocalendar().year.astype(int)
928
+ date_range = pd.date_range(start=start_dt, end=end_dt, freq="D")
929
+ df_daily = pd.DataFrame(date_range, columns=["Date"])
776
930
 
931
+ # ---------------------------------------------------------------------
932
+ # 1.1 Identify "week_start" for each daily row, based on week_commencing
933
+ # ---------------------------------------------------------------------
934
+ start_day_int = day_dict[week_commencing]
935
+ df_daily["week_start"] = df_daily["Date"].apply(
936
+ lambda x: x - pd.Timedelta(days=(x.weekday() - start_day_int) % 7),
937
+ )
777
938
 
778
- # ---------------------------------------------------------------------
779
- # 2. Build a weekly index (df_weekly_start) based on unique week_start dates
780
- # ---------------------------------------------------------------------
781
- df_weekly_start = df_daily[['week_start']].drop_duplicates().sort_values('week_start').reset_index(drop=True)
782
- df_weekly_start.rename(columns={'week_start': "Date"}, inplace=True)
783
- df_weekly_start.set_index("Date", inplace=True)
939
+ # ---------------------------------------------------------------------
940
+ # 1.2 Calculate ISO week number for each DAY (for later aggregation)
941
+ # Also calculate Year for each DAY to handle year transitions correctly
942
+ # ---------------------------------------------------------------------
943
+ df_daily["iso_week_daily"] = df_daily["Date"].dt.isocalendar().week.astype(int)
944
+ df_daily["iso_year_daily"] = df_daily["Date"].dt.isocalendar().year.astype(int)
945
+
946
+ # ---------------------------------------------------------------------
947
+ # 2. Build a weekly index (df_weekly_start) based on unique week_start dates
948
+ # ---------------------------------------------------------------------
949
+ df_weekly_start = (
950
+ df_daily[["week_start"]]
951
+ .drop_duplicates()
952
+ .sort_values("week_start")
953
+ .reset_index(drop=True)
954
+ )
955
+ df_weekly_start.rename(columns={"week_start": "Date"}, inplace=True)
956
+ df_weekly_start.set_index("Date", inplace=True)
957
+
958
+ # Create individual weekly dummies (optional, uncomment if needed)
959
+ dummy_columns = {}
960
+ for i, date_index in enumerate(df_weekly_start.index):
961
+ col_name = f"dum_{date_index.strftime('%Y_%m_%d')}"
962
+ dummy_columns[col_name] = [0] * len(df_weekly_start)
963
+ dummy_columns[col_name][i] = 1
964
+ df_dummies = pd.DataFrame(dummy_columns, index=df_weekly_start.index)
965
+ df_weekly_start = pd.concat([df_weekly_start, df_dummies], axis=1)
966
+
967
+ # ---------------------------------------------------------------------
968
+ # 3. Public holidays (daily) using WorkCalendar
969
+ # ---------------------------------------------------------------------
970
+ start_year = start_dt.year
971
+ end_year = end_dt.year
972
+ years_range = range(start_year, end_year + 1)
973
+
974
+ # Dictionary to store holiday dummies for each country
975
+ country_holiday_dummies = {}
976
+
977
+ for country_code in countries:
978
+ # Skip if country code not found in holiday_country dictionary
979
+ if country_code not in holiday_country:
980
+ print(
981
+ f"Warning: Country code '{country_code}' not found in country code dictionary. Skipping.",
982
+ )
983
+ continue
984
+
985
+ country = holiday_country[country_code]
986
+
987
+ # Skip if country not found in continent lookup dictionary
988
+ if country not in COUNTRY_TO_CONTINENT:
989
+ print(
990
+ f"Warning: Country '{country}' not found in continent lookup dictionary. Skipping.",
991
+ )
992
+ continue
784
993
 
785
- # Create individual weekly dummies (optional, uncomment if needed)
786
- dummy_columns = {}
787
- for i, date_index in enumerate(df_weekly_start.index):
788
- col_name = f"dum_{date_index.strftime('%Y_%m_%d')}"
789
- dummy_columns[col_name] = [0] * len(df_weekly_start)
790
- dummy_columns[col_name][i] = 1
791
- df_dummies = pd.DataFrame(dummy_columns, index=df_weekly_start.index)
792
- df_weekly_start = pd.concat([df_weekly_start, df_dummies], axis=1)
994
+ continent = COUNTRY_TO_CONTINENT[country]
995
+ module_path = f"workalendar.{continent}"
996
+ try:
997
+ module = importlib.import_module(module_path)
998
+ calendar_class = getattr(module, country)
999
+ cal = calendar_class()
1000
+ except (ImportError, AttributeError) as e:
1001
+ print(f"Error importing calendar for {country}: {e}. Skipping.")
1002
+ continue
1003
+
1004
+ # Collect holidays
1005
+ holidays_list = []
1006
+ for year in years_range:
1007
+ holidays_list.extend(cal.holidays(year))
793
1008
 
1009
+ holidays_df = pd.DataFrame(holidays_list, columns=["Date", "Holiday"])
1010
+ holidays_df["Date"] = pd.to_datetime(holidays_df["Date"])
794
1011
 
1012
+ # Filter out any holidays with "shift" or "substitute" in their name
1013
+ holidays_df = holidays_df[
1014
+ ~(
1015
+ holidays_df["Holiday"].str.lower().str.contains("shift")
1016
+ | holidays_df["Holiday"].str.lower().str.contains("substitute")
1017
+ )
1018
+ ]
1019
+
1020
+ # Filter by date range
1021
+ holidays_df = holidays_df[
1022
+ (holidays_df["Date"] >= start_dt) & (holidays_df["Date"] <= end_dt)
1023
+ ]
795
1024
  # ---------------------------------------------------------------------
796
- # 3. Public holidays (daily) using WorkCalendar
1025
+ # 3.1 Additional Public Holidays for Canada due to poor API data
797
1026
  # ---------------------------------------------------------------------
798
- start_year = start_dt.year
799
- end_year = end_dt.year
800
- years_range = range(start_year, end_year + 1)
801
-
802
- # Dictionary to store holiday dummies for each country
803
- country_holiday_dummies = {}
804
-
805
- for country_code in countries:
806
- # Skip if country code not found in holiday_country dictionary
807
- if country_code not in holiday_country:
808
- print(f"Warning: Country code '{country_code}' not found in country code dictionary. Skipping.")
809
- continue
810
-
811
- country = holiday_country[country_code]
812
-
813
- # Skip if country not found in continent lookup dictionary
814
- if country not in COUNTRY_TO_CONTINENT:
815
- print(f"Warning: Country '{country}' not found in continent lookup dictionary. Skipping.")
816
- continue
817
-
818
- continent = COUNTRY_TO_CONTINENT[country]
819
- module_path = f'workalendar.{continent}'
820
- try:
821
- module = importlib.import_module(module_path)
822
- calendar_class = getattr(module, country)
823
- cal = calendar_class()
824
- except (ImportError, AttributeError) as e:
825
- print(f"Error importing calendar for {country}: {e}. Skipping.")
826
- continue
827
-
828
- # Collect holidays
829
- holidays_list = []
1027
+ if country_code == "CA":
1028
+ # Add Canada Day (July 1st) if not already in the list
830
1029
  for year in years_range:
831
- holidays_list.extend(cal.holidays(year))
832
-
833
- holidays_df = pd.DataFrame(holidays_list, columns=['Date', 'Holiday'])
834
- holidays_df['Date'] = pd.to_datetime(holidays_df['Date'])
835
-
836
- # Filter out any holidays with "shift" or "substitute" in their name
837
- holidays_df = holidays_df[~(holidays_df['Holiday'].str.lower().str.contains('shift') |
838
- holidays_df['Holiday'].str.lower().str.contains('substitute'))]
839
-
840
- # Filter by date range
841
- holidays_df = holidays_df[(holidays_df['Date'] >= start_dt) & (holidays_df['Date'] <= end_dt)]
842
- # ---------------------------------------------------------------------
843
- # 3.1 Additional Public Holidays for Canada due to poor API data
844
- # ---------------------------------------------------------------------
845
- if country_code == 'CA':
846
- # Add Canada Day (July 1st) if not already in the list
847
- for year in years_range:
848
- canada_day = pd.Timestamp(f"{year}-07-01")
849
- if canada_day >= start_dt and canada_day <= end_dt:
850
- if not ((holidays_df['Date'] == canada_day) &
851
- (holidays_df['Holiday'].str.lower().str.contains('canada day'))).any():
852
- holidays_df = pd.concat([holidays_df,
853
- pd.DataFrame({'Date': [canada_day],
854
- 'Holiday': ['Canada Day']})],
855
- ignore_index=True)
856
-
857
- # Add Labour Day (first Monday in September)
858
- for year in years_range:
859
- # Get first day of September
860
- first_day = pd.Timestamp(f"{year}-09-01")
861
- # Calculate days until first Monday (Monday is weekday 0)
862
- days_until_monday = (7 - first_day.weekday()) % 7
863
- if days_until_monday == 0: # If first day is already Monday
864
- labour_day = first_day
865
- else:
866
- labour_day = first_day + pd.Timedelta(days=days_until_monday)
867
-
868
- if labour_day >= start_dt and labour_day <= end_dt:
869
- if not ((holidays_df['Date'] == labour_day) &
870
- (holidays_df['Holiday'].str.lower().str.contains('labour day'))).any():
871
- holidays_df = pd.concat([holidays_df,
872
- pd.DataFrame({'Date': [labour_day],
873
- 'Holiday': ['Labour Day']})],
874
- ignore_index=True)
875
-
876
- # Add Thanksgiving (second Monday in October)
877
- for year in years_range:
878
- # Get first day of October
879
- first_day = pd.Timestamp(f"{year}-10-01")
880
- # Calculate days until first Monday
881
- days_until_monday = (7 - first_day.weekday()) % 7
882
- if days_until_monday == 0: # If first day is already Monday
883
- first_monday = first_day
884
- else:
885
- first_monday = first_day + pd.Timedelta(days=days_until_monday)
886
-
887
- # Second Monday is 7 days after first Monday
888
- thanksgiving = first_monday + pd.Timedelta(days=7)
889
-
890
- if thanksgiving >= start_dt and thanksgiving <= end_dt:
891
- if not ((holidays_df['Date'] == thanksgiving) &
892
- (holidays_df['Holiday'].str.lower().str.contains('thanksgiving'))).any():
893
- holidays_df = pd.concat([holidays_df,
894
- pd.DataFrame({'Date': [thanksgiving],
895
- 'Holiday': ['Thanksgiving']})],
896
- ignore_index=True)
897
-
898
- # Now process the collected holidays and add to df_daily
899
- for _, row in holidays_df.iterrows():
900
- holiday_date = row['Date']
901
- # Create column name without modifying original holiday names
902
- holiday_name = row['Holiday'].lower().replace(' ', '_')
903
-
904
- # Remove "_shift" or "_substitute" if they appear as standalone suffixes
905
- if holiday_name.endswith('_shift'):
906
- holiday_name = holiday_name[:-6]
907
- elif holiday_name.endswith('_substitute'):
908
- holiday_name = holiday_name[:-11]
909
-
910
- column_name = f"seas_{holiday_name}_{country_code.lower()}"
911
-
912
- if column_name not in df_daily.columns:
913
- df_daily[column_name] = 0
914
-
915
- # Mark the specific holiday date
916
- df_daily.loc[df_daily["Date"] == holiday_date, column_name] = 1
917
-
918
- # Also mark a general holiday indicator for each country
919
- holiday_indicator = f"seas_holiday_{country_code.lower()}"
920
- if holiday_indicator not in df_daily.columns:
921
- df_daily[holiday_indicator] = 0
922
- df_daily.loc[df_daily["Date"] == holiday_date, holiday_indicator] = 1
1030
+ canada_day = pd.Timestamp(f"{year}-07-01")
1031
+ if canada_day >= start_dt and canada_day <= end_dt:
1032
+ if not (
1033
+ (holidays_df["Date"] == canada_day)
1034
+ & (
1035
+ holidays_df["Holiday"]
1036
+ .str.lower()
1037
+ .str.contains("canada day")
1038
+ )
1039
+ ).any():
1040
+ holidays_df = pd.concat(
1041
+ [
1042
+ holidays_df,
1043
+ pd.DataFrame(
1044
+ {
1045
+ "Date": [canada_day],
1046
+ "Holiday": ["Canada Day"],
1047
+ },
1048
+ ),
1049
+ ],
1050
+ ignore_index=True,
1051
+ )
1052
+
1053
+ # Add Labour Day (first Monday in September)
1054
+ for year in years_range:
1055
+ # Get first day of September
1056
+ first_day = pd.Timestamp(f"{year}-09-01")
1057
+ # Calculate days until first Monday (Monday is weekday 0)
1058
+ days_until_monday = (7 - first_day.weekday()) % 7
1059
+ if days_until_monday == 0: # If first day is already Monday
1060
+ labour_day = first_day
1061
+ else:
1062
+ labour_day = first_day + pd.Timedelta(days=days_until_monday)
1063
+
1064
+ if labour_day >= start_dt and labour_day <= end_dt:
1065
+ if not (
1066
+ (holidays_df["Date"] == labour_day)
1067
+ & (
1068
+ holidays_df["Holiday"]
1069
+ .str.lower()
1070
+ .str.contains("labour day")
1071
+ )
1072
+ ).any():
1073
+ holidays_df = pd.concat(
1074
+ [
1075
+ holidays_df,
1076
+ pd.DataFrame(
1077
+ {
1078
+ "Date": [labour_day],
1079
+ "Holiday": ["Labour Day"],
1080
+ },
1081
+ ),
1082
+ ],
1083
+ ignore_index=True,
1084
+ )
1085
+
1086
+ # Add Thanksgiving (second Monday in October)
1087
+ for year in years_range:
1088
+ # Get first day of October
1089
+ first_day = pd.Timestamp(f"{year}-10-01")
1090
+ # Calculate days until first Monday
1091
+ days_until_monday = (7 - first_day.weekday()) % 7
1092
+ if days_until_monday == 0: # If first day is already Monday
1093
+ first_monday = first_day
1094
+ else:
1095
+ first_monday = first_day + pd.Timedelta(days=days_until_monday)
1096
+
1097
+ # Second Monday is 7 days after first Monday
1098
+ thanksgiving = first_monday + pd.Timedelta(days=7)
1099
+
1100
+ if thanksgiving >= start_dt and thanksgiving <= end_dt:
1101
+ if not (
1102
+ (holidays_df["Date"] == thanksgiving)
1103
+ & (
1104
+ holidays_df["Holiday"]
1105
+ .str.lower()
1106
+ .str.contains("thanksgiving")
1107
+ )
1108
+ ).any():
1109
+ holidays_df = pd.concat(
1110
+ [
1111
+ holidays_df,
1112
+ pd.DataFrame(
1113
+ {
1114
+ "Date": [thanksgiving],
1115
+ "Holiday": ["Thanksgiving"],
1116
+ },
1117
+ ),
1118
+ ],
1119
+ ignore_index=True,
1120
+ )
1121
+
1122
+ # Now process the collected holidays and add to df_daily
1123
+ for _, row in holidays_df.iterrows():
1124
+ holiday_date = row["Date"]
1125
+ # Create column name without modifying original holiday names
1126
+ holiday_name = row["Holiday"].lower().replace(" ", "_")
1127
+
1128
+ # Remove "_shift" or "_substitute" if they appear as standalone suffixes
1129
+ if holiday_name.endswith("_shift"):
1130
+ holiday_name = holiday_name[:-6]
1131
+ elif holiday_name.endswith("_substitute"):
1132
+ holiday_name = holiday_name[:-11]
1133
+
1134
+ column_name = f"seas_{holiday_name}_{country_code.lower()}"
1135
+
1136
+ if column_name not in df_daily.columns:
1137
+ df_daily[column_name] = 0
1138
+
1139
+ # Mark the specific holiday date
1140
+ df_daily.loc[df_daily["Date"] == holiday_date, column_name] = 1
1141
+
1142
+ # Also mark a general holiday indicator for each country
1143
+ holiday_indicator = f"seas_holiday_{country_code.lower()}"
1144
+ if holiday_indicator not in df_daily.columns:
1145
+ df_daily[holiday_indicator] = 0
1146
+ df_daily.loc[df_daily["Date"] == holiday_date, holiday_indicator] = 1
1147
+
1148
+ # ---------------------------------------------------------------------
1149
+ # 3.1 Additional Special Days (Father's Day, Mother's Day, etc.)
1150
+ # ---------------------------------------------------------------------
1151
+ extra_cols = [
1152
+ "seas_valentines_day",
1153
+ "seas_halloween",
1154
+ "seas_fathers_day_us_uk", # Note: UK/US is 3rd Sun Jun, others vary
1155
+ "seas_mothers_day_us", # Note: US is 2nd Sun May
1156
+ "seas_mothers_day_uk", # Note: UK Mothering Sunday varies with Easter
1157
+ "seas_good_friday",
1158
+ "seas_easter_monday",
1159
+ "seas_black_friday", # US-centric, but globally adopted
1160
+ "seas_cyber_monday", # US-centric, but globally adopted
1161
+ ]
1162
+ for c in extra_cols:
1163
+ if (
1164
+ c not in df_daily.columns
1165
+ ): # Avoid overwriting if already created by holidays pkg
1166
+ df_daily[c] = 0
1167
+
1168
+ # Helper: nth_weekday_of_month(year, month, weekday, nth)
1169
+ def nth_weekday_of_month(year, month, weekday, nth):
1170
+ d = datetime(year, month, 1)
1171
+ w = d.weekday()
1172
+ delta = (weekday - w + 7) % 7 # Ensure positive delta
1173
+ first_weekday = d + timedelta(days=delta)
1174
+ target_date = first_weekday + timedelta(days=7 * (nth - 1))
1175
+ # Check if the calculated date is still in the same month
1176
+ if target_date.month == month:
1177
+ return target_date
1178
+ # This can happen if nth is too large (e.g., 5th Friday)
1179
+ # Return the last occurrence of that weekday in the month instead
1180
+ return target_date - timedelta(days=7)
1181
+
1182
+ def get_good_friday(year):
1183
+ return easter(year) - timedelta(days=2)
1184
+
1185
+ def get_easter_monday(year):
1186
+ return easter(year) + timedelta(days=1)
1187
+
1188
+ def get_black_friday(year):
1189
+ # US Thanksgiving is 4th Thursday in November (weekday=3)
1190
+ thanksgiving = nth_weekday_of_month(year, 11, 3, 4)
1191
+ return thanksgiving + timedelta(days=1)
1192
+
1193
+ def get_cyber_monday(year):
1194
+ # Monday after US Thanksgiving
1195
+ thanksgiving = nth_weekday_of_month(year, 11, 3, 4)
1196
+ return thanksgiving + timedelta(days=4)
1197
+
1198
+ def get_mothering_sunday_uk(year):
1199
+ # Fourth Sunday in Lent (3 weeks before Easter Sunday)
1200
+ # Lent starts on Ash Wednesday, 46 days before Easter.
1201
+ # Easter Sunday is day 0. Sunday before is -7, etc.
1202
+ # 4th Sunday in Lent is 3 weeks before Easter.
1203
+ return easter(year) - timedelta(days=21)
1204
+
1205
+ # Loop over each year in range
1206
+ for yr in range(start_year, end_year + 1):
1207
+ try: # Wrap calculations in try-except for robustness
1208
+ # Valentines = Feb 14
1209
+ valentines_day = datetime(yr, 2, 14)
1210
+ # Halloween = Oct 31
1211
+ halloween_day = datetime(yr, 10, 31)
1212
+ # Father's Day (US & UK) = 3rd Sunday (6) in June
1213
+ fathers_day = nth_weekday_of_month(yr, 6, 6, 3)
1214
+ # Mother's Day US = 2nd Sunday (6) in May
1215
+ mothers_day_us = nth_weekday_of_month(yr, 5, 6, 2)
1216
+ # Mother's Day UK (Mothering Sunday)
1217
+ mothering_sunday = get_mothering_sunday_uk(yr)
1218
+
1219
+ # Good Friday, Easter Monday
1220
+ gf = get_good_friday(yr)
1221
+ em = get_easter_monday(yr)
1222
+
1223
+ # Black Friday, Cyber Monday
1224
+ bf = get_black_friday(yr)
1225
+ cm = get_cyber_monday(yr)
1226
+
1227
+ # Mark them in df_daily if in range
1228
+ special_days_map = [
1229
+ (valentines_day, "seas_valentines_day"),
1230
+ (halloween_day, "seas_halloween"),
1231
+ (fathers_day, "seas_fathers_day_us_uk"),
1232
+ (mothers_day_us, "seas_mothers_day_us"),
1233
+ (mothering_sunday, "seas_mothers_day_uk"),
1234
+ (gf, "seas_good_friday"),
1235
+ (em, "seas_easter_monday"),
1236
+ (bf, "seas_black_friday"),
1237
+ (cm, "seas_cyber_monday"),
1238
+ ]
1239
+
1240
+ for special_date, col in special_days_map:
1241
+ if (
1242
+ special_date is not None
1243
+ ): # nth_weekday_of_month can return None edge cases
1244
+ special_ts = pd.Timestamp(special_date)
1245
+ # Only set if it's within the daily range AND column exists
1246
+ if (
1247
+ (special_ts >= df_daily["Date"].min())
1248
+ and (special_ts <= df_daily["Date"].max())
1249
+ and (col in df_daily.columns)
1250
+ ):
1251
+ df_daily.loc[df_daily["Date"] == special_ts, col] = 1
1252
+ except Exception as e:
1253
+ print(f"Warning: Could not calculate special days for year {yr}: {e}")
923
1254
 
1255
+ # ---------------------------------------------------------------------
1256
+ # 4. Add daily indicators for last day & last Friday of month & payday
1257
+ # ---------------------------------------------------------------------
1258
+ df_daily["is_last_day_of_month"] = df_daily["Date"].dt.is_month_end
924
1259
 
925
- # ---------------------------------------------------------------------
926
- # 3.1 Additional Special Days (Father's Day, Mother's Day, etc.)
927
- # ---------------------------------------------------------------------
928
- extra_cols = [
929
- "seas_valentines_day",
930
- "seas_halloween",
931
- "seas_fathers_day_us_uk", # Note: UK/US is 3rd Sun Jun, others vary
932
- "seas_mothers_day_us", # Note: US is 2nd Sun May
933
- "seas_mothers_day_uk", # Note: UK Mothering Sunday varies with Easter
934
- "seas_good_friday",
935
- "seas_easter_monday",
936
- "seas_black_friday", # US-centric, but globally adopted
937
- "seas_cyber_monday", # US-centric, but globally adopted
938
- ]
939
- for c in extra_cols:
940
- if c not in df_daily.columns: # Avoid overwriting if already created by holidays pkg
941
- df_daily[c] = 0
942
-
943
- # Helper: nth_weekday_of_month(year, month, weekday, nth)
944
- def nth_weekday_of_month(year, month, weekday, nth):
945
- d = datetime(year, month, 1)
946
- w = d.weekday()
947
- delta = (weekday - w + 7) % 7 # Ensure positive delta
948
- first_weekday = d + timedelta(days=delta)
949
- target_date = first_weekday + timedelta(days=7 * (nth - 1))
950
- # Check if the calculated date is still in the same month
951
- if target_date.month == month:
952
- return target_date
953
- else:
954
- # This can happen if nth is too large (e.g., 5th Friday)
955
- # Return the last occurrence of that weekday in the month instead
956
- return target_date - timedelta(days=7)
957
-
958
-
959
- def get_good_friday(year):
960
- return easter(year) - timedelta(days=2)
961
-
962
- def get_easter_monday(year):
963
- return easter(year) + timedelta(days=1)
964
-
965
- def get_black_friday(year):
966
- # US Thanksgiving is 4th Thursday in November (weekday=3)
967
- thanksgiving = nth_weekday_of_month(year, 11, 3, 4)
968
- return thanksgiving + timedelta(days=1)
969
-
970
- def get_cyber_monday(year):
971
- # Monday after US Thanksgiving
972
- thanksgiving = nth_weekday_of_month(year, 11, 3, 4)
973
- return thanksgiving + timedelta(days=4)
974
-
975
- def get_mothering_sunday_uk(year):
976
- # Fourth Sunday in Lent (3 weeks before Easter Sunday)
977
- # Lent starts on Ash Wednesday, 46 days before Easter.
978
- # Easter Sunday is day 0. Sunday before is -7, etc.
979
- # 4th Sunday in Lent is 3 weeks before Easter.
980
- return easter(year) - timedelta(days=21)
981
-
982
-
983
- # Loop over each year in range
984
- for yr in range(start_year, end_year + 1):
985
- try: # Wrap calculations in try-except for robustness
986
- # Valentines = Feb 14
987
- valentines_day = datetime(yr, 2, 14)
988
- # Halloween = Oct 31
989
- halloween_day = datetime(yr, 10, 31)
990
- # Father's Day (US & UK) = 3rd Sunday (6) in June
991
- fathers_day = nth_weekday_of_month(yr, 6, 6, 3)
992
- # Mother's Day US = 2nd Sunday (6) in May
993
- mothers_day_us = nth_weekday_of_month(yr, 5, 6, 2)
994
- # Mother's Day UK (Mothering Sunday)
995
- mothering_sunday = get_mothering_sunday_uk(yr)
996
-
997
- # Good Friday, Easter Monday
998
- gf = get_good_friday(yr)
999
- em = get_easter_monday(yr)
1000
-
1001
- # Black Friday, Cyber Monday
1002
- bf = get_black_friday(yr)
1003
- cm = get_cyber_monday(yr)
1004
-
1005
- # Mark them in df_daily if in range
1006
- special_days_map = [
1007
- (valentines_day, "seas_valentines_day"),
1008
- (halloween_day, "seas_halloween"),
1009
- (fathers_day, "seas_fathers_day_us_uk"),
1010
- (mothers_day_us, "seas_mothers_day_us"),
1011
- (mothering_sunday,"seas_mothers_day_uk"),
1012
- (gf, "seas_good_friday"),
1013
- (em, "seas_easter_monday"),
1014
- (bf, "seas_black_friday"),
1015
- (cm, "seas_cyber_monday"),
1016
- ]
1017
-
1018
- for special_date, col in special_days_map:
1019
- if special_date is not None: # nth_weekday_of_month can return None edge cases
1020
- special_ts = pd.Timestamp(special_date)
1021
- # Only set if it's within the daily range AND column exists
1022
- if (special_ts >= df_daily["Date"].min()) and \
1023
- (special_ts <= df_daily["Date"].max()) and \
1024
- (col in df_daily.columns):
1025
- df_daily.loc[df_daily["Date"] == special_ts, col] = 1
1026
- except Exception as e:
1027
- print(f"Warning: Could not calculate special days for year {yr}: {e}")
1260
+ def is_last_friday(date):
1261
+ # Check if it's a Friday first
1262
+ if date.weekday() != 4: # Friday is 4
1263
+ return 0
1264
+ # Check if next Friday is in the next month
1265
+ next_friday = date + timedelta(days=7)
1266
+ return 1 if next_friday.month != date.month else 0
1028
1267
 
1268
+ def is_payday(date):
1269
+ return 1 if date.day >= 25 else 0
1029
1270
 
1030
- # ---------------------------------------------------------------------
1031
- # 4. Add daily indicators for last day & last Friday of month & payday
1032
- # ---------------------------------------------------------------------
1033
- df_daily["is_last_day_of_month"] = df_daily["Date"].dt.is_month_end
1034
-
1035
- def is_last_friday(date):
1036
- # Check if it's a Friday first
1037
- if date.weekday() != 4: # Friday is 4
1038
- return 0
1039
- # Check if next Friday is in the next month
1040
- next_friday = date + timedelta(days=7)
1041
- return 1 if next_friday.month != date.month else 0
1042
-
1043
- def is_payday(date):
1044
- return 1 if date.day >= 25 else 0
1045
-
1046
- df_daily["is_last_friday_of_month"] = df_daily["Date"].apply(is_last_friday)
1047
-
1048
- df_daily["is_payday"] = df_daily["Date"].apply(is_payday)
1049
-
1050
- # Rename for clarity prefix
1051
- df_daily.rename(columns={
1271
+ df_daily["is_last_friday_of_month"] = df_daily["Date"].apply(is_last_friday)
1272
+
1273
+ df_daily["is_payday"] = df_daily["Date"].apply(is_payday)
1274
+
1275
+ # Rename for clarity prefix
1276
+ df_daily.rename(
1277
+ columns={
1052
1278
  "is_last_day_of_month": "seas_last_day_of_month",
1053
1279
  "is_last_friday_of_month": "seas_last_friday_of_month",
1054
- "is_payday": "seas_payday"
1055
- }, inplace=True)
1280
+ "is_payday": "seas_payday",
1281
+ },
1282
+ inplace=True,
1283
+ )
1056
1284
 
1057
- # ---------------------------------------------------------------------
1058
- # 5. Weekly aggregation
1059
- # ---------------------------------------------------------------------
1285
+ # ---------------------------------------------------------------------
1286
+ # 5. Weekly aggregation
1287
+ # ---------------------------------------------------------------------
1060
1288
 
1061
- # Select only columns that are indicators/flags (intended for max aggregation)
1062
- flag_cols = [col for col in df_daily.columns if (col.startswith('seas_') or col.startswith('is_')) and col != "seas_payday"]
1063
- # Ensure 'week_start' is present for grouping
1064
- df_to_agg = df_daily[['week_start'] + flag_cols]
1289
+ # Select only columns that are indicators/flags (intended for max aggregation)
1290
+ flag_cols = [
1291
+ col
1292
+ for col in df_daily.columns
1293
+ if (col.startswith("seas_") or col.startswith("is_"))
1294
+ and col != "seas_payday"
1295
+ ]
1296
+ # Ensure 'week_start' is present for grouping
1297
+ df_to_agg = df_daily[["week_start"] + flag_cols]
1298
+
1299
+ df_weekly_flags = (
1300
+ df_to_agg.groupby("week_start")
1301
+ .max() # if any day=1 in that week, entire week=1
1302
+ .reset_index()
1303
+ .rename(columns={"week_start": "Date"})
1304
+ .set_index("Date")
1305
+ )
1065
1306
 
1066
- df_weekly_flags = (
1067
- df_to_agg
1068
- .groupby('week_start')
1069
- .max() # if any day=1 in that week, entire week=1
1070
- .reset_index()
1071
- .rename(columns={'week_start': "Date"})
1072
- .set_index("Date")
1073
- )
1074
-
1075
- # Do specific aggregation for payday
1076
- # Make sure 'date' column exists in df_daily
1077
- df_daily["month"] = df_daily["Date"].dt.month
1078
- df_daily["year"] = df_daily["Date"].dt.year
1079
-
1080
- # Sum of seas_payday flags per week
1081
- week_payday_sum = df_daily.groupby("week_start")["seas_payday"].sum()
1082
-
1083
- # Divide the number of payday flags by number of paydays per month
1084
- payday_days_in_month = (
1085
- df_daily.groupby(["year", "month"])["seas_payday"].sum()
1086
- )
1087
- week_month = df_daily.groupby("week_start").first()[["month", "year"]]
1088
- week_days_in_month = week_month.apply(lambda row: payday_days_in_month.loc[(row["year"], row["month"])], axis=1)
1089
- df_weekly_flags["seas_payday"] = (week_payday_sum / week_days_in_month).fillna(0).values
1090
-
1091
- # # Drop intermediate columns
1092
- # df_weekly_flags = df_weekly_flags.drop(columns=["month", "year"])
1093
-
1094
- # --- Aggregate Week Number using MODE ---
1095
- # Define aggregation function for mode (handling potential multi-modal cases by taking the first)
1096
- def get_mode(x):
1097
- modes = pd.Series.mode(x)
1098
- return modes[0] if not modes.empty else np.nan # Return first mode or NaN
1099
-
1100
- df_weekly_iso_week_year = (
1101
- df_daily[['week_start', 'iso_week_daily', 'iso_year_daily']]
1102
- .groupby('week_start')
1103
- .agg(
1104
- # Find the most frequent week number and year within the group
1105
- Week=('iso_week_daily', get_mode),
1106
- Year=('iso_year_daily', get_mode)
1107
- )
1108
- .reset_index()
1109
- .rename(columns={'week_start': 'Date'})
1110
- .set_index('Date')
1111
- )
1112
- # Convert Week/Year back to integer type after aggregation
1113
- df_weekly_iso_week_year['Week'] = df_weekly_iso_week_year['Week'].astype(int)
1114
- df_weekly_iso_week_year['Year'] = df_weekly_iso_week_year['Year'].astype(int)
1115
-
1116
-
1117
- # --- Monthly dummies (spread evenly across week) ---
1118
- df_daily["Month"] = df_daily["Date"].dt.month_name().str.lower()
1119
- df_monthly_dummies_daily = pd.get_dummies(
1120
- df_daily[["week_start", "Month"]], # Only need these columns
1121
- prefix="seas_month",
1122
- columns=["Month"],
1123
- dtype=float # Use float for division
1124
- )
1125
- # Sum daily dummies within the week
1126
- df_monthly_dummies_summed = df_monthly_dummies_daily.groupby('week_start').sum()
1127
- # Divide by number of days in that specific week group (usually 7, except potentially start/end)
1128
- days_in_week = df_daily.groupby('week_start').size()
1129
- df_weekly_monthly_dummies = df_monthly_dummies_summed.div(days_in_week, axis=0)
1307
+ # Do specific aggregation for payday
1308
+ # Make sure 'date' column exists in df_daily
1309
+ df_daily["month"] = df_daily["Date"].dt.month
1310
+ df_daily["year"] = df_daily["Date"].dt.year
1130
1311
 
1131
- # Reset index to merge
1132
- df_weekly_monthly_dummies.reset_index(inplace=True)
1133
- df_weekly_monthly_dummies.rename(columns={'week_start': 'Date'}, inplace=True)
1134
- df_weekly_monthly_dummies.set_index('Date', inplace=True)
1312
+ # Sum of seas_payday flags per week
1313
+ week_payday_sum = df_daily.groupby("week_start")["seas_payday"].sum()
1135
1314
 
1136
- # ---------------------------------------------------------------------
1137
- # 6. Combine all weekly components
1138
- # ---------------------------------------------------------------------
1139
- # Start with the basic weekly index
1140
- df_combined = df_weekly_start.copy()
1141
-
1142
- # Join the other aggregated DataFrames
1143
- df_combined = df_combined.join(df_weekly_flags, how='left')
1144
- df_combined = df_combined.join(df_weekly_iso_week_year, how='left')
1145
- df_combined = df_combined.join(df_weekly_monthly_dummies, how='left')
1146
-
1147
- # Fill potential NaNs created by joins (e.g., if a flag column didn't exist) with 0
1148
- # Exclude 'Week' and 'Year' which should always be present
1149
- cols_to_fill = df_combined.columns.difference(['Week', 'Year'])
1150
- df_combined[cols_to_fill] = df_combined[cols_to_fill].fillna(0)
1151
-
1152
- # Ensure correct types for flag columns (int)
1153
- for col in df_weekly_flags.columns:
1154
- if col != 'seas_payday':
1155
- if col in df_combined.columns:
1156
- df_combined[col] = df_combined[col].astype(int)
1157
-
1158
- # Ensure correct types for month columns (float)
1159
- for col in df_weekly_monthly_dummies.columns:
1160
- if col in df_combined.columns:
1161
- df_combined[col] = df_combined[col].astype(float)
1315
+ # Divide the number of payday flags by number of paydays per month
1316
+ payday_days_in_month = df_daily.groupby(["year", "month"])["seas_payday"].sum()
1317
+ week_month = df_daily.groupby("week_start").first()[["month", "year"]]
1318
+ week_days_in_month = week_month.apply(
1319
+ lambda row: payday_days_in_month.loc[(row["year"], row["month"])],
1320
+ axis=1,
1321
+ )
1322
+ df_weekly_flags["seas_payday"] = (
1323
+ (week_payday_sum / week_days_in_month).fillna(0).values
1324
+ )
1162
1325
 
1163
- # ---------------------------------------------------------------------
1164
- # 7. Create weekly dummies for Week of Year & yearly dummies from aggregated cols
1165
- # ---------------------------------------------------------------------
1166
- df_combined.reset_index(inplace=True) # 'Date', 'Week', 'Year' become columns
1326
+ # # Drop intermediate columns
1327
+ # df_weekly_flags = df_weekly_flags.drop(columns=["month", "year"])
1328
+
1329
+ # --- Aggregate Week Number using MODE ---
1330
+ # Define aggregation function for mode (handling potential multi-modal cases by taking the first)
1331
+ def get_mode(x):
1332
+ modes = pd.Series.mode(x)
1333
+ return modes[0] if not modes.empty else np.nan # Return first mode or NaN
1334
+
1335
+ df_weekly_iso_week_year = (
1336
+ df_daily[["week_start", "iso_week_daily", "iso_year_daily"]]
1337
+ .groupby("week_start")
1338
+ .agg(
1339
+ # Find the most frequent week number and year within the group
1340
+ Week=("iso_week_daily", get_mode),
1341
+ Year=("iso_year_daily", get_mode),
1342
+ )
1343
+ .reset_index()
1344
+ .rename(columns={"week_start": "Date"})
1345
+ .set_index("Date")
1346
+ )
1347
+ # Convert Week/Year back to integer type after aggregation
1348
+ df_weekly_iso_week_year["Week"] = df_weekly_iso_week_year["Week"].astype(int)
1349
+ df_weekly_iso_week_year["Year"] = df_weekly_iso_week_year["Year"].astype(int)
1350
+
1351
+ # --- Monthly dummies (spread evenly across week) ---
1352
+ df_daily["Month"] = df_daily["Date"].dt.month_name().str.lower()
1353
+ df_monthly_dummies_daily = pd.get_dummies(
1354
+ df_daily[["week_start", "Month"]], # Only need these columns
1355
+ prefix="seas_month",
1356
+ columns=["Month"],
1357
+ dtype=float, # Use float for division
1358
+ )
1359
+ # Sum daily dummies within the week
1360
+ df_monthly_dummies_summed = df_monthly_dummies_daily.groupby("week_start").sum()
1361
+ # Divide by number of days in that specific week group (usually 7, except potentially start/end)
1362
+ days_in_week = df_daily.groupby("week_start").size()
1363
+ df_weekly_monthly_dummies = df_monthly_dummies_summed.div(days_in_week, axis=0)
1364
+
1365
+ # Reset index to merge
1366
+ df_weekly_monthly_dummies.reset_index(inplace=True)
1367
+ df_weekly_monthly_dummies.rename(columns={"week_start": "Date"}, inplace=True)
1368
+ df_weekly_monthly_dummies.set_index("Date", inplace=True)
1369
+
1370
+ # ---------------------------------------------------------------------
1371
+ # 6. Combine all weekly components
1372
+ # ---------------------------------------------------------------------
1373
+ # Start with the basic weekly index
1374
+ df_combined = df_weekly_start.copy()
1375
+
1376
+ # Join the other aggregated DataFrames
1377
+ df_combined = df_combined.join(df_weekly_flags, how="left")
1378
+ df_combined = df_combined.join(df_weekly_iso_week_year, how="left")
1379
+ df_combined = df_combined.join(df_weekly_monthly_dummies, how="left")
1380
+
1381
+ # Fill potential NaNs created by joins (e.g., if a flag column didn't exist) with 0
1382
+ # Exclude 'Week' and 'Year' which should always be present
1383
+ cols_to_fill = df_combined.columns.difference(["Week", "Year"])
1384
+ df_combined[cols_to_fill] = df_combined[cols_to_fill].fillna(0)
1385
+
1386
+ # Ensure correct types for flag columns (int)
1387
+ for col in df_weekly_flags.columns:
1388
+ if col != "seas_payday":
1389
+ if col in df_combined.columns:
1390
+ df_combined[col] = df_combined[col].astype(int)
1391
+
1392
+ # Ensure correct types for month columns (float)
1393
+ for col in df_weekly_monthly_dummies.columns:
1394
+ if col in df_combined.columns:
1395
+ df_combined[col] = df_combined[col].astype(float)
1396
+
1397
+ # ---------------------------------------------------------------------
1398
+ # 7. Create weekly dummies for Week of Year & yearly dummies from aggregated cols
1399
+ # ---------------------------------------------------------------------
1400
+ df_combined.reset_index(inplace=True) # 'Date', 'Week', 'Year' become columns
1401
+
1402
+ # Create dummies from the aggregated 'Week' column
1403
+ df_combined = pd.get_dummies(
1404
+ df_combined,
1405
+ prefix="seas",
1406
+ columns=["Week"],
1407
+ dtype=int,
1408
+ prefix_sep="_",
1409
+ )
1167
1410
 
1168
- # Create dummies from the aggregated 'Week' column
1169
- df_combined = pd.get_dummies(df_combined, prefix="seas", columns=["Week"], dtype=int, prefix_sep='_')
1411
+ # Create dummies from the aggregated 'Year' column
1412
+ df_combined = pd.get_dummies(
1413
+ df_combined,
1414
+ prefix="seas",
1415
+ columns=["Year"],
1416
+ dtype=int,
1417
+ prefix_sep="_",
1418
+ )
1170
1419
 
1171
- # Create dummies from the aggregated 'Year' column
1172
- df_combined = pd.get_dummies(df_combined, prefix="seas", columns=["Year"], dtype=int, prefix_sep='_')
1420
+ # ---------------------------------------------------------------------
1421
+ # 8. Add constant & trend
1422
+ # ---------------------------------------------------------------------
1423
+ df_combined["Constant"] = 1
1424
+ df_combined.reset_index(
1425
+ drop=True,
1426
+ inplace=True,
1427
+ ) # Ensure index is 0, 1, 2... for trend
1428
+ df_combined["Trend"] = df_combined.index + 1
1429
+
1430
+ # ---------------------------------------------------------------------
1431
+ # 9. Rename Date -> OBS and select final columns
1432
+ # ---------------------------------------------------------------------
1433
+ df_combined.rename(columns={"Date": "OBS"}, inplace=True)
1434
+
1435
+ # Reorder columns - OBS first, then Constant, Trend, then seasonal features
1436
+ cols_order = (
1437
+ ["OBS", "Constant", "Trend"]
1438
+ + sorted([col for col in df_combined.columns if col.startswith("seas_")])
1439
+ + sorted([col for col in df_combined.columns if col.startswith("dum_")])
1440
+ ) # If individual week dummies were enabled
1441
+
1442
+ # Filter out columns not in the desired order list (handles case where dum_ cols are off)
1443
+ final_cols = [col for col in cols_order if col in df_combined.columns]
1444
+ df_combined = df_combined[final_cols]
1445
+
1446
+ return df_combined
1173
1447
 
1174
- # ---------------------------------------------------------------------
1175
- # 8. Add constant & trend
1176
- # ---------------------------------------------------------------------
1177
- df_combined["Constant"] = 1
1178
- df_combined.reset_index(drop=True, inplace=True) # Ensure index is 0, 1, 2... for trend
1179
- df_combined["Trend"] = df_combined.index + 1
1180
-
1181
- # ---------------------------------------------------------------------
1182
- # 9. Rename Date -> OBS and select final columns
1183
- # ---------------------------------------------------------------------
1184
- df_combined.rename(columns={"Date": "OBS"}, inplace=True)
1185
-
1186
- # Reorder columns - OBS first, then Constant, Trend, then seasonal features
1187
- cols_order = ['OBS', 'Constant', 'Trend'] + \
1188
- sorted([col for col in df_combined.columns if col.startswith('seas_')]) + \
1189
- sorted([col for col in df_combined.columns if col.startswith('dum_')]) # If individual week dummies were enabled
1190
-
1191
- # Filter out columns not in the desired order list (handles case where dum_ cols are off)
1192
- final_cols = [col for col in cols_order if col in df_combined.columns]
1193
- df_combined = df_combined[final_cols]
1194
-
1195
- return df_combined
1196
-
1197
1448
  def pull_weather(self, week_commencing, start_date, country_codes) -> pd.DataFrame:
1198
1449
  """
1199
1450
  Pull weather data for a given week-commencing day and one or more country codes.
@@ -1223,7 +1474,15 @@ class datapull:
1223
1474
  raise ValueError("country_codes must be a list/tuple or a single string.")
1224
1475
 
1225
1476
  # --- Setup / Constants --- #
1226
- day_dict = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
1477
+ day_dict = {
1478
+ "mon": 0,
1479
+ "tue": 1,
1480
+ "wed": 2,
1481
+ "thu": 3,
1482
+ "fri": 4,
1483
+ "sat": 5,
1484
+ "sun": 6,
1485
+ }
1227
1486
  # Map each 2-letter code to a key
1228
1487
  country_dict = {
1229
1488
  "US": "US_STATES",
@@ -1231,24 +1490,45 @@ class datapull:
1231
1490
  "AU": "AU__ASOS",
1232
1491
  "GB": "GB__ASOS",
1233
1492
  "DE": "DE__ASOS",
1234
- "ZA": "ZA__ASOS"
1493
+ "ZA": "ZA__ASOS",
1235
1494
  }
1236
1495
 
1237
1496
  # Station-based countries for Mesonet
1238
1497
  station_map = {
1239
1498
  "GB__ASOS": [
1240
- "&stations=EGCC", "&stations=EGNM", "&stations=EGBB", "&stations=EGSH",
1241
- "&stations=EGFF", "&stations=EGHI", "&stations=EGLC", "&stations=EGHQ",
1242
- "&stations=EGAC", "&stations=EGPF", "&stations=EGGD", "&stations=EGPE",
1243
- "&stations=EGNT"
1499
+ "&stations=EGCC",
1500
+ "&stations=EGNM",
1501
+ "&stations=EGBB",
1502
+ "&stations=EGSH",
1503
+ "&stations=EGFF",
1504
+ "&stations=EGHI",
1505
+ "&stations=EGLC",
1506
+ "&stations=EGHQ",
1507
+ "&stations=EGAC",
1508
+ "&stations=EGPF",
1509
+ "&stations=EGGD",
1510
+ "&stations=EGPE",
1511
+ "&stations=EGNT",
1244
1512
  ],
1245
1513
  "AU__ASOS": [
1246
- "&stations=YPDN", "&stations=YBCS", "&stations=YBBN", "&stations=YSSY",
1247
- "&stations=YSSY", "&stations=YMEN", "&stations=YPAD", "&stations=YPPH"
1514
+ "&stations=YPDN",
1515
+ "&stations=YBCS",
1516
+ "&stations=YBBN",
1517
+ "&stations=YSSY",
1518
+ "&stations=YSSY",
1519
+ "&stations=YMEN",
1520
+ "&stations=YPAD",
1521
+ "&stations=YPPH",
1248
1522
  ],
1249
1523
  "DE__ASOS": [
1250
- "&stations=EDDL", "&stations=EDDH", "&stations=EDDB", "&stations=EDDN",
1251
- "&stations=EDDF", "&stations=EDDK", "&stations=EDLW", "&stations=EDDM"
1524
+ "&stations=EDDL",
1525
+ "&stations=EDDH",
1526
+ "&stations=EDDB",
1527
+ "&stations=EDDN",
1528
+ "&stations=EDDF",
1529
+ "&stations=EDDK",
1530
+ "&stations=EDLW",
1531
+ "&stations=EDDM",
1252
1532
  ],
1253
1533
  # Example: if ZA is also station-based, add it here.
1254
1534
  "ZA__ASOS": [
@@ -1261,110 +1541,467 @@ class datapull:
1261
1541
  # Non-US countries that also fetch RAIN & SNOW from Open-Meteo
1262
1542
  rainfall_city_map = {
1263
1543
  "GB__ASOS": [
1264
- "Manchester", "Leeds", "Birmingham", "London","Glasgow",
1544
+ "Manchester",
1545
+ "Leeds",
1546
+ "Birmingham",
1547
+ "London",
1548
+ "Glasgow",
1265
1549
  ],
1266
1550
  "AU__ASOS": [
1267
- "Darwin", "Cairns", "Brisbane", "Sydney", "Melbourne", "Adelaide", "Perth"
1551
+ "Darwin",
1552
+ "Cairns",
1553
+ "Brisbane",
1554
+ "Sydney",
1555
+ "Melbourne",
1556
+ "Adelaide",
1557
+ "Perth",
1268
1558
  ],
1269
1559
  "DE__ASOS": [
1270
- "Dortmund", "Düsseldorf", "Frankfurt", "Munich", "Cologne", "Berlin", "Hamburg", "Nuernberg"
1271
- ],
1272
- "ZA__ASOS": [
1273
- "Johannesburg", "Cape Town", "Durban", "Pretoria"
1560
+ "Dortmund",
1561
+ "Düsseldorf",
1562
+ "Frankfurt",
1563
+ "Munich",
1564
+ "Cologne",
1565
+ "Berlin",
1566
+ "Hamburg",
1567
+ "Nuernberg",
1274
1568
  ],
1569
+ "ZA__ASOS": ["Johannesburg", "Cape Town", "Durban", "Pretoria"],
1275
1570
  }
1276
1571
 
1277
1572
  # Canada sub-networks
1278
1573
  institute_vector = [
1279
- "CA_NB_ASOS", "CA_NF_ASOS", "CA_NT_ASOS", "CA_NS_ASOS", "CA_NU_ASOS"
1574
+ "CA_NB_ASOS",
1575
+ "CA_NF_ASOS",
1576
+ "CA_NT_ASOS",
1577
+ "CA_NS_ASOS",
1578
+ "CA_NU_ASOS",
1280
1579
  ]
1281
1580
  stations_list_canada = [
1282
1581
  [
1283
- "&stations=CYQM", "&stations=CERM", "&stations=CZCR",
1284
- "&stations=CZBF", "&stations=CYFC", "&stations=CYCX"
1582
+ "&stations=CYQM",
1583
+ "&stations=CERM",
1584
+ "&stations=CZCR",
1585
+ "&stations=CZBF",
1586
+ "&stations=CYFC",
1587
+ "&stations=CYCX",
1285
1588
  ],
1286
1589
  [
1287
- "&stations=CWZZ", "&stations=CYDP", "&stations=CYMH", "&stations=CYAY",
1288
- "&stations=CWDO", "&stations=CXTP", "&stations=CYJT", "&stations=CYYR",
1289
- "&stations=CZUM", "&stations=CYWK", "&stations=CYWK"
1590
+ "&stations=CWZZ",
1591
+ "&stations=CYDP",
1592
+ "&stations=CYMH",
1593
+ "&stations=CYAY",
1594
+ "&stations=CWDO",
1595
+ "&stations=CXTP",
1596
+ "&stations=CYJT",
1597
+ "&stations=CYYR",
1598
+ "&stations=CZUM",
1599
+ "&stations=CYWK",
1600
+ "&stations=CYWK",
1290
1601
  ],
1291
1602
  [
1292
- "&stations=CYHI", "&stations=CZCP", "&stations=CWLI", "&stations=CWND",
1293
- "&stations=CXTV", "&stations=CYVL", "&stations=CYCO", "&stations=CXDE",
1294
- "&stations=CYWE", "&stations=CYLK", "&stations=CWID", "&stations=CYRF",
1295
- "&stations=CXYH", "&stations=CYWY", "&stations=CWMT"
1603
+ "&stations=CYHI",
1604
+ "&stations=CZCP",
1605
+ "&stations=CWLI",
1606
+ "&stations=CWND",
1607
+ "&stations=CXTV",
1608
+ "&stations=CYVL",
1609
+ "&stations=CYCO",
1610
+ "&stations=CXDE",
1611
+ "&stations=CYWE",
1612
+ "&stations=CYLK",
1613
+ "&stations=CWID",
1614
+ "&stations=CYRF",
1615
+ "&stations=CXYH",
1616
+ "&stations=CYWY",
1617
+ "&stations=CWMT",
1296
1618
  ],
1297
1619
  [
1298
- "&stations=CWEF", "&stations=CXIB", "&stations=CYQY", "&stations=CYPD",
1299
- "&stations=CXNP", "&stations=CXMY", "&stations=CYAW", "&stations=CWKG",
1300
- "&stations=CWVU", "&stations=CXLB", "&stations=CWSA", "&stations=CWRN"
1620
+ "&stations=CWEF",
1621
+ "&stations=CXIB",
1622
+ "&stations=CYQY",
1623
+ "&stations=CYPD",
1624
+ "&stations=CXNP",
1625
+ "&stations=CXMY",
1626
+ "&stations=CYAW",
1627
+ "&stations=CWKG",
1628
+ "&stations=CWVU",
1629
+ "&stations=CXLB",
1630
+ "&stations=CWSA",
1631
+ "&stations=CWRN",
1301
1632
  ],
1302
1633
  [
1303
- "&stations=CYLT", "&stations=CWEU", "&stations=CWGZ", "&stations=CYIO",
1304
- "&stations=CXSE", "&stations=CYCB", "&stations=CWIL", "&stations=CXWB",
1305
- "&stations=CYZS", "&stations=CWJC", "&stations=CYFB", "&stations=CWUW"
1306
- ]
1634
+ "&stations=CYLT",
1635
+ "&stations=CWEU",
1636
+ "&stations=CWGZ",
1637
+ "&stations=CYIO",
1638
+ "&stations=CXSE",
1639
+ "&stations=CYCB",
1640
+ "&stations=CWIL",
1641
+ "&stations=CXWB",
1642
+ "&stations=CYZS",
1643
+ "&stations=CWJC",
1644
+ "&stations=CYFB",
1645
+ "&stations=CWUW",
1646
+ ],
1307
1647
  ]
1308
1648
 
1309
1649
  # US states and stations - each sub-network
1310
1650
  us_state_networks = {
1311
- state: f"{state}_ASOS" for state in [
1312
- "AL", "AR", "AZ", "CA", "CO", "CT", "DE", "FL", "GA", "IA", "ID", "IL", "IN",
1313
- "KS", "KY", "LA", "MA", "MD", "ME", "MI", "MN", "MO", "MS", "MT", "NC", "ND",
1314
- "NE", "NH", "NJ", "NM", "NV", "NY", "OH", "OK", "OR", "PA", "RI", "SC", "SD",
1315
- "TN", "TX", "UT", "VA", "VT", "WA", "WI", "WV", "WY"
1651
+ state: f"{state}_ASOS"
1652
+ for state in [
1653
+ "AL",
1654
+ "AR",
1655
+ "AZ",
1656
+ "CA",
1657
+ "CO",
1658
+ "CT",
1659
+ "DE",
1660
+ "FL",
1661
+ "GA",
1662
+ "IA",
1663
+ "ID",
1664
+ "IL",
1665
+ "IN",
1666
+ "KS",
1667
+ "KY",
1668
+ "LA",
1669
+ "MA",
1670
+ "MD",
1671
+ "ME",
1672
+ "MI",
1673
+ "MN",
1674
+ "MO",
1675
+ "MS",
1676
+ "MT",
1677
+ "NC",
1678
+ "ND",
1679
+ "NE",
1680
+ "NH",
1681
+ "NJ",
1682
+ "NM",
1683
+ "NV",
1684
+ "NY",
1685
+ "OH",
1686
+ "OK",
1687
+ "OR",
1688
+ "PA",
1689
+ "RI",
1690
+ "SC",
1691
+ "SD",
1692
+ "TN",
1693
+ "TX",
1694
+ "UT",
1695
+ "VA",
1696
+ "VT",
1697
+ "WA",
1698
+ "WI",
1699
+ "WV",
1700
+ "WY",
1316
1701
  ]
1317
1702
  }
1318
-
1703
+
1319
1704
  us_stations_map = {
1320
- "AL_ASOS": ["&stations=BHM", "&stations=HSV", "&stations=MGM", "&stations=MOB", "&stations=TCL"],
1321
- "AR_ASOS": ["&stations=LIT", "&stations=FSM", "&stations=TXK", "&stations=HOT", "&stations=FYV"],
1322
- "AZ_ASOS": ["&stations=PHX", "&stations=TUS", "&stations=FLG", "&stations=YUM", "&stations=PRC"],
1323
- "CA_ASOS": ["&stations=LAX", "&stations=SAN", "&stations=SJC", "&stations=SFO", "&stations=FAT"],
1324
- "CO_ASOS": ["&stations=DEN", "&stations=COS", "&stations=GJT", "&stations=PUB", "&stations=ASE"],
1325
- "CT_ASOS": ["&stations=BDL", "&stations=HVN", "&stations=BDR", "&stations=GON", "&stations=HFD"],
1705
+ "AL_ASOS": [
1706
+ "&stations=BHM",
1707
+ "&stations=HSV",
1708
+ "&stations=MGM",
1709
+ "&stations=MOB",
1710
+ "&stations=TCL",
1711
+ ],
1712
+ "AR_ASOS": [
1713
+ "&stations=LIT",
1714
+ "&stations=FSM",
1715
+ "&stations=TXK",
1716
+ "&stations=HOT",
1717
+ "&stations=FYV",
1718
+ ],
1719
+ "AZ_ASOS": [
1720
+ "&stations=PHX",
1721
+ "&stations=TUS",
1722
+ "&stations=FLG",
1723
+ "&stations=YUM",
1724
+ "&stations=PRC",
1725
+ ],
1726
+ "CA_ASOS": [
1727
+ "&stations=LAX",
1728
+ "&stations=SAN",
1729
+ "&stations=SJC",
1730
+ "&stations=SFO",
1731
+ "&stations=FAT",
1732
+ ],
1733
+ "CO_ASOS": [
1734
+ "&stations=DEN",
1735
+ "&stations=COS",
1736
+ "&stations=GJT",
1737
+ "&stations=PUB",
1738
+ "&stations=ASE",
1739
+ ],
1740
+ "CT_ASOS": [
1741
+ "&stations=BDL",
1742
+ "&stations=HVN",
1743
+ "&stations=BDR",
1744
+ "&stations=GON",
1745
+ "&stations=HFD",
1746
+ ],
1326
1747
  "DE_ASOS": ["&stations=ILG", "&stations=GED", "&stations=DOV"],
1327
- "FL_ASOS": ["&stations=MIA", "&stations=TPA", "&stations=ORL", "&stations=JAX", "&stations=TLH"],
1328
- "GA_ASOS": ["&stations=ATL", "&stations=SAV", "&stations=CSG", "&stations=MCN", "&stations=AGS"],
1329
- "IA_ASOS": ["&stations=DSM", "&stations=CID", "&stations=DBQ", "&stations=ALO", "&stations=SUX"],
1330
- "ID_ASOS": ["&stations=BOI", "&stations=IDA", "&stations=PIH", "&stations=SUN", "&stations=COE"],
1331
- "IL_ASOS": ["&stations=ORD", "&stations=MDW", "&stations=PIA", "&stations=SPI", "&stations=MLI"],
1332
- "IN_ASOS": ["&stations=IND", "&stations=FWA", "&stations=SBN", "&stations=EVV", "&stations=HUF"],
1333
- "KS_ASOS": ["&stations=ICT", "&stations=FOE", "&stations=GCK", "&stations=HYS", "&stations=SLN"],
1334
- "KY_ASOS": ["&stations=SDF", "&stations=LEX", "&stations=CVG", "&stations=PAH", "&stations=BWG"],
1335
- "LA_ASOS": ["&stations=MSY", "&stations=SHV", "&stations=LFT", "&stations=BTR", "&stations=MLU"],
1336
- "MA_ASOS": ["&stations=BOS", "&stations=ORH", "&stations=HYA", "&stations=ACK", "&stations=BED"],
1337
- "MD_ASOS": ["&stations=BWI", "&stations=MTN", "&stations=SBY", "&stations=HGR", "&stations=ADW"],
1338
- "ME_ASOS": ["&stations=PWM", "&stations=BGR", "&stations=CAR", "&stations=PQI", "&stations=RKD"],
1339
- "MI_ASOS": ["&stations=DTW", "&stations=GRR", "&stations=FNT", "&stations=LAN", "&stations=MKG"],
1340
- "MN_ASOS": ["&stations=MSP", "&stations=DLH", "&stations=RST", "&stations=STC", "&stations=INL"],
1341
- "MO_ASOS": ["&stations=STL", "&stations=MCI", "&stations=SGF", "&stations=COU", "&stations=JLN"],
1342
- "MS_ASOS": ["&stations=JAN", "&stations=GPT", "&stations=MEI", "&stations=PIB", "&stations=GLH"],
1343
- "MT_ASOS": ["&stations=BIL", "&stations=MSO", "&stations=GTF", "&stations=HLN", "&stations=BZN"],
1344
- "NC_ASOS": ["&stations=CLT", "&stations=RDU", "&stations=GSO", "&stations=ILM", "&stations=AVL"],
1345
- "ND_ASOS": ["&stations=BIS", "&stations=FAR", "&stations=GFK", "&stations=ISN", "&stations=JMS"],
1748
+ "FL_ASOS": [
1749
+ "&stations=MIA",
1750
+ "&stations=TPA",
1751
+ "&stations=ORL",
1752
+ "&stations=JAX",
1753
+ "&stations=TLH",
1754
+ ],
1755
+ "GA_ASOS": [
1756
+ "&stations=ATL",
1757
+ "&stations=SAV",
1758
+ "&stations=CSG",
1759
+ "&stations=MCN",
1760
+ "&stations=AGS",
1761
+ ],
1762
+ "IA_ASOS": [
1763
+ "&stations=DSM",
1764
+ "&stations=CID",
1765
+ "&stations=DBQ",
1766
+ "&stations=ALO",
1767
+ "&stations=SUX",
1768
+ ],
1769
+ "ID_ASOS": [
1770
+ "&stations=BOI",
1771
+ "&stations=IDA",
1772
+ "&stations=PIH",
1773
+ "&stations=SUN",
1774
+ "&stations=COE",
1775
+ ],
1776
+ "IL_ASOS": [
1777
+ "&stations=ORD",
1778
+ "&stations=MDW",
1779
+ "&stations=PIA",
1780
+ "&stations=SPI",
1781
+ "&stations=MLI",
1782
+ ],
1783
+ "IN_ASOS": [
1784
+ "&stations=IND",
1785
+ "&stations=FWA",
1786
+ "&stations=SBN",
1787
+ "&stations=EVV",
1788
+ "&stations=HUF",
1789
+ ],
1790
+ "KS_ASOS": [
1791
+ "&stations=ICT",
1792
+ "&stations=FOE",
1793
+ "&stations=GCK",
1794
+ "&stations=HYS",
1795
+ "&stations=SLN",
1796
+ ],
1797
+ "KY_ASOS": [
1798
+ "&stations=SDF",
1799
+ "&stations=LEX",
1800
+ "&stations=CVG",
1801
+ "&stations=PAH",
1802
+ "&stations=BWG",
1803
+ ],
1804
+ "LA_ASOS": [
1805
+ "&stations=MSY",
1806
+ "&stations=SHV",
1807
+ "&stations=LFT",
1808
+ "&stations=BTR",
1809
+ "&stations=MLU",
1810
+ ],
1811
+ "MA_ASOS": [
1812
+ "&stations=BOS",
1813
+ "&stations=ORH",
1814
+ "&stations=HYA",
1815
+ "&stations=ACK",
1816
+ "&stations=BED",
1817
+ ],
1818
+ "MD_ASOS": [
1819
+ "&stations=BWI",
1820
+ "&stations=MTN",
1821
+ "&stations=SBY",
1822
+ "&stations=HGR",
1823
+ "&stations=ADW",
1824
+ ],
1825
+ "ME_ASOS": [
1826
+ "&stations=PWM",
1827
+ "&stations=BGR",
1828
+ "&stations=CAR",
1829
+ "&stations=PQI",
1830
+ "&stations=RKD",
1831
+ ],
1832
+ "MI_ASOS": [
1833
+ "&stations=DTW",
1834
+ "&stations=GRR",
1835
+ "&stations=FNT",
1836
+ "&stations=LAN",
1837
+ "&stations=MKG",
1838
+ ],
1839
+ "MN_ASOS": [
1840
+ "&stations=MSP",
1841
+ "&stations=DLH",
1842
+ "&stations=RST",
1843
+ "&stations=STC",
1844
+ "&stations=INL",
1845
+ ],
1846
+ "MO_ASOS": [
1847
+ "&stations=STL",
1848
+ "&stations=MCI",
1849
+ "&stations=SGF",
1850
+ "&stations=COU",
1851
+ "&stations=JLN",
1852
+ ],
1853
+ "MS_ASOS": [
1854
+ "&stations=JAN",
1855
+ "&stations=GPT",
1856
+ "&stations=MEI",
1857
+ "&stations=PIB",
1858
+ "&stations=GLH",
1859
+ ],
1860
+ "MT_ASOS": [
1861
+ "&stations=BIL",
1862
+ "&stations=MSO",
1863
+ "&stations=GTF",
1864
+ "&stations=HLN",
1865
+ "&stations=BZN",
1866
+ ],
1867
+ "NC_ASOS": [
1868
+ "&stations=CLT",
1869
+ "&stations=RDU",
1870
+ "&stations=GSO",
1871
+ "&stations=ILM",
1872
+ "&stations=AVL",
1873
+ ],
1874
+ "ND_ASOS": [
1875
+ "&stations=BIS",
1876
+ "&stations=FAR",
1877
+ "&stations=GFK",
1878
+ "&stations=ISN",
1879
+ "&stations=JMS",
1880
+ ],
1346
1881
  "NE_ASOS": ["&stations=OMA"],
1347
- "NH_ASOS": ["&stations=MHT", "&stations=PSM", "&stations=CON", "&stations=LEB", "&stations=ASH"],
1348
- "NJ_ASOS": ["&stations=EWR", "&stations=ACY", "&stations=TTN", "&stations=MMU", "&stations=TEB"],
1349
- "NM_ASOS": ["&stations=ABQ", "&stations=SAF", "&stations=ROW", "&stations=HOB", "&stations=FMN"],
1882
+ "NH_ASOS": [
1883
+ "&stations=MHT",
1884
+ "&stations=PSM",
1885
+ "&stations=CON",
1886
+ "&stations=LEB",
1887
+ "&stations=ASH",
1888
+ ],
1889
+ "NJ_ASOS": [
1890
+ "&stations=EWR",
1891
+ "&stations=ACY",
1892
+ "&stations=TTN",
1893
+ "&stations=MMU",
1894
+ "&stations=TEB",
1895
+ ],
1896
+ "NM_ASOS": [
1897
+ "&stations=ABQ",
1898
+ "&stations=SAF",
1899
+ "&stations=ROW",
1900
+ "&stations=HOB",
1901
+ "&stations=FMN",
1902
+ ],
1350
1903
  "NV_ASOS": ["&stations=LAS"],
1351
- "NY_ASOS": ["&stations=JFK", "&stations=LGA", "&stations=BUF", "&stations=ALB", "&stations=SYR"],
1904
+ "NY_ASOS": [
1905
+ "&stations=JFK",
1906
+ "&stations=LGA",
1907
+ "&stations=BUF",
1908
+ "&stations=ALB",
1909
+ "&stations=SYR",
1910
+ ],
1352
1911
  "OH_ASOS": ["&stations=CMH"],
1353
- "OK_ASOS": ["&stations=OKC", "&stations=TUL", "&stations=LAW", "&stations=SWO", "&stations=PNC"],
1912
+ "OK_ASOS": [
1913
+ "&stations=OKC",
1914
+ "&stations=TUL",
1915
+ "&stations=LAW",
1916
+ "&stations=SWO",
1917
+ "&stations=PNC",
1918
+ ],
1354
1919
  "OR_ASOS": ["&stations=PDX"],
1355
- "PA_ASOS": ["&stations=PHL", "&stations=PIT", "&stations=ERI", "&stations=MDT", "&stations=AVP"],
1920
+ "PA_ASOS": [
1921
+ "&stations=PHL",
1922
+ "&stations=PIT",
1923
+ "&stations=ERI",
1924
+ "&stations=MDT",
1925
+ "&stations=AVP",
1926
+ ],
1356
1927
  "RI_ASOS": ["&stations=PVD", "&stations=WST", "&stations=UUU"],
1357
- "SC_ASOS": ["&stations=CHS", "&stations=CAE", "&stations=GSP", "&stations=MYR", "&stations=FLO"],
1358
- "SD_ASOS": ["&stations=FSD", "&stations=RAP", "&stations=PIR", "&stations=ABR", "&stations=YKN"],
1359
- "TN_ASOS": ["&stations=BNA", "&stations=MEM", "&stations=TYS", "&stations=CHA", "&stations=TRI"],
1360
- "TX_ASOS": ["&stations=DFW", "&stations=IAH", "&stations=AUS", "&stations=SAT", "&stations=ELP"],
1361
- "UT_ASOS": ["&stations=SLC", "&stations=OGD", "&stations=PVU", "&stations=SGU", "&stations=CNY"],
1362
- "VA_ASOS": ["&stations=DCA", "&stations=RIC", "&stations=ROA", "&stations=ORF", "&stations=SHD"],
1363
- "VT_ASOS": ["&stations=BTV", "&stations=MPV", "&stations=RUT", "&stations=VSF", "&stations=MVL"],
1364
- "WA_ASOS": ["&stations=SEA", "&stations=GEG", "&stations=TIW", "&stations=VUO", "&stations=BFI"],
1365
- "WI_ASOS": ["&stations=MKE", "&stations=MSN", "&stations=GRB", "&stations=EAU", "&stations=LSE"],
1366
- "WV_ASOS": ["&stations=CRW", "&stations=CKB", "&stations=HTS", "&stations=MGW", "&stations=BKW"],
1367
- "WY_ASOS": ["&stations=CPR", "&stations=JAC", "&stations=SHR", "&stations=COD", "&stations=RKS"],
1928
+ "SC_ASOS": [
1929
+ "&stations=CHS",
1930
+ "&stations=CAE",
1931
+ "&stations=GSP",
1932
+ "&stations=MYR",
1933
+ "&stations=FLO",
1934
+ ],
1935
+ "SD_ASOS": [
1936
+ "&stations=FSD",
1937
+ "&stations=RAP",
1938
+ "&stations=PIR",
1939
+ "&stations=ABR",
1940
+ "&stations=YKN",
1941
+ ],
1942
+ "TN_ASOS": [
1943
+ "&stations=BNA",
1944
+ "&stations=MEM",
1945
+ "&stations=TYS",
1946
+ "&stations=CHA",
1947
+ "&stations=TRI",
1948
+ ],
1949
+ "TX_ASOS": [
1950
+ "&stations=DFW",
1951
+ "&stations=IAH",
1952
+ "&stations=AUS",
1953
+ "&stations=SAT",
1954
+ "&stations=ELP",
1955
+ ],
1956
+ "UT_ASOS": [
1957
+ "&stations=SLC",
1958
+ "&stations=OGD",
1959
+ "&stations=PVU",
1960
+ "&stations=SGU",
1961
+ "&stations=CNY",
1962
+ ],
1963
+ "VA_ASOS": [
1964
+ "&stations=DCA",
1965
+ "&stations=RIC",
1966
+ "&stations=ROA",
1967
+ "&stations=ORF",
1968
+ "&stations=SHD",
1969
+ ],
1970
+ "VT_ASOS": [
1971
+ "&stations=BTV",
1972
+ "&stations=MPV",
1973
+ "&stations=RUT",
1974
+ "&stations=VSF",
1975
+ "&stations=MVL",
1976
+ ],
1977
+ "WA_ASOS": [
1978
+ "&stations=SEA",
1979
+ "&stations=GEG",
1980
+ "&stations=TIW",
1981
+ "&stations=VUO",
1982
+ "&stations=BFI",
1983
+ ],
1984
+ "WI_ASOS": [
1985
+ "&stations=MKE",
1986
+ "&stations=MSN",
1987
+ "&stations=GRB",
1988
+ "&stations=EAU",
1989
+ "&stations=LSE",
1990
+ ],
1991
+ "WV_ASOS": [
1992
+ "&stations=CRW",
1993
+ "&stations=CKB",
1994
+ "&stations=HTS",
1995
+ "&stations=MGW",
1996
+ "&stations=BKW",
1997
+ ],
1998
+ "WY_ASOS": [
1999
+ "&stations=CPR",
2000
+ "&stations=JAC",
2001
+ "&stations=SHR",
2002
+ "&stations=COD",
2003
+ "&stations=RKS",
2004
+ ],
1368
2005
  }
1369
2006
  # --- Date setup --- #
1370
2007
  date_object = datetime.strptime(start_date, "%Y-%m-%d")
@@ -1386,7 +2023,7 @@ class datapull:
1386
2023
  """Fetch station-based data (daily) from Iowa Mesonet."""
1387
2024
  import csv
1388
2025
 
1389
- station_query = ''.join(stations)
2026
+ station_query = "".join(stations)
1390
2027
  url = (
1391
2028
  "https://mesonet.agron.iastate.edu/cgi-bin/request/daily.py?"
1392
2029
  f"network={network}{station_query}"
@@ -1400,9 +2037,10 @@ class datapull:
1400
2037
  def fetch_canada_data() -> pd.DataFrame:
1401
2038
  """Canada uses multiple sub-networks. Combine them all."""
1402
2039
  import csv
2040
+
1403
2041
  final_df = pd.DataFrame()
1404
2042
  for i, institute_temp in enumerate(institute_vector):
1405
- station_query_temp = ''.join(stations_list_canada[i])
2043
+ station_query_temp = "".join(stations_list_canada[i])
1406
2044
  mesonet_url = (
1407
2045
  "https://mesonet.agron.iastate.edu/cgi-bin/request/daily.py?"
1408
2046
  f"network={institute_temp}{station_query_temp}"
@@ -1438,11 +2076,13 @@ class datapull:
1438
2076
  "start_date": formatted_date,
1439
2077
  "end_date": today.strftime("%Y-%m-%d"),
1440
2078
  "daily": "precipitation_sum,snowfall_sum",
1441
- "timezone": "auto"
2079
+ "timezone": "auto",
1442
2080
  }
1443
2081
  resp = requests.get(url, params=params)
1444
2082
  if resp.status_code != 200:
1445
- print(f"[ERROR] open-meteo returned status {resp.status_code} for city={city}")
2083
+ print(
2084
+ f"[ERROR] open-meteo returned status {resp.status_code} for city={city}",
2085
+ )
1446
2086
  continue
1447
2087
  try:
1448
2088
  data_json = resp.json()
@@ -1451,22 +2091,27 @@ class datapull:
1451
2091
  continue
1452
2092
 
1453
2093
  daily_block = data_json.get("daily", {})
1454
- if not {"time", "precipitation_sum", "snowfall_sum"}.issubset(daily_block.keys()):
1455
- print(f"[ERROR] missing required keys in open-meteo for city={city}")
2094
+ if not {"time", "precipitation_sum", "snowfall_sum"}.issubset(
2095
+ daily_block.keys(),
2096
+ ):
2097
+ print(
2098
+ f"[ERROR] missing required keys in open-meteo for city={city}",
2099
+ )
1456
2100
  continue
1457
2101
 
1458
- df_temp = pd.DataFrame({
1459
- "date": daily_block["time"],
1460
- "rain_sum": daily_block["precipitation_sum"],
1461
- "snow_sum": daily_block["snowfall_sum"]
1462
- })
2102
+ df_temp = pd.DataFrame(
2103
+ {
2104
+ "date": daily_block["time"],
2105
+ "rain_sum": daily_block["precipitation_sum"],
2106
+ "snow_sum": daily_block["snowfall_sum"],
2107
+ },
2108
+ )
1463
2109
  df_temp["city"] = city
1464
2110
  weather_data_list.append(df_temp)
1465
2111
 
1466
2112
  if weather_data_list:
1467
2113
  return pd.concat(weather_data_list, ignore_index=True)
1468
- else:
1469
- return pd.DataFrame()
2114
+ return pd.DataFrame()
1470
2115
 
1471
2116
  def weekly_aggregate_temp_mesonet(df: pd.DataFrame) -> pd.DataFrame:
1472
2117
  """
@@ -1500,10 +2145,12 @@ class datapull:
1500
2145
 
1501
2146
  # Group by "week_starting"
1502
2147
  df["week_starting"] = df["day"].apply(
1503
- lambda x: x - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
1504
- if pd.notnull(x) else pd.NaT
2148
+ lambda x: x
2149
+ - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
2150
+ if pd.notnull(x)
2151
+ else pd.NaT,
1505
2152
  )
1506
- numeric_cols = df.select_dtypes(include='number').columns
2153
+ numeric_cols = df.select_dtypes(include="number").columns
1507
2154
  weekly = df.groupby("week_starting")[numeric_cols].mean()
1508
2155
 
1509
2156
  # Rename columns
@@ -1526,13 +2173,16 @@ class datapull:
1526
2173
  We'll do weekly average of each. -> 'avg_rain_sum', 'avg_snow_sum'.
1527
2174
  """
1528
2175
  import pandas as pd
2176
+
1529
2177
  if "date" not in df.columns:
1530
2178
  return pd.DataFrame()
1531
2179
 
1532
2180
  df["date"] = pd.to_datetime(df["date"], errors="coerce")
1533
2181
  df["week_starting"] = df["date"].apply(
1534
- lambda x: x - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
1535
- if pd.notnull(x) else pd.NaT
2182
+ lambda x: x
2183
+ - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
2184
+ if pd.notnull(x)
2185
+ else pd.NaT,
1536
2186
  )
1537
2187
 
1538
2188
  # Convert to numeric
@@ -1540,13 +2190,10 @@ class datapull:
1540
2190
  if c in df.columns:
1541
2191
  df[c] = pd.to_numeric(df[c], errors="coerce")
1542
2192
 
1543
- numeric_cols = df.select_dtypes(include='number').columns
2193
+ numeric_cols = df.select_dtypes(include="number").columns
1544
2194
  weekly = df.groupby("week_starting")[numeric_cols].mean()
1545
2195
 
1546
- rename_map = {
1547
- "rain_sum": "avg_rain_sum",
1548
- "snow_sum": "avg_snow_sum"
1549
- }
2196
+ rename_map = {"rain_sum": "avg_rain_sum", "snow_sum": "avg_snow_sum"}
1550
2197
  weekly.rename(columns=rename_map, inplace=True)
1551
2198
  return weekly
1552
2199
 
@@ -1562,6 +2209,7 @@ class datapull:
1562
2209
  snow_in -> avg_snow_sum
1563
2210
  """
1564
2211
  import pandas as pd
2212
+
1565
2213
  if "day" not in df.columns:
1566
2214
  return pd.DataFrame()
1567
2215
 
@@ -1582,10 +2230,12 @@ class datapull:
1582
2230
 
1583
2231
  # Weekly grouping
1584
2232
  df["week_starting"] = df["day"].apply(
1585
- lambda x: x - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
1586
- if pd.notnull(x) else pd.NaT
2233
+ lambda x: x
2234
+ - pd.Timedelta(days=(x.weekday() - day_dict[week_commencing]) % 7)
2235
+ if pd.notnull(x)
2236
+ else pd.NaT,
1587
2237
  )
1588
- numeric_cols = df.select_dtypes(include='number').columns
2238
+ numeric_cols = df.select_dtypes(include="number").columns
1589
2239
  weekly = df.groupby("week_starting")[numeric_cols].mean()
1590
2240
 
1591
2241
  rename_map = {
@@ -1596,7 +2246,7 @@ class datapull:
1596
2246
  "min_temp_c": "avg_min_temp_c",
1597
2247
  "mean_temp_c": "avg_mean_temp_c",
1598
2248
  "precip_in": "avg_rain_sum",
1599
- "snow_in": "avg_snow_sum"
2249
+ "snow_in": "avg_snow_sum",
1600
2250
  }
1601
2251
  weekly.rename(columns=rename_map, inplace=True)
1602
2252
  return weekly
@@ -1642,7 +2292,9 @@ class datapull:
1642
2292
 
1643
2293
  weekly_state = weekly_aggregate_us(raw_df)
1644
2294
  if weekly_state.empty:
1645
- print(f"[DEBUG] Aggregated weekly DataFrame empty for {network_code}, skipping.")
2295
+ print(
2296
+ f"[DEBUG] Aggregated weekly DataFrame empty for {network_code}, skipping.",
2297
+ )
1646
2298
  continue
1647
2299
 
1648
2300
  weekly_state.reset_index(inplace=True)
@@ -1656,7 +2308,12 @@ class datapull:
1656
2308
  if combined_df.empty:
1657
2309
  combined_df = weekly_state
1658
2310
  else:
1659
- combined_df = pd.merge(combined_df, weekly_state, on="OBS", how="outer")
2311
+ combined_df = pd.merge(
2312
+ combined_df,
2313
+ weekly_state,
2314
+ on="OBS",
2315
+ how="outer",
2316
+ )
1660
2317
 
1661
2318
  # Done with the US. Move on to the next country in the loop
1662
2319
  continue
@@ -1692,7 +2349,13 @@ class datapull:
1692
2349
 
1693
2350
  # C) Merge the temperature data + precip/snow data on the weekly index
1694
2351
  if not weekly_temp.empty and not weekly_precip.empty:
1695
- merged_df = pd.merge(weekly_temp, weekly_precip, left_index=True, right_index=True, how="outer")
2352
+ merged_df = pd.merge(
2353
+ weekly_temp,
2354
+ weekly_precip,
2355
+ left_index=True,
2356
+ right_index=True,
2357
+ how="outer",
2358
+ )
1696
2359
  elif not weekly_temp.empty:
1697
2360
  merged_df = weekly_temp
1698
2361
  else:
@@ -1723,36 +2386,39 @@ class datapull:
1723
2386
  combined_df.sort_values(by="OBS", inplace=True)
1724
2387
 
1725
2388
  return combined_df
1726
-
2389
+
1727
2390
  def pull_macro_ons_uk(self, cdid_list=None, week_start_day="mon", sector=None):
1728
2391
  """
1729
2392
  Fetches time series data for multiple CDIDs from the ONS API, converts it to daily frequency,
1730
2393
  aggregates it to weekly averages, and renames variables based on specified rules.
1731
2394
 
1732
- Parameters:
2395
+ Parameters
2396
+ ----------
1733
2397
  cdid_list (list, optional): A list of additional CDIDs to fetch (e.g., ['JP9Z', 'UKPOP']). Defaults to None.
1734
2398
  week_start_day (str, optional): The day the week starts on ('mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'). Defaults to 'mon'.
1735
2399
  sector (str or list, optional): The sector(s) for which the standard CDIDs are fetched
1736
2400
  (e.g., 'fast_food', ['fast_food', 'retail']). Defaults to None (only default CDIDs).
1737
2401
 
1738
- Returns:
2402
+ Returns
2403
+ -------
1739
2404
  pd.DataFrame: A DataFrame with weekly frequency, containing an 'OBS' column (week commencing date)
1740
2405
  and all series as renamed columns (e.g., 'macro_retail_sales_uk').
1741
2406
  Returns an empty DataFrame if no data is fetched or processed.
2407
+
1742
2408
  """
1743
2409
  # Define CDIDs for sectors and defaults
1744
2410
  sector_cdids_map = {
1745
2411
  "fast_food": ["L7TD", "L78Q", "DOAD"],
1746
- "clothing_footwear": ["D7BW","D7GO","CHBJ"],
1747
- "fuel": ["A9FS","L7FP","CHOL"],
1748
- "cars":["D7E8","D7E9","D7CO"],
2412
+ "clothing_footwear": ["D7BW", "D7GO", "CHBJ"],
2413
+ "fuel": ["A9FS", "L7FP", "CHOL"],
2414
+ "cars": ["D7E8", "D7E9", "D7CO"],
1749
2415
  "default": ["D7G7", "MGSX", "UKPOP", "IHYQ", "YBEZ", "MS77"],
1750
2416
  }
1751
2417
 
1752
2418
  default_cdids = sector_cdids_map["default"]
1753
- sector_specific_cdids = [] # Initialize empty list for sector CDIDs
2419
+ sector_specific_cdids = [] # Initialize empty list for sector CDIDs
1754
2420
 
1755
- if sector: # Check if sector is not None or empty
2421
+ if sector: # Check if sector is not None or empty
1756
2422
  if isinstance(sector, str):
1757
2423
  # If it's a single string, wrap it in a list
1758
2424
  sector_list = [sector]
@@ -1760,34 +2426,56 @@ class datapull:
1760
2426
  # If it's already a list, use it directly
1761
2427
  sector_list = sector
1762
2428
  else:
1763
- raise TypeError("`sector` parameter must be a string or a list of strings.")
2429
+ raise TypeError(
2430
+ "`sector` parameter must be a string or a list of strings.",
2431
+ )
1764
2432
 
1765
2433
  # Iterate through the list of sectors and collect their CDIDs
1766
2434
  for sec in sector_list:
1767
- sector_specific_cdids.extend(sector_cdids_map.get(sec, [])) # Use extend to add items from the list
2435
+ sector_specific_cdids.extend(
2436
+ sector_cdids_map.get(sec, []),
2437
+ ) # Use extend to add items from the list
1768
2438
 
1769
- standard_cdids = list(set(default_cdids + sector_specific_cdids)) # Combine default and selected sector CDIDs, ensure uniqueness
2439
+ standard_cdids = list(
2440
+ set(default_cdids + sector_specific_cdids),
2441
+ ) # Combine default and selected sector CDIDs, ensure uniqueness
1770
2442
 
1771
2443
  # Combine standard CDIDs and any additional user-provided CDIDs
1772
2444
  if cdid_list is None:
1773
2445
  cdid_list = []
1774
- final_cdid_list = list(set(standard_cdids + cdid_list)) # Ensure uniqueness in the final list
2446
+ final_cdid_list = list(
2447
+ set(standard_cdids + cdid_list),
2448
+ ) # Ensure uniqueness in the final list
1775
2449
 
1776
- base_search_url = "https://api.beta.ons.gov.uk/v1/search?content_type=timeseries&cdids="
2450
+ base_search_url = (
2451
+ "https://api.beta.ons.gov.uk/v1/search?content_type=timeseries&cdids="
2452
+ )
1777
2453
  base_data_url = "https://api.beta.ons.gov.uk/v1/data?uri="
1778
2454
  combined_df = pd.DataFrame()
1779
2455
 
1780
2456
  # Map week start day to pandas weekday convention
1781
- days_map = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
2457
+ days_map = {
2458
+ "mon": 0,
2459
+ "tue": 1,
2460
+ "wed": 2,
2461
+ "thu": 3,
2462
+ "fri": 4,
2463
+ "sat": 5,
2464
+ "sun": 6,
2465
+ }
1782
2466
  if week_start_day.lower() not in days_map:
1783
- raise ValueError("Invalid week start day. Choose from: " + ", ".join(days_map.keys()))
1784
- week_start = days_map[week_start_day.lower()] # Use lower() for case-insensitivity
2467
+ raise ValueError(
2468
+ "Invalid week start day. Choose from: " + ", ".join(days_map.keys()),
2469
+ )
2470
+ week_start = days_map[
2471
+ week_start_day.lower()
2472
+ ] # Use lower() for case-insensitivity
1785
2473
 
1786
- for cdid in final_cdid_list: # Use the final combined list
2474
+ for cdid in final_cdid_list: # Use the final combined list
1787
2475
  try:
1788
2476
  # Search for the series
1789
2477
  search_url = f"{base_search_url}{cdid}"
1790
- search_response = requests.get(search_url, timeout=30) # Add timeout
2478
+ search_response = requests.get(search_url, timeout=30) # Add timeout
1791
2479
  search_response.raise_for_status()
1792
2480
  search_data = search_response.json()
1793
2481
 
@@ -1804,46 +2492,59 @@ class datapull:
1804
2492
  if "release_date" in item:
1805
2493
  try:
1806
2494
  # Ensure timezone awareness for comparison
1807
- current_date = datetime.fromisoformat(item["release_date"].replace("Z", "+00:00"))
2495
+ current_date = datetime.fromisoformat(
2496
+ item["release_date"].replace("Z", "+00:00"),
2497
+ )
1808
2498
  if latest_date is None or current_date > latest_date:
1809
2499
  latest_date = current_date
1810
2500
  latest_item = item
1811
2501
  except ValueError:
1812
- print(f"Warning: Could not parse release_date '{item['release_date']}' for CDID {cdid}")
1813
- continue # Skip this item if date is invalid
2502
+ print(
2503
+ f"Warning: Could not parse release_date '{item['release_date']}' for CDID {cdid}",
2504
+ )
2505
+ continue # Skip this item if date is invalid
1814
2506
 
1815
2507
  if latest_item is None:
1816
- print(f"Warning: No valid release date found for CDID: {cdid}")
1817
- continue
2508
+ print(f"Warning: No valid release date found for CDID: {cdid}")
2509
+ continue
1818
2510
 
1819
- series_name = latest_item.get("title", f"Series_{cdid}") # Use title from the latest item
2511
+ series_name = latest_item.get(
2512
+ "title",
2513
+ f"Series_{cdid}",
2514
+ ) # Use title from the latest item
1820
2515
  latest_uri = latest_item.get("uri")
1821
2516
  if not latest_uri:
1822
- print(f"Warning: No URI found for the latest release of CDID: {cdid}")
1823
- continue
2517
+ print(
2518
+ f"Warning: No URI found for the latest release of CDID: {cdid}",
2519
+ )
2520
+ continue
1824
2521
 
1825
2522
  # Fetch the dataset
1826
2523
  data_url = f"{base_data_url}{latest_uri}"
1827
- data_response = requests.get(data_url, timeout=30) # Add timeout
2524
+ data_response = requests.get(data_url, timeout=30) # Add timeout
1828
2525
  data_response.raise_for_status()
1829
2526
  data_json = data_response.json()
1830
2527
 
1831
2528
  # Detect the frequency and process accordingly
1832
2529
  frequency_key = None
1833
- if "months" in data_json and data_json["months"]:
2530
+ if data_json.get("months"):
1834
2531
  frequency_key = "months"
1835
- elif "quarters" in data_json and data_json["quarters"]:
2532
+ elif data_json.get("quarters"):
1836
2533
  frequency_key = "quarters"
1837
- elif "years" in data_json and data_json["years"]:
2534
+ elif data_json.get("years"):
1838
2535
  frequency_key = "years"
1839
2536
  else:
1840
- print(f"Warning: Unsupported frequency or no data values found for CDID: {cdid} at URI {latest_uri}")
2537
+ print(
2538
+ f"Warning: Unsupported frequency or no data values found for CDID: {cdid} at URI {latest_uri}",
2539
+ )
1841
2540
  continue
1842
2541
 
1843
2542
  # Prepare the DataFrame
1844
- if not data_json[frequency_key]: # Check if the list of values is empty
1845
- print(f"Warning: Empty data list for frequency '{frequency_key}' for CDID: {cdid}")
1846
- continue
2543
+ if not data_json[frequency_key]: # Check if the list of values is empty
2544
+ print(
2545
+ f"Warning: Empty data list for frequency '{frequency_key}' for CDID: {cdid}",
2546
+ )
2547
+ continue
1847
2548
 
1848
2549
  df = pd.DataFrame(data_json[frequency_key])
1849
2550
 
@@ -1856,21 +2557,33 @@ class datapull:
1856
2557
  try:
1857
2558
  if frequency_key == "months":
1858
2559
  # Handles "YYYY Mon" format (e.g., "2023 FEB") - adjust if format differs
1859
- df["date"] = pd.to_datetime(df["date"], format="%Y %b", errors="coerce")
2560
+ df["date"] = pd.to_datetime(
2561
+ df["date"],
2562
+ format="%Y %b",
2563
+ errors="coerce",
2564
+ )
1860
2565
  elif frequency_key == "quarters":
2566
+
1861
2567
  def parse_quarter(quarter_str):
1862
2568
  try:
1863
2569
  year, qtr = quarter_str.split(" Q")
1864
2570
  month = {"1": 1, "2": 4, "3": 7, "4": 10}[qtr]
1865
2571
  return datetime(int(year), month, 1)
1866
2572
  except (ValueError, KeyError):
1867
- return pd.NaT # Return Not a Time for parsing errors
2573
+ return pd.NaT # Return Not a Time for parsing errors
2574
+
1868
2575
  df["date"] = df["date"].apply(parse_quarter)
1869
2576
  elif frequency_key == "years":
1870
- df["date"] = pd.to_datetime(df["date"], format="%Y", errors="coerce")
2577
+ df["date"] = pd.to_datetime(
2578
+ df["date"],
2579
+ format="%Y",
2580
+ errors="coerce",
2581
+ )
1871
2582
  except Exception as e:
1872
- print(f"Error parsing date for CDID {cdid} with frequency {frequency_key}: {e}")
1873
- continue # Skip this series if date parsing fails
2583
+ print(
2584
+ f"Error parsing date for CDID {cdid} with frequency {frequency_key}: {e}",
2585
+ )
2586
+ continue # Skip this series if date parsing fails
1874
2587
 
1875
2588
  # Coerce value to numeric, handle potential errors
1876
2589
  df["value"] = pd.to_numeric(df["value"], errors="coerce")
@@ -1879,26 +2592,34 @@ class datapull:
1879
2592
  df.dropna(subset=["date", "value"], inplace=True)
1880
2593
 
1881
2594
  if df.empty:
1882
- print(f"Warning: No valid data points after processing for CDID: {cdid}")
2595
+ print(
2596
+ f"Warning: No valid data points after processing for CDID: {cdid}",
2597
+ )
1883
2598
  continue
1884
2599
 
1885
2600
  df.rename(columns={"value": series_name}, inplace=True)
1886
2601
 
1887
2602
  # Combine data
1888
- df_subset = df.loc[:, ["date", series_name]].reset_index(drop=True) # Explicitly select columns
2603
+ df_subset = df.loc[:, ["date", series_name]].reset_index(
2604
+ drop=True,
2605
+ ) # Explicitly select columns
1889
2606
  if combined_df.empty:
1890
2607
  combined_df = df_subset
1891
2608
  else:
1892
2609
  # Use outer merge to keep all dates, sort afterwards
1893
- combined_df = pd.merge(combined_df, df_subset, on="date", how="outer")
2610
+ combined_df = pd.merge(
2611
+ combined_df,
2612
+ df_subset,
2613
+ on="date",
2614
+ how="outer",
2615
+ )
1894
2616
 
1895
2617
  except requests.exceptions.RequestException as e:
1896
2618
  print(f"Error fetching data for CDID {cdid}: {e}")
1897
- except (KeyError, ValueError, TypeError) as e: # Added TypeError
2619
+ except (KeyError, ValueError, TypeError) as e: # Added TypeError
1898
2620
  print(f"Error processing data for CDID {cdid}: {e}")
1899
- except Exception as e: # Catch unexpected errors
1900
- print(f"An unexpected error occurred for CDID {cdid}: {e}")
1901
-
2621
+ except Exception as e: # Catch unexpected errors
2622
+ print(f"An unexpected error occurred for CDID {cdid}: {e}")
1902
2623
 
1903
2624
  if not combined_df.empty:
1904
2625
  # Sort by date after merging to ensure correct forward fill
@@ -1908,36 +2629,45 @@ class datapull:
1908
2629
  # Create a complete daily date range
1909
2630
  min_date = combined_df["date"].min()
1910
2631
  # Ensure max_date is timezone-naive if min_date is, or consistent otherwise
1911
- max_date = pd.Timestamp(datetime.today().date()) # Use today's date, timezone-naive
2632
+ max_date = pd.Timestamp(
2633
+ datetime.today().date(),
2634
+ ) # Use today's date, timezone-naive
1912
2635
 
1913
2636
  if pd.isna(min_date):
1914
- print("Error: Minimum date is NaT, cannot create date range.")
1915
- return pd.DataFrame()
2637
+ print("Error: Minimum date is NaT, cannot create date range.")
2638
+ return pd.DataFrame()
1916
2639
 
1917
2640
  # Make sure min_date is not NaT before creating the range
1918
- date_range = pd.date_range(start=min_date, end=max_date, freq='D')
1919
- daily_df = pd.DataFrame(date_range, columns=['date'])
2641
+ date_range = pd.date_range(start=min_date, end=max_date, freq="D")
2642
+ daily_df = pd.DataFrame(date_range, columns=["date"])
1920
2643
 
1921
2644
  # Merge with original data and forward fill
1922
2645
  daily_df = pd.merge(daily_df, combined_df, on="date", how="left")
1923
2646
  daily_df = daily_df.ffill()
1924
2647
 
1925
2648
  # Drop rows before the first valid data point after ffill
1926
- first_valid_index = daily_df.dropna(subset=daily_df.columns.difference(['date'])).index.min()
2649
+ first_valid_index = daily_df.dropna(
2650
+ subset=daily_df.columns.difference(["date"]),
2651
+ ).index.min()
1927
2652
  if pd.notna(first_valid_index):
1928
- daily_df = daily_df.loc[first_valid_index:]
2653
+ daily_df = daily_df.loc[first_valid_index:]
1929
2654
  else:
1930
- print("Warning: No valid data points found after forward filling.")
1931
- return pd.DataFrame() # Return empty if ffill results in no data
1932
-
2655
+ print("Warning: No valid data points found after forward filling.")
2656
+ return pd.DataFrame() # Return empty if ffill results in no data
1933
2657
 
1934
2658
  # Aggregate to weekly frequency
1935
2659
  # Ensure 'date' column is datetime type before dt accessor
1936
- daily_df['date'] = pd.to_datetime(daily_df['date'])
1937
- daily_df["week_commencing"] = daily_df["date"] - pd.to_timedelta((daily_df["date"].dt.weekday - week_start + 7) % 7, unit='D') # Corrected logic for week start
2660
+ daily_df["date"] = pd.to_datetime(daily_df["date"])
2661
+ daily_df["week_commencing"] = daily_df["date"] - pd.to_timedelta(
2662
+ (daily_df["date"].dt.weekday - week_start + 7) % 7,
2663
+ unit="D",
2664
+ ) # Corrected logic for week start
1938
2665
  # Group by week_commencing and calculate mean for numeric columns only
1939
- weekly_df = daily_df.groupby("week_commencing").mean(numeric_only=True).reset_index()
1940
-
2666
+ weekly_df = (
2667
+ daily_df.groupby("week_commencing")
2668
+ .mean(numeric_only=True)
2669
+ .reset_index()
2670
+ )
1941
2671
 
1942
2672
  def clean_column_name(name):
1943
2673
  # Remove content within parentheses (e.g., CPI INDEX 00: ALL ITEMS 2015=100)
@@ -1945,13 +2675,13 @@ class datapull:
1945
2675
  # Take only the part before the first colon if present
1946
2676
  name = re.split(r":", name)[0]
1947
2677
  # Remove digits
1948
- #name = re.sub(r"\d+", "", name) # Reconsider removing all digits, might be needed for some series
2678
+ # name = re.sub(r"\d+", "", name) # Reconsider removing all digits, might be needed for some series
1949
2679
  # Remove specific words like 'annual', 'rate' case-insensitively
1950
2680
  name = re.sub(r"\b(annual|rate)\b", "", name, flags=re.IGNORECASE)
1951
2681
  # Remove non-alphanumeric characters (except underscore and space)
1952
2682
  name = re.sub(r"[^\w\s]", "", name)
1953
2683
  # Replace spaces with underscores
1954
- name = name.strip() # Remove leading/trailing whitespace
2684
+ name = name.strip() # Remove leading/trailing whitespace
1955
2685
  name = name.replace(" ", "_")
1956
2686
  # Replace multiple underscores with a single one
1957
2687
  name = re.sub(r"_+", "_", name)
@@ -1961,30 +2691,38 @@ class datapull:
1961
2691
  return f"macro_{name.lower()}_uk"
1962
2692
 
1963
2693
  # Apply cleaning function to relevant columns
1964
- weekly_df.columns = [clean_column_name(col) if col != "week_commencing" else col for col in weekly_df.columns]
1965
- weekly_df.rename(columns={"week_commencing": "OBS"}, inplace=True) # Rename week commencing col
2694
+ weekly_df.columns = [
2695
+ clean_column_name(col) if col != "week_commencing" else col
2696
+ for col in weekly_df.columns
2697
+ ]
2698
+ weekly_df.rename(
2699
+ columns={"week_commencing": "OBS"},
2700
+ inplace=True,
2701
+ ) # Rename week commencing col
1966
2702
 
1967
2703
  # Optional: Fill remaining NaNs (e.g., at the beginning if ffill didn't cover) with 0
1968
2704
  # Consider if 0 is the appropriate fill value for your use case
1969
2705
  # weekly_df = weekly_df.fillna(0)
1970
2706
 
1971
2707
  return weekly_df
1972
- else:
1973
- print("No data successfully fetched or processed.")
1974
- return pd.DataFrame()
1975
-
2708
+ print("No data successfully fetched or processed.")
2709
+ return pd.DataFrame()
2710
+
1976
2711
  def pull_yfinance(self, tickers=None, week_start_day="mon"):
1977
2712
  """
1978
- Fetches stock data for multiple tickers from Yahoo Finance, converts it to daily frequency,
2713
+ Fetches stock data for multiple tickers from Yahoo Finance, converts it to daily frequency,
1979
2714
  aggregates it to weekly averages, and renames variables.
1980
2715
 
1981
- Parameters:
2716
+ Parameters
2717
+ ----------
1982
2718
  tickers (list): A list of additional stock tickers to fetch (e.g., ['AAPL', 'MSFT']). Defaults to None.
1983
2719
  week_start_day (str): The day the week starts on (e.g., 'Monday', 'Sunday').
1984
2720
 
1985
- Returns:
1986
- pd.DataFrame: A DataFrame with weekly frequency, containing an 'OBS' column
2721
+ Returns
2722
+ -------
2723
+ pd.DataFrame: A DataFrame with weekly frequency, containing an 'OBS' column
1987
2724
  and aggregated stock data for the specified tickers, with NaN values filled with 0.
2725
+
1988
2726
  """
1989
2727
  # Define default tickers
1990
2728
  default_tickers = ["^FTSE", "GBPUSD=X", "GBPEUR=X", "^GSPC"]
@@ -1996,16 +2734,26 @@ class datapull:
1996
2734
 
1997
2735
  # Automatically set end_date to today
1998
2736
  end_date = datetime.today().strftime("%Y-%m-%d")
1999
-
2737
+
2000
2738
  # Mapping week start day to pandas weekday convention
2001
- days_map = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
2739
+ days_map = {
2740
+ "mon": 0,
2741
+ "tue": 1,
2742
+ "wed": 2,
2743
+ "thu": 3,
2744
+ "fri": 4,
2745
+ "sat": 5,
2746
+ "sun": 6,
2747
+ }
2002
2748
  if week_start_day not in days_map:
2003
- raise ValueError("Invalid week start day. Choose from: " + ", ".join(days_map.keys()))
2749
+ raise ValueError(
2750
+ "Invalid week start day. Choose from: " + ", ".join(days_map.keys()),
2751
+ )
2004
2752
  week_start = days_map[week_start_day]
2005
2753
 
2006
2754
  # Fetch data for all tickers without specifying a start date to get all available data
2007
2755
  data = yf.download(tickers, end=end_date, group_by="ticker", auto_adjust=True)
2008
-
2756
+
2009
2757
  # Process the data
2010
2758
  combined_df = pd.DataFrame()
2011
2759
  for ticker in tickers:
@@ -2016,8 +2764,10 @@ class datapull:
2016
2764
 
2017
2765
  # Ensure necessary columns are present
2018
2766
  if "Close" not in ticker_data.columns:
2019
- raise ValueError(f"Ticker {ticker} does not have 'Close' price data.")
2020
-
2767
+ raise ValueError(
2768
+ f"Ticker {ticker} does not have 'Close' price data.",
2769
+ )
2770
+
2021
2771
  # Keep only relevant columns
2022
2772
  ticker_data = ticker_data[["Date", "Close"]]
2023
2773
  ticker_data.rename(columns={"Close": ticker}, inplace=True)
@@ -2026,7 +2776,12 @@ class datapull:
2026
2776
  if combined_df.empty:
2027
2777
  combined_df = ticker_data
2028
2778
  else:
2029
- combined_df = pd.merge(combined_df, ticker_data, on="Date", how="outer")
2779
+ combined_df = pd.merge(
2780
+ combined_df,
2781
+ ticker_data,
2782
+ on="Date",
2783
+ how="outer",
2784
+ )
2030
2785
 
2031
2786
  except KeyError:
2032
2787
  print(f"Data for ticker {ticker} not available.")
@@ -2041,13 +2796,16 @@ class datapull:
2041
2796
  # Fill missing dates
2042
2797
  min_date = combined_df.index.min()
2043
2798
  max_date = combined_df.index.max()
2044
- daily_index = pd.date_range(start=min_date, end=max_date, freq='D')
2799
+ daily_index = pd.date_range(start=min_date, end=max_date, freq="D")
2045
2800
  combined_df = combined_df.reindex(daily_index)
2046
2801
  combined_df.index.name = "Date"
2047
2802
  combined_df = combined_df.ffill()
2048
2803
 
2049
2804
  # Aggregate to weekly frequency
2050
- combined_df["OBS"] = combined_df.index - pd.to_timedelta((combined_df.index.weekday - week_start) % 7, unit="D")
2805
+ combined_df["OBS"] = combined_df.index - pd.to_timedelta(
2806
+ (combined_df.index.weekday - week_start) % 7,
2807
+ unit="D",
2808
+ )
2051
2809
  weekly_df = combined_df.groupby("OBS").mean(numeric_only=True).reset_index()
2052
2810
 
2053
2811
  # Fill NaN values with 0
@@ -2058,14 +2816,16 @@ class datapull:
2058
2816
  name = re.sub(r"[^\w\s]", "", name)
2059
2817
  return f"macro_{name.lower()}"
2060
2818
 
2061
- weekly_df.columns = [clean_column_name(col) if col != "OBS" else col for col in weekly_df.columns]
2819
+ weekly_df.columns = [
2820
+ clean_column_name(col) if col != "OBS" else col
2821
+ for col in weekly_df.columns
2822
+ ]
2062
2823
 
2063
2824
  return weekly_df
2064
2825
 
2065
- else:
2066
- print("No data available to process.")
2067
- return pd.DataFrame()
2068
-
2826
+ print("No data available to process.")
2827
+ return pd.DataFrame()
2828
+
2069
2829
  def pull_sports_events(self, start_date="2020-01-01", week_commencing="mon"):
2070
2830
  """
2071
2831
  Combines scraping logic for:
@@ -2078,27 +2838,33 @@ class datapull:
2078
2838
  ############################################################
2079
2839
  # 1) SCRAPE UEFA CHAMPIONS LEAGUE & NFL (YOUR FIRST FUNCTION)
2080
2840
  ############################################################
2081
- def scrape_sports_events(start_date=start_date, week_commencing=week_commencing):
2841
+ def scrape_sports_events(
2842
+ start_date=start_date,
2843
+ week_commencing=week_commencing,
2844
+ ):
2082
2845
  sports = {
2083
2846
  "uefa_champions_league": {
2084
2847
  "league_id": "4480",
2085
2848
  "seasons_url": "https://www.thesportsdb.com/league/4480-UEFA-Champions-League?a=1#allseasons",
2086
2849
  "season_url_template": "https://www.thesportsdb.com/season/4480-UEFA-Champions-League/{season}&all=1&view=",
2087
- "round_filters": ["quarter", "semi", "final"]
2850
+ "round_filters": ["quarter", "semi", "final"],
2088
2851
  },
2089
2852
  "nfl": {
2090
2853
  "league_id": "4391",
2091
2854
  "seasons_url": "https://www.thesportsdb.com/league/4391-NFL?a=1#allseasons",
2092
2855
  "season_url_template": "https://www.thesportsdb.com/season/4391-NFL/{season}&all=1&view=",
2093
- "round_filters": ["quarter", "semi", "final"]
2094
- }
2856
+ "round_filters": ["quarter", "semi", "final"],
2857
+ },
2095
2858
  }
2096
2859
 
2097
2860
  headers = {"User-Agent": "Mozilla/5.0"}
2098
2861
  start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
2099
2862
 
2100
2863
  # Create a full date range DataFrame
2101
- full_date_range = pd.date_range(start=start_date, end=pd.to_datetime("today"))
2864
+ full_date_range = pd.date_range(
2865
+ start=start_date,
2866
+ end=pd.to_datetime("today"),
2867
+ )
2102
2868
  time_series_df = pd.DataFrame({"date": full_date_range})
2103
2869
  time_series_df["seas_uefa_champions_league"] = 0
2104
2870
  time_series_df["seas_nfl"] = 0
@@ -2140,9 +2906,13 @@ class datapull:
2140
2906
  match_date = cols[0].text.strip()
2141
2907
  round_name = cols[1].text.strip().lower()
2142
2908
  try:
2143
- match_date_dt = datetime.strptime(match_date, "%d %b %y")
2144
- if (match_date_dt >= start_date_dt
2145
- and any(r in round_name for r in details["round_filters"])):
2909
+ match_date_dt = datetime.strptime(
2910
+ match_date,
2911
+ "%d %b %y",
2912
+ )
2913
+ if match_date_dt >= start_date_dt and any(
2914
+ r in round_name for r in details["round_filters"]
2915
+ ):
2146
2916
  filtered_matches.append(match_date_dt)
2147
2917
  except ValueError:
2148
2918
  continue
@@ -2152,27 +2922,35 @@ class datapull:
2152
2922
  if df_sport.empty:
2153
2923
  continue
2154
2924
 
2155
- col_name = "seas_nfl" if sport == "nfl" else "seas_uefa_champions_league"
2156
- time_series_df.loc[time_series_df["date"].isin(df_sport["date"]), col_name] = 1
2925
+ col_name = (
2926
+ "seas_nfl" if sport == "nfl" else "seas_uefa_champions_league"
2927
+ )
2928
+ time_series_df.loc[
2929
+ time_series_df["date"].isin(df_sport["date"]),
2930
+ col_name,
2931
+ ] = 1
2157
2932
 
2158
2933
  # Aggregate by week commencing
2159
2934
  day_offsets = {
2160
- 'mon': 'W-MON',
2161
- 'tue': 'W-TUE',
2162
- 'wed': 'W-WED',
2163
- 'thu': 'W-THU',
2164
- 'fri': 'W-FRI',
2165
- 'sat': 'W-SAT',
2166
- 'sun': 'W-SUN'
2935
+ "mon": "W-MON",
2936
+ "tue": "W-TUE",
2937
+ "wed": "W-WED",
2938
+ "thu": "W-THU",
2939
+ "fri": "W-FRI",
2940
+ "sat": "W-SAT",
2941
+ "sun": "W-SUN",
2167
2942
  }
2168
2943
  if week_commencing.lower() not in day_offsets:
2169
- raise ValueError(f"Invalid week_commencing value: {week_commencing}. Must be one of {list(day_offsets.keys())}.")
2944
+ raise ValueError(
2945
+ f"Invalid week_commencing value: {week_commencing}. Must be one of {list(day_offsets.keys())}.",
2946
+ )
2170
2947
 
2171
- time_series_df = (time_series_df
2172
- .set_index("date")
2173
- .resample(day_offsets[week_commencing.lower()])
2174
- .max()
2175
- .reset_index())
2948
+ time_series_df = (
2949
+ time_series_df.set_index("date")
2950
+ .resample(day_offsets[week_commencing.lower()])
2951
+ .max()
2952
+ .reset_index()
2953
+ )
2176
2954
 
2177
2955
  time_series_df.rename(columns={"date": "OBS"}, inplace=True)
2178
2956
  time_series_df.fillna(0, inplace=True)
@@ -2184,32 +2962,47 @@ class datapull:
2184
2962
  ############################################################
2185
2963
  def fetch_events(start_date=start_date, week_commencing=week_commencing):
2186
2964
  # Initialize date range
2187
- start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
2965
+ start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
2188
2966
  end_date_obj = datetime.today()
2189
2967
  date_range = pd.date_range(start=start_date_obj, end=end_date_obj)
2190
- df = pd.DataFrame({'OBS': date_range}).set_index('OBS')
2968
+ df = pd.DataFrame({"OBS": date_range}).set_index("OBS")
2191
2969
 
2192
2970
  # Define columns for sports
2193
2971
  event_columns = {
2194
- 'seas_fifa_world_cup': {
2195
- 'league_id': 4429, 'start_year': 1950, 'interval': 4
2972
+ "seas_fifa_world_cup": {
2973
+ "league_id": 4429,
2974
+ "start_year": 1950,
2975
+ "interval": 4,
2196
2976
  },
2197
- 'seas_uefa_european_championship': {
2198
- 'league_id': 4502, 'start_year': 1960, 'interval': 4, 'extra_years': [2021]
2977
+ "seas_uefa_european_championship": {
2978
+ "league_id": 4502,
2979
+ "start_year": 1960,
2980
+ "interval": 4,
2981
+ "extra_years": [2021],
2199
2982
  },
2200
- 'seas_rugby_world_cup': {
2201
- 'league_id': 4574, 'start_year': 1987, 'interval': 4
2983
+ "seas_rugby_world_cup": {
2984
+ "league_id": 4574,
2985
+ "start_year": 1987,
2986
+ "interval": 4,
2202
2987
  },
2203
- 'seas_six_nations': {
2204
- 'league_id': 4714, 'start_year': 2000, 'interval': 1
2988
+ "seas_six_nations": {
2989
+ "league_id": 4714,
2990
+ "start_year": 2000,
2991
+ "interval": 1,
2205
2992
  },
2206
2993
  }
2207
2994
 
2208
2995
  # Initialize columns
2209
- for col in event_columns.keys():
2996
+ for col in event_columns:
2210
2997
  df[col] = 0
2211
2998
 
2212
- def fetch_league_events(league_id, column_name, start_year, interval, extra_years=None):
2999
+ def fetch_league_events(
3000
+ league_id,
3001
+ column_name,
3002
+ start_year,
3003
+ interval,
3004
+ extra_years=None,
3005
+ ):
2213
3006
  extra_years = extra_years or []
2214
3007
  # Fetch seasons
2215
3008
  seasons_url = f"https://www.thesportsdb.com/api/v1/json/3/search_all_seasons.php?id={league_id}"
@@ -2217,54 +3010,59 @@ class datapull:
2217
3010
  if seasons_response.status_code != 200:
2218
3011
  return # Skip on failure
2219
3012
 
2220
- seasons_data = seasons_response.json().get('seasons', [])
3013
+ seasons_data = seasons_response.json().get("seasons", [])
2221
3014
  for season in seasons_data:
2222
- season_name = season.get('strSeason', '')
3015
+ season_name = season.get("strSeason", "")
2223
3016
  if not season_name.isdigit():
2224
3017
  continue
2225
3018
 
2226
3019
  year = int(season_name)
2227
3020
  # Check if the year is valid for this competition
2228
- if year in extra_years or (year >= start_year and (year - start_year) % interval == 0):
3021
+ if year in extra_years or (
3022
+ year >= start_year and (year - start_year) % interval == 0
3023
+ ):
2229
3024
  # Fetch events
2230
3025
  events_url = f"https://www.thesportsdb.com/api/v1/json/3/eventsseason.php?id={league_id}&s={season_name}"
2231
3026
  events_response = requests.get(events_url)
2232
3027
  if events_response.status_code != 200:
2233
3028
  continue
2234
3029
 
2235
- events_data = events_response.json().get('events', [])
3030
+ events_data = events_response.json().get("events", [])
2236
3031
  for event in events_data:
2237
- event_date_str = event.get('dateEvent')
3032
+ event_date_str = event.get("dateEvent")
2238
3033
  if event_date_str:
2239
- event_date = datetime.strptime(event_date_str, '%Y-%m-%d')
3034
+ event_date = datetime.strptime(
3035
+ event_date_str,
3036
+ "%Y-%m-%d",
3037
+ )
2240
3038
  if event_date in df.index:
2241
3039
  df.loc[event_date, column_name] = 1
2242
3040
 
2243
3041
  # Fetch events for all defined leagues
2244
3042
  for column_name, params in event_columns.items():
2245
3043
  fetch_league_events(
2246
- league_id=params['league_id'],
3044
+ league_id=params["league_id"],
2247
3045
  column_name=column_name,
2248
- start_year=params['start_year'],
2249
- interval=params['interval'],
2250
- extra_years=params.get('extra_years', [])
3046
+ start_year=params["start_year"],
3047
+ interval=params["interval"],
3048
+ extra_years=params.get("extra_years", []),
2251
3049
  )
2252
3050
 
2253
3051
  # Resample by week
2254
3052
  day_offsets = {
2255
- 'mon': 'W-MON',
2256
- 'tue': 'W-TUE',
2257
- 'wed': 'W-WED',
2258
- 'thu': 'W-THU',
2259
- 'fri': 'W-FRI',
2260
- 'sat': 'W-SAT',
2261
- 'sun': 'W-SUN'
3053
+ "mon": "W-MON",
3054
+ "tue": "W-TUE",
3055
+ "wed": "W-WED",
3056
+ "thu": "W-THU",
3057
+ "fri": "W-FRI",
3058
+ "sat": "W-SAT",
3059
+ "sun": "W-SUN",
2262
3060
  }
2263
3061
 
2264
3062
  if week_commencing.lower() not in day_offsets:
2265
3063
  raise ValueError(
2266
3064
  f"Invalid week_commencing value: {week_commencing}. "
2267
- f"Must be one of {list(day_offsets.keys())}."
3065
+ f"Must be one of {list(day_offsets.keys())}.",
2268
3066
  )
2269
3067
 
2270
3068
  df = df.resample(day_offsets[week_commencing.lower()]).max()
@@ -2278,16 +3076,16 @@ class datapull:
2278
3076
  df_other_events = fetch_events(start_date, week_commencing)
2279
3077
 
2280
3078
  # Merge on "OBS" column (outer join to preserve all dates in range)
2281
- final_df = pd.merge(df_uefa_nfl, df_other_events, on='OBS', how='outer')
3079
+ final_df = pd.merge(df_uefa_nfl, df_other_events, on="OBS", how="outer")
2282
3080
 
2283
3081
  # Fill any NaNs with 0 for event columns
2284
3082
  # (Only fill numeric columns or everything except 'OBS')
2285
3083
  for col in final_df.columns:
2286
- if col != 'OBS':
3084
+ if col != "OBS":
2287
3085
  final_df[col] = final_df[col].fillna(0)
2288
3086
 
2289
3087
  # Sort by date just in case
2290
- final_df.sort_values(by='OBS', inplace=True)
3088
+ final_df.sort_values(by="OBS", inplace=True)
2291
3089
  final_df.reset_index(drop=True, inplace=True)
2292
3090
 
2293
3091
  return final_df