linkml-store 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (31) hide show
  1. linkml_store/api/client.py +37 -8
  2. linkml_store/api/collection.py +81 -9
  3. linkml_store/api/config.py +28 -1
  4. linkml_store/api/database.py +26 -3
  5. linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
  6. linkml_store/api/stores/neo4j/__init__.py +0 -0
  7. linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
  8. linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
  9. linkml_store/cli.py +140 -13
  10. linkml_store/graphs/__init__.py +0 -0
  11. linkml_store/graphs/graph_map.py +24 -0
  12. linkml_store/inference/__init__.py +13 -0
  13. linkml_store/inference/implementations/__init__.py +0 -0
  14. linkml_store/inference/implementations/rag_inference_engine.py +145 -0
  15. linkml_store/inference/implementations/rule_based_inference_engine.py +158 -0
  16. linkml_store/inference/implementations/sklearn_inference_engine.py +290 -0
  17. linkml_store/inference/inference_config.py +62 -0
  18. linkml_store/inference/inference_engine.py +173 -0
  19. linkml_store/inference/inference_engine_registry.py +74 -0
  20. linkml_store/utils/format_utils.py +21 -90
  21. linkml_store/utils/llm_utils.py +95 -0
  22. linkml_store/utils/neo4j_utils.py +42 -0
  23. linkml_store/utils/object_utils.py +3 -1
  24. linkml_store/utils/pandas_utils.py +55 -2
  25. linkml_store/utils/sklearn_utils.py +193 -0
  26. linkml_store/utils/stats_utils.py +53 -0
  27. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/METADATA +30 -3
  28. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/RECORD +31 -14
  29. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/LICENSE +0 -0
  30. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/WHEEL +0 -0
  31. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/entry_points.txt +0 -0
linkml_store/cli.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import sys
3
3
  import warnings
4
+ from pathlib import Path
4
5
  from typing import Optional
5
6
 
6
7
  import click
@@ -10,14 +11,22 @@ from pydantic import BaseModel
10
11
 
11
12
  from linkml_store import Client
12
13
  from linkml_store.api import Collection, Database
14
+ from linkml_store.api.config import ClientConfig
13
15
  from linkml_store.api.queries import Query
14
16
  from linkml_store.index import get_indexer
15
17
  from linkml_store.index.implementations.simple_indexer import SimpleIndexer
16
18
  from linkml_store.index.indexer import Indexer
19
+ from linkml_store.inference import get_inference_engine
20
+ from linkml_store.inference.inference_config import InferenceConfig
21
+ from linkml_store.inference.inference_engine import ModelSerialization
17
22
  from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output, write_output
18
23
  from linkml_store.utils.object_utils import object_path_update
19
24
  from linkml_store.utils.pandas_utils import facet_summary_to_dataframe_unmelted
20
25
 
26
+ DEFAULT_LOCAL_CONF_PATH = Path("linkml.yaml")
27
+ # global path is ~/.linkml.yaml in the user's home directory
28
+ DEFAULT_GLOBAL_CONF_PATH = Path("~/.linkml.yaml").expanduser()
29
+
21
30
  index_type_option = click.option(
22
31
  "--index-type",
23
32
  "-t",
@@ -84,6 +93,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
84
93
  @click.group()
85
94
  @click.option("--database", "-d", help="Database name")
86
95
  @click.option("--collection", "-c", help="Collection name")
96
+ @click.option("--input", "-i", help="Input file (alternative to database/collection)")
87
97
  @click.option("--config", "-C", type=click.Path(exists=True), help="Path to the configuration file")
88
98
  @click.option("--set", help="Metadata settings in the form PATHEXPR=value", multiple=True)
89
99
  @click.option("-v", "--verbose", count=True)
@@ -96,7 +106,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
96
106
  help="If set then show full stacktrace on error",
97
107
  )
98
108
  @click.pass_context
99
- def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, **kwargs):
109
+ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, input, **kwargs):
100
110
  """A CLI for interacting with the linkml-store."""
101
111
  if not stacktrace:
102
112
  sys.tracebacklimit = 0
@@ -119,13 +129,25 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
119
129
  if quiet:
120
130
  logger.setLevel(logging.ERROR)
121
131
  ctx.ensure_object(dict)
132
+ if input:
133
+ stem = Path(input).stem
134
+ database = "duckdb"
135
+ collection = stem
136
+ config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
137
+ # collection = Path(input).stem
138
+ # database = f"file:{Path(input).parent}"
139
+ if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
140
+ config = DEFAULT_LOCAL_CONF_PATH
141
+ if config is None and DEFAULT_GLOBAL_CONF_PATH.exists():
142
+ config = DEFAULT_GLOBAL_CONF_PATH
143
+ if config == ".":
144
+ config = None
145
+ if not collection and database and "::" in database:
146
+ database, collection = database.split("::")
147
+
122
148
  client = Client().from_config(config, **kwargs) if config else Client()
123
149
  settings = ContextSettings(client=client, database_name=database, collection_name=collection)
124
150
  ctx.obj["settings"] = settings
125
- # DEPRECATED
126
- ctx.obj["client"] = client
127
- ctx.obj["database"] = database
128
- ctx.obj["collection"] = collection
129
151
  if settings.database_name:
130
152
  db = client.get_database(database)
131
153
  if set:
@@ -136,12 +158,6 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
136
158
  val = yaml.safe_load(val)
137
159
  logger.info(f"Setting {path} to {val}")
138
160
  db.metadata = object_path_update(db.metadata, path, val)
139
- # settings.database = db
140
- # DEPRECATED
141
- ctx.obj["database_obj"] = db
142
- if collection:
143
- collection_obj = db.get_collection(collection)
144
- ctx.obj["collection_obj"] = collection_obj
145
161
  if not settings.database_name:
146
162
  # if len(client.databases) != 1:
147
163
  # raise ValueError("Database must be specified if there are multiple databases.")
@@ -323,11 +339,12 @@ def apply(ctx, patch_files, identifier_attribute):
323
339
 
324
340
  @cli.command()
325
341
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query, as YAML")
342
+ @click.option("--select", "-s", type=click.STRING, help="SELECT clause for the query, as YAML")
326
343
  @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
327
344
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
328
345
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
329
346
  @click.pass_context
330
- def query(ctx, where, limit, output_type, output):
347
+ def query(ctx, where, select, limit, output_type, output):
331
348
  """Query objects from the specified collection.
332
349
 
333
350
 
@@ -353,7 +370,13 @@ def query(ctx, where, limit, output_type, output):
353
370
  """
354
371
  collection = ctx.obj["settings"].collection
355
372
  where_clause = yaml.safe_load(where) if where else None
356
- query = Query(from_table=collection.alias, where_clause=where_clause, limit=limit)
373
+ select_clause = yaml.safe_load(select) if select else None
374
+ if select_clause:
375
+ if isinstance(select_clause, str):
376
+ select_clause = [select_clause]
377
+ if not isinstance(select_clause, list):
378
+ raise ValueError(f"SELECT clause must be a list. Got: {select_clause}")
379
+ query = Query(from_table=collection.alias, select_cols=select_clause, where_clause=where_clause, limit=limit)
357
380
  result = collection.query(query)
358
381
  output_data = render_output(result.rows, output_type)
359
382
  if output:
@@ -458,6 +481,110 @@ def describe(ctx, where, output_type, output, limit):
458
481
  write_output(df.describe(include="all").transpose(), output_type, target=output)
459
482
 
460
483
 
