OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1262 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Learners for supervised and retrieval-augmented *term typing*.
16
+
17
+ This module implements two learners:
18
+
19
+ - **AlexbekRFLearner** (retriever/classifier):
20
+ Encodes terms with a Hugging Face encoder, optionally augments with simple
21
+ graph features, and trains a One-vs-Rest RandomForest for multi-label typing.
22
+
23
+ - **AlexbekRAGLearner** (retrieval-augmented generation):
24
+ Builds an in-memory example index with sentence embeddings, retrieves
25
+ nearest examples for each query term, then prompts an instruction-tuned
26
+ causal LLM to produce types, parsing the JSON response.
27
+
28
+ Both learners conform to the `AutoLearner` / `AutoRetriever` APIs used in
29
+ the outer pipeline.
30
+ """
31
+
32
+ import gc
33
+ import json
34
+ import re
35
+ from typing import Any, Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torch.nn.functional as F
40
+ import networkx as nx
41
+ from tqdm import tqdm
42
+ from sklearn.preprocessing import MultiLabelBinarizer
43
+ from sklearn.ensemble import RandomForestClassifier
44
+ from sklearn.multiclass import OneVsRestClassifier
45
+
46
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
47
+ from sentence_transformers import SentenceTransformer
48
+
49
+ from ...base import AutoLearner, AutoRetriever
50
+
51
+
52
+ class AlexbekRFLearner(AutoRetriever):
53
+ """
54
+ Embedding-based multi-label classifier for *term typing*.
55
+
56
+ Pipeline
57
+ 1) Load a Hugging Face encoder (tokenizer + model).
58
+ 2) Encode input terms into sentence embeddings.
59
+ 3) Optionally augment with simple graph (co-occurrence) features.
60
+ 4) Train a One-vs-Rest RandomForest on the concatenated features.
61
+ 5) Predict multi-label types with a probability threshold (fallback to top-1).
62
+
63
+ Implements the `AutoRetriever` interface used by the outer pipeline.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ device: str = "cpu",
69
+ batch_size: int = 16,
70
+ max_length: int = 256,
71
+ threshold: float = 0.30,
72
+ use_graph_features: bool = True,
73
+ rf_kwargs: Optional[Dict[str, Any]] = None,
74
+ ):
75
+ """Configure the RF-based multi-label learner.
76
+
77
+ Parameters
78
+ device:
79
+ Torch device spec ('cpu' or 'cuda').
80
+ batch_size:
81
+ Encoding mini-batch size for the transformer.
82
+ max_length:
83
+ Maximum input token length for the encoder tokenizer.
84
+ threshold:
85
+ Per-label probability threshold at prediction time.
86
+ use_graph_features:
87
+ If True, add simple graph features to embeddings.
88
+ rf_kwargs:
89
+ Optional RandomForest hyperparameters dictionary.
90
+
91
+ """
92
+ # Runtime / inference settings
93
+ self.device = torch.device(device)
94
+ self.batch_size = batch_size
95
+ self.max_length = max_length
96
+ self.threshold = threshold # probability cutoff for selecting labels
97
+ self.use_graph_features = use_graph_features
98
+
99
+ # RandomForest hyperparameters (with sensible defaults)
100
+ self.rf_kwargs = rf_kwargs or dict(
101
+ n_estimators=200, max_depth=20, class_weight="balanced", random_state=42
102
+ )
103
+
104
+ # Filled during load/fit
105
+ self.model_name: Optional[str] = None
106
+ self.tokenizer: Optional[AutoTokenizer] = None
107
+ self.embedding_model: Optional[AutoModel] = None
108
+
109
+ # Label processing / classifier / optional graph
110
+ self.label_binarizer = MultiLabelBinarizer()
111
+ self.ovr_random_forest: Optional[OneVsRestClassifier] = None
112
+ self.term_graph: Optional[nx.Graph] = None
113
+
114
+ def load(self, model_id: str, **_: Any) -> None:
115
+ """Load a Hugging Face encoder by model id (tokenizer + base model).
116
+
117
+ Parameters
118
+ model_id:
119
+ HF model identifier or local path for an encoder backbone.
120
+
121
+ Side Effects
122
+ - Sets `self.model_name`, `self.tokenizer`, `self.embedding_model`.
123
+ - Puts the model in eval mode and moves it to `self.device`.
124
+ """
125
+ self.model_name = model_id
126
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
127
+ self.embedding_model = AutoModel.from_pretrained(model_id)
128
+ self.embedding_model.eval().to(self.device)
129
+
130
+ def fit(self, data: Any, task: str, ontologizer: bool = True, **_: Any) -> None:
131
+ """Train the One-vs-Rest RandomForest on term embeddings (+ optional graph features).
132
+
133
+ Parameters
134
+ data:
135
+ Training payload; supported formats are routed via `_as_term_types_dicts`.
136
+ Each example must contain at least `{"term": str, "types": List[str]}`.
137
+ task:
138
+ Must be `'term-typing'`.
139
+ ontologizer:
140
+ Unused here; accepted for API compatibility.
141
+ **_:
142
+ Ignored extra arguments.
143
+
144
+ Raises
145
+ ValueError
146
+ If `task` is not `'term-typing'` or if no valid examples are found.
147
+ """
148
+ if task != "term-typing":
149
+ raise ValueError(
150
+ "OntologyTypeRFClassifier supports only task='term-typing'."
151
+ )
152
+
153
+ # Normalize incoming training data into a list of dicts: {term, types, RAG}
154
+ training_rows = self._as_term_types_dicts(data)
155
+ if not training_rows:
156
+ raise ValueError(
157
+ "No valid training examples found (need 'term' and 'types')."
158
+ )
159
+
160
+ # Split out terms and raw labels
161
+ training_terms: List[str] = [row["term"] for row in training_rows]
162
+ raw_label_lists: List[List[str]] = [row["types"] for row in training_rows]
163
+
164
+ # Fit label binarizer to learn label space/order
165
+ self.label_binarizer.fit(raw_label_lists)
166
+
167
+ # Encode terms to sentence embeddings
168
+ term_embeddings_train = self._encode(training_terms)
169
+
170
+ # Optionally build a light-weight co-occurrence graph and extract features
171
+ if self.use_graph_features:
172
+ self.term_graph = self._create_term_graph(training_rows)
173
+ graph_features_train = self._extract_graph_features(
174
+ self.term_graph, training_terms
175
+ )
176
+ X_train = np.hstack([term_embeddings_train, graph_features_train])
177
+ else:
178
+ self.term_graph = None
179
+ X_train = term_embeddings_train
180
+
181
+ # Multi-label targets (multi-hot)
182
+ Y_train = self.label_binarizer.transform(raw_label_lists)
183
+
184
+ # One-vs-Rest RandomForest (one binary RF per label)
185
+ self.ovr_random_forest = OneVsRestClassifier(
186
+ RandomForestClassifier(**self.rf_kwargs)
187
+ )
188
+ self.ovr_random_forest.fit(X_train, Y_train)
189
+
190
+ def predict(
191
+ self, data: Any, task: str, ontologizer: bool = True, **_: Any
192
+ ) -> List[Dict[str, Any]]:
193
+ """Predict multi-label types for input terms.
194
+
195
+ Parameters
196
+ data:
197
+ Evaluation payload; formats normalized by `_as_predict_terms_ids`.
198
+ task:
199
+ Must be `'term-typing'`.
200
+ ontologizer:
201
+ Unused here; accepted for API compatibility.
202
+ **_:
203
+ Ignored extra arguments.
204
+
205
+ Returns
206
+ List[Dict[str, Any]]
207
+ A list of dictionaries with keys:
208
+ - `id`: Original example id (if provided).
209
+ - `term`: Input term string.
210
+ - `types`: List of predicted label strings (selected by threshold or top-1).
211
+
212
+ Raises
213
+ ValueError
214
+ If `task` is not `'term-typing'`.
215
+ RuntimeError
216
+ If `load()` and `fit()` have not been called.
217
+ """
218
+ if task != "term-typing":
219
+ raise ValueError(
220
+ "OntologyTypeRFClassifier supports only task='term-typing'."
221
+ )
222
+ if (
223
+ self.ovr_random_forest is None
224
+ or self.tokenizer is None
225
+ or self.embedding_model is None
226
+ ):
227
+ raise RuntimeError("Call load() and fit() before predict().")
228
+
229
+ # Normalize prediction input into parallel lists of terms and example ids
230
+ test_terms, example_ids = self._as_predict_terms_ids(data)
231
+
232
+ # Encode terms
233
+ term_embeddings_test = self._encode(test_terms)
234
+
235
+ # Match feature layout used during training
236
+ if self.use_graph_features and self.term_graph is not None:
237
+ graph_features_test = self._extract_graph_features(
238
+ self.term_graph, test_terms
239
+ )
240
+ X_test = np.hstack([term_embeddings_test, graph_features_test])
241
+ else:
242
+ X_test = term_embeddings_test
243
+
244
+ # Probabilities per label (shape: [n_samples, n_labels])
245
+ probability_matrix = self.ovr_random_forest.predict_proba(X_test)
246
+
247
+ predictions: List[Dict[str, Any]] = []
248
+ label_names = self.label_binarizer.classes_
249
+ threshold = float(self.threshold)
250
+
251
+ # Select labels above threshold; fallback to argmax if none exceed it
252
+ for row_index, label_probabilities in enumerate(probability_matrix):
253
+ selected_label_indices = np.where(label_probabilities > threshold)[0]
254
+ if len(selected_label_indices) == 0:
255
+ selected_label_indices = [int(np.argmax(label_probabilities))]
256
+
257
+ predicted_types = [
258
+ label_names[label_idx] for label_idx in selected_label_indices
259
+ ]
260
+
261
+ predictions.append(
262
+ {
263
+ "id": example_ids[row_index],
264
+ "term": test_terms[row_index],
265
+ "types": predicted_types,
266
+ }
267
+ )
268
+ return predictions
269
+
270
+ def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, Any]]:
271
+ """Normalize ground-truth into a list of {id, term, types} dicts for evaluation.
272
+
273
+ Parameters
274
+ data:
275
+ Ground-truth payload; supported formats include objects exposing
276
+ `.term_typings`, a list of dicts, or a list of tuples/lists.
277
+ task:
278
+ Must be `'term-typing'`.
279
+
280
+ Returns
281
+ List[Dict[str, Any]]
282
+ A list of dictionaries with keys `id`, `term`, `types` (list of str).
283
+
284
+ Raises
285
+ ValueError
286
+ If `task` is not `'term-typing'`.
287
+ """
288
+ if task != "term-typing":
289
+ raise ValueError(
290
+ "OntologyTypeRFClassifier supports only task='term-typing'."
291
+ )
292
+ return self._as_gold_id_term_types(data)
293
+
294
+ def _encode(self, texts: List[str]) -> np.ndarray:
295
+ """Encode a list of strings into L2-normalized sentence embeddings.
296
+
297
+ Parameters
298
+ texts:
299
+ List of input texts/terms.
300
+
301
+ Returns
302
+ np.ndarray
303
+ Array of shape `(len(texts), hidden_size)` with L2-normalized
304
+ embeddings. If `texts` is empty, returns a `(0, hidden_size)` array.
305
+ """
306
+ assert self.tokenizer is not None and self.embedding_model is not None, (
307
+ "Call load(model_id) first."
308
+ )
309
+
310
+ if not texts:
311
+ hidden_size = getattr(
312
+ getattr(self.embedding_model, "config", None), "hidden_size", 768
313
+ )
314
+ return np.zeros((0, hidden_size), dtype=np.float32)
315
+
316
+ batch_embeddings: List[torch.Tensor] = []
317
+
318
+ for start_idx in tqdm(range(0, len(texts), self.batch_size), desc="Embedding"):
319
+ end_idx = start_idx + self.batch_size
320
+ batch_texts = texts[start_idx:end_idx]
321
+
322
+ # Tokenize and move to device
323
+ tokenized_batch = self.tokenizer(
324
+ batch_texts,
325
+ padding=True,
326
+ truncation=True,
327
+ max_length=self.max_length,
328
+ return_tensors="pt",
329
+ ).to(self.device)
330
+
331
+ # Forward pass without gradients
332
+ with torch.no_grad():
333
+ model_output = self.embedding_model(**tokenized_batch)
334
+
335
+ # Prefer dedicated pooler if provided; otherwise pool by last valid token
336
+ if (
337
+ hasattr(model_output, "pooler_output")
338
+ and model_output.pooler_output is not None
339
+ ):
340
+ sentence_embeddings = model_output.pooler_output
341
+ else:
342
+ sentence_embeddings = self._last_token_pool(
343
+ model_output.last_hidden_state,
344
+ tokenized_batch["attention_mask"],
345
+ )
346
+
347
+ # L2-normalize embeddings for stability
348
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
349
+
350
+ # Detach, move to CPU, collect
351
+ batch_embeddings.append(sentence_embeddings.detach().cpu())
352
+
353
+ # Best-effort memory cleanup (especially useful on CUDA)
354
+ del tokenized_batch, model_output, sentence_embeddings
355
+ if self.device.type == "cuda":
356
+ torch.cuda.empty_cache()
357
+ gc.collect()
358
+
359
+ # Concatenate all batches and convert to NumPy
360
+ return torch.cat(batch_embeddings, dim=0).numpy()
361
+
362
+ def _last_token_pool(
363
+ self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
364
+ ) -> torch.Tensor:
365
+ """Select the last *non-padding* token embedding for each sequence.
366
+
367
+ Parameters
368
+ last_hidden_states:
369
+ Tensor of shape `(batch, seq_len, hidden)`.
370
+ attention_mask:
371
+ Tensor of shape `(batch, seq_len)` with 1 for real tokens.
372
+
373
+ Returns
374
+ torch.Tensor
375
+ Tensor of shape `(batch, hidden)` with per-sequence pooled embeddings.
376
+ """
377
+ last_valid_token_idx = attention_mask.sum(dim=1) - 1 # (batch,)
378
+ batch_row_idx = torch.arange(
379
+ last_hidden_states.size(0), device=last_hidden_states.device
380
+ )
381
+ return last_hidden_states[batch_row_idx, last_valid_token_idx]
382
+
383
+ def _create_term_graph(self, training_rows: List[Dict[str, Any]]) -> nx.Graph:
384
+ """Create a simple undirected co-occurrence graph from training rows.
385
+
386
+ Graph Structure
387
+ Nodes
388
+ Terms (node attribute `'types'` is stored per term).
389
+ Edges
390
+ Between a term and each neighbor from its optional RAG list.
391
+ Edge weight = number of shared types (or 0.1 if none shared).
392
+
393
+ Parameters
394
+ training_rows:
395
+ Normalized rows with keys: `'term'`, `'types'`, optional `'RAG'`.
396
+
397
+ Returns
398
+ networkx.Graph
399
+ The constructed undirected graph.
400
+ """
401
+ graph = nx.Graph()
402
+
403
+ for row in training_rows:
404
+ term = row["term"]
405
+ term_types = row.get("types", [])
406
+ graph.add_node(term, types=term_types)
407
+
408
+ # RAG may be a list of neighbor dicts like {"term": ..., "types": [...]}
409
+ for neighbor in row.get("RAG", []) or []:
410
+ neighbor_term = neighbor.get("term")
411
+ neighbor_types = neighbor.get("types", [])
412
+
413
+ # Shared-type-based edge weight (weak edge if no overlap)
414
+ shared_types = set(term_types).intersection(set(neighbor_types))
415
+ edge_weight = float(len(shared_types)) if shared_types else 0.1
416
+
417
+ graph.add_edge(term, neighbor_term, weight=edge_weight)
418
+
419
+ return graph
420
+
421
+ def _extract_graph_features(
422
+ self, term_graph: nx.Graph, terms: List[str]
423
+ ) -> np.ndarray:
424
+ """Compute simple per-term graph features.
425
+
426
+ Feature Vector
427
+ For each term we compute a 4-dim vector:
428
+ `[degree, clustering_coefficient, degree_centrality, pagerank_score]`
429
+
430
+ Parameters
431
+ term_graph:
432
+ Graph built over training terms.
433
+ terms:
434
+ List of term strings to extract features for.
435
+
436
+ Returns
437
+ np.ndarray
438
+ Array of shape `(len(terms), 4)` (dtype float32).
439
+ """
440
+ if len(term_graph):
441
+ degree_centrality = nx.degree_centrality(term_graph)
442
+ pagerank_scores = nx.pagerank(term_graph)
443
+ else:
444
+ degree_centrality, pagerank_scores = {}, {}
445
+
446
+ feature_rows: List[List[float]] = []
447
+ for term in terms:
448
+ if term in term_graph:
449
+ feature_rows.append(
450
+ [
451
+ float(term_graph.degree(term)),
452
+ float(nx.clustering(term_graph, term)),
453
+ float(degree_centrality.get(term, 0.0)),
454
+ float(pagerank_scores.get(term, 0.0)),
455
+ ]
456
+ )
457
+ else:
458
+ feature_rows.append([0.0, 0.0, 0.0, 0.0])
459
+
460
+ return np.asarray(feature_rows, dtype=np.float32)
461
+
462
+ def _as_term_types_dicts(self, data: Any) -> List[Dict[str, Any]]:
463
+ """Normalize diverse training data formats to a list of dicts: {term, types, RAG}.
464
+
465
+ Supported Inputs
466
+ - Object with attribute `.term_typings` (iterable of items exposing
467
+ `.term`, `.types`, optional `.RAG`).
468
+ - List of dicts with keys `term`, `types`, optional `RAG`.
469
+ - List/tuple of `(term, types[, RAG])`.
470
+
471
+ Parameters
472
+ data:
473
+ Training payload.
474
+
475
+ Returns
476
+ List[Dict[str, Any]]
477
+ Normalized dictionaries ready for training.
478
+
479
+ Raises
480
+ ValueError
481
+ If `data` is neither a list/tuple nor exposes `.term_typings`.
482
+ """
483
+ normalized_rows: List[Dict[str, Any]] = []
484
+
485
+ # Case 1: object with attribute `.term_typings`
486
+ term_typings_attr = getattr(data, "term_typings", None)
487
+ if term_typings_attr is not None:
488
+ for item in term_typings_attr:
489
+ term_text = getattr(item, "term", None)
490
+ type_list = getattr(item, "types", None)
491
+ rag_neighbors = getattr(item, "RAG", None)
492
+ if term_text is None or type_list is None:
493
+ continue
494
+ if not isinstance(type_list, list):
495
+ type_list = [type_list]
496
+ normalized_rows.append(
497
+ {
498
+ "term": str(term_text),
499
+ "types": [str(x) for x in type_list],
500
+ "RAG": rag_neighbors,
501
+ }
502
+ )
503
+ return normalized_rows
504
+
505
+ # Otherwise: must be a list/tuple-like container
506
+ if not isinstance(data, (list, tuple)):
507
+ raise ValueError(
508
+ "Training data must be a list/tuple or expose .term_typings"
509
+ )
510
+
511
+ if not data:
512
+ return normalized_rows
513
+
514
+ # Case 2: list of dicts
515
+ if isinstance(data[0], dict):
516
+ for row in data:
517
+ term_text = row.get("term")
518
+ type_list = row.get("types")
519
+ rag_neighbors = row.get("RAG")
520
+ if term_text is None or type_list is None:
521
+ continue
522
+ if not isinstance(type_list, list):
523
+ type_list = [type_list]
524
+ normalized_rows.append(
525
+ {
526
+ "term": str(term_text),
527
+ "types": [str(x) for x in type_list],
528
+ "RAG": rag_neighbors,
529
+ }
530
+ )
531
+ return normalized_rows
532
+
533
+ # Case 3: list of tuples/lists: (term, types[, RAG])
534
+ for item in data:
535
+ if not isinstance(item, (list, tuple)) or len(item) < 2:
536
+ continue
537
+ term_text, type_list = item[0], item[1]
538
+ rag_neighbors = item[2] if len(item) > 2 else None
539
+ if term_text is None or type_list is None:
540
+ continue
541
+ if not isinstance(type_list, list):
542
+ type_list = [type_list]
543
+ normalized_rows.append(
544
+ {
545
+ "term": str(term_text),
546
+ "types": [str(x) for x in type_list],
547
+ "RAG": rag_neighbors,
548
+ }
549
+ )
550
+
551
+ return normalized_rows
552
+
553
+ def _as_predict_terms_ids(self, data: Any) -> Tuple[List[str], List[Any]]:
554
+ """Normalize prediction input into parallel lists: (terms, ids).
555
+
556
+ Supported Inputs
557
+ - Object with `.term_typings`.
558
+ - List of dicts with `term` and optional `id`.
559
+ - List of tuples/lists `(term, id[, ...])`.
560
+ - List of plain term strings.
561
+
562
+ Parameters
563
+ data:
564
+ Evaluation payload.
565
+
566
+ Returns
567
+ Tuple[List[str], List[Any]]
568
+ `(terms, example_ids)` lists aligned by index.
569
+
570
+ Raises
571
+ ValueError
572
+ If the input format is unsupported.
573
+ """
574
+ terms: List[str] = []
575
+ example_ids: List[Any] = []
576
+
577
+ # Case 1: object with attribute `.term_typings`
578
+ term_typings_attr = getattr(data, "term_typings", None)
579
+ if term_typings_attr is not None:
580
+ for idx, item in enumerate(term_typings_attr):
581
+ terms.append(str(getattr(item, "term", "")))
582
+ example_ids.append(getattr(item, "id", getattr(item, "ID", idx)))
583
+ return terms, example_ids
584
+
585
+ # Case 2: list/tuple container
586
+ if isinstance(data, (list, tuple)) and data:
587
+ first_element = data[0]
588
+
589
+ # 2a) list of dicts
590
+ if isinstance(first_element, dict):
591
+ for i, row in enumerate(data):
592
+ terms.append(str(row.get("term", "")))
593
+ example_ids.append(row.get("id", row.get("ID", i)))
594
+ return terms, example_ids
595
+
596
+ # 2b) list of tuples/lists: (term, id[, ...])
597
+ if isinstance(first_element, (list, tuple)):
598
+ for i, tuple_row in enumerate(data):
599
+ if not tuple_row:
600
+ continue
601
+ terms.append(str(tuple_row[0]))
602
+ example_ids.append(tuple_row[1] if len(tuple_row) > 1 else i)
603
+ return terms, example_ids
604
+
605
+ # 2c) list of strings (terms only)
606
+ if isinstance(first_element, str):
607
+ terms = [str(x) for x in data] # type: ignore[arg-type]
608
+ example_ids = list(range(len(terms)))
609
+ return terms, example_ids
610
+
611
+ raise ValueError("Unsupported predict() input format.")
612
+
613
+ def _as_gold_id_term_types(self, data: Any) -> List[Dict[str, Any]]:
614
+ """Normalize gold labels into a list of dicts: {id, term, types}.
615
+
616
+ Supported Inputs
617
+ Mirrors `_as_term_types_dicts`, but ensures an `id` is set.
618
+
619
+ Parameters
620
+ data:
621
+ Ground-truth payload.
622
+
623
+ Returns
624
+ List[Dict[str, Any]]
625
+ `{'id': Any, 'term': str, 'types': List[str]}` entries.
626
+
627
+ """
628
+ gold_rows: List[Dict[str, Any]] = []
629
+
630
+ # Case 1: object with attribute `.term_typings`
631
+ term_typings_attr = getattr(data, "term_typings", None)
632
+ if term_typings_attr is not None:
633
+ for idx, item in enumerate(term_typings_attr):
634
+ gold_id = getattr(item, "id", getattr(item, "ID", idx))
635
+ term_text = str(getattr(item, "term", ""))
636
+ type_list = getattr(item, "types", [])
637
+ if not isinstance(type_list, list):
638
+ type_list = [type_list]
639
+ gold_rows.append(
640
+ {
641
+ "id": gold_id,
642
+ "term": term_text,
643
+ "types": [str(t) for t in type_list],
644
+ }
645
+ )
646
+ return gold_rows
647
+
648
+ # Case 2: list/tuple container
649
+ if isinstance(data, (list, tuple)) and data:
650
+ first_element = data[0]
651
+
652
+ # 2a) list of dicts
653
+ if isinstance(first_element, dict):
654
+ for i, row in enumerate(data):
655
+ gold_id = row.get("id", row.get("ID", i))
656
+ term_text = str(row.get("term", ""))
657
+ type_list = row.get("types", [])
658
+ if not isinstance(type_list, list):
659
+ type_list = [type_list]
660
+ gold_rows.append(
661
+ {
662
+ "id": gold_id,
663
+ "term": term_text,
664
+ "types": [str(t) for t in type_list],
665
+ }
666
+ )
667
+ return gold_rows
668
+
669
+ # 2b) list of tuples/lists: (term, types[, id])
670
+ if isinstance(first_element, (list, tuple)):
671
+ for i, tuple_row in enumerate(data):
672
+ if not tuple_row or len(tuple_row) < 2:
673
+ continue
674
+ term_text = str(tuple_row[0])
675
+ type_list = tuple_row[1]
676
+ gold_id = tuple_row[2] if len(tuple_row) > 2 else i
677
+ if not isinstance(type_list, list):
678
+ type_list = [type_list]
679
+ gold_rows.append(
680
+ {
681
+ "id": gold_id,
682
+ "term": term_text,
683
+ "types": [str(t) for t in type_list],
684
+ }
685
+ )
686
+ return gold_rows
687
+
688
+ raise ValueError(
689
+ "Unsupported ground-truth input format for tasks_ground_truth_former()."
690
+ )
691
+
692
+
693
+ class AlexbekRAGLearner(AutoLearner):
694
+ """Retrieval-Augmented Term Typing learner (single task: term-typing).
695
+
696
+ Flow
697
+ 1) `fit`: collect (term -> [types]) examples, build an in-memory index
698
+ using a sentence-embedding model.
699
+ 2) `predict`: for each new term, retrieve top-k similar examples, compose a
700
+ structured prompt, query an instruction-tuned causal LLM, and parse types.
701
+
702
+ Returns
703
+ List[Dict[str, Any]]
704
+ `{"term": str, "types": List[str], "id": Optional[str]}` rows.
705
+ """
706
+
707
+ def __init__(
708
+ self,
709
+ llm_model_id: str = "Qwen/Qwen2.5-0.5B-Instruct",
710
+ retriever_model_id: str = "sentence-transformers/all-MiniLM-L6-v2",
711
+ device: str = "auto", # "auto" | "cuda" | "cpu"
712
+ token: str = "", # HF token if needed
713
+ top_k: int = 3,
714
+ max_new_tokens: int = 256,
715
+ gen_batch_size: int = 4, # generation batch size
716
+ enc_batch_size: int = 64, # embedding batch size
717
+ **kwargs: Any, # absorb extra pipeline-style args
718
+ ) -> None:
719
+ """Configure the RAG learner.
720
+
721
+ Parameters
722
+ llm_model_id:
723
+ HF model id/path for the instruction-tuned causal LLM.
724
+ retriever_model_id:
725
+ Sentence-embedding model id for retrieval.
726
+ device:
727
+ Device policy ('auto'|'cuda'|'cpu') for the LLM.
728
+ token:
729
+ Optional HF token for gated models.
730
+ top_k:
731
+ Number of nearest examples to retrieve per query term.
732
+ max_new_tokens:
733
+ Decoding budget for the LLM.
734
+ gen_batch_size:
735
+ Number of prompts per generation batch.
736
+ enc_batch_size:
737
+ Number of texts per embedding batch.
738
+ **kwargs:
739
+ Extra configuration captured for downstream use.
740
+ """
741
+ super().__init__()
742
+
743
+ # Consolidated configuration for simple serialization
744
+ self.cfg: Dict[str, Any] = {
745
+ "llm_model_id": llm_model_id,
746
+ "retriever_model_id": retriever_model_id,
747
+ "device": device,
748
+ "token": token,
749
+ "top_k": int(top_k),
750
+ "max_new_tokens": int(max_new_tokens),
751
+ "gen_batch_size": int(gen_batch_size),
752
+ "enc_batch_size": int(enc_batch_size),
753
+ }
754
+ self.extra_cfg: Dict[str, Any] = dict(kwargs)
755
+
756
+ # LLM components
757
+ self.tokenizer: Optional[AutoTokenizer] = None
758
+ self.generation_model: Optional[AutoModelForCausalLM] = None
759
+
760
+ # Retriever components
761
+ self.embedder: Optional[SentenceTransformer] = None
762
+ self.indexed_corpus: List[str] = [] # items: "<term> || [<types>...]"
763
+ self.corpus_embeddings: Optional[torch.Tensor] = None
764
+
765
+ # Training cache of (term, [types]) tuples
766
+ self.train_term_types: List[Tuple[str, List[str]]] = []
767
+
768
+ # Prompt templates
769
+ self._system_prompt: str = (
770
+ "You are an expert in ontologies and semantic term classification.\n"
771
+ "Task: determine semantic types for the TERM using the EXAMPLES provided.\n"
772
+ "Rules:\n"
773
+ "1) Types must be generalizing categories from the domain ontology.\n"
774
+ "2) Be concise. Respond ONLY in JSON using double quotes.\n"
775
+ 'Format: {"term":"...", "reasoning":"<<=100 words>>", "types":["...", "..."]}\n'
776
+ )
777
+ self._user_prompt_template: str = """{examples}
778
+
779
+ TERM: {term}
780
+
781
+ TASK: Determine semantic types for the given term based on the domain ontology.
782
+ Remember: types are generalizing categories, not the term itself. Respond in JSON.
783
+ """
784
+
785
+ def load(
786
+ self,
787
+ model_id: Optional[str] = None,
788
+ retriever_id: Optional[str] = None,
789
+ device: Optional[str] = None,
790
+ token: Optional[str] = None,
791
+ **kwargs: Any,
792
+ ) -> None:
793
+ """Load the LLM and the embedding retriever. Overrides constructor values if provided.
794
+
795
+ Parameters
796
+ model_id:
797
+ Optional override for the LLM model id.
798
+ retriever_id:
799
+ Optional override for the embedding model id.
800
+ device:
801
+ Optional override for device selection policy.
802
+ token:
803
+ Optional override for HF token.
804
+ **kwargs:
805
+ Extra values to store in `extra_cfg`.
806
+
807
+ """
808
+ if model_id is not None:
809
+ self.cfg["llm_model_id"] = model_id
810
+ if retriever_id is not None:
811
+ self.cfg["retriever_model_id"] = retriever_id
812
+ if device is not None:
813
+ self.cfg["device"] = device
814
+ if token is not None:
815
+ self.cfg["token"] = token
816
+ self.extra_cfg.update(kwargs)
817
+
818
+ # Choose device & dtype for the LLM
819
+ cuda_available: bool = torch.cuda.is_available()
820
+ use_cuda: bool = cuda_available and (self.cfg["device"] != "cpu")
821
+ device_map: str = "auto" if use_cuda else "cpu"
822
+ torch_dtype = torch.bfloat16 if use_cuda else torch.float32
823
+
824
+ # Tokenizer
825
+ self.tokenizer = AutoTokenizer.from_pretrained(
826
+ self.cfg["llm_model_id"], padding_side="left", token=self.cfg["token"]
827
+ )
828
+ if self.tokenizer.pad_token is None:
829
+ self.tokenizer.pad_token = self.tokenizer.eos_token
830
+
831
+ # LLM
832
+ self.generation_model = AutoModelForCausalLM.from_pretrained(
833
+ self.cfg["llm_model_id"],
834
+ device_map=device_map,
835
+ torch_dtype=torch_dtype,
836
+ token=self.cfg["token"],
837
+ )
838
+
839
+ # Deterministic decoding defaults
840
+ generation_cfg = self.generation_model.generation_config
841
+ generation_cfg.do_sample = False
842
+ generation_cfg.temperature = None
843
+ generation_cfg.top_p = None
844
+ generation_cfg.top_k = None
845
+ generation_cfg.num_beams = 1
846
+
847
+ # Retriever
848
+ self.embedder = SentenceTransformer(
849
+ self.cfg["retriever_model_id"], trust_remote_code=True
850
+ )
851
+
852
+ def fit(self, train_data: Any, task: str, ontologizer: bool = True) -> None:
853
+ """Prepare the retrieval index from training examples.
854
+
855
+ Parameters
856
+ train_data:
857
+ Training payload containing terms and their types.
858
+ task:
859
+ Must be `'term-typing'`; other tasks are forwarded to base.
860
+ ontologizer:
861
+ Unused flag for API compatibility.
862
+
863
+ Side Effects
864
+ - Normalizes to a list of `(term, [types])`.
865
+ - Builds an indexable text corpus and (if embedder is loaded)
866
+ computes embeddings for retrieval.
867
+ """
868
+ if task != "term-typing":
869
+ return super().fit(train_data, task, ontologizer)
870
+
871
+ # Normalize incoming training data -> list[(term, [types])]
872
+ self.train_term_types = self._unpack_train(train_data)
873
+
874
+ # Build the textual corpus to index
875
+ self.indexed_corpus = [
876
+ f"{term} || {json.dumps(types, ensure_ascii=False)}"
877
+ for term, types in self.train_term_types
878
+ ]
879
+
880
+ # Embed the corpus if available; else fall back to zero-shot prompting
881
+ if self.indexed_corpus and self.embedder is not None:
882
+ self.corpus_embeddings = self._encode_texts(self.indexed_corpus)
883
+ else:
884
+ self.corpus_embeddings = None
885
+
886
+ def predict(self, eval_data: Any, task: str, ontologizer: bool = True) -> Any:
887
+ """Predict types for evaluation items; returns a list of {term, types, id?}.
888
+
889
+ Parameters
890
+ eval_data:
891
+ Evaluation payload to type (terms + optional ids).
892
+ task:
893
+ Must be `'term-typing'`; other tasks are forwarded to base.
894
+ ontologizer:
895
+ Unused flag for API compatibility.
896
+
897
+ Returns
898
+ List[Dict[str, Any]]
899
+ For each input term, a dictionary with keys:
900
+ - `term`: The input term.
901
+ - `types`: A (unique, sorted) list of predicted types.
902
+ - `id`: Optional example id (if provided in input).
903
+ """
904
+ if task != "term-typing":
905
+ return super().predict(eval_data, task, ontologizer)
906
+
907
+ eval_terms, eval_ids = self._unpack_eval(eval_data)
908
+ if not eval_terms:
909
+ return []
910
+
911
+ # Use RAG if we have an indexed corpus & embeddings; otherwise zero-shot
912
+ rag_available = (
913
+ self.corpus_embeddings is not None
914
+ and self.embedder is not None
915
+ and len(self.indexed_corpus) > 0
916
+ )
917
+
918
+ if rag_available:
919
+ neighbor_docs_per_query = self._retrieve_batch(
920
+ eval_terms, top_k=int(self.cfg["top_k"])
921
+ )
922
+ else:
923
+ neighbor_docs_per_query = [[] for _ in eval_terms]
924
+
925
+ # Compose prompts
926
+ prompts: List[str] = []
927
+ for term, neighbor_docs in zip(eval_terms, neighbor_docs_per_query):
928
+ example_pairs = self._decode_examples(neighbor_docs)
929
+ examples_block = self._format_examples(example_pairs)
930
+ prompt_text = self._compose_prompt(examples_block, term)
931
+ prompts.append(prompt_text)
932
+
933
+ predicted_types_lists = self._generate_and_parse(prompts)
934
+
935
+ # Build standardized results
936
+ results: List[Dict[str, Any]] = []
937
+ for term, example_id, predicted_types in zip(
938
+ eval_terms, eval_ids, predicted_types_lists
939
+ ):
940
+ result_row: Dict[str, Any] = {
941
+ "term": term,
942
+ "types": sorted({t for t in predicted_types}), # unique + sorted
943
+ }
944
+ if example_id is not None:
945
+ result_row["id"] = example_id
946
+ results.append(result_row)
947
+
948
+ assert all(("term" in row and "types" in row) for row in results), (
949
+ "predict() must return term + types"
950
+ )
951
+ return results
952
+
953
+ def _unpack_train(self, data: Any) -> List[Tuple[str, List[str]]]:
954
+ """Extract `(term, [types])` tuples from supported training payloads.
955
+
956
+ Supported Inputs
957
+ - `data.term_typings` (objects exposing `.term` & `.types`)
958
+ - `list[dict]` with keys `'term'` and `'types'`
959
+ - `list[str]` → returns empty (nothing to index)
960
+ - other formats → empty
961
+
962
+ Parameters
963
+ data:
964
+ Training payload.
965
+
966
+ Returns
967
+ List[Tuple[str, List[str]]]
968
+ (term, types) tuples (types kept as strings).
969
+ """
970
+ term_typings = getattr(data, "term_typings", None)
971
+ if term_typings is not None:
972
+ parsed_pairs: List[Tuple[str, List[str]]] = []
973
+ for item in term_typings:
974
+ term = getattr(item, "term", None)
975
+ types = list(getattr(item, "types", []) or [])
976
+ if term and types:
977
+ parsed_pairs.append(
978
+ (term, [t for t in types if isinstance(t, str)])
979
+ )
980
+ return parsed_pairs
981
+
982
+ if isinstance(data, list) and data and isinstance(data[0], dict):
983
+ parsed_pairs = []
984
+ for row in data:
985
+ term = row.get("term")
986
+ types = row.get("types") or []
987
+ if term and isinstance(types, list) and types:
988
+ parsed_pairs.append(
989
+ (term, [t for t in types if isinstance(t, str)])
990
+ )
991
+ return parsed_pairs
992
+
993
+ # If only a list of strings is provided, there's nothing to index for RAG
994
+ if isinstance(data, (list, set, tuple)) and all(
995
+ isinstance(x, str) for x in data
996
+ ):
997
+ return []
998
+
999
+ return []
1000
+
1001
+ def _unpack_eval(self, data: Any) -> Tuple[List[str], List[Optional[str]]]:
1002
+ """Extract `(terms, ids)` from supported evaluation payloads.
1003
+
1004
+ Supported Inputs
1005
+ - `data.term_typings` (objects exposing `.term` & optional `.id`)
1006
+ - `list[str]`
1007
+ - `list[dict]` with `term` and optional `id`
1008
+
1009
+ Parameters
1010
+ data:
1011
+ Evaluation payload.
1012
+
1013
+ Returns
1014
+ Tuple[List[str], List[Optional[str]]]
1015
+ Two lists aligned by index: terms and ids (ids may contain `None`).
1016
+ """
1017
+ term_typings = getattr(data, "term_typings", None)
1018
+ if term_typings is not None:
1019
+ terms: List[str] = []
1020
+ ids: List[Optional[str]] = []
1021
+ for item in term_typings:
1022
+ terms.append(getattr(item, "term", ""))
1023
+ ids.append(getattr(item, "id", None))
1024
+ return terms, ids
1025
+
1026
+ if isinstance(data, list) and data and isinstance(data[0], str):
1027
+ return list(data), [None] * len(data)
1028
+
1029
+ if isinstance(data, list) and data and isinstance(data[0], dict):
1030
+ terms: List[str] = []
1031
+ ids: List[Optional[str]] = []
1032
+ for row in data:
1033
+ terms.append(row.get("term", ""))
1034
+ ids.append(row.get("id"))
1035
+ return terms, ids
1036
+
1037
+ return [], []
1038
+
1039
+ def _encode_texts(self, texts: List[str]) -> torch.Tensor:
1040
+ """Encode a batch of texts with the sentence-embedding model.
1041
+
1042
+ Parameters
1043
+ texts:
1044
+ List of strings to embed.
1045
+
1046
+ Returns
1047
+ torch.Tensor
1048
+ Tensor of shape `(len(texts), hidden_dim)`. If `texts` is empty,
1049
+ returns an empty tensor with 0 rows.
1050
+ """
1051
+ batch_size = int(self.cfg["enc_batch_size"])
1052
+ batch_embeddings: List[torch.Tensor] = []
1053
+
1054
+ for batch_start in range(0, len(texts), batch_size):
1055
+ batch_texts = texts[batch_start : batch_start + batch_size]
1056
+ embeddings = self.embedder.encode(
1057
+ batch_texts, convert_to_tensor=True, show_progress_bar=False
1058
+ )
1059
+ batch_embeddings.append(embeddings)
1060
+
1061
+ return (
1062
+ torch.cat(batch_embeddings, dim=0) if batch_embeddings else torch.empty(0)
1063
+ )
1064
+
1065
+ def _retrieve_batch(self, queries: List[str], top_k: int) -> List[List[str]]:
1066
+ """Return for each query the top-k most similar corpus entries.
1067
+
1068
+ Parameters
1069
+ queries:
1070
+ List of query terms.
1071
+ top_k:
1072
+ Number of neighbors to retrieve for each query.
1073
+
1074
+ Returns
1075
+ List[List[str]]
1076
+ For each query, a list of raw corpus strings formatted as
1077
+ `"<term> || [\\"type1\\", ...]"`.
1078
+ """
1079
+ if self.corpus_embeddings is None or not self.indexed_corpus:
1080
+ return [[] for _ in queries]
1081
+
1082
+ query_embeddings = self._encode_texts(queries) # [Q, D]
1083
+ doc_embeddings = self.corpus_embeddings # [N, D]
1084
+ if query_embeddings.shape[-1] != doc_embeddings.shape[-1]:
1085
+ raise ValueError(
1086
+ f"Embedding dim mismatch: {query_embeddings.shape[-1]} vs {doc_embeddings.shape[-1]}"
1087
+ )
1088
+
1089
+ # Cosine similarity via L2-normalized dot product
1090
+ q_norm = F.normalize(query_embeddings, p=2, dim=1)
1091
+ d_norm = F.normalize(doc_embeddings, p=2, dim=1)
1092
+ cos_sim = torch.matmul(q_norm, d_norm.T) # [Q, N]
1093
+
1094
+ k = min(max(1, top_k), len(self.indexed_corpus))
1095
+ _, top_indices = torch.topk(cos_sim, k=k, dim=1)
1096
+ return [[self.indexed_corpus[j] for j in row.tolist()] for row in top_indices]
1097
+
1098
+ def _decode_examples(self, docs: List[str]) -> List[Tuple[str, List[str]]]:
1099
+ """Parse raw corpus rows ('term || [types]') into `(term, [types])` pairs.
1100
+
1101
+ Parameters
1102
+ docs:
1103
+ Raw strings from the index/corpus.
1104
+
1105
+ Returns
1106
+ List[Tuple[str, List[str]]]
1107
+ Parsed (term, types) pairs; malformed rows are skipped.
1108
+ """
1109
+ example_pairs: List[Tuple[str, List[str]]] = []
1110
+ for raw_row in docs:
1111
+ try:
1112
+ term_raw, types_json = raw_row.split("||", 1)
1113
+ term = term_raw.strip()
1114
+ types_list = json.loads(types_json.strip())
1115
+ if isinstance(types_list, list):
1116
+ example_pairs.append(
1117
+ (term, [t for t in types_list if isinstance(t, str)])
1118
+ )
1119
+ except Exception:
1120
+ continue
1121
+ return example_pairs
1122
+
1123
+ def _format_examples(self, pairs: List[Tuple[str, List[str]]]) -> str:
1124
+ """Format retrieved example pairs into a compact block for the prompt.
1125
+
1126
+ Parameters
1127
+ pairs:
1128
+ Retrieved `(term, [types])` examples.
1129
+
1130
+ Returns
1131
+ str
1132
+ Human-readable lines to provide *light* guidance to the LLM.
1133
+ """
1134
+ if not pairs:
1135
+ return "EXAMPLES: (none provided)"
1136
+ lines: List[str] = ["CLASSIFICATION EXAMPLES:"]
1137
+ for idx, (term, types) in enumerate(pairs, 1):
1138
+ preview_types = types[:3] # keep context small
1139
+ lines.append(f"{idx}. Term: '{term}' → Types: {list(preview_types)}")
1140
+ lines.append("END OF EXAMPLES.")
1141
+ return "\n".join(lines)
1142
+
1143
+ def _compose_prompt(self, examples_block: str, term: str) -> str:
1144
+ """Compose the final prompt from system + user blocks.
1145
+
1146
+ Parameters
1147
+ examples_block:
1148
+ Text block with retrieved examples.
1149
+ term:
1150
+ The query term to classify.
1151
+
1152
+ Returns
1153
+ str
1154
+ Full prompt string passed to the LLM.
1155
+ """
1156
+ user_block = self._user_prompt_template.format(
1157
+ examples=examples_block, term=term
1158
+ )
1159
+ return f"{self._system_prompt}\n\n{user_block}\n"
1160
+
1161
+ def _generate_and_parse(self, prompts: List[str]) -> List[List[str]]:
1162
+ """Run generation for a batch of prompts and parse the JSON `'types'` from outputs.
1163
+
1164
+ Parameters
1165
+ prompts:
1166
+ Finalized prompts for the LLM.
1167
+
1168
+ Returns
1169
+ List[List[str]]
1170
+ For each prompt, a list of predicted type strings.
1171
+ """
1172
+ batch_size = int(self.cfg["gen_batch_size"])
1173
+ all_predicted_types: List[List[str]] = []
1174
+
1175
+ for batch_start in range(0, len(prompts), batch_size):
1176
+ prompt_batch = prompts[batch_start : batch_start + batch_size]
1177
+
1178
+ # Tokenize and move to the LLM's device
1179
+ model_device = getattr(self.generation_model, "device", None)
1180
+ encodings = self.tokenizer(
1181
+ prompt_batch, return_tensors="pt", padding=True
1182
+ ).to(model_device)
1183
+ input_token_length = encodings["input_ids"].shape[1]
1184
+
1185
+ # Deterministic decoding (greedy)
1186
+ with torch.no_grad():
1187
+ generated_tokens = self.generation_model.generate(
1188
+ **encodings,
1189
+ do_sample=False,
1190
+ num_beams=1,
1191
+ temperature=None,
1192
+ top_p=None,
1193
+ top_k=None,
1194
+ max_new_tokens=int(self.cfg["max_new_tokens"]),
1195
+ pad_token_id=self.tokenizer.eos_token_id,
1196
+ )
1197
+
1198
+ # Slice off the prompt tokens and decode only newly generated tokens
1199
+ new_token_span = generated_tokens[:, input_token_length:]
1200
+ decoded_texts = [
1201
+ self.tokenizer.decode(seq, skip_special_tokens=True)
1202
+ for seq in new_token_span
1203
+ ]
1204
+
1205
+ parsed_types_per_prompt = [
1206
+ self._parse_types(text) for text in decoded_texts
1207
+ ]
1208
+ all_predicted_types.extend(parsed_types_per_prompt)
1209
+
1210
+ return all_predicted_types
1211
+
1212
+ def _parse_types(self, text: str) -> List[str]:
1213
+ """Extract a list of type strings from LLM output.
1214
+
1215
+ Parsing Strategy (in order)
1216
+ 1) Strict JSON object with `"types"`.
1217
+ 2) Regex-extract JSON object containing `"types"`.
1218
+ 3) Regex-extract first bracketed list.
1219
+ 4) Comma-split fallback.
1220
+
1221
+ Parameters
1222
+ text:
1223
+ Raw LLM output to parse.
1224
+
1225
+ Returns
1226
+ List[str]
1227
+ Parsed list of type strings (possibly empty if parsing fails).
1228
+ """
1229
+ try:
1230
+ obj = json.loads(text)
1231
+ if isinstance(obj, dict) and isinstance(obj.get("types"), list):
1232
+ return [t for t in obj["types"] if isinstance(t, str)]
1233
+ except Exception:
1234
+ pass
1235
+
1236
+ try:
1237
+ obj_match = re.search(
1238
+ r'\{[^{}]*"types"\s*:\s*\[[^\]]*\][^{}]*\}', text, re.S
1239
+ )
1240
+ if obj_match:
1241
+ obj = json.loads(obj_match.group(0))
1242
+ types = obj.get("types", [])
1243
+ return [t for t in types if isinstance(t, str)]
1244
+ except Exception:
1245
+ pass
1246
+
1247
+ try:
1248
+ list_match = re.search(r"\[([^\]]+)\]", text)
1249
+ if list_match:
1250
+ items = [
1251
+ x.strip().strip('"').strip("'")
1252
+ for x in list_match.group(1).split(",")
1253
+ ]
1254
+ return [t for t in items if t]
1255
+ except Exception:
1256
+ pass
1257
+
1258
+ if "," in text:
1259
+ items = [x.strip().strip('"').strip("'") for x in text.split(",")]
1260
+ return [t for t in items if t]
1261
+
1262
+ return []