gsppy 4.1.0__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsppy/cli.py CHANGED
@@ -35,7 +35,8 @@ import csv
35
35
  import sys
36
36
  import json
37
37
  import logging
38
- from typing import Any, List, Tuple, Union, Optional, cast
38
+ import importlib
39
+ from typing import Any, List, Tuple, Union, Callable, Optional, cast
39
40
 
40
41
  import click
41
42
 
@@ -51,6 +52,54 @@ from gsppy.enums import (
51
52
  from gsppy.utils import has_timestamps
52
53
 
53
54
 
55
+ def _load_hook_function(import_path: str, hook_type: str) -> Callable[..., Any]:
56
+ """
57
+ Load a hook function from a Python module import path.
58
+
59
+ Parameters:
60
+ import_path (str): Import path in format 'module.submodule.function_name'
61
+ hook_type (str): Type of hook for error messages ('preprocess', 'postprocess', 'candidate_filter')
62
+
63
+ Returns:
64
+ Callable: The loaded hook function
65
+
66
+ Raises:
67
+ ValueError: If the import path is invalid or function cannot be loaded
68
+ """
69
+ try:
70
+ # Split into module path and function name
71
+ parts = import_path.rsplit(".", 1)
72
+ if len(parts) != 2:
73
+ raise ValueError(f"Invalid import path format. Expected 'module.function', got '{import_path}'")
74
+
75
+ module_name, function_name = parts
76
+
77
+ # Import the module
78
+ module = importlib.import_module(module_name)
79
+
80
+ # Get the function from the module
81
+ if not hasattr(module, function_name):
82
+ raise ValueError(f"Function '{function_name}' not found in module '{module_name}'")
83
+
84
+ hook_fn = getattr(module, function_name)
85
+
86
+ # Verify it's callable
87
+ if not callable(hook_fn):
88
+ raise ValueError(f"'{import_path}' is not a callable function")
89
+
90
+ return hook_fn
91
+
92
+ except ImportError as e:
93
+ # Extract module name from import path for error message
94
+ module_part = import_path.rsplit(".", 1)[0] if "." in import_path else import_path
95
+ raise ValueError(f"Failed to import {hook_type} hook module '{module_part}': {e}") from e
96
+ except ValueError:
97
+ # Re-raise ValueError as-is
98
+ raise
99
+ except Exception as e:
100
+ raise ValueError(f"Failed to load {hook_type} hook function '{import_path}': {e}") from e
101
+
102
+
54
103
  def setup_logging(verbose: bool) -> None:
55
104
  """
56
105
  Configure logging with standardized format based on verbosity level.
@@ -515,20 +564,26 @@ def _load_transactions_by_format(
515
564
  help="File format to use. 'auto' detects format from file extension.",
516
565
  )
517
566
  @click.option("--verbose", is_flag=True, help="Enable verbose output for debugging purposes.")
518
- def main(
519
- file_path: str,
520
- min_support: float,
521
- backend: str,
522
- mingap: Optional[float],
523
- maxgap: Optional[float],
524
- maxspan: Optional[float],
525
- transaction_col: Optional[str],
526
- item_col: Optional[str],
527
- timestamp_col: Optional[str],
528
- sequence_col: Optional[str],
529
- format: str, # noqa: A002
530
- verbose: bool,
531
- ) -> None:
567
+ @click.option(
568
+ "--preprocess-hook",
569
+ type=str,
570
+ default=None,
571
+ help="Python import path to preprocessing hook function (e.g., 'mymodule.preprocess_fn').",
572
+ )
573
+ @click.option(
574
+ "--postprocess-hook",
575
+ type=str,
576
+ default=None,
577
+ help="Python import path to postprocessing hook function (e.g., 'mymodule.postprocess_fn').",
578
+ )
579
+ @click.option(
580
+ "--candidate-filter-hook",
581
+ type=str,
582
+ default=None,
583
+ help="Python import path to candidate filter hook function (e.g., 'mymodule.filter_fn').",
584
+ )
585
+ @click.pass_context
586
+ def main(ctx: click.Context, **kwargs: Any) -> None:
532
587
  """
533
588
  Run the GSP algorithm on transactional data from a file.
534
589
 
@@ -573,9 +628,59 @@ def main(
573
628
  ```bash
574
629
  gsppy --file data.txt --format spm --min_support 0.3
575
630
  ```
631
+
632
+ With custom hooks (requires Python module with hook functions):
633
+
634
+ ```bash
635
+ # Create a hooks module first (hooks.py):
636
+ # def my_filter(candidate, support, context):
637
+ # return len(candidate) <= 2 # Keep only short patterns
638
+ #
639
+ # def my_postprocess(patterns):
640
+ # return patterns[:2] # Keep only first 2 levels
641
+
642
+ gsppy --file data.json --min_support 0.3 \
643
+ --candidate-filter-hook hooks.my_filter \
644
+ --postprocess-hook hooks.my_postprocess
645
+ ```
576
646
  """
647
+ # Extract parameters from kwargs
648
+ file_path = kwargs['file_path']
649
+ min_support = kwargs['min_support']
650
+ backend = kwargs['backend']
651
+ mingap = kwargs.get('mingap')
652
+ maxgap = kwargs.get('maxgap')
653
+ maxspan = kwargs.get('maxspan')
654
+ transaction_col = kwargs.get('transaction_col')
655
+ item_col = kwargs.get('item_col')
656
+ timestamp_col = kwargs.get('timestamp_col')
657
+ sequence_col = kwargs.get('sequence_col')
658
+ file_format = kwargs['format']
659
+ verbose = kwargs['verbose']
660
+ preprocess_hook = kwargs.get('preprocess_hook')
661
+ postprocess_hook = kwargs.get('postprocess_hook')
662
+ candidate_filter_hook = kwargs.get('candidate_filter_hook')
663
+
577
664
  setup_logging(verbose)
578
665
 
666
+ # Load hook functions if specified
667
+ try:
668
+ preprocess_fn = _load_hook_function(preprocess_hook, "preprocess") if preprocess_hook else None
669
+ postprocess_fn = _load_hook_function(postprocess_hook, "postprocess") if postprocess_hook else None
670
+ candidate_filter_fn = (
671
+ _load_hook_function(candidate_filter_hook, "candidate_filter") if candidate_filter_hook else None
672
+ )
673
+
674
+ if preprocess_fn:
675
+ logger.info(f"Loaded preprocessing hook: {preprocess_hook}")
676
+ if postprocess_fn:
677
+ logger.info(f"Loaded postprocessing hook: {postprocess_hook}")
678
+ if candidate_filter_fn:
679
+ logger.info(f"Loaded candidate filter hook: {candidate_filter_hook}")
680
+ except ValueError as e:
681
+ logger.error(f"Error loading hook function: {e}")
682
+ sys.exit(1)
683
+
579
684
  # Detect file extension to determine if DataFrame column params are needed
580
685
  _, file_extension = os.path.splitext(file_path)
581
686
  file_extension = file_extension.lower()
@@ -583,10 +688,10 @@ def main(
583
688
 
584
689
  # Automatically detect and load transactions
585
690
  try:
586
- file_format = format.lower()
691
+ file_format_lower = file_format.lower()
587
692
  transactions = _load_transactions_by_format(
588
693
  file_path,
589
- file_format,
694
+ file_format_lower,
590
695
  file_extension,
591
696
  is_dataframe_format,
592
697
  transaction_col,
@@ -608,7 +713,13 @@ def main(
608
713
  # Initialize and run GSP algorithm
609
714
  try:
610
715
  gsp = GSP(transactions, mingap=mingap, maxgap=maxgap, maxspan=maxspan, verbose=verbose)
611
- patterns = gsp.search(min_support=min_support, return_sequences=False)
716
+ patterns = gsp.search(
717
+ min_support=min_support,
718
+ return_sequences=False,
719
+ preprocess_fn=preprocess_fn,
720
+ postprocess_fn=postprocess_fn,
721
+ candidate_filter_fn=candidate_filter_fn,
722
+ )
612
723
  logger.info("Frequent Patterns Found:")
613
724
  for i, level in enumerate(patterns, start=1):
614
725
  logger.info(f"\n{i}-Sequence Patterns:")