linkml-store 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. linkml_store/__init__.py +7 -0
  2. linkml_store/api/__init__.py +8 -0
  3. linkml_store/api/client.py +414 -0
  4. linkml_store/api/collection.py +1280 -0
  5. linkml_store/api/config.py +187 -0
  6. linkml_store/api/database.py +862 -0
  7. linkml_store/api/queries.py +69 -0
  8. linkml_store/api/stores/__init__.py +0 -0
  9. linkml_store/api/stores/chromadb/__init__.py +7 -0
  10. linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
  11. linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
  12. linkml_store/api/stores/dremio/__init__.py +10 -0
  13. linkml_store/api/stores/dremio/dremio_collection.py +555 -0
  14. linkml_store/api/stores/dremio/dremio_database.py +1052 -0
  15. linkml_store/api/stores/dremio/mappings.py +105 -0
  16. linkml_store/api/stores/dremio_rest/__init__.py +11 -0
  17. linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
  18. linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
  19. linkml_store/api/stores/duckdb/__init__.py +16 -0
  20. linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
  21. linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
  22. linkml_store/api/stores/duckdb/mappings.py +8 -0
  23. linkml_store/api/stores/filesystem/__init__.py +15 -0
  24. linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
  25. linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
  26. linkml_store/api/stores/hdf5/__init__.py +7 -0
  27. linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
  28. linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
  29. linkml_store/api/stores/ibis/__init__.py +5 -0
  30. linkml_store/api/stores/ibis/ibis_collection.py +488 -0
  31. linkml_store/api/stores/ibis/ibis_database.py +328 -0
  32. linkml_store/api/stores/mongodb/__init__.py +25 -0
  33. linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
  34. linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
  35. linkml_store/api/stores/neo4j/__init__.py +0 -0
  36. linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
  37. linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
  38. linkml_store/api/stores/solr/__init__.py +3 -0
  39. linkml_store/api/stores/solr/solr_collection.py +224 -0
  40. linkml_store/api/stores/solr/solr_database.py +83 -0
  41. linkml_store/api/stores/solr/solr_utils.py +0 -0
  42. linkml_store/api/types.py +4 -0
  43. linkml_store/cli.py +1147 -0
  44. linkml_store/constants.py +7 -0
  45. linkml_store/graphs/__init__.py +0 -0
  46. linkml_store/graphs/graph_map.py +24 -0
  47. linkml_store/index/__init__.py +53 -0
  48. linkml_store/index/implementations/__init__.py +0 -0
  49. linkml_store/index/implementations/llm_indexer.py +174 -0
  50. linkml_store/index/implementations/simple_indexer.py +43 -0
  51. linkml_store/index/indexer.py +211 -0
  52. linkml_store/inference/__init__.py +13 -0
  53. linkml_store/inference/evaluation.py +195 -0
  54. linkml_store/inference/implementations/__init__.py +0 -0
  55. linkml_store/inference/implementations/llm_inference_engine.py +154 -0
  56. linkml_store/inference/implementations/rag_inference_engine.py +276 -0
  57. linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
  58. linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
  59. linkml_store/inference/inference_config.py +66 -0
  60. linkml_store/inference/inference_engine.py +209 -0
  61. linkml_store/inference/inference_engine_registry.py +74 -0
  62. linkml_store/plotting/__init__.py +5 -0
  63. linkml_store/plotting/cli.py +826 -0
  64. linkml_store/plotting/dimensionality_reduction.py +453 -0
  65. linkml_store/plotting/embedding_plot.py +489 -0
  66. linkml_store/plotting/facet_chart.py +73 -0
  67. linkml_store/plotting/heatmap.py +383 -0
  68. linkml_store/utils/__init__.py +0 -0
  69. linkml_store/utils/change_utils.py +17 -0
  70. linkml_store/utils/dat_parser.py +95 -0
  71. linkml_store/utils/embedding_matcher.py +424 -0
  72. linkml_store/utils/embedding_utils.py +299 -0
  73. linkml_store/utils/enrichment_analyzer.py +217 -0
  74. linkml_store/utils/file_utils.py +37 -0
  75. linkml_store/utils/format_utils.py +550 -0
  76. linkml_store/utils/io.py +38 -0
  77. linkml_store/utils/llm_utils.py +122 -0
  78. linkml_store/utils/mongodb_utils.py +145 -0
  79. linkml_store/utils/neo4j_utils.py +42 -0
  80. linkml_store/utils/object_utils.py +190 -0
  81. linkml_store/utils/pandas_utils.py +93 -0
  82. linkml_store/utils/patch_utils.py +126 -0
  83. linkml_store/utils/query_utils.py +89 -0
  84. linkml_store/utils/schema_utils.py +23 -0
  85. linkml_store/utils/sklearn_utils.py +193 -0
  86. linkml_store/utils/sql_utils.py +177 -0
  87. linkml_store/utils/stats_utils.py +53 -0
  88. linkml_store/utils/vector_utils.py +158 -0
  89. linkml_store/webapi/__init__.py +0 -0
  90. linkml_store/webapi/html/__init__.py +3 -0
  91. linkml_store/webapi/html/base.html.j2 +24 -0
  92. linkml_store/webapi/html/collection_details.html.j2 +15 -0
  93. linkml_store/webapi/html/database_details.html.j2 +16 -0
  94. linkml_store/webapi/html/databases.html.j2 +14 -0
  95. linkml_store/webapi/html/generic.html.j2 +43 -0
  96. linkml_store/webapi/main.py +855 -0
  97. linkml_store-0.3.0.dist-info/METADATA +226 -0
  98. linkml_store-0.3.0.dist-info/RECORD +101 -0
  99. linkml_store-0.3.0.dist-info/WHEEL +4 -0
  100. linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
  101. linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,195 @@
