corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +460 -21
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +32 -47
- statement_extractor/gliner_extraction.py +218 -0
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +15 -6
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- statement_extractor/scoring.py +17 -69
- corp_extractor-0.3.0.dist-info/RECORD +0 -12
- statement_extractor/spacy_extraction.py +0 -386
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
statement_extractor/extractor.py
CHANGED
|
@@ -721,16 +721,17 @@ class StatementExtractor:
|
|
|
721
721
|
Parse XML output into Statement objects.
|
|
722
722
|
|
|
723
723
|
Uses model for subject, object, entity types, and source_text.
|
|
724
|
-
Always uses
|
|
724
|
+
Always uses GLiNER2 for predicate extraction (model predicates are unreliable).
|
|
725
725
|
|
|
726
726
|
Produces two candidates for each statement:
|
|
727
|
-
1. Hybrid: model subject/object +
|
|
728
|
-
2.
|
|
727
|
+
1. Hybrid: model subject/object + GLiNER2 predicate
|
|
728
|
+
2. GLiNER2-only: all components from GLiNER2
|
|
729
729
|
|
|
730
730
|
Both go into the candidate pool; scoring/dedup picks the best.
|
|
731
731
|
"""
|
|
732
732
|
statements: list[Statement] = []
|
|
733
|
-
|
|
733
|
+
use_gliner_extraction = options.use_gliner_extraction if options else True
|
|
734
|
+
predicates = options.predicates if options else None
|
|
734
735
|
|
|
735
736
|
try:
|
|
736
737
|
root = ET.fromstring(xml_output)
|
|
@@ -780,75 +781,59 @@ class StatementExtractor:
|
|
|
780
781
|
logger.debug(f"Skipping statement: missing subject or object from model")
|
|
781
782
|
continue
|
|
782
783
|
|
|
783
|
-
if
|
|
784
|
+
if use_gliner_extraction and source_text:
|
|
784
785
|
try:
|
|
785
|
-
from .
|
|
786
|
-
|
|
786
|
+
from .gliner_extraction import extract_triple_from_text
|
|
787
|
+
|
|
788
|
+
# Get model predicate for fallback/refinement
|
|
789
|
+
predicate_elem = stmt_elem.find('predicate')
|
|
790
|
+
model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
791
|
+
|
|
792
|
+
gliner_result = extract_triple_from_text(
|
|
787
793
|
source_text=source_text,
|
|
788
794
|
model_subject=subject_text,
|
|
789
795
|
model_object=object_text,
|
|
790
|
-
model_predicate=
|
|
796
|
+
model_predicate=model_predicate,
|
|
797
|
+
predicates=predicates,
|
|
791
798
|
)
|
|
792
|
-
if
|
|
793
|
-
|
|
799
|
+
if gliner_result:
|
|
800
|
+
gliner_subj, gliner_pred, gliner_obj = gliner_result
|
|
794
801
|
|
|
795
|
-
if
|
|
796
|
-
# Candidate 1: Hybrid (model subject/object +
|
|
802
|
+
if gliner_pred:
|
|
803
|
+
# Candidate 1: Hybrid (model subject/object + GLiNER2 predicate)
|
|
797
804
|
logger.debug(
|
|
798
|
-
f"Adding hybrid candidate: '{subject_text}' --[{
|
|
805
|
+
f"Adding hybrid candidate: '{subject_text}' --[{gliner_pred}]--> '{object_text}'"
|
|
799
806
|
)
|
|
800
807
|
statements.append(Statement(
|
|
801
808
|
subject=Entity(text=subject_text, type=subject_type),
|
|
802
|
-
predicate=
|
|
809
|
+
predicate=gliner_pred,
|
|
803
810
|
object=Entity(text=object_text, type=object_type),
|
|
804
811
|
source_text=source_text,
|
|
805
812
|
extraction_method=ExtractionMethod.HYBRID,
|
|
806
813
|
))
|
|
807
814
|
|
|
808
|
-
# Candidate 2:
|
|
809
|
-
if
|
|
810
|
-
is_different = (
|
|
815
|
+
# Candidate 2: GLiNER2-only (if different from hybrid)
|
|
816
|
+
if gliner_subj and gliner_obj:
|
|
817
|
+
is_different = (gliner_subj != subject_text or gliner_obj != object_text)
|
|
811
818
|
if is_different:
|
|
812
819
|
logger.debug(
|
|
813
|
-
f"Adding
|
|
814
|
-
)
|
|
815
|
-
statements.append(Statement(
|
|
816
|
-
subject=Entity(text=spacy_subj, type=subject_type),
|
|
817
|
-
predicate=spacy_pred,
|
|
818
|
-
object=Entity(text=spacy_obj, type=object_type),
|
|
819
|
-
source_text=source_text,
|
|
820
|
-
extraction_method=ExtractionMethod.SPACY,
|
|
821
|
-
))
|
|
822
|
-
|
|
823
|
-
# Candidate 3: Predicate-split (split source text around predicate)
|
|
824
|
-
split_result = extract_triple_by_predicate_split(
|
|
825
|
-
source_text=source_text,
|
|
826
|
-
predicate=spacy_pred,
|
|
827
|
-
)
|
|
828
|
-
if split_result:
|
|
829
|
-
split_subj, split_pred, split_obj = split_result
|
|
830
|
-
# Only add if different from previous candidates
|
|
831
|
-
is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
|
|
832
|
-
is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
|
|
833
|
-
if is_different_from_hybrid and is_different_from_spacy:
|
|
834
|
-
logger.debug(
|
|
835
|
-
f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
|
|
820
|
+
f"Adding GLiNER2-only candidate: '{gliner_subj}' --[{gliner_pred}]--> '{gliner_obj}'"
|
|
836
821
|
)
|
|
837
822
|
statements.append(Statement(
|
|
838
|
-
subject=Entity(text=
|
|
839
|
-
predicate=
|
|
840
|
-
object=Entity(text=
|
|
823
|
+
subject=Entity(text=gliner_subj, type=subject_type),
|
|
824
|
+
predicate=gliner_pred,
|
|
825
|
+
object=Entity(text=gliner_obj, type=object_type),
|
|
841
826
|
source_text=source_text,
|
|
842
|
-
extraction_method=ExtractionMethod.
|
|
827
|
+
extraction_method=ExtractionMethod.GLINER,
|
|
843
828
|
))
|
|
844
829
|
else:
|
|
845
830
|
logger.debug(
|
|
846
|
-
f"
|
|
831
|
+
f"GLiNER2 found no predicate for: '{subject_text}' --> '{object_text}'"
|
|
847
832
|
)
|
|
848
833
|
except Exception as e:
|
|
849
|
-
logger.debug(f"
|
|
834
|
+
logger.debug(f"GLiNER2 extraction failed: {e}")
|
|
850
835
|
else:
|
|
851
|
-
#
|
|
836
|
+
# GLiNER2 disabled - fall back to model predicate
|
|
852
837
|
predicate_elem = stmt_elem.find('predicate')
|
|
853
838
|
model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
854
839
|
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GLiNER2-based triple extraction.
|
|
3
|
+
|
|
4
|
+
Uses GLiNER2 for relation extraction and entity recognition to extract
|
|
5
|
+
subject, predicate, and object from source text. T5-Gemma model provides
|
|
6
|
+
triple structure and coreference resolution, while GLiNER2 handles
|
|
7
|
+
linguistic analysis.
|
|
8
|
+
|
|
9
|
+
The GLiNER2 model is loaded automatically on first use.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Lazy-loaded GLiNER2 model
|
|
18
|
+
_model = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _get_model():
|
|
22
|
+
"""
|
|
23
|
+
Lazy-load the GLiNER2 model.
|
|
24
|
+
|
|
25
|
+
Uses the base model (205M parameters) which is CPU-optimized.
|
|
26
|
+
"""
|
|
27
|
+
global _model
|
|
28
|
+
if _model is None:
|
|
29
|
+
from gliner2 import GLiNER2
|
|
30
|
+
|
|
31
|
+
logger.info("Loading GLiNER2 model 'fastino/gliner2-base-v1'...")
|
|
32
|
+
_model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
|
33
|
+
logger.debug("GLiNER2 model loaded")
|
|
34
|
+
return _model
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def extract_triple_from_text(
|
|
38
|
+
source_text: str,
|
|
39
|
+
model_subject: str,
|
|
40
|
+
model_object: str,
|
|
41
|
+
model_predicate: str,
|
|
42
|
+
predicates: Optional[list[str]] = None,
|
|
43
|
+
) -> tuple[str, str, str] | None:
|
|
44
|
+
"""
|
|
45
|
+
Extract subject, predicate, object from source text using GLiNER2.
|
|
46
|
+
|
|
47
|
+
Returns a GLiNER2-based triple that can be added to the candidate pool
|
|
48
|
+
alongside the model's triple. The existing scoring/dedup logic will
|
|
49
|
+
pick the best one.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
source_text: The source sentence to analyze
|
|
53
|
+
model_subject: Subject from T5-Gemma (used for matching and fallback)
|
|
54
|
+
model_object: Object from T5-Gemma (used for matching and fallback)
|
|
55
|
+
model_predicate: Predicate from T5-Gemma (used when no predicates provided)
|
|
56
|
+
predicates: Optional list of predefined relation types to extract
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tuple of (subject, predicate, object) from GLiNER2, or None if extraction fails
|
|
60
|
+
"""
|
|
61
|
+
if not source_text:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
model = _get_model()
|
|
66
|
+
|
|
67
|
+
if predicates:
|
|
68
|
+
# Use relation extraction with predefined predicates
|
|
69
|
+
result = model.extract_relations(source_text, predicates)
|
|
70
|
+
|
|
71
|
+
# Find best matching relation
|
|
72
|
+
relation_data = result.get("relation_extraction", {})
|
|
73
|
+
best_match = None
|
|
74
|
+
best_confidence = 0.0
|
|
75
|
+
|
|
76
|
+
for rel_type, relations in relation_data.items():
|
|
77
|
+
for rel in relations:
|
|
78
|
+
# Handle both tuple format and dict format
|
|
79
|
+
if isinstance(rel, tuple):
|
|
80
|
+
head, tail = rel
|
|
81
|
+
confidence = 1.0
|
|
82
|
+
else:
|
|
83
|
+
head = rel.get("head", {}).get("text", "")
|
|
84
|
+
tail = rel.get("tail", {}).get("text", "")
|
|
85
|
+
confidence = min(
|
|
86
|
+
rel.get("head", {}).get("confidence", 0.5),
|
|
87
|
+
rel.get("tail", {}).get("confidence", 0.5)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Score based on match with model hints
|
|
91
|
+
score = confidence
|
|
92
|
+
if model_subject.lower() in head.lower() or head.lower() in model_subject.lower():
|
|
93
|
+
score += 0.2
|
|
94
|
+
if model_object.lower() in tail.lower() or tail.lower() in model_object.lower():
|
|
95
|
+
score += 0.2
|
|
96
|
+
|
|
97
|
+
if score > best_confidence:
|
|
98
|
+
best_confidence = score
|
|
99
|
+
best_match = (head, rel_type, tail)
|
|
100
|
+
|
|
101
|
+
if best_match:
|
|
102
|
+
logger.debug(
|
|
103
|
+
f"GLiNER2 extracted (relation): subj='{best_match[0]}', pred='{best_match[1]}', obj='{best_match[2]}'"
|
|
104
|
+
)
|
|
105
|
+
return best_match
|
|
106
|
+
|
|
107
|
+
else:
|
|
108
|
+
# No predicate list provided - use GLiNER2 for entity extraction
|
|
109
|
+
# and extract predicate from source text using the model's hint
|
|
110
|
+
|
|
111
|
+
# Extract entities to refine subject/object boundaries
|
|
112
|
+
entity_types = [
|
|
113
|
+
"person", "organization", "company", "location", "city", "country",
|
|
114
|
+
"product", "event", "date", "money", "quantity"
|
|
115
|
+
]
|
|
116
|
+
result = model.extract_entities(source_text, entity_types)
|
|
117
|
+
entities = result.get("entities", {})
|
|
118
|
+
|
|
119
|
+
# Find entities that match model subject/object
|
|
120
|
+
refined_subject = model_subject
|
|
121
|
+
refined_object = model_object
|
|
122
|
+
|
|
123
|
+
for entity_type, entity_list in entities.items():
|
|
124
|
+
for entity in entity_list:
|
|
125
|
+
entity_lower = entity.lower()
|
|
126
|
+
# Check if this entity matches or contains the model's subject/object
|
|
127
|
+
if model_subject.lower() in entity_lower or entity_lower in model_subject.lower():
|
|
128
|
+
# Use the entity text if it's more complete
|
|
129
|
+
if len(entity) >= len(refined_subject):
|
|
130
|
+
refined_subject = entity
|
|
131
|
+
if model_object.lower() in entity_lower or entity_lower in model_object.lower():
|
|
132
|
+
if len(entity) >= len(refined_object):
|
|
133
|
+
refined_object = entity
|
|
134
|
+
|
|
135
|
+
# Use model predicate directly (T5-Gemma provides the predicate)
|
|
136
|
+
if model_predicate:
|
|
137
|
+
logger.debug(
|
|
138
|
+
f"GLiNER2 extracted (entity-refined): subj='{refined_subject}', pred='{model_predicate}', obj='{refined_object}'"
|
|
139
|
+
)
|
|
140
|
+
return (refined_subject, model_predicate, refined_object)
|
|
141
|
+
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
except ImportError as e:
|
|
145
|
+
logger.warning(f"GLiNER2 not installed: {e}")
|
|
146
|
+
return None
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.debug(f"GLiNER2 extraction failed: {e}")
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def score_entity_content(text: str) -> float:
|
|
153
|
+
"""
|
|
154
|
+
Score how entity-like a text is using GLiNER2 entity recognition.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
1.0 - Recognized as a named entity with high confidence
|
|
158
|
+
0.8 - Recognized as an entity with moderate confidence
|
|
159
|
+
0.6 - Partially recognized or contains entity-like content
|
|
160
|
+
0.2 - Not recognized as any entity type
|
|
161
|
+
"""
|
|
162
|
+
if not text or not text.strip():
|
|
163
|
+
return 0.2
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
model = _get_model()
|
|
167
|
+
|
|
168
|
+
# Check if text is recognized as common entity types
|
|
169
|
+
entity_types = [
|
|
170
|
+
"person", "organization", "company", "location", "city", "country",
|
|
171
|
+
"product", "event", "date", "money", "quantity"
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
result = model.extract_entities(
|
|
175
|
+
text,
|
|
176
|
+
entity_types,
|
|
177
|
+
include_confidence=True
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Result format: {'entities': {'person': [{'text': '...', 'confidence': 0.99}], ...}}
|
|
181
|
+
entities_dict = result.get("entities", {})
|
|
182
|
+
|
|
183
|
+
# Find best matching entity across all types
|
|
184
|
+
best_confidence = 0.0
|
|
185
|
+
text_lower = text.lower().strip()
|
|
186
|
+
|
|
187
|
+
for entity_type, entity_list in entities_dict.items():
|
|
188
|
+
for entity in entity_list:
|
|
189
|
+
if isinstance(entity, dict):
|
|
190
|
+
entity_text = entity.get("text", "").lower().strip()
|
|
191
|
+
confidence = entity.get("confidence", 0.5)
|
|
192
|
+
else:
|
|
193
|
+
# Fallback for string format
|
|
194
|
+
entity_text = str(entity).lower().strip()
|
|
195
|
+
confidence = 0.8
|
|
196
|
+
|
|
197
|
+
# Check if entity covers most of the input text
|
|
198
|
+
if entity_text == text_lower:
|
|
199
|
+
# Exact match
|
|
200
|
+
best_confidence = max(best_confidence, confidence)
|
|
201
|
+
elif entity_text in text_lower or text_lower in entity_text:
|
|
202
|
+
# Partial match - reduce confidence
|
|
203
|
+
best_confidence = max(best_confidence, confidence * 0.8)
|
|
204
|
+
|
|
205
|
+
if best_confidence >= 0.9:
|
|
206
|
+
return 1.0
|
|
207
|
+
elif best_confidence >= 0.7:
|
|
208
|
+
return 0.8
|
|
209
|
+
elif best_confidence >= 0.5:
|
|
210
|
+
return 0.6
|
|
211
|
+
elif best_confidence > 0:
|
|
212
|
+
return 0.4
|
|
213
|
+
else:
|
|
214
|
+
return 0.2
|
|
215
|
+
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.debug(f"Entity scoring failed for '{text}': {e}")
|
|
218
|
+
return 0.5 # Neutral score on error
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM module for text generation using local models.
|
|
3
|
+
|
|
4
|
+
Supports:
|
|
5
|
+
- GGUF models via llama-cpp-python (efficient quantized inference)
|
|
6
|
+
- Transformers models via HuggingFace
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from statement_extractor.llm import LLM
|
|
10
|
+
|
|
11
|
+
llm = LLM() # Uses default Gemma3 12B GGUF
|
|
12
|
+
response = llm.generate("Your prompt here")
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LLM:
|
|
22
|
+
"""
|
|
23
|
+
LLM wrapper for text generation.
|
|
24
|
+
|
|
25
|
+
Automatically selects the best backend:
|
|
26
|
+
- GGUF models use llama-cpp-python (efficient, no de-quantization)
|
|
27
|
+
- Other models use HuggingFace transformers
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
|
|
33
|
+
gguf_file: Optional[str] = None,
|
|
34
|
+
n_ctx: int = 8192,
|
|
35
|
+
use_4bit: bool = True,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the LLM.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_id: HuggingFace model ID
|
|
42
|
+
gguf_file: GGUF filename (auto-detected if model_id ends with -gguf)
|
|
43
|
+
n_ctx: Context size for GGUF models
|
|
44
|
+
use_4bit: Use 4-bit quantization for transformers models
|
|
45
|
+
"""
|
|
46
|
+
self._model_id = model_id
|
|
47
|
+
self._gguf_file = gguf_file
|
|
48
|
+
self._n_ctx = n_ctx
|
|
49
|
+
self._use_4bit = use_4bit
|
|
50
|
+
|
|
51
|
+
# Model instances (lazy loaded)
|
|
52
|
+
self._llama_model = None # llama-cpp-python
|
|
53
|
+
self._transformers_model = None # HuggingFace transformers
|
|
54
|
+
self._tokenizer = None
|
|
55
|
+
|
|
56
|
+
self._load_failed = False
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def is_loaded(self) -> bool:
|
|
60
|
+
"""Check if the model is loaded."""
|
|
61
|
+
return self._llama_model is not None or self._transformers_model is not None
|
|
62
|
+
|
|
63
|
+
def _is_gguf_model(self) -> bool:
|
|
64
|
+
"""Check if the model ID is a GGUF model."""
|
|
65
|
+
return self._model_id.endswith("-gguf") or self._gguf_file is not None
|
|
66
|
+
|
|
67
|
+
def _get_gguf_filename(self) -> str:
|
|
68
|
+
"""Get the GGUF filename from the model ID."""
|
|
69
|
+
if self._gguf_file:
|
|
70
|
+
return self._gguf_file
|
|
71
|
+
# Extract filename from model ID like "google/gemma-3-12b-it-qat-q4_0-gguf"
|
|
72
|
+
# The actual file is "gemma-3-12b-it-q4_0.gguf" (note: "qat" is removed)
|
|
73
|
+
model_name = self._model_id.split("/")[-1]
|
|
74
|
+
if model_name.endswith("-gguf"):
|
|
75
|
+
model_name = model_name[:-5] # Remove "-gguf" suffix
|
|
76
|
+
# Remove "-qat" from the name (it's not in the actual filename)
|
|
77
|
+
model_name = model_name.replace("-qat", "")
|
|
78
|
+
return model_name + ".gguf"
|
|
79
|
+
|
|
80
|
+
def load(self) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Load the model.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
RuntimeError: If the model fails to load
|
|
86
|
+
"""
|
|
87
|
+
if self.is_loaded or self._load_failed:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
logger.debug(f"Loading LLM: {self._model_id}")
|
|
92
|
+
|
|
93
|
+
if self._is_gguf_model():
|
|
94
|
+
self._load_gguf_model()
|
|
95
|
+
else:
|
|
96
|
+
self._load_transformers_model()
|
|
97
|
+
|
|
98
|
+
logger.debug("LLM loaded successfully")
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
self._load_failed = True
|
|
102
|
+
error_msg = f"Failed to load LLM ({self._model_id}): {e}"
|
|
103
|
+
if "llama_cpp" in str(e).lower() or "llama-cpp" in str(e).lower():
|
|
104
|
+
error_msg += "\n Install with: pip install llama-cpp-python"
|
|
105
|
+
if "accelerate" in str(e):
|
|
106
|
+
error_msg += "\n Install with: pip install accelerate"
|
|
107
|
+
raise RuntimeError(error_msg) from e
|
|
108
|
+
|
|
109
|
+
def _load_gguf_model(self) -> None:
|
|
110
|
+
"""Load GGUF model using llama-cpp-python."""
|
|
111
|
+
try:
|
|
112
|
+
from llama_cpp import Llama
|
|
113
|
+
from huggingface_hub import hf_hub_download
|
|
114
|
+
except ImportError as e:
|
|
115
|
+
raise ImportError(
|
|
116
|
+
"llama-cpp-python is required for GGUF models. "
|
|
117
|
+
"Install with: pip install llama-cpp-python"
|
|
118
|
+
) from e
|
|
119
|
+
|
|
120
|
+
gguf_file = self._get_gguf_filename()
|
|
121
|
+
logger.debug(f"Loading GGUF model with file: {gguf_file}")
|
|
122
|
+
|
|
123
|
+
# Download the GGUF file from HuggingFace
|
|
124
|
+
model_path = hf_hub_download(
|
|
125
|
+
repo_id=self._model_id,
|
|
126
|
+
filename=gguf_file,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Load with llama-cpp-python
|
|
130
|
+
self._llama_model = Llama(
|
|
131
|
+
model_path=model_path,
|
|
132
|
+
n_ctx=self._n_ctx,
|
|
133
|
+
n_gpu_layers=-1, # Use all GPU layers (Metal on Mac, CUDA on Linux)
|
|
134
|
+
verbose=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _load_transformers_model(self) -> None:
|
|
138
|
+
"""Load model using HuggingFace transformers."""
|
|
139
|
+
import torch
|
|
140
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
141
|
+
|
|
142
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self._model_id)
|
|
143
|
+
|
|
144
|
+
if self._use_4bit:
|
|
145
|
+
try:
|
|
146
|
+
from transformers import BitsAndBytesConfig
|
|
147
|
+
quantization_config = BitsAndBytesConfig(
|
|
148
|
+
load_in_4bit=True,
|
|
149
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
150
|
+
)
|
|
151
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
152
|
+
self._model_id,
|
|
153
|
+
quantization_config=quantization_config,
|
|
154
|
+
device_map="auto",
|
|
155
|
+
)
|
|
156
|
+
except ImportError:
|
|
157
|
+
logger.debug("bitsandbytes not available, loading full precision")
|
|
158
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
159
|
+
self._model_id,
|
|
160
|
+
device_map="auto",
|
|
161
|
+
torch_dtype=torch.float16,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
165
|
+
self._model_id,
|
|
166
|
+
device_map="auto",
|
|
167
|
+
torch_dtype=torch.float16,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def generate(
|
|
171
|
+
self,
|
|
172
|
+
prompt: str,
|
|
173
|
+
max_tokens: int = 100,
|
|
174
|
+
stop: Optional[list[str]] = None,
|
|
175
|
+
) -> str:
|
|
176
|
+
"""
|
|
177
|
+
Generate text from a prompt.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
prompt: The input prompt
|
|
181
|
+
max_tokens: Maximum tokens to generate
|
|
182
|
+
stop: Stop sequences
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Generated text (not including the prompt)
|
|
186
|
+
"""
|
|
187
|
+
self.load()
|
|
188
|
+
|
|
189
|
+
if self._llama_model is not None:
|
|
190
|
+
return self._generate_with_llama(prompt, max_tokens, stop)
|
|
191
|
+
else:
|
|
192
|
+
return self._generate_with_transformers(prompt, max_tokens)
|
|
193
|
+
|
|
194
|
+
def _generate_with_llama(
|
|
195
|
+
self,
|
|
196
|
+
prompt: str,
|
|
197
|
+
max_tokens: int,
|
|
198
|
+
stop: Optional[list[str]],
|
|
199
|
+
) -> str:
|
|
200
|
+
"""Generate response using llama-cpp-python."""
|
|
201
|
+
output = self._llama_model(
|
|
202
|
+
prompt,
|
|
203
|
+
max_tokens=max_tokens,
|
|
204
|
+
stop=stop or ["\n\n", "</s>"],
|
|
205
|
+
echo=False,
|
|
206
|
+
)
|
|
207
|
+
return output["choices"][0]["text"]
|
|
208
|
+
|
|
209
|
+
def _generate_with_transformers(
|
|
210
|
+
self,
|
|
211
|
+
prompt: str,
|
|
212
|
+
max_tokens: int,
|
|
213
|
+
) -> str:
|
|
214
|
+
"""Generate response using transformers."""
|
|
215
|
+
import torch
|
|
216
|
+
|
|
217
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._transformers_model.device)
|
|
218
|
+
|
|
219
|
+
with torch.no_grad():
|
|
220
|
+
outputs = self._transformers_model.generate(
|
|
221
|
+
**inputs,
|
|
222
|
+
max_new_tokens=max_tokens,
|
|
223
|
+
do_sample=False,
|
|
224
|
+
pad_token_id=self._tokenizer.pad_token_id,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Singleton instance for shared use
|
|
231
|
+
_default_llm: Optional[LLM] = None
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_llm(
|
|
235
|
+
model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
|
|
236
|
+
**kwargs,
|
|
237
|
+
) -> LLM:
|
|
238
|
+
"""
|
|
239
|
+
Get or create a shared LLM instance.
|
|
240
|
+
|
|
241
|
+
Uses a singleton pattern to avoid loading the model multiple times.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
model_id: HuggingFace model ID
|
|
245
|
+
**kwargs: Additional arguments passed to LLM constructor
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
LLM instance
|
|
249
|
+
"""
|
|
250
|
+
global _default_llm
|
|
251
|
+
|
|
252
|
+
if _default_llm is None or _default_llm._model_id != model_id:
|
|
253
|
+
_default_llm = LLM(model_id=model_id, **kwargs)
|
|
254
|
+
|
|
255
|
+
return _default_llm
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for the extraction pipeline.
|
|
3
|
+
|
|
4
|
+
This module contains all Pydantic models used throughout the pipeline stages:
|
|
5
|
+
- Stage 1 (Splitting): RawTriple
|
|
6
|
+
- Stage 2 (Extraction): ExtractedEntity, PipelineStatement
|
|
7
|
+
- Stage 3 (Qualification): EntityQualifiers, QualifiedEntity
|
|
8
|
+
- Stage 4 (Canonicalization): CanonicalMatch, CanonicalEntity
|
|
9
|
+
- Stage 5 (Labeling): StatementLabel, LabeledStatement
|
|
10
|
+
|
|
11
|
+
It also re-exports all models from the original models.py for backward compatibility.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Import from the original models.py file (now a sibling at the same level)
|
|
15
|
+
# We need to import these BEFORE the local modules to avoid circular imports
|
|
16
|
+
import sys
|
|
17
|
+
import importlib.util
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
# Manually load the old models.py to avoid conflict with this package
|
|
21
|
+
_models_py_path = Path(__file__).parent.parent / "models.py"
|
|
22
|
+
if _models_py_path.exists():
|
|
23
|
+
_spec = importlib.util.spec_from_file_location("_old_models", _models_py_path)
|
|
24
|
+
_old_models = importlib.util.module_from_spec(_spec)
|
|
25
|
+
_spec.loader.exec_module(_old_models)
|
|
26
|
+
|
|
27
|
+
# Re-export everything from the old models
|
|
28
|
+
Entity = _old_models.Entity
|
|
29
|
+
ExtractionMethod = _old_models.ExtractionMethod
|
|
30
|
+
Statement = _old_models.Statement
|
|
31
|
+
ExtractionResult = _old_models.ExtractionResult
|
|
32
|
+
PredicateMatch = _old_models.PredicateMatch
|
|
33
|
+
PredicateTaxonomy = _old_models.PredicateTaxonomy
|
|
34
|
+
PredicateComparisonConfig = _old_models.PredicateComparisonConfig
|
|
35
|
+
ScoringConfig = _old_models.ScoringConfig
|
|
36
|
+
ExtractionOptions = _old_models.ExtractionOptions
|
|
37
|
+
|
|
38
|
+
# Use EntityType from old models
|
|
39
|
+
EntityType = _old_models.EntityType
|
|
40
|
+
else:
|
|
41
|
+
# Fallback: define locally if old models.py doesn't exist
|
|
42
|
+
from .entity import EntityType
|
|
43
|
+
|
|
44
|
+
# New pipeline models
|
|
45
|
+
from .entity import ExtractedEntity
|
|
46
|
+
from .statement import RawTriple, PipelineStatement
|
|
47
|
+
from .qualifiers import EntityQualifiers, QualifiedEntity
|
|
48
|
+
from .canonical import CanonicalMatch, CanonicalEntity
|
|
49
|
+
from .labels import StatementLabel, LabeledStatement, TaxonomyResult
|
|
50
|
+
|
|
51
|
+
__all__ = [
|
|
52
|
+
# Re-exported from original models.py (backward compatibility)
|
|
53
|
+
"Entity",
|
|
54
|
+
"EntityType",
|
|
55
|
+
"ExtractionMethod",
|
|
56
|
+
"Statement",
|
|
57
|
+
"ExtractionResult",
|
|
58
|
+
"PredicateMatch",
|
|
59
|
+
"PredicateTaxonomy",
|
|
60
|
+
"PredicateComparisonConfig",
|
|
61
|
+
"ScoringConfig",
|
|
62
|
+
"ExtractionOptions",
|
|
63
|
+
# New pipeline models
|
|
64
|
+
"ExtractedEntity",
|
|
65
|
+
"RawTriple",
|
|
66
|
+
"PipelineStatement",
|
|
67
|
+
"EntityQualifiers",
|
|
68
|
+
"QualifiedEntity",
|
|
69
|
+
"CanonicalMatch",
|
|
70
|
+
"CanonicalEntity",
|
|
71
|
+
"StatementLabel",
|
|
72
|
+
"LabeledStatement",
|
|
73
|
+
"TaxonomyResult",
|
|
74
|
+
]
|