linkml-store 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (27) hide show
  1. linkml_store/api/client.py +35 -8
  2. linkml_store/api/collection.py +40 -5
  3. linkml_store/api/config.py +20 -3
  4. linkml_store/api/database.py +24 -3
  5. linkml_store/api/stores/duckdb/duckdb_collection.py +3 -0
  6. linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
  7. linkml_store/cli.py +149 -13
  8. linkml_store/inference/__init__.py +13 -0
  9. linkml_store/inference/evaluation.py +189 -0
  10. linkml_store/inference/implementations/__init__.py +0 -0
  11. linkml_store/inference/implementations/rag_inference_engine.py +145 -0
  12. linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
  13. linkml_store/inference/implementations/sklearn_inference_engine.py +308 -0
  14. linkml_store/inference/inference_config.py +62 -0
  15. linkml_store/inference/inference_engine.py +200 -0
  16. linkml_store/inference/inference_engine_registry.py +74 -0
  17. linkml_store/utils/format_utils.py +27 -90
  18. linkml_store/utils/llm_utils.py +96 -0
  19. linkml_store/utils/object_utils.py +103 -2
  20. linkml_store/utils/pandas_utils.py +55 -2
  21. linkml_store/utils/sklearn_utils.py +193 -0
  22. linkml_store/utils/stats_utils.py +53 -0
  23. {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/METADATA +28 -2
  24. {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/RECORD +27 -15
  25. {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/LICENSE +0 -0
  26. {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/WHEEL +0 -0
  27. {linkml_store-0.1.13.dist-info → linkml_store-0.2.0.dist-info}/entry_points.txt +0 -0
@@ -100,7 +100,7 @@ class Client:
100
100
  """
101
101
  return self.metadata.base_dir
102
102
 
103
- def from_config(self, config: Union[ClientConfig, dict, str, Path], base_dir=None, **kwargs):
103
+ def from_config(self, config: Union[ClientConfig, dict, str, Path], base_dir=None, auto_attach=False, **kwargs):
104
104
  """
105
105
  Create a client from a configuration.
106
106
 
@@ -109,6 +109,10 @@ class Client:
109
109
  >>> from linkml_store.api.config import ClientConfig
110
110
  >>> client = Client().from_config(ClientConfig(databases={"test": {"handle": "duckdb:///:memory:"}}))
111
111
  >>> len(client.databases)
112
+ 0
113
+ >>> client = Client().from_config(ClientConfig(databases={"test": {"handle": "duckdb:///:memory:"}}),
114
+ ... auto_attach=True)
115
+ >>> len(client.databases)
112
116
  1
113
117
  >>> "test" in client.databases
114
118
  True
@@ -116,6 +120,8 @@ class Client:
116
120
  'duckdb:///:memory:'
117
121
 
118
122
  :param config:
123
+ :param base_dir:
124
+ :param auto_attach:
119
125
  :param kwargs:
120
126
  :return:
121
127
 
@@ -125,17 +131,17 @@ class Client:
125
131
  if isinstance(config, Path):
126
132
  config = str(config)
127
133
  if isinstance(config, str):
128
- # if not base_dir:
129
- # base_dir = Path(config).parent
134
+ if not base_dir:
135
+ base_dir = Path(config).parent
130
136
  parsed_obj = yaml.safe_load(open(config))
131
137
  config = ClientConfig(**parsed_obj)
132
138
  self.metadata = config
133
139
  if base_dir:
134
140
  self.metadata.base_dir = base_dir
135
- self._initialize_databases(**kwargs)
141
+ self._initialize_databases(auto_attach=auto_attach, **kwargs)
136
142
  return self
137
143
 
138
- def _initialize_databases(self, **kwargs):
144
+ def _initialize_databases(self, auto_attach=False, **kwargs):
139
145
  for name, db_config in self.metadata.databases.items():
140
146
  base_dir = self.base_dir
141
147
  logger.info(f"Initializing database: {name}, base_dir: {base_dir}")
@@ -146,8 +152,22 @@ class Client:
146
152
  db_config.handle = handle
147
153
  if db_config.schema_location:
148
154
  db_config.schema_location = db_config.schema_location.format(base_dir=base_dir)
149
- db = self.attach_database(handle, alias=name, **kwargs)
150
- db.from_config(db_config)
155
+ if auto_attach:
156
+ db = self.attach_database(handle, alias=name, **kwargs)
157
+ db.from_config(db_config)
158
+
159
+ def _set_database_config(self, db: Database):
160
+ """
161
+ Set the configuration for a database.
162
+
163
+ :param name:
164
+ :param config:
165
+ :return:
166
+ """
167
+ if not self.metadata:
168
+ return
169
+ if db.alias in self.metadata.databases:
170
+ db.from_config(self.metadata.databases[db.alias])
151
171
 
152
172
  def attach_database(
153
173
  self,
@@ -202,6 +222,7 @@ class Client:
202
222
  raise AssertionError(f"Inconsistent alias: {db.alias} != {alias}")
203
223
  else:
204
224
  db.metadata.alias = alias
225
+ self._set_database_config(db)
205
226
  return db
206
227
 
207
228
  def get_database(self, name: Optional[str] = None, create_if_not_exists=True, **kwargs) -> Database:
@@ -230,13 +251,19 @@ class Client:
230
251
  return list(self._databases.values())[0]
231
252
  if not self._databases:
232
253
  self._databases = {}
254
+ if name not in self._databases and name in self.metadata.databases:
255
+ db_config = self.metadata.databases[name]
256
+ db = self.attach_database(db_config.handle, alias=name, **kwargs)
257
+ self._databases[name] = db
233
258
  if name not in self._databases:
234
259
  if create_if_not_exists:
235
260
  logger.info(f"Creating database: {name}")
236
261
  self.attach_database(name, **kwargs)
237
262
  else:
238
263
  raise ValueError(f"Database {name} does not exist")
239
- return self._databases[name]
264
+ db = self._databases[name]
265
+ self._set_database_config(db)
266
+ return db
240
267
 
241
268
  @property
242
269
  def databases(self) -> Dict[str, Database]:
@@ -502,6 +502,7 @@ class Collection(Generic[DatabaseType]):
502
502
  index_name = self.default_index_name
503
503
  ix_coll = self.parent.get_collection(self._index_collection_name(index_name))
504
504
  if index_name not in self.indexers:
505
+ logger.debug(f"Indexer not found: {index_name} -- creating")
505
506
  ix = get_indexer(index_name)
506
507
  if not self._indexers:
507
508
  self._indexers = {}
@@ -509,6 +510,11 @@ class Collection(Generic[DatabaseType]):
509
510
  ix = self.indexers.get(index_name)
510
511
  if not ix:
511
512
  raise ValueError(f"No index named {index_name}")
513
+ logger.debug(f"Using indexer {type(ix)} with name {index_name}")
514
+ if ix_coll.size() == 0:
515
+ logger.info(f"Index {index_name} is empty; indexing all objects")
516
+ all_objs = self.find(limit=-1).rows
517
+ self.index_objects(all_objs, index_name, replace=True, **kwargs)
512
518
  qr = ix_coll.find(where=where, limit=-1, **kwargs)
513
519
  index_col = ix.index_field
514
520
  # TODO: optimize this for large indexes
@@ -518,6 +524,7 @@ class Collection(Generic[DatabaseType]):
518
524
  del r[1][index_col]
519
525
  new_qr = QueryResult(num_rows=len(results))
520
526
  new_qr.ranked_rows = results
527
+ new_qr.rows = [r[1] for r in results]
521
528
  return new_qr
522
529
 
523
530
  @property
@@ -562,6 +569,7 @@ class Collection(Generic[DatabaseType]):
562
569
  format=source.format,
563
570
  expected_type=source.expected_type,
564
571
  compression=source.compression,
572
+ select_query=source.select_query,
565
573
  **kwargs,
566
574
  )
567
575
  elif metadata.source.url:
@@ -570,9 +578,12 @@ class Collection(Generic[DatabaseType]):
570
578
  format=source.format,
571
579
  expected_type=source.expected_type,
572
580
  compression=source.compression,
581
+ select_query=source.select_query,
573
582
  **kwargs,
574
583
  )
575
- self.insert(objects)
584
+ else:
585
+ raise ValueError("No source local_path or url provided")
586
+ self.insert(objects)
576
587
 
577
588
  def _check_if_initialized(self) -> bool:
578
589
  return self._initialized
@@ -629,6 +640,14 @@ class Collection(Generic[DatabaseType]):
629
640
  self.insert(tr_objs)
630
641
  self.commit()
631
642
 
643
+ def size(self) -> int:
644
+ """
645
+ Return the number of objects in the collection.
646
+
647
+ :return: The number of objects in the collection.
648
+ """
649
+ return self.find({}, limit=1).num_rows
650
+
632
651
  def attach_indexer(self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs):
633
652
  """
634
653
  Attach an index to the collection.
@@ -777,6 +796,8 @@ class Collection(Generic[DatabaseType]):
777
796
  sv: SchemaView = self.parent.schema_view
778
797
  if sv:
779
798
  cls = sv.get_class(self.target_class_name)
799
+ # if not cls:
800
+ # logger.warning(f"{self.target_class_name} not in {sv.all_classes().keys()} ")
780
801
  # cls = sv.schema.classes[self.target_class_name]
781
802
  if cls and not cls.attributes:
782
803
  if not sv.class_induced_slots(cls.name):
@@ -900,11 +921,14 @@ class Collection(Generic[DatabaseType]):
900
921
  exact_dimensions_list.append(v.shape)
901
922
  break
902
923
  if isinstance(v, list):
924
+ # sample first item. TODO: more robust strategy
903
925
  v = v[0] if v else None
904
926
  multivalueds.append(True)
905
927
  elif isinstance(v, dict):
906
- v = list(v.values())[0]
907
- multivalueds.append(True)
928
+ pass
929
+ # TODO: check if this is a nested object or key-value list
930
+ # v = list(v.values())[0]
931
+ # multivalueds.append(True)
908
932
  else:
909
933
  multivalueds.append(False)
910
934
  if not v:
@@ -933,10 +957,21 @@ class Collection(Generic[DatabaseType]):
933
957
  # raise AssertionError(f"Empty rngs for {k} = {vs}")
934
958
  rng = rngs[0] if rngs else None
935
959
  for other_rng in rngs:
960
+ coercions = {
961
+ ("integer", "float"): "float",
962
+ }
936
963
  if rng != other_rng:
937
- raise ValueError(f"Conflict: {rng} != {other_rng} for {vs}")
964
+ if (rng, other_rng) in coercions:
965
+ rng = coercions[(rng, other_rng)]
966
+ elif (other_rng, rng) in coercions:
967
+ rng = coercions[(other_rng, rng)]
968
+ else:
969
+ raise ValueError(f"Conflict: {rng} != {other_rng} for {vs}")
938
970
  logger.debug(f"Inducing {k} as {rng} {multivalued} {inlined}")
939
- cd.attributes[k] = SlotDefinition(k, range=rng, multivalued=multivalued, inlined=inlined)
971
+ inlined_as_list = inlined and multivalued
972
+ cd.attributes[k] = SlotDefinition(
973
+ k, range=rng, multivalued=multivalued, inlined=inlined, inlined_as_list=inlined_as_list
974
+ )
940
975
  if exact_dimensions_list:
941
976
  array_expr = ArrayExpression(exact_number_dimensions=len(exact_dimensions_list[0]))
942
977
  cd.attributes[k].array = array_expr
@@ -1,8 +1,8 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Dict, List, Optional, Union
2
2
 
3
3
  from pydantic import BaseModel, Field
4
4
 
5
- from linkml_store.graphs.graph_map import GraphProjection
5
+ from linkml_store.graphs.graph_map import EdgeProjection, NodeProjection
6
6
 
7
7
 
8
8
  class ConfiguredBaseModel(BaseModel, extra="forbid"):
@@ -30,13 +30,30 @@ class CollectionSource(ConfiguredBaseModel):
30
30
  """
31
31
 
32
32
  url: Optional[str] = None
33
+ """Remote URL to fetch data from"""
34
+
33
35
  local_path: Optional[str] = None
36
+ """Local path to fetch data from"""
37
+
34
38
  source_location: Optional[str] = None
39
+
35
40
  refresh_interval_days: Optional[float] = None
41
+ """How often to refresh the data, in days"""
42
+
36
43
  expected_type: Optional[str] = None
44
+ """The expected type of the data, e.g list"""
45
+
37
46
  format: Optional[str] = None
47
+ """The format of the data, e.g., json, yaml, csv"""
48
+
38
49
  compression: Optional[str] = None
50
+ """The compression of the data, e.g., tgz, gzip, zip"""
51
+
52
+ select_query: Optional[str] = None
53
+ """A jsonpath query to preprocess the objects with"""
54
+
39
55
  arguments: Optional[Dict[str, Any]] = None
56
+ """Optional arguments to pass to the source"""
40
57
 
41
58
 
42
59
  class CollectionConfig(ConfiguredBaseModel):
@@ -81,7 +98,7 @@ class CollectionConfig(ConfiguredBaseModel):
81
98
  description="LinkML-Map derivations",
82
99
  )
83
100
  page_size: Optional[int] = Field(default=None, description="Suggested page size (items per page) in apps and APIs")
84
- graph_projection: Optional[GraphProjection] = Field(
101
+ graph_projection: Optional[Union[EdgeProjection, NodeProjection]] = Field(
85
102
  default=None,
86
103
  description="Optional graph projection configuration",
87
104
  )
@@ -707,12 +707,29 @@ class Database(ABC, Generic[CollectionType]):
707
707
  """
708
708
  raise NotImplementedError()
709
709
 
710
- def import_database(self, location: str, source_format: Optional[Union[str, Format]] = None, **kwargs):
710
+ def import_database(
711
+ self,
712
+ location: str,
713
+ source_format: Optional[Union[str, Format]] = None,
714
+ collection_name: Optional[str] = None,
715
+ **kwargs,
716
+ ):
711
717
  """
712
718
  Import a database from a file or location.
713
719
 
720
+ >>> from linkml_store.api.client import Client
721
+ >>> client = Client()
722
+ >>> db = client.attach_database("duckdb", alias="test")
723
+ >>> db.import_database("tests/input/iris.csv", Format.CSV, collection_name="iris")
724
+ >>> db.list_collection_names()
725
+ ['iris']
726
+ >>> collection = db.get_collection("iris")
727
+ >>> collection.find({}).num_rows
728
+ 150
729
+
714
730
  :param location: location of the file
715
731
  :param source_format: source format
732
+ :param collection_name: (Optional) name of the collection, for data that is flat
716
733
  :param kwargs: additional arguments
717
734
  """
718
735
  if isinstance(source_format, str):
@@ -732,8 +749,12 @@ class Database(ABC, Generic[CollectionType]):
732
749
  self.store(obj)
733
750
  return
734
751
  objects = load_objects(location, format=source_format)
735
- for obj in objects:
736
- self.store(obj)
752
+ if collection_name:
753
+ collection = self.get_collection(collection_name, create_if_not_exists=True)
754
+ collection.insert(objects)
755
+ else:
756
+ for obj in objects:
757
+ self.store(obj)
737
758
 
738
759
  def export_database(self, location: str, target_format: Optional[Union[str, Format]] = None, **kwargs):
739
760
  """
@@ -36,6 +36,9 @@ class DuckDBCollection(Collection):
36
36
  logger.info(f"Inserting into: {self.alias} // T={table.name}")
37
37
  engine = self.parent.engine
38
38
  col_names = [c.name for c in table.columns]
39
+ bad_objs = [obj for obj in objs if not isinstance(obj, dict)]
40
+ if bad_objs:
41
+ logger.error(f"Bad objects: {bad_objs}")
39
42
  objs = [{k: obj.get(k, None) for k in col_names} for obj in objs]
40
43
  with engine.connect() as conn:
41
44
  with conn.begin():
@@ -51,9 +51,13 @@ class MongoDBCollection(Collection):
51
51
  if offset and offset >= 0:
52
52
  cursor = cursor.skip(offset)
53
53
 
54
+ select_cols = query.select_cols
55
+
54
56
  def _as_row(row: dict):
55
57
  row = copy(row)
56
58
  del row["_id"]
59
+ if select_cols:
60
+ row = {k: row[k] for k in select_cols if k in row}
57
61
  return row
58
62
 
59
63
  rows = [_as_row(row) for row in cursor]
linkml_store/cli.py CHANGED
@@ -1,23 +1,34 @@
1
1
  import logging
2
2
  import sys
3
3
  import warnings
4
+ from pathlib import Path
4
5
  from typing import Optional
5
6
 
6
7
  import click
7
8
  import yaml
8
9
  from linkml_runtime.dumpers import json_dumper
10
+ from linkml_runtime.utils.formatutils import underscore
9
11
  from pydantic import BaseModel
10
12
 
11
13
  from linkml_store import Client
12
14
  from linkml_store.api import Collection, Database
15
+ from linkml_store.api.config import ClientConfig
13
16
  from linkml_store.api.queries import Query
14
17
  from linkml_store.index import get_indexer
15
18
  from linkml_store.index.implementations.simple_indexer import SimpleIndexer
16
19
  from linkml_store.index.indexer import Indexer
20
+ from linkml_store.inference import get_inference_engine
21
+ from linkml_store.inference.evaluation import evaluate_predictor, score_text_overlap
22
+ from linkml_store.inference.inference_config import InferenceConfig
23
+ from linkml_store.inference.inference_engine import ModelSerialization
17
24
  from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output, write_output
18
25
  from linkml_store.utils.object_utils import object_path_update
19
26
  from linkml_store.utils.pandas_utils import facet_summary_to_dataframe_unmelted
20
27
 
28
+ DEFAULT_LOCAL_CONF_PATH = Path("linkml.yaml")
29
+ # global path is ~/.linkml.yaml in the user's home directory
30
+ DEFAULT_GLOBAL_CONF_PATH = Path("~/.linkml.yaml").expanduser()
31
+
21
32
  index_type_option = click.option(
22
33
  "--index-type",
23
34
  "-t",
@@ -84,6 +95,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
84
95
  @click.group()
85
96
  @click.option("--database", "-d", help="Database name")
86
97
  @click.option("--collection", "-c", help="Collection name")
98
+ @click.option("--input", "-i", help="Input file (alternative to database/collection)")
87
99
  @click.option("--config", "-C", type=click.Path(exists=True), help="Path to the configuration file")
88
100
  @click.option("--set", help="Metadata settings in the form PATHEXPR=value", multiple=True)
89
101
  @click.option("-v", "--verbose", count=True)
@@ -96,7 +108,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
96
108
  help="If set then show full stacktrace on error",
97
109
  )
98
110
  @click.pass_context
99
- def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, **kwargs):
111
+ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, input, **kwargs):
100
112
  """A CLI for interacting with the linkml-store."""
101
113
  if not stacktrace:
102
114
  sys.tracebacklimit = 0
@@ -119,13 +131,25 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
119
131
  if quiet:
120
132
  logger.setLevel(logging.ERROR)
121
133
  ctx.ensure_object(dict)
134
+ if input:
135
+ stem = underscore(Path(input).stem)
136
+ database = "duckdb"
137
+ collection = stem
138
+ config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
139
+ # collection = Path(input).stem
140
+ # database = f"file:{Path(input).parent}"
141
+ if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
142
+ config = DEFAULT_LOCAL_CONF_PATH
143
+ if config is None and DEFAULT_GLOBAL_CONF_PATH.exists():
144
+ config = DEFAULT_GLOBAL_CONF_PATH
145
+ if config == ".":
146
+ config = None
147
+ if not collection and database and "::" in database:
148
+ database, collection = database.split("::")
149
+
122
150
  client = Client().from_config(config, **kwargs) if config else Client()
123
151
  settings = ContextSettings(client=client, database_name=database, collection_name=collection)
124
152
  ctx.obj["settings"] = settings
125
- # DEPRECATED
126
- ctx.obj["client"] = client
127
- ctx.obj["database"] = database
128
- ctx.obj["collection"] = collection
129
153
  if settings.database_name:
130
154
  db = client.get_database(database)
131
155
  if set:
@@ -136,12 +160,6 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
136
160
  val = yaml.safe_load(val)
137
161
  logger.info(f"Setting {path} to {val}")
138
162
  db.metadata = object_path_update(db.metadata, path, val)
139
- # settings.database = db
140
- # DEPRECATED
141
- ctx.obj["database_obj"] = db
142
- if collection:
143
- collection_obj = db.get_collection(collection)
144
- ctx.obj["collection_obj"] = collection_obj
145
163
  if not settings.database_name:
146
164
  # if len(client.databases) != 1:
147
165
  # raise ValueError("Database must be specified if there are multiple databases.")
@@ -323,11 +341,12 @@ def apply(ctx, patch_files, identifier_attribute):
323
341
 
324
342
  @cli.command()
325
343
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query, as YAML")
344
+ @click.option("--select", "-s", type=click.STRING, help="SELECT clause for the query, as YAML")
326
345
  @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
327
346
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
328
347
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
329
348
  @click.pass_context
330
- def query(ctx, where, limit, output_type, output):
349
+ def query(ctx, where, select, limit, output_type, output):
331
350
  """Query objects from the specified collection.
332
351
 
333
352
 
@@ -353,7 +372,13 @@ def query(ctx, where, limit, output_type, output):
353
372
  """
354
373
  collection = ctx.obj["settings"].collection
355
374
  where_clause = yaml.safe_load(where) if where else None
356
- query = Query(from_table=collection.alias, where_clause=where_clause, limit=limit)
375
+ select_clause = yaml.safe_load(select) if select else None
376
+ if select_clause:
377
+ if isinstance(select_clause, str):
378
+ select_clause = [select_clause]
379
+ if not isinstance(select_clause, list):
380
+ raise ValueError(f"SELECT clause must be a list. Got: {select_clause}")
381
+ query = Query(from_table=collection.alias, select_cols=select_clause, where_clause=where_clause, limit=limit)
357
382
  result = collection.query(query)
358
383
  output_data = render_output(result.rows, output_type)
359
384
  if output:
@@ -458,6 +483,117 @@ def describe(ctx, where, output_type, output, limit):
458
483
  write_output(df.describe(include="all").transpose(), output_type, target=output)
459
484
 
460
485
 
486
+ @cli.command()
487
+ @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
488
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
489
+ @click.option("--target-attribute", "-T", type=click.STRING, multiple=True, help="Target attributes for inference")
490
+ @click.option(
491
+ "--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
492
+ )
493
+ @click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
494
+ @click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
495
+ @click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
496
+ @click.option("--model-format", "-M", type=click.Choice([x.value for x in ModelSerialization]), help="Format for model")
497
+ @click.option("--training-test-data-split", "-S", type=click.Tuple([float, float]), help="Training/test data split")
498
+ @click.option(
499
+ "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
500
+ )
501
+ @click.option("--evaluation-count", "-n", type=click.INT, help="Number of examples to evaluate over")
502
+ @click.option("--query", "-q", type=click.STRING, help="query term")
503
+ @click.pass_context
504
+ def infer(
505
+ ctx,
506
+ inference_config_file,
507
+ query,
508
+ evaluation_count,
509
+ training_test_data_split,
510
+ predictor_type,
511
+ target_attribute,
512
+ feature_attributes,
513
+ output_type,
514
+ output,
515
+ model_format,
516
+ export_model,
517
+ load_model,
518
+ ):
519
+ """
520
+ Predict a complete object from a partial object.
521
+
522
+ Currently two main prediction methods are provided: RAG and sklearn
523
+
524
+ ## RAG:
525
+
526
+ The RAG approach will use Retrieval Augmented Generation to inference the missing attributes of an object.
527
+
528
+ Example:
529
+
530
+ linkml-store -i countries.jsonl inference -t rag -q 'name: Uruguay'
531
+
532
+ Result:
533
+
534
+ capital: Montevideo, code: UY, continent: South America, languages: [Spanish]
535
+
536
+ You can pass in configurations as follows:
537
+
538
+ linkml-store -i countries.jsonl inference -t rag:llm_config.model_name=llama-3 -q 'name: Uruguay'
539
+
540
+ ## SKLearn:
541
+
542
+ This uses scikit-learn (defaulting to simple decision trees) to do the prediction.
543
+
544
+ linkml-store -i tests/input/iris.csv inference -t sklearn \
545
+ -q '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
546
+ """
547
+ if query:
548
+ query_obj = yaml.safe_load(query)
549
+ else:
550
+ query_obj = None
551
+ collection = ctx.obj["settings"].collection
552
+ atts = collection.class_definition().attributes.keys()
553
+ if feature_attributes:
554
+ features = feature_attributes.split(",")
555
+ features = [f.strip() for f in features]
556
+ else:
557
+ if query_obj:
558
+ features = query_obj.keys()
559
+ else:
560
+ features = None
561
+ if target_attribute:
562
+ target_attributes = list(target_attribute)
563
+ else:
564
+ target_attributes = [att for att in atts if att not in features]
565
+ if model_format:
566
+ model_format = ModelSerialization(model_format)
567
+ if load_model:
568
+ predictor = get_inference_engine(predictor_type)
569
+ predictor = type(predictor).load_model(load_model)
570
+ else:
571
+ if inference_config_file:
572
+ config = InferenceConfig.from_file(inference_config_file)
573
+ else:
574
+ config = InferenceConfig(target_attributes=target_attributes, feature_attributes=features)
575
+ if training_test_data_split:
576
+ config.train_test_split = training_test_data_split
577
+ predictor = get_inference_engine(predictor_type, config=config)
578
+ predictor.load_and_split_data(collection)
579
+ predictor.initialize_model()
580
+ if export_model:
581
+ logger.info(f"Exporting model to {export_model} in {model_format}")
582
+ predictor.export_model(export_model, model_format)
583
+ if not query_obj:
584
+ if not export_model and not evaluation_count:
585
+ raise ValueError("Query or evaluate must be specified if not exporting model")
586
+ if evaluation_count:
587
+ outcome = evaluate_predictor(
588
+ predictor, target_attributes, evaluation_count=evaluation_count, match_function=score_text_overlap
589
+ )
590
+ print(f"Outcome: {outcome} // accuracy: {outcome.accuracy}")
591
+ if query_obj:
592
+ result = predictor.derive(query_obj)
593
+ dumped_obj = result.model_dump(exclude_none=True)
594
+ write_output([dumped_obj], output_type, target=output)
595
+
596
+
461
597
  @cli.command()
462
598
  @index_type_option
463
599
  @click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
@@ -0,0 +1,13 @@
1
+ """
2
+ inference engine package.
3
+ """
4
+
5
+ from linkml_store.inference.inference_config import InferenceConfig
6
+ from linkml_store.inference.inference_engine import InferenceEngine
7
+ from linkml_store.inference.inference_engine_registry import get_inference_engine
8
+
9
+ __all__ = [
10
+ "InferenceEngine",
11
+ "InferenceConfig",
12
+ "get_inference_engine",
13
+ ]