corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +460 -21
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +32 -47
  8. statement_extractor/gliner_extraction.py +218 -0
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +15 -6
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. statement_extractor/scoring.py +17 -69
  52. corp_extractor-0.3.0.dist-info/RECORD +0 -12
  53. statement_extractor/spacy_extraction.py +0 -386
  54. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  55. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -721,16 +721,17 @@ class StatementExtractor:
721
721
  Parse XML output into Statement objects.
722
722
 
723
723
  Uses model for subject, object, entity types, and source_text.
724
- Always uses spaCy for predicate extraction (model predicates are unreliable).
724
+ Always uses GLiNER2 for predicate extraction (model predicates are unreliable).
725
725
 
726
726
  Produces two candidates for each statement:
727
- 1. Hybrid: model subject/object + spaCy predicate
728
- 2. spaCy-only: all components from spaCy
727
+ 1. Hybrid: model subject/object + GLiNER2 predicate
728
+ 2. GLiNER2-only: all components from GLiNER2
729
729
 
730
730
  Both go into the candidate pool; scoring/dedup picks the best.
731
731
  """
732
732
  statements: list[Statement] = []
733
- use_spacy_extraction = options.use_spacy_extraction if options else True
733
+ use_gliner_extraction = options.use_gliner_extraction if options else True
734
+ predicates = options.predicates if options else None
734
735
 
735
736
  try:
736
737
  root = ET.fromstring(xml_output)
@@ -780,75 +781,59 @@ class StatementExtractor:
780
781
  logger.debug(f"Skipping statement: missing subject or object from model")
781
782
  continue
782
783
 
783
- if use_spacy_extraction and source_text:
784
+ if use_gliner_extraction and source_text:
784
785
  try:
785
- from .spacy_extraction import extract_triple_from_text, extract_triple_by_predicate_split
786
- spacy_result = extract_triple_from_text(
786
+ from .gliner_extraction import extract_triple_from_text
787
+
788
+ # Get model predicate for fallback/refinement
789
+ predicate_elem = stmt_elem.find('predicate')
790
+ model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
791
+
792
+ gliner_result = extract_triple_from_text(
787
793
  source_text=source_text,
788
794
  model_subject=subject_text,
789
795
  model_object=object_text,
790
- model_predicate="", # Don't pass model predicate
796
+ model_predicate=model_predicate,
797
+ predicates=predicates,
791
798
  )
792
- if spacy_result:
793
- spacy_subj, spacy_pred, spacy_obj = spacy_result
799
+ if gliner_result:
800
+ gliner_subj, gliner_pred, gliner_obj = gliner_result
794
801
 
795
- if spacy_pred:
796
- # Candidate 1: Hybrid (model subject/object + spaCy predicate)
802
+ if gliner_pred:
803
+ # Candidate 1: Hybrid (model subject/object + GLiNER2 predicate)
797
804
  logger.debug(
798
- f"Adding hybrid candidate: '{subject_text}' --[{spacy_pred}]--> '{object_text}'"
805
+ f"Adding hybrid candidate: '{subject_text}' --[{gliner_pred}]--> '{object_text}'"
799
806
  )
800
807
  statements.append(Statement(
801
808
  subject=Entity(text=subject_text, type=subject_type),
802
- predicate=spacy_pred,
809
+ predicate=gliner_pred,
803
810
  object=Entity(text=object_text, type=object_type),
804
811
  source_text=source_text,
805
812
  extraction_method=ExtractionMethod.HYBRID,
806
813
  ))
807
814
 
808
- # Candidate 2: spaCy-only (if different from hybrid)
809
- if spacy_subj and spacy_obj:
810
- is_different = (spacy_subj != subject_text or spacy_obj != object_text)
815
+ # Candidate 2: GLiNER2-only (if different from hybrid)
816
+ if gliner_subj and gliner_obj:
817
+ is_different = (gliner_subj != subject_text or gliner_obj != object_text)
811
818
  if is_different:
812
819
  logger.debug(
813
- f"Adding spaCy-only candidate: '{spacy_subj}' --[{spacy_pred}]--> '{spacy_obj}'"
814
- )
815
- statements.append(Statement(
816
- subject=Entity(text=spacy_subj, type=subject_type),
817
- predicate=spacy_pred,
818
- object=Entity(text=spacy_obj, type=object_type),
819
- source_text=source_text,
820
- extraction_method=ExtractionMethod.SPACY,
821
- ))
822
-
823
- # Candidate 3: Predicate-split (split source text around predicate)
824
- split_result = extract_triple_by_predicate_split(
825
- source_text=source_text,
826
- predicate=spacy_pred,
827
- )
828
- if split_result:
829
- split_subj, split_pred, split_obj = split_result
830
- # Only add if different from previous candidates
831
- is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
832
- is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
833
- if is_different_from_hybrid and is_different_from_spacy:
834
- logger.debug(
835
- f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
820
+ f"Adding GLiNER2-only candidate: '{gliner_subj}' --[{gliner_pred}]--> '{gliner_obj}'"
836
821
  )
837
822
  statements.append(Statement(
838
- subject=Entity(text=split_subj, type=subject_type),
839
- predicate=split_pred,
840
- object=Entity(text=split_obj, type=object_type),
823
+ subject=Entity(text=gliner_subj, type=subject_type),
824
+ predicate=gliner_pred,
825
+ object=Entity(text=gliner_obj, type=object_type),
841
826
  source_text=source_text,
842
- extraction_method=ExtractionMethod.SPLIT,
827
+ extraction_method=ExtractionMethod.GLINER,
843
828
  ))
844
829
  else:
845
830
  logger.debug(
846
- f"spaCy found no predicate for: '{subject_text}' --> '{object_text}'"
831
+ f"GLiNER2 found no predicate for: '{subject_text}' --> '{object_text}'"
847
832
  )
848
833
  except Exception as e:
849
- logger.debug(f"spaCy extraction failed: {e}")
834
+ logger.debug(f"GLiNER2 extraction failed: {e}")
850
835
  else:
851
- # spaCy disabled - fall back to model predicate
836
+ # GLiNER2 disabled - fall back to model predicate
852
837
  predicate_elem = stmt_elem.find('predicate')
853
838
  model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
854
839
 
@@ -0,0 +1,218 @@
1
+ """
2
+ GLiNER2-based triple extraction.
3
+
4
+ Uses GLiNER2 for relation extraction and entity recognition to extract
5
+ subject, predicate, and object from source text. T5-Gemma model provides
6
+ triple structure and coreference resolution, while GLiNER2 handles
7
+ linguistic analysis.
8
+
9
+ The GLiNER2 model is loaded automatically on first use.
10
+ """
11
+
12
+ import logging
13
+ from typing import Optional
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Lazy-loaded GLiNER2 model
18
+ _model = None
19
+
20
+
21
+ def _get_model():
22
+ """
23
+ Lazy-load the GLiNER2 model.
24
+
25
+ Uses the base model (205M parameters) which is CPU-optimized.
26
+ """
27
+ global _model
28
+ if _model is None:
29
+ from gliner2 import GLiNER2
30
+
31
+ logger.info("Loading GLiNER2 model 'fastino/gliner2-base-v1'...")
32
+ _model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
33
+ logger.debug("GLiNER2 model loaded")
34
+ return _model
35
+
36
+
37
+ def extract_triple_from_text(
38
+ source_text: str,
39
+ model_subject: str,
40
+ model_object: str,
41
+ model_predicate: str,
42
+ predicates: Optional[list[str]] = None,
43
+ ) -> tuple[str, str, str] | None:
44
+ """
45
+ Extract subject, predicate, object from source text using GLiNER2.
46
+
47
+ Returns a GLiNER2-based triple that can be added to the candidate pool
48
+ alongside the model's triple. The existing scoring/dedup logic will
49
+ pick the best one.
50
+
51
+ Args:
52
+ source_text: The source sentence to analyze
53
+ model_subject: Subject from T5-Gemma (used for matching and fallback)
54
+ model_object: Object from T5-Gemma (used for matching and fallback)
55
+ model_predicate: Predicate from T5-Gemma (used when no predicates provided)
56
+ predicates: Optional list of predefined relation types to extract
57
+
58
+ Returns:
59
+ Tuple of (subject, predicate, object) from GLiNER2, or None if extraction fails
60
+ """
61
+ if not source_text:
62
+ return None
63
+
64
+ try:
65
+ model = _get_model()
66
+
67
+ if predicates:
68
+ # Use relation extraction with predefined predicates
69
+ result = model.extract_relations(source_text, predicates)
70
+
71
+ # Find best matching relation
72
+ relation_data = result.get("relation_extraction", {})
73
+ best_match = None
74
+ best_confidence = 0.0
75
+
76
+ for rel_type, relations in relation_data.items():
77
+ for rel in relations:
78
+ # Handle both tuple format and dict format
79
+ if isinstance(rel, tuple):
80
+ head, tail = rel
81
+ confidence = 1.0
82
+ else:
83
+ head = rel.get("head", {}).get("text", "")
84
+ tail = rel.get("tail", {}).get("text", "")
85
+ confidence = min(
86
+ rel.get("head", {}).get("confidence", 0.5),
87
+ rel.get("tail", {}).get("confidence", 0.5)
88
+ )
89
+
90
+ # Score based on match with model hints
91
+ score = confidence
92
+ if model_subject.lower() in head.lower() or head.lower() in model_subject.lower():
93
+ score += 0.2
94
+ if model_object.lower() in tail.lower() or tail.lower() in model_object.lower():
95
+ score += 0.2
96
+
97
+ if score > best_confidence:
98
+ best_confidence = score
99
+ best_match = (head, rel_type, tail)
100
+
101
+ if best_match:
102
+ logger.debug(
103
+ f"GLiNER2 extracted (relation): subj='{best_match[0]}', pred='{best_match[1]}', obj='{best_match[2]}'"
104
+ )
105
+ return best_match
106
+
107
+ else:
108
+ # No predicate list provided - use GLiNER2 for entity extraction
109
+ # and extract predicate from source text using the model's hint
110
+
111
+ # Extract entities to refine subject/object boundaries
112
+ entity_types = [
113
+ "person", "organization", "company", "location", "city", "country",
114
+ "product", "event", "date", "money", "quantity"
115
+ ]
116
+ result = model.extract_entities(source_text, entity_types)
117
+ entities = result.get("entities", {})
118
+
119
+ # Find entities that match model subject/object
120
+ refined_subject = model_subject
121
+ refined_object = model_object
122
+
123
+ for entity_type, entity_list in entities.items():
124
+ for entity in entity_list:
125
+ entity_lower = entity.lower()
126
+ # Check if this entity matches or contains the model's subject/object
127
+ if model_subject.lower() in entity_lower or entity_lower in model_subject.lower():
128
+ # Use the entity text if it's more complete
129
+ if len(entity) >= len(refined_subject):
130
+ refined_subject = entity
131
+ if model_object.lower() in entity_lower or entity_lower in model_object.lower():
132
+ if len(entity) >= len(refined_object):
133
+ refined_object = entity
134
+
135
+ # Use model predicate directly (T5-Gemma provides the predicate)
136
+ if model_predicate:
137
+ logger.debug(
138
+ f"GLiNER2 extracted (entity-refined): subj='{refined_subject}', pred='{model_predicate}', obj='{refined_object}'"
139
+ )
140
+ return (refined_subject, model_predicate, refined_object)
141
+
142
+ return None
143
+
144
+ except ImportError as e:
145
+ logger.warning(f"GLiNER2 not installed: {e}")
146
+ return None
147
+ except Exception as e:
148
+ logger.debug(f"GLiNER2 extraction failed: {e}")
149
+ return None
150
+
151
+
152
+ def score_entity_content(text: str) -> float:
153
+ """
154
+ Score how entity-like a text is using GLiNER2 entity recognition.
155
+
156
+ Returns:
157
+ 1.0 - Recognized as a named entity with high confidence
158
+ 0.8 - Recognized as an entity with moderate confidence
159
+ 0.6 - Partially recognized or contains entity-like content
160
+ 0.2 - Not recognized as any entity type
161
+ """
162
+ if not text or not text.strip():
163
+ return 0.2
164
+
165
+ try:
166
+ model = _get_model()
167
+
168
+ # Check if text is recognized as common entity types
169
+ entity_types = [
170
+ "person", "organization", "company", "location", "city", "country",
171
+ "product", "event", "date", "money", "quantity"
172
+ ]
173
+
174
+ result = model.extract_entities(
175
+ text,
176
+ entity_types,
177
+ include_confidence=True
178
+ )
179
+
180
+ # Result format: {'entities': {'person': [{'text': '...', 'confidence': 0.99}], ...}}
181
+ entities_dict = result.get("entities", {})
182
+
183
+ # Find best matching entity across all types
184
+ best_confidence = 0.0
185
+ text_lower = text.lower().strip()
186
+
187
+ for entity_type, entity_list in entities_dict.items():
188
+ for entity in entity_list:
189
+ if isinstance(entity, dict):
190
+ entity_text = entity.get("text", "").lower().strip()
191
+ confidence = entity.get("confidence", 0.5)
192
+ else:
193
+ # Fallback for string format
194
+ entity_text = str(entity).lower().strip()
195
+ confidence = 0.8
196
+
197
+ # Check if entity covers most of the input text
198
+ if entity_text == text_lower:
199
+ # Exact match
200
+ best_confidence = max(best_confidence, confidence)
201
+ elif entity_text in text_lower or text_lower in entity_text:
202
+ # Partial match - reduce confidence
203
+ best_confidence = max(best_confidence, confidence * 0.8)
204
+
205
+ if best_confidence >= 0.9:
206
+ return 1.0
207
+ elif best_confidence >= 0.7:
208
+ return 0.8
209
+ elif best_confidence >= 0.5:
210
+ return 0.6
211
+ elif best_confidence > 0:
212
+ return 0.4
213
+ else:
214
+ return 0.2
215
+
216
+ except Exception as e:
217
+ logger.debug(f"Entity scoring failed for '{text}': {e}")
218
+ return 0.5 # Neutral score on error
@@ -0,0 +1,255 @@
1
+ """
2
+ LLM module for text generation using local models.
3
+
4
+ Supports:
5
+ - GGUF models via llama-cpp-python (efficient quantized inference)
6
+ - Transformers models via HuggingFace
7
+
8
+ Usage:
9
+ from statement_extractor.llm import LLM
10
+
11
+ llm = LLM() # Uses default Gemma3 12B GGUF
12
+ response = llm.generate("Your prompt here")
13
+ """
14
+
15
+ import logging
16
+ from typing import Optional
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class LLM:
22
+ """
23
+ LLM wrapper for text generation.
24
+
25
+ Automatically selects the best backend:
26
+ - GGUF models use llama-cpp-python (efficient, no de-quantization)
27
+ - Other models use HuggingFace transformers
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
33
+ gguf_file: Optional[str] = None,
34
+ n_ctx: int = 8192,
35
+ use_4bit: bool = True,
36
+ ):
37
+ """
38
+ Initialize the LLM.
39
+
40
+ Args:
41
+ model_id: HuggingFace model ID
42
+ gguf_file: GGUF filename (auto-detected if model_id ends with -gguf)
43
+ n_ctx: Context size for GGUF models
44
+ use_4bit: Use 4-bit quantization for transformers models
45
+ """
46
+ self._model_id = model_id
47
+ self._gguf_file = gguf_file
48
+ self._n_ctx = n_ctx
49
+ self._use_4bit = use_4bit
50
+
51
+ # Model instances (lazy loaded)
52
+ self._llama_model = None # llama-cpp-python
53
+ self._transformers_model = None # HuggingFace transformers
54
+ self._tokenizer = None
55
+
56
+ self._load_failed = False
57
+
58
+ @property
59
+ def is_loaded(self) -> bool:
60
+ """Check if the model is loaded."""
61
+ return self._llama_model is not None or self._transformers_model is not None
62
+
63
+ def _is_gguf_model(self) -> bool:
64
+ """Check if the model ID is a GGUF model."""
65
+ return self._model_id.endswith("-gguf") or self._gguf_file is not None
66
+
67
+ def _get_gguf_filename(self) -> str:
68
+ """Get the GGUF filename from the model ID."""
69
+ if self._gguf_file:
70
+ return self._gguf_file
71
+ # Extract filename from model ID like "google/gemma-3-12b-it-qat-q4_0-gguf"
72
+ # The actual file is "gemma-3-12b-it-q4_0.gguf" (note: "qat" is removed)
73
+ model_name = self._model_id.split("/")[-1]
74
+ if model_name.endswith("-gguf"):
75
+ model_name = model_name[:-5] # Remove "-gguf" suffix
76
+ # Remove "-qat" from the name (it's not in the actual filename)
77
+ model_name = model_name.replace("-qat", "")
78
+ return model_name + ".gguf"
79
+
80
+ def load(self) -> None:
81
+ """
82
+ Load the model.
83
+
84
+ Raises:
85
+ RuntimeError: If the model fails to load
86
+ """
87
+ if self.is_loaded or self._load_failed:
88
+ return
89
+
90
+ try:
91
+ logger.debug(f"Loading LLM: {self._model_id}")
92
+
93
+ if self._is_gguf_model():
94
+ self._load_gguf_model()
95
+ else:
96
+ self._load_transformers_model()
97
+
98
+ logger.debug("LLM loaded successfully")
99
+
100
+ except Exception as e:
101
+ self._load_failed = True
102
+ error_msg = f"Failed to load LLM ({self._model_id}): {e}"
103
+ if "llama_cpp" in str(e).lower() or "llama-cpp" in str(e).lower():
104
+ error_msg += "\n Install with: pip install llama-cpp-python"
105
+ if "accelerate" in str(e):
106
+ error_msg += "\n Install with: pip install accelerate"
107
+ raise RuntimeError(error_msg) from e
108
+
109
+ def _load_gguf_model(self) -> None:
110
+ """Load GGUF model using llama-cpp-python."""
111
+ try:
112
+ from llama_cpp import Llama
113
+ from huggingface_hub import hf_hub_download
114
+ except ImportError as e:
115
+ raise ImportError(
116
+ "llama-cpp-python is required for GGUF models. "
117
+ "Install with: pip install llama-cpp-python"
118
+ ) from e
119
+
120
+ gguf_file = self._get_gguf_filename()
121
+ logger.debug(f"Loading GGUF model with file: {gguf_file}")
122
+
123
+ # Download the GGUF file from HuggingFace
124
+ model_path = hf_hub_download(
125
+ repo_id=self._model_id,
126
+ filename=gguf_file,
127
+ )
128
+
129
+ # Load with llama-cpp-python
130
+ self._llama_model = Llama(
131
+ model_path=model_path,
132
+ n_ctx=self._n_ctx,
133
+ n_gpu_layers=-1, # Use all GPU layers (Metal on Mac, CUDA on Linux)
134
+ verbose=False,
135
+ )
136
+
137
+ def _load_transformers_model(self) -> None:
138
+ """Load model using HuggingFace transformers."""
139
+ import torch
140
+ from transformers import AutoModelForCausalLM, AutoTokenizer
141
+
142
+ self._tokenizer = AutoTokenizer.from_pretrained(self._model_id)
143
+
144
+ if self._use_4bit:
145
+ try:
146
+ from transformers import BitsAndBytesConfig
147
+ quantization_config = BitsAndBytesConfig(
148
+ load_in_4bit=True,
149
+ bnb_4bit_compute_dtype=torch.float16,
150
+ )
151
+ self._transformers_model = AutoModelForCausalLM.from_pretrained(
152
+ self._model_id,
153
+ quantization_config=quantization_config,
154
+ device_map="auto",
155
+ )
156
+ except ImportError:
157
+ logger.debug("bitsandbytes not available, loading full precision")
158
+ self._transformers_model = AutoModelForCausalLM.from_pretrained(
159
+ self._model_id,
160
+ device_map="auto",
161
+ torch_dtype=torch.float16,
162
+ )
163
+ else:
164
+ self._transformers_model = AutoModelForCausalLM.from_pretrained(
165
+ self._model_id,
166
+ device_map="auto",
167
+ torch_dtype=torch.float16,
168
+ )
169
+
170
+ def generate(
171
+ self,
172
+ prompt: str,
173
+ max_tokens: int = 100,
174
+ stop: Optional[list[str]] = None,
175
+ ) -> str:
176
+ """
177
+ Generate text from a prompt.
178
+
179
+ Args:
180
+ prompt: The input prompt
181
+ max_tokens: Maximum tokens to generate
182
+ stop: Stop sequences
183
+
184
+ Returns:
185
+ Generated text (not including the prompt)
186
+ """
187
+ self.load()
188
+
189
+ if self._llama_model is not None:
190
+ return self._generate_with_llama(prompt, max_tokens, stop)
191
+ else:
192
+ return self._generate_with_transformers(prompt, max_tokens)
193
+
194
+ def _generate_with_llama(
195
+ self,
196
+ prompt: str,
197
+ max_tokens: int,
198
+ stop: Optional[list[str]],
199
+ ) -> str:
200
+ """Generate response using llama-cpp-python."""
201
+ output = self._llama_model(
202
+ prompt,
203
+ max_tokens=max_tokens,
204
+ stop=stop or ["\n\n", "</s>"],
205
+ echo=False,
206
+ )
207
+ return output["choices"][0]["text"]
208
+
209
+ def _generate_with_transformers(
210
+ self,
211
+ prompt: str,
212
+ max_tokens: int,
213
+ ) -> str:
214
+ """Generate response using transformers."""
215
+ import torch
216
+
217
+ inputs = self._tokenizer(prompt, return_tensors="pt").to(self._transformers_model.device)
218
+
219
+ with torch.no_grad():
220
+ outputs = self._transformers_model.generate(
221
+ **inputs,
222
+ max_new_tokens=max_tokens,
223
+ do_sample=False,
224
+ pad_token_id=self._tokenizer.pad_token_id,
225
+ )
226
+
227
+ return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
228
+
229
+
230
+ # Singleton instance for shared use
231
+ _default_llm: Optional[LLM] = None
232
+
233
+
234
+ def get_llm(
235
+ model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
236
+ **kwargs,
237
+ ) -> LLM:
238
+ """
239
+ Get or create a shared LLM instance.
240
+
241
+ Uses a singleton pattern to avoid loading the model multiple times.
242
+
243
+ Args:
244
+ model_id: HuggingFace model ID
245
+ **kwargs: Additional arguments passed to LLM constructor
246
+
247
+ Returns:
248
+ LLM instance
249
+ """
250
+ global _default_llm
251
+
252
+ if _default_llm is None or _default_llm._model_id != model_id:
253
+ _default_llm = LLM(model_id=model_id, **kwargs)
254
+
255
+ return _default_llm
@@ -0,0 +1,74 @@
1
+ """
2
+ Data models for the extraction pipeline.
3
+
4
+ This module contains all Pydantic models used throughout the pipeline stages:
5
+ - Stage 1 (Splitting): RawTriple
6
+ - Stage 2 (Extraction): ExtractedEntity, PipelineStatement
7
+ - Stage 3 (Qualification): EntityQualifiers, QualifiedEntity
8
+ - Stage 4 (Canonicalization): CanonicalMatch, CanonicalEntity
9
+ - Stage 5 (Labeling): StatementLabel, LabeledStatement
10
+
11
+ It also re-exports all models from the original models.py for backward compatibility.
12
+ """
13
+
14
+ # Import from the original models.py file (now a sibling at the same level)
15
+ # We need to import these BEFORE the local modules to avoid circular imports
16
+ import sys
17
+ import importlib.util
18
+ from pathlib import Path
19
+
20
+ # Manually load the old models.py to avoid conflict with this package
21
+ _models_py_path = Path(__file__).parent.parent / "models.py"
22
+ if _models_py_path.exists():
23
+ _spec = importlib.util.spec_from_file_location("_old_models", _models_py_path)
24
+ _old_models = importlib.util.module_from_spec(_spec)
25
+ _spec.loader.exec_module(_old_models)
26
+
27
+ # Re-export everything from the old models
28
+ Entity = _old_models.Entity
29
+ ExtractionMethod = _old_models.ExtractionMethod
30
+ Statement = _old_models.Statement
31
+ ExtractionResult = _old_models.ExtractionResult
32
+ PredicateMatch = _old_models.PredicateMatch
33
+ PredicateTaxonomy = _old_models.PredicateTaxonomy
34
+ PredicateComparisonConfig = _old_models.PredicateComparisonConfig
35
+ ScoringConfig = _old_models.ScoringConfig
36
+ ExtractionOptions = _old_models.ExtractionOptions
37
+
38
+ # Use EntityType from old models
39
+ EntityType = _old_models.EntityType
40
+ else:
41
+ # Fallback: define locally if old models.py doesn't exist
42
+ from .entity import EntityType
43
+
44
+ # New pipeline models
45
+ from .entity import ExtractedEntity
46
+ from .statement import RawTriple, PipelineStatement
47
+ from .qualifiers import EntityQualifiers, QualifiedEntity
48
+ from .canonical import CanonicalMatch, CanonicalEntity
49
+ from .labels import StatementLabel, LabeledStatement, TaxonomyResult
50
+
51
+ __all__ = [
52
+ # Re-exported from original models.py (backward compatibility)
53
+ "Entity",
54
+ "EntityType",
55
+ "ExtractionMethod",
56
+ "Statement",
57
+ "ExtractionResult",
58
+ "PredicateMatch",
59
+ "PredicateTaxonomy",
60
+ "PredicateComparisonConfig",
61
+ "ScoringConfig",
62
+ "ExtractionOptions",
63
+ # New pipeline models
64
+ "ExtractedEntity",
65
+ "RawTriple",
66
+ "PipelineStatement",
67
+ "EntityQualifiers",
68
+ "QualifiedEntity",
69
+ "CanonicalMatch",
70
+ "CanonicalEntity",
71
+ "StatementLabel",
72
+ "LabeledStatement",
73
+ "TaxonomyResult",
74
+ ]