corp-extractor 0.5.0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/METADATA +228 -30
  2. corp_extractor-0.9.3.dist-info/RECORD +79 -0
  3. statement_extractor/__init__.py +1 -1
  4. statement_extractor/cli.py +2030 -24
  5. statement_extractor/data/statement_taxonomy.json +6949 -1159
  6. statement_extractor/database/__init__.py +52 -0
  7. statement_extractor/database/embeddings.py +186 -0
  8. statement_extractor/database/hub.py +428 -0
  9. statement_extractor/database/importers/__init__.py +32 -0
  10. statement_extractor/database/importers/companies_house.py +559 -0
  11. statement_extractor/database/importers/companies_house_officers.py +431 -0
  12. statement_extractor/database/importers/gleif.py +561 -0
  13. statement_extractor/database/importers/sec_edgar.py +392 -0
  14. statement_extractor/database/importers/sec_form4.py +512 -0
  15. statement_extractor/database/importers/wikidata.py +1120 -0
  16. statement_extractor/database/importers/wikidata_dump.py +1951 -0
  17. statement_extractor/database/importers/wikidata_people.py +1130 -0
  18. statement_extractor/database/models.py +254 -0
  19. statement_extractor/database/resolver.py +245 -0
  20. statement_extractor/database/store.py +3034 -0
  21. statement_extractor/document/__init__.py +62 -0
  22. statement_extractor/document/chunker.py +410 -0
  23. statement_extractor/document/context.py +171 -0
  24. statement_extractor/document/deduplicator.py +171 -0
  25. statement_extractor/document/html_extractor.py +246 -0
  26. statement_extractor/document/loader.py +303 -0
  27. statement_extractor/document/pipeline.py +388 -0
  28. statement_extractor/document/summarizer.py +195 -0
  29. statement_extractor/extractor.py +1 -1
  30. statement_extractor/models/__init__.py +19 -3
  31. statement_extractor/models/canonical.py +44 -1
  32. statement_extractor/models/document.py +308 -0
  33. statement_extractor/models/labels.py +47 -18
  34. statement_extractor/models/qualifiers.py +51 -3
  35. statement_extractor/models/statement.py +39 -15
  36. statement_extractor/models.py +1 -1
  37. statement_extractor/pipeline/config.py +6 -11
  38. statement_extractor/pipeline/context.py +5 -5
  39. statement_extractor/pipeline/orchestrator.py +90 -121
  40. statement_extractor/pipeline/registry.py +52 -46
  41. statement_extractor/plugins/__init__.py +20 -8
  42. statement_extractor/plugins/base.py +348 -78
  43. statement_extractor/plugins/extractors/gliner2.py +38 -28
  44. statement_extractor/plugins/labelers/taxonomy.py +18 -5
  45. statement_extractor/plugins/labelers/taxonomy_embedding.py +17 -6
  46. statement_extractor/plugins/pdf/__init__.py +10 -0
  47. statement_extractor/plugins/pdf/pypdf.py +291 -0
  48. statement_extractor/plugins/qualifiers/__init__.py +11 -0
  49. statement_extractor/plugins/qualifiers/companies_house.py +14 -3
  50. statement_extractor/plugins/qualifiers/embedding_company.py +422 -0
  51. statement_extractor/plugins/qualifiers/gleif.py +14 -3
  52. statement_extractor/plugins/qualifiers/person.py +588 -14
  53. statement_extractor/plugins/qualifiers/sec_edgar.py +14 -3
  54. statement_extractor/plugins/scrapers/__init__.py +10 -0
  55. statement_extractor/plugins/scrapers/http.py +236 -0
  56. statement_extractor/plugins/splitters/t5_gemma.py +176 -75
  57. statement_extractor/plugins/taxonomy/embedding.py +193 -46
  58. statement_extractor/plugins/taxonomy/mnli.py +16 -4
  59. statement_extractor/scoring.py +8 -8
  60. corp_extractor-0.5.0.dist-info/RECORD +0 -55
  61. statement_extractor/plugins/canonicalizers/__init__.py +0 -17
  62. statement_extractor/plugins/canonicalizers/base.py +0 -9
  63. statement_extractor/plugins/canonicalizers/location.py +0 -219
  64. statement_extractor/plugins/canonicalizers/organization.py +0 -230
  65. statement_extractor/plugins/canonicalizers/person.py +0 -242
  66. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/WHEEL +0 -0
  67. {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,20 @@
1
1
  """
2
2
  SECEdgarQualifierPlugin - Qualifies US ORG entities with SEC data.
3
3
 
4
+ DEPRECATED: Use EmbeddingCompanyQualifier instead, which uses a local
5
+ embedding database with pre-loaded SEC Edgar data for faster, offline matching.
6
+
4
7
  Uses the SEC EDGAR API to:
5
8
  - Look up CIK (Central Index Key) by company name
6
9
  - Retrieve ticker symbol, exchange, filing history
7
10
  """
8
11
 
9
12
  import logging
13
+ import warnings
10
14
  from typing import Optional
11
15
 
12
16
  from ..base import BaseQualifierPlugin, PluginCapability
13
17
  from ...pipeline.context import PipelineContext
14
- from ...pipeline.registry import PluginRegistry
15
18
  from ...models import ExtractedEntity, EntityQualifiers, EntityType
16
19
 
17
20
  logger = logging.getLogger(__name__)
@@ -21,11 +24,12 @@ SEC_COMPANY_SEARCH = "https://efts.sec.gov/LATEST/search-index"
21
24
  SEC_COMPANY_TICKERS = "https://www.sec.gov/files/company_tickers.json"
22
25
 
23
26
 
24
- @PluginRegistry.qualifier
27
+ # DEPRECATED: Not auto-registered. Use EmbeddingCompanyQualifier instead.
25
28
  class SECEdgarQualifierPlugin(BaseQualifierPlugin):
26
29
  """
27
- Qualifier plugin for US ORG entities using SEC EDGAR.
30
+ DEPRECATED: Use EmbeddingCompanyQualifier instead.
28
31
 
32
+ Qualifier plugin for US ORG entities using SEC EDGAR.
29
33
  Provides CIK and ticker symbol for publicly traded US companies.
30
34
  """
31
35
 
@@ -37,10 +41,17 @@ class SECEdgarQualifierPlugin(BaseQualifierPlugin):
37
41
  """
38
42
  Initialize the SEC EDGAR qualifier.
39
43
 
44
+ DEPRECATED: Use EmbeddingCompanyQualifier instead.
45
+
40
46
  Args:
41
47
  timeout: API request timeout in seconds
42
48
  cache_results: Whether to cache API results
43
49
  """
50
+ warnings.warn(
51
+ "SECEdgarQualifierPlugin is deprecated. Use EmbeddingCompanyQualifier instead.",
52
+ DeprecationWarning,
53
+ stacklevel=2,
54
+ )
44
55
  self._timeout = timeout
45
56
  self._cache_results = cache_results
46
57
  self._cache: dict[str, Optional[dict]] = {}
@@ -0,0 +1,10 @@
1
+ """
2
+ Scraper plugins for fetching content from URLs.
3
+
4
+ Built-in scrapers:
5
+ - http_scraper: Default HTTP scraper using httpx with retries
6
+ """
7
+
8
+ from .http import HttpScraperPlugin
9
+
10
+ __all__ = ["HttpScraperPlugin"]
@@ -0,0 +1,236 @@
1
+ """
2
+ HTTP scraper plugin for fetching web content.
3
+
4
+ Uses httpx for async HTTP requests with retries, timeouts, and CAPTCHA detection.
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ from ..base import BaseScraperPlugin, ContentType, ScraperResult
11
+ from ...pipeline.registry import PluginRegistry
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @PluginRegistry.scraper
17
+ class HttpScraperPlugin(BaseScraperPlugin):
18
+ """
19
+ Default HTTP scraper using httpx with retries and timeouts.
20
+
21
+ Features:
22
+ - Async HTTP requests with httpx
23
+ - Automatic redirect following
24
+ - Content type detection from headers and URL
25
+ - CAPTCHA page detection
26
+ - Configurable timeout and retries
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ timeout: float = 30.0,
32
+ max_retries: int = 3,
33
+ user_agent: str = "Mozilla/5.0 (compatible; StatementExtractor/1.0; +https://github.com/corp-o-rate/statement-extractor)",
34
+ follow_redirects: bool = True,
35
+ ):
36
+ self._timeout = timeout
37
+ self._max_retries = max_retries
38
+ self._user_agent = user_agent
39
+ self._follow_redirects = follow_redirects
40
+
41
+ @property
42
+ def name(self) -> str:
43
+ return "http_scraper"
44
+
45
+ @property
46
+ def priority(self) -> int:
47
+ return 100 # Default scraper
48
+
49
+ @property
50
+ def description(self) -> str:
51
+ return "Default HTTP scraper using httpx with retries and CAPTCHA detection"
52
+
53
+ async def fetch(self, url: str, timeout: Optional[float] = None) -> ScraperResult:
54
+ """
55
+ Fetch content from a URL with retries and CAPTCHA detection.
56
+
57
+ Args:
58
+ url: The URL to fetch
59
+ timeout: Request timeout in seconds (uses instance default if None)
60
+
61
+ Returns:
62
+ ScraperResult with content, content type, and any errors
63
+ """
64
+ import httpx
65
+
66
+ timeout = timeout or self._timeout
67
+ last_error: Optional[str] = None
68
+
69
+ for attempt in range(self._max_retries):
70
+ try:
71
+ async with httpx.AsyncClient(
72
+ timeout=timeout,
73
+ follow_redirects=self._follow_redirects,
74
+ ) as client:
75
+ logger.debug(f"Fetching URL: {url} (attempt {attempt + 1})")
76
+
77
+ response = await client.get(
78
+ url,
79
+ headers={"User-Agent": self._user_agent},
80
+ )
81
+
82
+ content_type = self._detect_content_type(
83
+ dict(response.headers), url
84
+ )
85
+
86
+ # Check for CAPTCHA if HTML
87
+ error = None
88
+ if content_type == ContentType.HTML:
89
+ if self._is_captcha_page(response.content):
90
+ error = "CAPTCHA or challenge page detected"
91
+ logger.warning(f"CAPTCHA detected at {url}")
92
+
93
+ return ScraperResult(
94
+ url=url,
95
+ final_url=str(response.url),
96
+ content=response.content,
97
+ content_type=content_type,
98
+ headers=dict(response.headers),
99
+ error=error,
100
+ )
101
+
102
+ except httpx.TimeoutException as e:
103
+ last_error = f"Request timed out after {timeout}s"
104
+ logger.warning(f"Timeout fetching {url}: {e}")
105
+ except httpx.ConnectError as e:
106
+ last_error = f"Connection error: {e}"
107
+ logger.warning(f"Connection error fetching {url}: {e}")
108
+ except httpx.HTTPStatusError as e:
109
+ last_error = f"HTTP {e.response.status_code}: {e.response.reason_phrase}"
110
+ logger.warning(f"HTTP error fetching {url}: {e}")
111
+ # Don't retry on 4xx errors
112
+ if 400 <= e.response.status_code < 500:
113
+ break
114
+ except Exception as e:
115
+ last_error = f"Unexpected error: {e}"
116
+ logger.exception(f"Error fetching {url}")
117
+
118
+ # All retries failed
119
+ return ScraperResult(
120
+ url=url,
121
+ final_url=url,
122
+ content=b"",
123
+ content_type=ContentType.UNKNOWN,
124
+ error=last_error or "Unknown error",
125
+ )
126
+
127
+ async def head(self, url: str, timeout: Optional[float] = None) -> ScraperResult:
128
+ """
129
+ Check content type without downloading the full body.
130
+
131
+ Args:
132
+ url: The URL to check
133
+ timeout: Request timeout in seconds
134
+
135
+ Returns:
136
+ ScraperResult with content_type populated (content is empty)
137
+ """
138
+ import httpx
139
+
140
+ timeout = timeout or self._timeout
141
+
142
+ try:
143
+ async with httpx.AsyncClient(
144
+ timeout=timeout,
145
+ follow_redirects=self._follow_redirects,
146
+ ) as client:
147
+ response = await client.head(
148
+ url,
149
+ headers={"User-Agent": self._user_agent},
150
+ )
151
+
152
+ content_type = self._detect_content_type(
153
+ dict(response.headers), url
154
+ )
155
+
156
+ return ScraperResult(
157
+ url=url,
158
+ final_url=str(response.url),
159
+ content=b"",
160
+ content_type=content_type,
161
+ headers=dict(response.headers),
162
+ )
163
+
164
+ except Exception as e:
165
+ logger.warning(f"HEAD request failed for {url}: {e}")
166
+ # Fall back to full fetch
167
+ return await self.fetch(url, timeout)
168
+
169
+ @staticmethod
170
+ def _detect_content_type(headers: dict[str, str], url: str) -> ContentType:
171
+ """
172
+ Detect content type from HTTP headers and URL.
173
+
174
+ Priority:
175
+ 1. Content-Type header
176
+ 2. URL file extension
177
+ """
178
+ content_type_header = headers.get("content-type", "").lower()
179
+
180
+ # Check Content-Type header
181
+ if "application/pdf" in content_type_header:
182
+ return ContentType.PDF
183
+ if any(mime in content_type_header for mime in [
184
+ "text/html",
185
+ "application/xhtml+xml",
186
+ ]):
187
+ return ContentType.HTML
188
+
189
+ # Check URL extension
190
+ url_lower = url.lower().split("?")[0] # Remove query params
191
+ if url_lower.endswith(".pdf"):
192
+ return ContentType.PDF
193
+ if url_lower.endswith((".html", ".htm")):
194
+ return ContentType.HTML
195
+
196
+ # Default based on content-type
197
+ if content_type_header.startswith("text/"):
198
+ return ContentType.HTML
199
+ if content_type_header.startswith(("image/", "audio/", "video/")):
200
+ return ContentType.BINARY
201
+
202
+ return ContentType.UNKNOWN
203
+
204
+ @staticmethod
205
+ def _is_captcha_page(content: bytes) -> bool:
206
+ """
207
+ Detect CAPTCHA or challenge pages.
208
+
209
+ Checks for common CAPTCHA patterns in HTML content.
210
+ """
211
+ try:
212
+ html = content.decode("utf-8", errors="replace").lower()
213
+ except Exception:
214
+ return False
215
+
216
+ # Only check small pages (challenge pages are usually small)
217
+ if len(html) > 50000:
218
+ return False
219
+
220
+ # Common CAPTCHA/challenge indicators
221
+ captcha_patterns = [
222
+ "captcha",
223
+ "cloudflare",
224
+ "checking your browser",
225
+ "please verify you are a human",
226
+ "access denied",
227
+ "bot protection",
228
+ "ddos protection",
229
+ "just a moment",
230
+ "enable javascript",
231
+ "please enable cookies",
232
+ "verify you are human",
233
+ "security check",
234
+ ]
235
+
236
+ return any(pattern in html for pattern in captcha_patterns)
@@ -1,19 +1,18 @@
1
1
  """
2
- T5GemmaSplitter - Stage 1 plugin that wraps the existing StatementExtractor.
2
+ T5GemmaSplitter - Stage 1 plugin that splits text into atomic sentences.
3
3
 
4
- Uses T5-Gemma2 model with Diverse Beam Search to generate high-quality
5
- subject-predicate-object triples from text.
4
+ Uses T5-Gemma2 model with Diverse Beam Search to split unstructured text
5
+ into atomic statements that can be converted to triples in Stage 2.
6
6
  """
