linkml-store 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. linkml_store/__init__.py +7 -0
  2. linkml_store/api/__init__.py +8 -0
  3. linkml_store/api/client.py +414 -0
  4. linkml_store/api/collection.py +1280 -0
  5. linkml_store/api/config.py +187 -0
  6. linkml_store/api/database.py +862 -0
  7. linkml_store/api/queries.py +69 -0
  8. linkml_store/api/stores/__init__.py +0 -0
  9. linkml_store/api/stores/chromadb/__init__.py +7 -0
  10. linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
  11. linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
  12. linkml_store/api/stores/dremio/__init__.py +10 -0
  13. linkml_store/api/stores/dremio/dremio_collection.py +555 -0
  14. linkml_store/api/stores/dremio/dremio_database.py +1052 -0
  15. linkml_store/api/stores/dremio/mappings.py +105 -0
  16. linkml_store/api/stores/dremio_rest/__init__.py +11 -0
  17. linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
  18. linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
  19. linkml_store/api/stores/duckdb/__init__.py +16 -0
  20. linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
  21. linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
  22. linkml_store/api/stores/duckdb/mappings.py +8 -0
  23. linkml_store/api/stores/filesystem/__init__.py +15 -0
  24. linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
  25. linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
  26. linkml_store/api/stores/hdf5/__init__.py +7 -0
  27. linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
  28. linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
  29. linkml_store/api/stores/ibis/__init__.py +5 -0
  30. linkml_store/api/stores/ibis/ibis_collection.py +488 -0
  31. linkml_store/api/stores/ibis/ibis_database.py +328 -0
  32. linkml_store/api/stores/mongodb/__init__.py +25 -0
  33. linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
  34. linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
  35. linkml_store/api/stores/neo4j/__init__.py +0 -0
  36. linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
  37. linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
  38. linkml_store/api/stores/solr/__init__.py +3 -0
  39. linkml_store/api/stores/solr/solr_collection.py +224 -0
  40. linkml_store/api/stores/solr/solr_database.py +83 -0
  41. linkml_store/api/stores/solr/solr_utils.py +0 -0
  42. linkml_store/api/types.py +4 -0
  43. linkml_store/cli.py +1147 -0
  44. linkml_store/constants.py +7 -0
  45. linkml_store/graphs/__init__.py +0 -0
  46. linkml_store/graphs/graph_map.py +24 -0
  47. linkml_store/index/__init__.py +53 -0
  48. linkml_store/index/implementations/__init__.py +0 -0
  49. linkml_store/index/implementations/llm_indexer.py +174 -0
  50. linkml_store/index/implementations/simple_indexer.py +43 -0
  51. linkml_store/index/indexer.py +211 -0
  52. linkml_store/inference/__init__.py +13 -0
  53. linkml_store/inference/evaluation.py +195 -0
  54. linkml_store/inference/implementations/__init__.py +0 -0
  55. linkml_store/inference/implementations/llm_inference_engine.py +154 -0
  56. linkml_store/inference/implementations/rag_inference_engine.py +276 -0
  57. linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
  58. linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
  59. linkml_store/inference/inference_config.py +66 -0
  60. linkml_store/inference/inference_engine.py +209 -0
  61. linkml_store/inference/inference_engine_registry.py +74 -0
  62. linkml_store/plotting/__init__.py +5 -0
  63. linkml_store/plotting/cli.py +826 -0
  64. linkml_store/plotting/dimensionality_reduction.py +453 -0
  65. linkml_store/plotting/embedding_plot.py +489 -0
  66. linkml_store/plotting/facet_chart.py +73 -0
  67. linkml_store/plotting/heatmap.py +383 -0
  68. linkml_store/utils/__init__.py +0 -0
  69. linkml_store/utils/change_utils.py +17 -0
  70. linkml_store/utils/dat_parser.py +95 -0
  71. linkml_store/utils/embedding_matcher.py +424 -0
  72. linkml_store/utils/embedding_utils.py +299 -0
  73. linkml_store/utils/enrichment_analyzer.py +217 -0
  74. linkml_store/utils/file_utils.py +37 -0
  75. linkml_store/utils/format_utils.py +550 -0
  76. linkml_store/utils/io.py +38 -0
  77. linkml_store/utils/llm_utils.py +122 -0
  78. linkml_store/utils/mongodb_utils.py +145 -0
  79. linkml_store/utils/neo4j_utils.py +42 -0
  80. linkml_store/utils/object_utils.py +190 -0
  81. linkml_store/utils/pandas_utils.py +93 -0
  82. linkml_store/utils/patch_utils.py +126 -0
  83. linkml_store/utils/query_utils.py +89 -0
  84. linkml_store/utils/schema_utils.py +23 -0
  85. linkml_store/utils/sklearn_utils.py +193 -0
  86. linkml_store/utils/sql_utils.py +177 -0
  87. linkml_store/utils/stats_utils.py +53 -0
  88. linkml_store/utils/vector_utils.py +158 -0
  89. linkml_store/webapi/__init__.py +0 -0
  90. linkml_store/webapi/html/__init__.py +3 -0
  91. linkml_store/webapi/html/base.html.j2 +24 -0
  92. linkml_store/webapi/html/collection_details.html.j2 +15 -0
  93. linkml_store/webapi/html/database_details.html.j2 +16 -0
  94. linkml_store/webapi/html/databases.html.j2 +14 -0
  95. linkml_store/webapi/html/generic.html.j2 +43 -0
  96. linkml_store/webapi/main.py +855 -0
  97. linkml_store-0.3.0.dist-info/METADATA +226 -0
  98. linkml_store-0.3.0.dist-info/RECORD +101 -0
  99. linkml_store-0.3.0.dist-info/WHEEL +4 -0
  100. linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
  101. linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,169 @@
