linkml-store 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (35) hide show
  1. linkml_store/api/client.py +15 -4
  2. linkml_store/api/collection.py +185 -15
  3. linkml_store/api/config.py +11 -3
  4. linkml_store/api/database.py +36 -5
  5. linkml_store/api/stores/duckdb/duckdb_collection.py +6 -3
  6. linkml_store/api/stores/duckdb/duckdb_database.py +20 -1
  7. linkml_store/api/stores/filesystem/__init__.py +7 -8
  8. linkml_store/api/stores/filesystem/filesystem_collection.py +150 -113
  9. linkml_store/api/stores/filesystem/filesystem_database.py +57 -21
  10. linkml_store/api/stores/mongodb/mongodb_collection.py +82 -34
  11. linkml_store/api/stores/mongodb/mongodb_database.py +13 -2
  12. linkml_store/api/types.py +4 -0
  13. linkml_store/cli.py +97 -8
  14. linkml_store/index/__init__.py +5 -3
  15. linkml_store/index/indexer.py +7 -2
  16. linkml_store/utils/change_utils.py +17 -0
  17. linkml_store/utils/format_utils.py +89 -8
  18. linkml_store/utils/patch_utils.py +126 -0
  19. linkml_store/utils/query_utils.py +89 -0
  20. linkml_store/utils/schema_utils.py +23 -0
  21. linkml_store/webapi/__init__.py +0 -0
  22. linkml_store/webapi/html/__init__.py +3 -0
  23. linkml_store/webapi/html/base.html.j2 +24 -0
  24. linkml_store/webapi/html/collection_details.html.j2 +15 -0
  25. linkml_store/webapi/html/database_details.html.j2 +16 -0
  26. linkml_store/webapi/html/databases.html.j2 +14 -0
  27. linkml_store/webapi/html/generic.html.j2 +46 -0
  28. linkml_store/webapi/main.py +572 -0
  29. linkml_store-0.1.10.dist-info/METADATA +138 -0
  30. linkml_store-0.1.10.dist-info/RECORD +58 -0
  31. {linkml_store-0.1.8.dist-info → linkml_store-0.1.10.dist-info}/entry_points.txt +1 -0
  32. linkml_store-0.1.8.dist-info/METADATA +0 -58
  33. linkml_store-0.1.8.dist-info/RECORD +0 -45
  34. {linkml_store-0.1.8.dist-info → linkml_store-0.1.10.dist-info}/LICENSE +0 -0
  35. {linkml_store-0.1.8.dist-info → linkml_store-0.1.10.dist-info}/WHEEL +0 -0
@@ -1,142 +1,179 @@
1
1
  import logging
2
+ from pathlib import Path
2
3
  from typing import Any, Dict, List, Optional, Union
3
4
 
4
- import sqlalchemy as sqla
5
- from linkml_runtime.linkml_model import ClassDefinition, SlotDefinition
6
- from sqlalchemy import Column, Table, delete, insert, inspect, text
7
- from sqlalchemy.sql.ddl import CreateTable
8
-
9
5
  from linkml_store.api import Collection
10
6
  from linkml_store.api.collection import DEFAULT_FACET_LIMIT, OBJECT
11
- from linkml_store.api.queries import Query
12
- from linkml_store.api.stores.duckdb.mappings import TMAP
13
- from linkml_store.utils.sql_utils import facet_count_sql
7
+ from linkml_store.api.queries import Query, QueryResult
8
+ from linkml_store.api.types import DatabaseType
9
+ from linkml_store.utils.query_utils import mongo_query_to_match_function
14
10
 
15
11
  logger = logging.getLogger(__name__)
16
12
 
17
13
 
