glinker 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glinker/__init__.py +54 -0
- glinker/core/__init__.py +56 -0
- glinker/core/base.py +103 -0
- glinker/core/builders.py +547 -0
- glinker/core/dag.py +898 -0
- glinker/core/factory.py +261 -0
- glinker/core/registry.py +31 -0
- glinker/l0/__init__.py +21 -0
- glinker/l0/component.py +472 -0
- glinker/l0/models.py +90 -0
- glinker/l0/processor.py +108 -0
- glinker/l1/__init__.py +15 -0
- glinker/l1/component.py +284 -0
- glinker/l1/models.py +47 -0
- glinker/l1/processor.py +152 -0
- glinker/l2/__init__.py +19 -0
- glinker/l2/component.py +1220 -0
- glinker/l2/models.py +99 -0
- glinker/l2/processor.py +170 -0
- glinker/l3/__init__.py +12 -0
- glinker/l3/component.py +184 -0
- glinker/l3/models.py +48 -0
- glinker/l3/processor.py +350 -0
- glinker/l4/__init__.py +9 -0
- glinker/l4/component.py +121 -0
- glinker/l4/models.py +21 -0
- glinker/l4/processor.py +156 -0
- glinker/py.typed +1 -0
- glinker-0.1.0.dist-info/METADATA +994 -0
- glinker-0.1.0.dist-info/RECORD +33 -0
- glinker-0.1.0.dist-info/WHEEL +5 -0
- glinker-0.1.0.dist-info/licenses/LICENSE +201 -0
- glinker-0.1.0.dist-info/top_level.txt +1 -0
glinker/l2/models.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from pydantic import Field, BaseModel
|
|
2
|
+
from typing import List, Dict, Any, Optional, Literal
|
|
3
|
+
from glinker.core.base import BaseConfig, BaseInput, BaseOutput
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DatabaseRecord(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Unified format for all database layers
|
|
9
|
+
|
|
10
|
+
All layers (Dict, Redis, Elasticsearch, Postgres) use this format.
|
|
11
|
+
"""
|
|
12
|
+
entity_id: str = Field(..., description="Unique entity identifier")
|
|
13
|
+
label: str = Field(..., description="Primary label/name")
|
|
14
|
+
aliases: List[str] = Field(default_factory=list, description="Alternative names")
|
|
15
|
+
description: str = Field(default="", description="Entity description")
|
|
16
|
+
entity_type: str = Field(default="", description="Entity type/category")
|
|
17
|
+
popularity: int = Field(default=0, description="Popularity score")
|
|
18
|
+
metadata: Dict[str, Any] = Field(
|
|
19
|
+
default_factory=dict,
|
|
20
|
+
description="Database-specific metadata"
|
|
21
|
+
)
|
|
22
|
+
source: str = Field(default="", description="Source layer: dict|redis|elasticsearch|postgres")
|
|
23
|
+
|
|
24
|
+
# Embedding fields for precomputed label embeddings
|
|
25
|
+
embedding: Optional[List[float]] = Field(
|
|
26
|
+
default=None,
|
|
27
|
+
description="Precomputed label embedding vector"
|
|
28
|
+
)
|
|
29
|
+
embedding_model_id: Optional[str] = Field(
|
|
30
|
+
default=None,
|
|
31
|
+
description="Model ID used to compute the embedding"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FuzzyConfig(BaseConfig):
|
|
36
|
+
"""Fuzzy search configuration"""
|
|
37
|
+
max_distance: int = Field(2, description="Maximum Levenshtein distance")
|
|
38
|
+
min_similarity: float = Field(0.3, description="Minimum similarity threshold")
|
|
39
|
+
n_gram_size: int = Field(3, description="N-gram size for matching")
|
|
40
|
+
prefix_length: int = Field(1, description="Prefix length to preserve")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LayerConfig(BaseConfig):
|
|
44
|
+
"""Database layer configuration"""
|
|
45
|
+
type: str = Field(..., description="Layer type: dict|redis|elasticsearch|postgres")
|
|
46
|
+
priority: int = Field(..., description="Search priority (0 = highest)")
|
|
47
|
+
config: Dict[str, Any] = Field(default_factory=dict, description="Layer-specific config")
|
|
48
|
+
|
|
49
|
+
search_mode: List[Literal["exact", "fuzzy"]] = Field(
|
|
50
|
+
["exact"],
|
|
51
|
+
description="Search methods: ['exact'], ['fuzzy'], or ['exact', 'fuzzy']"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
write: bool = Field(True, description="Enable write operations")
|
|
55
|
+
cache_policy: str = Field("always", description="Cache policy: always|miss|hit")
|
|
56
|
+
ttl: int = Field(3600, description="TTL in seconds (0 = no expiry)")
|
|
57
|
+
field_mapping: Dict[str, str] = Field(
|
|
58
|
+
default_factory=lambda: {
|
|
59
|
+
"entity_id": "entity_id",
|
|
60
|
+
"label": "label",
|
|
61
|
+
"aliases": "aliases",
|
|
62
|
+
"description": "description",
|
|
63
|
+
"entity_type": "entity_type",
|
|
64
|
+
"popularity": "popularity"
|
|
65
|
+
},
|
|
66
|
+
description="Field mapping: DatabaseRecord field -> storage field"
|
|
67
|
+
)
|
|
68
|
+
fuzzy: Optional[FuzzyConfig] = Field(default_factory=FuzzyConfig, description="Fuzzy search config")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class EmbeddingConfig(BaseModel):
|
|
72
|
+
"""Configuration for precomputed label embeddings"""
|
|
73
|
+
enabled: bool = Field(False, description="Enable embedding support")
|
|
74
|
+
model_name: Optional[str] = Field(None, description="Model name for encoding labels")
|
|
75
|
+
dim: int = Field(768, description="Embedding dimension")
|
|
76
|
+
precompute_on_load: bool = Field(False, description="Compute embeddings during load_bulk")
|
|
77
|
+
batch_size: int = Field(32, description="Batch size for encoding")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class L2Config(BaseConfig):
|
|
81
|
+
"""L2 processor configuration"""
|
|
82
|
+
layers: List[LayerConfig] = Field(..., description="Database layers in priority order")
|
|
83
|
+
max_candidates: int = Field(30, description="Maximum candidates per mention")
|
|
84
|
+
min_popularity: int = Field(0, description="Minimum popularity threshold")
|
|
85
|
+
embeddings: Optional[EmbeddingConfig] = Field(
|
|
86
|
+
default=None,
|
|
87
|
+
description="Embedding configuration for precomputed labels"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class L2Input(BaseInput):
|
|
92
|
+
"""L2 processor input"""
|
|
93
|
+
mentions: List[str] = Field(..., description="List of mentions to search")
|
|
94
|
+
structure: List[List[str]] = Field(None, description="Optional grouping structure")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class L2Output(BaseOutput):
|
|
98
|
+
"""L2 processor output"""
|
|
99
|
+
candidates: List[List[DatabaseRecord]] = Field(..., description="Candidates per mention/group")
|
glinker/l2/processor.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from typing import Any, List, Union
|
|
2
|
+
from glinker.core.base import BaseProcessor
|
|
3
|
+
from glinker.core.registry import processor_registry
|
|
4
|
+
from .models import L2Config, L2Input, L2Output, DatabaseRecord
|
|
5
|
+
from .component import DatabaseChainComponent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class L2Processor(BaseProcessor[L2Config, L2Input, L2Output]):
|
|
9
|
+
"""Multi-layer database search processor"""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
config: L2Config,
|
|
14
|
+
component: DatabaseChainComponent,
|
|
15
|
+
pipeline: list[tuple[str, dict[str, Any]]] = None
|
|
16
|
+
):
|
|
17
|
+
super().__init__(config, component, pipeline)
|
|
18
|
+
self.schema = {} # Will be set by DAG executor from node config
|
|
19
|
+
|
|
20
|
+
def format_label(self, record: DatabaseRecord) -> str:
|
|
21
|
+
"""Format label using schema template"""
|
|
22
|
+
template = self.schema.get('template', '{label}')
|
|
23
|
+
try:
|
|
24
|
+
return template.format(**record.model_dump())
|
|
25
|
+
except KeyError:
|
|
26
|
+
return record.label
|
|
27
|
+
|
|
28
|
+
def precompute_embeddings(
|
|
29
|
+
self,
|
|
30
|
+
encoder_fn,
|
|
31
|
+
target_layers: List[str] = None,
|
|
32
|
+
batch_size: int = 32
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Precompute embeddings for entities using schema template.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
encoder_fn: Function that takes List[str] and returns embeddings
|
|
39
|
+
target_layers: Layer types to update
|
|
40
|
+
batch_size: Batch size for encoding
|
|
41
|
+
"""
|
|
42
|
+
template = self.schema.get('template', '{label}')
|
|
43
|
+
model_id = self.config.embeddings.model_name if self.config.embeddings else 'unknown'
|
|
44
|
+
|
|
45
|
+
return self.component.precompute_embeddings(
|
|
46
|
+
encoder_fn=encoder_fn,
|
|
47
|
+
template=template,
|
|
48
|
+
model_id=model_id,
|
|
49
|
+
target_layers=target_layers,
|
|
50
|
+
batch_size=batch_size
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]:
|
|
54
|
+
return [
|
|
55
|
+
("search", {}),
|
|
56
|
+
("filter_by_popularity", {}),
|
|
57
|
+
("deduplicate_candidates", {}),
|
|
58
|
+
("limit_candidates", {}),
|
|
59
|
+
("sort_by_popularity", {})
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
mentions: Union[List[str], List[List[Any]], L2Input] = None,
|
|
65
|
+
texts: List[str] = None,
|
|
66
|
+
structure: List[List[str]] = None,
|
|
67
|
+
input_data: L2Input = None
|
|
68
|
+
) -> L2Output:
|
|
69
|
+
"""
|
|
70
|
+
Process mentions and return candidates
|
|
71
|
+
|
|
72
|
+
Supports:
|
|
73
|
+
- List[str]: flat list of mention strings
|
|
74
|
+
- List[List[L1Entity]]: nested list of L1Entity objects (one list per text)
|
|
75
|
+
- L2Input: structured input with mentions and structure
|
|
76
|
+
- mentions=None: return entire entity database (one copy per text)
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if input_data is not None:
|
|
80
|
+
mentions = input_data.mentions
|
|
81
|
+
structure = input_data.structure
|
|
82
|
+
elif isinstance(mentions, L2Input):
|
|
83
|
+
structure = mentions.structure
|
|
84
|
+
mentions = mentions.mentions
|
|
85
|
+
|
|
86
|
+
# No mentions → return entire database (simple pipeline mode)
|
|
87
|
+
if mentions is None:
|
|
88
|
+
all_entities = self.component.get_all_entities()
|
|
89
|
+
n = len(texts) if texts is not None else 1
|
|
90
|
+
return L2Output(candidates=[all_entities for _ in range(n)])
|
|
91
|
+
|
|
92
|
+
# Check if mentions is nested (list of lists - one per text)
|
|
93
|
+
if mentions and isinstance(mentions[0], (list, tuple)):
|
|
94
|
+
# Nested structure: [[entities_text1], [entities_text2], ...]
|
|
95
|
+
all_candidates = []
|
|
96
|
+
|
|
97
|
+
for text_entities in mentions:
|
|
98
|
+
text_candidates = []
|
|
99
|
+
|
|
100
|
+
for entity in text_entities:
|
|
101
|
+
# Extract text from L1Entity or dict
|
|
102
|
+
mention_text = self._extract_mention_text(entity)
|
|
103
|
+
|
|
104
|
+
# Search candidates for this mention
|
|
105
|
+
candidates = self._execute_pipeline(mention_text, self.pipeline)
|
|
106
|
+
text_candidates.extend(candidates)
|
|
107
|
+
|
|
108
|
+
all_candidates.append(text_candidates)
|
|
109
|
+
|
|
110
|
+
return L2Output(candidates=all_candidates)
|
|
111
|
+
|
|
112
|
+
# Flat structure: ["mention1", "mention2", ...]
|
|
113
|
+
else:
|
|
114
|
+
all_candidates = []
|
|
115
|
+
|
|
116
|
+
for mention in mentions:
|
|
117
|
+
mention_text = self._extract_mention_text(mention)
|
|
118
|
+
candidates = self._execute_pipeline(mention_text, self.pipeline)
|
|
119
|
+
all_candidates.append(candidates)
|
|
120
|
+
|
|
121
|
+
if structure:
|
|
122
|
+
grouped = self._group_by_structure(all_candidates, structure)
|
|
123
|
+
else:
|
|
124
|
+
# Flatten all into one group
|
|
125
|
+
grouped = [self._flatten(all_candidates)]
|
|
126
|
+
|
|
127
|
+
return L2Output(candidates=grouped)
|
|
128
|
+
|
|
129
|
+
def _extract_mention_text(self, mention: Any) -> str:
|
|
130
|
+
"""Extract text string from mention (can be L1Entity, dict, or str)"""
|
|
131
|
+
if isinstance(mention, str):
|
|
132
|
+
return mention
|
|
133
|
+
elif hasattr(mention, 'text'):
|
|
134
|
+
return mention.text
|
|
135
|
+
elif isinstance(mention, dict):
|
|
136
|
+
return mention.get('text', str(mention))
|
|
137
|
+
else:
|
|
138
|
+
return str(mention)
|
|
139
|
+
|
|
140
|
+
def _group_by_structure(
|
|
141
|
+
self,
|
|
142
|
+
all_candidates: List[List[DatabaseRecord]],
|
|
143
|
+
structure: List[List[str]]
|
|
144
|
+
) -> List[List[DatabaseRecord]]:
|
|
145
|
+
"""Group candidates according to structure"""
|
|
146
|
+
grouped = []
|
|
147
|
+
idx = 0
|
|
148
|
+
for text_mentions in structure:
|
|
149
|
+
text_candidates = []
|
|
150
|
+
for _ in text_mentions:
|
|
151
|
+
if idx < len(all_candidates):
|
|
152
|
+
text_candidates.extend(all_candidates[idx])
|
|
153
|
+
idx += 1
|
|
154
|
+
grouped.append(text_candidates)
|
|
155
|
+
return grouped
|
|
156
|
+
|
|
157
|
+
def _flatten(self, nested: List[List[Any]]) -> List[Any]:
|
|
158
|
+
"""Flatten nested list"""
|
|
159
|
+
flat = []
|
|
160
|
+
for sublist in nested:
|
|
161
|
+
flat.extend(sublist)
|
|
162
|
+
return flat
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@processor_registry.register("l2_chain")
|
|
166
|
+
def create_l2_processor(config_dict: dict, pipeline: list = None) -> L2Processor:
|
|
167
|
+
"""Factory: creates component + processor"""
|
|
168
|
+
config = L2Config(**config_dict)
|
|
169
|
+
component = DatabaseChainComponent(config)
|
|
170
|
+
return L2Processor(config, component, pipeline)
|
glinker/l3/__init__.py
ADDED
glinker/l3/component.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
import torch
|
|
3
|
+
from gliner import GLiNER
|
|
4
|
+
from glinker.core.base import BaseComponent
|
|
5
|
+
from .models import L3Config, L3Entity
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class L3Component(BaseComponent[L3Config]):
|
|
9
|
+
"""GLiNER-based entity linking component"""
|
|
10
|
+
|
|
11
|
+
def _setup(self):
|
|
12
|
+
"""Initialize GLiNER model"""
|
|
13
|
+
self.model = GLiNER.from_pretrained(
|
|
14
|
+
self.config.model_name,
|
|
15
|
+
token=self.config.token,
|
|
16
|
+
max_length=self.config.max_length
|
|
17
|
+
)
|
|
18
|
+
self.model.to(self.config.device)
|
|
19
|
+
|
|
20
|
+
# Fix labels tokenizer max_length for BiEncoder models
|
|
21
|
+
# Some models have model_max_length not properly set (> 10^18)
|
|
22
|
+
if (self.config.max_length is not None and
|
|
23
|
+
hasattr(self.model, 'data_processor') and
|
|
24
|
+
hasattr(self.model.data_processor, 'labels_tokenizer')):
|
|
25
|
+
tok = self.model.data_processor.labels_tokenizer
|
|
26
|
+
if tok.model_max_length > 100000:
|
|
27
|
+
tok.model_max_length = self.config.max_length
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def device(self):
|
|
31
|
+
return self.config.device
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def supports_precomputed_embeddings(self) -> bool:
|
|
35
|
+
"""Check if model supports precomputed embeddings (BiEncoder)"""
|
|
36
|
+
return hasattr(self.model, 'encode_labels') and self.model.config.labels_encoder is not None
|
|
37
|
+
|
|
38
|
+
def get_available_methods(self) -> List[str]:
|
|
39
|
+
return [
|
|
40
|
+
"predict_entities",
|
|
41
|
+
"predict_with_embeddings",
|
|
42
|
+
"encode_labels",
|
|
43
|
+
"filter_by_score",
|
|
44
|
+
"sort_by_position",
|
|
45
|
+
"deduplicate_entities"
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def encode_labels(self, labels: List[str], batch_size: int = 32) -> torch.Tensor:
|
|
49
|
+
"""
|
|
50
|
+
Encode labels using GLiNER's native label encoder.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
labels: List of label strings to encode
|
|
54
|
+
batch_size: Batch size for encoding
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tensor of shape (num_labels, hidden_size)
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
NotImplementedError: If model doesn't support label encoding (not BiEncoder)
|
|
61
|
+
"""
|
|
62
|
+
if not self.supports_precomputed_embeddings:
|
|
63
|
+
raise NotImplementedError(
|
|
64
|
+
f"Model {self.config.model_name} doesn't support label precomputation. "
|
|
65
|
+
"Only BiEncoder models support this feature."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return self.model.encode_labels(labels, batch_size=batch_size)
|
|
69
|
+
|
|
70
|
+
def predict_with_embeddings(
|
|
71
|
+
self,
|
|
72
|
+
text: str,
|
|
73
|
+
labels: List[str],
|
|
74
|
+
embeddings: torch.Tensor,
|
|
75
|
+
input_spans: List[List[dict]] = None
|
|
76
|
+
) -> List[L3Entity]:
|
|
77
|
+
"""
|
|
78
|
+
Predict entities using pre-computed label embeddings.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
text: Input text
|
|
82
|
+
labels: List of label strings (for output mapping)
|
|
83
|
+
embeddings: Pre-computed embeddings tensor (num_labels, hidden_size)
|
|
84
|
+
input_spans: Optional list of span dicts with 'start' and 'end' keys
|
|
85
|
+
to constrain prediction to specific spans from L1
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of L3Entity predictions
|
|
89
|
+
"""
|
|
90
|
+
if not self.supports_precomputed_embeddings:
|
|
91
|
+
# Fallback to regular prediction
|
|
92
|
+
return self.predict_entities(text, labels, input_spans=input_spans)
|
|
93
|
+
|
|
94
|
+
kwargs = dict(
|
|
95
|
+
threshold=self.config.threshold,
|
|
96
|
+
flat_ner=self.config.flat_ner,
|
|
97
|
+
multi_label=self.config.multi_label,
|
|
98
|
+
return_class_probs=True
|
|
99
|
+
)
|
|
100
|
+
if input_spans is not None:
|
|
101
|
+
kwargs["input_spans"] = input_spans
|
|
102
|
+
|
|
103
|
+
entities = self.model.predict_with_embeds(
|
|
104
|
+
text,
|
|
105
|
+
embeddings,
|
|
106
|
+
labels,
|
|
107
|
+
**kwargs
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return [
|
|
111
|
+
L3Entity(
|
|
112
|
+
text=e["text"],
|
|
113
|
+
label=e["label"],
|
|
114
|
+
start=e["start"],
|
|
115
|
+
end=e["end"],
|
|
116
|
+
score=e["score"],
|
|
117
|
+
class_probs=e.get("class_probs")
|
|
118
|
+
)
|
|
119
|
+
for e in entities
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
def predict_entities(
|
|
123
|
+
self,
|
|
124
|
+
text: str,
|
|
125
|
+
labels: List[str],
|
|
126
|
+
input_spans: List[List[dict]] = None
|
|
127
|
+
) -> List[L3Entity]:
|
|
128
|
+
"""Predict entities using GLiNER
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
text: Input text
|
|
132
|
+
labels: List of label strings
|
|
133
|
+
input_spans: Optional list of span dicts with 'start' and 'end' keys
|
|
134
|
+
to constrain prediction to specific spans from L1
|
|
135
|
+
"""
|
|
136
|
+
if not labels:
|
|
137
|
+
return []
|
|
138
|
+
|
|
139
|
+
kwargs = dict(
|
|
140
|
+
threshold=self.config.threshold,
|
|
141
|
+
flat_ner=self.config.flat_ner,
|
|
142
|
+
multi_label=self.config.multi_label,
|
|
143
|
+
return_class_probs=True
|
|
144
|
+
)
|
|
145
|
+
if input_spans is not None:
|
|
146
|
+
kwargs["input_spans"] = input_spans
|
|
147
|
+
|
|
148
|
+
entities = self.model.predict_entities(
|
|
149
|
+
text,
|
|
150
|
+
labels,
|
|
151
|
+
**kwargs
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return [
|
|
155
|
+
L3Entity(
|
|
156
|
+
text=e["text"],
|
|
157
|
+
label=e["label"],
|
|
158
|
+
start=e["start"],
|
|
159
|
+
end=e["end"],
|
|
160
|
+
score=e["score"],
|
|
161
|
+
class_probs=e.get("class_probs")
|
|
162
|
+
)
|
|
163
|
+
for e in entities
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
def filter_by_score(self, entities: List[L3Entity], threshold: float = None) -> List[L3Entity]:
|
|
167
|
+
"""Filter entities by confidence score"""
|
|
168
|
+
threshold = threshold if threshold is not None else self.config.threshold
|
|
169
|
+
return [e for e in entities if e.score >= threshold]
|
|
170
|
+
|
|
171
|
+
def sort_by_position(self, entities: List[L3Entity]) -> List[L3Entity]:
|
|
172
|
+
"""Sort entities by position in text"""
|
|
173
|
+
return sorted(entities, key=lambda e: e.start)
|
|
174
|
+
|
|
175
|
+
def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]:
|
|
176
|
+
"""Remove duplicate entities"""
|
|
177
|
+
seen = set()
|
|
178
|
+
unique = []
|
|
179
|
+
for entity in entities:
|
|
180
|
+
key = (entity.text, entity.start, entity.end)
|
|
181
|
+
if key not in seen:
|
|
182
|
+
unique.append(entity)
|
|
183
|
+
seen.add(key)
|
|
184
|
+
return unique
|
glinker/l3/models.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from typing import Dict, List, Any, Optional
|
|
3
|
+
from glinker.core.base import BaseConfig, BaseInput, BaseOutput
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class L3Config(BaseConfig):
|
|
7
|
+
model_name: str = Field(...)
|
|
8
|
+
token: str = Field(None)
|
|
9
|
+
device: str = Field("cpu")
|
|
10
|
+
threshold: float = Field(0.5)
|
|
11
|
+
flat_ner: bool = Field(True)
|
|
12
|
+
multi_label: bool = Field(False)
|
|
13
|
+
batch_size: int = Field(8)
|
|
14
|
+
|
|
15
|
+
# Embedding settings
|
|
16
|
+
use_precomputed_embeddings: bool = Field(
|
|
17
|
+
True,
|
|
18
|
+
description="Use precomputed embeddings from L2 candidates if available"
|
|
19
|
+
)
|
|
20
|
+
cache_embeddings: bool = Field(
|
|
21
|
+
False,
|
|
22
|
+
description="Cache computed embeddings back to L2"
|
|
23
|
+
)
|
|
24
|
+
max_length: int = Field(
|
|
25
|
+
None,
|
|
26
|
+
description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# TODO replace candidates with labels
|
|
31
|
+
class L3Input(BaseInput):
|
|
32
|
+
texts: List[str] = Field(...)
|
|
33
|
+
labels: List[List[Any]] = Field(...)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class L3Entity(BaseOutput):
|
|
37
|
+
text: str
|
|
38
|
+
label: str
|
|
39
|
+
start: int
|
|
40
|
+
end: int
|
|
41
|
+
score: float
|
|
42
|
+
class_probs: Optional[Dict[str, float]] = Field(
|
|
43
|
+
None, description="Per-label class probabilities from GLiNER"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class L3Output(BaseOutput):
|
|
48
|
+
entities: List[List[L3Entity]] = Field(...)
|