linkml-store 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linkml_store/__init__.py +7 -0
- linkml_store/api/__init__.py +8 -0
- linkml_store/api/client.py +414 -0
- linkml_store/api/collection.py +1280 -0
- linkml_store/api/config.py +187 -0
- linkml_store/api/database.py +862 -0
- linkml_store/api/queries.py +69 -0
- linkml_store/api/stores/__init__.py +0 -0
- linkml_store/api/stores/chromadb/__init__.py +7 -0
- linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
- linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
- linkml_store/api/stores/dremio/__init__.py +10 -0
- linkml_store/api/stores/dremio/dremio_collection.py +555 -0
- linkml_store/api/stores/dremio/dremio_database.py +1052 -0
- linkml_store/api/stores/dremio/mappings.py +105 -0
- linkml_store/api/stores/dremio_rest/__init__.py +11 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
- linkml_store/api/stores/duckdb/__init__.py +16 -0
- linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
- linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
- linkml_store/api/stores/duckdb/mappings.py +8 -0
- linkml_store/api/stores/filesystem/__init__.py +15 -0
- linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
- linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
- linkml_store/api/stores/hdf5/__init__.py +7 -0
- linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
- linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
- linkml_store/api/stores/ibis/__init__.py +5 -0
- linkml_store/api/stores/ibis/ibis_collection.py +488 -0
- linkml_store/api/stores/ibis/ibis_database.py +328 -0
- linkml_store/api/stores/mongodb/__init__.py +25 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
- linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
- linkml_store/api/stores/neo4j/__init__.py +0 -0
- linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
- linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
- linkml_store/api/stores/solr/__init__.py +3 -0
- linkml_store/api/stores/solr/solr_collection.py +224 -0
- linkml_store/api/stores/solr/solr_database.py +83 -0
- linkml_store/api/stores/solr/solr_utils.py +0 -0
- linkml_store/api/types.py +4 -0
- linkml_store/cli.py +1147 -0
- linkml_store/constants.py +7 -0
- linkml_store/graphs/__init__.py +0 -0
- linkml_store/graphs/graph_map.py +24 -0
- linkml_store/index/__init__.py +53 -0
- linkml_store/index/implementations/__init__.py +0 -0
- linkml_store/index/implementations/llm_indexer.py +174 -0
- linkml_store/index/implementations/simple_indexer.py +43 -0
- linkml_store/index/indexer.py +211 -0
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/evaluation.py +195 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/llm_inference_engine.py +154 -0
- linkml_store/inference/implementations/rag_inference_engine.py +276 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
- linkml_store/inference/inference_config.py +66 -0
- linkml_store/inference/inference_engine.py +209 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/plotting/__init__.py +5 -0
- linkml_store/plotting/cli.py +826 -0
- linkml_store/plotting/dimensionality_reduction.py +453 -0
- linkml_store/plotting/embedding_plot.py +489 -0
- linkml_store/plotting/facet_chart.py +73 -0
- linkml_store/plotting/heatmap.py +383 -0
- linkml_store/utils/__init__.py +0 -0
- linkml_store/utils/change_utils.py +17 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/embedding_matcher.py +424 -0
- linkml_store/utils/embedding_utils.py +299 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/file_utils.py +37 -0
- linkml_store/utils/format_utils.py +550 -0
- linkml_store/utils/io.py +38 -0
- linkml_store/utils/llm_utils.py +122 -0
- linkml_store/utils/mongodb_utils.py +145 -0
- linkml_store/utils/neo4j_utils.py +42 -0
- linkml_store/utils/object_utils.py +190 -0
- linkml_store/utils/pandas_utils.py +93 -0
- linkml_store/utils/patch_utils.py +126 -0
- linkml_store/utils/query_utils.py +89 -0
- linkml_store/utils/schema_utils.py +23 -0
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/sql_utils.py +177 -0
- linkml_store/utils/stats_utils.py +53 -0
- linkml_store/utils/vector_utils.py +158 -0
- linkml_store/webapi/__init__.py +0 -0
- linkml_store/webapi/html/__init__.py +3 -0
- linkml_store/webapi/html/base.html.j2 +24 -0
- linkml_store/webapi/html/collection_details.html.j2 +15 -0
- linkml_store/webapi/html/database_details.html.j2 +16 -0
- linkml_store/webapi/html/databases.html.j2 +14 -0
- linkml_store/webapi/html/generic.html.j2 +43 -0
- linkml_store/webapi/main.py +855 -0
- linkml_store-0.3.0.dist-info/METADATA +226 -0
- linkml_store-0.3.0.dist-info/RECORD +101 -0
- linkml_store-0.3.0.dist-info/WHEEL +4 -0
- linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
- linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from linkml_store.inference import InferenceEngine
|
|
10
|
+
from linkml_store.utils.object_utils import select_nested
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def score_match(target: Optional[Any], candidate: Optional[Any], match_function: Optional[Callable] = None) -> float:
|
|
16
|
+
"""
|
|
17
|
+
Compute a score for a match between two objects
|
|
18
|
+
|
|
19
|
+
>>> score_match("a", "a")
|
|
20
|
+
1.0
|
|
21
|
+
>>> score_match("a", "b")
|
|
22
|
+
0.0
|
|
23
|
+
>>> score_match("abcd", "abcde")
|
|
24
|
+
0.0
|
|
25
|
+
>>> score_match("a", None)
|
|
26
|
+
0.0
|
|
27
|
+
>>> score_match(None, "a")
|
|
28
|
+
0.0
|
|
29
|
+
>>> score_match(None, None)
|
|
30
|
+
1.0
|
|
31
|
+
>>> score_match(["a", "b"], ["a", "b"])
|
|
32
|
+
1.0
|
|
33
|
+
>>> score_match(["a", "b"], ["b", "a"])
|
|
34
|
+
1.0
|
|
35
|
+
>>> round(score_match(["a"], ["b", "a"]), 2)
|
|
36
|
+
0.67
|
|
37
|
+
>>> score_match({"a": 1}, {"a": 1})
|
|
38
|
+
1.0
|
|
39
|
+
>>> score_match({"a": 1}, {"a": 2})
|
|
40
|
+
0.0
|
|
41
|
+
>>> score_match({"a": 1, "b": None}, {"a": 1})
|
|
42
|
+
1.0
|
|
43
|
+
>>> score_match([{"a": 1, "b": 2}, {"a": 3, "b": 4}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
|
|
44
|
+
1.0
|
|
45
|
+
>>> score_match([{"a": 1, "b": 4}, {"a": 3, "b": 2}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
|
|
46
|
+
0.5
|
|
47
|
+
>>> def char_match(x, y):
|
|
48
|
+
... return len(set(x).intersection(set(y))) / len(set(x).union(set(y)))
|
|
49
|
+
>>> score_match("abcd", "abc", char_match)
|
|
50
|
+
0.75
|
|
51
|
+
>>> score_match(["abcd", "efgh"], ["ac", "gh"], char_match)
|
|
52
|
+
0.5
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
:param target:
|
|
56
|
+
:param candidate:
|
|
57
|
+
:param match_function: defaults to struct
|
|
58
|
+
:return:
|
|
59
|
+
"""
|
|
60
|
+
if target == candidate:
|
|
61
|
+
return 1.0
|
|
62
|
+
if target is None or candidate is None:
|
|
63
|
+
return 0.0
|
|
64
|
+
if isinstance(target, (set, list)) and isinstance(candidate, (set, list)):
|
|
65
|
+
# create an all by all matrix using numpy
|
|
66
|
+
# for each pair of elements, compute the score
|
|
67
|
+
# return the average score
|
|
68
|
+
score_matrix = np.array([[score_match(t, c, match_function) for c in candidate] for t in target])
|
|
69
|
+
best_matches0 = np.max(score_matrix, axis=0)
|
|
70
|
+
best_matches1 = np.max(score_matrix, axis=1)
|
|
71
|
+
return (np.sum(best_matches0) + np.sum(best_matches1)) / (len(target) + len(candidate))
|
|
72
|
+
if isinstance(target, dict) and isinstance(candidate, dict):
|
|
73
|
+
keys = set(target.keys()).union(candidate.keys())
|
|
74
|
+
scores = [score_match(target.get(k), candidate.get(k), match_function) for k in keys]
|
|
75
|
+
return np.mean(scores)
|
|
76
|
+
if match_function:
|
|
77
|
+
return match_function(target, candidate)
|
|
78
|
+
return 0.0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Outcome(BaseModel):
|
|
82
|
+
true_positive_count: float
|
|
83
|
+
total_count: int
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def accuracy(self) -> float:
|
|
87
|
+
return self.true_positive_count / self.total_count
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def evaluate_predictor(
|
|
91
|
+
predictor: InferenceEngine,
|
|
92
|
+
target_attributes: List[str],
|
|
93
|
+
feature_attributes: Optional[List[str]] = None,
|
|
94
|
+
test_data: pd.DataFrame = None,
|
|
95
|
+
evaluation_count: Optional[int] = 10,
|
|
96
|
+
match_function: Optional[Callable] = None,
|
|
97
|
+
) -> Outcome:
|
|
98
|
+
"""
|
|
99
|
+
Evaluate a predictor by comparing its predictions to the expected values in the testing data.
|
|
100
|
+
|
|
101
|
+
:param predictor:
|
|
102
|
+
:param target_attributes:
|
|
103
|
+
:param feature_attributes:
|
|
104
|
+
:param evaluation_count: max iterations
|
|
105
|
+
:param match_function: function to use for matching
|
|
106
|
+
:return:
|
|
107
|
+
"""
|
|
108
|
+
n = 0
|
|
109
|
+
tp = 0
|
|
110
|
+
if test_data is None:
|
|
111
|
+
test_data = predictor.testing_data.as_dataframe()
|
|
112
|
+
for row in test_data.to_dict(orient="records"):
|
|
113
|
+
expected_obj = select_nested(row, target_attributes)
|
|
114
|
+
if feature_attributes:
|
|
115
|
+
test_obj = {k: v for k, v in row.items() if k not in target_attributes}
|
|
116
|
+
else:
|
|
117
|
+
test_obj = row
|
|
118
|
+
result = predictor.derive(test_obj)
|
|
119
|
+
tp += score_match(result.predicted_object, expected_obj, match_function)
|
|
120
|
+
logger.info(f"TP={tp} MF={match_function} Predicted: {result.predicted_object} Expected: {expected_obj}")
|
|
121
|
+
n += 1
|
|
122
|
+
if evaluation_count is not None and n >= evaluation_count:
|
|
123
|
+
break
|
|
124
|
+
return Outcome(true_positive_count=tp, total_count=n)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def score_text_overlap(str1: Any, str2: Any) -> float:
|
|
128
|
+
"""
|
|
129
|
+
Compute the overlap score between two strings.
|
|
130
|
+
|
|
131
|
+
>>> score_text_overlap("abc", "bcde")
|
|
132
|
+
0.5
|
|
133
|
+
|
|
134
|
+
:param str1:
|
|
135
|
+
:param str2:
|
|
136
|
+
:return:
|
|
137
|
+
"""
|
|
138
|
+
if str1 == str2:
|
|
139
|
+
return 1.0
|
|
140
|
+
if not str1 or not str2:
|
|
141
|
+
return 0.0
|
|
142
|
+
overlap, length = find_longest_overlap(str1, str2)
|
|
143
|
+
return len(overlap) / max(len(str1), len(str2))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def find_longest_overlap(str1: str, str2: str):
|
|
147
|
+
"""
|
|
148
|
+
Find the longest overlapping substring between two strings.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
str1 (str): The first string
|
|
152
|
+
str2 (str): The second string
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
tuple: A tuple containing the longest overlapping substring and its length
|
|
156
|
+
|
|
157
|
+
Examples:
|
|
158
|
+
>>> find_longest_overlap("hello world", "world of programming")
|
|
159
|
+
('world', 5)
|
|
160
|
+
>>> find_longest_overlap("abcdefg", "defghi")
|
|
161
|
+
('defg', 4)
|
|
162
|
+
>>> find_longest_overlap("python", "java")
|
|
163
|
+
('', 0)
|
|
164
|
+
>>> find_longest_overlap("", "test")
|
|
165
|
+
('', 0)
|
|
166
|
+
>>> find_longest_overlap("aabbcc", "ddeeff")
|
|
167
|
+
('', 0)
|
|
168
|
+
>>> find_longest_overlap("programming", "PROGRAMMING")
|
|
169
|
+
('', 0)
|
|
170
|
+
"""
|
|
171
|
+
if not str1 or not str2:
|
|
172
|
+
return "", 0
|
|
173
|
+
|
|
174
|
+
# Create a table to store lengths of matching substrings
|
|
175
|
+
m, n = len(str1), len(str2)
|
|
176
|
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
177
|
+
|
|
178
|
+
# Variables to store the maximum length and ending position
|
|
179
|
+
max_length = 0
|
|
180
|
+
end_pos = 0
|
|
181
|
+
|
|
182
|
+
# Fill the dp table
|
|
183
|
+
for i in range(1, m + 1):
|
|
184
|
+
for j in range(1, n + 1):
|
|
185
|
+
if str1[i - 1] == str2[j - 1]:
|
|
186
|
+
dp[i][j] = dp[i - 1][j - 1] + 1
|
|
187
|
+
if dp[i][j] > max_length:
|
|
188
|
+
max_length = dp[i][j]
|
|
189
|
+
end_pos = i
|
|
190
|
+
|
|
191
|
+
# Extract the longest common substring
|
|
192
|
+
start_pos = end_pos - max_length
|
|
193
|
+
longest_substring = str1[start_pos:end_pos]
|
|
194
|
+
|
|
195
|
+
return longest_substring, max_length
|
|
File without changes
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import ClassVar, List, Optional, TextIO, Union
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
from llm import get_key
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from linkml_store.api.collection import OBJECT
|
|
11
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
|
|
12
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
13
|
+
from linkml_store.utils.llm_utils import parse_yaml_payload
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
MAX_ITERATIONS = 5
|
|
18
|
+
DEFAULT_NUM_EXAMPLES = 20
|
|
19
|
+
|
|
20
|
+
SYSTEM_PROMPT = """
|
|
21
|
+
Your task is to inference the complete YAML
|
|
22
|
+
object output given the YAML object input. I will provide you
|
|
23
|
+
with contextual information, including the schema,
|
|
24
|
+
to help with the inference. You can use the following
|
|
25
|
+
|
|
26
|
+
You should return ONLY valid YAML in your response.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TrainedModel(BaseModel, extra="forbid"):
|
|
31
|
+
index_rows: List[OBJECT]
|
|
32
|
+
config: Optional[InferenceConfig] = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LLMInference(Inference):
|
|
36
|
+
iterations: int = 0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class LLMInferenceEngine(InferenceEngine):
|
|
41
|
+
"""
|
|
42
|
+
LLM based predictor.
|
|
43
|
+
|
|
44
|
+
Unlike the RAG predictor this performs few-shot inference
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
_model: "llm.Model" = None # noqa: F821
|
|
48
|
+
|
|
49
|
+
PERSIST_COLS: ClassVar[List[str]] = [
|
|
50
|
+
"config",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
def __post_init__(self):
|
|
54
|
+
if not self.config:
|
|
55
|
+
self.config = InferenceConfig()
|
|
56
|
+
if not self.config.llm_config:
|
|
57
|
+
self.config.llm_config = LLMConfig()
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def model(self) -> "llm.Model": # noqa: F821
|
|
61
|
+
import llm
|
|
62
|
+
|
|
63
|
+
if self._model is None:
|
|
64
|
+
self._model = llm.get_model(self.config.llm_config.model_name)
|
|
65
|
+
if self._model.needs_key:
|
|
66
|
+
key = get_key(None, key_alias=self._model.needs_key)
|
|
67
|
+
self._model.key = key
|
|
68
|
+
|
|
69
|
+
return self._model
|
|
70
|
+
|
|
71
|
+
def initialize_model(self, **kwargs):
|
|
72
|
+
logger.info(f"Initializing model {self.model}")
|
|
73
|
+
|
|
74
|
+
def object_to_text(self, object: OBJECT) -> str:
|
|
75
|
+
return yaml.dump(object)
|
|
76
|
+
|
|
77
|
+
def _schema_str(self) -> str:
|
|
78
|
+
db = self.training_data.base_collection.parent
|
|
79
|
+
from linkml_runtime.dumpers import json_dumper
|
|
80
|
+
|
|
81
|
+
schema_dict = json_dumper.to_dict(db.schema_view.schema)
|
|
82
|
+
return yaml.dump(schema_dict)
|
|
83
|
+
|
|
84
|
+
def derive(
|
|
85
|
+
self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
|
|
86
|
+
) -> Optional[LLMInference]:
|
|
87
|
+
import llm
|
|
88
|
+
|
|
89
|
+
model: llm.Model = self.model
|
|
90
|
+
# model_name = self.config.llm_config.model_name
|
|
91
|
+
# feature_attributes = self.config.feature_attributes
|
|
92
|
+
target_attributes = self.config.target_attributes
|
|
93
|
+
query_text = self.object_to_text(object)
|
|
94
|
+
|
|
95
|
+
if not target_attributes:
|
|
96
|
+
target_attributes = [k for k, v in object.items() if v is None or v == ""]
|
|
97
|
+
# if not feature_attributes:
|
|
98
|
+
# feature_attributes = [k for k, v in object.items() if v is not None and v != ""]
|
|
99
|
+
|
|
100
|
+
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
101
|
+
|
|
102
|
+
system_prompt += "\n## SCHEMA:\n\n" + self._schema_str()
|
|
103
|
+
|
|
104
|
+
stub = ", ".join([f"{k}: ..." for k in target_attributes])
|
|
105
|
+
stub = "{" + stub + "}"
|
|
106
|
+
prompt = (
|
|
107
|
+
"Provide a YAML object of the form"
|
|
108
|
+
"```yaml\n"
|
|
109
|
+
f"{stub}\n"
|
|
110
|
+
"```\n"
|
|
111
|
+
"---\nQuery:\n"
|
|
112
|
+
f"## INCOMPLETE OBJECT:\n{query_text}\n"
|
|
113
|
+
"## OUTPUT:\n"
|
|
114
|
+
)
|
|
115
|
+
logger.info(f"Prompt: {prompt}")
|
|
116
|
+
response = model.prompt(prompt, system=system_prompt)
|
|
117
|
+
yaml_str = response.text()
|
|
118
|
+
logger.info(f"Response: {yaml_str}")
|
|
119
|
+
predicted_object = parse_yaml_payload(yaml_str, strict=True)
|
|
120
|
+
predicted_object = {**object, **predicted_object}
|
|
121
|
+
if self.config.validate_results:
|
|
122
|
+
base_collection = self.training_data.base_collection
|
|
123
|
+
errs = list(base_collection.iter_validate_collection([predicted_object]))
|
|
124
|
+
if errs:
|
|
125
|
+
print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}")
|
|
126
|
+
print(f"PARSED: {predicted_object}")
|
|
127
|
+
print(f"ERRORS: {errs}")
|
|
128
|
+
if iteration > MAX_ITERATIONS:
|
|
129
|
+
raise ValueError(f"Validation errors: {errs}")
|
|
130
|
+
extra_texts = [
|
|
131
|
+
"Make sure results conform to the schema. Previously you provided:\n",
|
|
132
|
+
yaml_str,
|
|
133
|
+
"\nThis was invalid.\n",
|
|
134
|
+
"Validation errors:\n",
|
|
135
|
+
] + [self.object_to_text(e) for e in errs]
|
|
136
|
+
return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
|
|
137
|
+
return LLMInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
|
|
138
|
+
|
|
139
|
+
def export_model(
|
|
140
|
+
self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
|
|
141
|
+
):
|
|
142
|
+
self.save_model(output)
|
|
143
|
+
|
|
144
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Save the trained model and related data to a file.
|
|
147
|
+
|
|
148
|
+
:param output: Path to save the model
|
|
149
|
+
"""
|
|
150
|
+
raise NotImplementedError("Does not make sense for this engine")
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def load_model(cls, file_path: Union[str, Path]) -> "LLMInferenceEngine":
|
|
154
|
+
raise NotImplementedError("Does not make sense for this engine")
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import ClassVar, List, Optional, TextIO, Union
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
from llm import get_key
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
12
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
|
|
13
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
14
|
+
from linkml_store.utils.object_utils import select_nested
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
MAX_ITERATIONS = 5
|
|
19
|
+
DEFAULT_NUM_EXAMPLES = 20
|
|
20
|
+
DEFAULT_MMR_RELEVANCE_FACTOR = 0.8
|
|
21
|
+
|
|
22
|
+
SYSTEM_PROMPT = """
|
|
23
|
+
You are a {llm_config.role}, your task is to infer the YAML
|
|
24
|
+
object output given the YAML object input. I will provide you
|
|
25
|
+
with a collection of examples that will provide guidance both
|
|
26
|
+
on the desired structure of the response, as well as the kind
|
|
27
|
+
of content.
|
|
28
|
+
|
|
29
|
+
You should return ONLY valid YAML in your response.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TrainedModel(BaseModel, extra="forbid"):
|
|
34
|
+
rag_collection_rows: List[OBJECT]
|
|
35
|
+
index_rows: List[OBJECT]
|
|
36
|
+
config: Optional[InferenceConfig] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RAGInference(Inference):
|
|
40
|
+
iterations: int = 0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class RAGInferenceEngine(InferenceEngine):
|
|
45
|
+
"""
|
|
46
|
+
AI Retrieval Augmented Generation (RAG) based predictor.
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
>>> from linkml_store.api.client import Client
|
|
50
|
+
>>> from linkml_store.utils.format_utils import Format
|
|
51
|
+
>>> from linkml_store.inference.inference_config import LLMConfig
|
|
52
|
+
>>> client = Client()
|
|
53
|
+
>>> db = client.attach_database("duckdb", alias="test")
|
|
54
|
+
>>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
|
|
55
|
+
>>> db.list_collection_names()
|
|
56
|
+
['countries']
|
|
57
|
+
>>> collection = db.get_collection("countries")
|
|
58
|
+
>>> features = ["name"]
|
|
59
|
+
>>> targets = ["code", "capital", "continent", "languages"]
|
|
60
|
+
>>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
|
|
61
|
+
>>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
|
|
62
|
+
>>> ie = RAGInferenceEngine(config=config)
|
|
63
|
+
>>> ie.load_and_split_data(collection)
|
|
64
|
+
>>> ie.initialize_model()
|
|
65
|
+
>>> prediction = ie.derive({"name": "Uruguay"})
|
|
66
|
+
>>> prediction.predicted_object
|
|
67
|
+
{'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
|
|
68
|
+
|
|
69
|
+
The "model" can be saved for later use:
|
|
70
|
+
|
|
71
|
+
>>> ie.export_model("tests/output/countries.rag_model.json")
|
|
72
|
+
|
|
73
|
+
Note in this case the model is not the underlying LLM, but the "RAG Model" which is the vectorized
|
|
74
|
+
representation of training set objects.
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
_model: "llm.Model" = None # noqa: F821
|
|
79
|
+
|
|
80
|
+
rag_collection: Collection = None
|
|
81
|
+
|
|
82
|
+
PERSIST_COLS: ClassVar[List[str]] = [
|
|
83
|
+
"config",
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
def __post_init__(self):
|
|
87
|
+
if not self.config:
|
|
88
|
+
self.config = InferenceConfig()
|
|
89
|
+
if not self.config.llm_config:
|
|
90
|
+
self.config.llm_config = LLMConfig()
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def model(self) -> "llm.Model": # noqa: F821
|
|
94
|
+
import llm
|
|
95
|
+
|
|
96
|
+
if self._model is None:
|
|
97
|
+
self._model = llm.get_model(self.config.llm_config.model_name)
|
|
98
|
+
if self._model.needs_key:
|
|
99
|
+
key = get_key(None, key_alias=self._model.needs_key)
|
|
100
|
+
self._model.key = key
|
|
101
|
+
|
|
102
|
+
return self._model
|
|
103
|
+
|
|
104
|
+
def initialize_model(self, **kwargs):
|
|
105
|
+
logger.info(f"Initializing model {self.model}")
|
|
106
|
+
if self.training_data:
|
|
107
|
+
rag_collection = self.training_data.collection
|
|
108
|
+
rag_collection.attach_indexer("llm", auto_index=False)
|
|
109
|
+
self.rag_collection = rag_collection
|
|
110
|
+
|
|
111
|
+
def object_to_text(self, object: OBJECT) -> str:
|
|
112
|
+
return yaml.dump(object)
|
|
113
|
+
|
|
114
|
+
def derive(
|
|
115
|
+
self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
|
|
116
|
+
) -> Optional[RAGInference]:
|
|
117
|
+
import llm
|
|
118
|
+
from tiktoken import encoding_for_model
|
|
119
|
+
|
|
120
|
+
from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
|
|
121
|
+
|
|
122
|
+
model: llm.Model = self.model
|
|
123
|
+
model_name = self.config.llm_config.model_name
|
|
124
|
+
feature_attributes = self.config.feature_attributes
|
|
125
|
+
target_attributes = self.config.target_attributes
|
|
126
|
+
num_examples = self.config.llm_config.number_of_few_shot_examples or DEFAULT_NUM_EXAMPLES
|
|
127
|
+
query_text = self.object_to_text(object)
|
|
128
|
+
mmr_relevance_factor = DEFAULT_MMR_RELEVANCE_FACTOR
|
|
129
|
+
if not self.rag_collection:
|
|
130
|
+
# TODO: zero-shot mode
|
|
131
|
+
examples = []
|
|
132
|
+
else:
|
|
133
|
+
if not self.rag_collection.indexers:
|
|
134
|
+
raise ValueError("RAG collection must have an indexer attached")
|
|
135
|
+
logger.info(f"Searching {self.rag_collection.alias} for examples for: {query_text}")
|
|
136
|
+
rs = self.rag_collection.search(
|
|
137
|
+
query_text, limit=num_examples, index_name="llm", mmr_relevance_factor=mmr_relevance_factor
|
|
138
|
+
)
|
|
139
|
+
examples = rs.rows
|
|
140
|
+
logger.info(f"Found {len(examples)} examples")
|
|
141
|
+
if not examples:
|
|
142
|
+
raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
|
|
143
|
+
prompt_clauses = []
|
|
144
|
+
this_feature_attributes = feature_attributes
|
|
145
|
+
if not this_feature_attributes:
|
|
146
|
+
this_feature_attributes = list(set(object.keys()) - set(target_attributes))
|
|
147
|
+
query_obj = select_nested(object, this_feature_attributes)
|
|
148
|
+
query_text = self.object_to_text(query_obj)
|
|
149
|
+
for example in examples:
|
|
150
|
+
this_feature_attributes = feature_attributes
|
|
151
|
+
if not this_feature_attributes:
|
|
152
|
+
this_feature_attributes = list(set(example.keys()) - set(target_attributes))
|
|
153
|
+
if not this_feature_attributes:
|
|
154
|
+
raise ValueError(f"No feature attributes found in example {example}")
|
|
155
|
+
input_obj = select_nested(example, this_feature_attributes)
|
|
156
|
+
input_obj_text = self.object_to_text(input_obj)
|
|
157
|
+
if input_obj_text == query_text:
|
|
158
|
+
continue
|
|
159
|
+
# raise ValueError(
|
|
160
|
+
# f"Query object {query_text} is the same as example object {input_obj_text}\n"
|
|
161
|
+
# "This indicates possible test data leakage\n."
|
|
162
|
+
# "TODO: allow an option that allows user to treat this as a basic lookup\n"
|
|
163
|
+
# )
|
|
164
|
+
output_obj = select_nested(example, target_attributes)
|
|
165
|
+
prompt_clause = (
|
|
166
|
+
"---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
|
|
167
|
+
)
|
|
168
|
+
prompt_clauses.append(prompt_clause)
|
|
169
|
+
|
|
170
|
+
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
171
|
+
system_prompt += "\n".join(additional_prompt_texts or [])
|
|
172
|
+
prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
|
|
173
|
+
|
|
174
|
+
def make_text(texts: List[str]):
|
|
175
|
+
return "\n".join(texts) + prompt_end
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
encoding = encoding_for_model(model_name)
|
|
179
|
+
except KeyError:
|
|
180
|
+
encoding = encoding_for_model("gpt-4")
|
|
181
|
+
token_limit = get_token_limit(model_name)
|
|
182
|
+
prompt = render_formatted_text(
|
|
183
|
+
make_text, values=prompt_clauses, encoding=encoding, token_limit=token_limit, additional_text=system_prompt
|
|
184
|
+
)
|
|
185
|
+
logger.info(f"Prompt: {prompt}")
|
|
186
|
+
response = model.prompt(prompt, system=system_prompt)
|
|
187
|
+
yaml_str = response.text()
|
|
188
|
+
logger.info(f"Response: {yaml_str}")
|
|
189
|
+
predicted_object = self._parse_yaml_payload(yaml_str, strict=True)
|
|
190
|
+
if self.config.validate_results:
|
|
191
|
+
base_collection = self.training_data.base_collection
|
|
192
|
+
errs = list(base_collection.iter_validate_collection([predicted_object]))
|
|
193
|
+
if errs:
|
|
194
|
+
print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}")
|
|
195
|
+
print(f"PARSED: {predicted_object}")
|
|
196
|
+
print(f"ERRORS: {errs}")
|
|
197
|
+
if iteration > MAX_ITERATIONS:
|
|
198
|
+
raise ValueError(f"Validation errors: {errs}")
|
|
199
|
+
extra_texts = [
|
|
200
|
+
"Make sure results conform to the schema. Previously you provided:\n",
|
|
201
|
+
yaml_str,
|
|
202
|
+
"\nThis was invalid.\n",
|
|
203
|
+
"Validation errors:\n",
|
|
204
|
+
] + [self.object_to_text(e) for e in errs]
|
|
205
|
+
return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
|
|
206
|
+
return RAGInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
|
|
207
|
+
|
|
208
|
+
def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
|
|
209
|
+
if "```" in yaml_str:
|
|
210
|
+
yaml_str = yaml_str.split("```")[1].strip()
|
|
211
|
+
if yaml_str.startswith("yaml"):
|
|
212
|
+
yaml_str = yaml_str[4:].strip()
|
|
213
|
+
try:
|
|
214
|
+
return yaml.safe_load(yaml_str)
|
|
215
|
+
except Exception as e:
|
|
216
|
+
if strict:
|
|
217
|
+
raise e
|
|
218
|
+
logger.error(f"Error parsing YAML: {yaml_str}\n{e}")
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
def export_model(
|
|
222
|
+
self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
|
|
223
|
+
):
|
|
224
|
+
self.save_model(output)
|
|
225
|
+
|
|
226
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Save the trained model and related data to a file.
|
|
229
|
+
|
|
230
|
+
:param output: Path to save the model
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
# trigger index
|
|
234
|
+
_qr = self.rag_collection.search("*", limit=1)
|
|
235
|
+
assert len(_qr.ranked_rows) > 0
|
|
236
|
+
|
|
237
|
+
rows = self.rag_collection.find(limit=-1).rows
|
|
238
|
+
|
|
239
|
+
indexers = self.rag_collection.indexers
|
|
240
|
+
assert len(indexers) == 1
|
|
241
|
+
ix = self.rag_collection.indexers["llm"]
|
|
242
|
+
ix_coll = self.rag_collection.parent.get_collection(self.rag_collection.get_index_collection_name(ix))
|
|
243
|
+
|
|
244
|
+
ix_rows = ix_coll.find(limit=-1).rows
|
|
245
|
+
assert len(ix_rows) > 0
|
|
246
|
+
tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows, config=self.config)
|
|
247
|
+
# tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows)
|
|
248
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
249
|
+
json.dump(tm.model_dump(), f)
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def load_model(cls, file_path: Union[str, Path]) -> "RAGInferenceEngine":
|
|
253
|
+
"""
|
|
254
|
+
Load a trained model and related data from a file.
|
|
255
|
+
|
|
256
|
+
:param file_path: Path to the saved model
|
|
257
|
+
:return: SklearnInferenceEngine instance with loaded model
|
|
258
|
+
"""
|
|
259
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
260
|
+
model_data = json.load(f)
|
|
261
|
+
tm = TrainedModel(**model_data)
|
|
262
|
+
from linkml_store.api import Client
|
|
263
|
+
|
|
264
|
+
client = Client()
|
|
265
|
+
db = client.attach_database("duckdb", alias="training")
|
|
266
|
+
db.store({"data": tm.rag_collection_rows})
|
|
267
|
+
collection = db.get_collection("data")
|
|
268
|
+
ix = collection.attach_indexer("llm", auto_index=False)
|
|
269
|
+
assert ix.name
|
|
270
|
+
ix_coll_name = collection.get_index_collection_name(ix)
|
|
271
|
+
assert ix_coll_name
|
|
272
|
+
ix_coll = db.get_collection(ix_coll_name, create_if_not_exists=True)
|
|
273
|
+
ix_coll.insert(tm.index_rows)
|
|
274
|
+
ie = cls(config=tm.config)
|
|
275
|
+
ie.rag_collection = collection
|
|
276
|
+
return ie
|