linkml-store 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +35 -8
- linkml_store/api/collection.py +40 -5
- linkml_store/api/config.py +20 -3
- linkml_store/api/database.py +24 -3
- linkml_store/api/stores/duckdb/duckdb_collection.py +3 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
- linkml_store/cli.py +149 -13
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/evaluation.py +189 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/rag_inference_engine.py +145 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +308 -0
- linkml_store/inference/inference_config.py +62 -0
- linkml_store/inference/inference_engine.py +200 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/utils/format_utils.py +27 -90
- linkml_store/utils/llm_utils.py +96 -0
- linkml_store/utils/object_utils.py +103 -2
- linkml_store/utils/pandas_utils.py +55 -2
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/stats_utils.py +53 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/METADATA +28 -2
- {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/RECORD +27 -15
- {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/LICENSE +0 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/WHEEL +0 -0
- {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from linkml_store.inference import InferenceEngine
|
|
10
|
+
from linkml_store.utils.object_utils import select_nested
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def score_match(target: Optional[Any], candidate: Optional[Any], match_function: Optional[Callable] = None) -> float:
|
|
16
|
+
"""
|
|
17
|
+
Compute a score for a match between two objects
|
|
18
|
+
|
|
19
|
+
>>> score_match("a", "a")
|
|
20
|
+
1.0
|
|
21
|
+
>>> score_match("a", "b")
|
|
22
|
+
0.0
|
|
23
|
+
>>> score_match("a", None)
|
|
24
|
+
0.0
|
|
25
|
+
>>> score_match(None, "a")
|
|
26
|
+
0.0
|
|
27
|
+
>>> score_match(None, None)
|
|
28
|
+
1.0
|
|
29
|
+
>>> score_match(["a", "b"], ["a", "b"])
|
|
30
|
+
1.0
|
|
31
|
+
>>> score_match(["a", "b"], ["b", "a"])
|
|
32
|
+
1.0
|
|
33
|
+
>>> round(score_match(["a"], ["b", "a"]), 2)
|
|
34
|
+
0.67
|
|
35
|
+
>>> score_match({"a": 1}, {"a": 1})
|
|
36
|
+
1.0
|
|
37
|
+
>>> score_match({"a": 1}, {"a": 2})
|
|
38
|
+
0.0
|
|
39
|
+
>>> score_match({"a": 1, "b": None}, {"a": 1})
|
|
40
|
+
1.0
|
|
41
|
+
>>> score_match([{"a": 1, "b": 2}, {"a": 3, "b": 4}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
|
|
42
|
+
1.0
|
|
43
|
+
>>> score_match([{"a": 1, "b": 4}, {"a": 3, "b": 2}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
|
|
44
|
+
0.5
|
|
45
|
+
>>> def char_match(x, y):
|
|
46
|
+
... return len(set(x).intersection(set(y))) / len(set(x).union(set(y)))
|
|
47
|
+
>>> score_match("abcd", "abc", char_match)
|
|
48
|
+
0.75
|
|
49
|
+
>>> score_match(["abcd", "efgh"], ["ac", "gh"], char_match)
|
|
50
|
+
0.5
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
:param target:
|
|
54
|
+
:param candidate:
|
|
55
|
+
:param match_function:
|
|
56
|
+
:return:
|
|
57
|
+
"""
|
|
58
|
+
if target == candidate:
|
|
59
|
+
return 1.0
|
|
60
|
+
if target is None or candidate is None:
|
|
61
|
+
return 0.0
|
|
62
|
+
if isinstance(target, (set, list)) and isinstance(candidate, (set, list)):
|
|
63
|
+
# create an all by all matrix using numpy
|
|
64
|
+
# for each pair of elements, compute the score
|
|
65
|
+
# return the average score
|
|
66
|
+
score_matrix = np.array([[score_match(t, c, match_function) for c in candidate] for t in target])
|
|
67
|
+
best_matches0 = np.max(score_matrix, axis=0)
|
|
68
|
+
best_matches1 = np.max(score_matrix, axis=1)
|
|
69
|
+
return (np.sum(best_matches0) + np.sum(best_matches1)) / (len(target) + len(candidate))
|
|
70
|
+
if isinstance(target, dict) and isinstance(candidate, dict):
|
|
71
|
+
keys = set(target.keys()).union(candidate.keys())
|
|
72
|
+
scores = [score_match(target.get(k), candidate.get(k), match_function) for k in keys]
|
|
73
|
+
return np.mean(scores)
|
|
74
|
+
if match_function:
|
|
75
|
+
return match_function(target, candidate)
|
|
76
|
+
return 0.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class Outcome(BaseModel):
|
|
80
|
+
true_positive_count: float
|
|
81
|
+
total_count: int
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def accuracy(self) -> float:
|
|
85
|
+
return self.true_positive_count / self.total_count
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def evaluate_predictor(
|
|
89
|
+
predictor: InferenceEngine,
|
|
90
|
+
target_attributes: List[str],
|
|
91
|
+
feature_attributes: Optional[List[str]] = None,
|
|
92
|
+
test_data: pd.DataFrame = None,
|
|
93
|
+
evaluation_count: Optional[int] = 10,
|
|
94
|
+
match_function: Optional[Callable] = None,
|
|
95
|
+
) -> Outcome:
|
|
96
|
+
"""
|
|
97
|
+
Evaluate a predictor by comparing its predictions to the expected values in the testing data.
|
|
98
|
+
|
|
99
|
+
:param predictor:
|
|
100
|
+
:param target_attributes:
|
|
101
|
+
:param feature_attributes:
|
|
102
|
+
:param evaluation_count:
|
|
103
|
+
:return:
|
|
104
|
+
"""
|
|
105
|
+
n = 0
|
|
106
|
+
tp = 0
|
|
107
|
+
if test_data is None:
|
|
108
|
+
test_data = predictor.testing_data.as_dataframe()
|
|
109
|
+
for row in test_data.to_dict(orient="records"):
|
|
110
|
+
expected_obj = select_nested(row, target_attributes)
|
|
111
|
+
if feature_attributes:
|
|
112
|
+
test_obj = {k: v for k, v in row.items() if k not in target_attributes}
|
|
113
|
+
else:
|
|
114
|
+
test_obj = row
|
|
115
|
+
result = predictor.derive(test_obj)
|
|
116
|
+
logger.info(f"Predicted: {result.predicted_object} Expected: {expected_obj}")
|
|
117
|
+
tp += score_match(result.predicted_object, expected_obj, match_function)
|
|
118
|
+
n += 1
|
|
119
|
+
if evaluation_count is not None and n >= evaluation_count:
|
|
120
|
+
break
|
|
121
|
+
return Outcome(true_positive_count=tp, total_count=n)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def score_text_overlap(str1: Any, str2: Any) -> float:
|
|
125
|
+
"""
|
|
126
|
+
Compute the overlap score between two strings.
|
|
127
|
+
|
|
128
|
+
:param str1:
|
|
129
|
+
:param str2:
|
|
130
|
+
:return:
|
|
131
|
+
"""
|
|
132
|
+
if str1 == str2:
|
|
133
|
+
return 1.0
|
|
134
|
+
if not str1 or not str2:
|
|
135
|
+
return 0.0
|
|
136
|
+
overlap, length = find_longest_overlap(str1, str2)
|
|
137
|
+
return len(overlap) / max(len(str1), len(str2))
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def find_longest_overlap(str1: str, str2: str):
|
|
141
|
+
"""
|
|
142
|
+
Find the longest overlapping substring between two strings.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
str1 (str): The first string
|
|
146
|
+
str2 (str): The second string
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
tuple: A tuple containing the longest overlapping substring and its length
|
|
150
|
+
|
|
151
|
+
Examples:
|
|
152
|
+
>>> find_longest_overlap("hello world", "world of programming")
|
|
153
|
+
('world', 5)
|
|
154
|
+
>>> find_longest_overlap("abcdefg", "defghi")
|
|
155
|
+
('defg', 4)
|
|
156
|
+
>>> find_longest_overlap("python", "java")
|
|
157
|
+
('', 0)
|
|
158
|
+
>>> find_longest_overlap("", "test")
|
|
159
|
+
('', 0)
|
|
160
|
+
>>> find_longest_overlap("aabbcc", "ddeeff")
|
|
161
|
+
('', 0)
|
|
162
|
+
>>> find_longest_overlap("programming", "PROGRAMMING")
|
|
163
|
+
('', 0)
|
|
164
|
+
"""
|
|
165
|
+
if not str1 or not str2:
|
|
166
|
+
return "", 0
|
|
167
|
+
|
|
168
|
+
# Create a table to store lengths of matching substrings
|
|
169
|
+
m, n = len(str1), len(str2)
|
|
170
|
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
171
|
+
|
|
172
|
+
# Variables to store the maximum length and ending position
|
|
173
|
+
max_length = 0
|
|
174
|
+
end_pos = 0
|
|
175
|
+
|
|
176
|
+
# Fill the dp table
|
|
177
|
+
for i in range(1, m + 1):
|
|
178
|
+
for j in range(1, n + 1):
|
|
179
|
+
if str1[i - 1] == str2[j - 1]:
|
|
180
|
+
dp[i][j] = dp[i - 1][j - 1] + 1
|
|
181
|
+
if dp[i][j] > max_length:
|
|
182
|
+
max_length = dp[i][j]
|
|
183
|
+
end_pos = i
|
|
184
|
+
|
|
185
|
+
# Extract the longest common substring
|
|
186
|
+
start_pos = end_pos - max_length
|
|
187
|
+
longest_substring = str1[start_pos:end_pos]
|
|
188
|
+
|
|
189
|
+
return longest_substring, max_length
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from llm import get_key
|
|
7
|
+
|
|
8
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
9
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
|
|
10
|
+
from linkml_store.inference.inference_engine import InferenceEngine
|
|
11
|
+
from linkml_store.utils.object_utils import select_nested
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
SYSTEM_PROMPT = """
|
|
16
|
+
You are a {llm_config.role}, your task is to inference the YAML
|
|
17
|
+
object output given the YAML object input. I will provide you
|
|
18
|
+
with a collection of examples that will provide guidance both
|
|
19
|
+
on the desired structure of the response, as well as the kind
|
|
20
|
+
of content.
|
|
21
|
+
|
|
22
|
+
You should return ONLY valid YAML in your response.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# def select_object(obj: OBJECT, key_paths: List[str]) -> OBJECT:
|
|
27
|
+
# return {k: obj.get(k, None) for k in keys}
|
|
28
|
+
# return {k: object_path_get(obj, k, None) for k in key_paths}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class RAGInferenceEngine(InferenceEngine):
|
|
33
|
+
"""
|
|
34
|
+
AI Retrieval Augmented Generation (RAG) based predictor.
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
>>> from linkml_store.api.client import Client
|
|
38
|
+
>>> from linkml_store.utils.format_utils import Format
|
|
39
|
+
>>> from linkml_store.inference.inference_config import LLMConfig
|
|
40
|
+
>>> client = Client()
|
|
41
|
+
>>> db = client.attach_database("duckdb", alias="test")
|
|
42
|
+
>>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
|
|
43
|
+
>>> db.list_collection_names()
|
|
44
|
+
['countries']
|
|
45
|
+
>>> collection = db.get_collection("countries")
|
|
46
|
+
>>> features = ["name"]
|
|
47
|
+
>>> targets = ["code", "capital", "continent", "languages"]
|
|
48
|
+
>>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
|
|
49
|
+
>>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
|
|
50
|
+
>>> ie = RAGInferenceEngine(config=config)
|
|
51
|
+
>>> ie.load_and_split_data(collection)
|
|
52
|
+
>>> ie.initialize_model()
|
|
53
|
+
>>> prediction = ie.derive({"name": "Uruguay"})
|
|
54
|
+
>>> prediction.predicted_object
|
|
55
|
+
{'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
classifier: Any = None
|
|
60
|
+
encoders: dict = None
|
|
61
|
+
_model: "llm.Model" = None # noqa: F821
|
|
62
|
+
|
|
63
|
+
rag_collection: Collection = None
|
|
64
|
+
|
|
65
|
+
def __post_init__(self):
|
|
66
|
+
if not self.config:
|
|
67
|
+
self.config = InferenceConfig()
|
|
68
|
+
if not self.config.llm_config:
|
|
69
|
+
self.config.llm_config = LLMConfig()
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def model(self) -> "llm.Model": # noqa: F821
|
|
73
|
+
import llm
|
|
74
|
+
|
|
75
|
+
if self._model is None:
|
|
76
|
+
self._model = llm.get_model(self.config.llm_config.model_name)
|
|
77
|
+
if self._model.needs_key:
|
|
78
|
+
key = get_key(None, key_alias=self._model.needs_key)
|
|
79
|
+
self._model.key = key
|
|
80
|
+
|
|
81
|
+
return self._model
|
|
82
|
+
|
|
83
|
+
def initialize_model(self, **kwargs):
|
|
84
|
+
rag_collection = self.training_data.collection
|
|
85
|
+
rag_collection.attach_indexer("llm", auto_index=False)
|
|
86
|
+
self.rag_collection = rag_collection
|
|
87
|
+
|
|
88
|
+
def object_to_text(self, object: OBJECT) -> str:
|
|
89
|
+
return yaml.dump(object)
|
|
90
|
+
|
|
91
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
92
|
+
import llm
|
|
93
|
+
from tiktoken import encoding_for_model
|
|
94
|
+
|
|
95
|
+
from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
|
|
96
|
+
|
|
97
|
+
model: llm.Model = self.model
|
|
98
|
+
model_name = self.config.llm_config.model_name
|
|
99
|
+
feature_attributes = self.config.feature_attributes
|
|
100
|
+
target_attributes = self.config.target_attributes
|
|
101
|
+
num_examples = self.config.llm_config.number_of_few_shot_examples or 5
|
|
102
|
+
query_text = self.object_to_text(object)
|
|
103
|
+
if not self.rag_collection.indexers:
|
|
104
|
+
raise ValueError("RAG collection must have an indexer attached")
|
|
105
|
+
rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
|
|
106
|
+
examples = rs.rows
|
|
107
|
+
if not examples:
|
|
108
|
+
raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
|
|
109
|
+
prompt_clauses = []
|
|
110
|
+
for example in examples:
|
|
111
|
+
# input_obj = {k: example.get(k, None) for k in feature_attributes}
|
|
112
|
+
input_obj = select_nested(example, feature_attributes)
|
|
113
|
+
# output_obj = {k: example.get(k, None) for k in target_attributes}
|
|
114
|
+
output_obj = select_nested(example, target_attributes)
|
|
115
|
+
prompt_clause = (
|
|
116
|
+
"---\nExample:\n"
|
|
117
|
+
f"## INPUT:\n{self.object_to_text(input_obj)}\n"
|
|
118
|
+
f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
|
|
119
|
+
)
|
|
120
|
+
prompt_clauses.append(prompt_clause)
|
|
121
|
+
# query_obj = {k: object.get(k, None) for k in feature_attributes}
|
|
122
|
+
query_obj = select_nested(object, feature_attributes)
|
|
123
|
+
query_text = self.object_to_text(query_obj)
|
|
124
|
+
prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
|
|
125
|
+
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
126
|
+
|
|
127
|
+
def make_text(texts):
|
|
128
|
+
return "\n".join(prompt_clauses) + prompt_end
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
encoding = encoding_for_model(model_name)
|
|
132
|
+
except KeyError:
|
|
133
|
+
encoding = encoding_for_model("gpt-4")
|
|
134
|
+
token_limit = get_token_limit(model_name)
|
|
135
|
+
prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
|
|
136
|
+
logger.info(f"Prompt: {prompt}")
|
|
137
|
+
response = model.prompt(prompt, system_prompt)
|
|
138
|
+
yaml_str = response.text()
|
|
139
|
+
logger.info(f"Response: {yaml_str}")
|
|
140
|
+
try:
|
|
141
|
+
predicted_object = yaml.safe_load(yaml_str)
|
|
142
|
+
return Inference(predicted_object=predicted_object)
|
|
143
|
+
except yaml.parser.ParserError as e:
|
|
144
|
+
logger.error(f"Error parsing response: {yaml_str}\n{e}")
|
|
145
|
+
return None
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import copy
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, ClassVar, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from linkml_map.utils.eval_utils import eval_expr
|
|
10
|
+
from linkml_runtime import SchemaView
|
|
11
|
+
from linkml_runtime.linkml_model.meta import AnonymousClassExpression, ClassRule
|
|
12
|
+
from linkml_runtime.utils.formatutils import underscore
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
16
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig
|
|
17
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def expression_matches(ce: AnonymousClassExpression, object: OBJECT) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
Check if a class expression matches an object.
|
|
25
|
+
|
|
26
|
+
:param ce: The class expression
|
|
27
|
+
:param object: The object
|
|
28
|
+
:return: True if the class expression matches the object
|
|
29
|
+
"""
|
|
30
|
+
if ce.any_of:
|
|
31
|
+
if not any(expression_matches(subce, object) for subce in ce.any_of):
|
|
32
|
+
return False
|
|
33
|
+
if ce.all_of:
|
|
34
|
+
if not all(expression_matches(subce, object) for subce in ce.all_of):
|
|
35
|
+
return False
|
|
36
|
+
if ce.none_of:
|
|
37
|
+
if any(expression_matches(subce, object) for subce in ce.none_of):
|
|
38
|
+
return False
|
|
39
|
+
if ce.slot_conditions:
|
|
40
|
+
for slot in ce.slot_conditions.values():
|
|
41
|
+
slot_name = slot.name
|
|
42
|
+
v = object.get(slot_name, None)
|
|
43
|
+
if slot.equals_string is not None:
|
|
44
|
+
if slot.equals_string != str(v):
|
|
45
|
+
return False
|
|
46
|
+
if slot.equals_integer is not None:
|
|
47
|
+
if slot.equals_integer != v:
|
|
48
|
+
return False
|
|
49
|
+
if slot.equals_expression is not None:
|
|
50
|
+
eval_v = eval_expr(slot.equals_expression, **object)
|
|
51
|
+
if v != eval_v:
|
|
52
|
+
return False
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def apply_rule(rule: ClassRule, object: OBJECT):
|
|
57
|
+
"""
|
|
58
|
+
Apply a rule to an object.
|
|
59
|
+
|
|
60
|
+
Mutates the object
|
|
61
|
+
|
|
62
|
+
:param rule: The rule to apply
|
|
63
|
+
:param object: The object to apply the rule to
|
|
64
|
+
"""
|
|
65
|
+
for condition in rule.preconditions:
|
|
66
|
+
if expression_matches(condition, object):
|
|
67
|
+
for postcondition in rule.postconditions:
|
|
68
|
+
all_of = [x for x in postcondition.all_of] + [postcondition]
|
|
69
|
+
for pc in all_of:
|
|
70
|
+
sc = pc.slot_condition
|
|
71
|
+
if sc:
|
|
72
|
+
if sc.equals_string:
|
|
73
|
+
object[sc.name] = sc.equals_string
|
|
74
|
+
if sc.equals_integer:
|
|
75
|
+
object[sc.name] = sc.equals_integer
|
|
76
|
+
if sc.equals_expression:
|
|
77
|
+
object[sc.name] = eval_expr(sc.equals_expression, **object)
|
|
78
|
+
return object
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class RuleBasedInferenceEngine(InferenceEngine):
|
|
83
|
+
"""
|
|
84
|
+
TODO
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
class_rules: Optional[List[ClassRule]] = None
|
|
89
|
+
slot_rules: Optional[Dict[str, List[ClassRule]]] = None
|
|
90
|
+
slot_expressions: Optional[Dict[str, str]] = None
|
|
91
|
+
|
|
92
|
+
PERSIST_COLS: ClassVar = ["config", "class_rules", "slot_rules", "slot_expressions"]
|
|
93
|
+
|
|
94
|
+
def initialize_model(self, **kwargs):
|
|
95
|
+
td = self.training_data
|
|
96
|
+
collection: Collection = td.collection
|
|
97
|
+
cd = collection.class_definition()
|
|
98
|
+
sv: SchemaView = collection.parent.schema_view
|
|
99
|
+
class_rules = cd.rules
|
|
100
|
+
if class_rules:
|
|
101
|
+
self.class_rules = class_rules
|
|
102
|
+
for slot in sv.class_induced_slots(cd.name):
|
|
103
|
+
if slot.equals_expression:
|
|
104
|
+
self.slot_expressions[slot.name] = slot.equals_expression
|
|
105
|
+
|
|
106
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
107
|
+
object = copy(object)
|
|
108
|
+
if self.class_rules:
|
|
109
|
+
for rule in self.class_rules:
|
|
110
|
+
apply_rule(rule, object)
|
|
111
|
+
object = {underscore(k): v for k, v in object.items()}
|
|
112
|
+
if self.slot_expressions:
|
|
113
|
+
for slot, expr in self.slot_expressions.items():
|
|
114
|
+
v = eval_expr(expr, **object)
|
|
115
|
+
if v is not None:
|
|
116
|
+
object[slot] = v
|
|
117
|
+
if self.config and self.config.target_attributes:
|
|
118
|
+
predicted_object = {k: object.get(k, None) for k in self.config.target_attributes}
|
|
119
|
+
else:
|
|
120
|
+
predicted_object = object
|
|
121
|
+
if all(v is None for v in predicted_object.values()):
|
|
122
|
+
return None
|
|
123
|
+
return Inference(predicted_object=predicted_object)
|
|
124
|
+
|
|
125
|
+
def import_model_from(self, inference_engine: InferenceEngine, **kwargs):
|
|
126
|
+
io = StringIO()
|
|
127
|
+
inference_engine.export_model(io, model_serialization=ModelSerialization.LINKML_EXPRESSION)
|
|
128
|
+
config = inference_engine.config
|
|
129
|
+
if len(config.target_attributes) != 1:
|
|
130
|
+
raise ValueError("Can only import models with a single target attribute")
|
|
131
|
+
target_attribute = config.target_attributes[0]
|
|
132
|
+
if self.slot_expressions is None:
|
|
133
|
+
self.slot_expressions = {}
|
|
134
|
+
self.slot_expressions[target_attribute] = io.getvalue()
|
|
135
|
+
if not self.config:
|
|
136
|
+
self.config = inference_engine.config
|
|
137
|
+
|
|
138
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Save the trained model and related data to a file.
|
|
141
|
+
|
|
142
|
+
:param output: Path to save the model
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def _serialize_value(v: Any) -> Any:
|
|
146
|
+
if isinstance(v, BaseModel):
|
|
147
|
+
return v.model_dump(exclude_unset=True)
|
|
148
|
+
return v
|
|
149
|
+
|
|
150
|
+
model_data = {k: _serialize_value(getattr(self, k)) for k in self.PERSIST_COLS}
|
|
151
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
152
|
+
yaml.dump(model_data, f)
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def load_model(cls, file_path: Union[str, Path]) -> "RuleBasedInferenceEngine":
|
|
156
|
+
model_data = yaml.safe_load(open(file_path))
|
|
157
|
+
|
|
158
|
+
if model_data["config"]:
|
|
159
|
+
config = InferenceConfig(**model_data["config"])
|
|
160
|
+
else:
|
|
161
|
+
config = None
|
|
162
|
+
engine = cls(config=config)
|
|
163
|
+
for k, v in model_data.items():
|
|
164
|
+
if k == "config":
|
|
165
|
+
continue
|
|
166
|
+
setattr(engine, k, v)
|
|
167
|
+
|
|
168
|
+
logger.info(f"Model loaded from {file_path}")
|
|
169
|
+
return engine
|