18
- class FileSystemCollection(Collection):
19
- _table_created: bool = None
14
+ class FileSystemCollection(Collection[DatabaseType]):
15
+ path: Optional[Path] = None
16
+ file_format: Optional[str] = None
17
+ encoding: Optional[str] = None
18
+ _objects_list: List[OBJECT] = None
19
+ _object_map: Dict[str, OBJECT] = None
20
+
21
+ def __init__(self, **kwargs):
22
+ super().__init__(**kwargs)
23
+ parent: DatabaseType = self.parent
24
+ if not self.path:
25
+ if self.parent:
26
+ self.path = Path(parent.directory_path)
27
+ self._objects_list = []
28
+ self._object_map = {}
29
+ if not self.file_format:
30
+ self.file_format = "json"
31
+
32
+ @property
33
+ def path_to_file(self):
34
+ return Path(self.parent.directory_path) / f"{self.name}.{self.file_format}"
35
+
36
+ @property
37
+ def objects_as_list(self) -> List[OBJECT]:
38
+ if self._object_map:
39
+ return list(self._object_map.values())
40
+ else:
41
+ return self._objects_list
42
+
43
+ def _set_objects(self, objs: List[OBJECT]):
44
+ pk = self.identifier_attribute_name
45
+ if pk:
46
+ self._object_map = {obj[pk]: obj for obj in objs}
47
+ self._objects_list = []
48
+ else:
49
+ self._objects_list = objs
50
+ self._object_map = {}
51
+
52
+ def commit(self):
53
+ path = self.path_to_file
54
+ if not path:
55
+ raise ValueError("Path not set")
56
+ path.parent.mkdir(parents=True, exist_ok=True)
57
+ self._save(path)
58
+
59
+ def _save(self, path: Path):
60
+ encoding = self.encoding or "utf-8"
61
+ fmt = self.file_format or "json"
62
+ mode = "w"
63
+ if fmt == "parquet":
64
+ mode = "wb"
65
+ encoding = None
66
+ with open(path, mode, encoding=encoding) as stream:
67
+ if fmt == "json":
68
+ import json
69
+
70
+ json.dump(self.objects_as_list, stream, indent=2)
71
+ elif fmt == "jsonl":
72
+ import jsonlines
73
+
74
+ writer = jsonlines.Writer(stream)
75
+ writer.write_all(self.objects_as_list)
76
+ elif fmt == "yaml":
77
+ import yaml
78
+
79
+ yaml.dump_all(self.objects_as_list, stream)
80
+ elif fmt == "parquet":
81
+ import pandas as pd
82
+ import pyarrow
83
+ import pyarrow.parquet as pq
84
+
85
+ df = pd.DataFrame(self.objects_as_list)
86
+ table = pyarrow.Table.from_pandas(df)
87
+ pq.write_table(table, stream)
88
+ elif fmt in {"csv", "tsv"}:
89
+ import csv
90
+
91
+ delimiter = "\t" if fmt == "tsv" else ","
92
+ fieldnames = list(self.objects_as_list[0].keys())
93
+ for obj in self.objects_as_list[1:]:
94
+ fieldnames.extend([k for k in obj.keys() if k not in fieldnames])
95
+ writer = csv.DictWriter(stream, fieldnames=fieldnames, delimiter=delimiter)
96
+ writer.writeheader()
97
+ for obj in self.objects_as_list:
98
+ writer.writerow(obj)
99
+ else:
100
+ raise ValueError(f"Unsupported file format: {fmt}")
20
101
 
21
102
  def insert(self, objs: Union[OBJECT, List[OBJECT]], **kwargs):
22
103
  if not isinstance(objs, list):
23
104
  objs = [objs]
24
105
  if not objs:
25
106
  return
26
- cd = self.class_definition()
27
- if not cd:
28
- cd = self.induce_class_definition_from_objects(objs)
29
- self._create_table(cd)
30
- table = self._sqla_table(cd)
31
- logger.info(f"Inserting into: {self.alias} // T={table.name}")
32
- engine = self.parent.engine
33
- col_names = [c.name for c in table.columns]
34
- objs = [{k: obj.get(k, None) for k in col_names} for obj in objs]
35
- with engine.connect() as conn:
36
- with conn.begin():
37
- conn.execute(insert(table), objs)
38
- conn.commit()
107
+ pk = self.identifier_attribute_name
108
+ if pk:
109
+ for obj in objs:
110
+ if pk not in obj:
111
+ raise ValueError(f"Primary key {pk} not found in object {obj}")
112
+ pk_val = obj[pk]
113
+ self._object_map[pk_val] = obj
114
+ else:
115
+ self._objects_list.extend(objs)
39
116
 
