corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +460 -21
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +32 -47
  8. statement_extractor/gliner_extraction.py +218 -0
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +15 -6
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. statement_extractor/scoring.py +17 -69
  52. corp_extractor-0.3.0.dist-info/RECORD +0 -12
  53. statement_extractor/spacy_extraction.py +0 -386
  54. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  55. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,134 @@
1
+ """
2
+ PipelineConfig - Configuration for stage/plugin selection.
3
+
4
+ Controls which stages are enabled and which plugins to use.
5
+ """
6
+
7
+ from typing import Any, Optional
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class PipelineConfig(BaseModel):
13
+ """
14
+ Configuration for the extraction pipeline.
15
+
16
+ Controls which stages are enabled, which plugins to use,
17
+ and stage-specific options.
18
+ """
19
+ # Stage selection (1=Splitting, 2=Extraction, 3=Qualification, 4=Canonicalization, 5=Labeling, 6=Taxonomy)
20
+ enabled_stages: set[int] = Field(
21
+ default={1, 2, 3, 4, 5, 6},
22
+ description="Set of enabled stage numbers (1-6)"
23
+ )
24
+
25
+ # Plugin selection
26
+ enabled_plugins: Optional[set[str]] = Field(
27
+ None,
28
+ description="Set of enabled plugin names (None = all enabled)"
29
+ )
30
+ disabled_plugins: set[str] = Field(
31
+ default_factory=lambda: {
32
+ "mnli_taxonomy_classifier", # Disabled by default - use embedding_taxonomy_classifier instead (faster)
33
+ },
34
+ description="Set of disabled plugin names"
35
+ )
36
+
37
+ # Stage-specific options
38
+ splitter_options: dict[str, Any] = Field(
39
+ default_factory=dict,
40
+ description="Options passed to splitter plugins"
41
+ )
42
+ extractor_options: dict[str, Any] = Field(
43
+ default_factory=dict,
44
+ description="Options passed to extractor plugins"
45
+ )
46
+ qualifier_options: dict[str, Any] = Field(
47
+ default_factory=dict,
48
+ description="Options passed to qualifier plugins"
49
+ )
50
+ canonicalizer_options: dict[str, Any] = Field(
51
+ default_factory=dict,
52
+ description="Options passed to canonicalizer plugins"
53
+ )
54
+ labeler_options: dict[str, Any] = Field(
55
+ default_factory=dict,
56
+ description="Options passed to labeler plugins"
57
+ )
58
+ taxonomy_options: dict[str, Any] = Field(
59
+ default_factory=dict,
60
+ description="Options passed to taxonomy plugins"
61
+ )
62
+
63
+ # General options
64
+ fail_fast: bool = Field(
65
+ default=True,
66
+ description="Stop processing on first error (otherwise continue and collect errors)"
67
+ )
68
+ parallel_processing: bool = Field(
69
+ default=False,
70
+ description="Enable parallel processing where possible"
71
+ )
72
+ max_statements: Optional[int] = Field(
73
+ None,
74
+ description="Maximum number of statements to process (None = unlimited)"
75
+ )
76
+
77
+ def is_stage_enabled(self, stage: int) -> bool:
78
+ """Check if a stage is enabled."""
79
+ return stage in self.enabled_stages
80
+
81
+ def is_plugin_enabled(self, plugin_name: str) -> bool:
82
+ """Check if a plugin is enabled."""
83
+ if plugin_name in self.disabled_plugins:
84
+ return False
85
+ if self.enabled_plugins is None:
86
+ return True
87
+ return plugin_name in self.enabled_plugins
88
+
89
+ @classmethod
90
+ def from_stage_string(cls, stages: str, **kwargs) -> "PipelineConfig":
91
+ """
92
+ Create config from a stage string.
93
+
94
+ Examples:
95
+ "1,2,3" -> stages 1, 2, 3
96
+ "1-3" -> stages 1, 2, 3
97
+ "1-5" -> all stages
98
+ """
99
+ enabled = set()
100
+ for part in stages.split(","):
101
+ part = part.strip()
102
+ if "-" in part:
103
+ start, end = part.split("-", 1)
104
+ for i in range(int(start), int(end) + 1):
105
+ enabled.add(i)
106
+ else:
107
+ enabled.add(int(part))
108
+ return cls(enabled_stages=enabled, **kwargs)
109
+
110
+ @classmethod
111
+ def default(cls) -> "PipelineConfig":
112
+ """Create a default configuration with all stages enabled."""
113
+ return cls()
114
+
115
+ @classmethod
116
+ def minimal(cls) -> "PipelineConfig":
117
+ """Create a minimal configuration with only splitting and extraction."""
118
+ return cls(enabled_stages={1, 2})
119
+
120
+
121
+ # Stage name mapping
122
+ STAGE_NAMES = {
123
+ 1: "splitting",
124
+ 2: "extraction",
125
+ 3: "qualification",
126
+ 4: "canonicalization",
127
+ 5: "labeling",
128
+ 6: "taxonomy",
129
+ }
130
+
131
+
132
+ def get_stage_name(stage: int) -> str:
133
+ """Get the human-readable name for a stage."""
134
+ return STAGE_NAMES.get(stage, f"stage_{stage}")
@@ -0,0 +1,177 @@
1
+ """
2
+ PipelineContext - Data container that flows through all pipeline stages.
3
+
4
+ The context accumulates outputs from each stage:
5
+ - Stage 1 (Splitting): raw_triples
6
+ - Stage 2 (Extraction): statements
7
+ - Stage 3 (Qualification): qualified_entities
8
+ - Stage 4 (Canonicalization): canonical_entities
9
+ - Stage 5 (Labeling): labeled_statements
10
+ """
11
+
12
+ from typing import Any, Optional
13
+
14
+ from pydantic import BaseModel, Field
15
+
16
+ from ..models import (
17
+ RawTriple,
18
+ PipelineStatement,
19
+ QualifiedEntity,
20
+ CanonicalEntity,
21
+ LabeledStatement,
22
+ TaxonomyResult,
23
+ )
24
+
25
+
26
+ class PipelineContext(BaseModel):
27
+ """
28
+ Context object that flows through all pipeline stages.
29
+
30
+ Accumulates outputs from each stage and provides access to
31
+ source text, metadata, and intermediate results.
32
+ """
33
+ # Input
34
+ source_text: str = Field(..., description="Original input text")
35
+ source_metadata: dict[str, Any] = Field(
36
+ default_factory=dict,
37
+ description="Metadata about the source (e.g., document ID, URL, timestamp)"
38
+ )
39
+
40
+ # Stage 1 output: Raw triples from splitting
41
+ raw_triples: list[RawTriple] = Field(
42
+ default_factory=list,
43
+ description="Raw triples from Stage 1 (Splitting)"
44
+ )
45
+
46
+ # Stage 2 output: Statements with extracted entities
47
+ statements: list[PipelineStatement] = Field(
48
+ default_factory=list,
49
+ description="Statements from Stage 2 (Extraction)"
50
+ )
51
+
52
+ # Stage 3 output: Qualified entities (keyed by entity_ref)
53
+ qualified_entities: dict[str, QualifiedEntity] = Field(
54
+ default_factory=dict,
55
+ description="Qualified entities from Stage 3 (Qualification)"
56
+ )
57
+
58
+ # Stage 4 output: Canonical entities (keyed by entity_ref)
59
+ canonical_entities: dict[str, CanonicalEntity] = Field(
60
+ default_factory=dict,
61
+ description="Canonical entities from Stage 4 (Canonicalization)"
62
+ )
63
+
64
+ # Stage 5 output: Final labeled statements
65
+ labeled_statements: list[LabeledStatement] = Field(
66
+ default_factory=list,
67
+ description="Final labeled statements from Stage 5 (Labeling)"
68
+ )
69
+
70
+ # Classification results from extractor (populated by GLiNER2 or similar)
71
+ # Keyed by source_text -> label_type -> (label_value, confidence)
72
+ classification_results: dict[str, dict[str, tuple[str, float]]] = Field(
73
+ default_factory=dict,
74
+ description="Pre-computed classification results from Stage 2 extractor"
75
+ )
76
+
77
+ # Stage 6 output: Taxonomy classifications
78
+ # Keyed by (source_text, taxonomy_name) -> list of TaxonomyResult
79
+ # Multiple labels may match a single statement above threshold
80
+ taxonomy_results: dict[tuple[str, str], list[TaxonomyResult]] = Field(
81
+ default_factory=dict,
82
+ description="Taxonomy classifications from Stage 6 (multiple labels per statement)"
83
+ )
84
+
85
+ # Processing metadata
86
+ processing_errors: list[str] = Field(
87
+ default_factory=list,
88
+ description="Errors encountered during processing"
89
+ )
90
+ processing_warnings: list[str] = Field(
91
+ default_factory=list,
92
+ description="Warnings generated during processing"
93
+ )
94
+ stage_timings: dict[str, float] = Field(
95
+ default_factory=dict,
96
+ description="Timing information for each stage (stage_name -> seconds)"
97
+ )
98
+
99
+ def add_error(self, error: str) -> None:
100
+ """Add a processing error."""
101
+ self.processing_errors.append(error)
102
+
103
+ def add_warning(self, warning: str) -> None:
104
+ """Add a processing warning."""
105
+ self.processing_warnings.append(warning)
106
+
107
+ def record_timing(self, stage: str, duration: float) -> None:
108
+ """Record timing for a stage."""
109
+ self.stage_timings[stage] = duration
110
+
111
+ def get_entity_refs(self) -> set[str]:
112
+ """Get all unique entity refs from statements."""
113
+ refs = set()
114
+ for stmt in self.statements:
115
+ refs.add(stmt.subject.entity_ref)
116
+ refs.add(stmt.object.entity_ref)
117
+ return refs
118
+
119
+ def get_qualified_entity(self, entity_ref: str) -> Optional[QualifiedEntity]:
120
+ """Get qualified entity by ref, or None if not found."""
121
+ return self.qualified_entities.get(entity_ref)
122
+
123
+ def get_canonical_entity(self, entity_ref: str) -> Optional[CanonicalEntity]:
124
+ """Get canonical entity by ref, or None if not found."""
125
+ return self.canonical_entities.get(entity_ref)
126
+
127
+ def get_classification(
128
+ self,
129
+ source_text: str,
130
+ label_type: str,
131
+ ) -> Optional[tuple[str, float]]:
132
+ """
133
+ Get pre-computed classification result for a source text.
134
+
135
+ Args:
136
+ source_text: The source text that was classified
137
+ label_type: The type of label (e.g., "sentiment")
138
+
139
+ Returns:
140
+ Tuple of (label_value, confidence) or None if not found
141
+ """
142
+ if source_text in self.classification_results:
143
+ return self.classification_results[source_text].get(label_type)
144
+ return None
145
+
146
+ def set_classification(
147
+ self,
148
+ source_text: str,
149
+ label_type: str,
150
+ label_value: str,
151
+ confidence: float,
152
+ ) -> None:
153
+ """
154
+ Store a classification result for a source text.
155
+
156
+ Args:
157
+ source_text: The source text that was classified
158
+ label_type: The type of label (e.g., "sentiment")
159
+ label_value: The classification result (e.g., "positive")
160
+ confidence: Confidence score (0.0 to 1.0)
161
+ """
162
+ if source_text not in self.classification_results:
163
+ self.classification_results[source_text] = {}
164
+ self.classification_results[source_text][label_type] = (label_value, confidence)
165
+
166
+ @property
167
+ def has_errors(self) -> bool:
168
+ """Check if any errors occurred during processing."""
169
+ return len(self.processing_errors) > 0
170
+
171
+ @property
172
+ def statement_count(self) -> int:
173
+ """Get the number of statements in the final output."""
174
+ return len(self.labeled_statements) if self.labeled_statements else len(self.statements)
175
+
176
+ class Config:
177
+ arbitrary_types_allowed = True