imsciences 0.9.6.6__tar.gz → 0.9.6.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/PKG-INFO +1 -1
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/pull.py +498 -44
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/PKG-INFO +1 -1
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/setup.py +1 -1
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/LICENSE.txt +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/README.md +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/__init__.py +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/geo.py +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/mmm.py +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/unittesting.py +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences/vis.py +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/PKG-INFO-TomG-HP-290722 +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/SOURCES.txt +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/dependency_links.txt +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/requires.txt +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/imsciences.egg-info/top_level.txt +0 -0
- {imsciences-0.9.6.6 → imsciences-0.9.6.8}/setup.cfg +0 -0
|
@@ -13,6 +13,8 @@ import holidays
|
|
|
13
13
|
from dateutil.easter import easter
|
|
14
14
|
import urllib.request
|
|
15
15
|
from geopy.geocoders import Nominatim
|
|
16
|
+
import importlib
|
|
17
|
+
import workalendar
|
|
16
18
|
|
|
17
19
|
from imsciences.mmm import dataprocessing
|
|
18
20
|
|
|
@@ -379,7 +381,7 @@ class datapull:
|
|
|
379
381
|
|
|
380
382
|
############################################################### Seasonality ##########################################################################
|
|
381
383
|
|
|
382
|
-
def pull_seasonality(
|
|
384
|
+
def pull_seasonality(week_commencing, start_date, countries):
|
|
383
385
|
"""
|
|
384
386
|
Generates a DataFrame with weekly seasonality features.
|
|
385
387
|
|
|
@@ -388,7 +390,7 @@ class datapull:
|
|
|
388
390
|
start_date (str): The start date in 'YYYY-MM-DD' format.
|
|
389
391
|
countries (list): A list of country codes (e.g., ['GB', 'US']) for holidays.
|
|
390
392
|
|
|
391
|
-
|
|
393
|
+
Returns:
|
|
392
394
|
pd.DataFrame: A DataFrame indexed by week start date, containing various
|
|
393
395
|
seasonal dummy variables, holidays, trend, and constant.
|
|
394
396
|
The date column is named 'OBS'.
|
|
@@ -399,6 +401,343 @@ class datapull:
|
|
|
399
401
|
day_dict = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
|
|
400
402
|
if week_commencing not in day_dict:
|
|
401
403
|
raise ValueError(f"Invalid week_commencing value: {week_commencing}. Use one of {list(day_dict.keys())}")
|
|
404
|
+
|
|
405
|
+
# ---------------------------------------------------------------------
|
|
406
|
+
# 0.2 Setup: dictionary continents and countries
|
|
407
|
+
# ---------------------------------------------------------------------
|
|
408
|
+
COUNTRY_TO_CONTINENT = {
|
|
409
|
+
#Europe
|
|
410
|
+
"Austria": "europe",
|
|
411
|
+
"Belarus": "europe",
|
|
412
|
+
"Belgium": "europe",
|
|
413
|
+
"Bulgaria": "europe",
|
|
414
|
+
"Croatia": "europe",
|
|
415
|
+
"Cyprus": "europe",
|
|
416
|
+
"Czechia": "europe",
|
|
417
|
+
"CzechRepublic": "europe",
|
|
418
|
+
"Denmark": "europe",
|
|
419
|
+
"Estonia": "europe",
|
|
420
|
+
"EuropeanCentralBank": "europe",
|
|
421
|
+
"Finland": "europe",
|
|
422
|
+
"France": "europe",
|
|
423
|
+
"FranceAlsaceMoselle": "europe",
|
|
424
|
+
"Germany": "europe",
|
|
425
|
+
"GermanyBaden": "europe",
|
|
426
|
+
"GermanyBavaria": "europe",
|
|
427
|
+
"GermanyBerlin": "europe",
|
|
428
|
+
"GermanyBrandenburg": "europe",
|
|
429
|
+
"GermanyBremen": "europe",
|
|
430
|
+
"GermanyHamburg": "europe",
|
|
431
|
+
"GermanyHesse": "europe",
|
|
432
|
+
"GermanyLowerSaxony": "europe",
|
|
433
|
+
"GermanyMecklenburgVorpommern": "europe",
|
|
434
|
+
"GermanyNorthRhineWestphalia": "europe",
|
|
435
|
+
"GermanyRhinelandPalatinate": "europe",
|
|
436
|
+
"GermanySaarland": "europe",
|
|
437
|
+
"GermanySaxony": "europe",
|
|
438
|
+
"GermanySaxonyAnhalt": "europe",
|
|
439
|
+
"GermanySchleswigHolstein": "europe",
|
|
440
|
+
"GermanyThuringia": "europe",
|
|
441
|
+
"Greece": "europe",
|
|
442
|
+
"Hungary": "europe",
|
|
443
|
+
"Iceland": "europe",
|
|
444
|
+
"Ireland": "europe",
|
|
445
|
+
"Italy": "europe",
|
|
446
|
+
"Latvia": "europe",
|
|
447
|
+
"Lithuania": "europe",
|
|
448
|
+
"Luxembourg": "europe",
|
|
449
|
+
"Malta": "europe",
|
|
450
|
+
"Monaco": "europe",
|
|
451
|
+
"Netherlands": "europe",
|
|
452
|
+
"Norway": "europe",
|
|
453
|
+
"Poland": "europe",
|
|
454
|
+
"Portugal": "europe",
|
|
455
|
+
"Romania": "europe",
|
|
456
|
+
"Russia": "europe",
|
|
457
|
+
"Serbia": "europe",
|
|
458
|
+
"Slovakia": "europe",
|
|
459
|
+
"Slovenia": "europe",
|
|
460
|
+
"Spain": "europe",
|
|
461
|
+
"SpainAndalusia": "europe",
|
|
462
|
+
"SpainAragon": "europe",
|
|
463
|
+
"SpainAsturias": "europe",
|
|
464
|
+
"SpainBalearicIslands": "europe",
|
|
465
|
+
"SpainBasqueCountry": "europe",
|
|
466
|
+
"SpainCanaryIslands": "europe",
|
|
467
|
+
"SpainCantabria": "europe",
|
|
468
|
+
"SpainCastileAndLeon": "europe",
|
|
469
|
+
"SpainCastillaLaMancha": "europe",
|
|
470
|
+
"SpainCatalonia": "europe",
|
|
471
|
+
"SpainExtremadura": "europe",
|
|
472
|
+
"SpainGalicia": "europe",
|
|
473
|
+
"SpainLaRioja": "europe",
|
|
474
|
+
"SpainMadrid": "europe",
|
|
475
|
+
"SpainMurcia": "europe",
|
|
476
|
+
"SpainNavarre": "europe",
|
|
477
|
+
"SpainValencia": "europe",
|
|
478
|
+
"Sweden": "europe",
|
|
479
|
+
"Switzerland": "europe",
|
|
480
|
+
"Ukraine": "europe",
|
|
481
|
+
"UnitedKingdom": "europe",
|
|
482
|
+
|
|
483
|
+
# Americas
|
|
484
|
+
"Argentina": "america",
|
|
485
|
+
"Barbados": "america",
|
|
486
|
+
"Brazil": "america",
|
|
487
|
+
"Canada": "america",
|
|
488
|
+
"Chile": "america",
|
|
489
|
+
"Colombia": "america",
|
|
490
|
+
"Mexico": "america",
|
|
491
|
+
"Panama": "america",
|
|
492
|
+
"Paraguay": "america",
|
|
493
|
+
"Peru": "america",
|
|
494
|
+
"UnitedStates": "usa",
|
|
495
|
+
|
|
496
|
+
#US States
|
|
497
|
+
"Alabama": "usa.states",
|
|
498
|
+
"Alaska": "usa.states",
|
|
499
|
+
"Arizona": "usa.states",
|
|
500
|
+
"Arkansas": "usa.states",
|
|
501
|
+
"California": "usa.states",
|
|
502
|
+
"Colorado": "usa.states",
|
|
503
|
+
"Connecticut": "usa.states",
|
|
504
|
+
"Delaware": "usa.states",
|
|
505
|
+
"DistrictOfColumbia": "usa.states",
|
|
506
|
+
"Florida": "usa.states",
|
|
507
|
+
"Georgia": "usa.states",
|
|
508
|
+
"Hawaii": "usa.states",
|
|
509
|
+
"Idaho": "usa.states",
|
|
510
|
+
"Illinois": "usa.states",
|
|
511
|
+
"Indiana": "usa.states",
|
|
512
|
+
"Iowa": "usa.states",
|
|
513
|
+
"Kansas": "usa.states",
|
|
514
|
+
"Kentucky": "usa.states",
|
|
515
|
+
"Louisiana": "usa.states",
|
|
516
|
+
"Maine": "usa.states",
|
|
517
|
+
"Maryland": "usa.states",
|
|
518
|
+
"Massachusetts": "usa.states",
|
|
519
|
+
"Michigan": "usa.states",
|
|
520
|
+
"Minnesota": "usa.states",
|
|
521
|
+
"Mississippi": "usa.states",
|
|
522
|
+
"Missouri": "usa.states",
|
|
523
|
+
"Montana": "usa.states",
|
|
524
|
+
"Nebraska": "usa.states",
|
|
525
|
+
"Nevada": "usa.states",
|
|
526
|
+
"NewHampshire": "usa.states",
|
|
527
|
+
"NewJersey": "usa.states",
|
|
528
|
+
"NewMexico": "usa.states",
|
|
529
|
+
"NewYork": "usa.states",
|
|
530
|
+
"NorthCarolina": "usa.states",
|
|
531
|
+
"NorthDakota": "usa.states",
|
|
532
|
+
"Ohio": "usa.states",
|
|
533
|
+
"Oklahoma": "usa.states",
|
|
534
|
+
"Oregon": "usa.states",
|
|
535
|
+
"Pennsylvania": "usa.states",
|
|
536
|
+
"RhodeIsland": "usa.states",
|
|
537
|
+
"SouthCarolina": "usa.states",
|
|
538
|
+
"SouthDakota": "usa.states",
|
|
539
|
+
"Tennessee": "usa.states",
|
|
540
|
+
"Texas": "usa.states",
|
|
541
|
+
"Utah": "usa.states",
|
|
542
|
+
"Vermont": "usa.states",
|
|
543
|
+
"Virginia": "usa.states",
|
|
544
|
+
"Washington": "usa.states",
|
|
545
|
+
"WestVirginia": "usa.states",
|
|
546
|
+
"Wisconsin": "usa.states",
|
|
547
|
+
"Wyoming": "usa.states",
|
|
548
|
+
|
|
549
|
+
#Oceania
|
|
550
|
+
"Australia": "oceania",
|
|
551
|
+
"AustraliaCapitalTerritory": "oceania",
|
|
552
|
+
"AustraliaNewSouthWales": "oceania",
|
|
553
|
+
"AustraliaNorthernTerritory": "oceania",
|
|
554
|
+
"AustraliaQueensland": "oceania",
|
|
555
|
+
"AustraliaSouthAustralia": "oceania",
|
|
556
|
+
"AustraliaTasmania": "oceania",
|
|
557
|
+
"AustraliaVictoria": "oceania",
|
|
558
|
+
"AustraliaWesternAustralia": "oceania",
|
|
559
|
+
"MarshallIslands": "oceania",
|
|
560
|
+
"NewZealand": "oceania",
|
|
561
|
+
|
|
562
|
+
#Asia
|
|
563
|
+
"China": "asia",
|
|
564
|
+
"HongKong": "asia",
|
|
565
|
+
"India": "asia",
|
|
566
|
+
"Israel": "asia",
|
|
567
|
+
"Japan": "asia",
|
|
568
|
+
"Kazakhstan": "asia",
|
|
569
|
+
"Malaysia": "asia",
|
|
570
|
+
"Qatar": "asia",
|
|
571
|
+
"Singapore": "asia",
|
|
572
|
+
"SouthKorea": "asia",
|
|
573
|
+
"Taiwan": "asia",
|
|
574
|
+
"Turkey": "asia",
|
|
575
|
+
"Vietnam": "asia",
|
|
576
|
+
|
|
577
|
+
#Africa
|
|
578
|
+
"Algeria": "africa",
|
|
579
|
+
"Angola": "africa",
|
|
580
|
+
"Benin": "africa",
|
|
581
|
+
"IvoryCoast": "africa",
|
|
582
|
+
"Kenya": "africa",
|
|
583
|
+
"Madagascar": "africa",
|
|
584
|
+
"Nigeria": "africa",
|
|
585
|
+
"SaoTomeAndPrincipe": "africa",
|
|
586
|
+
"SouthAfrica": "africa",
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
# Dictionary mapping ISO country codes to their corresponding workalendar country names
|
|
590
|
+
holiday_country = {
|
|
591
|
+
# Major countries with required formats
|
|
592
|
+
"GB": "UnitedKingdom",
|
|
593
|
+
"US": "UnitedStates",
|
|
594
|
+
"USA": "UnitedStates", # Alternative code for US
|
|
595
|
+
"CA": "Canada",
|
|
596
|
+
"ZA": "SouthAfrica",
|
|
597
|
+
"FR": "France",
|
|
598
|
+
"DE": "Germany",
|
|
599
|
+
"AU": "Australia",
|
|
600
|
+
"AUS": "Australia", # Alternative code for Australia
|
|
601
|
+
|
|
602
|
+
# European countries
|
|
603
|
+
"AT": "Austria",
|
|
604
|
+
"BY": "Belarus",
|
|
605
|
+
"BE": "Belgium",
|
|
606
|
+
"BG": "Bulgaria",
|
|
607
|
+
"HR": "Croatia",
|
|
608
|
+
"CY": "Cyprus",
|
|
609
|
+
"CZ": "Czechia",
|
|
610
|
+
"DK": "Denmark",
|
|
611
|
+
"EE": "Estonia",
|
|
612
|
+
"FI": "Finland",
|
|
613
|
+
"GR": "Greece",
|
|
614
|
+
"HU": "Hungary",
|
|
615
|
+
"IS": "Iceland",
|
|
616
|
+
"IE": "Ireland",
|
|
617
|
+
"IT": "Italy",
|
|
618
|
+
"LV": "Latvia",
|
|
619
|
+
"LT": "Lithuania",
|
|
620
|
+
"LU": "Luxembourg",
|
|
621
|
+
"MT": "Malta",
|
|
622
|
+
"MC": "Monaco",
|
|
623
|
+
"NL": "Netherlands",
|
|
624
|
+
"NO": "Norway",
|
|
625
|
+
"PL": "Poland",
|
|
626
|
+
"PT": "Portugal",
|
|
627
|
+
"RO": "Romania",
|
|
628
|
+
"RU": "Russia",
|
|
629
|
+
"RS": "Serbia",
|
|
630
|
+
"SK": "Slovakia",
|
|
631
|
+
"SI": "Slovenia",
|
|
632
|
+
"ES": "Spain",
|
|
633
|
+
"SE": "Sweden",
|
|
634
|
+
"CH": "Switzerland",
|
|
635
|
+
"UA": "Ukraine",
|
|
636
|
+
|
|
637
|
+
# Americas
|
|
638
|
+
"AR": "Argentina",
|
|
639
|
+
"BB": "Barbados",
|
|
640
|
+
"BR": "Brazil",
|
|
641
|
+
"CL": "Chile",
|
|
642
|
+
"CO": "Colombia",
|
|
643
|
+
"MX": "Mexico",
|
|
644
|
+
"PA": "Panama",
|
|
645
|
+
"PY": "Paraguay",
|
|
646
|
+
"PE": "Peru",
|
|
647
|
+
|
|
648
|
+
# USA States (using common abbreviations)
|
|
649
|
+
"AL": "Alabama",
|
|
650
|
+
"AK": "Alaska",
|
|
651
|
+
"AZ": "Arizona",
|
|
652
|
+
"AR": "Arkansas",
|
|
653
|
+
"CA_US": "California",
|
|
654
|
+
"CO_US": "Colorado",
|
|
655
|
+
"CT": "Connecticut",
|
|
656
|
+
"DE_US": "Delaware",
|
|
657
|
+
"DC": "DistrictOfColumbia",
|
|
658
|
+
"FL": "Florida",
|
|
659
|
+
"GA": "Georgia",
|
|
660
|
+
"HI": "Hawaii",
|
|
661
|
+
"ID": "Idaho",
|
|
662
|
+
"IL": "Illinois",
|
|
663
|
+
"IN": "Indiana",
|
|
664
|
+
"IA": "Iowa",
|
|
665
|
+
"KS": "Kansas",
|
|
666
|
+
"KY": "Kentucky",
|
|
667
|
+
"LA": "Louisiana",
|
|
668
|
+
"ME": "Maine",
|
|
669
|
+
"MD": "Maryland",
|
|
670
|
+
"MA": "Massachusetts",
|
|
671
|
+
"MI": "Michigan",
|
|
672
|
+
"MN": "Minnesota",
|
|
673
|
+
"MS": "Mississippi",
|
|
674
|
+
"MO": "Missouri",
|
|
675
|
+
"MT": "Montana",
|
|
676
|
+
"NE": "Nebraska",
|
|
677
|
+
"NV": "Nevada",
|
|
678
|
+
"NH": "NewHampshire",
|
|
679
|
+
"NJ": "NewJersey",
|
|
680
|
+
"NM": "NewMexico",
|
|
681
|
+
"NY": "NewYork",
|
|
682
|
+
"NC": "NorthCarolina",
|
|
683
|
+
"ND": "NorthDakota",
|
|
684
|
+
"OH": "Ohio",
|
|
685
|
+
"OK": "Oklahoma",
|
|
686
|
+
"OR": "Oregon",
|
|
687
|
+
"PA_US": "Pennsylvania",
|
|
688
|
+
"RI": "RhodeIsland",
|
|
689
|
+
"SC": "SouthCarolina",
|
|
690
|
+
"SD": "SouthDakota",
|
|
691
|
+
"TN": "Tennessee",
|
|
692
|
+
"TX": "Texas",
|
|
693
|
+
"UT": "Utah",
|
|
694
|
+
"VT": "Vermont",
|
|
695
|
+
"VA": "Virginia",
|
|
696
|
+
"WA": "Washington",
|
|
697
|
+
"WV": "WestVirginia",
|
|
698
|
+
"WI": "Wisconsin",
|
|
699
|
+
"WY": "Wyoming",
|
|
700
|
+
|
|
701
|
+
# Australia territories
|
|
702
|
+
"ACT": "AustraliaCapitalTerritory",
|
|
703
|
+
"NSW": "AustraliaNewSouthWales",
|
|
704
|
+
"NT": "AustraliaNorthernTerritory",
|
|
705
|
+
"QLD": "AustraliaQueensland",
|
|
706
|
+
"SA": "AustraliaSouthAustralia",
|
|
707
|
+
"TAS": "AustraliaTasmania",
|
|
708
|
+
"VIC": "AustraliaVictoria",
|
|
709
|
+
"WA_AU": "AustraliaWesternAustralia",
|
|
710
|
+
|
|
711
|
+
# Asian countries
|
|
712
|
+
"CN": "China",
|
|
713
|
+
"HK": "HongKong",
|
|
714
|
+
"IN": "India",
|
|
715
|
+
"IL": "Israel",
|
|
716
|
+
"JP": "Japan",
|
|
717
|
+
"KZ": "Kazakhstan",
|
|
718
|
+
"MY": "Malaysia",
|
|
719
|
+
"QA": "Qatar",
|
|
720
|
+
"SG": "Singapore",
|
|
721
|
+
"KR": "SouthKorea",
|
|
722
|
+
"TW": "Taiwan",
|
|
723
|
+
"TR": "Turkey",
|
|
724
|
+
"VN": "Vietnam",
|
|
725
|
+
|
|
726
|
+
# Other Oceania countries
|
|
727
|
+
"MH": "MarshallIslands",
|
|
728
|
+
"NZ": "NewZealand",
|
|
729
|
+
|
|
730
|
+
# African countries
|
|
731
|
+
"DZ": "Algeria",
|
|
732
|
+
"AO": "Angola",
|
|
733
|
+
"BJ": "Benin",
|
|
734
|
+
"CI": "IvoryCoast",
|
|
735
|
+
"KE": "Kenya",
|
|
736
|
+
"MG": "Madagascar",
|
|
737
|
+
"NG": "Nigeria",
|
|
738
|
+
"ST": "SaoTomeAndPrincipe"
|
|
739
|
+
}
|
|
740
|
+
|
|
402
741
|
|
|
403
742
|
# ---------------------------------------------------------------------
|
|
404
743
|
# 1. Create daily date range from start_date to today
|
|
@@ -452,44 +791,136 @@ class datapull:
|
|
|
452
791
|
df_dummies = pd.DataFrame(dummy_columns, index=df_weekly_start.index)
|
|
453
792
|
df_weekly_start = pd.concat([df_weekly_start, df_dummies], axis=1)
|
|
454
793
|
|
|
794
|
+
|
|
455
795
|
# ---------------------------------------------------------------------
|
|
456
|
-
# 3. Public holidays (daily)
|
|
796
|
+
# 3. Public holidays (daily) using WorkCalendar
|
|
457
797
|
# ---------------------------------------------------------------------
|
|
458
798
|
start_year = start_dt.year
|
|
459
799
|
end_year = end_dt.year
|
|
460
800
|
years_range = range(start_year, end_year + 1)
|
|
461
|
-
|
|
462
|
-
for country
|
|
801
|
+
|
|
802
|
+
# Dictionary to store holiday dummies for each country
|
|
803
|
+
country_holiday_dummies = {}
|
|
804
|
+
|
|
805
|
+
for country_code in countries:
|
|
806
|
+
# Skip if country code not found in holiday_country dictionary
|
|
807
|
+
if country_code not in holiday_country:
|
|
808
|
+
print(f"Warning: Country code '{country_code}' not found in country code dictionary. Skipping.")
|
|
809
|
+
continue
|
|
810
|
+
|
|
811
|
+
country = holiday_country[country_code]
|
|
812
|
+
|
|
813
|
+
# Skip if country not found in continent lookup dictionary
|
|
814
|
+
if country not in COUNTRY_TO_CONTINENT:
|
|
815
|
+
print(f"Warning: Country '{country}' not found in continent lookup dictionary. Skipping.")
|
|
816
|
+
continue
|
|
817
|
+
|
|
818
|
+
continent = COUNTRY_TO_CONTINENT[country]
|
|
819
|
+
module_path = f'workalendar.{continent}'
|
|
463
820
|
try:
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
#
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
821
|
+
module = importlib.import_module(module_path)
|
|
822
|
+
calendar_class = getattr(module, country)
|
|
823
|
+
cal = calendar_class()
|
|
824
|
+
except (ImportError, AttributeError) as e:
|
|
825
|
+
print(f"Error importing calendar for {country}: {e}. Skipping.")
|
|
826
|
+
continue
|
|
827
|
+
|
|
828
|
+
# Collect holidays
|
|
829
|
+
holidays_list = []
|
|
830
|
+
for year in years_range:
|
|
831
|
+
holidays_list.extend(cal.holidays(year))
|
|
832
|
+
|
|
833
|
+
holidays_df = pd.DataFrame(holidays_list, columns=['Date', 'Holiday'])
|
|
834
|
+
holidays_df['Date'] = pd.to_datetime(holidays_df['Date'])
|
|
835
|
+
|
|
836
|
+
# Filter out any holidays with "shift" or "substitute" in their name
|
|
837
|
+
holidays_df = holidays_df[~(holidays_df['Holiday'].str.lower().str.contains('shift') |
|
|
838
|
+
holidays_df['Holiday'].str.lower().str.contains('substitute'))]
|
|
839
|
+
|
|
840
|
+
# Filter by date range
|
|
841
|
+
holidays_df = holidays_df[(holidays_df['Date'] >= start_dt) & (holidays_df['Date'] <= end_dt)]
|
|
842
|
+
# ---------------------------------------------------------------------
|
|
843
|
+
# 3.1 Additional Public Holidays for Canada due to poor API data
|
|
844
|
+
# ---------------------------------------------------------------------
|
|
845
|
+
if country_code == 'CA':
|
|
846
|
+
# Add Canada Day (July 1st) if not already in the list
|
|
847
|
+
for year in years_range:
|
|
848
|
+
canada_day = pd.Timestamp(f"{year}-07-01")
|
|
849
|
+
if canada_day >= start_dt and canada_day <= end_dt:
|
|
850
|
+
if not ((holidays_df['Date'] == canada_day) &
|
|
851
|
+
(holidays_df['Holiday'].str.lower().str.contains('canada day'))).any():
|
|
852
|
+
holidays_df = pd.concat([holidays_df,
|
|
853
|
+
pd.DataFrame({'Date': [canada_day],
|
|
854
|
+
'Holiday': ['Canada Day']})],
|
|
855
|
+
ignore_index=True)
|
|
856
|
+
|
|
857
|
+
# Add Labour Day (first Monday in September)
|
|
858
|
+
for year in years_range:
|
|
859
|
+
# Get first day of September
|
|
860
|
+
first_day = pd.Timestamp(f"{year}-09-01")
|
|
861
|
+
# Calculate days until first Monday (Monday is weekday 0)
|
|
862
|
+
days_until_monday = (7 - first_day.weekday()) % 7
|
|
863
|
+
if days_until_monday == 0: # If first day is already Monday
|
|
864
|
+
labour_day = first_day
|
|
865
|
+
else:
|
|
866
|
+
labour_day = first_day + pd.Timedelta(days=days_until_monday)
|
|
867
|
+
|
|
868
|
+
if labour_day >= start_dt and labour_day <= end_dt:
|
|
869
|
+
if not ((holidays_df['Date'] == labour_day) &
|
|
870
|
+
(holidays_df['Holiday'].str.lower().str.contains('labour day'))).any():
|
|
871
|
+
holidays_df = pd.concat([holidays_df,
|
|
872
|
+
pd.DataFrame({'Date': [labour_day],
|
|
873
|
+
'Holiday': ['Labour Day']})],
|
|
874
|
+
ignore_index=True)
|
|
875
|
+
|
|
876
|
+
# Add Thanksgiving (second Monday in October)
|
|
877
|
+
for year in years_range:
|
|
878
|
+
# Get first day of October
|
|
879
|
+
first_day = pd.Timestamp(f"{year}-10-01")
|
|
880
|
+
# Calculate days until first Monday
|
|
881
|
+
days_until_monday = (7 - first_day.weekday()) % 7
|
|
882
|
+
if days_until_monday == 0: # If first day is already Monday
|
|
883
|
+
first_monday = first_day
|
|
884
|
+
else:
|
|
885
|
+
first_monday = first_day + pd.Timedelta(days=days_until_monday)
|
|
886
|
+
|
|
887
|
+
# Second Monday is 7 days after first Monday
|
|
888
|
+
thanksgiving = first_monday + pd.Timedelta(days=7)
|
|
889
|
+
|
|
890
|
+
if thanksgiving >= start_dt and thanksgiving <= end_dt:
|
|
891
|
+
if not ((holidays_df['Date'] == thanksgiving) &
|
|
892
|
+
(holidays_df['Holiday'].str.lower().str.contains('thanksgiving'))).any():
|
|
893
|
+
holidays_df = pd.concat([holidays_df,
|
|
894
|
+
pd.DataFrame({'Date': [thanksgiving],
|
|
895
|
+
'Holiday': ['Thanksgiving']})],
|
|
896
|
+
ignore_index=True)
|
|
897
|
+
|
|
898
|
+
# Now process the collected holidays and add to df_daily
|
|
899
|
+
for _, row in holidays_df.iterrows():
|
|
900
|
+
holiday_date = row['Date']
|
|
901
|
+
# Create column name without modifying original holiday names
|
|
902
|
+
holiday_name = row['Holiday'].lower().replace(' ', '_')
|
|
903
|
+
|
|
904
|
+
# Remove "_shift" or "_substitute" if they appear as standalone suffixes
|
|
905
|
+
if holiday_name.endswith('_shift'):
|
|
906
|
+
holiday_name = holiday_name[:-6]
|
|
907
|
+
elif holiday_name.endswith('_substitute'):
|
|
908
|
+
holiday_name = holiday_name[:-11]
|
|
909
|
+
|
|
910
|
+
column_name = f"seas_{holiday_name}_{country_code.lower()}"
|
|
911
|
+
|
|
912
|
+
if column_name not in df_daily.columns:
|
|
913
|
+
df_daily[column_name] = 0
|
|
914
|
+
|
|
915
|
+
# Mark the specific holiday date
|
|
916
|
+
df_daily.loc[df_daily["Date"] == holiday_date, column_name] = 1
|
|
917
|
+
|
|
918
|
+
# Also mark a general holiday indicator for each country
|
|
919
|
+
holiday_indicator = f"seas_holiday_{country_code.lower()}"
|
|
920
|
+
if holiday_indicator not in df_daily.columns:
|
|
921
|
+
df_daily[holiday_indicator] = 0
|
|
922
|
+
df_daily.loc[df_daily["Date"] == holiday_date, holiday_indicator] = 1
|
|
923
|
+
|
|
493
924
|
|
|
494
925
|
# ---------------------------------------------------------------------
|
|
495
926
|
# 3.1 Additional Special Days (Father's Day, Mother's Day, etc.)
|
|
@@ -597,7 +1028,7 @@ class datapull:
|
|
|
597
1028
|
|
|
598
1029
|
|
|
599
1030
|
# ---------------------------------------------------------------------
|
|
600
|
-
# 4. Add daily indicators for last day & last Friday of month
|
|
1031
|
+
# 4. Add daily indicators for last day & last Friday of month & payday
|
|
601
1032
|
# ---------------------------------------------------------------------
|
|
602
1033
|
df_daily["is_last_day_of_month"] = df_daily["Date"].dt.is_month_end
|
|
603
1034
|
|
|
@@ -608,22 +1039,27 @@ class datapull:
|
|
|
608
1039
|
# Check if next Friday is in the next month
|
|
609
1040
|
next_friday = date + timedelta(days=7)
|
|
610
1041
|
return 1 if next_friday.month != date.month else 0
|
|
1042
|
+
|
|
1043
|
+
def is_payday(date):
|
|
1044
|
+
return 1 if date.day >= 25 else 0
|
|
611
1045
|
|
|
612
1046
|
df_daily["is_last_friday_of_month"] = df_daily["Date"].apply(is_last_friday)
|
|
613
|
-
|
|
1047
|
+
|
|
1048
|
+
df_daily["is_payday"] = df_daily["Date"].apply(is_payday)
|
|
1049
|
+
|
|
614
1050
|
# Rename for clarity prefix
|
|
615
1051
|
df_daily.rename(columns={
|
|
616
1052
|
"is_last_day_of_month": "seas_last_day_of_month",
|
|
617
|
-
"is_last_friday_of_month": "seas_last_friday_of_month"
|
|
1053
|
+
"is_last_friday_of_month": "seas_last_friday_of_month",
|
|
1054
|
+
"is_payday": "seas_payday"
|
|
618
1055
|
}, inplace=True)
|
|
619
1056
|
|
|
620
|
-
|
|
621
1057
|
# ---------------------------------------------------------------------
|
|
622
1058
|
# 5. Weekly aggregation
|
|
623
1059
|
# ---------------------------------------------------------------------
|
|
624
1060
|
|
|
625
1061
|
# Select only columns that are indicators/flags (intended for max aggregation)
|
|
626
|
-
flag_cols = [col for col in df_daily.columns if col.startswith('seas_') or col.startswith('is_')]
|
|
1062
|
+
flag_cols = [col for col in df_daily.columns if (col.startswith('seas_') or col.startswith('is_')) and col != "seas_payday"]
|
|
627
1063
|
# Ensure 'week_start' is present for grouping
|
|
628
1064
|
df_to_agg = df_daily[['week_start'] + flag_cols]
|
|
629
1065
|
|
|
@@ -635,7 +1071,26 @@ class datapull:
|
|
|
635
1071
|
.rename(columns={'week_start': "Date"})
|
|
636
1072
|
.set_index("Date")
|
|
637
1073
|
)
|
|
1074
|
+
|
|
1075
|
+
# Do specific aggregation for payday
|
|
1076
|
+
# Make sure 'date' column exists in df_daily
|
|
1077
|
+
df_daily["month"] = df_daily["Date"].dt.month
|
|
1078
|
+
df_daily["year"] = df_daily["Date"].dt.year
|
|
1079
|
+
|
|
1080
|
+
# Sum of seas_payday flags per week
|
|
1081
|
+
week_payday_sum = df_daily.groupby("week_start")["seas_payday"].sum()
|
|
1082
|
+
|
|
1083
|
+
# Divide the number of payday flags by number of paydays per month
|
|
1084
|
+
payday_days_in_month = (
|
|
1085
|
+
df_daily.groupby(["year", "month"])["seas_payday"].sum()
|
|
1086
|
+
)
|
|
1087
|
+
week_month = df_daily.groupby("week_start").first()[["month", "year"]]
|
|
1088
|
+
week_days_in_month = week_month.apply(lambda row: payday_days_in_month.loc[(row["year"], row["month"])], axis=1)
|
|
1089
|
+
df_weekly_flags["seas_payday"] = (week_payday_sum / week_days_in_month).fillna(0).values
|
|
638
1090
|
|
|
1091
|
+
# # Drop intermediate columns
|
|
1092
|
+
# df_weekly_flags = df_weekly_flags.drop(columns=["month", "year"])
|
|
1093
|
+
|
|
639
1094
|
# --- Aggregate Week Number using MODE ---
|
|
640
1095
|
# Define aggregation function for mode (handling potential multi-modal cases by taking the first)
|
|
641
1096
|
def get_mode(x):
|
|
@@ -678,7 +1133,6 @@ class datapull:
|
|
|
678
1133
|
df_weekly_monthly_dummies.rename(columns={'week_start': 'Date'}, inplace=True)
|
|
679
1134
|
df_weekly_monthly_dummies.set_index('Date', inplace=True)
|
|
680
1135
|
|
|
681
|
-
|
|
682
1136
|
# ---------------------------------------------------------------------
|
|
683
1137
|
# 6. Combine all weekly components
|
|
684
1138
|
# ---------------------------------------------------------------------
|
|
@@ -697,15 +1151,15 @@ class datapull:
|
|
|
697
1151
|
|
|
698
1152
|
# Ensure correct types for flag columns (int)
|
|
699
1153
|
for col in df_weekly_flags.columns:
|
|
700
|
-
if col
|
|
701
|
-
|
|
1154
|
+
if col != 'seas_payday':
|
|
1155
|
+
if col in df_combined.columns:
|
|
1156
|
+
df_combined[col] = df_combined[col].astype(int)
|
|
702
1157
|
|
|
703
1158
|
# Ensure correct types for month columns (float)
|
|
704
1159
|
for col in df_weekly_monthly_dummies.columns:
|
|
705
1160
|
if col in df_combined.columns:
|
|
706
1161
|
df_combined[col] = df_combined[col].astype(float)
|
|
707
1162
|
|
|
708
|
-
|
|
709
1163
|
# ---------------------------------------------------------------------
|
|
710
1164
|
# 7. Create weekly dummies for Week of Year & yearly dummies from aggregated cols
|
|
711
1165
|
# ---------------------------------------------------------------------
|
|
@@ -737,7 +1191,7 @@ class datapull:
|
|
|
737
1191
|
# Filter out columns not in the desired order list (handles case where dum_ cols are off)
|
|
738
1192
|
final_cols = [col for col in cols_order if col in df_combined.columns]
|
|
739
1193
|
df_combined = df_combined[final_cols]
|
|
740
|
-
|
|
1194
|
+
|
|
741
1195
|
return df_combined
|
|
742
1196
|
|
|
743
1197
|
def pull_weather(self, week_commencing, start_date, country_codes) -> pd.DataFrame:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|