7
7
 
8
8
  import logging
9
9
  import re
10
- import xml.etree.ElementTree as ET
11
10
  from typing import Optional
12
11
 
13
12
  from ..base import BaseSplitterPlugin, PluginCapability
14
13
  from ...pipeline.context import PipelineContext
15
14
  from ...pipeline.registry import PluginRegistry
16
- from ...models import RawTriple
15
+ from ...models import SplitSentence
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
@@ -21,10 +20,11 @@ logger = logging.getLogger(__name__)
21
20
  @PluginRegistry.splitter
22
21
  class T5GemmaSplitter(BaseSplitterPlugin):
23
22
  """
24
- Splitter plugin that uses T5-Gemma2 for triple extraction.
23
+ Splitter plugin that uses T5-Gemma2 to split text into atomic sentences.
25
24
 
26
- Wraps the existing StatementExtractor from extractor.py to produce
27
- RawTriple objects for the pipeline.
25
+ Uses the T5-Gemma2 model to identify and extract atomic statements
26
+ from unstructured text. Each sentence can be converted to a
27
+ subject-predicate-object triple in Stage 2.
28
28
  """
29
29
 
30
30
  def __init__(
@@ -62,11 +62,21 @@ class T5GemmaSplitter(BaseSplitterPlugin):
62
62
 
63
63
  @property
64
64
  def capabilities(self) -> PluginCapability:
65
- return PluginCapability.LLM_REQUIRED
65
+ return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
66
66
 
67
67
  @property
68
68
  def description(self) -> str:
69
- return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
69
+ return "T5-Gemma2 model for splitting text into atomic sentences"
70
+
71
+ @property
72
+ def model_vram_gb(self) -> float:
73
+ """T5-Gemma2 model weights ~2GB in bfloat16."""
74
+ return 2.0
75
+
76
+ @property
77
+ def per_item_vram_gb(self) -> float:
78
+ """Each text item during batch processing ~0.5GB for KV cache and activations."""
79
+ return 0.5
70
80
 
71
81
  def _get_extractor(self):
72
82
  """Lazy-load the StatementExtractor."""
@@ -85,16 +95,16 @@ class T5GemmaSplitter(BaseSplitterPlugin):
85
95
  self,
86
96
  text: str,
87
97
  context: PipelineContext,
88
- ) -> list[RawTriple]:
98
+ ) -> list[SplitSentence]:
89
99
  """
90
- Split text into raw triples using T5-Gemma2.
100
+ Split text into atomic sentences using T5-Gemma2.
91
101
 
92
102
  Args:
93
103
  text: Input text to split
94
104
  context: Pipeline context
95
105
 
96
106
  Returns:
97
- List of RawTriple objects
107
+ List of SplitSentence objects
98
108
  """