1
+ import logging
2
+ from collections.abc import Callable
3
+ from typing import Any, List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pydantic import BaseModel
8
+
9
+ from linkml_store.inference import InferenceEngine
10
+ from linkml_store.utils.object_utils import select_nested
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def score_match(target: Optional[Any], candidate: Optional[Any], match_function: Optional[Callable] = None) -> float:
16
+ """
17
+ Compute a score for a match between two objects
18
+
19
+ >>> score_match("a", "a")
20
+ 1.0
21
+ >>> score_match("a", "b")
22
+ 0.0
23
+ >>> score_match("abcd", "abcde")
24
+ 0.0
25
+ >>> score_match("a", None)
26
+ 0.0
27
+ >>> score_match(None, "a")
28
+ 0.0
29
+ >>> score_match(None, None)
30
+ 1.0
31
+ >>> score_match(["a", "b"], ["a", "b"])
32
+ 1.0
33
+ >>> score_match(["a", "b"], ["b", "a"])
34
+ 1.0
35
+ >>> round(score_match(["a"], ["b", "a"]), 2)
36
+ 0.67
37
+ >>> score_match({"a": 1}, {"a": 1})
38
+ 1.0
39
+ >>> score_match({"a": 1}, {"a": 2})
40
+ 0.0
41
+ >>> score_match({"a": 1, "b": None}, {"a": 1})
42
+ 1.0
43
+ >>> score_match([{"a": 1, "b": 2}, {"a": 3, "b": 4}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
44
+ 1.0
45
+ >>> score_match([{"a": 1, "b": 4}, {"a": 3, "b": 2}], [{"a": 1, "b": 2}, {"a": 3, "b": 4}])
46
+ 0.5
47
+ >>> def char_match(x, y):
48
+ ... return len(set(x).intersection(set(y))) / len(set(x).union(set(y)))
49
+ >>> score_match("abcd", "abc", char_match)
50
+ 0.75
51
+ >>> score_match(["abcd", "efgh"], ["ac", "gh"], char_match)
52
+ 0.5
53
+
54
+
55
+ :param target:
56
+ :param candidate:
57
+ :param match_function: defaults to struct
58
+ :return:
59
+ """
60
+ if target == candidate:
61
+ return 1.0
62
+ if target is None or candidate is None:
63
+ return 0.0
64
+ if isinstance(target, (set, list)) and isinstance(candidate, (set, list)):
65
+ # create an all by all matrix using numpy
66
+ # for each pair of elements, compute the score
67
+ # return the average score
68
+ score_matrix = np.array([[score_match(t, c, match_function) for c in candidate] for t in target])
69
+ best_matches0 = np.max(score_matrix, axis=0)
70
+ best_matches1 = np.max(score_matrix, axis=1)
71
+ return (np.sum(best_matches0) + np.sum(best_matches1)) / (len(target) + len(candidate))
72
+ if isinstance(target, dict) and isinstance(candidate, dict):
73
+ keys = set(target.keys()).union(candidate.keys())
74
+ scores = [score_match(target.get(k), candidate.get(k), match_function) for k in keys]
75
+ return np.mean(scores)
76
+ if match_function:
77
+ return match_function(target, candidate)
78
+ return 0.0
79
+
80
+
81
+ class Outcome(BaseModel):
82
+ true_positive_count: float
83
+ total_count: int
84
+
85
+ @property
86
+ def accuracy(self) -> float:
87
+ return self.true_positive_count / self.total_count
88
+
89
+
90
+ def evaluate_predictor(
91
+ predictor: InferenceEngine,
92
+ target_attributes: List[str],
93
+ feature_attributes: Optional[List[str]] = None,
94
+ test_data: pd.DataFrame = None,
95
+ evaluation_count: Optional[int] = 10,
96
+ match_function: Optional[Callable] = None,
97
+ ) -> Outcome:
98
+ """
99
+ Evaluate a predictor by comparing its predictions to the expected values in the testing data.
100
+
101
+ :param predictor:
102
+ :param target_attributes:
103
+ :param feature_attributes:
104
+ :param evaluation_count: max iterations
105
+ :param match_function: function to use for matching
106
+ :return:
107
+ """
108
+ n = 0
109
+ tp = 0
110
+ if test_data is None:
111
+ test_data = predictor.testing_data.as_dataframe()
112
+ for row in test_data.to_dict(orient="records"):
113
+ expected_obj = select_nested(row, target_attributes)
114
+ if feature_attributes:
115
+ test_obj = {k: v for k, v in row.items() if k not in target_attributes}
116
+ else:
117
+ test_obj = row
118
+ result = predictor.derive(test_obj)
119
+ tp += score_match(result.predicted_object, expected_obj, match_function)
120
+ logger.info(f"TP={tp} MF={match_function} Predicted: {result.predicted_object} Expected: {expected_obj}")
121
+ n += 1
122
+ if evaluation_count is not None and n >= evaluation_count:
123
+ break
124
+ return Outcome(true_positive_count=tp, total_count=n)
125
+
126
+
127
+ def score_text_overlap(str1: Any, str2: Any) -> float:
128
+ """
129
+ Compute the overlap score between two strings.
130
+
131
+ >>> score_text_overlap("abc", "bcde")
132
+ 0.5
133
+
134
+ :param str1:
135
+ :param str2:
136
+ :return:
137
+ """
138
+ if str1 == str2:
139
+ return 1.0
140
+ if not str1 or not str2:
141
+ return 0.0
142
+ overlap, length = find_longest_overlap(str1, str2)
143
+ return len(overlap) / max(len(str1), len(str2))
144
+
145
+
146
+ def find_longest_overlap(str1: str, str2: str):
147
+ """
148
+ Find the longest overlapping substring between two strings.
149
+
150
+ Args:
151
+ str1 (str): The first string
152
+ str2 (str): The second string
153
+
154
+ Returns:
155
+ tuple: A tuple containing the longest overlapping substring and its length
156
+
157
+ Examples:
158
+ >>> find_longest_overlap("hello world", "world of programming")
159
+ ('world', 5)
160
+ >>> find_longest_overlap("abcdefg", "defghi")
161
+ ('defg', 4)
162
+ >>> find_longest_overlap("python", "java")
163
+ ('', 0)
164
+ >>> find_longest_overlap("", "test")
165
+ ('', 0)
166
+ >>> find_longest_overlap("aabbcc", "ddeeff")
167
+ ('', 0)
168
+ >>> find_longest_overlap("programming", "PROGRAMMING")
169
+ ('', 0)
170
+ """
171
+ if not str1 or not str2:
172
+ return "", 0
173
+
174
+ # Create a table to store lengths of matching substrings
175
+ m, n = len(str1), len(str2)
176
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
177
+
178
+ # Variables to store the maximum length and ending position
179
+ max_length = 0
180
+ end_pos = 0
181
+
182
+ # Fill the dp table
183
+ for i in range(1, m + 1):
184
+ for j in range(1, n + 1):
185
+ if str1[i - 1] == str2[j - 1]:
186
+ dp[i][j] = dp[i - 1][j - 1] + 1
187
+ if dp[i][j] > max_length:
188
+ max_length = dp[i][j]
189
+ end_pos = i
190
+
191
+ # Extract the longest common substring
192
+ start_pos = end_pos - max_length
193
+ longest_substring = str1[start_pos:end_pos]
194
+
195
+ return longest_substring, max_length
File without changes
@@ -0,0 +1,154 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import ClassVar, List, Optional, TextIO, Union
5
+
6
+ import yaml
7
+ from llm import get_key
8
+ from pydantic import BaseModel
9
+
10
+ from linkml_store.api.collection import OBJECT
11
+ from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
12
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
13
+ from linkml_store.utils.llm_utils import parse_yaml_payload
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ MAX_ITERATIONS = 5
18
+ DEFAULT_NUM_EXAMPLES = 20
19
+
20
+ SYSTEM_PROMPT = """
21
+ Your task is to inference the complete YAML
22
+ object output given the YAML object input. I will provide you
23
+ with contextual information, including the schema,
24
+ to help with the inference. You can use the following
25
+
26
+ You should return ONLY valid YAML in your response.
27
+ """
28
+
29
+
30
+ class TrainedModel(BaseModel, extra="forbid"):
31
+ index_rows: List[OBJECT]
32
+ config: Optional[InferenceConfig] = None
33
+
34
+
35
+ class LLMInference(Inference):
36
+ iterations: int = 0
37
+
38
+
39
+ @dataclass
40
+ class LLMInferenceEngine(InferenceEngine):
41
+ """
42
+ LLM based predictor.
43
+
44
+ Unlike the RAG predictor this performs few-shot inference
45
+ """
46
+
47
+ _model: "llm.Model" = None # noqa: F821
48
+
49
+ PERSIST_COLS: ClassVar[List[str]] = [
50
+ "config",
51
+ ]
52
+
53
+ def __post_init__(self):
54
+ if not self.config:
55
+ self.config = InferenceConfig()
56
+ if not self.config.llm_config:
57
+ self.config.llm_config = LLMConfig()
58
+
59
+ @property
60
+ def model(self) -> "llm.Model": # noqa: F821
61
+ import llm
62
+
63
+ if self._model is None:
64
+ self._model = llm.get_model(self.config.llm_config.model_name)
65
+ if self._model.needs_key:
66
+ key = get_key(None, key_alias=self._model.needs_key)
67
+ self._model.key = key
68
+
69
+ return self._model
70
+
71
+ def initialize_model(self, **kwargs):
72
+ logger.info(f"Initializing model {self.model}")
73
+
74
+ def object_to_text(self, object: OBJECT) -> str:
75
+ return yaml.dump(object)
76
+
77
+ def _schema_str(self) -> str:
78
+ db = self.training_data.base_collection.parent
79
+ from linkml_runtime.dumpers import json_dumper
80
+
81
+ schema_dict = json_dumper.to_dict(db.schema_view.schema)
82
+ return yaml.dump(schema_dict)
83
+
84
+ def derive(
85
+ self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
86
+ ) -> Optional[LLMInference]:
87
+ import llm
88
+
89
+ model: llm.Model = self.model
90
+ # model_name = self.config.llm_config.model_name
91
+ # feature_attributes = self.config.feature_attributes
92
+ target_attributes = self.config.target_attributes
93
+ query_text = self.object_to_text(object)
94
+
95
+ if not target_attributes:
96
+ target_attributes = [k for k, v in object.items() if v is None or v == ""]
97
+ # if not feature_attributes:
98
+ # feature_attributes = [k for k, v in object.items() if v is not None and v != ""]
99
+
100
+ system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
101
+
102
+ system_prompt += "\n## SCHEMA:\n\n" + self._schema_str()
103
+
104
+ stub = ", ".join([f"{k}: ..." for k in target_attributes])
105
+ stub = "{" + stub + "}"
106
+ prompt = (
107
+ "Provide a YAML object of the form"
108
+ "```yaml\n"
109
+ f"{stub}\n"
110
+ "```\n"
111
+ "---\nQuery:\n"
112
+ f"## INCOMPLETE OBJECT:\n{query_text}\n"
113
+ "## OUTPUT:\n"
114
+ )
115
+ logger.info(f"Prompt: {prompt}")
116
+ response = model.prompt(prompt, system=system_prompt)
117
+ yaml_str = response.text()
118
+ logger.info(f"Response: {yaml_str}")
119
+ predicted_object = parse_yaml_payload(yaml_str, strict=True)
120
+ predicted_object = {**object, **predicted_object}
121
+ if self.config.validate_results:
122
+ base_collection = self.training_data.base_collection
123
+ errs = list(base_collection.iter_validate_collection([predicted_object]))
124
+ if errs:
125
+ print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}")
126
+ print(f"PARSED: {predicted_object}")
127
+ print(f"ERRORS: {errs}")
128
+ if iteration > MAX_ITERATIONS:
129
+ raise ValueError(f"Validation errors: {errs}")
130
+ extra_texts = [
131
+ "Make sure results conform to the schema. Previously you provided:\n",
132
+ yaml_str,
133
+ "\nThis was invalid.\n",
134
+ "Validation errors:\n",
135
+ ] + [self.object_to_text(e) for e in errs]
136
+ return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
137
+ return LLMInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
138
+
139
+ def export_model(
140
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
141
+ ):
142
+ self.save_model(output)
143
+
144
+ def save_model(self, output: Union[str, Path]) -> None:
145
+ """
146
+ Save the trained model and related data to a file.
147
+
148
+ :param output: Path to save the model
149
+ """
150
+ raise NotImplementedError("Does not make sense for this engine")
151
+
152
+ @classmethod
153
+ def load_model(cls, file_path: Union[str, Path]) -> "LLMInferenceEngine":
154
+ raise NotImplementedError("Does not make sense for this engine")
@@ -0,0 +1,276 @@
1
+ import json
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import ClassVar, List, Optional, TextIO, Union
6
+
7
+ import yaml
8
+ from llm import get_key
9
+ from pydantic import BaseModel
10
+
11
+ from linkml_store.api.collection import OBJECT, Collection
12
+ from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
13
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
14
+ from linkml_store.utils.object_utils import select_nested
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ MAX_ITERATIONS = 5
19
+ DEFAULT_NUM_EXAMPLES = 20
20
+ DEFAULT_MMR_RELEVANCE_FACTOR = 0.8
21
+
22
+ SYSTEM_PROMPT = """
23
+ You are a {llm_config.role}, your task is to infer the YAML
24
+ object output given the YAML object input. I will provide you
25
+ with a collection of examples that will provide guidance both
26
+ on the desired structure of the response, as well as the kind
27
+ of content.
28
+
29
+ You should return ONLY valid YAML in your response.
30
+ """
31
+
32
+
33
+ class TrainedModel(BaseModel, extra="forbid"):
34
+ rag_collection_rows: List[OBJECT]
35
+ index_rows: List[OBJECT]
36
+ config: Optional[InferenceConfig] = None
37
+
38
+
39
+ class RAGInference(Inference):
40
+ iterations: int = 0
41
+
42
+
43
+ @dataclass
44
+ class RAGInferenceEngine(InferenceEngine):
45
+ """
46
+ AI Retrieval Augmented Generation (RAG) based predictor.
47
+
48
+
49
+ >>> from linkml_store.api.client import Client
50
+ >>> from linkml_store.utils.format_utils import Format
51
+ >>> from linkml_store.inference.inference_config import LLMConfig
52
+ >>> client = Client()
53
+ >>> db = client.attach_database("duckdb", alias="test")
54
+ >>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
55
+ >>> db.list_collection_names()
56
+ ['countries']
57
+ >>> collection = db.get_collection("countries")
58
+ >>> features = ["name"]
59
+ >>> targets = ["code", "capital", "continent", "languages"]
60
+ >>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
61
+ >>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
62
+ >>> ie = RAGInferenceEngine(config=config)
63
+ >>> ie.load_and_split_data(collection)
64
+ >>> ie.initialize_model()
65
+ >>> prediction = ie.derive({"name": "Uruguay"})
66
+ >>> prediction.predicted_object
67
+ {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
68
+
69
+ The "model" can be saved for later use:
70
+
71
+ >>> ie.export_model("tests/output/countries.rag_model.json")
72
+
73
+ Note in this case the model is not the underlying LLM, but the "RAG Model" which is the vectorized
74
+ representation of training set objects.
75
+
76
+ """
77
+
78
+ _model: "llm.Model" = None # noqa: F821
79
+
80
+ rag_collection: Collection = None
81
+
82
+ PERSIST_COLS: ClassVar[List[str]] = [
83
+ "config",
84
+ ]
85
+
86
+ def __post_init__(self):
87
+ if not self.config:
88
+ self.config = InferenceConfig()
89
+ if not self.config.llm_config:
90
+ self.config.llm_config = LLMConfig()
91
+
92
+ @property
93
+ def model(self) -> "llm.Model": # noqa: F821
94
+ import llm
95
+
96
+ if self._model is None:
97
+ self._model = llm.get_model(self.config.llm_config.model_name)
98
+ if self._model.needs_key:
99
+ key = get_key(None, key_alias=self._model.needs_key)
100
+ self._model.key = key
101
+
102
+ return self._model
103
+
104
+ def initialize_model(self, **kwargs):
105
+ logger.info(f"Initializing model {self.model}")
106
+ if self.training_data:
107
+ rag_collection = self.training_data.collection
108
+ rag_collection.attach_indexer("llm", auto_index=False)
109
+ self.rag_collection = rag_collection
110
+
111
+ def object_to_text(self, object: OBJECT) -> str:
112
+ return yaml.dump(object)
113
+
114
+ def derive(
115
+ self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
116
+ ) -> Optional[RAGInference]:
117
+ import llm
118
+ from tiktoken import encoding_for_model
119
+
120
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
121
+
122
+ model: llm.Model = self.model
123
+ model_name = self.config.llm_config.model_name
124
+ feature_attributes = self.config.feature_attributes
125
+ target_attributes = self.config.target_attributes
126
+ num_examples = self.config.llm_config.number_of_few_shot_examples or DEFAULT_NUM_EXAMPLES
127
+ query_text = self.object_to_text(object)
128
+ mmr_relevance_factor = DEFAULT_MMR_RELEVANCE_FACTOR
129
+ if not self.rag_collection:
130
+ # TODO: zero-shot mode
131
+ examples = []
132
+ else:
133
+ if not self.rag_collection.indexers:
134
+ raise ValueError("RAG collection must have an indexer attached")
135
+ logger.info(f"Searching {self.rag_collection.alias} for examples for: {query_text}")
136
+ rs = self.rag_collection.search(
137
+ query_text, limit=num_examples, index_name="llm", mmr_relevance_factor=mmr_relevance_factor
138
+ )
139
+ examples = rs.rows
140
+ logger.info(f"Found {len(examples)} examples")
141
+ if not examples:
142
+ raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
143
+ prompt_clauses = []
144
+ this_feature_attributes = feature_attributes
145
+ if not this_feature_attributes:
146
+ this_feature_attributes = list(set(object.keys()) - set(target_attributes))
147
+ query_obj = select_nested(object, this_feature_attributes)
148
+ query_text = self.object_to_text(query_obj)
149
+ for example in examples:
150
+ this_feature_attributes = feature_attributes
151
+ if not this_feature_attributes:
152
+ this_feature_attributes = list(set(example.keys()) - set(target_attributes))
153
+ if not this_feature_attributes:
154
+ raise ValueError(f"No feature attributes found in example {example}")
155
+ input_obj = select_nested(example, this_feature_attributes)
156
+ input_obj_text = self.object_to_text(input_obj)
157
+ if input_obj_text == query_text:
158
+ continue
159
+ # raise ValueError(
160
+ # f"Query object {query_text} is the same as example object {input_obj_text}\n"
161
+ # "This indicates possible test data leakage\n."
162
+ # "TODO: allow an option that allows user to treat this as a basic lookup\n"
163
+ # )
164
+ output_obj = select_nested(example, target_attributes)
165
+ prompt_clause = (
166
+ "---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
167
+ )
168
+ prompt_clauses.append(prompt_clause)
169
+
170
+ system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
171
+ system_prompt += "\n".join(additional_prompt_texts or [])
172
+ prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
173
+
174
+ def make_text(texts: List[str]):
175
+ return "\n".join(texts) + prompt_end
176
+
177
+ try:
178
+ encoding = encoding_for_model(model_name)
179
+ except KeyError:
180
+ encoding = encoding_for_model("gpt-4")
181
+ token_limit = get_token_limit(model_name)
182
+ prompt = render_formatted_text(
183
+ make_text, values=prompt_clauses, encoding=encoding, token_limit=token_limit, additional_text=system_prompt
184
+ )
185
+ logger.info(f"Prompt: {prompt}")
186
+ response = model.prompt(prompt, system=system_prompt)
187
+ yaml_str = response.text()
188
+ logger.info(f"Response: {yaml_str}")
189
+ predicted_object = self._parse_yaml_payload(yaml_str, strict=True)
190
+ if self.config.validate_results:
191
+ base_collection = self.training_data.base_collection
192
+ errs = list(base_collection.iter_validate_collection([predicted_object]))
193
+ if errs:
194
+ print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}")
195
+ print(f"PARSED: {predicted_object}")
196
+ print(f"ERRORS: {errs}")
197
+ if iteration > MAX_ITERATIONS:
198
+ raise ValueError(f"Validation errors: {errs}")
199
+ extra_texts = [
200
+ "Make sure results conform to the schema. Previously you provided:\n",
201
+ yaml_str,
202
+ "\nThis was invalid.\n",
203
+ "Validation errors:\n",
204
+ ] + [self.object_to_text(e) for e in errs]
205
+ return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
206
+ return RAGInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
207
+
208
+ def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
209
+ if "```" in yaml_str:
210
+ yaml_str = yaml_str.split("```")[1].strip()
211
+ if yaml_str.startswith("yaml"):
212
+ yaml_str = yaml_str[4:].strip()
213
+ try:
214
+ return yaml.safe_load(yaml_str)
215
+ except Exception as e:
216
+ if strict:
217
+ raise e
218
+ logger.error(f"Error parsing YAML: {yaml_str}\n{e}")
219
+ return None
220
+
221
+ def export_model(
222
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
223
+ ):
224
+ self.save_model(output)
225
+
226
+ def save_model(self, output: Union[str, Path]) -> None:
227
+ """
228
+ Save the trained model and related data to a file.
229
+
230
+ :param output: Path to save the model
231
+ """
232
+
233
+ # trigger index
234
+ _qr = self.rag_collection.search("*", limit=1)
235
+ assert len(_qr.ranked_rows) > 0
236
+
237
+ rows = self.rag_collection.find(limit=-1).rows
238
+
239
+ indexers = self.rag_collection.indexers
240
+ assert len(indexers) == 1
241
+ ix = self.rag_collection.indexers["llm"]
242
+ ix_coll = self.rag_collection.parent.get_collection(self.rag_collection.get_index_collection_name(ix))
243
+
244
+ ix_rows = ix_coll.find(limit=-1).rows
245
+ assert len(ix_rows) > 0
246
+ tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows, config=self.config)
247
+ # tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows)
248
+ with open(output, "w", encoding="utf-8") as f:
249
+ json.dump(tm.model_dump(), f)
250
+
251
+ @classmethod
252
+ def load_model(cls, file_path: Union[str, Path]) -> "RAGInferenceEngine":
253
+ """
254
+ Load a trained model and related data from a file.
255
+
256
+ :param file_path: Path to the saved model
257
+ :return: SklearnInferenceEngine instance with loaded model
258
+ """
259
+ with open(file_path, "r", encoding="utf-8") as f:
260
+ model_data = json.load(f)
261
+ tm = TrainedModel(**model_data)
262
+ from linkml_store.api import Client
263
+
264
+ client = Client()
265
+ db = client.attach_database("duckdb", alias="training")
266
+ db.store({"data": tm.rag_collection_rows})
267
+ collection = db.get_collection("data")
268
+ ix = collection.attach_indexer("llm", auto_index=False)
269
+ assert ix.name
270
+ ix_coll_name = collection.get_index_collection_name(ix)
271
+ assert ix_coll_name
272
+ ix_coll = db.get_collection(ix_coll_name, create_if_not_exists=True)
273
+ ix_coll.insert(tm.index_rows)
274
+ ie = cls(config=tm.config)
275
+ ie.rag_collection = collection
276
+ return ie