484
+ @cli.command()
485
+ @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
486
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
487
+ @click.option("--target-attribute", "-T", type=click.STRING, multiple=True, help="Target attributes for inference")
488
+ @click.option(
489
+ "--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
490
+ )
491
+ @click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
492
+ @click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
493
+ @click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
494
+ @click.option("--model-format", "-M", type=click.Choice([x.value for x in ModelSerialization]), help="Format for model")
495
+ @click.option("--training-test-data-split", "-S", type=click.Tuple([float, float]), help="Training/test data split")
496
+ @click.option(
497
+ "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
498
+ )
499
+ @click.option("--query", "-q", type=click.STRING, help="query term")
500
+ @click.pass_context
501
+ def infer(
502
+ ctx,
503
+ inference_config_file,
504
+ query,
505
+ training_test_data_split,
506
+ predictor_type,
507
+ target_attribute,
508
+ feature_attributes,
509
+ output_type,
510
+ output,
511
+ model_format,
512
+ export_model,
513
+ load_model,
514
+ ):
515
+ """
516
+ Predict a complete object from a partial object.
517
+
518
+ Currently two main prediction methods are provided: RAG and sklearn
519
+
520
+ ## RAG:
521
+
522
+ The RAG approach will use Retrieval Augmented Generation to inference the missing attributes of an object.
523
+
524
+ Example:
525
+
526
+ linkml-store -i countries.jsonl inference -t rag -q 'name: Uruguay'
527
+
528
+ Result:
529
+
530
+ capital: Montevideo, code: UY, continent: South America, languages: [Spanish]
531
+
532
+ You can pass in configurations as follows:
533
+
534
+ linkml-store -i countries.jsonl inference -t rag:llm_config.model_name=llama-3 -q 'name: Uruguay'
535
+
536
+ ## SKLearn:
537
+
538
+ This uses scikit-learn (defaulting to simple decision trees) to do the prediction.
539
+
540
+ linkml-store -i tests/input/iris.csv inference -t sklearn \
541
+ -q '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
542
+ """
543
+ if query:
544
+ query_obj = yaml.safe_load(query)
545
+ else:
546
+ query_obj = None
547
+ collection = ctx.obj["settings"].collection
548
+ atts = collection.class_definition().attributes.keys()
549
+ if model_format:
550
+ model_format = ModelSerialization(model_format)
551
+ if load_model:
552
+ predictor = get_inference_engine(predictor_type)
553
+ predictor = type(predictor).load_model(load_model)
554
+ else:
555
+ if feature_attributes:
556
+ features = feature_attributes.split(",")
557
+ features = [f.strip() for f in features]
558
+ else:
559
+ if query_obj:
560
+ features = query_obj.keys()
561
+ else:
562
+ features = None
563
+ if target_attribute:
564
+ target_attributes = list(target_attribute)
565
+ else:
566
+ target_attributes = [att for att in atts if att not in features]
567
+ if inference_config_file:
568
+ config = InferenceConfig.from_file(inference_config_file)
569
+ else:
570
+ config = InferenceConfig(target_attributes=target_attributes, feature_attributes=features)
571
+ if training_test_data_split:
572
+ config.train_test_split = training_test_data_split
573
+ predictor = get_inference_engine(predictor_type, config=config)
574
+ predictor.load_and_split_data(collection)
575
+ predictor.initialize_model()
576
+ if export_model:
577
+ logger.info(f"Exporting model to {export_model} in {model_format}")
578
+ predictor.export_model(export_model, model_format)
579
+ if not query_obj:
580
+ if not export_model:
581
+ raise ValueError("Query must be specified if not exporting model")
582
+ if query_obj:
583
+ result = predictor.derive(query_obj)
584
+ dumped_obj = result.model_dump(exclude_none=True)
585
+ write_output([dumped_obj], output_type, target=output)
586
+
587
+
461
588
  @cli.command()
462
589
  @index_type_option
463
590
  @click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
