glinker 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
glinker/l2/models.py ADDED
@@ -0,0 +1,99 @@
1
+ from pydantic import Field, BaseModel
2
+ from typing import List, Dict, Any, Optional, Literal
3
+ from glinker.core.base import BaseConfig, BaseInput, BaseOutput
4
+
5
+
6
+ class DatabaseRecord(BaseModel):
7
+ """
8
+ Unified format for all database layers
9
+
10
+ All layers (Dict, Redis, Elasticsearch, Postgres) use this format.
11
+ """
12
+ entity_id: str = Field(..., description="Unique entity identifier")
13
+ label: str = Field(..., description="Primary label/name")
14
+ aliases: List[str] = Field(default_factory=list, description="Alternative names")
15
+ description: str = Field(default="", description="Entity description")
16
+ entity_type: str = Field(default="", description="Entity type/category")
17
+ popularity: int = Field(default=0, description="Popularity score")
18
+ metadata: Dict[str, Any] = Field(
19
+ default_factory=dict,
20
+ description="Database-specific metadata"
21
+ )
22
+ source: str = Field(default="", description="Source layer: dict|redis|elasticsearch|postgres")
23
+
24
+ # Embedding fields for precomputed label embeddings
25
+ embedding: Optional[List[float]] = Field(
26
+ default=None,
27
+ description="Precomputed label embedding vector"
28
+ )
29
+ embedding_model_id: Optional[str] = Field(
30
+ default=None,
31
+ description="Model ID used to compute the embedding"
32
+ )
33
+
34
+
35
+ class FuzzyConfig(BaseConfig):
36
+ """Fuzzy search configuration"""
37
+ max_distance: int = Field(2, description="Maximum Levenshtein distance")
38
+ min_similarity: float = Field(0.3, description="Minimum similarity threshold")
39
+ n_gram_size: int = Field(3, description="N-gram size for matching")
40
+ prefix_length: int = Field(1, description="Prefix length to preserve")
41
+
42
+
43
+ class LayerConfig(BaseConfig):
44
+ """Database layer configuration"""
45
+ type: str = Field(..., description="Layer type: dict|redis|elasticsearch|postgres")
46
+ priority: int = Field(..., description="Search priority (0 = highest)")
47
+ config: Dict[str, Any] = Field(default_factory=dict, description="Layer-specific config")
48
+
49
+ search_mode: List[Literal["exact", "fuzzy"]] = Field(
50
+ ["exact"],
51
+ description="Search methods: ['exact'], ['fuzzy'], or ['exact', 'fuzzy']"
52
+ )
53
+
54
+ write: bool = Field(True, description="Enable write operations")
55
+ cache_policy: str = Field("always", description="Cache policy: always|miss|hit")
56
+ ttl: int = Field(3600, description="TTL in seconds (0 = no expiry)")
57
+ field_mapping: Dict[str, str] = Field(
58
+ default_factory=lambda: {
59
+ "entity_id": "entity_id",
60
+ "label": "label",
61
+ "aliases": "aliases",
62
+ "description": "description",
63
+ "entity_type": "entity_type",
64
+ "popularity": "popularity"
65
+ },
66
+ description="Field mapping: DatabaseRecord field -> storage field"
67
+ )
68
+ fuzzy: Optional[FuzzyConfig] = Field(default_factory=FuzzyConfig, description="Fuzzy search config")
69
+
70
+
71
+ class EmbeddingConfig(BaseModel):
72
+ """Configuration for precomputed label embeddings"""
73
+ enabled: bool = Field(False, description="Enable embedding support")
74
+ model_name: Optional[str] = Field(None, description="Model name for encoding labels")
75
+ dim: int = Field(768, description="Embedding dimension")
76
+ precompute_on_load: bool = Field(False, description="Compute embeddings during load_bulk")
77
+ batch_size: int = Field(32, description="Batch size for encoding")
78
+
79
+
80
+ class L2Config(BaseConfig):
81
+ """L2 processor configuration"""
82
+ layers: List[LayerConfig] = Field(..., description="Database layers in priority order")
83
+ max_candidates: int = Field(30, description="Maximum candidates per mention")
84
+ min_popularity: int = Field(0, description="Minimum popularity threshold")
85
+ embeddings: Optional[EmbeddingConfig] = Field(
86
+ default=None,
87
+ description="Embedding configuration for precomputed labels"
88
+ )
89
+
90
+
91
+ class L2Input(BaseInput):
92
+ """L2 processor input"""
93
+ mentions: List[str] = Field(..., description="List of mentions to search")
94
+ structure: List[List[str]] = Field(None, description="Optional grouping structure")
95
+
96
+
97
+ class L2Output(BaseOutput):
98
+ """L2 processor output"""
99
+ candidates: List[List[DatabaseRecord]] = Field(..., description="Candidates per mention/group")
@@ -0,0 +1,170 @@
1
+ from typing import Any, List, Union
2
+ from glinker.core.base import BaseProcessor
3
+ from glinker.core.registry import processor_registry
4
+ from .models import L2Config, L2Input, L2Output, DatabaseRecord
5
+ from .component import DatabaseChainComponent
6
+
7
+
8
+ class L2Processor(BaseProcessor[L2Config, L2Input, L2Output]):
9
+ """Multi-layer database search processor"""
10
+
11
+ def __init__(
12
+ self,
13
+ config: L2Config,
14
+ component: DatabaseChainComponent,
15
+ pipeline: list[tuple[str, dict[str, Any]]] = None
16
+ ):
17
+ super().__init__(config, component, pipeline)
18
+ self.schema = {} # Will be set by DAG executor from node config
19
+
20
+ def format_label(self, record: DatabaseRecord) -> str:
21
+ """Format label using schema template"""
22
+ template = self.schema.get('template', '{label}')
23
+ try:
24
+ return template.format(**record.model_dump())
25
+ except KeyError:
26
+ return record.label
27
+
28
+ def precompute_embeddings(
29
+ self,
30
+ encoder_fn,
31
+ target_layers: List[str] = None,
32
+ batch_size: int = 32
33
+ ):
34
+ """
35
+ Precompute embeddings for entities using schema template.
36
+
37
+ Args:
38
+ encoder_fn: Function that takes List[str] and returns embeddings
39
+ target_layers: Layer types to update
40
+ batch_size: Batch size for encoding
41
+ """
42
+ template = self.schema.get('template', '{label}')
43
+ model_id = self.config.embeddings.model_name if self.config.embeddings else 'unknown'
44
+
45
+ return self.component.precompute_embeddings(
46
+ encoder_fn=encoder_fn,
47
+ template=template,
48
+ model_id=model_id,
49
+ target_layers=target_layers,
50
+ batch_size=batch_size
51
+ )
52
+
53
+ def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]:
54
+ return [
55
+ ("search", {}),
56
+ ("filter_by_popularity", {}),
57
+ ("deduplicate_candidates", {}),
58
+ ("limit_candidates", {}),
59
+ ("sort_by_popularity", {})
60
+ ]
61
+
62
+ def __call__(
63
+ self,
64
+ mentions: Union[List[str], List[List[Any]], L2Input] = None,
65
+ texts: List[str] = None,
66
+ structure: List[List[str]] = None,
67
+ input_data: L2Input = None
68
+ ) -> L2Output:
69
+ """
70
+ Process mentions and return candidates
71
+
72
+ Supports:
73
+ - List[str]: flat list of mention strings
74
+ - List[List[L1Entity]]: nested list of L1Entity objects (one list per text)
75
+ - L2Input: structured input with mentions and structure
76
+ - mentions=None: return entire entity database (one copy per text)
77
+ """
78
+
79
+ if input_data is not None:
80
+ mentions = input_data.mentions
81
+ structure = input_data.structure
82
+ elif isinstance(mentions, L2Input):
83
+ structure = mentions.structure
84
+ mentions = mentions.mentions
85
+
86
+ # No mentions → return entire database (simple pipeline mode)
87
+ if mentions is None:
88
+ all_entities = self.component.get_all_entities()
89
+ n = len(texts) if texts is not None else 1
90
+ return L2Output(candidates=[all_entities for _ in range(n)])
91
+
92
+ # Check if mentions is nested (list of lists - one per text)
93
+ if mentions and isinstance(mentions[0], (list, tuple)):
94
+ # Nested structure: [[entities_text1], [entities_text2], ...]
95
+ all_candidates = []
96
+
97
+ for text_entities in mentions:
98
+ text_candidates = []
99
+
100
+ for entity in text_entities:
101
+ # Extract text from L1Entity or dict
102
+ mention_text = self._extract_mention_text(entity)
103
+
104
+ # Search candidates for this mention
105
+ candidates = self._execute_pipeline(mention_text, self.pipeline)
106
+ text_candidates.extend(candidates)
107
+
108
+ all_candidates.append(text_candidates)
109
+
110
+ return L2Output(candidates=all_candidates)
111
+
112
+ # Flat structure: ["mention1", "mention2", ...]
113
+ else:
114
+ all_candidates = []
115
+
116
+ for mention in mentions:
117
+ mention_text = self._extract_mention_text(mention)
118
+ candidates = self._execute_pipeline(mention_text, self.pipeline)
119
+ all_candidates.append(candidates)
120
+
121
+ if structure:
122
+ grouped = self._group_by_structure(all_candidates, structure)
123
+ else:
124
+ # Flatten all into one group
125
+ grouped = [self._flatten(all_candidates)]
126
+
127
+ return L2Output(candidates=grouped)
128
+
129
+ def _extract_mention_text(self, mention: Any) -> str:
130
+ """Extract text string from mention (can be L1Entity, dict, or str)"""
131
+ if isinstance(mention, str):
132
+ return mention
133
+ elif hasattr(mention, 'text'):
134
+ return mention.text
135
+ elif isinstance(mention, dict):
136
+ return mention.get('text', str(mention))
137
+ else:
138
+ return str(mention)
139
+
140
+ def _group_by_structure(
141
+ self,
142
+ all_candidates: List[List[DatabaseRecord]],
143
+ structure: List[List[str]]
144
+ ) -> List[List[DatabaseRecord]]:
145
+ """Group candidates according to structure"""
146
+ grouped = []
147
+ idx = 0
148
+ for text_mentions in structure:
149
+ text_candidates = []
150
+ for _ in text_mentions:
151
+ if idx < len(all_candidates):
152
+ text_candidates.extend(all_candidates[idx])
153
+ idx += 1
154
+ grouped.append(text_candidates)
155
+ return grouped
156
+
157
+ def _flatten(self, nested: List[List[Any]]) -> List[Any]:
158
+ """Flatten nested list"""
159
+ flat = []
160
+ for sublist in nested:
161
+ flat.extend(sublist)
162
+ return flat
163
+
164
+
165
+ @processor_registry.register("l2_chain")
166
+ def create_l2_processor(config_dict: dict, pipeline: list = None) -> L2Processor:
167
+ """Factory: creates component + processor"""
168
+ config = L2Config(**config_dict)
169
+ component = DatabaseChainComponent(config)
170
+ return L2Processor(config, component, pipeline)
glinker/l3/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ from .models import L3Config, L3Input, L3Output, L3Entity
2
+ from .component import L3Component
3
+ from .processor import L3Processor
4
+
5
+ __all__ = [
6
+ "L3Config",
7
+ "L3Input",
8
+ "L3Output",
9
+ "L3Entity",
10
+ "L3Component",
11
+ "L3Processor",
12
+ ]
@@ -0,0 +1,184 @@
1
+ from typing import Dict, List, Optional
2
+ import torch
3
+ from gliner import GLiNER
4
+ from glinker.core.base import BaseComponent
5
+ from .models import L3Config, L3Entity
6
+
7
+
8
+ class L3Component(BaseComponent[L3Config]):
9
+ """GLiNER-based entity linking component"""
10
+
11
+ def _setup(self):
12
+ """Initialize GLiNER model"""
13
+ self.model = GLiNER.from_pretrained(
14
+ self.config.model_name,
15
+ token=self.config.token,
16
+ max_length=self.config.max_length
17
+ )
18
+ self.model.to(self.config.device)
19
+
20
+ # Fix labels tokenizer max_length for BiEncoder models
21
+ # Some models have model_max_length not properly set (> 10^18)
22
+ if (self.config.max_length is not None and
23
+ hasattr(self.model, 'data_processor') and
24
+ hasattr(self.model.data_processor, 'labels_tokenizer')):
25
+ tok = self.model.data_processor.labels_tokenizer
26
+ if tok.model_max_length > 100000:
27
+ tok.model_max_length = self.config.max_length
28
+
29
+ @property
30
+ def device(self):
31
+ return self.config.device
32
+
33
+ @property
34
+ def supports_precomputed_embeddings(self) -> bool:
35
+ """Check if model supports precomputed embeddings (BiEncoder)"""
36
+ return hasattr(self.model, 'encode_labels') and self.model.config.labels_encoder is not None
37
+
38
+ def get_available_methods(self) -> List[str]:
39
+ return [
40
+ "predict_entities",
41
+ "predict_with_embeddings",
42
+ "encode_labels",
43
+ "filter_by_score",
44
+ "sort_by_position",
45
+ "deduplicate_entities"
46
+ ]
47
+
48
+ def encode_labels(self, labels: List[str], batch_size: int = 32) -> torch.Tensor:
49
+ """
50
+ Encode labels using GLiNER's native label encoder.
51
+
52
+ Args:
53
+ labels: List of label strings to encode
54
+ batch_size: Batch size for encoding
55
+
56
+ Returns:
57
+ Tensor of shape (num_labels, hidden_size)
58
+
59
+ Raises:
60
+ NotImplementedError: If model doesn't support label encoding (not BiEncoder)
61
+ """
62
+ if not self.supports_precomputed_embeddings:
63
+ raise NotImplementedError(
64
+ f"Model {self.config.model_name} doesn't support label precomputation. "
65
+ "Only BiEncoder models support this feature."
66
+ )
67
+
68
+ return self.model.encode_labels(labels, batch_size=batch_size)
69
+
70
+ def predict_with_embeddings(
71
+ self,
72
+ text: str,
73
+ labels: List[str],
74
+ embeddings: torch.Tensor,
75
+ input_spans: List[List[dict]] = None
76
+ ) -> List[L3Entity]:
77
+ """
78
+ Predict entities using pre-computed label embeddings.
79
+
80
+ Args:
81
+ text: Input text
82
+ labels: List of label strings (for output mapping)
83
+ embeddings: Pre-computed embeddings tensor (num_labels, hidden_size)
84
+ input_spans: Optional list of span dicts with 'start' and 'end' keys
85
+ to constrain prediction to specific spans from L1
86
+
87
+ Returns:
88
+ List of L3Entity predictions
89
+ """
90
+ if not self.supports_precomputed_embeddings:
91
+ # Fallback to regular prediction
92
+ return self.predict_entities(text, labels, input_spans=input_spans)
93
+
94
+ kwargs = dict(
95
+ threshold=self.config.threshold,
96
+ flat_ner=self.config.flat_ner,
97
+ multi_label=self.config.multi_label,
98
+ return_class_probs=True
99
+ )
100
+ if input_spans is not None:
101
+ kwargs["input_spans"] = input_spans
102
+
103
+ entities = self.model.predict_with_embeds(
104
+ text,
105
+ embeddings,
106
+ labels,
107
+ **kwargs
108
+ )
109
+
110
+ return [
111
+ L3Entity(
112
+ text=e["text"],
113
+ label=e["label"],
114
+ start=e["start"],
115
+ end=e["end"],
116
+ score=e["score"],
117
+ class_probs=e.get("class_probs")
118
+ )
119
+ for e in entities
120
+ ]
121
+
122
+ def predict_entities(
123
+ self,
124
+ text: str,
125
+ labels: List[str],
126
+ input_spans: List[List[dict]] = None
127
+ ) -> List[L3Entity]:
128
+ """Predict entities using GLiNER
129
+
130
+ Args:
131
+ text: Input text
132
+ labels: List of label strings
133
+ input_spans: Optional list of span dicts with 'start' and 'end' keys
134
+ to constrain prediction to specific spans from L1
135
+ """
136
+ if not labels:
137
+ return []
138
+
139
+ kwargs = dict(
140
+ threshold=self.config.threshold,
141
+ flat_ner=self.config.flat_ner,
142
+ multi_label=self.config.multi_label,
143
+ return_class_probs=True
144
+ )
145
+ if input_spans is not None:
146
+ kwargs["input_spans"] = input_spans
147
+
148
+ entities = self.model.predict_entities(
149
+ text,
150
+ labels,
151
+ **kwargs
152
+ )
153
+
154
+ return [
155
+ L3Entity(
156
+ text=e["text"],
157
+ label=e["label"],
158
+ start=e["start"],
159
+ end=e["end"],
160
+ score=e["score"],
161
+ class_probs=e.get("class_probs")
162
+ )
163
+ for e in entities
164
+ ]
165
+
166
+ def filter_by_score(self, entities: List[L3Entity], threshold: float = None) -> List[L3Entity]:
167
+ """Filter entities by confidence score"""
168
+ threshold = threshold if threshold is not None else self.config.threshold
169
+ return [e for e in entities if e.score >= threshold]
170
+
171
+ def sort_by_position(self, entities: List[L3Entity]) -> List[L3Entity]:
172
+ """Sort entities by position in text"""
173
+ return sorted(entities, key=lambda e: e.start)
174
+
175
+ def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]:
176
+ """Remove duplicate entities"""
177
+ seen = set()
178
+ unique = []
179
+ for entity in entities:
180
+ key = (entity.text, entity.start, entity.end)
181
+ if key not in seen:
182
+ unique.append(entity)
183
+ seen.add(key)
184
+ return unique
glinker/l3/models.py ADDED
@@ -0,0 +1,48 @@
1
+ from pydantic import Field
2
+ from typing import Dict, List, Any, Optional
3
+ from glinker.core.base import BaseConfig, BaseInput, BaseOutput
4
+
5
+
6
+ class L3Config(BaseConfig):
7
+ model_name: str = Field(...)
8
+ token: str = Field(None)
9
+ device: str = Field("cpu")
10
+ threshold: float = Field(0.5)
11
+ flat_ner: bool = Field(True)
12
+ multi_label: bool = Field(False)
13
+ batch_size: int = Field(8)
14
+
15
+ # Embedding settings
16
+ use_precomputed_embeddings: bool = Field(
17
+ True,
18
+ description="Use precomputed embeddings from L2 candidates if available"
19
+ )
20
+ cache_embeddings: bool = Field(
21
+ False,
22
+ description="Cache computed embeddings back to L2"
23
+ )
24
+ max_length: int = Field(
25
+ None,
26
+ description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained."
27
+ )
28
+
29
+
30
+ # TODO replace candidates with labels
31
+ class L3Input(BaseInput):
32
+ texts: List[str] = Field(...)
33
+ labels: List[List[Any]] = Field(...)
34
+
35
+
36
+ class L3Entity(BaseOutput):
37
+ text: str
38
+ label: str
39
+ start: int
40
+ end: int
41
+ score: float
42
+ class_probs: Optional[Dict[str, float]] = Field(
43
+ None, description="Per-label class probabilities from GLiNER"
44
+ )
45
+
46
+
47
+ class L3Output(BaseOutput):
48
+ entities: List[List[L3Entity]] = Field(...)