40
117
  def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
41
118
  if not isinstance(objs, list):
42
119
  objs = [objs]
43
- cd = self.class_definition()
44
- if not cd:
45
- cd = self.induce_class_definition_from_objects(objs)
46
- table = self._sqla_table(cd)
47
- engine = self.parent.engine
48
- with engine.connect() as conn:
120
+ if not objs:
121
+ return 0
122
+ pk = self.identifier_attribute_name
123
+ n = 0
124
+ if pk:
49
125
  for obj in objs:
50
- conditions = [table.c[k] == v for k, v in obj.items() if k in cd.attributes]
51
- stmt = delete(table).where(*conditions)
52
- stmt = stmt.compile(engine)
53
- conn.execute(stmt)
54
- conn.commit()
55
- return
126
+ pk_val = obj[pk]
127
+ if pk_val in self._object_map:
128
+ del self._object_map[pk_val]
129
+ n += 1
130
+ else:
131
+ n = len(objs)
132
+ self._objects_list = [o for o in self._objects_list if o not in objs]
133
+ n = n - len(objs)
134
+ return n
56
135
 
57
136
  def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
58
137
  logger.info(f"Deleting from {self.target_class_name} where: {where}")
59
138
  if where is None:
60
139
  where = {}
61
- cd = self.class_definition()
62
- if not cd:
63
- logger.info(f"No class definition found for {self.target_class_name}, assuming not prepopulated")
64
- return 0
65
- table = self._sqla_table(cd)
66
- engine = self.parent.engine
67
- inspector = inspect(engine)
68
- table_exists = table.name in inspector.get_table_names()
69
- if not table_exists:
70
- logger.info(f"Table {table.name} does not exist, assuming no data")
71
- return 0
72
- with engine.connect() as conn:
73
- conditions = [table.c[k] == v for k, v in where.items()]
74
- stmt = delete(table).where(*conditions)
75
- stmt = stmt.compile(engine)
76
- result = conn.execute(stmt)
77
- deleted_rows_count = result.rowcount
78
- if deleted_rows_count == 0 and not missing_ok:
79
- raise ValueError(f"No rows found for {where}")
80
- conn.commit()
81
- return deleted_rows_count if deleted_rows_count > -1 else None
140
+
141
+ def matches(obj: OBJECT):
142
+ for k, v in where.items():
143
+ if obj.get(k) != v:
144
+ return False
145
+ return True
146
+
147
+ print(type(self))
148
+ print(self)
149
+ print(vars(self))
150
+ curr_objects = [o for o in self.objects_as_list if not matches(o)]
151
+ self._set_objects(curr_objects)
152
+
153
+ def query(self, query: Query, **kwargs) -> QueryResult:
154
+
155
+ where = query.where_clause or {}
156
+ match = mongo_query_to_match_function(where)
157
+ rows = [o for o in self.objects_as_list if match(o)]
158
+ count = len(rows)
159
+ return QueryResult(query=query, num_rows=count, rows=rows)
82
160
 
83
161
  def query_facets(
84
162
  self, where: Dict = None, facet_columns: List[str] = None, facet_limit=DEFAULT_FACET_LIMIT, **kwargs
85
163
  ) -> Dict[str, Dict[str, int]]:
86
- results = {}
87
- cd = self.class_definition()
88
- with self.parent.engine.connect() as conn:
89
- if not facet_columns:
90
- facet_columns = list(self.class_definition().attributes.keys())
91
- for col in facet_columns:
92
- logger.debug(f"Faceting on {col}")
93
- if isinstance(col, tuple):
94
- sd = SlotDefinition(name="PLACEHOLDER")
95
- else:
96
- sd = cd.attributes[col]
97
- facet_query = self._create_query(where_clause=where)
98
- facet_query_str = facet_count_sql(facet_query, col, multivalued=sd.multivalued)
99
- logger.debug(f"Facet query: {facet_query_str}")
100
- rows = list(conn.execute(text(facet_query_str)))
101
- results[col] = rows
102
- return results
103
-
104
- def _sqla_table(self, cd: ClassDefinition) -> Table:
105
- schema_view = self.parent.schema_view
106
- metadata_obj = sqla.MetaData()
107
- cols = []
108
- for att in schema_view.class_induced_slots(cd.name):
109
- typ = TMAP.get(att.range, sqla.String)
110
- if att.inlined:
111
- typ = sqla.JSON
112
- if att.multivalued:
113
- typ = sqla.ARRAY(typ, dimensions=1)
114
- if att.array:
115
- typ = sqla.ARRAY(typ, dimensions=1)
116
- col = Column(att.name, typ)
117
- cols.append(col)
118
- t = Table(self.alias, metadata_obj, *cols)
119
- return t
120
-
121
- def _create_table(self, cd: ClassDefinition):
122
- if self._table_created or self.metadata.is_prepopulated:
123
- logger.info(f"Already have table for: {cd.name}")
124
- return
125
- query = Query(
126
- from_table="information_schema.tables", where_clause={"table_type": "BASE TABLE", "table_name": self.alias}
127
- )
128
- qr = self.parent.query(query)
129
- if qr.num_rows > 0:
130
- logger.info(f"Table already exists for {cd.name}")
131
- self._table_created = True
132
- self.metadata.is_prepopulated = True
133
- return
134
- logger.info(f"Creating table for {cd.name}")
135
- t = self._sqla_table(cd)
136
- ct = CreateTable(t)
137
- ddl = str(ct.compile(self.parent.engine))
138
- with self.parent.engine.connect() as conn:
139
- conn.execute(text(ddl))
140
- conn.commit()
141
- self._table_created = True
142
- self.metadata.is_prepopulated = True
164
+ match = mongo_query_to_match_function(where)
165
+ rows = [o for o in self.objects_as_list if match(o)]
166
+ if not facet_columns:
167
+ facet_columns = self.class_definition().attributes.keys()
168
+ facet_results = {c: {} for c in facet_columns}
169
+ for row in rows:
170
+ for fc in facet_columns:
171
+ if fc in row:
172
+ v = row[fc]
173
+ if not isinstance(v, str):
174
+ v = str(v)
175
+ if v not in facet_results[fc]:
176
+ facet_results[fc][v] = 1
177
+ else:
178
+ facet_results[fc][v] += 1
179
+ return {fc: list(facet_results[fc].items()) for fc in facet_results}
@@ -1,36 +1,72 @@
1
1
  import logging
2
+ from pathlib import Path
2
3
  from typing import Optional
3
4
 
4
- from linkml_store.api import Collection, Database
5
- from linkml_store.api.config import CollectionConfig
6
- from linkml_store.api.stores.duckdb import DuckDBDatabase
5
+ import yaml
6
+ from linkml.utils.schema_builder import SchemaBuilder
7
+ from linkml_runtime import SchemaView
8
+
9
+ from linkml_store.api import Database
10
+ from linkml_store.api.config import DatabaseConfig
7
11
  from linkml_store.api.stores.filesystem.filesystem_collection import FileSystemCollection
12
+ from linkml_store.utils.format_utils import Format, load_objects
8
13
 
9
14
  logger = logging.getLogger(__name__)
10
15
 
11
16
 
12
17
  class FileSystemDatabase(Database):
13
18
  collection_class = FileSystemCollection
14
- wrapped_database: Database = None
15
19
 
