linkml-store 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +35 -8
- linkml_store/api/collection.py +40 -5
- linkml_store/api/config.py +20 -3
- linkml_store/api/database.py +24 -3
- linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
- linkml_store/cli.py +140 -13
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/rag_inference_engine.py +145 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +158 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +290 -0
- linkml_store/inference/inference_config.py +62 -0
- linkml_store/inference/inference_engine.py +173 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/utils/format_utils.py +21 -90
- linkml_store/utils/llm_utils.py +95 -0
- linkml_store/utils/object_utils.py +3 -1
- linkml_store/utils/pandas_utils.py +55 -2
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/stats_utils.py +53 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.1.14.dist-info}/METADATA +25 -2
- {linkml_store-0.1.13.dist-info → linkml_store-0.1.14.dist-info}/RECORD +25 -14
- {linkml_store-0.1.13.dist-info → linkml_store-0.1.14.dist-info}/LICENSE +0 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.1.14.dist-info}/WHEEL +0 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.1.14.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from llm import get_key
|
|
7
|
+
|
|
8
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
9
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
|
|
10
|
+
from linkml_store.inference.inference_engine import InferenceEngine
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
SYSTEM_PROMPT = """
|
|
15
|
+
You are a {llm_config.role}, your task is to inference the YAML
|
|
16
|
+
object output given the YAML object input. I will provide you
|
|
17
|
+
with a collection of examples that will provide guidance both
|
|
18
|
+
on the desired structure of the response, as well as the kind
|
|
19
|
+
of content.
|
|
20
|
+
|
|
21
|
+
You should return ONLY valid YAML in your response.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class RAGInferenceEngine(InferenceEngine):
|
|
27
|
+
"""
|
|
28
|
+
AI Retrieval Augmented Generation (RAG) based predictor.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
>>> from linkml_store.api.client import Client
|
|
32
|
+
>>> from linkml_store.utils.format_utils import Format
|
|
33
|
+
>>> from linkml_store.inference.inference_config import LLMConfig
|
|
34
|
+
>>> client = Client()
|
|
35
|
+
>>> db = client.attach_database("duckdb", alias="test")
|
|
36
|
+
>>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
|
|
37
|
+
>>> db.list_collection_names()
|
|
38
|
+
['countries']
|
|
39
|
+
>>> collection = db.get_collection("countries")
|
|
40
|
+
>>> features = ["name"]
|
|
41
|
+
>>> targets = ["code", "capital", "continent", "languages"]
|
|
42
|
+
>>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
|
|
43
|
+
>>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
|
|
44
|
+
>>> ie = RAGInferenceEngine(config=config)
|
|
45
|
+
>>> ie.load_and_split_data(collection)
|
|
46
|
+
>>> ie.initialize_model()
|
|
47
|
+
>>> prediction = ie.derive({"name": "Uruguay"})
|
|
48
|
+
>>> prediction.predicted_object
|
|
49
|
+
{'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
classifier: Any = None
|
|
54
|
+
encoders: dict = None
|
|
55
|
+
_model: "llm.Model" = None # noqa: F821
|
|
56
|
+
|
|
57
|
+
rag_collection: Collection = None
|
|
58
|
+
|
|
59
|
+
def __post_init__(self):
|
|
60
|
+
if not self.config:
|
|
61
|
+
self.config = InferenceConfig()
|
|
62
|
+
if not self.config.llm_config:
|
|
63
|
+
self.config.llm_config = LLMConfig()
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def model(self) -> "llm.Model": # noqa: F821
|
|
67
|
+
import llm
|
|
68
|
+
|
|
69
|
+
if self._model is None:
|
|
70
|
+
self._model = llm.get_model(self.config.llm_config.model_name)
|
|
71
|
+
if self._model.needs_key:
|
|
72
|
+
key = get_key(None, key_alias=self._model.needs_key)
|
|
73
|
+
self._model.key = key
|
|
74
|
+
|
|
75
|
+
return self._model
|
|
76
|
+
|
|
77
|
+
def initialize_model(self, **kwargs):
|
|
78
|
+
td = self.training_data
|
|
79
|
+
s = td.slice
|
|
80
|
+
if not s[0] and not s[1]:
|
|
81
|
+
rag_collection = td.collection
|
|
82
|
+
else:
|
|
83
|
+
base_collection = td.collection
|
|
84
|
+
objs = base_collection.find({}, offset=s[0], limit=s[1] - s[0]).rows
|
|
85
|
+
db = base_collection.parent
|
|
86
|
+
rag_collection = db.get_collection(f"{base_collection.alias}__rag_{s[0]}_{s[1]}", create_if_not_exists=True)
|
|
87
|
+
rag_collection.insert(objs)
|
|
88
|
+
rag_collection.attach_indexer("llm", auto_index=False)
|
|
89
|
+
self.rag_collection = rag_collection
|
|
90
|
+
|
|
91
|
+
def object_to_text(self, object: OBJECT) -> str:
|
|
92
|
+
return yaml.dump(object)
|
|
93
|
+
|
|
94
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
95
|
+
import llm
|
|
96
|
+
from tiktoken import encoding_for_model
|
|
97
|
+
|
|
98
|
+
from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
|
|
99
|
+
|
|
100
|
+
model: llm.Model = self.model
|
|
101
|
+
model_name = self.config.llm_config.model_name
|
|
102
|
+
feature_attributes = self.config.feature_attributes
|
|
103
|
+
target_attributes = self.config.target_attributes
|
|
104
|
+
num_examples = self.config.llm_config.number_of_few_shot_examples or 5
|
|
105
|
+
query_text = self.object_to_text(object)
|
|
106
|
+
if not self.rag_collection.indexers:
|
|
107
|
+
raise ValueError("RAG collection must have an indexer attached")
|
|
108
|
+
rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
|
|
109
|
+
examples = rs.rows
|
|
110
|
+
if not examples:
|
|
111
|
+
raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
|
|
112
|
+
prompt_clauses = []
|
|
113
|
+
for example in examples:
|
|
114
|
+
input_obj = {k: example.get(k, None) for k in feature_attributes}
|
|
115
|
+
output_obj = {k: example.get(k, None) for k in target_attributes}
|
|
116
|
+
prompt_clause = (
|
|
117
|
+
"---\nExample:\n"
|
|
118
|
+
f"## INPUT:\n{self.object_to_text(input_obj)}\n"
|
|
119
|
+
f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
|
|
120
|
+
)
|
|
121
|
+
prompt_clauses.append(prompt_clause)
|
|
122
|
+
query_obj = {k: object.get(k, None) for k in feature_attributes}
|
|
123
|
+
query_text = self.object_to_text(query_obj)
|
|
124
|
+
prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
|
|
125
|
+
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
126
|
+
|
|
127
|
+
def make_text(texts):
|
|
128
|
+
return "\n".join(prompt_clauses) + prompt_end
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
encoding = encoding_for_model(model_name)
|
|
132
|
+
except KeyError:
|
|
133
|
+
encoding = encoding_for_model("gpt-4")
|
|
134
|
+
token_limit = get_token_limit(model_name)
|
|
135
|
+
prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
|
|
136
|
+
logger.info(f"Prompt: {prompt}")
|
|
137
|
+
response = model.prompt(prompt, system_prompt)
|
|
138
|
+
yaml_str = response.text()
|
|
139
|
+
logger.info(f"Response: {yaml_str}")
|
|
140
|
+
try:
|
|
141
|
+
predicted_object = yaml.safe_load(yaml_str)
|
|
142
|
+
return Inference(predicted_object=predicted_object)
|
|
143
|
+
except yaml.parser.ParserError as e:
|
|
144
|
+
logger.error(f"Error parsing response: {yaml_str}\n{e}")
|
|
145
|
+
return None
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import copy
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, ClassVar, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from linkml_map.utils.eval_utils import eval_expr
|
|
10
|
+
from linkml_runtime import SchemaView
|
|
11
|
+
from linkml_runtime.linkml_model.meta import AnonymousClassExpression, ClassRule
|
|
12
|
+
from linkml_runtime.utils.formatutils import underscore
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
16
|
+
from linkml_store.inference.inference_config import Inference
|
|
17
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def expression_matches(ce: AnonymousClassExpression, object: OBJECT) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
Check if a class expression matches an object.
|
|
25
|
+
|
|
26
|
+
:param ce: The class expression
|
|
27
|
+
:param object: The object
|
|
28
|
+
:return: True if the class expression matches the object
|
|
29
|
+
"""
|
|
30
|
+
if ce.any_of:
|
|
31
|
+
if not any(expression_matches(subce, object) for subce in ce.any_of):
|
|
32
|
+
return False
|
|
33
|
+
if ce.all_of:
|
|
34
|
+
if not all(expression_matches(subce, object) for subce in ce.all_of):
|
|
35
|
+
return False
|
|
36
|
+
if ce.none_of:
|
|
37
|
+
if any(expression_matches(subce, object) for subce in ce.none_of):
|
|
38
|
+
return False
|
|
39
|
+
if ce.slot_conditions:
|
|
40
|
+
for slot in ce.slot_conditions.values():
|
|
41
|
+
slot_name = slot.name
|
|
42
|
+
v = object.get(slot_name, None)
|
|
43
|
+
if slot.equals_string is not None:
|
|
44
|
+
if slot.equals_string != str(v):
|
|
45
|
+
return False
|
|
46
|
+
if slot.equals_integer is not None:
|
|
47
|
+
if slot.equals_integer != v:
|
|
48
|
+
return False
|
|
49
|
+
if slot.equals_expression is not None:
|
|
50
|
+
eval_v = eval_expr(slot.equals_expression, **object)
|
|
51
|
+
if v != eval_v:
|
|
52
|
+
return False
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def apply_rule(rule: ClassRule, object: OBJECT):
|
|
57
|
+
"""
|
|
58
|
+
Apply a rule to an object.
|
|
59
|
+
|
|
60
|
+
Mutates the object
|
|
61
|
+
|
|
62
|
+
:param rule: The rule to apply
|
|
63
|
+
:param object: The object to apply the rule to
|
|
64
|
+
"""
|
|
65
|
+
for condition in rule.preconditions:
|
|
66
|
+
if expression_matches(condition, object):
|
|
67
|
+
for postcondition in rule.postconditions:
|
|
68
|
+
all_of = [x for x in postcondition.all_of] + [postcondition]
|
|
69
|
+
for pc in all_of:
|
|
70
|
+
sc = pc.slot_condition
|
|
71
|
+
if sc:
|
|
72
|
+
if sc.equals_string:
|
|
73
|
+
object[sc.name] = sc.equals_string
|
|
74
|
+
if sc.equals_integer:
|
|
75
|
+
object[sc.name] = sc.equals_integer
|
|
76
|
+
if sc.equals_expression:
|
|
77
|
+
object[sc.name] = eval_expr(sc.equals_expression, **object)
|
|
78
|
+
return object
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class RuleBasedInferenceEngine(InferenceEngine):
|
|
83
|
+
"""
|
|
84
|
+
TODO
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
class_rules: Optional[List[ClassRule]] = None
|
|
89
|
+
slot_rules: Optional[Dict[str, List[ClassRule]]] = None
|
|
90
|
+
slot_expressions: Optional[Dict[str, str]] = None
|
|
91
|
+
|
|
92
|
+
PERSIST_COLS: ClassVar = ["config", "class_rules", "slot_rules", "slot_expressions"]
|
|
93
|
+
|
|
94
|
+
def initialize_model(self, **kwargs):
|
|
95
|
+
td = self.training_data
|
|
96
|
+
collection: Collection = td.collection
|
|
97
|
+
cd = collection.class_definition()
|
|
98
|
+
sv: SchemaView = collection.parent.schema_view
|
|
99
|
+
class_rules = cd.rules
|
|
100
|
+
if class_rules:
|
|
101
|
+
self.class_rules = class_rules
|
|
102
|
+
for slot in sv.class_induced_slots(cd.name):
|
|
103
|
+
if slot.equals_expression:
|
|
104
|
+
self.slot_expressions[slot.name] = slot.equals_expression
|
|
105
|
+
|
|
106
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
107
|
+
object = copy(object)
|
|
108
|
+
if self.class_rules:
|
|
109
|
+
for rule in self.class_rules:
|
|
110
|
+
apply_rule(rule, object)
|
|
111
|
+
object = {underscore(k): v for k, v in object.items()}
|
|
112
|
+
if self.slot_expressions:
|
|
113
|
+
for slot, expr in self.slot_expressions.items():
|
|
114
|
+
print(f"EVAL {object}")
|
|
115
|
+
v = eval_expr(expr, **object)
|
|
116
|
+
if v is not None:
|
|
117
|
+
object[slot] = v
|
|
118
|
+
return Inference(predicted_object=object)
|
|
119
|
+
|
|
120
|
+
def import_model_from(self, inference_engine: InferenceEngine, **kwargs):
|
|
121
|
+
io = StringIO()
|
|
122
|
+
inference_engine.export_model(io, model_serialization=ModelSerialization.LINKML_EXPRESSION)
|
|
123
|
+
config = inference_engine.config
|
|
124
|
+
if len(config.target_attributes) != 1:
|
|
125
|
+
raise ValueError("Can only import models with a single target attribute")
|
|
126
|
+
target_attribute = config.target_attributes[0]
|
|
127
|
+
if self.slot_expressions is None:
|
|
128
|
+
self.slot_expressions = {}
|
|
129
|
+
self.slot_expressions[target_attribute] = io.getvalue()
|
|
130
|
+
|
|
131
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Save the trained model and related data to a file.
|
|
134
|
+
|
|
135
|
+
:param output: Path to save the model
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def _serialize_value(v: Any) -> Any:
|
|
139
|
+
if isinstance(v, BaseModel):
|
|
140
|
+
return v.model_dump(exclude_unset=True)
|
|
141
|
+
return v
|
|
142
|
+
|
|
143
|
+
model_data = {k: _serialize_value(getattr(self, k)) for k in self.PERSIST_COLS}
|
|
144
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
145
|
+
yaml.dump(model_data, f)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def load_model(cls, file_path: Union[str, Path]) -> "RuleBasedInferenceEngine":
|
|
149
|
+
model_data = yaml.safe_load(open(file_path))
|
|
150
|
+
|
|
151
|
+
engine = cls(config=model_data["config"])
|
|
152
|
+
for k, v in model_data.items():
|
|
153
|
+
if k == "config":
|
|
154
|
+
continue
|
|
155
|
+
setattr(engine, k, v)
|
|
156
|
+
|
|
157
|
+
logger.info(f"Model loaded from {file_path}")
|
|
158
|
+
return engine
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, ClassVar, Dict, List, Optional, TextIO, Type, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from sklearn.model_selection import cross_val_score
|
|
9
|
+
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer, OneHotEncoder
|
|
10
|
+
from sklearn.tree import DecisionTreeClassifier
|
|
11
|
+
|
|
12
|
+
from linkml_store.api.collection import OBJECT
|
|
13
|
+
from linkml_store.inference.implementations.rule_based_inference_engine import RuleBasedInferenceEngine
|
|
14
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig
|
|
15
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
16
|
+
from linkml_store.utils.sklearn_utils import tree_to_nested_expression, visualize_decision_tree
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class SklearnInferenceEngine(InferenceEngine):
|
|
23
|
+
config: InferenceConfig
|
|
24
|
+
classifier: Any = None
|
|
25
|
+
encoders: Dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
transformed_features: List[str] = field(default_factory=list)
|
|
27
|
+
transformed_targets: List[str] = field(default_factory=list)
|
|
28
|
+
skip_features: List[str] = field(default_factory=list)
|
|
29
|
+
categorical_encoder_class: Optional[Type[Union[OneHotEncoder, MultiLabelBinarizer]]] = None
|
|
30
|
+
maximum_proportion_distinct_features: float = 0.2
|
|
31
|
+
confidence: float = 0.0
|
|
32
|
+
|
|
33
|
+
strict: bool = False
|
|
34
|
+
|
|
35
|
+
PERSIST_COLS: ClassVar = [
|
|
36
|
+
"config",
|
|
37
|
+
"classifier",
|
|
38
|
+
"encoders",
|
|
39
|
+
"transformed_features",
|
|
40
|
+
"transformed_targets",
|
|
41
|
+
"skip_features",
|
|
42
|
+
"confidence",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def _get_encoder(self, v: Union[List[Any], Any]) -> Any:
|
|
46
|
+
if isinstance(v, list):
|
|
47
|
+
if all(isinstance(x, list) for x in v):
|
|
48
|
+
return MultiLabelBinarizer()
|
|
49
|
+
elif all(isinstance(x, str) for x in v):
|
|
50
|
+
return OneHotEncoder(sparse_output=False, handle_unknown="ignore")
|
|
51
|
+
elif all(isinstance(x, (int, float)) for x in v):
|
|
52
|
+
return None
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError("Mixed data types in the list are not supported")
|
|
55
|
+
else:
|
|
56
|
+
if hasattr(v, "dtype"):
|
|
57
|
+
if v.dtype == "object" or v.dtype.name == "category":
|
|
58
|
+
if isinstance(v.iloc[0], list):
|
|
59
|
+
return MultiLabelBinarizer()
|
|
60
|
+
elif self.categorical_encoder_class:
|
|
61
|
+
return self.categorical_encoder_class(handle_unknown="ignore")
|
|
62
|
+
else:
|
|
63
|
+
return OneHotEncoder(sparse_output=False, handle_unknown="ignore")
|
|
64
|
+
elif v.dtype.kind in "biufc":
|
|
65
|
+
return None
|
|
66
|
+
raise ValueError("Unable to determine appropriate encoder for the input data")
|
|
67
|
+
|
|
68
|
+
def _is_complex_column(self, column: pd.Series) -> bool:
|
|
69
|
+
"""Check if the column contains complex data types like lists or dicts."""
|
|
70
|
+
# MV_TYPE = (list, dict)
|
|
71
|
+
MV_TYPE = (list,)
|
|
72
|
+
return (column.dtype == "object" or column.dtype == "category") and any(
|
|
73
|
+
isinstance(x, MV_TYPE) for x in column.dropna()
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _get_unique_values(self, column: pd.Series) -> set:
|
|
77
|
+
"""Get unique values from a column, handling list-type data."""
|
|
78
|
+
if self._is_complex_column(column):
|
|
79
|
+
# For columns with lists, flatten the lists and get unique values
|
|
80
|
+
return set(
|
|
81
|
+
item for sublist in column.dropna() for item in (sublist if isinstance(sublist, list) else [sublist])
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
return set(column.unique())
|
|
85
|
+
|
|
86
|
+
def initialize_model(self, **kwargs):
|
|
87
|
+
logger.info(f"Initializing model with config: {self.config}")
|
|
88
|
+
df = self.training_data.as_dataframe(flattened=True)
|
|
89
|
+
logger.info(f"Training data shape: {df.shape}")
|
|
90
|
+
target_cols = self.config.target_attributes
|
|
91
|
+
feature_cols = self.config.feature_attributes
|
|
92
|
+
if len(target_cols) != 1:
|
|
93
|
+
raise ValueError("Only one target column is supported")
|
|
94
|
+
if not feature_cols:
|
|
95
|
+
feature_cols = df.columns.difference(target_cols).tolist()
|
|
96
|
+
self.config.feature_attributes = feature_cols
|
|
97
|
+
target_col = target_cols[0]
|
|
98
|
+
logger.info(f"Feature columns: {feature_cols}")
|
|
99
|
+
X = df[feature_cols].copy()
|
|
100
|
+
logger.info(f"Target column: {target_col}")
|
|
101
|
+
y = df[target_col].copy()
|
|
102
|
+
|
|
103
|
+
# find list of features to skip (categorical with > N categories)
|
|
104
|
+
skip_features = []
|
|
105
|
+
for col in X.columns:
|
|
106
|
+
unique_values = self._get_unique_values(X[col])
|
|
107
|
+
if len(unique_values) > self.maximum_proportion_distinct_features * len(X[col]):
|
|
108
|
+
skip_features.append(col)
|
|
109
|
+
if False and (X[col].dtype == "object" or X[col].dtype.name == "category"):
|
|
110
|
+
if len(X[col].unique()) > self.maximum_proportion_distinct_features * len(X[col]):
|
|
111
|
+
skip_features.append(col)
|
|
112
|
+
self.skip_features = skip_features
|
|
113
|
+
X = X.drop(skip_features, axis=1)
|
|
114
|
+
logger.info(f"Skipping features: {skip_features}")
|
|
115
|
+
|
|
116
|
+
# Encode features
|
|
117
|
+
encoded_features = []
|
|
118
|
+
for col in X.columns:
|
|
119
|
+
logger.info(f"Checking whether to encode: {col}")
|
|
120
|
+
col_encoder = self._get_encoder(X[col])
|
|
121
|
+
if col_encoder:
|
|
122
|
+
self.encoders[col] = col_encoder
|
|
123
|
+
if isinstance(col_encoder, OneHotEncoder):
|
|
124
|
+
encoded = col_encoder.fit_transform(X[[col]])
|
|
125
|
+
feature_names = col_encoder.get_feature_names_out([col])
|
|
126
|
+
encoded_df = pd.DataFrame(encoded, columns=feature_names, index=X.index)
|
|
127
|
+
X = pd.concat([X.drop(col, axis=1), encoded_df], axis=1)
|
|
128
|
+
encoded_features.extend(feature_names)
|
|
129
|
+
elif isinstance(col_encoder, MultiLabelBinarizer):
|
|
130
|
+
encoded = col_encoder.fit_transform(X[col])
|
|
131
|
+
feature_names = [f"{col}_{c}" for c in col_encoder.classes_]
|
|
132
|
+
encoded_df = pd.DataFrame(encoded, columns=feature_names, index=X.index)
|
|
133
|
+
X = pd.concat([X.drop(col, axis=1), encoded_df], axis=1)
|
|
134
|
+
encoded_features.extend(feature_names)
|
|
135
|
+
else:
|
|
136
|
+
X[col] = col_encoder.fit_transform(X[col])
|
|
137
|
+
encoded_features.append(col)
|
|
138
|
+
else:
|
|
139
|
+
encoded_features.append(col)
|
|
140
|
+
|
|
141
|
+
self.transformed_features = encoded_features
|
|
142
|
+
logger.info(f"Encoded features: {self.transformed_features}")
|
|
143
|
+
logger.info(f"Number of features after encoding: {len(self.transformed_features)}")
|
|
144
|
+
|
|
145
|
+
# Encode target
|
|
146
|
+
# y_encoder = LabelEncoder()
|
|
147
|
+
y_encoder = self._get_encoder(y)
|
|
148
|
+
if isinstance(y_encoder, OneHotEncoder):
|
|
149
|
+
y_encoder = LabelEncoder()
|
|
150
|
+
# self.encoders[target_col] = y_encoder
|
|
151
|
+
if y_encoder:
|
|
152
|
+
self.encoders[target_col] = y_encoder
|
|
153
|
+
y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
|
|
154
|
+
self.transformed_targets = y_encoder.classes_
|
|
155
|
+
|
|
156
|
+
logger.info(f"Fitting model with features: {X.columns}")
|
|
157
|
+
clf = DecisionTreeClassifier(random_state=42)
|
|
158
|
+
clf.fit(X, y)
|
|
159
|
+
self.classifier = clf
|
|
160
|
+
logger.info("Model fit complete")
|
|
161
|
+
cv_scores = cross_val_score(self.classifier, X, y, cv=5)
|
|
162
|
+
self.confidence = cv_scores.mean()
|
|
163
|
+
logger.info(f"Cross-validation scores: {cv_scores}")
|
|
164
|
+
|
|
165
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
166
|
+
object = self._normalize(object)
|
|
167
|
+
new_X = pd.DataFrame([object])
|
|
168
|
+
|
|
169
|
+
# Apply encodings
|
|
170
|
+
encoded_features = {}
|
|
171
|
+
for col in self.config.feature_attributes:
|
|
172
|
+
if col in self.skip_features:
|
|
173
|
+
continue
|
|
174
|
+
if col in self.encoders:
|
|
175
|
+
encoder = self.encoders[col]
|
|
176
|
+
if isinstance(encoder, OneHotEncoder):
|
|
177
|
+
encoded = encoder.transform(new_X[[col]])
|
|
178
|
+
feature_names = encoder.get_feature_names_out([col])
|
|
179
|
+
for i, name in enumerate(feature_names):
|
|
180
|
+
encoded_features[name] = encoded[0, i]
|
|
181
|
+
elif isinstance(encoder, MultiLabelBinarizer):
|
|
182
|
+
encoded = encoder.transform(new_X[col])
|
|
183
|
+
feature_names = [f"{col}_{c}" for c in encoder.classes_]
|
|
184
|
+
for i, name in enumerate(feature_names):
|
|
185
|
+
encoded_features[name] = encoded[0, i]
|
|
186
|
+
else: # LabelEncoder or similar
|
|
187
|
+
encoded_features[col] = encoder.transform(new_X[col].astype(str))[0]
|
|
188
|
+
else:
|
|
189
|
+
encoded_features[col] = new_X[col].iloc[0]
|
|
190
|
+
|
|
191
|
+
# Ensure all expected features are present and in the correct order
|
|
192
|
+
final_features = []
|
|
193
|
+
for feature in self.transformed_features:
|
|
194
|
+
if feature in encoded_features:
|
|
195
|
+
final_features.append(encoded_features[feature])
|
|
196
|
+
else:
|
|
197
|
+
final_features.append(0) # or some other default value
|
|
198
|
+
|
|
199
|
+
# Create the final input array
|
|
200
|
+
new_X_array = np.array(final_features).reshape(1, -1)
|
|
201
|
+
|
|
202
|
+
logger.info(f"Input features: {self.transformed_features}")
|
|
203
|
+
logger.info(f"Number of input features: {len(self.transformed_features)}")
|
|
204
|
+
|
|
205
|
+
predictions = self.classifier.predict(new_X_array)
|
|
206
|
+
target_attribute = self.config.target_attributes[0]
|
|
207
|
+
y_encoder = self.encoders.get(target_attribute)
|
|
208
|
+
|
|
209
|
+
if y_encoder:
|
|
210
|
+
v = y_encoder.inverse_transform(predictions)
|
|
211
|
+
else:
|
|
212
|
+
v = predictions
|
|
213
|
+
|
|
214
|
+
predicted_object = {target_attribute: v[0]}
|
|
215
|
+
logger.info(f"Predicted object: {predicted_object}")
|
|
216
|
+
return Inference(predicted_object=predicted_object, confidence=self.confidence)
|
|
217
|
+
|
|
218
|
+
def _normalize(self, object: OBJECT) -> OBJECT:
|
|
219
|
+
return {k: object.get(k, None) for k in self.config.feature_attributes}
|
|
220
|
+
|
|
221
|
+
def export_model(
|
|
222
|
+
self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
|
|
223
|
+
):
|
|
224
|
+
def as_file():
|
|
225
|
+
if isinstance(output, (str, Path)):
|
|
226
|
+
return open(output, "w")
|
|
227
|
+
return output
|
|
228
|
+
|
|
229
|
+
if model_serialization is None:
|
|
230
|
+
if isinstance(output, (str, Path)):
|
|
231
|
+
model_serialization = ModelSerialization.from_filepath(output)
|
|
232
|
+
if model_serialization is None:
|
|
233
|
+
model_serialization = ModelSerialization.JOBLIB
|
|
234
|
+
|
|
235
|
+
if model_serialization == ModelSerialization.LINKML_EXPRESSION:
|
|
236
|
+
expr = tree_to_nested_expression(
|
|
237
|
+
self.classifier,
|
|
238
|
+
self.transformed_features,
|
|
239
|
+
self.encoders.keys(),
|
|
240
|
+
feature_encoders=self.encoders,
|
|
241
|
+
target_encoder=self.encoders.get(self.config.target_attributes[0]),
|
|
242
|
+
)
|
|
243
|
+
as_file().write(expr)
|
|
244
|
+
elif model_serialization == ModelSerialization.JOBLIB:
|
|
245
|
+
self.save_model(output)
|
|
246
|
+
elif model_serialization == ModelSerialization.RULE_BASED:
|
|
247
|
+
rbie = RuleBasedInferenceEngine(config=self.config)
|
|
248
|
+
rbie.import_model_from(self)
|
|
249
|
+
rbie.save_model(output)
|
|
250
|
+
elif model_serialization == ModelSerialization.PNG:
|
|
251
|
+
visualize_decision_tree(self.classifier, self.transformed_features, self.transformed_targets, output)
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError(f"Unsupported model serialization: {model_serialization}")
|
|
254
|
+
|
|
255
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Save the trained model and related data to a file.
|
|
258
|
+
|
|
259
|
+
:param output: Path to save the model
|
|
260
|
+
"""
|
|
261
|
+
import joblib
|
|
262
|
+
|
|
263
|
+
if self.classifier is None:
|
|
264
|
+
raise ValueError("Model has not been trained. Call initialize_model() first.")
|
|
265
|
+
|
|
266
|
+
# Use self.PERSIST_COLS
|
|
267
|
+
model_data = {k: getattr(self, k) for k in self.PERSIST_COLS}
|
|
268
|
+
|
|
269
|
+
joblib.dump(model_data, output)
|
|
270
|
+
|
|
271
|
+
@classmethod
|
|
272
|
+
def load_model(cls, file_path: Union[str, Path]) -> "SklearnInferenceEngine":
|
|
273
|
+
"""
|
|
274
|
+
Load a trained model and related data from a file.
|
|
275
|
+
|
|
276
|
+
:param file_path: Path to the saved model
|
|
277
|
+
:return: SklearnInferenceEngine instance with loaded model
|
|
278
|
+
"""
|
|
279
|
+
import joblib
|
|
280
|
+
|
|
281
|
+
model_data = joblib.load(file_path)
|
|
282
|
+
|
|
283
|
+
engine = cls(config=model_data["config"])
|
|
284
|
+
for k, v in model_data.items():
|
|
285
|
+
if k == "config":
|
|
286
|
+
continue
|
|
287
|
+
setattr(engine, k, v)
|
|
288
|
+
|
|
289
|
+
logger.info(f"Model loaded from {file_path}")
|
|
290
|
+
return engine
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
from linkml_store.api.collection import OBJECT
|
|
7
|
+
from linkml_store.utils.format_utils import Format, load_objects
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LLMConfig(BaseModel, extra="forbid"):
|
|
13
|
+
"""
|
|
14
|
+
Configuration for the LLM indexer.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
18
|
+
|
|
19
|
+
model_name: str = "gpt-4o-mini"
|
|
20
|
+
token_limit: Optional[int] = None
|
|
21
|
+
number_of_few_shot_examples: Optional[int] = None
|
|
22
|
+
role: str = "Domain Expert"
|
|
23
|
+
cached_embeddings_database: Optional[str] = None
|
|
24
|
+
cached_embeddings_collection: Optional[str] = None
|
|
25
|
+
text_template: Optional[str] = None
|
|
26
|
+
text_template_syntax: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InferenceConfig(BaseModel, extra="forbid"):
|
|
30
|
+
"""
|
|
31
|
+
Configuration for inference engines.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
target_attributes: Optional[List[str]] = None
|
|
35
|
+
feature_attributes: Optional[List[str]] = None
|
|
36
|
+
train_test_split: Optional[Tuple[float, float]] = None
|
|
37
|
+
llm_config: Optional[LLMConfig] = None
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig":
|
|
41
|
+
"""
|
|
42
|
+
Load an inference config from a file.
|
|
43
|
+
|
|
44
|
+
:param file_path: Path to the file.
|
|
45
|
+
:param format: Format of the file (YAML is recommended).
|
|
46
|
+
:return: InferenceConfig
|
|
47
|
+
"""
|
|
48
|
+
if format and format.is_xsv():
|
|
49
|
+
logger.warning("XSV format is not recommended for inference config files")
|
|
50
|
+
objs = load_objects(file_path, format=format)
|
|
51
|
+
if len(objs) != 1:
|
|
52
|
+
raise ValueError(f"Expected 1 object, got {len(objs)}")
|
|
53
|
+
return cls(**objs[0])
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Inference(BaseModel, extra="forbid"):
|
|
57
|
+
"""
|
|
58
|
+
Result of an inference derivation.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
predicted_object: OBJECT = Field(..., description="The predicted object.")
|
|
62
|
+
confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0)
|