File without changes
@@ -0,0 +1,24 @@
1
+ from abc import ABC
2
+ from typing import Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ DEFAULT_IDENTIFIER_ATTRIBUTE = "id"
7
+ DEFAULT_CATEGORY_LABELS_ATTRIBUTE = "category"
8
+ DEFAULT_SUBJECT_ATTRIBUTE = "subject"
9
+ DEFAULT_PREDICATE_ATTRIBUTE = "predicate"
10
+ DEFAULT_OBJECT_ATTRIBUTE = "object"
11
+
12
+
13
+ class GraphProjection(BaseModel, ABC):
14
+ identifier_attribute: str = DEFAULT_IDENTIFIER_ATTRIBUTE
15
+
16
+
17
+ class NodeProjection(GraphProjection):
18
+ category_labels_attribute: Optional[str] = DEFAULT_CATEGORY_LABELS_ATTRIBUTE
19
+
20
+
21
+ class EdgeProjection(GraphProjection):
22
+ subject_attribute: str = DEFAULT_SUBJECT_ATTRIBUTE
23
+ predicate_attribute: str = DEFAULT_PREDICATE_ATTRIBUTE
24
+ object_attribute: str = DEFAULT_OBJECT_ATTRIBUTE
@@ -0,0 +1,13 @@
1
+ """
2
+ inference engine package.
3
+ """
4
+
5
+ from linkml_store.inference.inference_config import InferenceConfig
6
+ from linkml_store.inference.inference_engine import InferenceEngine
7
+ from linkml_store.inference.inference_engine_registry import get_inference_engine
8
+
9
+ __all__ = [
10
+ "InferenceEngine",
11
+ "InferenceConfig",
12
+ "get_inference_engine",
13
+ ]
File without changes
@@ -0,0 +1,145 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+ import yaml
6
+ from llm import get_key
7
+
8
+ from linkml_store.api.collection import OBJECT, Collection
9
+ from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
10
+ from linkml_store.inference.inference_engine import InferenceEngine
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ SYSTEM_PROMPT = """
15
+ You are a {llm_config.role}, your task is to inference the YAML
16
+ object output given the YAML object input. I will provide you
17
+ with a collection of examples that will provide guidance both
18
+ on the desired structure of the response, as well as the kind
19
+ of content.
20
+
21
+ You should return ONLY valid YAML in your response.
22
+ """
23
+
24
+
25
+ @dataclass
26
+ class RAGInferenceEngine(InferenceEngine):
27
+ """
28
+ AI Retrieval Augmented Generation (RAG) based predictor.
29
+
30
+
31
+ >>> from linkml_store.api.client import Client
32
+ >>> from linkml_store.utils.format_utils import Format
33
+ >>> from linkml_store.inference.inference_config import LLMConfig
34
+ >>> client = Client()
35
+ >>> db = client.attach_database("duckdb", alias="test")
36
+ >>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
37
+ >>> db.list_collection_names()
38
+ ['countries']
39
+ >>> collection = db.get_collection("countries")
40
+ >>> features = ["name"]
41
+ >>> targets = ["code", "capital", "continent", "languages"]
42
+ >>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
43
+ >>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
44
+ >>> ie = RAGInferenceEngine(config=config)
45
+ >>> ie.load_and_split_data(collection)
46
+ >>> ie.initialize_model()
47
+ >>> prediction = ie.derive({"name": "Uruguay"})
48
+ >>> prediction.predicted_object
49
+ {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
50
+
51
+ """
52
+
53
+ classifier: Any = None
54
+ encoders: dict = None
55
+ _model: "llm.Model" = None # noqa: F821
56
+
57
+ rag_collection: Collection = None
58
+
59
+ def __post_init__(self):
60
+ if not self.config:
61
+ self.config = InferenceConfig()
62
+ if not self.config.llm_config:
63
+ self.config.llm_config = LLMConfig()
64
+
65
+ @property
66
+ def model(self) -> "llm.Model": # noqa: F821
67
+ import llm
68
+
69
+ if self._model is None:
70
+ self._model = llm.get_model(self.config.llm_config.model_name)
71
+ if self._model.needs_key:
72
+ key = get_key(None, key_alias=self._model.needs_key)
73
+ self._model.key = key
74
+
75
+ return self._model
76
+
77
+ def initialize_model(self, **kwargs):
78
+ td = self.training_data
79
+ s = td.slice
80
+ if not s[0] and not s[1]:
81
+ rag_collection = td.collection
82
+ else:
83
+ base_collection = td.collection
84
+ objs = base_collection.find({}, offset=s[0], limit=s[1] - s[0]).rows
85
+ db = base_collection.parent
86
+ rag_collection = db.get_collection(f"{base_collection.alias}__rag_{s[0]}_{s[1]}", create_if_not_exists=True)
87
+ rag_collection.insert(objs)
88
+ rag_collection.attach_indexer("llm", auto_index=False)
89
+ self.rag_collection = rag_collection
90
+
91
+ def object_to_text(self, object: OBJECT) -> str:
92
+ return yaml.dump(object)
93
+
94
+ def derive(self, object: OBJECT) -> Optional[Inference]:
95
+ import llm
96
+ from tiktoken import encoding_for_model
97
+
98
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
99
+
100
+ model: llm.Model = self.model
101
+ model_name = self.config.llm_config.model_name
102
+ feature_attributes = self.config.feature_attributes
103
+ target_attributes = self.config.target_attributes
104
+ num_examples = self.config.llm_config.number_of_few_shot_examples or 5
105
+ query_text = self.object_to_text(object)
106
+ if not self.rag_collection.indexers:
107
+ raise ValueError("RAG collection must have an indexer attached")
108
+ rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
109
+ examples = rs.rows
110
+ if not examples:
111
+ raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
112
+ prompt_clauses = []
113
+ for example in examples:
114
+ input_obj = {k: example.get(k, None) for k in feature_attributes}
115
+ output_obj = {k: example.get(k, None) for k in target_attributes}
116
+ prompt_clause = (
117
+ "---\nExample:\n"
118
+ f"## INPUT:\n{self.object_to_text(input_obj)}\n"
119
+ f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
120
+ )
121
+ prompt_clauses.append(prompt_clause)
122
+ query_obj = {k: object.get(k, None) for k in feature_attributes}
123
+ query_text = self.object_to_text(query_obj)
124
+ prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
125
+ system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
126
+
127
+ def make_text(texts):
128
+ return "\n".join(prompt_clauses) + prompt_end
129
+
130
+ try:
131
+ encoding = encoding_for_model(model_name)
132
+ except KeyError:
133
+ encoding = encoding_for_model("gpt-4")
134
+ token_limit = get_token_limit(model_name)
135
+ prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
136
+ logger.info(f"Prompt: {prompt}")
137
+ response = model.prompt(prompt, system_prompt)
138
+ yaml_str = response.text()
139
+ logger.info(f"Response: {yaml_str}")
140
+ try:
141
+ predicted_object = yaml.safe_load(yaml_str)
142
+ return Inference(predicted_object=predicted_object)
143
+ except yaml.parser.ParserError as e:
144
+ logger.error(f"Error parsing response: {yaml_str}\n{e}")
145
+ return None
@@ -0,0 +1,158 @@
1
+ import logging
2
+ from copy import copy
3
+ from dataclasses import dataclass
4
+ from io import StringIO
5
+ from pathlib import Path
6
+ from typing import Any, ClassVar, Dict, List, Optional, Union
7
+
8
+ import yaml
9
+ from linkml_map.utils.eval_utils import eval_expr
10
+ from linkml_runtime import SchemaView
11
+ from linkml_runtime.linkml_model.meta import AnonymousClassExpression, ClassRule
12
+ from linkml_runtime.utils.formatutils import underscore
13
+ from pydantic import BaseModel
14
+
15
+ from linkml_store.api.collection import OBJECT, Collection
16
+ from linkml_store.inference.inference_config import Inference
17
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def expression_matches(ce: AnonymousClassExpression, object: OBJECT) -> bool:
23
+ """
24
+ Check if a class expression matches an object.
25
+
26
+ :param ce: The class expression
27
+ :param object: The object
28
+ :return: True if the class expression matches the object
29
+ """
30
+ if ce.any_of:
31
+ if not any(expression_matches(subce, object) for subce in ce.any_of):
32
+ return False
33
+ if ce.all_of:
34
+ if not all(expression_matches(subce, object) for subce in ce.all_of):
35
+ return False
36
+ if ce.none_of:
37
+ if any(expression_matches(subce, object) for subce in ce.none_of):
38
+ return False
39
+ if ce.slot_conditions:
40
+ for slot in ce.slot_conditions.values():
41
+ slot_name = slot.name
42
+ v = object.get(slot_name, None)
43
+ if slot.equals_string is not None:
44
+ if slot.equals_string != str(v):
45
+ return False
46
+ if slot.equals_integer is not None:
47
+ if slot.equals_integer != v:
48
+ return False
49
+ if slot.equals_expression is not None:
50
+ eval_v = eval_expr(slot.equals_expression, **object)
51
+ if v != eval_v:
52
+ return False
53
+ return True
54
+
55
+
56
+ def apply_rule(rule: ClassRule, object: OBJECT):
57
+ """
58
+ Apply a rule to an object.
59
+
60
+ Mutates the object
61
+
62
+ :param rule: The rule to apply
63
+ :param object: The object to apply the rule to
64
+ """
65
+ for condition in rule.preconditions:
66
+ if expression_matches(condition, object):
67
+ for postcondition in rule.postconditions:
68
+ all_of = [x for x in postcondition.all_of] + [postcondition]
69
+ for pc in all_of:
70
+ sc = pc.slot_condition
71
+ if sc:
72
+ if sc.equals_string:
73
+ object[sc.name] = sc.equals_string
74
+ if sc.equals_integer:
75
+ object[sc.name] = sc.equals_integer
76
+ if sc.equals_expression:
77
+ object[sc.name] = eval_expr(sc.equals_expression, **object)
78
+ return object
79
+
80
+
81
+ @dataclass
82
+ class RuleBasedInferenceEngine(InferenceEngine):
83
+ """
84
+ TODO
85
+
86
+ """
87
+
88
+ class_rules: Optional[List[ClassRule]] = None
89
+ slot_rules: Optional[Dict[str, List[ClassRule]]] = None
90
+ slot_expressions: Optional[Dict[str, str]] = None
91
+
92
+ PERSIST_COLS: ClassVar = ["config", "class_rules", "slot_rules", "slot_expressions"]
93
+
94
+ def initialize_model(self, **kwargs):
95
+ td = self.training_data
96
+ collection: Collection = td.collection
97
+ cd = collection.class_definition()
98
+ sv: SchemaView = collection.parent.schema_view
99
+ class_rules = cd.rules
100
+ if class_rules:
101
+ self.class_rules = class_rules
102
+ for slot in sv.class_induced_slots(cd.name):
103
+ if slot.equals_expression:
104
+ self.slot_expressions[slot.name] = slot.equals_expression
105
+
106
+ def derive(self, object: OBJECT) -> Optional[Inference]:
107
+ object = copy(object)
108
+ if self.class_rules:
109
+ for rule in self.class_rules:
110
+ apply_rule(rule, object)
111
+ object = {underscore(k): v for k, v in object.items()}
112
+ if self.slot_expressions:
113
+ for slot, expr in self.slot_expressions.items():
114
+ print(f"EVAL {object}")
115
+ v = eval_expr(expr, **object)
116
+ if v is not None:
117
+ object[slot] = v
118
+ return Inference(predicted_object=object)
119
+
120
+ def import_model_from(self, inference_engine: InferenceEngine, **kwargs):
121
+ io = StringIO()
122
+ inference_engine.export_model(io, model_serialization=ModelSerialization.LINKML_EXPRESSION)
123
+ config = inference_engine.config
124
+ if len(config.target_attributes) != 1:
125
+ raise ValueError("Can only import models with a single target attribute")
126
+ target_attribute = config.target_attributes[0]
127
+ if self.slot_expressions is None:
128
+ self.slot_expressions = {}
129
+ self.slot_expressions[target_attribute] = io.getvalue()
130
+
131
+ def save_model(self, output: Union[str, Path]) -> None:
132
+ """
133
+ Save the trained model and related data to a file.
134
+
135
+ :param output: Path to save the model
136
+ """
137
+
138
+ def _serialize_value(v: Any) -> Any:
139
+ if isinstance(v, BaseModel):
140
+ return v.model_dump(exclude_unset=True)
141
+ return v
142
+
143
+ model_data = {k: _serialize_value(getattr(self, k)) for k in self.PERSIST_COLS}
144
+ with open(output, "w", encoding="utf-8") as f:
145
+ yaml.dump(model_data, f)
146
+
147
+ @classmethod
148
+ def load_model(cls, file_path: Union[str, Path]) -> "RuleBasedInferenceEngine":
149
+ model_data = yaml.safe_load(open(file_path))
150
+
151
+ engine = cls(config=model_data["config"])
152
+ for k, v in model_data.items():
153
+ if k == "config":
154
+ continue
155
+ setattr(engine, k, v)
156
+
157
+ logger.info(f"Model loaded from {file_path}")
158
+ return engine