16
- def __init__(self, handle: Optional[str] = None, recreate_if_exists: bool = False, **kwargs):
17
- self.wrapped_database = DuckDBDatabase("duckdb:///:memory:")
20
+ directory_path: Optional[Path] = None
21
+ default_file_format: Optional[str] = None
22
+
23
+ def __init__(self, handle: Optional[str] = None, **kwargs):
24
+ handle = handle.replace("file:", "")
25
+ if handle.startswith("//"):
26
+ handle = handle[2:]
27
+ self.directory_path = Path(handle)
28
+ self.load_metadata()
18
29
  super().__init__(handle=handle, **kwargs)
19
30
 
20
- def commit(self, **kwargs):
21
- # TODO: sync
22
- pass
31
+ @property
32
+ def metadata_path(self) -> Path:
33
+ return self.directory_path / ".linkml_metadata.yaml"
34
+
35
+ def load_metadata(self):
36
+ if self.metadata_path.exists():
37
+ md_dict = yaml.safe_load(open(self.metadata_path))
38
+ metadata = DatabaseConfig(**md_dict)
39
+ else:
40
+ metadata = DatabaseConfig()
41
+ self.metadata = metadata
23
42
 
24
43
  def close(self, **kwargs):
25
- self.wrapped_database.close()
26
-
27
- def create_collection(
28
- self,
29
- name: str,
30
- alias: Optional[str] = None,
31
- metadata: Optional[CollectionConfig] = None,
32
- recreate_if_exists=False,
33
- **kwargs,
34
- ) -> Collection:
35
- wd = self.wrapped_database
36
- wd.create_collection()
44
+ pass
45
+
46
+ def init_collections(self):
47
+ metadata = self.metadata
48
+ if self._collections is None:
49
+ self._collections = {}
50
+ for name, collection_config in metadata.collections.items():
51
+ collection = FileSystemCollection(parent=self, **collection_config.dict())
52
+ self._collections[name] = collection
53
+ path = self.directory_path
54
+ if path.exists():
55
+ for fmt in Format:
56
+ suffix = fmt.value
57
+ logger.info(f"Looking for {suffix} files in {path}")
58
+ for f in path.glob(f"*.{suffix}"):
59
+ logger.info(f"Found {f}")
60
+ n = f.stem
61
+ objs = load_objects(f, suffix, expected_type=list)
62
+ collection = FileSystemCollection(parent=self, name=n)
63
+ self._collections[n] = collection
64
+ collection._set_objects(objs)
65
+
66
+ def induce_schema_view(self) -> SchemaView:
67
+ logger.info(f"Inducing schema view for {self.handle}")
68
+ sb = SchemaBuilder()
69
+
70
+ for collection_name in self.list_collection_names():
71
+ sb.add_class(collection_name)
72
+ return SchemaView(sb.schema)
@@ -2,7 +2,6 @@ import logging
2
2
  from copy import copy
3
3
  from typing import Any, Dict, List, Optional, Tuple, Union
4
4
 
5
- from linkml_runtime.linkml_model import SlotDefinition
6
5
  from pymongo.collection import Collection as MongoCollection
7
6
 
8
7
  from linkml_store.api import Collection
@@ -26,19 +25,27 @@ class MongoDBCollection(Collection):
26
25
  def mongo_collection(self) -> MongoCollection:
27
26
  if not self.name:
28
27
  raise ValueError("Collection name not set")
29
- return self.parent.native_db[self.name]
28
+ collection_name = self.alias or self.name
29
+ return self.parent.native_db[collection_name]
30
30
 
31
31
  def insert(self, objs: Union[OBJECT, List[OBJECT]], **kwargs):
32
32
  if not isinstance(objs, list):
33
33
  objs = [objs]
34
34
  self.mongo_collection.insert_many(objs)
35
+ # TODO: allow mapping of _id to id for efficiency
36
+ for obj in objs:
37
+ del obj["_id"]
38
+ self._post_insert_hook(objs)
35
39
 
