corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +10 -1
- statement_extractor/cli.py +1663 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +6972 -0
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +89 -0
- statement_extractor/models/canonical.py +182 -0
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +220 -0
- statement_extractor/models/qualifiers.py +139 -0
- statement_extractor/models/statement.py +101 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +129 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +416 -0
- statement_extractor/pipeline/registry.py +303 -0
- statement_extractor/plugins/__init__.py +55 -0
- statement_extractor/plugins/base.py +716 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +546 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +386 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +30 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +185 -0
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +197 -0
- statement_extractor/plugins/qualifiers/person.py +785 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +293 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +484 -0
- statement_extractor/plugins/taxonomy/mnli.py +291 -0
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DocumentSummarizer - Generate document summaries using Gemma3.
|
|
3
|
+
|
|
4
|
+
Creates concise summaries focused on entities, events, and relationships
|
|
5
|
+
that are useful for providing context during extraction.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from ..models.document import Document
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DocumentSummarizer:
|
|
17
|
+
"""
|
|
18
|
+
Generates document summaries using the Gemma3 LLM.
|
|
19
|
+
|
|
20
|
+
Summaries focus on:
|
|
21
|
+
- Key entities mentioned
|
|
22
|
+
- Important events and actions
|
|
23
|
+
- Relationships between entities
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
MAX_INPUT_TOKENS = 10_000
|
|
27
|
+
DEFAULT_MAX_OUTPUT_TOKENS = 300
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
max_input_tokens: int = MAX_INPUT_TOKENS,
|
|
32
|
+
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the summarizer.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
max_input_tokens: Maximum tokens of input to send to the LLM
|
|
39
|
+
max_output_tokens: Maximum tokens for the summary output
|
|
40
|
+
"""
|
|
41
|
+
self._max_input_tokens = max_input_tokens
|
|
42
|
+
self._max_output_tokens = max_output_tokens
|
|
43
|
+
self._llm = None
|
|
44
|
+
self._tokenizer = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def llm(self):
|
|
48
|
+
"""Lazy-load the LLM."""
|
|
49
|
+
if self._llm is None:
|
|
50
|
+
from ..llm import get_llm
|
|
51
|
+
logger.debug("Loading LLM for summarization")
|
|
52
|
+
self._llm = get_llm()
|
|
53
|
+
return self._llm
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def tokenizer(self):
|
|
57
|
+
"""Lazy-load tokenizer for token counting."""
|
|
58
|
+
if self._tokenizer is None:
|
|
59
|
+
from transformers import AutoTokenizer
|
|
60
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
61
|
+
"Corp-o-Rate-Community/statement-extractor",
|
|
62
|
+
trust_remote_code=True,
|
|
63
|
+
)
|
|
64
|
+
return self._tokenizer
|
|
65
|
+
|
|
66
|
+
def _count_tokens(self, text: str) -> int:
|
|
67
|
+
"""Count tokens in text."""
|
|
68
|
+
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
69
|
+
|
|
70
|
+
def _truncate_to_tokens(self, text: str, max_tokens: int) -> str:
|
|
71
|
+
"""
|
|
72
|
+
Truncate text to a maximum number of tokens.
|
|
73
|
+
|
|
74
|
+
Tries to truncate at sentence boundaries when possible.
|
|
75
|
+
"""
|
|
76
|
+
token_count = self._count_tokens(text)
|
|
77
|
+
|
|
78
|
+
if token_count <= max_tokens:
|
|
79
|
+
return text
|
|
80
|
+
|
|
81
|
+
# Estimate chars per token
|
|
82
|
+
chars_per_token = len(text) / token_count
|
|
83
|
+
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% buffer
|
|
84
|
+
|
|
85
|
+
# Truncate
|
|
86
|
+
truncated = text[:target_chars]
|
|
87
|
+
|
|
88
|
+
# Try to end at a sentence boundary
|
|
89
|
+
last_period = truncated.rfind(". ")
|
|
90
|
+
last_newline = truncated.rfind("\n")
|
|
91
|
+
split_pos = max(last_period, last_newline)
|
|
92
|
+
|
|
93
|
+
if split_pos > target_chars * 0.7: # Don't lose too much text
|
|
94
|
+
truncated = truncated[:split_pos + 1]
|
|
95
|
+
|
|
96
|
+
logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars")
|
|
97
|
+
return truncated
|
|
98
|
+
|
|
99
|
+
def summarize(
|
|
100
|
+
self,
|
|
101
|
+
document: Document,
|
|
102
|
+
custom_prompt: Optional[str] = None,
|
|
103
|
+
) -> str:
|
|
104
|
+
"""
|
|
105
|
+
Generate a summary of the document.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
document: Document to summarize
|
|
109
|
+
custom_prompt: Optional custom prompt (uses default if not provided)
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Summary string
|
|
113
|
+
"""
|
|
114
|
+
if not document.full_text.strip():
|
|
115
|
+
logger.warning("Cannot summarize empty document")
|
|
116
|
+
return ""
|
|
117
|
+
|
|
118
|
+
logger.info(f"Generating summary for document {document.document_id}")
|
|
119
|
+
|
|
120
|
+
# Truncate text to max input tokens
|
|
121
|
+
text = self._truncate_to_tokens(document.full_text, self._max_input_tokens)
|
|
122
|
+
|
|
123
|
+
# Build prompt
|
|
124
|
+
if custom_prompt:
|
|
125
|
+
prompt = f"{custom_prompt}\n\n{text}"
|
|
126
|
+
else:
|
|
127
|
+
prompt = self._build_prompt(text, document)
|
|
128
|
+
|
|
129
|
+
# Generate summary
|
|
130
|
+
try:
|
|
131
|
+
summary = self.llm.generate(
|
|
132
|
+
prompt=prompt,
|
|
133
|
+
max_tokens=self._max_output_tokens,
|
|
134
|
+
stop=["\n\n\n", "---"],
|
|
135
|
+
)
|
|
136
|
+
summary = summary.strip()
|
|
137
|
+
logger.info(f"Generated summary ({len(summary)} chars):")
|
|
138
|
+
# Log summary with indentation for readability
|
|
139
|
+
for line in summary.split("\n"):
|
|
140
|
+
logger.info(f" {line}")
|
|
141
|
+
return summary
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Summary generation failed: {e}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
def _build_prompt(self, text: str, document: Document) -> str:
|
|
148
|
+
"""Build the summarization prompt."""
|
|
149
|
+
# Include document metadata context if available
|
|
150
|
+
context_parts = []
|
|
151
|
+
if document.metadata.title:
|
|
152
|
+
context_parts.append(f"Title: {document.metadata.title}")
|
|
153
|
+
if document.metadata.authors:
|
|
154
|
+
context_parts.append(f"Authors: {', '.join(document.metadata.authors)}")
|
|
155
|
+
if document.metadata.source_type:
|
|
156
|
+
context_parts.append(f"Source type: {document.metadata.source_type}")
|
|
157
|
+
|
|
158
|
+
context = "\n".join(context_parts) if context_parts else ""
|
|
159
|
+
|
|
160
|
+
prompt = f"""Summarize the following document, focusing on:
|
|
161
|
+
1. Key entities (companies, people, locations) mentioned
|
|
162
|
+
2. Important events, actions, and decisions
|
|
163
|
+
3. Relationships between entities
|
|
164
|
+
4. Main topics and themes
|
|
165
|
+
|
|
166
|
+
Keep the summary concise (2-3 paragraphs) and factual.
|
|
167
|
+
|
|
168
|
+
{context}
|
|
169
|
+
|
|
170
|
+
Document text:
|
|
171
|
+
{text}
|
|
172
|
+
|
|
173
|
+
Summary:"""
|
|
174
|
+
|
|
175
|
+
return prompt
|
|
176
|
+
|
|
177
|
+
def summarize_text(
|
|
178
|
+
self,
|
|
179
|
+
text: str,
|
|
180
|
+
title: Optional[str] = None,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Generate a summary from plain text.
|
|
184
|
+
|
|
185
|
+
Convenience method that creates a temporary Document.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
text: Text to summarize
|
|
189
|
+
title: Optional document title for context
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Summary string
|
|
193
|
+
"""
|
|
194
|
+
document = Document.from_text(text, title=title)
|
|
195
|
+
return self.summarize(document)
|
statement_extractor/extractor.py
CHANGED
|
@@ -783,7 +783,7 @@ class StatementExtractor:
|
|
|
783
783
|
|
|
784
784
|
if use_gliner_extraction and source_text:
|
|
785
785
|
try:
|
|
786
|
-
from .gliner_extraction import extract_triple_from_text
|
|
786
|
+
from .gliner_extraction import extract_triple_from_text
|
|
787
787
|
|
|
788
788
|
# Get model predicate for fallback/refinement
|
|
789
789
|
predicate_elem = stmt_elem.find('predicate')
|
|
@@ -826,28 +826,6 @@ class StatementExtractor:
|
|
|
826
826
|
source_text=source_text,
|
|
827
827
|
extraction_method=ExtractionMethod.GLINER,
|
|
828
828
|
))
|
|
829
|
-
|
|
830
|
-
# Candidate 3: Predicate-split (split source text around predicate)
|
|
831
|
-
split_result = extract_triple_by_predicate_split(
|
|
832
|
-
source_text=source_text,
|
|
833
|
-
predicate=gliner_pred,
|
|
834
|
-
)
|
|
835
|
-
if split_result:
|
|
836
|
-
split_subj, split_pred, split_obj = split_result
|
|
837
|
-
# Only add if different from previous candidates
|
|
838
|
-
is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
|
|
839
|
-
is_different_from_gliner = (split_subj != gliner_subj or split_obj != gliner_obj)
|
|
840
|
-
if is_different_from_hybrid and is_different_from_gliner:
|
|
841
|
-
logger.debug(
|
|
842
|
-
f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
|
|
843
|
-
)
|
|
844
|
-
statements.append(Statement(
|
|
845
|
-
subject=Entity(text=split_subj, type=subject_type),
|
|
846
|
-
predicate=split_pred,
|
|
847
|
-
object=Entity(text=split_obj, type=object_type),
|
|
848
|
-
source_text=source_text,
|
|
849
|
-
extraction_method=ExtractionMethod.SPLIT,
|
|
850
|
-
))
|
|
851
829
|
else:
|
|
852
830
|
logger.debug(
|
|
853
831
|
f"GLiNER2 found no predicate for: '{subject_text}' --> '{object_text}'"
|
|
@@ -132,18 +132,12 @@ def extract_triple_from_text(
|
|
|
132
132
|
if len(entity) >= len(refined_object):
|
|
133
133
|
refined_object = entity
|
|
134
134
|
|
|
135
|
-
#
|
|
136
|
-
|
|
137
|
-
if predicate_result:
|
|
138
|
-
_, extracted_predicate, _ = predicate_result
|
|
139
|
-
else:
|
|
140
|
-
extracted_predicate = model_predicate
|
|
141
|
-
|
|
142
|
-
if extracted_predicate:
|
|
135
|
+
# Use model predicate directly (T5-Gemma provides the predicate)
|
|
136
|
+
if model_predicate:
|
|
143
137
|
logger.debug(
|
|
144
|
-
f"GLiNER2 extracted (entity-refined): subj='{refined_subject}', pred='{
|
|
138
|
+
f"GLiNER2 extracted (entity-refined): subj='{refined_subject}', pred='{model_predicate}', obj='{refined_object}'"
|
|
145
139
|
)
|
|
146
|
-
return (refined_subject,
|
|
140
|
+
return (refined_subject, model_predicate, refined_object)
|
|
147
141
|
|
|
148
142
|
return None
|
|
149
143
|
|
|
@@ -155,70 +149,6 @@ def extract_triple_from_text(
|
|
|
155
149
|
return None
|
|
156
150
|
|
|
157
151
|
|
|
158
|
-
def extract_triple_by_predicate_split(
|
|
159
|
-
source_text: str,
|
|
160
|
-
predicate: str,
|
|
161
|
-
) -> tuple[str, str, str] | None:
|
|
162
|
-
"""
|
|
163
|
-
Extract subject and object by splitting the source text around the predicate.
|
|
164
|
-
|
|
165
|
-
This is useful when the predicate is known but subject/object boundaries
|
|
166
|
-
are uncertain. Uses the predicate as an anchor point.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
source_text: The source sentence
|
|
170
|
-
predicate: The predicate (verb phrase) to split on
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
Tuple of (subject, predicate, object) or None if split fails
|
|
174
|
-
"""
|
|
175
|
-
if not source_text or not predicate:
|
|
176
|
-
return None
|
|
177
|
-
|
|
178
|
-
# Find the predicate in the source text (case-insensitive)
|
|
179
|
-
source_lower = source_text.lower()
|
|
180
|
-
pred_lower = predicate.lower()
|
|
181
|
-
|
|
182
|
-
pred_pos = source_lower.find(pred_lower)
|
|
183
|
-
if pred_pos < 0:
|
|
184
|
-
# Try finding just the main verb (first word of predicate)
|
|
185
|
-
main_verb = pred_lower.split()[0] if pred_lower.split() else ""
|
|
186
|
-
if main_verb and len(main_verb) > 2:
|
|
187
|
-
pred_pos = source_lower.find(main_verb)
|
|
188
|
-
if pred_pos >= 0:
|
|
189
|
-
# Adjust to use the actual predicate length for splitting
|
|
190
|
-
predicate = main_verb
|
|
191
|
-
|
|
192
|
-
if pred_pos < 0:
|
|
193
|
-
return None
|
|
194
|
-
|
|
195
|
-
# Extract subject (text before predicate, trimmed)
|
|
196
|
-
subject = source_text[:pred_pos].strip()
|
|
197
|
-
|
|
198
|
-
# Extract object (text after predicate, trimmed)
|
|
199
|
-
pred_end = pred_pos + len(predicate)
|
|
200
|
-
obj = source_text[pred_end:].strip()
|
|
201
|
-
|
|
202
|
-
# Clean up: remove trailing punctuation from object
|
|
203
|
-
obj = obj.rstrip('.,;:!?')
|
|
204
|
-
|
|
205
|
-
# Clean up: remove leading articles/prepositions from object if very short
|
|
206
|
-
obj_words = obj.split()
|
|
207
|
-
if obj_words and obj_words[0].lower() in ('a', 'an', 'the', 'to', 'of', 'for'):
|
|
208
|
-
if len(obj_words) > 1:
|
|
209
|
-
obj = ' '.join(obj_words[1:])
|
|
210
|
-
|
|
211
|
-
# Validate: both subject and object should have meaningful content
|
|
212
|
-
if len(subject) < 2 or len(obj) < 2:
|
|
213
|
-
return None
|
|
214
|
-
|
|
215
|
-
logger.debug(
|
|
216
|
-
f"Predicate-split extracted: subj='{subject}', pred='{predicate}', obj='{obj}'"
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
return (subject, predicate, obj)
|
|
220
|
-
|
|
221
|
-
|
|
222
152
|
def score_entity_content(text: str) -> float:
|
|
223
153
|
"""
|
|
224
154
|
Score how entity-like a text is using GLiNER2 entity recognition.
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM module for text generation using local models.
|
|
3
|
+
|
|
4
|
+
Supports:
|
|
5
|
+
- GGUF models via llama-cpp-python (efficient quantized inference)
|
|
6
|
+
- Transformers models via HuggingFace
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from statement_extractor.llm import LLM
|
|
10
|
+
|
|
11
|
+
llm = LLM() # Uses default Gemma3 12B GGUF
|
|
12
|
+
response = llm.generate("Your prompt here")
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LLM:
|
|
22
|
+
"""
|
|
23
|
+
LLM wrapper for text generation.
|
|
24
|
+
|
|
25
|
+
Automatically selects the best backend:
|
|
26
|
+
- GGUF models use llama-cpp-python (efficient, no de-quantization)
|
|
27
|
+
- Other models use HuggingFace transformers
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
|
|
33
|
+
gguf_file: Optional[str] = None,
|
|
34
|
+
n_ctx: int = 8192,
|
|
35
|
+
use_4bit: bool = True,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the LLM.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_id: HuggingFace model ID
|
|
42
|
+
gguf_file: GGUF filename (auto-detected if model_id ends with -gguf)
|
|
43
|
+
n_ctx: Context size for GGUF models
|
|
44
|
+
use_4bit: Use 4-bit quantization for transformers models
|
|
45
|
+
"""
|
|
46
|
+
self._model_id = model_id
|
|
47
|
+
self._gguf_file = gguf_file
|
|
48
|
+
self._n_ctx = n_ctx
|
|
49
|
+
self._use_4bit = use_4bit
|
|
50
|
+
|
|
51
|
+
# Model instances (lazy loaded)
|
|
52
|
+
self._llama_model = None # llama-cpp-python
|
|
53
|
+
self._transformers_model = None # HuggingFace transformers
|
|
54
|
+
self._tokenizer = None
|
|
55
|
+
|
|
56
|
+
self._load_failed = False
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def is_loaded(self) -> bool:
|
|
60
|
+
"""Check if the model is loaded."""
|
|
61
|
+
return self._llama_model is not None or self._transformers_model is not None
|
|
62
|
+
|
|
63
|
+
def _is_gguf_model(self) -> bool:
|
|
64
|
+
"""Check if the model ID is a GGUF model."""
|
|
65
|
+
return self._model_id.endswith("-gguf") or self._gguf_file is not None
|
|
66
|
+
|
|
67
|
+
def _get_gguf_filename(self) -> str:
|
|
68
|
+
"""Get the GGUF filename from the model ID."""
|
|
69
|
+
if self._gguf_file:
|
|
70
|
+
return self._gguf_file
|
|
71
|
+
# Extract filename from model ID like "google/gemma-3-12b-it-qat-q4_0-gguf"
|
|
72
|
+
# The actual file is "gemma-3-12b-it-q4_0.gguf" (note: "qat" is removed)
|
|
73
|
+
model_name = self._model_id.split("/")[-1]
|
|
74
|
+
if model_name.endswith("-gguf"):
|
|
75
|
+
model_name = model_name[:-5] # Remove "-gguf" suffix
|
|
76
|
+
# Remove "-qat" from the name (it's not in the actual filename)
|
|
77
|
+
model_name = model_name.replace("-qat", "")
|
|
78
|
+
return model_name + ".gguf"
|
|
79
|
+
|
|
80
|
+
def load(self) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Load the model.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
RuntimeError: If the model fails to load
|
|
86
|
+
"""
|
|
87
|
+
if self.is_loaded or self._load_failed:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
logger.debug(f"Loading LLM: {self._model_id}")
|
|
92
|
+
|
|
93
|
+
if self._is_gguf_model():
|
|
94
|
+
self._load_gguf_model()
|
|
95
|
+
else:
|
|
96
|
+
self._load_transformers_model()
|
|
97
|
+
|
|
98
|
+
logger.debug("LLM loaded successfully")
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
self._load_failed = True
|
|
102
|
+
error_msg = f"Failed to load LLM ({self._model_id}): {e}"
|
|
103
|
+
if "llama_cpp" in str(e).lower() or "llama-cpp" in str(e).lower():
|
|
104
|
+
error_msg += "\n Install with: pip install llama-cpp-python"
|
|
105
|
+
if "accelerate" in str(e):
|
|
106
|
+
error_msg += "\n Install with: pip install accelerate"
|
|
107
|
+
raise RuntimeError(error_msg) from e
|
|
108
|
+
|
|
109
|
+
def _load_gguf_model(self) -> None:
|
|
110
|
+
"""Load GGUF model using llama-cpp-python."""
|
|
111
|
+
try:
|
|
112
|
+
from llama_cpp import Llama
|
|
113
|
+
from huggingface_hub import hf_hub_download
|
|
114
|
+
except ImportError as e:
|
|
115
|
+
raise ImportError(
|
|
116
|
+
"llama-cpp-python is required for GGUF models. "
|
|
117
|
+
"Install with: pip install llama-cpp-python"
|
|
118
|
+
) from e
|
|
119
|
+
|
|
120
|
+
gguf_file = self._get_gguf_filename()
|
|
121
|
+
logger.debug(f"Loading GGUF model with file: {gguf_file}")
|
|
122
|
+
|
|
123
|
+
# Download the GGUF file from HuggingFace
|
|
124
|
+
model_path = hf_hub_download(
|
|
125
|
+
repo_id=self._model_id,
|
|
126
|
+
filename=gguf_file,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Load with llama-cpp-python
|
|
130
|
+
self._llama_model = Llama(
|
|
131
|
+
model_path=model_path,
|
|
132
|
+
n_ctx=self._n_ctx,
|
|
133
|
+
n_gpu_layers=-1, # Use all GPU layers (Metal on Mac, CUDA on Linux)
|
|
134
|
+
verbose=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _load_transformers_model(self) -> None:
|
|
138
|
+
"""Load model using HuggingFace transformers."""
|
|
139
|
+
import torch
|
|
140
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
141
|
+
|
|
142
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self._model_id)
|
|
143
|
+
|
|
144
|
+
if self._use_4bit:
|
|
145
|
+
try:
|
|
146
|
+
from transformers import BitsAndBytesConfig
|
|
147
|
+
quantization_config = BitsAndBytesConfig(
|
|
148
|
+
load_in_4bit=True,
|
|
149
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
150
|
+
)
|
|
151
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
152
|
+
self._model_id,
|
|
153
|
+
quantization_config=quantization_config,
|
|
154
|
+
device_map="auto",
|
|
155
|
+
)
|
|
156
|
+
except ImportError:
|
|
157
|
+
logger.debug("bitsandbytes not available, loading full precision")
|
|
158
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
159
|
+
self._model_id,
|
|
160
|
+
device_map="auto",
|
|
161
|
+
torch_dtype=torch.float16,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
self._transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
165
|
+
self._model_id,
|
|
166
|
+
device_map="auto",
|
|
167
|
+
torch_dtype=torch.float16,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def generate(
|
|
171
|
+
self,
|
|
172
|
+
prompt: str,
|
|
173
|
+
max_tokens: int = 100,
|
|
174
|
+
stop: Optional[list[str]] = None,
|
|
175
|
+
) -> str:
|
|
176
|
+
"""
|
|
177
|
+
Generate text from a prompt.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
prompt: The input prompt
|
|
181
|
+
max_tokens: Maximum tokens to generate
|
|
182
|
+
stop: Stop sequences
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Generated text (not including the prompt)
|
|
186
|
+
"""
|
|
187
|
+
self.load()
|
|
188
|
+
|
|
189
|
+
if self._llama_model is not None:
|
|
190
|
+
return self._generate_with_llama(prompt, max_tokens, stop)
|
|
191
|
+
else:
|
|
192
|
+
return self._generate_with_transformers(prompt, max_tokens)
|
|
193
|
+
|
|
194
|
+
def _generate_with_llama(
|
|
195
|
+
self,
|
|
196
|
+
prompt: str,
|
|
197
|
+
max_tokens: int,
|
|
198
|
+
stop: Optional[list[str]],
|
|
199
|
+
) -> str:
|
|
200
|
+
"""Generate response using llama-cpp-python."""
|
|
201
|
+
output = self._llama_model(
|
|
202
|
+
prompt,
|
|
203
|
+
max_tokens=max_tokens,
|
|
204
|
+
stop=stop or ["\n\n", "</s>"],
|
|
205
|
+
echo=False,
|
|
206
|
+
)
|
|
207
|
+
return output["choices"][0]["text"]
|
|
208
|
+
|
|
209
|
+
def _generate_with_transformers(
|
|
210
|
+
self,
|
|
211
|
+
prompt: str,
|
|
212
|
+
max_tokens: int,
|
|
213
|
+
) -> str:
|
|
214
|
+
"""Generate response using transformers."""
|
|
215
|
+
import torch
|
|
216
|
+
|
|
217
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._transformers_model.device)
|
|
218
|
+
|
|
219
|
+
with torch.no_grad():
|
|
220
|
+
outputs = self._transformers_model.generate(
|
|
221
|
+
**inputs,
|
|
222
|
+
max_new_tokens=max_tokens,
|
|
223
|
+
do_sample=False,
|
|
224
|
+
pad_token_id=self._tokenizer.pad_token_id,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Singleton instance for shared use
|
|
231
|
+
_default_llm: Optional[LLM] = None
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_llm(
|
|
235
|
+
model_id: str = "google/gemma-3-12b-it-qat-q4_0-gguf",
|
|
236
|
+
**kwargs,
|
|
237
|
+
) -> LLM:
|
|
238
|
+
"""
|
|
239
|
+
Get or create a shared LLM instance.
|
|
240
|
+
|
|
241
|
+
Uses a singleton pattern to avoid loading the model multiple times.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
model_id: HuggingFace model ID
|
|
245
|
+
**kwargs: Additional arguments passed to LLM constructor
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
LLM instance
|
|
249
|
+
"""
|
|
250
|
+
global _default_llm
|
|
251
|
+
|
|
252
|
+
if _default_llm is None or _default_llm._model_id != model_id:
|
|
253
|
+
_default_llm = LLM(model_id=model_id, **kwargs)
|
|
254
|
+
|
|
255
|
+
return _default_llm
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for the extraction pipeline.
|
|
3
|
+
|
|
4
|
+
This module contains all Pydantic models used throughout the pipeline stages:
|
|
5
|
+
- Stage 1 (Splitting): RawTriple
|
|
6
|
+
- Stage 2 (Extraction): ExtractedEntity, PipelineStatement
|
|
7
|
+
- Stage 3 (Qualification): EntityQualifiers, QualifiedEntity
|
|
8
|
+
- Stage 4 (Canonicalization): CanonicalMatch, CanonicalEntity
|
|
9
|
+
- Stage 5 (Labeling): StatementLabel, LabeledStatement
|
|
10
|
+
|
|
11
|
+
It also re-exports all models from the original models.py for backward compatibility.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Import from the original models.py file (now a sibling at the same level)
|
|
15
|
+
# We need to import these BEFORE the local modules to avoid circular imports
|
|
16
|
+
import sys
|
|
17
|
+
import importlib.util
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
# Manually load the old models.py to avoid conflict with this package
|
|
21
|
+
_models_py_path = Path(__file__).parent.parent / "models.py"
|
|
22
|
+
if _models_py_path.exists():
|
|
23
|
+
_spec = importlib.util.spec_from_file_location("_old_models", _models_py_path)
|
|
24
|
+
_old_models = importlib.util.module_from_spec(_spec)
|
|
25
|
+
_spec.loader.exec_module(_old_models)
|
|
26
|
+
|
|
27
|
+
# Re-export everything from the old models
|
|
28
|
+
Entity = _old_models.Entity
|
|
29
|
+
ExtractionMethod = _old_models.ExtractionMethod
|
|
30
|
+
Statement = _old_models.Statement
|
|
31
|
+
ExtractionResult = _old_models.ExtractionResult
|
|
32
|
+
PredicateMatch = _old_models.PredicateMatch
|
|
33
|
+
PredicateTaxonomy = _old_models.PredicateTaxonomy
|
|
34
|
+
PredicateComparisonConfig = _old_models.PredicateComparisonConfig
|
|
35
|
+
ScoringConfig = _old_models.ScoringConfig
|
|
36
|
+
ExtractionOptions = _old_models.ExtractionOptions
|
|
37
|
+
|
|
38
|
+
# Use EntityType from old models
|
|
39
|
+
EntityType = _old_models.EntityType
|
|
40
|
+
else:
|
|
41
|
+
# Fallback: define locally if old models.py doesn't exist
|
|
42
|
+
from .entity import EntityType
|
|
43
|
+
|
|
44
|
+
# New pipeline models
|
|
45
|
+
from .entity import ExtractedEntity
|
|
46
|
+
from .statement import RawTriple, PipelineStatement
|
|
47
|
+
from .qualifiers import EntityQualifiers, QualifiedEntity, ResolvedRole, ResolvedOrganization
|
|
48
|
+
from .canonical import CanonicalMatch, CanonicalEntity
|
|
49
|
+
from .labels import StatementLabel, LabeledStatement, TaxonomyResult
|
|
50
|
+
from .document import (
|
|
51
|
+
Document,
|
|
52
|
+
DocumentMetadata,
|
|
53
|
+
DocumentPage,
|
|
54
|
+
TextChunk,
|
|
55
|
+
ChunkingConfig,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
# Re-exported from original models.py (backward compatibility)
|
|
60
|
+
"Entity",
|
|
61
|
+
"EntityType",
|
|
62
|
+
"ExtractionMethod",
|
|
63
|
+
"Statement",
|
|
64
|
+
"ExtractionResult",
|
|
65
|
+
"PredicateMatch",
|
|
66
|
+
"PredicateTaxonomy",
|
|
67
|
+
"PredicateComparisonConfig",
|
|
68
|
+
"ScoringConfig",
|
|
69
|
+
"ExtractionOptions",
|
|
70
|
+
# New pipeline models
|
|
71
|
+
"ExtractedEntity",
|
|
72
|
+
"RawTriple",
|
|
73
|
+
"PipelineStatement",
|
|
74
|
+
"EntityQualifiers",
|
|
75
|
+
"QualifiedEntity",
|
|
76
|
+
"ResolvedRole",
|
|
77
|
+
"ResolvedOrganization",
|
|
78
|
+
"CanonicalMatch",
|
|
79
|
+
"CanonicalEntity",
|
|
80
|
+
"StatementLabel",
|
|
81
|
+
"LabeledStatement",
|
|
82
|
+
"TaxonomyResult",
|
|
83
|
+
# Document models
|
|
84
|
+
"Document",
|
|
85
|
+
"DocumentMetadata",
|
|
86
|
+
"DocumentPage",
|
|
87
|
+
"TextChunk",
|
|
88
|
+
"ChunkingConfig",
|
|
89
|
+
]
|