linkml-store 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

@@ -100,7 +100,7 @@ class Client:
100
100
  """
101
101
  return self.metadata.base_dir
102
102
 
103
- def from_config(self, config: Union[ClientConfig, dict, str, Path], base_dir=None, **kwargs):
103
+ def from_config(self, config: Union[ClientConfig, dict, str, Path], base_dir=None, auto_attach=False, **kwargs):
104
104
  """
105
105
  Create a client from a configuration.
106
106
 
@@ -109,6 +109,10 @@ class Client:
109
109
  >>> from linkml_store.api.config import ClientConfig
110
110
  >>> client = Client().from_config(ClientConfig(databases={"test": {"handle": "duckdb:///:memory:"}}))
111
111
  >>> len(client.databases)
112
+ 0
113
+ >>> client = Client().from_config(ClientConfig(databases={"test": {"handle": "duckdb:///:memory:"}}),
114
+ ... auto_attach=True)
115
+ >>> len(client.databases)
112
116
  1
113
117
  >>> "test" in client.databases
114
118
  True
@@ -116,6 +120,8 @@ class Client:
116
120
  'duckdb:///:memory:'
117
121
 
118
122
  :param config:
123
+ :param base_dir:
124
+ :param auto_attach:
119
125
  :param kwargs:
120
126
  :return:
121
127
 
@@ -125,17 +131,17 @@ class Client:
125
131
  if isinstance(config, Path):
126
132
  config = str(config)
127
133
  if isinstance(config, str):
128
- # if not base_dir:
129
- # base_dir = Path(config).parent
134
+ if not base_dir:
135
+ base_dir = Path(config).parent
130
136
  parsed_obj = yaml.safe_load(open(config))
131
137
  config = ClientConfig(**parsed_obj)
132
138
  self.metadata = config
133
139
  if base_dir:
134
140
  self.metadata.base_dir = base_dir
135
- self._initialize_databases(**kwargs)
141
+ self._initialize_databases(auto_attach=auto_attach, **kwargs)
136
142
  return self
137
143
 
138
- def _initialize_databases(self, **kwargs):
144
+ def _initialize_databases(self, auto_attach=False, **kwargs):
139
145
  for name, db_config in self.metadata.databases.items():
140
146
  base_dir = self.base_dir
141
147
  logger.info(f"Initializing database: {name}, base_dir: {base_dir}")
@@ -146,8 +152,22 @@ class Client:
146
152
  db_config.handle = handle
147
153
  if db_config.schema_location:
148
154
  db_config.schema_location = db_config.schema_location.format(base_dir=base_dir)
149
- db = self.attach_database(handle, alias=name, **kwargs)
150
- db.from_config(db_config)
155
+ if auto_attach:
156
+ db = self.attach_database(handle, alias=name, **kwargs)
157
+ db.from_config(db_config)
158
+
159
+ def _set_database_config(self, db: Database):
160
+ """
161
+ Set the configuration for a database.
162
+
163
+ :param name:
164
+ :param config:
165
+ :return:
166
+ """
167
+ if not self.metadata:
168
+ return
169
+ if db.alias in self.metadata.databases:
170
+ db.from_config(self.metadata.databases[db.alias])
151
171
 
