linkml-store 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +37 -8
- linkml_store/api/collection.py +81 -9
- linkml_store/api/config.py +28 -1
- linkml_store/api/database.py +26 -3
- linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
- linkml_store/api/stores/neo4j/__init__.py +0 -0
- linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
- linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
- linkml_store/cli.py +140 -13
- linkml_store/graphs/__init__.py +0 -0
- linkml_store/graphs/graph_map.py +24 -0
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/rag_inference_engine.py +145 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +158 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +290 -0
- linkml_store/inference/inference_config.py +62 -0
- linkml_store/inference/inference_engine.py +173 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/utils/format_utils.py +21 -90
- linkml_store/utils/llm_utils.py +95 -0
- linkml_store/utils/neo4j_utils.py +42 -0
- linkml_store/utils/object_utils.py +3 -1
- linkml_store/utils/pandas_utils.py +55 -2
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/stats_utils.py +53 -0
- {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/METADATA +30 -3
- {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/RECORD +31 -14
- {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/LICENSE +0 -0
- {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/WHEEL +0 -0
- {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/entry_points.txt +0 -0
linkml_store/cli.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import sys
|
|
3
3
|
import warnings
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
6
7
|
import click
|
|
@@ -10,14 +11,22 @@ from pydantic import BaseModel
|
|
|
10
11
|
|
|
11
12
|
from linkml_store import Client
|
|
12
13
|
from linkml_store.api import Collection, Database
|
|
14
|
+
from linkml_store.api.config import ClientConfig
|
|
13
15
|
from linkml_store.api.queries import Query
|
|
14
16
|
from linkml_store.index import get_indexer
|
|
15
17
|
from linkml_store.index.implementations.simple_indexer import SimpleIndexer
|
|
16
18
|
from linkml_store.index.indexer import Indexer
|
|
19
|
+
from linkml_store.inference import get_inference_engine
|
|
20
|
+
from linkml_store.inference.inference_config import InferenceConfig
|
|
21
|
+
from linkml_store.inference.inference_engine import ModelSerialization
|
|
17
22
|
from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output, write_output
|
|
18
23
|
from linkml_store.utils.object_utils import object_path_update
|
|
19
24
|
from linkml_store.utils.pandas_utils import facet_summary_to_dataframe_unmelted
|
|
20
25
|
|
|
26
|
+
DEFAULT_LOCAL_CONF_PATH = Path("linkml.yaml")
|
|
27
|
+
# global path is ~/.linkml.yaml in the user's home directory
|
|
28
|
+
DEFAULT_GLOBAL_CONF_PATH = Path("~/.linkml.yaml").expanduser()
|
|
29
|
+
|
|
21
30
|
index_type_option = click.option(
|
|
22
31
|
"--index-type",
|
|
23
32
|
"-t",
|
|
@@ -84,6 +93,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
|
|
|
84
93
|
@click.group()
|
|
85
94
|
@click.option("--database", "-d", help="Database name")
|
|
86
95
|
@click.option("--collection", "-c", help="Collection name")
|
|
96
|
+
@click.option("--input", "-i", help="Input file (alternative to database/collection)")
|
|
87
97
|
@click.option("--config", "-C", type=click.Path(exists=True), help="Path to the configuration file")
|
|
88
98
|
@click.option("--set", help="Metadata settings in the form PATHEXPR=value", multiple=True)
|
|
89
99
|
@click.option("-v", "--verbose", count=True)
|
|
@@ -96,7 +106,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
|
|
|
96
106
|
help="If set then show full stacktrace on error",
|
|
97
107
|
)
|
|
98
108
|
@click.pass_context
|
|
99
|
-
def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, **kwargs):
|
|
109
|
+
def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, input, **kwargs):
|
|
100
110
|
"""A CLI for interacting with the linkml-store."""
|
|
101
111
|
if not stacktrace:
|
|
102
112
|
sys.tracebacklimit = 0
|
|
@@ -119,13 +129,25 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
119
129
|
if quiet:
|
|
120
130
|
logger.setLevel(logging.ERROR)
|
|
121
131
|
ctx.ensure_object(dict)
|
|
132
|
+
if input:
|
|
133
|
+
stem = Path(input).stem
|
|
134
|
+
database = "duckdb"
|
|
135
|
+
collection = stem
|
|
136
|
+
config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
|
|
137
|
+
# collection = Path(input).stem
|
|
138
|
+
# database = f"file:{Path(input).parent}"
|
|
139
|
+
if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
|
|
140
|
+
config = DEFAULT_LOCAL_CONF_PATH
|
|
141
|
+
if config is None and DEFAULT_GLOBAL_CONF_PATH.exists():
|
|
142
|
+
config = DEFAULT_GLOBAL_CONF_PATH
|
|
143
|
+
if config == ".":
|
|
144
|
+
config = None
|
|
145
|
+
if not collection and database and "::" in database:
|
|
146
|
+
database, collection = database.split("::")
|
|
147
|
+
|
|
122
148
|
client = Client().from_config(config, **kwargs) if config else Client()
|
|
123
149
|
settings = ContextSettings(client=client, database_name=database, collection_name=collection)
|
|
124
150
|
ctx.obj["settings"] = settings
|
|
125
|
-
# DEPRECATED
|
|
126
|
-
ctx.obj["client"] = client
|
|
127
|
-
ctx.obj["database"] = database
|
|
128
|
-
ctx.obj["collection"] = collection
|
|
129
151
|
if settings.database_name:
|
|
130
152
|
db = client.get_database(database)
|
|
131
153
|
if set:
|
|
@@ -136,12 +158,6 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
136
158
|
val = yaml.safe_load(val)
|
|
137
159
|
logger.info(f"Setting {path} to {val}")
|
|
138
160
|
db.metadata = object_path_update(db.metadata, path, val)
|
|
139
|
-
# settings.database = db
|
|
140
|
-
# DEPRECATED
|
|
141
|
-
ctx.obj["database_obj"] = db
|
|
142
|
-
if collection:
|
|
143
|
-
collection_obj = db.get_collection(collection)
|
|
144
|
-
ctx.obj["collection_obj"] = collection_obj
|
|
145
161
|
if not settings.database_name:
|
|
146
162
|
# if len(client.databases) != 1:
|
|
147
163
|
# raise ValueError("Database must be specified if there are multiple databases.")
|
|
@@ -323,11 +339,12 @@ def apply(ctx, patch_files, identifier_attribute):
|
|
|
323
339
|
|
|
324
340
|
@cli.command()
|
|
325
341
|
@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query, as YAML")
|
|
342
|
+
@click.option("--select", "-s", type=click.STRING, help="SELECT clause for the query, as YAML")
|
|
326
343
|
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
|
|
327
344
|
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
328
345
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
329
346
|
@click.pass_context
|
|
330
|
-
def query(ctx, where, limit, output_type, output):
|
|
347
|
+
def query(ctx, where, select, limit, output_type, output):
|
|
331
348
|
"""Query objects from the specified collection.
|
|
332
349
|
|
|
333
350
|
|
|
@@ -353,7 +370,13 @@ def query(ctx, where, limit, output_type, output):
|
|
|
353
370
|
"""
|
|
354
371
|
collection = ctx.obj["settings"].collection
|
|
355
372
|
where_clause = yaml.safe_load(where) if where else None
|
|
356
|
-
|
|
373
|
+
select_clause = yaml.safe_load(select) if select else None
|
|
374
|
+
if select_clause:
|
|
375
|
+
if isinstance(select_clause, str):
|
|
376
|
+
select_clause = [select_clause]
|
|
377
|
+
if not isinstance(select_clause, list):
|
|
378
|
+
raise ValueError(f"SELECT clause must be a list. Got: {select_clause}")
|
|
379
|
+
query = Query(from_table=collection.alias, select_cols=select_clause, where_clause=where_clause, limit=limit)
|
|
357
380
|
result = collection.query(query)
|
|
358
381
|
output_data = render_output(result.rows, output_type)
|
|
359
382
|
if output:
|
|
@@ -458,6 +481,110 @@ def describe(ctx, where, output_type, output, limit):
|
|
|
458
481
|
write_output(df.describe(include="all").transpose(), output_type, target=output)
|
|
459
482
|
|
|
460
483
|
|
|
484
|
+
@cli.command()
|
|
485
|
+
@click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
|
|
486
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
487
|
+
@click.option("--target-attribute", "-T", type=click.STRING, multiple=True, help="Target attributes for inference")
|
|
488
|
+
@click.option(
|
|
489
|
+
"--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
|
|
490
|
+
)
|
|
491
|
+
@click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
|
|
492
|
+
@click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
|
|
493
|
+
@click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
|
|
494
|
+
@click.option("--model-format", "-M", type=click.Choice([x.value for x in ModelSerialization]), help="Format for model")
|
|
495
|
+
@click.option("--training-test-data-split", "-S", type=click.Tuple([float, float]), help="Training/test data split")
|
|
496
|
+
@click.option(
|
|
497
|
+
"--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
|
|
498
|
+
)
|
|
499
|
+
@click.option("--query", "-q", type=click.STRING, help="query term")
|
|
500
|
+
@click.pass_context
|
|
501
|
+
def infer(
|
|
502
|
+
ctx,
|
|
503
|
+
inference_config_file,
|
|
504
|
+
query,
|
|
505
|
+
training_test_data_split,
|
|
506
|
+
predictor_type,
|
|
507
|
+
target_attribute,
|
|
508
|
+
feature_attributes,
|
|
509
|
+
output_type,
|
|
510
|
+
output,
|
|
511
|
+
model_format,
|
|
512
|
+
export_model,
|
|
513
|
+
load_model,
|
|
514
|
+
):
|
|
515
|
+
"""
|
|
516
|
+
Predict a complete object from a partial object.
|
|
517
|
+
|
|
518
|
+
Currently two main prediction methods are provided: RAG and sklearn
|
|
519
|
+
|
|
520
|
+
## RAG:
|
|
521
|
+
|
|
522
|
+
The RAG approach will use Retrieval Augmented Generation to inference the missing attributes of an object.
|
|
523
|
+
|
|
524
|
+
Example:
|
|
525
|
+
|
|
526
|
+
linkml-store -i countries.jsonl inference -t rag -q 'name: Uruguay'
|
|
527
|
+
|
|
528
|
+
Result:
|
|
529
|
+
|
|
530
|
+
capital: Montevideo, code: UY, continent: South America, languages: [Spanish]
|
|
531
|
+
|
|
532
|
+
You can pass in configurations as follows:
|
|
533
|
+
|
|
534
|
+
linkml-store -i countries.jsonl inference -t rag:llm_config.model_name=llama-3 -q 'name: Uruguay'
|
|
535
|
+
|
|
536
|
+
## SKLearn:
|
|
537
|
+
|
|
538
|
+
This uses scikit-learn (defaulting to simple decision trees) to do the prediction.
|
|
539
|
+
|
|
540
|
+
linkml-store -i tests/input/iris.csv inference -t sklearn \
|
|
541
|
+
-q '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
|
|
542
|
+
"""
|
|
543
|
+
if query:
|
|
544
|
+
query_obj = yaml.safe_load(query)
|
|
545
|
+
else:
|
|
546
|
+
query_obj = None
|
|
547
|
+
collection = ctx.obj["settings"].collection
|
|
548
|
+
atts = collection.class_definition().attributes.keys()
|
|
549
|
+
if model_format:
|
|
550
|
+
model_format = ModelSerialization(model_format)
|
|
551
|
+
if load_model:
|
|
552
|
+
predictor = get_inference_engine(predictor_type)
|
|
553
|
+
predictor = type(predictor).load_model(load_model)
|
|
554
|
+
else:
|
|
555
|
+
if feature_attributes:
|
|
556
|
+
features = feature_attributes.split(",")
|
|
557
|
+
features = [f.strip() for f in features]
|
|
558
|
+
else:
|
|
559
|
+
if query_obj:
|
|
560
|
+
features = query_obj.keys()
|
|
561
|
+
else:
|
|
562
|
+
features = None
|
|
563
|
+
if target_attribute:
|
|
564
|
+
target_attributes = list(target_attribute)
|
|
565
|
+
else:
|
|
566
|
+
target_attributes = [att for att in atts if att not in features]
|
|
567
|
+
if inference_config_file:
|
|
568
|
+
config = InferenceConfig.from_file(inference_config_file)
|
|
569
|
+
else:
|
|
570
|
+
config = InferenceConfig(target_attributes=target_attributes, feature_attributes=features)
|
|
571
|
+
if training_test_data_split:
|
|
572
|
+
config.train_test_split = training_test_data_split
|
|
573
|
+
predictor = get_inference_engine(predictor_type, config=config)
|
|
574
|
+
predictor.load_and_split_data(collection)
|
|
575
|
+
predictor.initialize_model()
|
|
576
|
+
if export_model:
|
|
577
|
+
logger.info(f"Exporting model to {export_model} in {model_format}")
|
|
578
|
+
predictor.export_model(export_model, model_format)
|
|
579
|
+
if not query_obj:
|
|
580
|
+
if not export_model:
|
|
581
|
+
raise ValueError("Query must be specified if not exporting model")
|
|
582
|
+
if query_obj:
|
|
583
|
+
result = predictor.derive(query_obj)
|
|
584
|
+
dumped_obj = result.model_dump(exclude_none=True)
|
|
585
|
+
write_output([dumped_obj], output_type, target=output)
|
|
586
|
+
|
|
587
|
+
|
|
461
588
|
@cli.command()
|
|
462
589
|
@index_type_option
|
|
463
590
|
@click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
DEFAULT_IDENTIFIER_ATTRIBUTE = "id"
|
|
7
|
+
DEFAULT_CATEGORY_LABELS_ATTRIBUTE = "category"
|
|
8
|
+
DEFAULT_SUBJECT_ATTRIBUTE = "subject"
|
|
9
|
+
DEFAULT_PREDICATE_ATTRIBUTE = "predicate"
|
|
10
|
+
DEFAULT_OBJECT_ATTRIBUTE = "object"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GraphProjection(BaseModel, ABC):
|
|
14
|
+
identifier_attribute: str = DEFAULT_IDENTIFIER_ATTRIBUTE
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NodeProjection(GraphProjection):
|
|
18
|
+
category_labels_attribute: Optional[str] = DEFAULT_CATEGORY_LABELS_ATTRIBUTE
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EdgeProjection(GraphProjection):
|
|
22
|
+
subject_attribute: str = DEFAULT_SUBJECT_ATTRIBUTE
|
|
23
|
+
predicate_attribute: str = DEFAULT_PREDICATE_ATTRIBUTE
|
|
24
|
+
object_attribute: str = DEFAULT_OBJECT_ATTRIBUTE
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
inference engine package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from linkml_store.inference.inference_config import InferenceConfig
|
|
6
|
+
from linkml_store.inference.inference_engine import InferenceEngine
|
|
7
|
+
from linkml_store.inference.inference_engine_registry import get_inference_engine
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"InferenceEngine",
|
|
11
|
+
"InferenceConfig",
|
|
12
|
+
"get_inference_engine",
|
|
13
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from llm import get_key
|
|
7
|
+
|
|
8
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
9
|
+
from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
|
|
10
|
+
from linkml_store.inference.inference_engine import InferenceEngine
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
SYSTEM_PROMPT = """
|
|
15
|
+
You are a {llm_config.role}, your task is to inference the YAML
|
|
16
|
+
object output given the YAML object input. I will provide you
|
|
17
|
+
with a collection of examples that will provide guidance both
|
|
18
|
+
on the desired structure of the response, as well as the kind
|
|
19
|
+
of content.
|
|
20
|
+
|
|
21
|
+
You should return ONLY valid YAML in your response.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class RAGInferenceEngine(InferenceEngine):
|
|
27
|
+
"""
|
|
28
|
+
AI Retrieval Augmented Generation (RAG) based predictor.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
>>> from linkml_store.api.client import Client
|
|
32
|
+
>>> from linkml_store.utils.format_utils import Format
|
|
33
|
+
>>> from linkml_store.inference.inference_config import LLMConfig
|
|
34
|
+
>>> client = Client()
|
|
35
|
+
>>> db = client.attach_database("duckdb", alias="test")
|
|
36
|
+
>>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
|
|
37
|
+
>>> db.list_collection_names()
|
|
38
|
+
['countries']
|
|
39
|
+
>>> collection = db.get_collection("countries")
|
|
40
|
+
>>> features = ["name"]
|
|
41
|
+
>>> targets = ["code", "capital", "continent", "languages"]
|
|
42
|
+
>>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
|
|
43
|
+
>>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
|
|
44
|
+
>>> ie = RAGInferenceEngine(config=config)
|
|
45
|
+
>>> ie.load_and_split_data(collection)
|
|
46
|
+
>>> ie.initialize_model()
|
|
47
|
+
>>> prediction = ie.derive({"name": "Uruguay"})
|
|
48
|
+
>>> prediction.predicted_object
|
|
49
|
+
{'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
classifier: Any = None
|
|
54
|
+
encoders: dict = None
|
|
55
|
+
_model: "llm.Model" = None # noqa: F821
|
|
56
|
+
|
|
57
|
+
rag_collection: Collection = None
|
|
58
|
+
|
|
59
|
+
def __post_init__(self):
|
|
60
|
+
if not self.config:
|
|
61
|
+
self.config = InferenceConfig()
|
|
62
|
+
if not self.config.llm_config:
|
|
63
|
+
self.config.llm_config = LLMConfig()
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def model(self) -> "llm.Model": # noqa: F821
|
|
67
|
+
import llm
|
|
68
|
+
|
|
69
|
+
if self._model is None:
|
|
70
|
+
self._model = llm.get_model(self.config.llm_config.model_name)
|
|
71
|
+
if self._model.needs_key:
|
|
72
|
+
key = get_key(None, key_alias=self._model.needs_key)
|
|
73
|
+
self._model.key = key
|
|
74
|
+
|
|
75
|
+
return self._model
|
|
76
|
+
|
|
77
|
+
def initialize_model(self, **kwargs):
|
|
78
|
+
td = self.training_data
|
|
79
|
+
s = td.slice
|
|
80
|
+
if not s[0] and not s[1]:
|
|
81
|
+
rag_collection = td.collection
|
|
82
|
+
else:
|
|
83
|
+
base_collection = td.collection
|
|
84
|
+
objs = base_collection.find({}, offset=s[0], limit=s[1] - s[0]).rows
|
|
85
|
+
db = base_collection.parent
|
|
86
|
+
rag_collection = db.get_collection(f"{base_collection.alias}__rag_{s[0]}_{s[1]}", create_if_not_exists=True)
|
|
87
|
+
rag_collection.insert(objs)
|
|
88
|
+
rag_collection.attach_indexer("llm", auto_index=False)
|
|
89
|
+
self.rag_collection = rag_collection
|
|
90
|
+
|
|
91
|
+
def object_to_text(self, object: OBJECT) -> str:
|
|
92
|
+
return yaml.dump(object)
|
|
93
|
+
|
|
94
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
95
|
+
import llm
|
|
96
|
+
from tiktoken import encoding_for_model
|
|
97
|
+
|
|
98
|
+
from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
|
|
99
|
+
|
|
100
|
+
model: llm.Model = self.model
|
|
101
|
+
model_name = self.config.llm_config.model_name
|
|
102
|
+
feature_attributes = self.config.feature_attributes
|
|
103
|
+
target_attributes = self.config.target_attributes
|
|
104
|
+
num_examples = self.config.llm_config.number_of_few_shot_examples or 5
|
|
105
|
+
query_text = self.object_to_text(object)
|
|
106
|
+
if not self.rag_collection.indexers:
|
|
107
|
+
raise ValueError("RAG collection must have an indexer attached")
|
|
108
|
+
rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
|
|
109
|
+
examples = rs.rows
|
|
110
|
+
if not examples:
|
|
111
|
+
raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
|
|
112
|
+
prompt_clauses = []
|
|
113
|
+
for example in examples:
|
|
114
|
+
input_obj = {k: example.get(k, None) for k in feature_attributes}
|
|
115
|
+
output_obj = {k: example.get(k, None) for k in target_attributes}
|
|
116
|
+
prompt_clause = (
|
|
117
|
+
"---\nExample:\n"
|
|
118
|
+
f"## INPUT:\n{self.object_to_text(input_obj)}\n"
|
|
119
|
+
f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
|
|
120
|
+
)
|
|
121
|
+
prompt_clauses.append(prompt_clause)
|
|
122
|
+
query_obj = {k: object.get(k, None) for k in feature_attributes}
|
|
123
|
+
query_text = self.object_to_text(query_obj)
|
|
124
|
+
prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
|
|
125
|
+
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
126
|
+
|
|
127
|
+
def make_text(texts):
|
|
128
|
+
return "\n".join(prompt_clauses) + prompt_end
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
encoding = encoding_for_model(model_name)
|
|
132
|
+
except KeyError:
|
|
133
|
+
encoding = encoding_for_model("gpt-4")
|
|
134
|
+
token_limit = get_token_limit(model_name)
|
|
135
|
+
prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
|
|
136
|
+
logger.info(f"Prompt: {prompt}")
|
|
137
|
+
response = model.prompt(prompt, system_prompt)
|
|
138
|
+
yaml_str = response.text()
|
|
139
|
+
logger.info(f"Response: {yaml_str}")
|
|
140
|
+
try:
|
|
141
|
+
predicted_object = yaml.safe_load(yaml_str)
|
|
142
|
+
return Inference(predicted_object=predicted_object)
|
|
143
|
+
except yaml.parser.ParserError as e:
|
|
144
|
+
logger.error(f"Error parsing response: {yaml_str}\n{e}")
|
|
145
|
+
return None
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import copy
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, ClassVar, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from linkml_map.utils.eval_utils import eval_expr
|
|
10
|
+
from linkml_runtime import SchemaView
|
|
11
|
+
from linkml_runtime.linkml_model.meta import AnonymousClassExpression, ClassRule
|
|
12
|
+
from linkml_runtime.utils.formatutils import underscore
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from linkml_store.api.collection import OBJECT, Collection
|
|
16
|
+
from linkml_store.inference.inference_config import Inference
|
|
17
|
+
from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def expression_matches(ce: AnonymousClassExpression, object: OBJECT) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
Check if a class expression matches an object.
|
|
25
|
+
|
|
26
|
+
:param ce: The class expression
|
|
27
|
+
:param object: The object
|
|
28
|
+
:return: True if the class expression matches the object
|
|
29
|
+
"""
|
|
30
|
+
if ce.any_of:
|
|
31
|
+
if not any(expression_matches(subce, object) for subce in ce.any_of):
|
|
32
|
+
return False
|
|
33
|
+
if ce.all_of:
|
|
34
|
+
if not all(expression_matches(subce, object) for subce in ce.all_of):
|
|
35
|
+
return False
|
|
36
|
+
if ce.none_of:
|
|
37
|
+
if any(expression_matches(subce, object) for subce in ce.none_of):
|
|
38
|
+
return False
|
|
39
|
+
if ce.slot_conditions:
|
|
40
|
+
for slot in ce.slot_conditions.values():
|
|
41
|
+
slot_name = slot.name
|
|
42
|
+
v = object.get(slot_name, None)
|
|
43
|
+
if slot.equals_string is not None:
|
|
44
|
+
if slot.equals_string != str(v):
|
|
45
|
+
return False
|
|
46
|
+
if slot.equals_integer is not None:
|
|
47
|
+
if slot.equals_integer != v:
|
|
48
|
+
return False
|
|
49
|
+
if slot.equals_expression is not None:
|
|
50
|
+
eval_v = eval_expr(slot.equals_expression, **object)
|
|
51
|
+
if v != eval_v:
|
|
52
|
+
return False
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def apply_rule(rule: ClassRule, object: OBJECT):
|
|
57
|
+
"""
|
|
58
|
+
Apply a rule to an object.
|
|
59
|
+
|
|
60
|
+
Mutates the object
|
|
61
|
+
|
|
62
|
+
:param rule: The rule to apply
|
|
63
|
+
:param object: The object to apply the rule to
|
|
64
|
+
"""
|
|
65
|
+
for condition in rule.preconditions:
|
|
66
|
+
if expression_matches(condition, object):
|
|
67
|
+
for postcondition in rule.postconditions:
|
|
68
|
+
all_of = [x for x in postcondition.all_of] + [postcondition]
|
|
69
|
+
for pc in all_of:
|
|
70
|
+
sc = pc.slot_condition
|
|
71
|
+
if sc:
|
|
72
|
+
if sc.equals_string:
|
|
73
|
+
object[sc.name] = sc.equals_string
|
|
74
|
+
if sc.equals_integer:
|
|
75
|
+
object[sc.name] = sc.equals_integer
|
|
76
|
+
if sc.equals_expression:
|
|
77
|
+
object[sc.name] = eval_expr(sc.equals_expression, **object)
|
|
78
|
+
return object
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class RuleBasedInferenceEngine(InferenceEngine):
|
|
83
|
+
"""
|
|
84
|
+
TODO
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
class_rules: Optional[List[ClassRule]] = None
|
|
89
|
+
slot_rules: Optional[Dict[str, List[ClassRule]]] = None
|
|
90
|
+
slot_expressions: Optional[Dict[str, str]] = None
|
|
91
|
+
|
|
92
|
+
PERSIST_COLS: ClassVar = ["config", "class_rules", "slot_rules", "slot_expressions"]
|
|
93
|
+
|
|
94
|
+
def initialize_model(self, **kwargs):
|
|
95
|
+
td = self.training_data
|
|
96
|
+
collection: Collection = td.collection
|
|
97
|
+
cd = collection.class_definition()
|
|
98
|
+
sv: SchemaView = collection.parent.schema_view
|
|
99
|
+
class_rules = cd.rules
|
|
100
|
+
if class_rules:
|
|
101
|
+
self.class_rules = class_rules
|
|
102
|
+
for slot in sv.class_induced_slots(cd.name):
|
|
103
|
+
if slot.equals_expression:
|
|
104
|
+
self.slot_expressions[slot.name] = slot.equals_expression
|
|
105
|
+
|
|
106
|
+
def derive(self, object: OBJECT) -> Optional[Inference]:
|
|
107
|
+
object = copy(object)
|
|
108
|
+
if self.class_rules:
|
|
109
|
+
for rule in self.class_rules:
|
|
110
|
+
apply_rule(rule, object)
|
|
111
|
+
object = {underscore(k): v for k, v in object.items()}
|
|
112
|
+
if self.slot_expressions:
|
|
113
|
+
for slot, expr in self.slot_expressions.items():
|
|
114
|
+
print(f"EVAL {object}")
|
|
115
|
+
v = eval_expr(expr, **object)
|
|
116
|
+
if v is not None:
|
|
117
|
+
object[slot] = v
|
|
118
|
+
return Inference(predicted_object=object)
|
|
119
|
+
|
|
120
|
+
def import_model_from(self, inference_engine: InferenceEngine, **kwargs):
|
|
121
|
+
io = StringIO()
|
|
122
|
+
inference_engine.export_model(io, model_serialization=ModelSerialization.LINKML_EXPRESSION)
|
|
123
|
+
config = inference_engine.config
|
|
124
|
+
if len(config.target_attributes) != 1:
|
|
125
|
+
raise ValueError("Can only import models with a single target attribute")
|
|
126
|
+
target_attribute = config.target_attributes[0]
|
|
127
|
+
if self.slot_expressions is None:
|
|
128
|
+
self.slot_expressions = {}
|
|
129
|
+
self.slot_expressions[target_attribute] = io.getvalue()
|
|
130
|
+
|
|
131
|
+
def save_model(self, output: Union[str, Path]) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Save the trained model and related data to a file.
|
|
134
|
+
|
|
135
|
+
:param output: Path to save the model
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def _serialize_value(v: Any) -> Any:
|
|
139
|
+
if isinstance(v, BaseModel):
|
|
140
|
+
return v.model_dump(exclude_unset=True)
|
|
141
|
+
return v
|
|
142
|
+
|
|
143
|
+
model_data = {k: _serialize_value(getattr(self, k)) for k in self.PERSIST_COLS}
|
|
144
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
145
|
+
yaml.dump(model_data, f)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def load_model(cls, file_path: Union[str, Path]) -> "RuleBasedInferenceEngine":
|
|
149
|
+
model_data = yaml.safe_load(open(file_path))
|
|
150
|
+
|
|
151
|
+
engine = cls(config=model_data["config"])
|
|
152
|
+
for k, v in model_data.items():
|
|
153
|
+
if k == "config":
|
|
154
|
+
continue
|
|
155
|
+
setattr(engine, k, v)
|
|
156
|
+
|
|
157
|
+
logger.info(f"Model loaded from {file_path}")
|
|
158
|
+
return engine
|