99
109
  logger.debug(f"T5GemmaSplitter processing {len(text)} chars")
100
110
 
@@ -120,68 +130,159 @@ class T5GemmaSplitter(BaseSplitterPlugin):
120
130
  extractor = self._get_extractor()
121
131
  xml_output = extractor.extract_as_xml(text, options)
122
132
 
123
- # Parse XML to RawTriple objects
124
- raw_triples = self._parse_xml_to_raw_triples(xml_output)
125
-
126
- logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
127
- return raw_triples
128
-
129
- def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
130
- """Parse XML output into RawTriple objects."""
131
- raw_triples = []
132
-
133
- try:
134
- root = ET.fromstring(xml_output)
135
- except ET.ParseError as e:
136
- logger.warning(f"XML parse error: {e}")
137
- # Try to repair
138
- xml_output = self._repair_xml(xml_output)
139
- try:
140
- root = ET.fromstring(xml_output)
141
- except ET.ParseError:
142
- logger.error("XML repair failed")
143
- return raw_triples
144
-
145
- if root.tag != "statements":
146
- logger.warning(f"Unexpected root tag: {root.tag}")
147
- return raw_triples
148
-
149
- for stmt_elem in root.findall("stmt"):
150
- try:
151
- subject_elem = stmt_elem.find("subject")
152
- predicate_elem = stmt_elem.find("predicate")
153
- object_elem = stmt_elem.find("object")
154
- text_elem = stmt_elem.find("text")
155
-
156
- subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
157
- predicate_text = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
158
- object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
159
- source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else ""
160
-
161
- if subject_text and object_text and source_text:
162
- raw_triples.append(RawTriple(
163
- subject_text=subject_text,
164
- predicate_text=predicate_text,
165
- object_text=object_text,
166
- source_sentence=source_text,
167
- ))
168
- else:
169
- logger.debug(f"Skipping incomplete triple: s={subject_text}, p={predicate_text}, o={object_text}")
170
-
171
- except Exception as e:
172
- logger.warning(f"Error parsing stmt element: {e}")
173
- continue
174
-
175
- return raw_triples
176
-
177
- def _repair_xml(self, xml_string: str) -> str:
178
- """Attempt to repair common XML syntax errors."""
179
- # Use the repair function from extractor.py
180
- from ...extractor import repair_xml
181
- repaired, repairs = repair_xml(xml_string)
182
- if repairs:
183
- logger.debug(f"XML repairs: {', '.join(repairs)}")
184
- return repaired
133
+ # Parse XML to SplitSentence objects
134
+ sentences = self._parse_xml_to_sentences(xml_output)
135
+
136
+ logger.info(f"T5GemmaSplitter produced {len(sentences)} sentences")
137
+ return sentences
138
+
139
+ def split_batch(
140
+ self,
141
+ texts: list[str],
142
+ context: PipelineContext,
143
+ ) -> list[list[SplitSentence]]:
144
+ """
145
+ Split multiple texts into atomic sentences using batch processing.
146
+
147
+ Processes all texts through the T5-Gemma2 model in batches
148
+ sized for optimal GPU utilization.
149
+
150
+ Args:
151
+ texts: List of input texts to split
152
+ context: Pipeline context
153
+
154
+ Returns:
155
+ List of SplitSentence lists, one per input text
156
+ """
157
+ if not texts:
158
+ return []
159
+
160
+ batch_size = self.get_optimal_batch_size()
161
+ logger.info(f"T5GemmaSplitter batch processing {len(texts)} texts with batch_size={batch_size}")
162
+
163
+ # Get options from context
164
+ splitter_options = context.source_metadata.get("splitter_options", {})
165
+ num_beams = splitter_options.get("num_beams", self._num_beams)
166
+ diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
167
+ max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
168
+
169
+ # Create extraction options
170
+ from ...models import ExtractionOptions as LegacyExtractionOptions
171
+ options = LegacyExtractionOptions(
172
+ num_beams=num_beams,
173
+ diversity_penalty=diversity_penalty,
174
+ max_new_tokens=max_new_tokens,
175
+ use_gliner_extraction=False,
176
+ embedding_dedup=False,
177
+ deduplicate=False,
178
+ )
179
+
180
+ extractor = self._get_extractor()
181
+ all_results: list[list[SplitSentence]] = []
182
+
183
+ # Process in batches
184
+ for i in range(0, len(texts), batch_size):
185
+ batch_texts = texts[i:i + batch_size]
186
+ logger.debug(f"Processing batch {i // batch_size + 1}: {len(batch_texts)} texts")
187
+
188
+ batch_results = self._process_batch(batch_texts, extractor, options)
189
+ all_results.extend(batch_results)
190
+
191
+ total_sentences = sum(len(r) for r in all_results)
192
+ logger.info(f"T5GemmaSplitter batch produced {total_sentences} total sentences from {len(texts)} texts")
193
+ return all_results
194
+
195
+ def _process_batch(
196
+ self,
197
+ texts: list[str],
198
+ extractor,
199
+ options,
200
+ ) -> list[list[SplitSentence]]:
201
+ """
202
+ Process a batch of texts through the model.
203
+
204
+ Uses the model's batch generation capability for efficient GPU utilization.
205
+ """
206
+ import torch
207
+
208
+ # Wrap texts in page tags
209
+ wrapped_texts = [f"<page>{t}</page>" if not t.startswith("<page>") else t for t in texts]
210
+
211
+ # Tokenize batch
212
+ tokenizer = extractor.tokenizer
213
+ model = extractor.model
214
+
215
+ inputs = tokenizer(
216
+ wrapped_texts,
217
+ return_tensors="pt",
218
+ max_length=4096,
219
+ truncation=True,
220
+ padding=True,
221
+ ).to(extractor.device)
222
+
223
+ # Create stopping criteria
224
+ from ...extractor import StopOnSequence
225
+ from transformers import StoppingCriteriaList
226
+
227
+ input_length = inputs["input_ids"].shape[1]
228
+ stop_criteria = StopOnSequence(
229
+ tokenizer=tokenizer,
230
+ stop_sequence="</statements>",
231
+ input_length=input_length,
232
+ )
233
+
234
+ # Generate for all texts in batch
235
+ with torch.no_grad():
236
+ outputs = model.generate(
237
+ **inputs,
238
+ max_new_tokens=options.max_new_tokens,
239
+ max_length=None,
240
+ num_beams=options.num_beams,
241
+ num_beam_groups=options.num_beams,
242
+ num_return_sequences=1, # One sequence per input for batch
243
+ diversity_penalty=options.diversity_penalty,
244
+ do_sample=False,
245
+ top_p=None,
246
+ top_k=None,
247
+ trust_remote_code=True,
248
+ custom_generate="transformers-community/group-beam-search",
249
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
250
+ )
251
+
252
+ # Decode and parse each output
253
+ results: list[list[SplitSentence]] = []
254
+ end_tag = "</statements>"
255
+
256
+ for output in outputs:
257
+ decoded = tokenizer.decode(output, skip_special_tokens=True)
258
+
259
+ # Truncate at </statements>
260
+ if end_tag in decoded:
261
+ end_pos = decoded.find(end_tag) + len(end_tag)
262
+ decoded = decoded[:end_pos]
263
+
264
+ sentences = self._parse_xml_to_sentences(decoded)
265
+ results.append(sentences)
266
+
267
+ return results
268
+
269
+ # Regex pattern to extract <text> content from <stmt> blocks
270
+ _STMT_TEXT_PATTERN = re.compile(r'<stmt>.*?<text>(.*?)</text>.*?</stmt>', re.DOTALL)
271
+
272
+ def _parse_xml_to_sentences(self, xml_output: str) -> list[SplitSentence]:
273
+ """Extract atomic sentences from <stmt><text>...</text></stmt> blocks."""
274
+ sentences = []
275
+
276
+ # Find all <text> content within <stmt> blocks
277
+ text_matches = self._STMT_TEXT_PATTERN.findall(xml_output)
278
+ logger.debug(f"Found {len(text_matches)} stmt text blocks via regex")
279
+
280
+ for sentence_text in text_matches:
281
+ sentence_text = sentence_text.strip()
282
+ if sentence_text:
283
+ sentences.append(SplitSentence(text=sentence_text))
284
+
285
+ return sentences
185
286
 
186
287
 
187
288
  # Allow importing without decorator for testing