36
- def query(self, query: Query, **kwargs) -> QueryResult:
40
+ def query(self, query: Query, limit: Optional[int] = None, offset: Optional[int] = None, **kwargs) -> QueryResult:
37
41
  mongo_filter = self._build_mongo_filter(query.where_clause)
38
- if query.limit:
39
- cursor = self.mongo_collection.find(mongo_filter).limit(query.limit)
40
- else:
41
- cursor = self.mongo_collection.find(mongo_filter)
42
+ limit = limit or query.limit
43
+ cursor = self.mongo_collection.find(mongo_filter)
44
+ if limit and limit >= 0:
45
+ cursor = cursor.limit(limit)
46
+ offset = offset or query.offset
47
+ if offset and offset >= 0:
48
+ cursor = cursor.skip(offset)
42
49
 
43
50
  def _as_row(row: dict):
44
51
  row = copy(row)
@@ -57,46 +64,87 @@ class MongoDBCollection(Collection):
57
64
  mongo_filter[field] = value
58
65
  return mongo_filter
59
66
 
67
+ from typing import Any, Dict, List, Union
68
+
60
69
  def query_facets(
61
- self, where: Dict = None, facet_columns: List[str] = None, facet_limit=DEFAULT_FACET_LIMIT, **kwargs
62
- ) -> Dict[str, List[Tuple[Any, int]]]:
70
+ self,
71
+ where: Dict = None,
72
+ facet_columns: List[Union[str, Tuple[str, ...]]] = None,
73
+ facet_limit=DEFAULT_FACET_LIMIT,
74
+ **kwargs,
75
+ ) -> Dict[Union[str, Tuple[str, ...]], List[Tuple[Any, int]]]:
63
76
  results = {}
64
- cd = self.class_definition()
65
77
  if not facet_columns:
66
78
  facet_columns = list(self.class_definition().attributes.keys())
67
79
 
68
80
  for col in facet_columns:
69
81
  logger.debug(f"Faceting on {col}")
82
+
83
+ # Handle tuple columns
84
+ if isinstance(col, tuple):
85
+ group_id = {k.replace(".", "_"): f"${k}" for k in col}
86
+ all_fields = col
87
+ else:
88
+ group_id = f"${col}"
89
+ all_fields = [col]
90
+
91
+ # Initial pipeline without unwinding
92
+ facet_pipeline = [
93
+ {"$match": where} if where else {"$match": {}},
94
+ {"$group": {"_id": group_id, "count": {"$sum": 1}}},
95
+ {"$sort": {"count": -1}},
96
+ {"$limit": facet_limit},
97
+ ]
98
+
99
+ logger.info(f"Initial facet pipeline: {facet_pipeline}")
100
+ initial_results = list(self.mongo_collection.aggregate(facet_pipeline))
101
+
102
+ # Check if we need to unwind based on the results
103
+ needs_unwinding = False
70
104
  if isinstance(col, tuple):
71
- sd = SlotDefinition(name="PLACEHOLDER")
105
+ needs_unwinding = any(
106
+ isinstance(result["_id"], dict) and any(isinstance(v, list) for v in result["_id"].values())
107
+ for result in initial_results
108
+ )
109
+ else:
110
+ needs_unwinding = any(isinstance(result["_id"], list) for result in initial_results)
111
+
112
+ if needs_unwinding:
113
+ logger.info(f"Detected array values for {col}, unwinding...")
114
+ facet_pipeline = [{"$match": where} if where else {"$match": {}}]
115
+
116
+ # Unwind each field if needed
117
+ for field in all_fields:
118
+ field_parts = field.split(".")
119
+ for i in range(len(field_parts)):
120
+ facet_pipeline.append({"$unwind": f"${'.'.join(field_parts[:i + 1])}"})
121
+
122
+ facet_pipeline.extend(
123
+ [
124
+ {"$group": {"_id": group_id, "count": {"$sum": 1}}},
125
+ {"$sort": {"count": -1}},
126
+ {"$limit": facet_limit},
127
+ ]
128
+ )
129
+
130
+ logger.info(f"Updated facet pipeline with unwinding: {facet_pipeline}")
131
+ facet_results = list(self.mongo_collection.aggregate(facet_pipeline))
72
132
  else:
73
- if col in cd.attributes:
74
- sd = cd.attributes[col]
75
- else:
76
- logger.info(f"No schema metadata for {col}")
77
- sd = SlotDefinition(name=col)
78
- group = {"$group": {"_id": f"${col}", "count": {"$sum": 1}}}
133
+ facet_results = initial_results
134
+
135
+ logger.info(f"Facet results: {facet_results}")
136
+
137
+ # Process results
79
138
  if isinstance(col, tuple):
80
- q = {k.replace(".", ""): f"${k}" for k in col}
81
- group["$group"]["_id"] = q
82
- if sd and sd.multivalued:
83
- facet_pipeline = [
84
- {"$match": where} if where else {"$match": {}},
85
- {"$unwind": f"${col}"},
86
- group,
87
- {"$sort": {"count": -1}},
88
- {"$limit": facet_limit},
139
+ results[col] = [
140
+ (tuple(result["_id"].values()), result["count"])
141
+ for result in facet_results
142
+ if result["_id"] is not None and all(v is not None for v in result["_id"].values())
89
143
  ]
90
144
  else:
91
- facet_pipeline = [
92
- {"$match": where} if where else {"$match": {}},
93
- group,
94
- {"$sort": {"count": -1}},
95
- {"$limit": facet_limit},
145
+ results[col] = [
146
+ (result["_id"], result["count"]) for result in facet_results if result["_id"] is not None
96
147
  ]
97
- logger.info(f"Facet pipeline: {facet_pipeline}")
98
- facet_results = list(self.mongo_collection.aggregate(facet_pipeline))
99
- results[col] = [(result["_id"], result["count"]) for result in facet_results]
100
148
 
101
149
  return results
102
150
 
@@ -29,9 +29,17 @@ class MongoDBDatabase(Database):
29
29
 
30
30
  def __init__(self, handle: Optional[str] = None, **kwargs):
31
31
  if handle is None:
32
- handle = "mongodb://localhost:27017"
32
+ handle = "mongodb://localhost:27017/test"
33
33
  super().__init__(handle=handle, **kwargs)
34
34
 
35
+ @property
36
+ def _db_name(self) -> str:
37
+ if self.handle:
38
+ db = self.handle.split("/")[-1]
39
+ else:
40
+ db = "default"
41
+ return db
42
+
35
43
  @property
36
44
  def native_client(self) -> MongoClient:
37
45
  if self._native_client is None:
@@ -44,7 +52,7 @@ class MongoDBDatabase(Database):
44
52
  alias = self.metadata.alias
45
53
  if not alias:
46
54
  alias = "default"
47
- self._native_db = self.native_client[alias]
55
+ self._native_db = self.native_client[self._db_name]
48
56
  return self._native_db
49
57
 
50
58
  def commit(self, **kwargs):
@@ -58,9 +66,12 @@ class MongoDBDatabase(Database):
58
66
  self.native_client.drop_database(self.metadata.alias)
59
67
 
60
68
  def query(self, query: Query, **kwargs) -> QueryResult:
69
+ # TODO: DRY
61
70
  if query.from_table:
62
71
  collection = self.get_collection(query.from_table)
63
72
  return collection.query(query, **kwargs)
73
+ else:
74
+ raise NotImplementedError(f"Querying without a table is not supported in {self.__class__.__name__}")
64
75
 
65
76
  def init_collections(self):
66
77
  if self._collections is None:
@@ -0,0 +1,4 @@
1
+ from typing import TypeVar
2
+
3
+ DatabaseType = TypeVar("DatabaseType", bound="Database") # noqa: F821
4
+ CollectionType = TypeVar("CollectionType", bound="Collection") # noqa: F821