152
172
  def attach_database(
153
173
  self,
@@ -202,6 +222,7 @@ class Client:
202
222
  raise AssertionError(f"Inconsistent alias: {db.alias} != {alias}")
203
223
  else:
204
224
  db.metadata.alias = alias
225
+ self._set_database_config(db)
205
226
  return db
206
227
 
207
228
  def get_database(self, name: Optional[str] = None, create_if_not_exists=True, **kwargs) -> Database:
@@ -230,13 +251,19 @@ class Client:
230
251
  return list(self._databases.values())[0]
231
252
  if not self._databases:
232
253
  self._databases = {}
254
+ if name not in self._databases and name in self.metadata.databases:
255
+ db_config = self.metadata.databases[name]
256
+ db = self.attach_database(db_config.handle, alias=name, **kwargs)
257
+ self._databases[name] = db
233
258
  if name not in self._databases:
234
259
  if create_if_not_exists:
235
260
  logger.info(f"Creating database: {name}")
236
261
  self.attach_database(name, **kwargs)
237
262
  else:
238
263
  raise ValueError(f"Database {name} does not exist")
239
- return self._databases[name]
264
+ db = self._databases[name]
265
+ self._set_database_config(db)
266
+ return db
240
267
 
241
268
  @property
242
269
  def databases(self) -> Dict[str, Database]:
@@ -502,6 +502,7 @@ class Collection(Generic[DatabaseType]):
502
502
  index_name = self.default_index_name
503
503
  ix_coll = self.parent.get_collection(self._index_collection_name(index_name))
504
504
  if index_name not in self.indexers:
505
+ logger.debug(f"Indexer not found: {index_name} -- creating")
505
506
  ix = get_indexer(index_name)
506
507
  if not self._indexers:
507
508
  self._indexers = {}
@@ -509,6 +510,11 @@ class Collection(Generic[DatabaseType]):
509
510
  ix = self.indexers.get(index_name)
510
511
  if not ix:
511
512
  raise ValueError(f"No index named {index_name}")
513
+ logger.debug(f"Using indexer {type(ix)} with name {index_name}")
514
+ if ix_coll.size() == 0:
515
+ logger.info(f"Index {index_name} is empty; indexing all objects")
516
+ all_objs = self.find(limit=-1).rows
517
+ self.index_objects(all_objs, index_name, replace=True, **kwargs)
512
518
  qr = ix_coll.find(where=where, limit=-1, **kwargs)
513
519
  index_col = ix.index_field
514
520
  # TODO: optimize this for large indexes
@@ -518,6 +524,7 @@ class Collection(Generic[DatabaseType]):
518
524
  del r[1][index_col]
519
525
  new_qr = QueryResult(num_rows=len(results))
520
526
  new_qr.ranked_rows = results
527
+ new_qr.rows = [r[1] for r in results]
521
528
  return new_qr
522
529
 
523
530
  @property
@@ -562,6 +569,7 @@ class Collection(Generic[DatabaseType]):
562
569
  format=source.format,
563
570
  expected_type=source.expected_type,
564
571
  compression=source.compression,
572
+ select_query=source.select_query,
565
573
  **kwargs,
566
574
  )
567
575
  elif metadata.source.url:
@@ -570,9 +578,12 @@ class Collection(Generic[DatabaseType]):
570
578
  format=source.format,
571
579
  expected_type=source.expected_type,
572
580
  compression=source.compression,
581
+ select_query=source.select_query,
573
582
  **kwargs,
574
583
  )
575
- self.insert(objects)
584
+ else:
585
+ raise ValueError("No source local_path or url provided")
586
+ self.insert(objects)
576
587
 
577
588
  def _check_if_initialized(self) -> bool:
578
589
  return self._initialized
@@ -629,6 +640,14 @@ class Collection(Generic[DatabaseType]):
629
640
  self.insert(tr_objs)
630
641
  self.commit()
631
642
 
643
+ def size(self) -> int:
644
+ """
645
+ Return the number of objects in the collection.
646
+
647
+ :return: The number of objects in the collection.
648
+ """
649
+ return self.find({}, limit=1).num_rows
650
+
632
651
  def attach_indexer(self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs):
