imsciences 0.9.6.6__py3-none-any.whl → 0.9.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of imsciences might be problematic. Click here for more details.

imsciences/pull.py CHANGED
@@ -13,6 +13,8 @@ import holidays
13
13
  from dateutil.easter import easter
14
14
  import urllib.request
15
15
  from geopy.geocoders import Nominatim
16
+ import importlib
17
+ import workalendar
16
18
 
17
19
  from imsciences.mmm import dataprocessing
18
20
 
@@ -379,7 +381,7 @@ class datapull:
379
381
 
380
382
  ############################################################### Seasonality ##########################################################################
381
383
 
382
- def pull_seasonality(self, week_commencing, start_date, countries):
384
+ def pull_seasonality(week_commencing, start_date, countries):
383
385
  """
384
386
  Generates a DataFrame with weekly seasonality features.
385
387
 
@@ -388,7 +390,7 @@ class datapull:
388
390
  start_date (str): The start date in 'YYYY-MM-DD' format.
389
391
  countries (list): A list of country codes (e.g., ['GB', 'US']) for holidays.
390
392
 
391
- Returns:
393
+ Returns:
392
394
  pd.DataFrame: A DataFrame indexed by week start date, containing various
393
395
  seasonal dummy variables, holidays, trend, and constant.
394
396
  The date column is named 'OBS'.
@@ -399,6 +401,343 @@ class datapull:
399
401
  day_dict = {"mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6}
400
402
  if week_commencing not in day_dict:
401
403
  raise ValueError(f"Invalid week_commencing value: {week_commencing}. Use one of {list(day_dict.keys())}")
404
+
405
+ # ---------------------------------------------------------------------
406
+ # 0.2 Setup: dictionary continents and countries
407
+ # ---------------------------------------------------------------------
408
+ COUNTRY_TO_CONTINENT = {
409
+ #Europe
410
+ "Austria": "europe",
411
+ "Belarus": "europe",
412
+ "Belgium": "europe",
413
+ "Bulgaria": "europe",
414
+ "Croatia": "europe",
415
+ "Cyprus": "europe",
416
+ "Czechia": "europe",
417
+ "CzechRepublic": "europe",
418
+ "Denmark": "europe",
419
+ "Estonia": "europe",
420
+ "EuropeanCentralBank": "europe",
421
+ "Finland": "europe",
422
+ "France": "europe",
423
+ "FranceAlsaceMoselle": "europe",
424
+ "Germany": "europe",
425
+ "GermanyBaden": "europe",
426
+ "GermanyBavaria": "europe",
427
+ "GermanyBerlin": "europe",
428
+ "GermanyBrandenburg": "europe",
429
+ "GermanyBremen": "europe",
430
+ "GermanyHamburg": "europe",
431
+ "GermanyHesse": "europe",
432
+ "GermanyLowerSaxony": "europe",
433
+ "GermanyMecklenburgVorpommern": "europe",
434
+ "GermanyNorthRhineWestphalia": "europe",
435
+ "GermanyRhinelandPalatinate": "europe",
436
+ "GermanySaarland": "europe",
437
+ "GermanySaxony": "europe",
438
+ "GermanySaxonyAnhalt": "europe",
439
+ "GermanySchleswigHolstein": "europe",
440
+ "GermanyThuringia": "europe",
441
+ "Greece": "europe",
442
+ "Hungary": "europe",
443
+ "Iceland": "europe",
444
+ "Ireland": "europe",
445
+ "Italy": "europe",
446
+ "Latvia": "europe",
447
+ "Lithuania": "europe",
448
+ "Luxembourg": "europe",
449
+ "Malta": "europe",
450
+ "Monaco": "europe",
451
+ "Netherlands": "europe",
452
+ "Norway": "europe",
453
+ "Poland": "europe",
454
+ "Portugal": "europe",
455
+ "Romania": "europe",
456
+ "Russia": "europe",
457
+ "Serbia": "europe",
458
+ "Slovakia": "europe",
459
+ "Slovenia": "europe",
460
+ "Spain": "europe",
461
+ "SpainAndalusia": "europe",
462
+ "SpainAragon": "europe",
463
+ "SpainAsturias": "europe",
464
+ "SpainBalearicIslands": "europe",
465
+ "SpainBasqueCountry": "europe",
466
+ "SpainCanaryIslands": "europe",
467
+ "SpainCantabria": "europe",
468
+ "SpainCastileAndLeon": "europe",
469
+ "SpainCastillaLaMancha": "europe",
470
+ "SpainCatalonia": "europe",
471
+ "SpainExtremadura": "europe",
472
+ "SpainGalicia": "europe",
473
+ "SpainLaRioja": "europe",
474
+ "SpainMadrid": "europe",
475
+ "SpainMurcia": "europe",
476
+ "SpainNavarre": "europe",
477
+ "SpainValencia": "europe",
478
+ "Sweden": "europe",
479
+ "Switzerland": "europe",
480
+ "Ukraine": "europe",
481
+ "UnitedKingdom": "europe",
482
+
483
+ # Americas
484
+ "Argentina": "america",
485
+ "Barbados": "america",
486
+ "Brazil": "america",
487
+ "Canada": "america",
488
+ "Chile": "america",
489
+ "Colombia": "america",
490
+ "Mexico": "america",
491
+ "Panama": "america",
492
+ "Paraguay": "america",
493
+ "Peru": "america",
494
+ "UnitedStates": "usa",
495
+
496
+ #US States
497
+ "Alabama": "usa.states",
498
+ "Alaska": "usa.states",
499
+ "Arizona": "usa.states",
500
+ "Arkansas": "usa.states",
501
+ "California": "usa.states",
502
+ "Colorado": "usa.states",
503
+ "Connecticut": "usa.states",
504
+ "Delaware": "usa.states",
505
+ "DistrictOfColumbia": "usa.states",
506
+ "Florida": "usa.states",
507
+ "Georgia": "usa.states",
508
+ "Hawaii": "usa.states",
509
+ "Idaho": "usa.states",
510
+ "Illinois": "usa.states",
511
+ "Indiana": "usa.states",
512
+ "Iowa": "usa.states",
513
+ "Kansas": "usa.states",
514
+ "Kentucky": "usa.states",
515
+ "Louisiana": "usa.states",
516
+ "Maine": "usa.states",
517
+ "Maryland": "usa.states",
518
+ "Massachusetts": "usa.states",
519
+ "Michigan": "usa.states",
520
+ "Minnesota": "usa.states",
521
+ "Mississippi": "usa.states",
522
+ "Missouri": "usa.states",
523
+ "Montana": "usa.states",
524
+ "Nebraska": "usa.states",
525
+ "Nevada": "usa.states",
526
+ "NewHampshire": "usa.states",
527
+ "NewJersey": "usa.states",
528
+ "NewMexico": "usa.states",
529
+ "NewYork": "usa.states",
530
+ "NorthCarolina": "usa.states",
531
+ "NorthDakota": "usa.states",
532
+ "Ohio": "usa.states",
533
+ "Oklahoma": "usa.states",
534
+ "Oregon": "usa.states",
535
+ "Pennsylvania": "usa.states",
536
+ "RhodeIsland": "usa.states",
537
+ "SouthCarolina": "usa.states",
538
+ "SouthDakota": "usa.states",
539
+ "Tennessee": "usa.states",
540
+ "Texas": "usa.states",
541
+ "Utah": "usa.states",
542
+ "Vermont": "usa.states",
543
+ "Virginia": "usa.states",
544
+ "Washington": "usa.states",
545
+ "WestVirginia": "usa.states",
546
+ "Wisconsin": "usa.states",
547
+ "Wyoming": "usa.states",
548
+
549
+ #Oceania
550
+ "Australia": "oceania",
551
+ "AustraliaCapitalTerritory": "oceania",
552
+ "AustraliaNewSouthWales": "oceania",
553
+ "AustraliaNorthernTerritory": "oceania",
554
+ "AustraliaQueensland": "oceania",
555
+ "AustraliaSouthAustralia": "oceania",
556
+ "AustraliaTasmania": "oceania",
557
+ "AustraliaVictoria": "oceania",
558
+ "AustraliaWesternAustralia": "oceania",
559
+ "MarshallIslands": "oceania",
560
+ "NewZealand": "oceania",
561
+
562
+ #Asia
563
+ "China": "asia",
564
+ "HongKong": "asia",
565
+ "India": "asia",
566
+ "Israel": "asia",
567
+ "Japan": "asia",
568
+ "Kazakhstan": "asia",
569
+ "Malaysia": "asia",
570
+ "Qatar": "asia",
571
+ "Singapore": "asia",
572
+ "SouthKorea": "asia",
573
+ "Taiwan": "asia",
574
+ "Turkey": "asia",
575
+ "Vietnam": "asia",
576
+
577
+ #Africa
578
+ "Algeria": "africa",
579
+ "Angola": "africa",
580
+ "Benin": "africa",
581
+ "IvoryCoast": "africa",
582
+ "Kenya": "africa",
583
+ "Madagascar": "africa",
584
+ "Nigeria": "africa",
585
+ "SaoTomeAndPrincipe": "africa",
586
+ "SouthAfrica": "africa",
587
+ }
588
+
589
+ # Dictionary mapping ISO country codes to their corresponding workalendar country names
590
+ holiday_country = {
591
+ # Major countries with required formats
592
+ "GB": "UnitedKingdom",
593
+ "US": "UnitedStates",
594
+ "USA": "UnitedStates", # Alternative code for US
595
+ "CA": "Canada",
596
+ "ZA": "SouthAfrica",
597
+ "FR": "France",
598
+ "DE": "Germany",
599
+ "AU": "Australia",
600
+ "AUS": "Australia", # Alternative code for Australia
601
+
602
+ # European countries
603
+ "AT": "Austria",
604
+ "BY": "Belarus",
605
+ "BE": "Belgium",
606
+ "BG": "Bulgaria",
607
+ "HR": "Croatia",
608
+ "CY": "Cyprus",
609
+ "CZ": "Czechia",
610
+ "DK": "Denmark",
611
+ "EE": "Estonia",
612
+ "FI": "Finland",
613
+ "GR": "Greece",
614
+ "HU": "Hungary",
615
+ "IS": "Iceland",
616
+ "IE": "Ireland",
617
+ "IT": "Italy",
618
+ "LV": "Latvia",
619
+ "LT": "Lithuania",
620
+ "LU": "Luxembourg",
621
+ "MT": "Malta",
622
+ "MC": "Monaco",
623
+ "NL": "Netherlands",
624
+ "NO": "Norway",
625
+ "PL": "Poland",
626
+ "PT": "Portugal",
627
+ "RO": "Romania",
628
+ "RU": "Russia",
629
+ "RS": "Serbia",
630
+ "SK": "Slovakia",
631
+ "SI": "Slovenia",
632
+ "ES": "Spain",
633
+ "SE": "Sweden",
634
+ "CH": "Switzerland",
635
+ "UA": "Ukraine",
636
+
637
+ # Americas
638
+ "AR": "Argentina",
639
+ "BB": "Barbados",
640
+ "BR": "Brazil",
641
+ "CL": "Chile",
642
+ "CO": "Colombia",
643
+ "MX": "Mexico",
644
+ "PA": "Panama",
645
+ "PY": "Paraguay",
646
+ "PE": "Peru",
647
+
648
+ # USA States (using common abbreviations)
649
+ "AL": "Alabama",
650
+ "AK": "Alaska",
651
+ "AZ": "Arizona",
652
+ "AR": "Arkansas",
653
+ "CA_US": "California",
654
+ "CO_US": "Colorado",
655
+ "CT": "Connecticut",
656
+ "DE_US": "Delaware",
657
+ "DC": "DistrictOfColumbia",
658
+ "FL": "Florida",
659
+ "GA": "Georgia",
660
+ "HI": "Hawaii",
661
+ "ID": "Idaho",
662
+ "IL": "Illinois",
663
+ "IN": "Indiana",
664
+ "IA": "Iowa",
665
+ "KS": "Kansas",
666
+ "KY": "Kentucky",
667
+ "LA": "Louisiana",
668
+ "ME": "Maine",
669
+ "MD": "Maryland",
670
+ "MA": "Massachusetts",
671
+ "MI": "Michigan",
672
+ "MN": "Minnesota",
673
+ "MS": "Mississippi",
674
+ "MO": "Missouri",
675
+ "MT": "Montana",
676
+ "NE": "Nebraska",
677
+ "NV": "Nevada",
678
+ "NH": "NewHampshire",
679
+ "NJ": "NewJersey",
680
+ "NM": "NewMexico",
681
+ "NY": "NewYork",
682
+ "NC": "NorthCarolina",
683
+ "ND": "NorthDakota",
684
+ "OH": "Ohio",
685
+ "OK": "Oklahoma",
686
+ "OR": "Oregon",
687
+ "PA_US": "Pennsylvania",
688
+ "RI": "RhodeIsland",
689
+ "SC": "SouthCarolina",
690
+ "SD": "SouthDakota",
691
+ "TN": "Tennessee",
692
+ "TX": "Texas",
693
+ "UT": "Utah",
694
+ "VT": "Vermont",
695
+ "VA": "Virginia",
696
+ "WA": "Washington",
697
+ "WV": "WestVirginia",
698
+ "WI": "Wisconsin",
699
+ "WY": "Wyoming",
700
+
701
+ # Australia territories
702
+ "ACT": "AustraliaCapitalTerritory",
703
+ "NSW": "AustraliaNewSouthWales",
704
+ "NT": "AustraliaNorthernTerritory",
705
+ "QLD": "AustraliaQueensland",
706
+ "SA": "AustraliaSouthAustralia",
707
+ "TAS": "AustraliaTasmania",
708
+ "VIC": "AustraliaVictoria",
709
+ "WA_AU": "AustraliaWesternAustralia",
710
+
711
+ # Asian countries
712
+ "CN": "China",
713
+ "HK": "HongKong",
714
+ "IN": "India",
715
+ "IL": "Israel",
716
+ "JP": "Japan",
717
+ "KZ": "Kazakhstan",
718
+ "MY": "Malaysia",
719
+ "QA": "Qatar",
720
+ "SG": "Singapore",
721
+ "KR": "SouthKorea",
722
+ "TW": "Taiwan",
723
+ "TR": "Turkey",
724
+ "VN": "Vietnam",
725
+
726
+ # Other Oceania countries
727
+ "MH": "MarshallIslands",
728
+ "NZ": "NewZealand",
729
+
730
+ # African countries
731
+ "DZ": "Algeria",
732
+ "AO": "Angola",
733
+ "BJ": "Benin",
734
+ "CI": "IvoryCoast",
735
+ "KE": "Kenya",
736
+ "MG": "Madagascar",
737
+ "NG": "Nigeria",
738
+ "ST": "SaoTomeAndPrincipe"
739
+ }
740
+
402
741
 
403
742
  # ---------------------------------------------------------------------
404
743
  # 1. Create daily date range from start_date to today
@@ -452,44 +791,136 @@ class datapull:
452
791
  df_dummies = pd.DataFrame(dummy_columns, index=df_weekly_start.index)
453
792
  df_weekly_start = pd.concat([df_weekly_start, df_dummies], axis=1)
454
793
 
794
+
455
795
  # ---------------------------------------------------------------------
456
- # 3. Public holidays (daily) from 'holidays' package + each holiday name
796
+ # 3. Public holidays (daily) using WorkCalendar
457
797
  # ---------------------------------------------------------------------
458
798
  start_year = start_dt.year
459
799
  end_year = end_dt.year
460
800
  years_range = range(start_year, end_year + 1)
461
-
462
- for country in countries:
801
+
802
+ # Dictionary to store holiday dummies for each country
803
+ country_holiday_dummies = {}
804
+
805
+ for country_code in countries:
806
+ # Skip if country code not found in holiday_country dictionary
807
+ if country_code not in holiday_country:
808
+ print(f"Warning: Country code '{country_code}' not found in country code dictionary. Skipping.")
809
+ continue
810
+
811
+ country = holiday_country[country_code]
812
+
813
+ # Skip if country not found in continent lookup dictionary
814
+ if country not in COUNTRY_TO_CONTINENT:
815
+ print(f"Warning: Country '{country}' not found in continent lookup dictionary. Skipping.")
816
+ continue
817
+
818
+ continent = COUNTRY_TO_CONTINENT[country]
819
+ module_path = f'workalendar.{continent}'
463
820
  try:
464
- country_holidays = holidays.CountryHoliday(
465
- country,
466
- years=years_range,
467
- observed=False # Typically you want the actual date, not observed substitute
468
- )
469
- # Handle cases like UK where specific subdivisions might be needed for some holidays
470
- # Example: if country == 'GB': country_holidays.observed = True # If observed are needed
471
- except KeyError:
472
- print(f"Warning: Country code '{country}' not found in holidays library. Skipping.")
473
- continue # Skip to next country
474
-
475
- # Daily indicator: 1 if that date is a holiday
476
- df_daily[f"seas_holiday_{country.lower()}"] = df_daily["Date"].apply(
477
- lambda x: 1 if x in country_holidays else 0
478
- )
479
- # Create columns for specific holiday names
480
- for date_hol, name in sorted(country_holidays.items()): # Sort for consistent column order
481
- # Clean name: lower, replace space with underscore, remove non-alphanumeric (except underscore)
482
- clean_name = ''.join(c for c in name if c.isalnum() or c == ' ').strip().replace(' ', '_').lower()
483
- clean_name = clean_name.replace('_(observed)', '').replace("'", "") # specific cleaning
484
- col_name = f"seas_{clean_name}_{country.lower()}"
485
-
486
- # Only create column if the holiday occurs within the df_daily date range
487
- if pd.Timestamp(date_hol).year in years_range:
488
- if col_name not in df_daily.columns:
489
- df_daily[col_name] = 0
490
- # Ensure date_hol is within the actual daily range before assigning
491
- if (pd.Timestamp(date_hol) >= df_daily["Date"].min()) and (pd.Timestamp(date_hol) <= df_daily["Date"].max()):
492
- df_daily.loc[df_daily["Date"] == pd.Timestamp(date_hol), col_name] = 1
821
+ module = importlib.import_module(module_path)
822
+ calendar_class = getattr(module, country)
823
+ cal = calendar_class()
824
+ except (ImportError, AttributeError) as e:
825
+ print(f"Error importing calendar for {country}: {e}. Skipping.")
826
+ continue
827
+
828
+ # Collect holidays
829
+ holidays_list = []
830
+ for year in years_range:
831
+ holidays_list.extend(cal.holidays(year))
832
+
833
+ holidays_df = pd.DataFrame(holidays_list, columns=['Date', 'Holiday'])
834
+ holidays_df['Date'] = pd.to_datetime(holidays_df['Date'])
835
+
836
+ # Filter out any holidays with "shift" or "substitute" in their name
837
+ holidays_df = holidays_df[~(holidays_df['Holiday'].str.lower().str.contains('shift') |
838
+ holidays_df['Holiday'].str.lower().str.contains('substitute'))]
839
+
840
+ # Filter by date range
841
+ holidays_df = holidays_df[(holidays_df['Date'] >= start_dt) & (holidays_df['Date'] <= end_dt)]
842
+ # ---------------------------------------------------------------------
843
+ # 3.1 Additional Public Holidays for Canada due to poor API data
844
+ # ---------------------------------------------------------------------
845
+ if country_code == 'CA':
846
+ # Add Canada Day (July 1st) if not already in the list
847
+ for year in years_range:
848
+ canada_day = pd.Timestamp(f"{year}-07-01")
849
+ if canada_day >= start_dt and canada_day <= end_dt:
850
+ if not ((holidays_df['Date'] == canada_day) &
851
+ (holidays_df['Holiday'].str.lower().str.contains('canada day'))).any():
852
+ holidays_df = pd.concat([holidays_df,
853
+ pd.DataFrame({'Date': [canada_day],
854
+ 'Holiday': ['Canada Day']})],
855
+ ignore_index=True)
856
+
857
+ # Add Labour Day (first Monday in September)
858
+ for year in years_range:
859
+ # Get first day of September
860
+ first_day = pd.Timestamp(f"{year}-09-01")
861
+ # Calculate days until first Monday (Monday is weekday 0)
862
+ days_until_monday = (7 - first_day.weekday()) % 7
863
+ if days_until_monday == 0: # If first day is already Monday
864
+ labour_day = first_day
865
+ else:
866
+ labour_day = first_day + pd.Timedelta(days=days_until_monday)
867
+
868
+ if labour_day >= start_dt and labour_day <= end_dt:
869
+ if not ((holidays_df['Date'] == labour_day) &
870
+ (holidays_df['Holiday'].str.lower().str.contains('labour day'))).any():
871
+ holidays_df = pd.concat([holidays_df,
872
+ pd.DataFrame({'Date': [labour_day],
873
+ 'Holiday': ['Labour Day']})],
874
+ ignore_index=True)
875
+
876
+ # Add Thanksgiving (second Monday in October)
877
+ for year in years_range:
878
+ # Get first day of October
879
+ first_day = pd.Timestamp(f"{year}-10-01")
880
+ # Calculate days until first Monday
881
+ days_until_monday = (7 - first_day.weekday()) % 7
882
+ if days_until_monday == 0: # If first day is already Monday
883
+ first_monday = first_day
884
+ else:
885
+ first_monday = first_day + pd.Timedelta(days=days_until_monday)
886
+
887
+ # Second Monday is 7 days after first Monday
888
+ thanksgiving = first_monday + pd.Timedelta(days=7)
889
+
890
+ if thanksgiving >= start_dt and thanksgiving <= end_dt:
891
+ if not ((holidays_df['Date'] == thanksgiving) &
892
+ (holidays_df['Holiday'].str.lower().str.contains('thanksgiving'))).any():
893
+ holidays_df = pd.concat([holidays_df,
894
+ pd.DataFrame({'Date': [thanksgiving],
895
+ 'Holiday': ['Thanksgiving']})],
896
+ ignore_index=True)
897
+
898
+ # Now process the collected holidays and add to df_daily
899
+ for _, row in holidays_df.iterrows():
900
+ holiday_date = row['Date']
901
+ # Create column name without modifying original holiday names
902
+ holiday_name = row['Holiday'].lower().replace(' ', '_')
903
+
904
+ # Remove "_shift" or "_substitute" if they appear as standalone suffixes
905
+ if holiday_name.endswith('_shift'):
906
+ holiday_name = holiday_name[:-6]
907
+ elif holiday_name.endswith('_substitute'):
908
+ holiday_name = holiday_name[:-11]
909
+
910
+ column_name = f"seas_{holiday_name}_{country_code.lower()}"
911
+
912
+ if column_name not in df_daily.columns:
913
+ df_daily[column_name] = 0
914
+
915
+ # Mark the specific holiday date
916
+ df_daily.loc[df_daily["Date"] == holiday_date, column_name] = 1
917
+
918
+ # Also mark a general holiday indicator for each country
919
+ holiday_indicator = f"seas_holiday_{country_code.lower()}"
920
+ if holiday_indicator not in df_daily.columns:
921
+ df_daily[holiday_indicator] = 0
922
+ df_daily.loc[df_daily["Date"] == holiday_date, holiday_indicator] = 1
923
+
493
924
 
494
925
  # ---------------------------------------------------------------------
495
926
  # 3.1 Additional Special Days (Father's Day, Mother's Day, etc.)
@@ -597,7 +1028,7 @@ class datapull:
597
1028
 
598
1029
 
599
1030
  # ---------------------------------------------------------------------
600
- # 4. Add daily indicators for last day & last Friday of month
1031
+ # 4. Add daily indicators for last day & last Friday of month & payday
601
1032
  # ---------------------------------------------------------------------
602
1033
  df_daily["is_last_day_of_month"] = df_daily["Date"].dt.is_month_end
603
1034
 
@@ -608,22 +1039,27 @@ class datapull:
608
1039
  # Check if next Friday is in the next month
609
1040
  next_friday = date + timedelta(days=7)
610
1041
  return 1 if next_friday.month != date.month else 0
1042
+
1043
+ def is_payday(date):
1044
+ return 1 if date.day >= 25 else 0
611
1045
 
612
1046
  df_daily["is_last_friday_of_month"] = df_daily["Date"].apply(is_last_friday)
613
-
1047
+
1048
+ df_daily["is_payday"] = df_daily["Date"].apply(is_payday)
1049
+
614
1050
  # Rename for clarity prefix
615
1051
  df_daily.rename(columns={
616
1052
  "is_last_day_of_month": "seas_last_day_of_month",
617
- "is_last_friday_of_month": "seas_last_friday_of_month"
1053
+ "is_last_friday_of_month": "seas_last_friday_of_month",
1054
+ "is_payday": "seas_payday"
618
1055
  }, inplace=True)
619
1056
 
620
-
621
1057
  # ---------------------------------------------------------------------
622
1058
  # 5. Weekly aggregation
623
1059
  # ---------------------------------------------------------------------
624
1060
 
625
1061
  # Select only columns that are indicators/flags (intended for max aggregation)
626
- flag_cols = [col for col in df_daily.columns if col.startswith('seas_') or col.startswith('is_')]
1062
+ flag_cols = [col for col in df_daily.columns if (col.startswith('seas_') or col.startswith('is_')) and col != "seas_payday"]
627
1063
  # Ensure 'week_start' is present for grouping
628
1064
  df_to_agg = df_daily[['week_start'] + flag_cols]
629
1065
 
@@ -635,7 +1071,26 @@ class datapull:
635
1071
  .rename(columns={'week_start': "Date"})
636
1072
  .set_index("Date")
637
1073
  )
1074
+
1075
+ # Do specific aggregation for payday
1076
+ # Make sure 'date' column exists in df_daily
1077
+ df_daily["month"] = df_daily["Date"].dt.month
1078
+ df_daily["year"] = df_daily["Date"].dt.year
1079
+
1080
+ # Sum of seas_payday flags per week
1081
+ week_payday_sum = df_daily.groupby("week_start")["seas_payday"].sum()
1082
+
1083
+ # Divide the number of payday flags by number of paydays per month
1084
+ payday_days_in_month = (
1085
+ df_daily.groupby(["year", "month"])["seas_payday"].sum()
1086
+ )
1087
+ week_month = df_daily.groupby("week_start").first()[["month", "year"]]
1088
+ week_days_in_month = week_month.apply(lambda row: payday_days_in_month.loc[(row["year"], row["month"])], axis=1)
1089
+ df_weekly_flags["seas_payday"] = (week_payday_sum / week_days_in_month).fillna(0).values
638
1090
 
1091
+ # # Drop intermediate columns
1092
+ # df_weekly_flags = df_weekly_flags.drop(columns=["month", "year"])
1093
+
639
1094
  # --- Aggregate Week Number using MODE ---
640
1095
  # Define aggregation function for mode (handling potential multi-modal cases by taking the first)
641
1096
  def get_mode(x):
@@ -678,7 +1133,6 @@ class datapull:
678
1133
  df_weekly_monthly_dummies.rename(columns={'week_start': 'Date'}, inplace=True)
679
1134
  df_weekly_monthly_dummies.set_index('Date', inplace=True)
680
1135
 
681
-
682
1136
  # ---------------------------------------------------------------------
683
1137
  # 6. Combine all weekly components
684
1138
  # ---------------------------------------------------------------------
@@ -697,15 +1151,15 @@ class datapull:
697
1151
 
698
1152
  # Ensure correct types for flag columns (int)
699
1153
  for col in df_weekly_flags.columns:
700
- if col in df_combined.columns:
701
- df_combined[col] = df_combined[col].astype(int)
1154
+ if col != 'seas_payday':
1155
+ if col in df_combined.columns:
1156
+ df_combined[col] = df_combined[col].astype(int)
702
1157
 
703
1158
  # Ensure correct types for month columns (float)
704
1159
  for col in df_weekly_monthly_dummies.columns:
705
1160
  if col in df_combined.columns:
706
1161
  df_combined[col] = df_combined[col].astype(float)
707
1162
 
708
-
709
1163
  # ---------------------------------------------------------------------
710
1164
  # 7. Create weekly dummies for Week of Year & yearly dummies from aggregated cols
711
1165
  # ---------------------------------------------------------------------
@@ -737,7 +1191,7 @@ class datapull:
737
1191
  # Filter out columns not in the desired order list (handles case where dum_ cols are off)
738
1192
  final_cols = [col for col in cols_order if col in df_combined.columns]
739
1193
  df_combined = df_combined[final_cols]
740
-
1194
+
741
1195
  return df_combined
742
1196
 
743
1197
  def pull_weather(self, week_commencing, start_date, country_codes) -> pd.DataFrame:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imsciences
3
- Version: 0.9.6.6
3
+ Version: 0.9.6.8
4
4
  Summary: IMS Data Processing Package
5
5
  Author: IMS
6
6
  Author-email: cam@im-sciences.com
@@ -1,12 +1,12 @@
1
1
  imsciences/__init__.py,sha256=_HuYeLbDMTdt7GpKI4r6-d7yRPZgcAQ7yOW0-ydR2Yo,117
2
2
  imsciences/geo.py,sha256=eenng7_BP_E2WD5Wt1G_oNxQS8W3t6lycRwJ91ngysY,15808
3
3
  imsciences/mmm.py,sha256=qMh0ccOepehfCcux7EeG8cq6piSEoFEz5iiJbDBWOS4,82214
4
- imsciences/pull.py,sha256=4NGKzNmsvfzADMucR8iLGTkYDyb5wdnqphe1CzepyWw,94992
4
+ imsciences/pull.py,sha256=rHC4__gvfpcgAjgELABYvZrGyb1Ucg0qOW7724qiNns,114816
5
5
  imsciences/unittesting.py,sha256=U177_Txg0Lqn49zYRu5bl9OVe_X7MkNJ6V_Zd6DHOsU,45656
6
6
  imsciences/vis.py,sha256=2izdHQhmWEReerRqIxhY4Ai10VjL7xoUqyWyZC7-2XI,8931
7
- imsciences-0.9.6.6.dist-info/LICENSE.txt,sha256=lVq2QwcExPX4Kl2DHeEkRrikuItcDB1Pr7yF7FQ8_z8,1108
8
- imsciences-0.9.6.6.dist-info/METADATA,sha256=VO2fDn0xBsVhkPf42oeGOD5TEyC3uo1rHdylOmdaMu8,18846
9
- imsciences-0.9.6.6.dist-info/PKG-INFO-TomG-HP-290722,sha256=RMcthCSyWmU6IBsXGL-nYqw0RP06pzjPKK3dzOQcU-8,18846
10
- imsciences-0.9.6.6.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
11
- imsciences-0.9.6.6.dist-info/top_level.txt,sha256=hsENS-AlDVRh8tQJ6-426iUQlla9bPcGc0-UlFF0_iU,11
12
- imsciences-0.9.6.6.dist-info/RECORD,,
7
+ imsciences-0.9.6.8.dist-info/LICENSE.txt,sha256=lVq2QwcExPX4Kl2DHeEkRrikuItcDB1Pr7yF7FQ8_z8,1108
8
+ imsciences-0.9.6.8.dist-info/METADATA,sha256=8M6I7Tw6M3mZsQYv08QxzyLKNQ4MclT8BvCRhhkuAgs,18846
9
+ imsciences-0.9.6.8.dist-info/PKG-INFO-TomG-HP-290722,sha256=RMcthCSyWmU6IBsXGL-nYqw0RP06pzjPKK3dzOQcU-8,18846
10
+ imsciences-0.9.6.8.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
11
+ imsciences-0.9.6.8.dist-info/top_level.txt,sha256=hsENS-AlDVRh8tQJ6-426iUQlla9bPcGc0-UlFF0_iU,11
12
+ imsciences-0.9.6.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (70.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5