corp-extractor 0.5.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +191 -24
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +1 -1
- statement_extractor/cli.py +1227 -10
- statement_extractor/data/statement_taxonomy.json +6949 -1159
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/models/__init__.py +16 -1
- statement_extractor/models/canonical.py +44 -1
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/labels.py +47 -18
- statement_extractor/models/qualifiers.py +51 -3
- statement_extractor/models/statement.py +26 -0
- statement_extractor/pipeline/config.py +6 -11
- statement_extractor/pipeline/orchestrator.py +80 -111
- statement_extractor/pipeline/registry.py +52 -46
- statement_extractor/plugins/__init__.py +20 -8
- statement_extractor/plugins/base.py +334 -64
- statement_extractor/plugins/extractors/gliner2.py +10 -0
- statement_extractor/plugins/labelers/taxonomy.py +18 -5
- statement_extractor/plugins/labelers/taxonomy_embedding.py +17 -6
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +11 -0
- statement_extractor/plugins/qualifiers/companies_house.py +14 -3
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +14 -3
- statement_extractor/plugins/qualifiers/person.py +578 -14
- statement_extractor/plugins/qualifiers/sec_edgar.py +14 -3
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/t5_gemma.py +158 -53
- statement_extractor/plugins/taxonomy/embedding.py +193 -46
- statement_extractor/plugins/taxonomy/mnli.py +16 -4
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.5.0.dist-info/RECORD +0 -55
- statement_extractor/plugins/canonicalizers/__init__.py +0 -17
- statement_extractor/plugins/canonicalizers/base.py +0 -9
- statement_extractor/plugins/canonicalizers/location.py +0 -219
- statement_extractor/plugins/canonicalizers/organization.py +0 -230
- statement_extractor/plugins/canonicalizers/person.py +0 -242
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
"""
|
|
2
2
|
SECEdgarQualifierPlugin - Qualifies US ORG entities with SEC data.
|
|
3
3
|
|
|
4
|
+
DEPRECATED: Use EmbeddingCompanyQualifier instead, which uses a local
|
|
5
|
+
embedding database with pre-loaded SEC Edgar data for faster, offline matching.
|
|
6
|
+
|
|
4
7
|
Uses the SEC EDGAR API to:
|
|
5
8
|
- Look up CIK (Central Index Key) by company name
|
|
6
9
|
- Retrieve ticker symbol, exchange, filing history
|
|
7
10
|
"""
|
|
8
11
|
|
|
9
12
|
import logging
|
|
13
|
+
import warnings
|
|
10
14
|
from typing import Optional
|
|
11
15
|
|
|
12
16
|
from ..base import BaseQualifierPlugin, PluginCapability
|
|
13
17
|
from ...pipeline.context import PipelineContext
|
|
14
|
-
from ...pipeline.registry import PluginRegistry
|
|
15
18
|
from ...models import ExtractedEntity, EntityQualifiers, EntityType
|
|
16
19
|
|
|
17
20
|
logger = logging.getLogger(__name__)
|
|
@@ -21,11 +24,12 @@ SEC_COMPANY_SEARCH = "https://efts.sec.gov/LATEST/search-index"
|
|
|
21
24
|
SEC_COMPANY_TICKERS = "https://www.sec.gov/files/company_tickers.json"
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
|
|
27
|
+
# DEPRECATED: Not auto-registered. Use EmbeddingCompanyQualifier instead.
|
|
25
28
|
class SECEdgarQualifierPlugin(BaseQualifierPlugin):
|
|
26
29
|
"""
|
|
27
|
-
|
|
30
|
+
DEPRECATED: Use EmbeddingCompanyQualifier instead.
|
|
28
31
|
|
|
32
|
+
Qualifier plugin for US ORG entities using SEC EDGAR.
|
|
29
33
|
Provides CIK and ticker symbol for publicly traded US companies.
|
|
30
34
|
"""
|
|
31
35
|
|
|
@@ -37,10 +41,17 @@ class SECEdgarQualifierPlugin(BaseQualifierPlugin):
|
|
|
37
41
|
"""
|
|
38
42
|
Initialize the SEC EDGAR qualifier.
|
|
39
43
|
|
|
44
|
+
DEPRECATED: Use EmbeddingCompanyQualifier instead.
|
|
45
|
+
|
|
40
46
|
Args:
|
|
41
47
|
timeout: API request timeout in seconds
|
|
42
48
|
cache_results: Whether to cache API results
|
|
43
49
|
"""
|
|
50
|
+
warnings.warn(
|
|
51
|
+
"SECEdgarQualifierPlugin is deprecated. Use EmbeddingCompanyQualifier instead.",
|
|
52
|
+
DeprecationWarning,
|
|
53
|
+
stacklevel=2,
|
|
54
|
+
)
|
|
44
55
|
self._timeout = timeout
|
|
45
56
|
self._cache_results = cache_results
|
|
46
57
|
self._cache: dict[str, Optional[dict]] = {}
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP scraper plugin for fetching web content.
|
|
3
|
+
|
|
4
|
+
Uses httpx for async HTTP requests with retries, timeouts, and CAPTCHA detection.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from ..base import BaseScraperPlugin, ContentType, ScraperResult
|
|
11
|
+
from ...pipeline.registry import PluginRegistry
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@PluginRegistry.scraper
|
|
17
|
+
class HttpScraperPlugin(BaseScraperPlugin):
|
|
18
|
+
"""
|
|
19
|
+
Default HTTP scraper using httpx with retries and timeouts.
|
|
20
|
+
|
|
21
|
+
Features:
|
|
22
|
+
- Async HTTP requests with httpx
|
|
23
|
+
- Automatic redirect following
|
|
24
|
+
- Content type detection from headers and URL
|
|
25
|
+
- CAPTCHA page detection
|
|
26
|
+
- Configurable timeout and retries
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
timeout: float = 30.0,
|
|
32
|
+
max_retries: int = 3,
|
|
33
|
+
user_agent: str = "Mozilla/5.0 (compatible; StatementExtractor/1.0; +https://github.com/corp-o-rate/statement-extractor)",
|
|
34
|
+
follow_redirects: bool = True,
|
|
35
|
+
):
|
|
36
|
+
self._timeout = timeout
|
|
37
|
+
self._max_retries = max_retries
|
|
38
|
+
self._user_agent = user_agent
|
|
39
|
+
self._follow_redirects = follow_redirects
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def name(self) -> str:
|
|
43
|
+
return "http_scraper"
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def priority(self) -> int:
|
|
47
|
+
return 100 # Default scraper
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def description(self) -> str:
|
|
51
|
+
return "Default HTTP scraper using httpx with retries and CAPTCHA detection"
|
|
52
|
+
|
|
53
|
+
async def fetch(self, url: str, timeout: Optional[float] = None) -> ScraperResult:
|
|
54
|
+
"""
|
|
55
|
+
Fetch content from a URL with retries and CAPTCHA detection.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
url: The URL to fetch
|
|
59
|
+
timeout: Request timeout in seconds (uses instance default if None)
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
ScraperResult with content, content type, and any errors
|
|
63
|
+
"""
|
|
64
|
+
import httpx
|
|
65
|
+
|
|
66
|
+
timeout = timeout or self._timeout
|
|
67
|
+
last_error: Optional[str] = None
|
|
68
|
+
|
|
69
|
+
for attempt in range(self._max_retries):
|
|
70
|
+
try:
|
|
71
|
+
async with httpx.AsyncClient(
|
|
72
|
+
timeout=timeout,
|
|
73
|
+
follow_redirects=self._follow_redirects,
|
|
74
|
+
) as client:
|
|
75
|
+
logger.debug(f"Fetching URL: {url} (attempt {attempt + 1})")
|
|
76
|
+
|
|
77
|
+
response = await client.get(
|
|
78
|
+
url,
|
|
79
|
+
headers={"User-Agent": self._user_agent},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
content_type = self._detect_content_type(
|
|
83
|
+
dict(response.headers), url
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Check for CAPTCHA if HTML
|
|
87
|
+
error = None
|
|
88
|
+
if content_type == ContentType.HTML:
|
|
89
|
+
if self._is_captcha_page(response.content):
|
|
90
|
+
error = "CAPTCHA or challenge page detected"
|
|
91
|
+
logger.warning(f"CAPTCHA detected at {url}")
|
|
92
|
+
|
|
93
|
+
return ScraperResult(
|
|
94
|
+
url=url,
|
|
95
|
+
final_url=str(response.url),
|
|
96
|
+
content=response.content,
|
|
97
|
+
content_type=content_type,
|
|
98
|
+
headers=dict(response.headers),
|
|
99
|
+
error=error,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except httpx.TimeoutException as e:
|
|
103
|
+
last_error = f"Request timed out after {timeout}s"
|
|
104
|
+
logger.warning(f"Timeout fetching {url}: {e}")
|
|
105
|
+
except httpx.ConnectError as e:
|
|
106
|
+
last_error = f"Connection error: {e}"
|
|
107
|
+
logger.warning(f"Connection error fetching {url}: {e}")
|
|
108
|
+
except httpx.HTTPStatusError as e:
|
|
109
|
+
last_error = f"HTTP {e.response.status_code}: {e.response.reason_phrase}"
|
|
110
|
+
logger.warning(f"HTTP error fetching {url}: {e}")
|
|
111
|
+
# Don't retry on 4xx errors
|
|
112
|
+
if 400 <= e.response.status_code < 500:
|
|
113
|
+
break
|
|
114
|
+
except Exception as e:
|
|
115
|
+
last_error = f"Unexpected error: {e}"
|
|
116
|
+
logger.exception(f"Error fetching {url}")
|
|
117
|
+
|
|
118
|
+
# All retries failed
|
|
119
|
+
return ScraperResult(
|
|
120
|
+
url=url,
|
|
121
|
+
final_url=url,
|
|
122
|
+
content=b"",
|
|
123
|
+
content_type=ContentType.UNKNOWN,
|
|
124
|
+
error=last_error or "Unknown error",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
async def head(self, url: str, timeout: Optional[float] = None) -> ScraperResult:
|
|
128
|
+
"""
|
|
129
|
+
Check content type without downloading the full body.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
url: The URL to check
|
|
133
|
+
timeout: Request timeout in seconds
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
ScraperResult with content_type populated (content is empty)
|
|
137
|
+
"""
|
|
138
|
+
import httpx
|
|
139
|
+
|
|
140
|
+
timeout = timeout or self._timeout
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
async with httpx.AsyncClient(
|
|
144
|
+
timeout=timeout,
|
|
145
|
+
follow_redirects=self._follow_redirects,
|
|
146
|
+
) as client:
|
|
147
|
+
response = await client.head(
|
|
148
|
+
url,
|
|
149
|
+
headers={"User-Agent": self._user_agent},
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
content_type = self._detect_content_type(
|
|
153
|
+
dict(response.headers), url
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return ScraperResult(
|
|
157
|
+
url=url,
|
|
158
|
+
final_url=str(response.url),
|
|
159
|
+
content=b"",
|
|
160
|
+
content_type=content_type,
|
|
161
|
+
headers=dict(response.headers),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.warning(f"HEAD request failed for {url}: {e}")
|
|
166
|
+
# Fall back to full fetch
|
|
167
|
+
return await self.fetch(url, timeout)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _detect_content_type(headers: dict[str, str], url: str) -> ContentType:
|
|
171
|
+
"""
|
|
172
|
+
Detect content type from HTTP headers and URL.
|
|
173
|
+
|
|
174
|
+
Priority:
|
|
175
|
+
1. Content-Type header
|
|
176
|
+
2. URL file extension
|
|
177
|
+
"""
|
|
178
|
+
content_type_header = headers.get("content-type", "").lower()
|
|
179
|
+
|
|
180
|
+
# Check Content-Type header
|
|
181
|
+
if "application/pdf" in content_type_header:
|
|
182
|
+
return ContentType.PDF
|
|
183
|
+
if any(mime in content_type_header for mime in [
|
|
184
|
+
"text/html",
|
|
185
|
+
"application/xhtml+xml",
|
|
186
|
+
]):
|
|
187
|
+
return ContentType.HTML
|
|
188
|
+
|
|
189
|
+
# Check URL extension
|
|
190
|
+
url_lower = url.lower().split("?")[0] # Remove query params
|
|
191
|
+
if url_lower.endswith(".pdf"):
|
|
192
|
+
return ContentType.PDF
|
|
193
|
+
if url_lower.endswith((".html", ".htm")):
|
|
194
|
+
return ContentType.HTML
|
|
195
|
+
|
|
196
|
+
# Default based on content-type
|
|
197
|
+
if content_type_header.startswith("text/"):
|
|
198
|
+
return ContentType.HTML
|
|
199
|
+
if content_type_header.startswith(("image/", "audio/", "video/")):
|
|
200
|
+
return ContentType.BINARY
|
|
201
|
+
|
|
202
|
+
return ContentType.UNKNOWN
|
|
203
|
+
|
|
204
|
+
@staticmethod
|
|
205
|
+
def _is_captcha_page(content: bytes) -> bool:
|
|
206
|
+
"""
|
|
207
|
+
Detect CAPTCHA or challenge pages.
|
|
208
|
+
|
|
209
|
+
Checks for common CAPTCHA patterns in HTML content.
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
html = content.decode("utf-8", errors="replace").lower()
|
|
213
|
+
except Exception:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
# Only check small pages (challenge pages are usually small)
|
|
217
|
+
if len(html) > 50000:
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
# Common CAPTCHA/challenge indicators
|
|
221
|
+
captcha_patterns = [
|
|
222
|
+
"captcha",
|
|
223
|
+
"cloudflare",
|
|
224
|
+
"checking your browser",
|
|
225
|
+
"please verify you are a human",
|
|
226
|
+
"access denied",
|
|
227
|
+
"bot protection",
|
|
228
|
+
"ddos protection",
|
|
229
|
+
"just a moment",
|
|
230
|
+
"enable javascript",
|
|
231
|
+
"please enable cookies",
|
|
232
|
+
"verify you are human",
|
|
233
|
+
"security check",
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
return any(pattern in html for pattern in captcha_patterns)
|
|
@@ -7,7 +7,6 @@ subject-predicate-object triples from text.
|
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
9
|
import re
|
|
10
|
-
import xml.etree.ElementTree as ET
|
|
11
10
|
from typing import Optional
|
|
12
11
|
|
|
13
12
|
from ..base import BaseSplitterPlugin, PluginCapability
|
|
@@ -62,12 +61,22 @@ class T5GemmaSplitter(BaseSplitterPlugin):
|
|
|
62
61
|
|
|
63
62
|
@property
|
|
64
63
|
def capabilities(self) -> PluginCapability:
|
|
65
|
-
return PluginCapability.LLM_REQUIRED
|
|
64
|
+
return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
|
|
66
65
|
|
|
67
66
|
@property
|
|
68
67
|
def description(self) -> str:
|
|
69
68
|
return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
|
|
70
69
|
|
|
70
|
+
@property
|
|
71
|
+
def model_vram_gb(self) -> float:
|
|
72
|
+
"""T5-Gemma2 model weights ~2GB in bfloat16."""
|
|
73
|
+
return 2.0
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def per_item_vram_gb(self) -> float:
|
|
77
|
+
"""Each text item during batch processing ~0.5GB for KV cache and activations."""
|
|
78
|
+
return 0.5
|
|
79
|
+
|
|
71
80
|
def _get_extractor(self):
|
|
72
81
|
"""Lazy-load the StatementExtractor."""
|
|
73
82
|
if self._extractor is None:
|
|
@@ -126,62 +135,158 @@ class T5GemmaSplitter(BaseSplitterPlugin):
|
|
|
126
135
|
logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
|
|
127
136
|
return raw_triples
|
|
128
137
|
|
|
138
|
+
def split_batch(
|
|
139
|
+
self,
|
|
140
|
+
texts: list[str],
|
|
141
|
+
context: PipelineContext,
|
|
142
|
+
) -> list[list[RawTriple]]:
|
|
143
|
+
"""
|
|
144
|
+
Split multiple texts into atomic triples using batch processing.
|
|
145
|
+
|
|
146
|
+
Processes all texts through the T5-Gemma2 model in batches
|
|
147
|
+
sized for optimal GPU utilization.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
texts: List of input texts to split
|
|
151
|
+
context: Pipeline context
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
List of RawTriple lists, one per input text
|
|
155
|
+
"""
|
|
156
|
+
if not texts:
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
batch_size = self.get_optimal_batch_size()
|
|
160
|
+
logger.info(f"T5GemmaSplitter batch processing {len(texts)} texts with batch_size={batch_size}")
|
|
161
|
+
|
|
162
|
+
# Get options from context
|
|
163
|
+
splitter_options = context.source_metadata.get("splitter_options", {})
|
|
164
|
+
num_beams = splitter_options.get("num_beams", self._num_beams)
|
|
165
|
+
diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
|
|
166
|
+
max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
|
|
167
|
+
|
|
168
|
+
# Create extraction options
|
|
169
|
+
from ...models import ExtractionOptions as LegacyExtractionOptions
|
|
170
|
+
options = LegacyExtractionOptions(
|
|
171
|
+
num_beams=num_beams,
|
|
172
|
+
diversity_penalty=diversity_penalty,
|
|
173
|
+
max_new_tokens=max_new_tokens,
|
|
174
|
+
use_gliner_extraction=False,
|
|
175
|
+
embedding_dedup=False,
|
|
176
|
+
deduplicate=False,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
extractor = self._get_extractor()
|
|
180
|
+
all_results: list[list[RawTriple]] = []
|
|
181
|
+
|
|
182
|
+
# Process in batches
|
|
183
|
+
for i in range(0, len(texts), batch_size):
|
|
184
|
+
batch_texts = texts[i:i + batch_size]
|
|
185
|
+
logger.debug(f"Processing batch {i // batch_size + 1}: {len(batch_texts)} texts")
|
|
186
|
+
|
|
187
|
+
batch_results = self._process_batch(batch_texts, extractor, options)
|
|
188
|
+
all_results.extend(batch_results)
|
|
189
|
+
|
|
190
|
+
total_triples = sum(len(r) for r in all_results)
|
|
191
|
+
logger.info(f"T5GemmaSplitter batch produced {total_triples} total triples from {len(texts)} texts")
|
|
192
|
+
return all_results
|
|
193
|
+
|
|
194
|
+
def _process_batch(
|
|
195
|
+
self,
|
|
196
|
+
texts: list[str],
|
|
197
|
+
extractor,
|
|
198
|
+
options,
|
|
199
|
+
) -> list[list[RawTriple]]:
|
|
200
|
+
"""
|
|
201
|
+
Process a batch of texts through the model.
|
|
202
|
+
|
|
203
|
+
Uses the model's batch generation capability for efficient GPU utilization.
|
|
204
|
+
"""
|
|
205
|
+
import torch
|
|
206
|
+
|
|
207
|
+
# Wrap texts in page tags
|
|
208
|
+
wrapped_texts = [f"<page>{t}</page>" if not t.startswith("<page>") else t for t in texts]
|
|
209
|
+
|
|
210
|
+
# Tokenize batch
|
|
211
|
+
tokenizer = extractor.tokenizer
|
|
212
|
+
model = extractor.model
|
|
213
|
+
|
|
214
|
+
inputs = tokenizer(
|
|
215
|
+
wrapped_texts,
|
|
216
|
+
return_tensors="pt",
|
|
217
|
+
max_length=4096,
|
|
218
|
+
truncation=True,
|
|
219
|
+
padding=True,
|
|
220
|
+
).to(extractor.device)
|
|
221
|
+
|
|
222
|
+
# Create stopping criteria
|
|
223
|
+
from ...extractor import StopOnSequence
|
|
224
|
+
from transformers import StoppingCriteriaList
|
|
225
|
+
|
|
226
|
+
input_length = inputs["input_ids"].shape[1]
|
|
227
|
+
stop_criteria = StopOnSequence(
|
|
228
|
+
tokenizer=tokenizer,
|
|
229
|
+
stop_sequence="</statements>",
|
|
230
|
+
input_length=input_length,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Generate for all texts in batch
|
|
234
|
+
with torch.no_grad():
|
|
235
|
+
outputs = model.generate(
|
|
236
|
+
**inputs,
|
|
237
|
+
max_new_tokens=options.max_new_tokens,
|
|
238
|
+
max_length=None,
|
|
239
|
+
num_beams=options.num_beams,
|
|
240
|
+
num_beam_groups=options.num_beams,
|
|
241
|
+
num_return_sequences=1, # One sequence per input for batch
|
|
242
|
+
diversity_penalty=options.diversity_penalty,
|
|
243
|
+
do_sample=False,
|
|
244
|
+
top_p=None,
|
|
245
|
+
top_k=None,
|
|
246
|
+
trust_remote_code=True,
|
|
247
|
+
custom_generate="transformers-community/group-beam-search",
|
|
248
|
+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Decode and parse each output
|
|
252
|
+
results: list[list[RawTriple]] = []
|
|
253
|
+
end_tag = "</statements>"
|
|
254
|
+
|
|
255
|
+
for output in outputs:
|
|
256
|
+
decoded = tokenizer.decode(output, skip_special_tokens=True)
|
|
257
|
+
|
|
258
|
+
# Truncate at </statements>
|
|
259
|
+
if end_tag in decoded:
|
|
260
|
+
end_pos = decoded.find(end_tag) + len(end_tag)
|
|
261
|
+
decoded = decoded[:end_pos]
|
|
262
|
+
|
|
263
|
+
triples = self._parse_xml_to_raw_triples(decoded)
|
|
264
|
+
results.append(triples)
|
|
265
|
+
|
|
266
|
+
return results
|
|
267
|
+
|
|
268
|
+
# Regex pattern to extract <text> content from <stmt> blocks
|
|
269
|
+
_STMT_TEXT_PATTERN = re.compile(r'<stmt>.*?<text>(.*?)</text>.*?</stmt>', re.DOTALL)
|
|
270
|
+
|
|
129
271
|
def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
|
|
130
|
-
"""
|
|
272
|
+
"""Extract source sentences from <stmt><text>...</text></stmt> blocks."""
|
|
131
273
|
raw_triples = []
|
|
132
274
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
logger.warning(f"XML parse error: {e}")
|
|
137
|
-
# Try to repair
|
|
138
|
-
xml_output = self._repair_xml(xml_output)
|
|
139
|
-
try:
|
|
140
|
-
root = ET.fromstring(xml_output)
|
|
141
|
-
except ET.ParseError:
|
|
142
|
-
logger.error("XML repair failed")
|
|
143
|
-
return raw_triples
|
|
144
|
-
|
|
145
|
-
if root.tag != "statements":
|
|
146
|
-
logger.warning(f"Unexpected root tag: {root.tag}")
|
|
147
|
-
return raw_triples
|
|
148
|
-
|
|
149
|
-
for stmt_elem in root.findall("stmt"):
|
|
150
|
-
try:
|
|
151
|
-
subject_elem = stmt_elem.find("subject")
|
|
152
|
-
predicate_elem = stmt_elem.find("predicate")
|
|
153
|
-
object_elem = stmt_elem.find("object")
|
|
154
|
-
text_elem = stmt_elem.find("text")
|
|
155
|
-
|
|
156
|
-
subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
|
|
157
|
-
predicate_text = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
158
|
-
object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
|
|
159
|
-
source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else ""
|
|
160
|
-
|
|
161
|
-
if subject_text and object_text and source_text:
|
|
162
|
-
raw_triples.append(RawTriple(
|
|
163
|
-
subject_text=subject_text,
|
|
164
|
-
predicate_text=predicate_text,
|
|
165
|
-
object_text=object_text,
|
|
166
|
-
source_sentence=source_text,
|
|
167
|
-
))
|
|
168
|
-
else:
|
|
169
|
-
logger.debug(f"Skipping incomplete triple: s={subject_text}, p={predicate_text}, o={object_text}")
|
|
170
|
-
|
|
171
|
-
except Exception as e:
|
|
172
|
-
logger.warning(f"Error parsing stmt element: {e}")
|
|
173
|
-
continue
|
|
275
|
+
# Find all <text> content within <stmt> blocks
|
|
276
|
+
text_matches = self._STMT_TEXT_PATTERN.findall(xml_output)
|
|
277
|
+
logger.debug(f"Found {len(text_matches)} stmt text blocks via regex")
|
|
174
278
|
|
|
175
|
-
|
|
279
|
+
for source_text in text_matches:
|
|
280
|
+
source_text = source_text.strip()
|
|
281
|
+
if source_text:
|
|
282
|
+
raw_triples.append(RawTriple(
|
|
283
|
+
subject_text="",
|
|
284
|
+
predicate_text="",
|
|
285
|
+
object_text="",
|
|
286
|
+
source_sentence=source_text,
|
|
287
|
+
))
|
|
176
288
|
|
|
177
|
-
|
|
178
|
-
"""Attempt to repair common XML syntax errors."""
|
|
179
|
-
# Use the repair function from extractor.py
|
|
180
|
-
from ...extractor import repair_xml
|
|
181
|
-
repaired, repairs = repair_xml(xml_string)
|
|
182
|
-
if repairs:
|
|
183
|
-
logger.debug(f"XML repairs: {', '.join(repairs)}")
|
|
184
|
-
return repaired
|
|
289
|
+
return raw_triples
|
|
185
290
|
|
|
186
291
|
|
|
187
292
|
# Allow importing without decorator for testing
|