633
652
  """
634
653
  Attach an index to the collection.
@@ -777,6 +796,8 @@ class Collection(Generic[DatabaseType]):
777
796
  sv: SchemaView = self.parent.schema_view
778
797
  if sv:
779
798
  cls = sv.get_class(self.target_class_name)
799
+ # if not cls:
800
+ # logger.warning(f"{self.target_class_name} not in {sv.all_classes().keys()} ")
780
801
  # cls = sv.schema.classes[self.target_class_name]
781
802
  if cls and not cls.attributes:
782
803
  if not sv.class_induced_slots(cls.name):
@@ -900,11 +921,14 @@ class Collection(Generic[DatabaseType]):
900
921
  exact_dimensions_list.append(v.shape)
901
922
  break
902
923
  if isinstance(v, list):
924
+ # sample first item. TODO: more robust strategy
903
925
  v = v[0] if v else None
904
926
  multivalueds.append(True)
905
927
  elif isinstance(v, dict):
906
- v = list(v.values())[0]
907
- multivalueds.append(True)
928
+ pass
929
+ # TODO: check if this is a nested object or key-value list
930
+ # v = list(v.values())[0]
931
+ # multivalueds.append(True)
908
932
  else:
909
933
  multivalueds.append(False)
910
934
  if not v:
@@ -933,10 +957,21 @@ class Collection(Generic[DatabaseType]):
933
957
  # raise AssertionError(f"Empty rngs for {k} = {vs}")
934
958
  rng = rngs[0] if rngs else None
935
959
  for other_rng in rngs:
960
+ coercions = {
961
+ ("integer", "float"): "float",
962
+ }
936
963
  if rng != other_rng:
937
- raise ValueError(f"Conflict: {rng} != {other_rng} for {vs}")
964
+ if (rng, other_rng) in coercions:
965
+ rng = coercions[(rng, other_rng)]
966
+ elif (other_rng, rng) in coercions:
967
+ rng = coercions[(other_rng, rng)]
968
+ else:
969
+ raise ValueError(f"Conflict: {rng} != {other_rng} for {vs}")
938
970
  logger.debug(f"Inducing {k} as {rng} {multivalued} {inlined}")
939
- cd.attributes[k] = SlotDefinition(k, range=rng, multivalued=multivalued, inlined=inlined)
971
+ inlined_as_list = inlined and multivalued
972
+ cd.attributes[k] = SlotDefinition(
973
+ k, range=rng, multivalued=multivalued, inlined=inlined, inlined_as_list=inlined_as_list
974
+ )
940
975
  if exact_dimensions_list:
941
976
  array_expr = ArrayExpression(exact_number_dimensions=len(exact_dimensions_list[0]))
942
977
  cd.attributes[k].array = array_expr
@@ -1,8 +1,8 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Dict, List, Optional, Union
2
2
 
3
3
  from pydantic import BaseModel, Field
4
4
 
5
- from linkml_store.graphs.graph_map import GraphProjection
5
+ from linkml_store.graphs.graph_map import EdgeProjection, NodeProjection
6
6
 
7
7
 
8
8
  class ConfiguredBaseModel(BaseModel, extra="forbid"):
@@ -30,13 +30,30 @@ class CollectionSource(ConfiguredBaseModel):
30
30
  """
31
31
 
32
32
  url: Optional[str] = None
33
+ """Remote URL to fetch data from"""
34
+
33
35
  local_path: Optional[str] = None
36
+ """Local path to fetch data from"""
37
+
34
38
  source_location: Optional[str] = None
39
+
35
40
  refresh_interval_days: Optional[float] = None
41
+ """How often to refresh the data, in days"""
42
+
36
43
  expected_type: Optional[str] = None
44
+ """The expected type of the data, e.g list"""
45
+
37
46
  format: Optional[str] = None
47
+ """The format of the data, e.g., json, yaml, csv"""
48
+
38
49
  compression: Optional[str] = None
50
+ """The compression of the data, e.g., tgz, gzip, zip"""
51
+
52
+ select_query: Optional[str] = None
53
+ """A jsonpath query to preprocess the objects with"""
54
+
39
55
  arguments: Optional[Dict[str, Any]] = None
56
+ """Optional arguments to pass to the source"""
40
57
 
41
58
 
42
59
  class CollectionConfig(ConfiguredBaseModel):
@@ -81,7 +98,7 @@ class CollectionConfig(ConfiguredBaseModel):
81
98
  description="LinkML-Map derivations",
82
99
  )
83
100
  page_size: Optional[int] = Field(default=None, description="Suggested page size (items per page) in apps and APIs")
