glinker 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
glinker/core/dag.py ADDED
@@ -0,0 +1,898 @@
1
+ from typing import Dict, List, Set, Any, Optional, Literal, Union
2
+ from collections import defaultdict, deque, OrderedDict
3
+ from pydantic import BaseModel, Field
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ import re
7
+ import json
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # ============================================================================
14
+ # INPUT/OUTPUT CONFIG
15
+ # ============================================================================
16
+
17
+ class ReshapeConfig(BaseModel):
18
+ """Configuration for data reshaping"""
19
+ by: str = Field(..., description="Reference structure path: 'l1_result.entities'")
20
+ mode: Literal["flatten_per_group", "preserve_structure"] = Field(
21
+ "flatten_per_group",
22
+ description="Reshape mode"
23
+ )
24
+
25
+
26
+ class InputConfig(BaseModel):
27
+ """
28
+ Unified input data specification
29
+
30
+ Examples:
31
+ source: "l1_result"
32
+ fields: "entities[*].text"
33
+ reduce: "flatten"
34
+ """
35
+ source: str = Field(
36
+ ...,
37
+ description="Data source: key ('l1_result'), index ('outputs[-1]'), or '$input'"
38
+ )
39
+
40
+ fields: Union[str, List[str], None] = Field(
41
+ None,
42
+ description="JSONPath fields: 'entities[*].text' or ['label', 'score']"
43
+ )
44
+
45
+ reduce: Literal["all", "first", "last", "flatten"] = Field(
46
+ "all",
47
+ description="Reduction mode for lists"
48
+ )
49
+
50
+ reshape: Optional[ReshapeConfig] = Field(
51
+ None,
52
+ description="Data reshaping configuration"
53
+ )
54
+
55
+ template: Optional[str] = Field(
56
+ None,
57
+ description="Field concatenation template: '{label}: {description}'"
58
+ )
59
+
60
+ filter: Optional[str] = Field(
61
+ None,
62
+ description="Filter expression: 'score > 0.5'"
63
+ )
64
+
65
+ default: Any = None
66
+
67
+
68
+ class OutputConfig(BaseModel):
69
+ """Output specification"""
70
+ key: str = Field(..., description="Key for storing in context")
71
+ fields: Union[str, List[str], None] = Field(
72
+ None,
73
+ description="Fields to save (optional, defaults to all)"
74
+ )
75
+
76
+
77
+ # ============================================================================
78
+ # PIPE NODE
79
+ # ============================================================================
80
+
81
+ class PipeNode(BaseModel):
82
+ """
83
+ Single node in DAG pipeline
84
+
85
+ Represents one processing stage with:
86
+ - Inputs (where to get data)
87
+ - Processor (what to do)
88
+ - Output (where to store result)
89
+ - Dependencies (execution order)
90
+ """
91
+
92
+ id: str = Field(..., description="Unique node identifier")
93
+
94
+ processor: str = Field(..., description="Processor name from registry")
95
+
96
+ inputs: Dict[str, InputConfig] = Field(
97
+ default_factory=dict,
98
+ description="Input parameter mappings"
99
+ )
100
+
101
+ output: OutputConfig = Field(..., description="Output specification")
102
+
103
+ requires: List[str] = Field(
104
+ default_factory=list,
105
+ description="Explicit dependencies (node IDs)"
106
+ )
107
+
108
+ config: Dict[str, Any] = Field(
109
+ default_factory=dict,
110
+ description="Processor configuration"
111
+ )
112
+
113
+ schema: Optional[Dict[str, Any]] = Field(
114
+ None,
115
+ description="Schema for field mappings/transformations"
116
+ )
117
+
118
+ condition: Optional[str] = Field(
119
+ None,
120
+ description="Conditional execution expression"
121
+ )
122
+
123
+ class Config:
124
+ fields = {'schema': 'schema'}
125
+
126
+
127
+ # ============================================================================
128
+ # PIPE CONTEXT
129
+ # ============================================================================
130
+
131
+ class PipeContext:
132
+ """
133
+ Pipeline execution context
134
+
135
+ Stores all outputs from pipeline stages and provides unified access:
136
+ - By key: "l1_result"
137
+ - By index: "outputs[-1]" (last output)
138
+ - Pipeline input: "$input"
139
+ """
140
+
141
+ def __init__(self, pipeline_input: Any = None):
142
+ self._outputs: OrderedDict[str, Any] = OrderedDict()
143
+ self._execution_order: List[str] = []
144
+ self._pipeline_input = pipeline_input
145
+ self._metadata: Dict[str, Any] = {}
146
+
147
+ def set(self, key: str, value: Any, metadata: Optional[Dict] = None):
148
+ """
149
+ Store output
150
+
151
+ Args:
152
+ key: Output key
153
+ value: Output value
154
+ metadata: Optional metadata (timing, source, etc.)
155
+ """
156
+ self._outputs[key] = value
157
+ self._execution_order.append(key)
158
+
159
+ if metadata:
160
+ self._metadata[key] = metadata
161
+
162
+ def get(self, source: str) -> Any:
163
+ """
164
+ Unified data access
165
+
166
+ Examples:
167
+ - "$input" → pipeline input
168
+ - "outputs[-1]" → last output
169
+ - "outputs[0]" → first output
170
+ - "l1_result" → by key
171
+ """
172
+ if source == "$input":
173
+ return self._pipeline_input
174
+
175
+ if source.startswith("outputs["):
176
+ index_str = source.replace("outputs[", "").replace("]", "")
177
+ index = int(index_str)
178
+
179
+ if index < 0:
180
+ index = len(self._execution_order) + index
181
+
182
+ if 0 <= index < len(self._execution_order):
183
+ key = self._execution_order[index]
184
+ return self._outputs[key]
185
+
186
+ return None
187
+
188
+ return self._outputs.get(source)
189
+
190
+ def has(self, key: str) -> bool:
191
+ """Check if output exists"""
192
+ return key in self._outputs
193
+
194
+ def get_all_outputs(self) -> Dict[str, Any]:
195
+ """Get all outputs as dict"""
196
+ return dict(self._outputs)
197
+
198
+ def get_metadata(self, key: str) -> Optional[Dict[str, Any]]:
199
+ """Get metadata for output"""
200
+ return self._metadata.get(key)
201
+
202
+ def get_execution_order(self) -> List[str]:
203
+ """Get list of output keys in execution order"""
204
+ return list(self._execution_order)
205
+
206
+ @property
207
+ def data(self) -> Dict[str, Any]:
208
+ """For compatibility"""
209
+ return dict(self._outputs)
210
+
211
+ def to_dict(self) -> Dict[str, Any]:
212
+ """
213
+ Serialize context to dict
214
+
215
+ Returns:
216
+ Dict with full context state
217
+ """
218
+ def serialize(value):
219
+ if hasattr(value, 'dict'):
220
+ return {'__type__': 'pydantic', 'data': value.dict()}
221
+ elif isinstance(value, list):
222
+ return [serialize(item) for item in value]
223
+ elif isinstance(value, dict):
224
+ return {k: serialize(v) for k, v in value.items()}
225
+ elif isinstance(value, OrderedDict):
226
+ return {'__type__': 'OrderedDict', 'data': list(value.items())}
227
+ elif isinstance(value, (str, int, float, bool, type(None))):
228
+ return value
229
+ else:
230
+ return {'__type__': 'object', 'repr': repr(value)}
231
+
232
+ return {
233
+ 'outputs': {k: serialize(v) for k, v in self._outputs.items()},
234
+ 'execution_order': self._execution_order,
235
+ 'pipeline_input': serialize(self._pipeline_input),
236
+ 'metadata': self._metadata,
237
+ 'saved_at': datetime.now().isoformat()
238
+ }
239
+
240
+ @classmethod
241
+ def from_dict(cls, data: Dict[str, Any]) -> 'PipeContext':
242
+ """
243
+ Deserialize context from dict
244
+
245
+ Args:
246
+ data: Dict with saved state
247
+
248
+ Returns:
249
+ Restored PipeContext
250
+ """
251
+ def deserialize(value):
252
+ if isinstance(value, dict):
253
+ if '__type__' in value:
254
+ if value['__type__'] == 'OrderedDict':
255
+ return OrderedDict(value['data'])
256
+ elif value['__type__'] == 'pydantic':
257
+ return value['data']
258
+ elif value['__type__'] == 'object':
259
+ return value['repr']
260
+ return {k: deserialize(v) for k, v in value.items()}
261
+ elif isinstance(value, list):
262
+ return [deserialize(item) for item in value]
263
+ else:
264
+ return value
265
+
266
+ pipeline_input = deserialize(data.get('pipeline_input'))
267
+ context = cls(pipeline_input)
268
+
269
+ outputs_data = data.get('outputs', {})
270
+ for key, value in outputs_data.items():
271
+ context._outputs[key] = deserialize(value)
272
+
273
+ context._execution_order = data.get('execution_order', [])
274
+ context._metadata = data.get('metadata', {})
275
+
276
+ return context
277
+
278
+ def to_json(self, filepath: str = None, indent: int = 2) -> str:
279
+ """
280
+ Serialize to JSON
281
+
282
+ Args:
283
+ filepath: Path to save (optional)
284
+ indent: Indentation for formatting
285
+
286
+ Returns:
287
+ JSON string
288
+ """
289
+ data = self.to_dict()
290
+ json_str = json.dumps(data, indent=indent, ensure_ascii=False)
291
+
292
+ if filepath:
293
+ Path(filepath).write_text(json_str, encoding='utf-8')
294
+ logger.info(f"Context saved to {filepath}")
295
+
296
+ return json_str
297
+
298
+ @classmethod
299
+ def from_json(cls, json_data: str = None, filepath: str = None) -> 'PipeContext':
300
+ """
301
+ Load from JSON
302
+
303
+ Args:
304
+ json_data: JSON string (optional)
305
+ filepath: Path to JSON file (optional)
306
+
307
+ Returns:
308
+ Restored PipeContext
309
+ """
310
+ if filepath:
311
+ json_data = Path(filepath).read_text(encoding='utf-8')
312
+ logger.info(f"Context loaded from {filepath}")
313
+
314
+ if not json_data:
315
+ raise ValueError("Either json_data or filepath must be provided")
316
+
317
+ data = json.loads(json_data)
318
+ return cls.from_dict(data)
319
+
320
+
321
+ # ============================================================================
322
+ # FIELD RESOLVER
323
+ # ============================================================================
324
+
325
+ class FieldResolver:
326
+ """Resolve fields from data using path expressions"""
327
+
328
+ @staticmethod
329
+ def resolve(context: PipeContext, config: InputConfig) -> Any:
330
+ """Main resolve method"""
331
+ data = context.get(config.source)
332
+ if data is None:
333
+ return config.default
334
+
335
+ if config.fields:
336
+ data = FieldResolver._extract_fields(data, config.fields)
337
+
338
+ if config.template:
339
+ data = FieldResolver._format_template(data, config.template)
340
+
341
+ if isinstance(data, list) and config.reduce:
342
+ data = FieldResolver._apply_reduce(data, config.reduce)
343
+
344
+ if config.filter:
345
+ data = FieldResolver._apply_filter(data, config.filter)
346
+
347
+ return data
348
+
349
+ @staticmethod
350
+ def _extract_fields(data: Any, path: str) -> Any:
351
+ """
352
+ Extract field from data using path
353
+
354
+ Examples:
355
+ 'entities' -> data.entities
356
+ 'entities[*]' -> [item for item in data.entities]
357
+ 'entities[*].text' -> [item.text for item in data.entities]
358
+ 'entities[*][*].text' -> [[e.text for e in group] for group in data.entities]
359
+ """
360
+ parts = path.split('.')
361
+ current = data
362
+
363
+ for part in parts:
364
+ if '[' in part:
365
+ current = FieldResolver._handle_brackets(current, part)
366
+ else:
367
+ current = FieldResolver._get_attr(current, part)
368
+
369
+ if current is None:
370
+ return None
371
+
372
+ return current
373
+
374
+ @staticmethod
375
+ def _handle_brackets(data: Any, part: str) -> Any:
376
+ """Handle parts with brackets like 'entities[*]' or 'items[0]' or '[*][*]'"""
377
+ if part.startswith('['):
378
+ field_name = None
379
+ brackets = part
380
+ else:
381
+ bracket_idx = part.index('[')
382
+ field_name = part[:bracket_idx]
383
+ brackets = part[bracket_idx:]
384
+
385
+ current = data
386
+ if field_name:
387
+ current = FieldResolver._get_attr(current, field_name)
388
+
389
+ if current is None:
390
+ return None
391
+
392
+ while '[' in brackets:
393
+ start = brackets.index('[')
394
+ end = brackets.index(']')
395
+ content = brackets[start+1:end]
396
+ brackets = brackets[end+1:]
397
+
398
+ if content == '*':
399
+ if not isinstance(current, list):
400
+ current = [current]
401
+ elif ':' in content:
402
+ parts = content.split(':')
403
+ s = int(parts[0]) if parts[0] else None
404
+ e = int(parts[1]) if parts[1] else None
405
+ current = current[s:e]
406
+ else:
407
+ idx = int(content)
408
+ current = current[idx]
409
+
410
+ return current
411
+
412
+ @staticmethod
413
+ def _get_attr(data: Any, field: str) -> Any:
414
+ """Get attribute from data (works with dict, object, list)"""
415
+ if isinstance(data, list):
416
+ return [FieldResolver._get_attr(item, field) for item in data]
417
+
418
+ if isinstance(data, dict):
419
+ return data.get(field)
420
+
421
+ return getattr(data, field, None)
422
+
423
+ @staticmethod
424
+ def _format_template(data: Union[List[Any], Any], template: str) -> Union[List[str], str]:
425
+ """Format data using template"""
426
+ if isinstance(data, list):
427
+ results = []
428
+ for item in data:
429
+ try:
430
+ if hasattr(item, 'dict'):
431
+ results.append(template.format(**item.dict()))
432
+ elif isinstance(item, dict):
433
+ results.append(template.format(**item))
434
+ else:
435
+ results.append(str(item))
436
+ except:
437
+ results.append(str(item))
438
+ return results
439
+ else:
440
+ try:
441
+ if hasattr(data, 'dict'):
442
+ return template.format(**data.dict())
443
+ elif isinstance(data, dict):
444
+ return template.format(**data)
445
+ else:
446
+ return str(data)
447
+ except:
448
+ return str(data)
449
+
450
+ @staticmethod
451
+ def _apply_reduce(data: List[Any], mode: str) -> Any:
452
+ """Reduce list based on mode"""
453
+ if mode == "first":
454
+ return data[0] if data else None
455
+
456
+ elif mode == "last":
457
+ return data[-1] if data else None
458
+
459
+ elif mode == "flatten":
460
+ def flatten(lst):
461
+ result = []
462
+ for item in lst:
463
+ if isinstance(item, list):
464
+ result.extend(flatten(item))
465
+ else:
466
+ result.append(item)
467
+ return result
468
+
469
+ return flatten(data)
470
+
471
+ return data
472
+
473
+ @staticmethod
474
+ def _apply_filter(data: List[Any], filter_expr: str) -> List[Any]:
475
+ """Filter list based on expression"""
476
+ if not isinstance(data, list):
477
+ return data
478
+
479
+ pattern = r'(\w+)\s*(>=|<=|>|<|==|!=)\s*(.+)'
480
+ match = re.match(pattern, filter_expr)
481
+
482
+ if not match:
483
+ return data
484
+
485
+ field, operator, value = match.groups()
486
+
487
+ try:
488
+ if value.startswith("'") or value.startswith('"'):
489
+ value = value.strip("'\"")
490
+ elif '.' in value:
491
+ value = float(value)
492
+ else:
493
+ value = int(value)
494
+ except:
495
+ pass
496
+
497
+ result = []
498
+ for item in data:
499
+ try:
500
+ if isinstance(item, dict):
501
+ item_value = item.get(field)
502
+ else:
503
+ item_value = getattr(item, field, None)
504
+
505
+ if item_value is None:
506
+ continue
507
+
508
+ passes = False
509
+ if operator == '>':
510
+ passes = item_value > value
511
+ elif operator == '>=':
512
+ passes = item_value >= value
513
+ elif operator == '<':
514
+ passes = item_value < value
515
+ elif operator == '<=':
516
+ passes = item_value <= value
517
+ elif operator == '==':
518
+ passes = item_value == value
519
+ elif operator == '!=':
520
+ passes = item_value != value
521
+
522
+ if passes:
523
+ result.append(item)
524
+ except:
525
+ continue
526
+
527
+ return result
528
+
529
+
530
+ # ============================================================================
531
+ # DAG EXECUTOR
532
+ # ============================================================================
533
+
534
+ class DAGPipeline(BaseModel):
535
+ """DAG pipeline configuration"""
536
+ name: str = Field(...)
537
+ nodes: List[PipeNode] = Field(...)
538
+ description: Optional[str] = None
539
+
540
+
541
+ class DAGExecutor:
542
+ """Executes DAG pipeline with topological sort"""
543
+
544
+ def __init__(self, pipeline: DAGPipeline, verbose: bool = False):
545
+ self.pipeline = pipeline
546
+ self.verbose = verbose
547
+
548
+ self.nodes_map: Dict[str, PipeNode] = {}
549
+ self.dependency_graph: Dict[str, List[str]] = defaultdict(list)
550
+ self.reverse_graph: Dict[str, Set[str]] = defaultdict(set)
551
+ self.processors: Dict[str, Any] = {}
552
+
553
+ self._build_dependency_graph()
554
+ self._initialize_processors()
555
+
556
+ def _build_dependency_graph(self):
557
+ """Build dependency graph from nodes"""
558
+ for node in self.pipeline.nodes:
559
+ self.nodes_map[node.id] = node
560
+
561
+ for node in self.pipeline.nodes:
562
+ # Explicit dependencies
563
+ for dep_id in node.requires:
564
+ self.dependency_graph[dep_id].append(node.id)
565
+ self.reverse_graph[node.id].add(dep_id)
566
+
567
+ # Implicit dependencies from inputs
568
+ for input_config in node.inputs.values():
569
+ source = input_config.source
570
+
571
+ if source == "$input" or source.startswith("outputs["):
572
+ continue
573
+
574
+ if source in self.nodes_map:
575
+ self.dependency_graph[source].append(node.id)
576
+ self.reverse_graph[node.id].add(source)
577
+
578
+ if input_config.reshape and input_config.reshape.by:
579
+ reshape_source = input_config.reshape.by.split('.')[0]
580
+ if reshape_source in self.nodes_map:
581
+ if reshape_source not in self.reverse_graph[node.id]:
582
+ self.dependency_graph[reshape_source].append(node.id)
583
+ self.reverse_graph[node.id].add(reshape_source)
584
+
585
+ def _initialize_processors(self):
586
+ """Initialize all processors once"""
587
+ from .registry import processor_registry
588
+
589
+ if self.verbose:
590
+ logger.info(f"Initializing {len(self.nodes_map)} processors...")
591
+
592
+ for node_id, node in self.nodes_map.items():
593
+ try:
594
+ processor_factory = processor_registry.get(node.processor)
595
+ processor = processor_factory(config_dict=node.config, pipeline=None)
596
+ self.processors[node_id] = processor
597
+
598
+ if self.verbose:
599
+ logger.info(f" Created processor for '{node_id}' ({node.processor})")
600
+ except Exception as e:
601
+ raise RuntimeError(
602
+ f"Failed to create processor '{node.processor}' for node '{node_id}': {e}"
603
+ )
604
+
605
+ if self.verbose:
606
+ logger.info(f"All processors initialized and cached")
607
+
608
+ def _topological_sort(self) -> List[List[str]]:
609
+ """Topological sort with level grouping"""
610
+ in_degree = {}
611
+ for node_id in self.nodes_map:
612
+ in_degree[node_id] = len(self.reverse_graph.get(node_id, set()))
613
+
614
+ queue = deque([
615
+ node_id for node_id, degree in in_degree.items()
616
+ if degree == 0
617
+ ])
618
+
619
+ levels = []
620
+ visited = set()
621
+
622
+ while queue:
623
+ current_level = list(queue)
624
+ levels.append(current_level)
625
+
626
+ next_queue = deque()
627
+ for node_id in current_level:
628
+ visited.add(node_id)
629
+
630
+ for dependent_id in self.dependency_graph.get(node_id, []):
631
+ in_degree[dependent_id] -= 1
632
+
633
+ if in_degree[dependent_id] == 0:
634
+ next_queue.append(dependent_id)
635
+
636
+ queue = next_queue
637
+
638
+ if len(visited) != len(self.nodes_map):
639
+ unvisited = set(self.nodes_map.keys()) - visited
640
+ raise ValueError(
641
+ f"Cycle detected in pipeline DAG! "
642
+ f"Unvisited nodes: {unvisited}"
643
+ )
644
+
645
+ return levels
646
+
647
+ def load_entities(
648
+ self,
649
+ source,
650
+ target_layers: List[str] = None,
651
+ batch_size: int = 1000,
652
+ overwrite: bool = False
653
+ ) -> Dict[str, Dict[str, int]]:
654
+ """
655
+ Load entities into database layers
656
+
657
+ Finds all L2 processors and loads entities into their database layers.
658
+
659
+ Args:
660
+ source: entity data — file path (str/Path), list of dicts, or
661
+ dict mapping entity_id to entity data
662
+ target_layers: ['dict', 'redis', 'elasticsearch', 'postgres'] or None (all writable)
663
+ batch_size: batch size for bulk operations
664
+ overwrite: overwrite existing entities
665
+
666
+ Returns:
667
+ {'l2_node_id': {'redis': 1500, 'elasticsearch': 1500}}
668
+ """
669
+ results = {}
670
+
671
+ for node_id, processor in self.processors.items():
672
+ if hasattr(processor, 'component') and hasattr(processor.component, 'load_entities'):
673
+ if self.verbose:
674
+ logger.info(f"\nLoading entities for node '{node_id}'")
675
+
676
+ result = processor.component.load_entities(
677
+ source=source,
678
+ target_layers=target_layers,
679
+ batch_size=batch_size,
680
+ overwrite=overwrite
681
+ )
682
+ results[node_id] = result
683
+
684
+ return results
685
+
686
+ def clear_databases(self, layer_names: List[str] = None) -> Dict[str, bool]:
687
+ """Clear database layers in all L2 processors"""
688
+ results = {}
689
+
690
+ for node_id, processor in self.processors.items():
691
+ if hasattr(processor, 'component') and hasattr(processor.component, 'clear_layers'):
692
+ processor.component.clear_layers(layer_names)
693
+ results[node_id] = True
694
+
695
+ return results
696
+
697
+ def count_entities(self) -> Dict[str, Dict[str, int]]:
698
+ """Count entities in all database layers"""
699
+ results = {}
700
+
701
+ for node_id, processor in self.processors.items():
702
+ if hasattr(processor, 'component') and hasattr(processor.component, 'count_entities'):
703
+ counts = processor.component.count_entities()
704
+ results[node_id] = counts
705
+
706
+ return results
707
+
708
+ def precompute_embeddings(
709
+ self,
710
+ target_layers: List[str] = None,
711
+ batch_size: int = 32
712
+ ) -> Dict[str, int]:
713
+ """
714
+ Precompute embeddings for all entities using L3 model and L2 schema.
715
+
716
+ Uses:
717
+ - L3 processor's model for encoding
718
+ - L2 processor's schema for label formatting
719
+
720
+ Args:
721
+ target_layers: Layer types to update (e.g., ['dict', 'postgres'])
722
+ batch_size: Batch size for encoding
723
+
724
+ Returns:
725
+ Dict with count of updated entities per layer
726
+ """
727
+ # Find L2 and L3 processors
728
+ l2_processor = None
729
+ l3_processor = None
730
+ l2_node = None
731
+ l3_node = None
732
+
733
+ for node_id, processor in self.processors.items():
734
+ node = self.nodes_map[node_id]
735
+
736
+ # Check for L2 (has component with precompute_embeddings)
737
+ if hasattr(processor, 'component') and hasattr(processor.component, 'precompute_embeddings'):
738
+ l2_processor = processor
739
+ l2_node = node
740
+
741
+ # Check for L3 (has component with encode_labels)
742
+ if hasattr(processor, 'component') and hasattr(processor.component, 'encode_labels'):
743
+ l3_processor = processor
744
+ l3_node = node
745
+
746
+ if not l2_processor:
747
+ raise ValueError("No L2 processor found with precompute_embeddings support")
748
+
749
+ if not l3_processor:
750
+ raise ValueError("No L3 processor found with encode_labels support")
751
+
752
+ # Check if L3 model supports precomputed embeddings
753
+ if not l3_processor.component.supports_precomputed_embeddings:
754
+ raise ValueError(
755
+ f"L3 model '{l3_processor.config.model_name}' doesn't support label precomputation. "
756
+ "Only BiEncoder models support this feature."
757
+ )
758
+
759
+ # Get schema from L2 node (or L3 as fallback)
760
+ template = '{label}'
761
+ if l2_node and l2_node.schema:
762
+ template = l2_node.schema.get('template', '{label}')
763
+ elif l3_node and l3_node.schema:
764
+ template = l3_node.schema.get('template', '{label}')
765
+
766
+ # Apply schema to L2 processor
767
+ if l2_node and l2_node.schema:
768
+ l2_processor.schema = l2_node.schema
769
+
770
+ model_id = l3_processor.config.model_name
771
+
772
+ if self.verbose:
773
+ logger.info(f"\nPrecomputing embeddings:")
774
+ logger.info(f" Model: {model_id}")
775
+ logger.info(f" Template: {template}")
776
+ logger.info(f" Target layers: {target_layers or 'all'}")
777
+
778
+ # Create encoder function using L3 component
779
+ def encoder_fn(labels: List[str]):
780
+ return l3_processor.component.encode_labels(labels, batch_size=batch_size)
781
+
782
+ # Run precompute through L2 component
783
+ results = l2_processor.component.precompute_embeddings(
784
+ encoder_fn=encoder_fn,
785
+ template=template,
786
+ model_id=model_id,
787
+ target_layers=target_layers,
788
+ batch_size=batch_size
789
+ )
790
+
791
+ if self.verbose:
792
+ logger.info(f"\nPrecompute completed:")
793
+ for layer, count in results.items():
794
+ logger.info(f" {layer}: {count} entities")
795
+
796
+ return results
797
+
798
+ def setup_l3_cache_writeback(self):
799
+ """Setup L3 processor to write back embeddings to L2"""
800
+ l2_processor = None
801
+ l3_processor = None
802
+
803
+ for node_id, processor in self.processors.items():
804
+ if hasattr(processor, 'component') and hasattr(processor.component, 'precompute_embeddings'):
805
+ l2_processor = processor
806
+ if hasattr(processor, '_l2_processor'):
807
+ l3_processor = processor
808
+
809
+ if l2_processor and l3_processor:
810
+ l3_processor._l2_processor = l2_processor
811
+ if self.verbose:
812
+ logger.info("L3 cache write-back enabled")
813
+
814
+ def execute(self, pipeline_input: Any) -> PipeContext:
815
+ """Execute full pipeline"""
816
+ context = PipeContext(pipeline_input)
817
+ execution_levels = self._topological_sort()
818
+
819
+ if self.verbose:
820
+ logger.info(f"Executing pipeline: {self.pipeline.name}")
821
+ logger.info(f"Total nodes: {len(self.nodes_map)}")
822
+ logger.info(f"Execution levels: {len(execution_levels)}")
823
+
824
+ for level_idx, level_nodes in enumerate(execution_levels):
825
+ if self.verbose:
826
+ logger.info(f"\n{'='*60}")
827
+ logger.info(
828
+ f"Level {level_idx + 1}/{len(execution_levels)} "
829
+ f"({len(level_nodes)} nodes)"
830
+ )
831
+ logger.info(f"{'='*60}")
832
+
833
+ for node_id in level_nodes:
834
+ self._run_node(node_id, context)
835
+
836
+ if self.verbose:
837
+ logger.info(f"\nPipeline completed successfully!")
838
+
839
+ return context
840
+
841
+ def _run_node(self, node_id: str, context: PipeContext):
842
+ """Execute single node"""
843
+ node = self.nodes_map[node_id]
844
+
845
+ if self.verbose:
846
+ logger.info(f"\nExecuting: {node.id} (processor: {node.processor})")
847
+
848
+ if node.condition and not self._evaluate_condition(node.condition, context):
849
+ if self.verbose:
850
+ logger.info(f" Skipped (condition not met)")
851
+ return
852
+
853
+ # Resolve inputs
854
+ kwargs = {}
855
+ for param_name, input_config in node.inputs.items():
856
+ try:
857
+ value = FieldResolver.resolve(context, input_config)
858
+ kwargs[param_name] = value
859
+
860
+ if self.verbose:
861
+ logger.info(f" Input '{param_name}': {input_config.source}")
862
+ except Exception as e:
863
+ raise ValueError(
864
+ f"Failed to resolve input '{param_name}' for node '{node_id}': {e}"
865
+ )
866
+
867
+ # Get cached processor
868
+ processor = self.processors[node_id]
869
+
870
+ # Apply schema if needed
871
+ if node.schema and hasattr(processor, 'schema'):
872
+ processor.schema = node.schema
873
+
874
+ # Execute processor
875
+ try:
876
+ result = processor(**kwargs)
877
+
878
+ if self.verbose:
879
+ logger.info(f" Processing...")
880
+ except Exception as e:
881
+ if self.verbose:
882
+ logger.error(f" Failed: {e}")
883
+ raise RuntimeError(f"Node '{node_id}' failed: {e}")
884
+
885
+ # Extract output fields if specified
886
+ if node.output.fields:
887
+ result = FieldResolver._extract_fields(result, node.output.fields)
888
+
889
+ # Store output
890
+ context.set(node.output.key, result)
891
+
892
+ if self.verbose:
893
+ logger.info(f" Output: '{node.output.key}'")
894
+ logger.info(f" Success")
895
+
896
+ def _evaluate_condition(self, condition: str, context: PipeContext) -> bool:
897
+ """Evaluate conditional expression"""
898
+ return True