1
+ import logging
2
+ from copy import copy
3
+ from dataclasses import dataclass
4
+ from io import StringIO
5
+ from pathlib import Path
6
+ from typing import Any, ClassVar, Dict, List, Optional, Union
7
+
8
+ import yaml
9
+ from linkml_map.utils.eval_utils import eval_expr
10
+ from linkml_runtime import SchemaView
11
+ from linkml_runtime.linkml_model.meta import AnonymousClassExpression, ClassRule
12
+ from linkml_runtime.utils.formatutils import underscore
13
+ from pydantic import BaseModel
14
+
15
+ from linkml_store.api.collection import OBJECT, Collection
16
+ from linkml_store.inference.inference_config import Inference, InferenceConfig
17
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def expression_matches(ce: AnonymousClassExpression, object: OBJECT) -> bool:
23
+ """
24
+ Check if a class expression matches an object.
25
+
26
+ :param ce: The class expression
27
+ :param object: The object
28
+ :return: True if the class expression matches the object
29
+ """
30
+ if ce.any_of:
31
+ if not any(expression_matches(subce, object) for subce in ce.any_of):
32
+ return False
33
+ if ce.all_of:
34
+ if not all(expression_matches(subce, object) for subce in ce.all_of):
35
+ return False
36
+ if ce.none_of:
37
+ if any(expression_matches(subce, object) for subce in ce.none_of):
38
+ return False
39
+ if ce.slot_conditions:
40
+ for slot in ce.slot_conditions.values():
41
+ slot_name = slot.name
42
+ v = object.get(slot_name, None)
43
+ if slot.equals_string is not None:
44
+ if slot.equals_string != str(v):
45
+ return False
46
+ if slot.equals_integer is not None:
47
+ if slot.equals_integer != v:
48
+ return False
49
+ if slot.equals_expression is not None:
50
+ eval_v = eval_expr(slot.equals_expression, **object)
51
+ if v != eval_v:
52
+ return False
53
+ return True
54
+
55
+
56
+ def apply_rule(rule: ClassRule, object: OBJECT):
57
+ """
58
+ Apply a rule to an object.
59
+
60
+ Mutates the object
61
+
62
+ :param rule: The rule to apply
63
+ :param object: The object to apply the rule to
64
+ """
65
+ for condition in rule.preconditions:
66
+ if expression_matches(condition, object):
67
+ for postcondition in rule.postconditions:
68
+ all_of = [x for x in postcondition.all_of] + [postcondition]
69
+ for pc in all_of:
70
+ sc = pc.slot_condition
71
+ if sc:
72
+ if sc.equals_string:
73
+ object[sc.name] = sc.equals_string
74
+ if sc.equals_integer:
75
+ object[sc.name] = sc.equals_integer
76
+ if sc.equals_expression:
77
+ object[sc.name] = eval_expr(sc.equals_expression, **object)
78
+ return object
79
+
80
+
81
+ @dataclass
82
+ class RuleBasedInferenceEngine(InferenceEngine):
83
+ """
84
+ TODO
85
+
86
+ """
87
+
88
+ class_rules: Optional[List[ClassRule]] = None
89
+ slot_rules: Optional[Dict[str, List[ClassRule]]] = None
90
+ slot_expressions: Optional[Dict[str, str]] = None
91
+
92
+ PERSIST_COLS: ClassVar = ["config", "class_rules", "slot_rules", "slot_expressions"]
93
+
94
+ def initialize_model(self, **kwargs):
95
+ td = self.training_data
96
+ collection: Collection = td.collection
97
+ cd = collection.class_definition()
98
+ sv: SchemaView = collection.parent.schema_view
99
+ class_rules = cd.rules
100
+ if class_rules:
101
+ self.class_rules = class_rules
102
+ for slot in sv.class_induced_slots(cd.name):
103
+ if slot.equals_expression:
104
+ self.slot_expressions[slot.name] = slot.equals_expression
105
+
106
+ def derive(self, object: OBJECT) -> Optional[Inference]:
107
+ object = copy(object)
108
+ if self.class_rules:
109
+ for rule in self.class_rules:
110
+ apply_rule(rule, object)
111
+ object = {underscore(k): v for k, v in object.items()}
112
+ if self.slot_expressions:
113
+ for slot, expr in self.slot_expressions.items():
114
+ v = eval_expr(expr, **object)
115
+ if v is not None:
116
+ object[slot] = v
117
+ if self.config and self.config.target_attributes:
118
+ predicted_object = {k: object.get(k, None) for k in self.config.target_attributes}
119
+ else:
120
+ predicted_object = object
121
+ if all(v is None for v in predicted_object.values()):
122
+ return None
123
+ return Inference(predicted_object=predicted_object)
124
+
125
+ def import_model_from(self, inference_engine: InferenceEngine, **kwargs):
126
+ io = StringIO()
127
+ inference_engine.export_model(io, model_serialization=ModelSerialization.LINKML_EXPRESSION)
128
+ config = inference_engine.config
129
+ if len(config.target_attributes) != 1:
130
+ raise ValueError("Can only import models with a single target attribute")
131
+ target_attribute = config.target_attributes[0]
132
+ if self.slot_expressions is None:
133
+ self.slot_expressions = {}
134
+ self.slot_expressions[target_attribute] = io.getvalue()
135
+ if not self.config:
136
+ self.config = inference_engine.config
137
+
138
+ def save_model(self, output: Union[str, Path]) -> None:
139
+ """
140
+ Save the trained model and related data to a file.
141
+
142
+ :param output: Path to save the model
143
+ """
144
+
145
+ def _serialize_value(v: Any) -> Any:
146
+ if isinstance(v, BaseModel):
147
+ return v.model_dump(exclude_unset=True)
148
+ return v
149
+
150
+ model_data = {k: _serialize_value(getattr(self, k)) for k in self.PERSIST_COLS}
151
+ with open(output, "w", encoding="utf-8") as f:
152
+ yaml.dump(model_data, f)
153
+
154
+ @classmethod
155
+ def load_model(cls, file_path: Union[str, Path]) -> "RuleBasedInferenceEngine":
156
+ model_data = yaml.safe_load(open(file_path))
157
+
158
+ if model_data["config"]:
159
+ config = InferenceConfig(**model_data["config"])
160
+ else:
161
+ config = None
162
+ engine = cls(config=config)
163
+ for k, v in model_data.items():
164
+ if k == "config":
165
+ continue
166
+ setattr(engine, k, v)
167
+
168
+ logger.info(f"Model loaded from {file_path}")
169
+ return engine
@@ -0,0 +1,314 @@
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Any, ClassVar, Dict, List, Optional, TextIO, Type, Union
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.model_selection import cross_val_score
9
+ from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer, OneHotEncoder
10
+ from sklearn.tree import DecisionTreeClassifier
11
+
12
+ from linkml_store.api.collection import OBJECT
13
+ from linkml_store.inference.implementations.rule_based_inference_engine import RuleBasedInferenceEngine
14
+ from linkml_store.inference.inference_config import Inference, InferenceConfig
15
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
16
+ from linkml_store.utils.sklearn_utils import tree_to_nested_expression, visualize_decision_tree
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class SklearnInferenceEngine(InferenceEngine):
23
+ config: InferenceConfig
24
+ classifier: Any = None
25
+ encoders: Dict[str, Any] = field(default_factory=dict)
26
+ transformed_features: List[str] = field(default_factory=list)
27
+ transformed_targets: List[str] = field(default_factory=list)
28
+ skip_features: List[str] = field(default_factory=list)
29
+ categorical_encoder_class: Optional[Type[Union[OneHotEncoder, MultiLabelBinarizer]]] = None
30
+ maximum_proportion_distinct_features: float = 0.2
31
+ confidence: float = 0.0
32
+
33
+ strict: bool = False
34
+
35
+ PERSIST_COLS: ClassVar = [
36
+ "config",
37
+ "classifier",
38
+ "encoders",
39
+ "transformed_features",
40
+ "transformed_targets",
41
+ "skip_features",
42
+ "confidence",
43
+ ]
44
+
45
+ def _get_encoder(self, v: Union[List[Any], Any]) -> Any:
46
+ if isinstance(v, list):
47
+ if all(isinstance(x, list) for x in v):
48
+ return MultiLabelBinarizer()
49
+ elif all(isinstance(x, str) for x in v):
50
+ return OneHotEncoder(sparse_output=False, handle_unknown="ignore")
51
+ elif all(isinstance(x, (int, float)) for x in v):
52
+ return None
53
+ else:
54
+ raise ValueError("Mixed data types in the list are not supported")
55
+ else:
56
+ if hasattr(v, "dtype"):
57
+ if v.dtype == "object" or v.dtype.name == "category":
58
+ if isinstance(v.iloc[0], list):
59
+ return MultiLabelBinarizer()
60
+ elif self.categorical_encoder_class:
61
+ return self.categorical_encoder_class(handle_unknown="ignore")
62
+ else:
63
+ return OneHotEncoder(sparse_output=False, handle_unknown="ignore")
64
+ elif v.dtype.kind in "biufc":
65
+ return None
66
+ raise ValueError("Unable to determine appropriate encoder for the input data")
67
+
68
+ def _is_complex_column(self, column: pd.Series) -> bool:
69
+ """Check if the column contains complex data types like lists or dicts."""
70
+ # MV_TYPE = (list, dict)
71
+ MV_TYPE = (list,)
72
+ return (column.dtype == "object" or column.dtype == "category") and any(
73
+ isinstance(x, MV_TYPE) for x in column.dropna()
74
+ )
75
+
76
+ def _get_unique_values(self, column: pd.Series) -> set:
77
+ """Get unique values from a column, handling list-type data."""
78
+ if self._is_complex_column(column):
79
+ # For columns with lists, flatten the lists and get unique values
80
+ return set(
81
+ item for sublist in column.dropna() for item in (sublist if isinstance(sublist, list) else [sublist])
82
+ )
83
+ else:
84
+ return set(column.unique())
85
+
86
+ def initialize_model(self, **kwargs):
87
+ logger.info(f"Initializing model with config: {self.config}")
88
+ df = self.training_data.as_dataframe(flattened=True)
89
+ logger.info(f"Training data shape: {df.shape}")
90
+ target_cols = self.config.target_attributes
91
+ feature_cols = self.config.feature_attributes
92
+ if len(target_cols) != 1:
93
+ raise ValueError("Only one target column is supported")
94
+ if not feature_cols:
95
+ feature_cols = df.columns.difference(target_cols).tolist()
96
+ self.config.feature_attributes = feature_cols
97
+ if not feature_cols:
98
+ raise ValueError("No features found in the data")
99
+ target_col = target_cols[0]
100
+ logger.info(f"Feature columns: {feature_cols}")
101
+ X = df[feature_cols].copy()
102
+ logger.info(f"Target column: {target_col}")
103
+ y = df[target_col].copy()
104
+
105
+ # find list of features to skip (categorical with > N categories)
106
+ skip_features = []
107
+ if not len(X.columns):
108
+ raise ValueError("No features to train on")
109
+ for col in X.columns:
110
+ unique_values = self._get_unique_values(X[col])
111
+ if len(unique_values) > self.maximum_proportion_distinct_features * len(X[col]):
112
+ skip_features.append(col)
113
+ if False and (X[col].dtype == "object" or X[col].dtype.name == "category"):
114
+ if len(X[col].unique()) > self.maximum_proportion_distinct_features * len(X[col]):
115
+ skip_features.append(col)
116
+ self.skip_features = skip_features
117
+ X = X.drop(skip_features, axis=1)
118
+ logger.info(f"Skipping features: {skip_features}")
119
+
120
+ # Encode features
121
+ encoded_features = []
122
+ if not len(X.columns):
123
+ raise ValueError(f"No features to train on from after skipping {skip_features}")
124
+ for col in X.columns:
125
+ logger.info(f"Checking whether to encode: {col}")
126
+ col_encoder = self._get_encoder(X[col])
127
+ if col_encoder:
128
+ self.encoders[col] = col_encoder
129
+ if isinstance(col_encoder, OneHotEncoder):
130
+ encoded = col_encoder.fit_transform(X[[col]])
131
+ feature_names = col_encoder.get_feature_names_out([col])
132
+ encoded_df = pd.DataFrame(encoded, columns=feature_names, index=X.index)
133
+ X = pd.concat([X.drop(col, axis=1), encoded_df], axis=1)
134
+ encoded_features.extend(feature_names)
135
+ elif isinstance(col_encoder, MultiLabelBinarizer):
136
+ encoded = col_encoder.fit_transform(X[col])
137
+ feature_names = [f"{col}_{c}" for c in col_encoder.classes_]
138
+ encoded_df = pd.DataFrame(encoded, columns=feature_names, index=X.index)
139
+ X = pd.concat([X.drop(col, axis=1), encoded_df], axis=1)
140
+ encoded_features.extend(feature_names)
141
+ else:
142
+ X[col] = col_encoder.fit_transform(X[col])
143
+ encoded_features.append(col)
144
+ else:
145
+ encoded_features.append(col)
146
+
147
+ self.transformed_features = encoded_features
148
+ logger.info(f"Encoded features: {self.transformed_features}")
149
+ logger.info(f"Number of features after encoding: {len(self.transformed_features)}")
150
+
151
+ # Encode target
152
+ # y_encoder = LabelEncoder()
153
+ y_encoder = self._get_encoder(y)
154
+ if isinstance(y_encoder, OneHotEncoder):
155
+ y_encoder = LabelEncoder()
156
+ # self.encoders[target_col] = y_encoder
157
+ if y_encoder:
158
+ self.encoders[target_col] = y_encoder
159
+ y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
160
+ self.transformed_targets = y_encoder.classes_
161
+
162
+ # print(f"Fitting model with features: {X.columns}, y={y}, X={X}")
163
+ clf = DecisionTreeClassifier(random_state=42)
164
+ clf.fit(X, y)
165
+ self.classifier = clf
166
+ logger.info("Model fit complete")
167
+ cv_scores = cross_val_score(self.classifier, X, y, cv=5)
168
+ self.confidence = cv_scores.mean()
169
+ logger.info(f"Cross-validation scores: {cv_scores}")
170
+
171
+ def derive(self, object: OBJECT) -> Optional[Inference]:
172
+ object = self._normalize(object)
173
+ new_X = pd.DataFrame([object])
174
+
175
+ # Apply encodings
176
+ encoded_features = {}
177
+ for col in self.config.feature_attributes:
178
+ if col in self.skip_features:
179
+ continue
180
+ if col in self.encoders:
181
+ encoder = self.encoders[col]
182
+ if isinstance(encoder, OneHotEncoder):
183
+ print(f"Encoding: {col} v={object[col]} df={new_X[[col]]} encoder={encoder}")
184
+ encoded = encoder.transform(new_X[[col]])
185
+ feature_names = encoder.get_feature_names_out([col])
186
+ for i, name in enumerate(feature_names):
187
+ encoded_features[name] = encoded[0, i]
188
+ elif isinstance(encoder, MultiLabelBinarizer):
189
+ encoded = encoder.transform(new_X[col])
190
+ feature_names = [f"{col}_{c}" for c in encoder.classes_]
191
+ for i, name in enumerate(feature_names):
192
+ encoded_features[name] = encoded[0, i]
193
+ else: # LabelEncoder or similar
194
+ encoded_features[col] = encoder.transform(new_X[col].astype(str))[0]
195
+ else:
196
+ encoded_features[col] = new_X[col].iloc[0]
197
+
198
+ # Ensure all expected features are present and in the correct order
199
+ final_features = []
200
+ for feature in self.transformed_features:
201
+ if feature in encoded_features:
202
+ final_features.append(encoded_features[feature])
203
+ else:
204
+ final_features.append(0) # or some other default value
205
+
206
+ # Create the final input array
207
+ new_X_array = np.array(final_features).reshape(1, -1)
208
+
209
+ logger.info(f"Input features: {self.transformed_features}")
210
+ logger.info(f"Number of input features: {len(self.transformed_features)}")
211
+
212
+ predictions = self.classifier.predict(new_X_array)
213
+ target_attribute = self.config.target_attributes[0]
214
+ y_encoder = self.encoders.get(target_attribute)
215
+
216
+ if y_encoder:
217
+ v = y_encoder.inverse_transform(predictions)
218
+ else:
219
+ v = predictions
220
+
221
+ predicted_object = {target_attribute: v[0]}
222
+ logger.info(f"Predicted object: {predicted_object}")
223
+ return Inference(predicted_object=predicted_object, confidence=self.confidence)
224
+
225
+ def _normalize(self, object: OBJECT) -> OBJECT:
226
+ """
227
+ Normalize the input object to ensure it has all the expected attributes.
228
+
229
+ Also remove any numpy/pandas oddities
230
+
231
+ :param object:
232
+ :return:
233
+ """
234
+ np_map = {np.nan: None}
235
+
236
+ def _tr(x: Any):
237
+ # TODO: figure a more elegant way to do this
238
+ try:
239
+ return np_map.get(x, x)
240
+ except TypeError:
241
+ return x
242
+
243
+ return {k: _tr(object.get(k, None)) for k in self.config.feature_attributes}
244
+
245
+ def export_model(
246
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
247
+ ):
248
+ def as_file():
249
+ if isinstance(output, (str, Path)):
250
+ return open(output, "w")
251
+ return output
252
+
253
+ if model_serialization is None:
254
+ if isinstance(output, (str, Path)):
255
+ model_serialization = ModelSerialization.from_filepath(output)
256
+ if model_serialization is None:
257
+ model_serialization = ModelSerialization.JOBLIB
258
+
259
+ if model_serialization == ModelSerialization.LINKML_EXPRESSION:
260
+ expr = tree_to_nested_expression(
261
+ self.classifier,
262
+ self.transformed_features,
263
+ self.encoders.keys(),
264
+ feature_encoders=self.encoders,
265
+ target_encoder=self.encoders.get(self.config.target_attributes[0]),
266
+ )
267
+ as_file().write(expr)
268
+ elif model_serialization == ModelSerialization.JOBLIB:
269
+ self.save_model(output)
270
+ elif model_serialization == ModelSerialization.RULE_BASED:
271
+ rbie = RuleBasedInferenceEngine(config=self.config)
272
+ rbie.import_model_from(self)
273
+ rbie.save_model(output)
274
+ elif model_serialization == ModelSerialization.PNG:
275
+ visualize_decision_tree(self.classifier, self.transformed_features, self.transformed_targets, output)
276
+ else:
277
+ raise ValueError(f"Unsupported model serialization: {model_serialization}")
278
+
279
+ def save_model(self, output: Union[str, Path]) -> None:
280
+ """
281
+ Save the trained model and related data to a file.
282
+
283
+ :param output: Path to save the model
284
+ """
285
+ import joblib
286
+
287
+ if self.classifier is None:
288
+ raise ValueError("Model has not been trained. Call initialize_model() first.")
289
+
290
+ # Use self.PERSIST_COLS
291
+ model_data = {k: getattr(self, k) for k in self.PERSIST_COLS}
292
+
293
+ joblib.dump(model_data, output)
294
+
295
+ @classmethod
296
+ def load_model(cls, file_path: Union[str, Path]) -> "SklearnInferenceEngine":
297
+ """
298
+ Load a trained model and related data from a file.
299
+
300
+ :param file_path: Path to the saved model
301
+ :return: SklearnInferenceEngine instance with loaded model
302
+ """
303
+ import joblib
304
+
305
+ model_data = joblib.load(file_path)
306
+
307
+ engine = cls(config=model_data["config"])
308
+ for k, v in model_data.items():
309
+ if k == "config":
310
+ continue
311
+ setattr(engine, k, v)
312
+
313
+ logger.info(f"Model loaded from {file_path}")
314
+ return engine
@@ -0,0 +1,66 @@
1
+ import logging
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field
5
+
6
+ from linkml_store.api.collection import OBJECT
7
+ from linkml_store.utils.format_utils import Format, load_objects
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LLMConfig(BaseModel, extra="forbid"):
13
+ """
14
+ Configuration for the LLM indexer.
15
+ """
16
+
17
+ model_config = ConfigDict(protected_namespaces=())
18
+
19
+ model_name: str = "gpt-4o-mini"
20
+ token_limit: Optional[int] = None
21
+ number_of_few_shot_examples: Optional[int] = None
22
+ role: str = "Domain Expert"
23
+ cached_embeddings_database: Optional[str] = None
24
+ cached_embeddings_collection: Optional[str] = None
25
+ text_template: Optional[str] = None
26
+ text_template_syntax: Optional[str] = None
27
+
28
+
29
+ class InferenceConfig(BaseModel, extra="forbid"):
30
+ """
31
+ Configuration for inference engines.
32
+ """
33
+
34
+ target_attributes: Optional[List[str]] = None
35
+ feature_attributes: Optional[List[str]] = None
36
+ train_test_split: Optional[Tuple[float, float]] = None
37
+ llm_config: Optional[LLMConfig] = None
38
+ random_seed: Optional[int] = None
39
+ validate_results: Optional[bool] = None
40
+
41
+ @classmethod
42
+ def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig":
43
+ """
44
+ Load an inference config from a file.
45
+
46
+ :param file_path: Path to the file.
47
+ :param format: Format of the file (YAML is recommended).
48
+ :return: InferenceConfig
49
+ """
50
+ if format and format.is_xsv():
51
+ logger.warning("XSV format is not recommended for inference config files")
52
+ objs = load_objects(file_path, format=format)
53
+ if len(objs) != 1:
54
+ raise ValueError(f"Expected 1 object, got {len(objs)}")
55
+ return cls(**objs[0])
56
+
57
+
58
+ class Inference(BaseModel, extra="forbid"):
59
+ """
60
+ Result of an inference derivation.
61
+ """
62
+
63
+ query: Optional[OBJECT] = Field(default=None, description="The query object.")
64
+ predicted_object: OBJECT = Field(..., description="The predicted object.")
65
+ confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0)
66
+ explanation: Optional[Any] = Field(default=None, description="Explanation of the prediction.")