84
- graph_projection: Optional[GraphProjection] = Field(
101
+ graph_projection: Optional[Union[EdgeProjection, NodeProjection]] = Field(
85
102
  default=None,
86
103
  description="Optional graph projection configuration",
87
104
  )
@@ -707,12 +707,29 @@ class Database(ABC, Generic[CollectionType]):
707
707
  """
708
708
  raise NotImplementedError()
709
709
 
710
- def import_database(self, location: str, source_format: Optional[Union[str, Format]] = None, **kwargs):
710
+ def import_database(
711
+ self,
712
+ location: str,
713
+ source_format: Optional[Union[str, Format]] = None,
714
+ collection_name: Optional[str] = None,
715
+ **kwargs,
716
+ ):
711
717
  """
712
718
  Import a database from a file or location.
713
719
 
720
+ >>> from linkml_store.api.client import Client
721
+ >>> client = Client()
722
+ >>> db = client.attach_database("duckdb", alias="test")
723
+ >>> db.import_database("tests/input/iris.csv", Format.CSV, collection_name="iris")
724
+ >>> db.list_collection_names()
725
+ ['iris']
726
+ >>> collection = db.get_collection("iris")
727
+ >>> collection.find({}).num_rows
728
+ 150
729
+
714
730
  :param location: location of the file
715
731
  :param source_format: source format
732
+ :param collection_name: (Optional) name of the collection, for data that is flat
716
733
  :param kwargs: additional arguments
717
734
  """
718
735
  if isinstance(source_format, str):
@@ -732,8 +749,12 @@ class Database(ABC, Generic[CollectionType]):
732
749
  self.store(obj)
733
750
  return
734
751
  objects = load_objects(location, format=source_format)
735
- for obj in objects:
736
- self.store(obj)
752
+ if collection_name:
753
+ collection = self.get_collection(collection_name, create_if_not_exists=True)
754
+ collection.insert(objects)
755
+ else:
756
+ for obj in objects:
757
+ self.store(obj)
737
758
 
738
759
  def export_database(self, location: str, target_format: Optional[Union[str, Format]] = None, **kwargs):
739
760
  """
@@ -51,9 +51,13 @@ class MongoDBCollection(Collection):
51
51
  if offset and offset >= 0:
52
52
  cursor = cursor.skip(offset)
53
53
 
54
+ select_cols = query.select_cols
55
+
54
56
  def _as_row(row: dict):
55
57
  row = copy(row)
56
58
  del row["_id"]
59
+ if select_cols:
60
+ row = {k: row[k] for k in select_cols if k in row}
57
61
  return row
58
62
 
59
63
  rows = [_as_row(row) for row in cursor]
linkml_store/cli.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import sys
3
3
  import warnings
4
+ from pathlib import Path
4
5
  from typing import Optional
5
6
 
6
7
  import click
@@ -10,14 +11,22 @@ from pydantic import BaseModel
10
11
 
11
12
  from linkml_store import Client
12
13
  from linkml_store.api import Collection, Database
14
+ from linkml_store.api.config import ClientConfig
13
15
  from linkml_store.api.queries import Query
14
16
  from linkml_store.index import get_indexer
15
17
  from linkml_store.index.implementations.simple_indexer import SimpleIndexer
16
18
  from linkml_store.index.indexer import Indexer
19
+ from linkml_store.inference import get_inference_engine
20
+ from linkml_store.inference.inference_config import InferenceConfig
21
+ from linkml_store.inference.inference_engine import ModelSerialization
17
22
  from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output, write_output
18
23
  from linkml_store.utils.object_utils import object_path_update
19
24
  from linkml_store.utils.pandas_utils import facet_summary_to_dataframe_unmelted
20
25
 
26
+ DEFAULT_LOCAL_CONF_PATH = Path("linkml.yaml")
27
+ # global path is ~/.linkml.yaml in the user's home directory
28
+ DEFAULT_GLOBAL_CONF_PATH = Path("~/.linkml.yaml").expanduser()
29
+
21
30
  index_type_option = click.option(
22
31
  "--index-type",
23
32
  "-t",
@@ -84,6 +93,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
84
93
  @click.group()
85
94
  @click.option("--database", "-d", help="Database name")
86
95
  @click.option("--collection", "-c", help="Collection name")
96
+ @click.option("--input", "-i", help="Input file (alternative to database/collection)")
87
97
  @click.option("--config", "-C", type=click.Path(exists=True), help="Path to the configuration file")
88
98
  @click.option("--set", help="Metadata settings in the form PATHEXPR=value", multiple=True)
89
99
  @click.option("-v", "--verbose", count=True)
@@ -96,7 +106,7 @@ include_internal_option = click.option("--include-internal/--no-include-internal
96
106
  help="If set then show full stacktrace on error",
97
107
  )
98
108
  @click.pass_context
99
- def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, **kwargs):
109
+ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection, config, set, input, **kwargs):
100
110
  """A CLI for interacting with the linkml-store."""
101
111
  if not stacktrace:
102
112
  sys.tracebacklimit = 0
@@ -119,13 +129,25 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
119
129
  if quiet:
120
130
  logger.setLevel(logging.ERROR)
121
131
  ctx.ensure_object(dict)
132
+ if input:
133
+ stem = Path(input).stem
134
+ database = "duckdb"
135
+ collection = stem
136
+ config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
137
+ # collection = Path(input).stem
138
+ # database = f"file:{Path(input).parent}"
139
+ if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
140
+ config = DEFAULT_LOCAL_CONF_PATH
141
+ if config is None and DEFAULT_GLOBAL_CONF_PATH.exists():
142
+ config = DEFAULT_GLOBAL_CONF_PATH
143
+ if config == ".":
144
+ config = None
145
+ if not collection and database and "::" in database:
146
+ database, collection = database.split("::")
147
+
122
148
  client = Client().from_config(config, **kwargs) if config else Client()
123
149
  settings = ContextSettings(client=client, database_name=database, collection_name=collection)
124
150
  ctx.obj["settings"] = settings
125
- # DEPRECATED
126
- ctx.obj["client"] = client
127
- ctx.obj["database"] = database
128
- ctx.obj["collection"] = collection
129
151
  if settings.database_name:
130
152
  db = client.get_database(database)
131
153
  if set:
@@ -136,12 +158,6 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
136
158
  val = yaml.safe_load(val)
137
159
  logger.info(f"Setting {path} to {val}")
138
160
  db.metadata = object_path_update(db.metadata, path, val)
139
- # settings.database = db
140
- # DEPRECATED
141
- ctx.obj["database_obj"] = db
142
- if collection:
143
- collection_obj = db.get_collection(collection)
144
- ctx.obj["collection_obj"] = collection_obj
145
161
  if not settings.database_name:
146
162
  # if len(client.databases) != 1:
147
163
  # raise ValueError("Database must be specified if there are multiple databases.")
@@ -323,11 +339,12 @@ def apply(ctx, patch_files, identifier_attribute):
323
339
 
324
340
  @cli.command()
325
341
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query, as YAML")
342
+ @click.option("--select", "-s", type=click.STRING, help="SELECT clause for the query, as YAML")
326
343
  @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
327
344
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
328
345
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
329
346
  @click.pass_context
330
- def query(ctx, where, limit, output_type, output):
347
+ def query(ctx, where, select, limit, output_type, output):
331
348
  """Query objects from the specified collection.
332
349
 
333
350
 
@@ -353,7 +370,13 @@ def query(ctx, where, limit, output_type, output):
353
370
  """
354
371
  collection = ctx.obj["settings"].collection
355
372
  where_clause = yaml.safe_load(where) if where else None
356
- query = Query(from_table=collection.alias, where_clause=where_clause, limit=limit)
373
+ select_clause = yaml.safe_load(select) if select else None
374
+ if select_clause:
375
+ if isinstance(select_clause, str):
376
+ select_clause = [select_clause]
377
+ if not isinstance(select_clause, list):
378
+ raise ValueError(f"SELECT clause must be a list. Got: {select_clause}")
379
+ query = Query(from_table=collection.alias, select_cols=select_clause, where_clause=where_clause, limit=limit)
357
380
  result = collection.query(query)
358
381
  output_data = render_output(result.rows, output_type)
359
382
  if output:
@@ -458,6 +481,110 @@ def describe(ctx, where, output_type, output, limit):
458
481
  write_output(df.describe(include="all").transpose(), output_type, target=output)
459
482
 
460
483
 
484
+ @cli.command()
485
+ @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
486
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
487
+ @click.option("--target-attribute", "-T", type=click.STRING, multiple=True, help="Target attributes for inference")
488
+ @click.option(
489
+ "--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
490
+ )
491
+ @click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
492
+ @click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
493
+ @click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
494
+ @click.option("--model-format", "-M", type=click.Choice([x.value for x in ModelSerialization]), help="Format for model")
495
+ @click.option("--training-test-data-split", "-S", type=click.Tuple([float, float]), help="Training/test data split")
496
+ @click.option(
497
+ "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
498
+ )
499
+ @click.option("--query", "-q", type=click.STRING, help="query term")
500
+ @click.pass_context
501
+ def infer(
502
+ ctx,
503
+ inference_config_file,
504
+ query,
505
+ training_test_data_split,
506
+ predictor_type,
507
+ target_attribute,
508
+ feature_attributes,
509
+ output_type,
510
+ output,
511
+ model_format,
512
+ export_model,
513
+ load_model,
514
+ ):
515
+ """
516
+ Predict a complete object from a partial object.
517
+
518
+ Currently two main prediction methods are provided: RAG and sklearn
519
+
520
+ ## RAG:
521
+
522
+ The RAG approach will use Retrieval Augmented Generation to inference the missing attributes of an object.
523
+
524
+ Example:
525
+
526
+ linkml-store -i countries.jsonl inference -t rag -q 'name: Uruguay'
527
+
528
+ Result:
529
+
530
+ capital: Montevideo, code: UY, continent: South America, languages: [Spanish]
531
+
532
+ You can pass in configurations as follows:
533
+
534
+ linkml-store -i countries.jsonl inference -t rag:llm_config.model_name=llama-3 -q 'name: Uruguay'
535
+
536
+ ## SKLearn:
537
+
538
+ This uses scikit-learn (defaulting to simple decision trees) to do the prediction.
539
+
540
+ linkml-store -i tests/input/iris.csv inference -t sklearn \
541
+ -q '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
542
+ """
543
+ if query:
544
+ query_obj = yaml.safe_load(query)
545
+ else:
546
+ query_obj = None
547
+ collection = ctx.obj["settings"].collection
548
+ atts = collection.class_definition().attributes.keys()
549
+ if model_format:
550
+ model_format = ModelSerialization(model_format)
551
+ if load_model:
552
+ predictor = get_inference_engine(predictor_type)
553
+ predictor = type(predictor).load_model(load_model)
554
+ else:
555
+ if feature_attributes:
556
+ features = feature_attributes.split(",")
557
+ features = [f.strip() for f in features]
558
+ else:
559
+ if query_obj:
560
+ features = query_obj.keys()
561
+ else:
562
+ features = None
563
+ if target_attribute:
564
+ target_attributes = list(target_attribute)
565
+ else:
566
+ target_attributes = [att for att in atts if att not in features]
567
+ if inference_config_file:
568
+ config = InferenceConfig.from_file(inference_config_file)
569
+ else:
570
+ config = InferenceConfig(target_attributes=target_attributes, feature_attributes=features)
571
+ if training_test_data_split:
572
+ config.train_test_split = training_test_data_split
573
+ predictor = get_inference_engine(predictor_type, config=config)
574
+ predictor.load_and_split_data(collection)
575
+ predictor.initialize_model()
576
+ if export_model:
577
+ logger.info(f"Exporting model to {export_model} in {model_format}")
578
+ predictor.export_model(export_model, model_format)
579
+ if not query_obj:
580
+ if not export_model:
581
+ raise ValueError("Query must be specified if not exporting model")
582
+ if query_obj:
583
+ result = predictor.derive(query_obj)
584
+ dumped_obj = result.model_dump(exclude_none=True)
585
+ write_output([dumped_obj], output_type, target=output)
586
+
587
+
461
588
  @cli.command()
462
589
  @index_type_option
463
590
  @click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
@@ -0,0 +1,13 @@
1
+ """
2
+ inference engine package.
3
+ """
4
+
5
+ from linkml_store.inference.inference_config import InferenceConfig
6
+ from linkml_store.inference.inference_engine import InferenceEngine
7
+ from linkml_store.inference.inference_engine_registry import get_inference_engine
8
+
9
+ __all__ = [
10
+ "InferenceEngine",
11
+ "InferenceConfig",
12
+ "get_inference_engine",
13
+ ]
File without changes