chemrecon 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chemrecon/__init__.py +73 -0
- chemrecon/chem/__init__.py +0 -0
- chemrecon/chem/chemreaction.py +223 -0
- chemrecon/chem/constant_compounds.py +3 -0
- chemrecon/chem/create_mol.py +91 -0
- chemrecon/chem/elements.py +141 -0
- chemrecon/chem/gml/__init__.py +0 -0
- chemrecon/chem/gml/gml.py +324 -0
- chemrecon/chem/gml/gml_reactant_matching.py +130 -0
- chemrecon/chem/gml/gml_to_rdk.py +217 -0
- chemrecon/chem/mol.py +483 -0
- chemrecon/chem/sumformula.py +120 -0
- chemrecon/connection.py +97 -0
- chemrecon/core/__init__.py +0 -0
- chemrecon/core/id_types.py +687 -0
- chemrecon/core/ontology.py +209 -0
- chemrecon/core/populate_query_handler.py +336 -0
- chemrecon/core/query_handler.py +587 -0
- chemrecon/database/__init__.py +1 -0
- chemrecon/database/connect.py +63 -0
- chemrecon/database/connection_params/chemrecon_pub.dbinfo +5 -0
- chemrecon/database/connection_params/local_docker_dev.dbinfo +5 -0
- chemrecon/database/connection_params/local_docker_init.dbinfo +5 -0
- chemrecon/database/connection_params/local_docker_pub.dbinfo +5 -0
- chemrecon/database/params.py +88 -0
- chemrecon/entrygraph/draw.py +119 -0
- chemrecon/entrygraph/entrygraph.py +301 -0
- chemrecon/entrygraph/explorationprotocol.py +199 -0
- chemrecon/entrygraph/explore.py +421 -0
- chemrecon/entrygraph/explore_procedure.py +183 -0
- chemrecon/entrygraph/filter.py +88 -0
- chemrecon/entrygraph/scoring.py +141 -0
- chemrecon/query/__init__.py +26 -0
- chemrecon/query/create_entry.py +86 -0
- chemrecon/query/default_protocols.py +57 -0
- chemrecon/query/find_entry.py +84 -0
- chemrecon/query/get_relations.py +143 -0
- chemrecon/query/get_structures_from_compound.py +65 -0
- chemrecon/schema/__init__.py +86 -0
- chemrecon/schema/db_object.py +363 -0
- chemrecon/schema/direction.py +10 -0
- chemrecon/schema/entry_types/__init__.py +0 -0
- chemrecon/schema/entry_types/aam.py +34 -0
- chemrecon/schema/entry_types/aam_repr.py +37 -0
- chemrecon/schema/entry_types/compound.py +52 -0
- chemrecon/schema/entry_types/enzyme.py +49 -0
- chemrecon/schema/entry_types/molstructure.py +64 -0
- chemrecon/schema/entry_types/molstructure_repr.py +41 -0
- chemrecon/schema/entry_types/reaction.py +57 -0
- chemrecon/schema/enums.py +154 -0
- chemrecon/schema/procedural_relation_entrygraph.py +66 -0
- chemrecon/schema/relation_types_composed/__init__.py +0 -0
- chemrecon/schema/relation_types_composed/compound_has_molstructure_relation.py +59 -0
- chemrecon/schema/relation_types_composed/reaction_has_aam_relation.py +50 -0
- chemrecon/schema/relation_types_procedural/__init__.py +0 -0
- chemrecon/schema/relation_types_procedural/aam_convert_relation.py +69 -0
- chemrecon/schema/relation_types_procedural/compound_select_structure_proceduralrelation.py +36 -0
- chemrecon/schema/relation_types_procedural/compound_similarlity_proceduralrelation.py +1 -0
- chemrecon/schema/relation_types_procedural/molstructure_convert_relation.py +49 -0
- chemrecon/schema/relation_types_procedural/reaction_select_aam_proceduralrelation.py +38 -0
- chemrecon/schema/relation_types_procedural/reaction_similarity_proceduralrelation.py +1 -0
- chemrecon/schema/relation_types_source/__init__.py +0 -0
- chemrecon/schema/relation_types_source/aam_involves_molstructure_relation.py +77 -0
- chemrecon/schema/relation_types_source/aam_repr_involves_molstructure_repr_relation.py +79 -0
- chemrecon/schema/relation_types_source/compound_has_structure_representation_relation.py +33 -0
- chemrecon/schema/relation_types_source/compound_reference_relation.py +34 -0
- chemrecon/schema/relation_types_source/molstructure_standardisation_relation.py +71 -0
- chemrecon/schema/relation_types_source/ontology/__init__.py +0 -0
- chemrecon/schema/relation_types_source/ontology/compound_ontology.py +369 -0
- chemrecon/schema/relation_types_source/ontology/enzyme_ontology.py +142 -0
- chemrecon/schema/relation_types_source/ontology/reaction_ontology.py +140 -0
- chemrecon/schema/relation_types_source/reaction_has_aam_representation_relation.py +34 -0
- chemrecon/schema/relation_types_source/reaction_has_enzyme_relation.py +71 -0
- chemrecon/schema/relation_types_source/reaction_involves_compound_relation.py +69 -0
- chemrecon/schema/relation_types_source/reaction_reference_relation.py +33 -0
- chemrecon/scripts/initialize_database.py +494 -0
- chemrecon/utils/copy_signature.py +10 -0
- chemrecon/utils/encodeable_list.py +11 -0
- chemrecon/utils/get_id_type.py +70 -0
- chemrecon/utils/hungarian.py +31 -0
- chemrecon/utils/reactant_matching.py +168 -0
- chemrecon/utils/rxnutils.py +44 -0
- chemrecon/utils/set_cwd.py +12 -0
- chemrecon-0.1.1.dist-info/METADATA +143 -0
- chemrecon-0.1.1.dist-info/RECORD +86 -0
- chemrecon-0.1.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any, DefaultDict, Generator, Optional, Sequence
|
|
3
|
+
|
|
4
|
+
import psycopg as pg
|
|
5
|
+
import psycopg.sql as sql
|
|
6
|
+
from psycopg.rows import class_row
|
|
7
|
+
|
|
8
|
+
from chemrecon.schema.db_object import Column, DatabaseObject, Entry, Relation
|
|
9
|
+
import chemrecon.schema as schema
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class QueryHandler():
|
|
13
|
+
""" Context manager to wrap database queries.
|
|
14
|
+
"""
|
|
15
|
+
conn: pg.Connection
|
|
16
|
+
c: pg.Cursor
|
|
17
|
+
|
|
18
|
+
# The handler also keeps track of a parallel 'database' of procedurally generated entries
|
|
19
|
+
procedural_entries: dict[
|
|
20
|
+
type[Entry], set[Entry]
|
|
21
|
+
]
|
|
22
|
+
procedural_entries_index: dict[
|
|
23
|
+
tuple[type[Entry], tuple],
|
|
24
|
+
Entry
|
|
25
|
+
]
|
|
26
|
+
procedural_entry_recon_id: dict[
|
|
27
|
+
tuple[type[Entry], int],
|
|
28
|
+
Entry
|
|
29
|
+
]
|
|
30
|
+
procedural_id_table: dict[type[Entry], int] # For assigning procedural ids to entries created locally
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Context manager
|
|
34
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
35
|
+
def __init__(self, connection: pg.Connection):
|
|
36
|
+
self.conn = connection
|
|
37
|
+
self.c = self.conn.cursor()
|
|
38
|
+
|
|
39
|
+
self.procedural_entries = DefaultDict(set)
|
|
40
|
+
self.procedural_entries_index = dict()
|
|
41
|
+
self.procedural_entry_recon_id = dict()
|
|
42
|
+
self.procedural_id_table = {
|
|
43
|
+
e_type: -1 # Start at -1 and decrement
|
|
44
|
+
for e_type in schema.entrytypes
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def __enter__(self):
|
|
48
|
+
# Enter, returning the handler as the interface
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
52
|
+
# On leaving the context, execute the commands.
|
|
53
|
+
if exc_type is not None:
|
|
54
|
+
# Exception was raised in context
|
|
55
|
+
raise exc_val
|
|
56
|
+
self.conn.commit()
|
|
57
|
+
self.c.close()
|
|
58
|
+
|
|
59
|
+
# Procedural entries
|
|
60
|
+
# ------------------------------------------------------------------------------------------------------------------
|
|
61
|
+
def add_procedural_entry[T: Entry](self, entry: T) -> int:
|
|
62
|
+
""" Adds a procedural entry, and returns its virtual recon id. If already exists,
|
|
63
|
+
return the virtual recon id.
|
|
64
|
+
"""
|
|
65
|
+
# TODO does the fact that index is not an ordereddict cause problems?
|
|
66
|
+
index = tuple(val for val in entry.get_index_columns_with_values().values())
|
|
67
|
+
try:
|
|
68
|
+
# Try to return existing
|
|
69
|
+
existing_entry = self.procedural_entries_index[
|
|
70
|
+
type(entry), index
|
|
71
|
+
]
|
|
72
|
+
entry.recon_id = existing_entry.recon_id
|
|
73
|
+
return existing_entry.recon_id
|
|
74
|
+
except KeyError:
|
|
75
|
+
# Not found, create and insert
|
|
76
|
+
new_id = self._assign_procedural_id(entry)
|
|
77
|
+
self.procedural_entries[type(entry)].add(entry)
|
|
78
|
+
self.procedural_entries_index[(type(entry), index)] = entry
|
|
79
|
+
self.procedural_entry_recon_id[(type(entry), new_id)] = entry
|
|
80
|
+
return new_id
|
|
81
|
+
|
|
82
|
+
def _assign_procedural_id(self, entry: Entry) -> int:
|
|
83
|
+
""" Assign a local id to an object, returning the id.
|
|
84
|
+
"""
|
|
85
|
+
new_id = self.procedural_id_table[type(entry)]
|
|
86
|
+
entry.recon_id = new_id
|
|
87
|
+
self.procedural_id_table[type(entry)] -= 1
|
|
88
|
+
return new_id
|
|
89
|
+
|
|
90
|
+
# Row factories for Psycopg3
|
|
91
|
+
# ------------------------------------------------------------------------------------------------------------------
|
|
92
|
+
# Row factory for relation views
|
|
93
|
+
def make_relation_entry_view_row_factory[T1: Entry, T2: Entry](
|
|
94
|
+
self,
|
|
95
|
+
relationtype: type[Relation[T1, T2]]
|
|
96
|
+
) -> pg.rows.RowFactory[tuple[Relation[T1, T2], T1, T2]]:
|
|
97
|
+
|
|
98
|
+
class RelationIteratorRowFactory[T1_: Entry, T2_: Entry](pg.rows.RowFactory):
|
|
99
|
+
""" Creates a row factory which constructs rows based on the relation view with attached entries.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, cursor: pg.Cursor[Any]):
|
|
103
|
+
self.fields = [c.name for c in cursor.description]
|
|
104
|
+
|
|
105
|
+
# Create fields dict (for fields of each table to the fields of the view)
|
|
106
|
+
self.rel_fields: dict[str, str] = dict()
|
|
107
|
+
self.t1_fields = dict()
|
|
108
|
+
self.t2_fields = dict()
|
|
109
|
+
|
|
110
|
+
for field in self.fields:
|
|
111
|
+
if field == 'recon_id_1':
|
|
112
|
+
self.rel_fields[field] = 'recon_id_1'
|
|
113
|
+
self.t1_fields[field] = 'recon_id'
|
|
114
|
+
elif field == 'recon_id_2':
|
|
115
|
+
self.rel_fields[field] = 'recon_id_2'
|
|
116
|
+
self.t2_fields[field] = 'recon_id'
|
|
117
|
+
elif field.startswith('rel_'):
|
|
118
|
+
self.rel_fields[field] = field.removeprefix('rel_')
|
|
119
|
+
elif field.startswith('t1_'):
|
|
120
|
+
self.t1_fields[field] = field.removeprefix('t1_')
|
|
121
|
+
elif field.startswith('t2_'):
|
|
122
|
+
self.t2_fields[field] = field.removeprefix('t2_')
|
|
123
|
+
|
|
124
|
+
def __call__(self, values: Sequence[Any]) -> tuple[Relation[T1_, T2_], T1_, T2_]:
|
|
125
|
+
# Get fields of each object
|
|
126
|
+
rel_args: dict[str, Any] = dict()
|
|
127
|
+
t1_args: dict[str, Any] = dict()
|
|
128
|
+
t2_args: dict[str, Any] = dict()
|
|
129
|
+
for fieldname, val in zip(self.fields, values):
|
|
130
|
+
# Assign values to fields
|
|
131
|
+
if fieldname in self.rel_fields:
|
|
132
|
+
rel_args[self.rel_fields[fieldname]] = val
|
|
133
|
+
if fieldname in self.t1_fields:
|
|
134
|
+
t1_args[self.t1_fields[fieldname]] = val
|
|
135
|
+
if fieldname in self.t2_fields:
|
|
136
|
+
t2_args[self.t2_fields[fieldname]] = val
|
|
137
|
+
|
|
138
|
+
# Call constructors with the the given args
|
|
139
|
+
return (
|
|
140
|
+
relationtype(**rel_args),
|
|
141
|
+
relationtype.source_entrytype(**t1_args),
|
|
142
|
+
relationtype.target_entrytype(**t2_args)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return RelationIteratorRowFactory
|
|
146
|
+
|
|
147
|
+
# Getters
|
|
148
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
149
|
+
def get_entry_by_recon_id[T: Entry](self, table: type[T], recon_id: int) -> T:
|
|
150
|
+
""" Raises KeyError on failure. """
|
|
151
|
+
res = self.get_entries_by_recon_ids(table, [recon_id])
|
|
152
|
+
return res[recon_id]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def get_entries_by_recon_ids[T: Entry](self, table: type[T], recon_ids: list[int]) -> dict[int, T]:
|
|
156
|
+
""" Get the entries associated with the given recon ids. Raises KeyError on failure. """
|
|
157
|
+
c = self.conn.cursor(row_factory = class_row(table))
|
|
158
|
+
|
|
159
|
+
q = sql.SQL("""
|
|
160
|
+
SELECT *
|
|
161
|
+
FROM {table}
|
|
162
|
+
WHERE recon_id = %s;
|
|
163
|
+
""").format(
|
|
164
|
+
table = sql.Identifier(table.get_table_name())
|
|
165
|
+
)
|
|
166
|
+
c.executemany(
|
|
167
|
+
q,
|
|
168
|
+
params_seq = [[i] for i in recon_ids],
|
|
169
|
+
returning = True
|
|
170
|
+
)
|
|
171
|
+
cursor_results = self._extract_results_from_cursor(c)
|
|
172
|
+
return {
|
|
173
|
+
i: e
|
|
174
|
+
for i, e in zip(recon_ids, cursor_results)
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def get_entry_by_index[T: Entry](self, entry_index: T) -> T:
|
|
179
|
+
""" Given an entry with a valid index, get the corresponding full entry from the database.
|
|
180
|
+
Raise KeyError in failure.
|
|
181
|
+
"""
|
|
182
|
+
return self.get_entries_by_indices([entry_index])[0]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def get_entries_by_indices[T: Entry](self, entry_indices: list[T]) -> list[T]:
|
|
186
|
+
""" Given an entry with a valid index, get the corresponding full entry from the database.
|
|
187
|
+
Raise KeyError in failure.
|
|
188
|
+
"""
|
|
189
|
+
table = type(entry_indices[0])
|
|
190
|
+
index_cols: list[Column] = table.get_index_columns()
|
|
191
|
+
c = self.conn.cursor(row_factory = class_row(table))
|
|
192
|
+
|
|
193
|
+
q = sql.SQL("""
|
|
194
|
+
SELECT *
|
|
195
|
+
FROM {table}
|
|
196
|
+
WHERE ({index_cols}) = ({index_placeholders})
|
|
197
|
+
""").format(
|
|
198
|
+
table = sql.Identifier(table.get_table_name()),
|
|
199
|
+
index_cols = sql.SQL(', ').join(
|
|
200
|
+
sql.Identifier(c.name) for c in index_cols
|
|
201
|
+
),
|
|
202
|
+
index_placeholders = sql.SQL(', ').join(sql.SQL('%s') for _ in index_cols)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Execute and fetch
|
|
206
|
+
c.executemany(
|
|
207
|
+
q,
|
|
208
|
+
params_seq = [
|
|
209
|
+
list(entry_index.get_index_columns_with_values().values())
|
|
210
|
+
for entry_index in entry_indices
|
|
211
|
+
],
|
|
212
|
+
returning = True
|
|
213
|
+
)
|
|
214
|
+
cursor_results = self._extract_results_from_cursor(c)
|
|
215
|
+
return cursor_results
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Relation getters
|
|
219
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
220
|
+
def get_relations_by_recon_ids[T1: Entry, T2: Entry](
|
|
221
|
+
self,
|
|
222
|
+
entry_type: type[T1],
|
|
223
|
+
recon_ids: list[int],
|
|
224
|
+
relation_type: type[Relation[T1, T2]] | type[Relation[T2, T1]]
|
|
225
|
+
) -> list[list[Relation[T1,T2]]] | list[list[Relation[T2, T1]]]:
|
|
226
|
+
""" Given a list of entries, get all connected relations of a given type.
|
|
227
|
+
"""
|
|
228
|
+
# TODO dispatch to the three sub-functions.
|
|
229
|
+
raise NotImplementedError()
|
|
230
|
+
|
|
231
|
+
def get_relations_by_recon_ids_symmetric[T: Entry](
|
|
232
|
+
self,
|
|
233
|
+
recon_ids: list[int],
|
|
234
|
+
relation_type: type[Relation[T, T]]
|
|
235
|
+
) -> list[list[Relation[T, T]]]:
|
|
236
|
+
# Symmetric - either recon_id 1 or 2 can be found
|
|
237
|
+
c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
|
|
238
|
+
q = sql.SQL("""
|
|
239
|
+
SELECT *
|
|
240
|
+
FROM {table}
|
|
241
|
+
WHERE (recon_id_1 = {recon_id_placeholder}) OR (recon_id_2 = {recon_id_placeholder});
|
|
242
|
+
""").format(
|
|
243
|
+
table = sql.Identifier(relation_type.get_table_name()),
|
|
244
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
245
|
+
)
|
|
246
|
+
c.executemany(q, params_seq = [[recon_id] * 2 for recon_id in recon_ids], returning = True)
|
|
247
|
+
|
|
248
|
+
# Each input recon_id corresponds to a result set
|
|
249
|
+
relations: list[list[Relation[T, T]]] = list()
|
|
250
|
+
for _ in recon_ids:
|
|
251
|
+
relations.append(c.fetchall())
|
|
252
|
+
c.nextset()
|
|
253
|
+
|
|
254
|
+
return relations
|
|
255
|
+
|
|
256
|
+
def get_relations_by_recon_ids_of_t1[T1: Entry, T2: Entry](
|
|
257
|
+
self,
|
|
258
|
+
recon_ids: list[int],
|
|
259
|
+
relation_type: type[Relation[T1, T2]]
|
|
260
|
+
) -> list[list[Relation[T1, T2]]]:
|
|
261
|
+
c = self.conn.cursor(row_factory = class_row(relation_type))
|
|
262
|
+
q = sql.SQL("""
|
|
263
|
+
SELECT *
|
|
264
|
+
FROM {table}
|
|
265
|
+
WHERE (recon_id_1 = {recon_id_placeholder});
|
|
266
|
+
""").format(
|
|
267
|
+
table = sql.Identifier(relation_type.get_table_name()),
|
|
268
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
269
|
+
)
|
|
270
|
+
c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
|
|
271
|
+
|
|
272
|
+
# Each input recon_id corresponds to a result set
|
|
273
|
+
relations: list[list[Relation[T1, T2]]] = list()
|
|
274
|
+
for _ in recon_ids:
|
|
275
|
+
relations.append(c.fetchall())
|
|
276
|
+
c.nextset()
|
|
277
|
+
|
|
278
|
+
return relations
|
|
279
|
+
|
|
280
|
+
def get_relations_by_recon_ids_of_t2[T1: Entry, T2: Entry](
|
|
281
|
+
self,
|
|
282
|
+
recon_ids: list[int],
|
|
283
|
+
relation_type: type[Relation[T2, T1]]
|
|
284
|
+
) -> list[list[Relation[T2, T1]]]:
|
|
285
|
+
c = self.conn.cursor(row_factory = class_row(relation_type))
|
|
286
|
+
q = sql.SQL("""
|
|
287
|
+
SELECT *
|
|
288
|
+
FROM {table}
|
|
289
|
+
WHERE (recon_id_2 = {recon_id_placeholder});
|
|
290
|
+
""").format(
|
|
291
|
+
table = sql.Identifier(relation_type.get_table_name()),
|
|
292
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
293
|
+
)
|
|
294
|
+
c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
|
|
295
|
+
|
|
296
|
+
# Each input recon_id corresponds to a result set
|
|
297
|
+
relations: list[list[Relation[T1, T2]]] = list()
|
|
298
|
+
for _ in recon_ids:
|
|
299
|
+
relations.append(c.fetchall())
|
|
300
|
+
c.nextset()
|
|
301
|
+
|
|
302
|
+
return relations
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# Relation and entry getters
|
|
306
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
307
|
+
def get_relations_with_entries_by_recon_ids[T1: Entry, T2: Entry](
|
|
308
|
+
self,
|
|
309
|
+
entry_type: type[T1],
|
|
310
|
+
recon_ids: list[int],
|
|
311
|
+
relation_type: type[Relation[T1, T2]] | type[Relation[T2, T1]],
|
|
312
|
+
) -> list[dict[T2, list[Relation[T1, T2]] | list[Relation[T2, T1]]]]:
|
|
313
|
+
""" Given a list of entries, get connected relations of the given type.
|
|
314
|
+
Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
|
|
315
|
+
to the relation(s) with which they are connected.
|
|
316
|
+
"""
|
|
317
|
+
# First, establish the types
|
|
318
|
+
t2_entrytype = (
|
|
319
|
+
relation_type.source_entrytype if relation_type.target_entrytype is entry_type
|
|
320
|
+
else relation_type.target_entrytype
|
|
321
|
+
)
|
|
322
|
+
is_symmetric = (entry_type is t2_entrytype)
|
|
323
|
+
|
|
324
|
+
# Dispatch to sub-functions depending on types
|
|
325
|
+
if is_symmetric:
|
|
326
|
+
# Case 1: Symmetric - either recon_id 1 or 2 can be found
|
|
327
|
+
return self.get_relations_with_entries_by_recon_ids_symmetric(entry_type, recon_ids, relation_type)
|
|
328
|
+
else:
|
|
329
|
+
if relation_type.source_entrytype is entry_type:
|
|
330
|
+
# Case 2: Relation is [T1, T2]
|
|
331
|
+
return self.get_relations_with_entries_by_recon_ids_of_t1(entry_type, recon_ids, relation_type)
|
|
332
|
+
elif relation_type.target_entrytype is entry_type:
|
|
333
|
+
# Case 3: Relation is [T2, T1]
|
|
334
|
+
return self.get_relations_with_entries_by_recon_ids_of_t2(entry_type, recon_ids, relation_type)
|
|
335
|
+
else:
|
|
336
|
+
# Invalid shape of entrytype and relation?
|
|
337
|
+
raise ValueError()
|
|
338
|
+
|
|
339
|
+
# Specialized getters for different entry/relation cases
|
|
340
|
+
def get_relations_with_entries_by_recon_ids_symmetric[T: Entry](
|
|
341
|
+
self,
|
|
342
|
+
entry_type: type[T],
|
|
343
|
+
recon_ids: list[int],
|
|
344
|
+
relation_type: type[Relation[T, T]],
|
|
345
|
+
) -> list[dict[T, list[Relation[T, T]]]]:
|
|
346
|
+
""" Specialization for symmetric relations.
|
|
347
|
+
"""
|
|
348
|
+
c = self.conn.cursor(row_factory = class_row(relation_type))
|
|
349
|
+
|
|
350
|
+
# TODO REDO, call above functions
|
|
351
|
+
# TODO get by views!
|
|
352
|
+
|
|
353
|
+
# Symmetric - either recon_id 1 or 2 can be found
|
|
354
|
+
q = sql.SQL("""
|
|
355
|
+
SELECT *
|
|
356
|
+
FROM {table}
|
|
357
|
+
WHERE (recon_id_1 = {recon_id_placeholder}) OR (recon_id_2 = {recon_id_placeholder});
|
|
358
|
+
""").format(
|
|
359
|
+
table = sql.Identifier(relation_type.get_table_name()),
|
|
360
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
361
|
+
)
|
|
362
|
+
c.executemany(q, params_seq = [[recon_id] * 2 for recon_id in recon_ids], returning = True)
|
|
363
|
+
|
|
364
|
+
# Each input recon_id corresponds to a result set
|
|
365
|
+
entry_relations: dict[int, list[Relation]] = dict()
|
|
366
|
+
for recon_id in recon_ids:
|
|
367
|
+
entry_relations[recon_id] = c.fetchall()
|
|
368
|
+
c.nextset()
|
|
369
|
+
|
|
370
|
+
# Get entries
|
|
371
|
+
results: list[dict[T, list[Relation[T, T]]]] = list()
|
|
372
|
+
|
|
373
|
+
for input_recon_id, rels in entry_relations.items():
|
|
374
|
+
# For each source recon id, run a query to get relation targets
|
|
375
|
+
subresults: dict[T, list[Relation[T, T]]] = defaultdict(list)
|
|
376
|
+
target_recon_ids = [ # List corresponding to the list of relations rels
|
|
377
|
+
rel.recon_id_1 if rel.recon_id_2 == input_recon_id else rel.recon_id_2
|
|
378
|
+
for rel in rels
|
|
379
|
+
]
|
|
380
|
+
q = sql.SQL("""
|
|
381
|
+
SELECT *
|
|
382
|
+
FROM {t2_table}
|
|
383
|
+
WHERE recon_id = %s;
|
|
384
|
+
""").format(
|
|
385
|
+
t2_table = sql.Identifier(entry_type.get_table_name())
|
|
386
|
+
)
|
|
387
|
+
c = self.conn.cursor(row_factory = class_row(entry_type))
|
|
388
|
+
c.executemany(q, params_seq = [[tid] for tid in target_recon_ids], returning = True)
|
|
389
|
+
target_entries: list[T] = self._extract_results_from_cursor(c)
|
|
390
|
+
|
|
391
|
+
# Add to results
|
|
392
|
+
for rel, target_entry in zip(rels, target_entries):
|
|
393
|
+
subresults[target_entry].append(rel)
|
|
394
|
+
results.append(subresults)
|
|
395
|
+
|
|
396
|
+
# Return
|
|
397
|
+
return results
|
|
398
|
+
|
|
399
|
+
def get_relations_with_entries_by_recon_ids_of_t1[T1: Entry, T2: Entry](
|
|
400
|
+
self,
|
|
401
|
+
entry_type: type[T1],
|
|
402
|
+
recon_ids: list[int],
|
|
403
|
+
relation_type: type[Relation[T1, T2]],
|
|
404
|
+
) -> list[dict[T2, list[Relation[T1, T2]]]]:
|
|
405
|
+
""" Given a list of entries, get connected relations of the given type.
|
|
406
|
+
Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
|
|
407
|
+
to the relation(s) with which they are connected.
|
|
408
|
+
"""
|
|
409
|
+
c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
|
|
410
|
+
q = sql.SQL("""
|
|
411
|
+
SELECT *
|
|
412
|
+
FROM {view}
|
|
413
|
+
WHERE (recon_id_1 = {recon_id_placeholder});
|
|
414
|
+
""").format(
|
|
415
|
+
view = sql.Identifier(f'{relation_type.get_table_name()}_v'),
|
|
416
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
417
|
+
)
|
|
418
|
+
c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
|
|
419
|
+
|
|
420
|
+
# Process output
|
|
421
|
+
result: list[dict[T2, list[Relation[T1, T2]]]] = list()
|
|
422
|
+
cursor_result = self._extract_results_from_cursor_by_sets(c)
|
|
423
|
+
for recon_id, relation_list in zip(recon_ids, cursor_result):
|
|
424
|
+
subdict: dict[T2, list[Relation[T1, T2]]] = defaultdict(list)
|
|
425
|
+
for relation, entry_1, entry_2 in relation_list:
|
|
426
|
+
assert entry_1.recon_id == recon_id
|
|
427
|
+
subdict[entry_2].append(relation)
|
|
428
|
+
result.append(subdict)
|
|
429
|
+
return result
|
|
430
|
+
|
|
431
|
+
def get_relations_with_entries_by_recon_ids_of_t2[T1: Entry, T2: Entry](
|
|
432
|
+
self,
|
|
433
|
+
entry_type: type[T1],
|
|
434
|
+
recon_ids: list[int],
|
|
435
|
+
relation_type: type[Relation[T2, T1]],
|
|
436
|
+
) -> list[dict[T2, list[Relation[T2, T1]]]]:
|
|
437
|
+
""" Given a list of entries, get connected relations of the given type.
|
|
438
|
+
Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
|
|
439
|
+
to the relation(s) with which they are connected.
|
|
440
|
+
"""
|
|
441
|
+
c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
|
|
442
|
+
q = sql.SQL("""
|
|
443
|
+
SELECT *
|
|
444
|
+
FROM {view}
|
|
445
|
+
WHERE (recon_id_2 = {recon_id_placeholder});
|
|
446
|
+
""").format(
|
|
447
|
+
view = sql.Identifier(f'{relation_type.get_table_name()}_v'),
|
|
448
|
+
recon_id_placeholder = sql.SQL('%s')
|
|
449
|
+
)
|
|
450
|
+
c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
|
|
451
|
+
|
|
452
|
+
# Process output
|
|
453
|
+
result: list[dict[T2, list[Relation[T2, T1]]]] = list()
|
|
454
|
+
cursor_result = self._extract_results_from_cursor_by_sets(c)
|
|
455
|
+
for recon_id, relation_list in zip(recon_ids, cursor_result):
|
|
456
|
+
subdict: dict[T2, list[Relation[T2, T1]]] = defaultdict(list)
|
|
457
|
+
for relation, entry_2, entry_1 in relation_list:
|
|
458
|
+
assert entry_1.recon_id == recon_id
|
|
459
|
+
subdict[entry_2].append(relation)
|
|
460
|
+
result.append(subdict)
|
|
461
|
+
return result
|
|
462
|
+
|
|
463
|
+
# Graph exploration getters
|
|
464
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
465
|
+
# These are the main queries used to efficiently explore the graph.
|
|
466
|
+
def explore_vertex[T: Entry](
|
|
467
|
+
self,
|
|
468
|
+
entry: T,
|
|
469
|
+
relation_types: type[Relation[T, Any]]
|
|
470
|
+
) -> list[tuple[Entry, list[Relation]]]:
|
|
471
|
+
""" Explores the neighbouring entries to a given entry.
|
|
472
|
+
Results are returned as a list of tuples, each containing a neighbouring entry and the relations
|
|
473
|
+
which point to it.
|
|
474
|
+
Note that the input entry may be given in the output in the case of reflexive relations.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
# TODO Run get_relations_with_entries_from_source() for each allowed relation, and collate the results
|
|
478
|
+
|
|
479
|
+
# TODO
|
|
480
|
+
raise NotImplementedError()
|
|
481
|
+
|
|
482
|
+
# Misc
|
|
483
|
+
# ----------------------------------------------------------------------------------------------------------
|
|
484
|
+
def get_number_of_entries[T: DatabaseObject](self, table: type[T]) -> int:
|
|
485
|
+
""" Return hte number of entries in the table"""
|
|
486
|
+
table_name = table.get_table_name()
|
|
487
|
+
q = sql.SQL("""SELECT count(*) FROM {table}""").format(
|
|
488
|
+
table = sql.Identifier(table_name)
|
|
489
|
+
)
|
|
490
|
+
self.c.execute(q)
|
|
491
|
+
result: int = self.c.fetchone()[0]
|
|
492
|
+
return result
|
|
493
|
+
|
|
494
|
+
def iterate_table[T: DatabaseObject](
|
|
495
|
+
self,
|
|
496
|
+
table: type[T],
|
|
497
|
+
batch_size: int = 1000,
|
|
498
|
+
up_to: int = 0
|
|
499
|
+
) -> Generator[T, None, None]:
|
|
500
|
+
""" Creates an iterator which queries all entries in a table.
|
|
501
|
+
Entries are requested from the database in batches, but returned as a continuous iterator.
|
|
502
|
+
"""
|
|
503
|
+
i: int = 0
|
|
504
|
+
batch: list[T] = list()
|
|
505
|
+
|
|
506
|
+
with self.conn.cursor(row_factory = class_row(table)) as cursor:
|
|
507
|
+
q = sql.SQL("""
|
|
508
|
+
SELECT *
|
|
509
|
+
FROM {table}
|
|
510
|
+
;
|
|
511
|
+
""").format(
|
|
512
|
+
table = sql.Identifier(table.get_table_name())
|
|
513
|
+
)
|
|
514
|
+
try:
|
|
515
|
+
cursor.execute(q)
|
|
516
|
+
except Exception as e:
|
|
517
|
+
raise RuntimeError('!: {e}')
|
|
518
|
+
|
|
519
|
+
while True:
|
|
520
|
+
# Main loop, fetch batches and yield from the batches
|
|
521
|
+
batch = cursor.fetchmany(size = batch_size)
|
|
522
|
+
if not batch:
|
|
523
|
+
break
|
|
524
|
+
item: T
|
|
525
|
+
for item in batch:
|
|
526
|
+
i += 1
|
|
527
|
+
yield item
|
|
528
|
+
if up_to and i > up_to:
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def iterate_relation_with_entries[T1: Entry, T2: Entry](
|
|
533
|
+
self,
|
|
534
|
+
relationtype: type[Relation[T1, T2]],
|
|
535
|
+
batch_size: int = 1000,
|
|
536
|
+
up_to: int = 0
|
|
537
|
+
) -> Generator[tuple[Relation[T1, T2], T1, T2], None, None]:
|
|
538
|
+
""" Iterates the view over a relation table, getting entries adjacent to each relation.
|
|
539
|
+
"""
|
|
540
|
+
i: int = 0
|
|
541
|
+
batch: list = list()
|
|
542
|
+
|
|
543
|
+
# TODO make a relationIteratorRowFactory and use
|
|
544
|
+
|
|
545
|
+
# With the factory created, iterate the rows
|
|
546
|
+
with self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relationtype)) as cursor:
|
|
547
|
+
q = sql.SQL("""
|
|
548
|
+
SELECT *
|
|
549
|
+
FROM {view}
|
|
550
|
+
;
|
|
551
|
+
""").format(
|
|
552
|
+
view = sql.Identifier(f'{relationtype.get_table_name()}_v')
|
|
553
|
+
)
|
|
554
|
+
try:
|
|
555
|
+
cursor.execute(q)
|
|
556
|
+
except Exception as e:
|
|
557
|
+
raise RuntimeError('!: {e}')
|
|
558
|
+
|
|
559
|
+
while True:
|
|
560
|
+
# Main loop, fetch batches and yield from the batches
|
|
561
|
+
batch = cursor.fetchmany(size = batch_size)
|
|
562
|
+
if not batch:
|
|
563
|
+
break
|
|
564
|
+
item: tuple[Relation[T1, T2], T1, T2]
|
|
565
|
+
for item in batch:
|
|
566
|
+
i += 1
|
|
567
|
+
yield item
|
|
568
|
+
if up_to and i > up_to:
|
|
569
|
+
return
|
|
570
|
+
|
|
571
|
+
def _extract_results_from_cursor(self, cursor: Optional[pg.Cursor] = None) -> list[Any]:
|
|
572
|
+
""" Iterates over all result sets of the Psycopg3 cursor and returns.
|
|
573
|
+
"""
|
|
574
|
+
c = cursor or self.c
|
|
575
|
+
result: list[Any] = [c.fetchone()]
|
|
576
|
+
while c.nextset():
|
|
577
|
+
result.append(c.fetchone())
|
|
578
|
+
return result
|
|
579
|
+
|
|
580
|
+
def _extract_results_from_cursor_by_sets(self, cursor: Optional[pg.Cursor] = None) -> list[list[Any]]:
|
|
581
|
+
""" Iterates over all result sets of the Psycopg3 cursor and returns.
|
|
582
|
+
"""
|
|
583
|
+
c = cursor or self.c
|
|
584
|
+
result: list[list[Any]] = [c.fetchall()]
|
|
585
|
+
while c.nextset():
|
|
586
|
+
result.append(c.fetchall())
|
|
587
|
+
return result
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from chemrecon.database.params import parameter_sets
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
""" Implements database connection, cursor, etc. functionality.
|
|
2
|
+
"""
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import psycopg as pg
|
|
6
|
+
import psycopg.sql as sql
|
|
7
|
+
from psycopg.types.enum import EnumInfo, register_enum
|
|
8
|
+
|
|
9
|
+
import chemrecon.schema as schema
|
|
10
|
+
from chemrecon.database.params import Params
|
|
11
|
+
|
|
12
|
+
# Initialisation functions
|
|
13
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
|
14
|
+
def postgres_connect(
|
|
15
|
+
params: Params,
|
|
16
|
+
initialise_enums: bool = True
|
|
17
|
+
) -> pg.Connection:
|
|
18
|
+
""" Connect to the database server.
|
|
19
|
+
Connection string given relative to the working directory (chemrecon/).
|
|
20
|
+
If find_directory is set, will automatically try to locate the chemrecon directory
|
|
21
|
+
"""
|
|
22
|
+
# Load parameters
|
|
23
|
+
print(f'[ChemRecon] Attempting to connect with string: {params.connection_string()}')
|
|
24
|
+
try:
|
|
25
|
+
conn: pg.Connection = pg.connect(
|
|
26
|
+
params.connection_string()
|
|
27
|
+
)
|
|
28
|
+
except pg.OperationalError as e:
|
|
29
|
+
print(f'[ChemRecon] Could not connect: {e}')
|
|
30
|
+
raise ConnectionError from e
|
|
31
|
+
|
|
32
|
+
print(f'[ChemRecon] Connection succesful.')
|
|
33
|
+
|
|
34
|
+
# Set correct schemata to search
|
|
35
|
+
schemata = [
|
|
36
|
+
'chemrecon',
|
|
37
|
+
'pg_catalog',
|
|
38
|
+
'meta',
|
|
39
|
+
'public'
|
|
40
|
+
]
|
|
41
|
+
conn.execute(
|
|
42
|
+
sql.SQL("SET {search_path} TO {schemata}").format(
|
|
43
|
+
search_path = sql.Identifier('search_path'),
|
|
44
|
+
schemata = sql.SQL(', ').join(
|
|
45
|
+
[sql.Identifier(s) for s in schemata]
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
conn.commit()
|
|
50
|
+
|
|
51
|
+
# Register enum mapping
|
|
52
|
+
if initialise_enums:
|
|
53
|
+
conn.commit()
|
|
54
|
+
for enum in schema.enums.enum_register:
|
|
55
|
+
# enum_info = EnumInfo.fetch(conn, f'chemrecon.{enum.__name__}')
|
|
56
|
+
enum_info = EnumInfo.fetch(conn, sql.Identifier(enum.__name__))
|
|
57
|
+
if enum_info is None:
|
|
58
|
+
raise RuntimeError(f'Enum info none for enum {enum}.')
|
|
59
|
+
register_enum(enum_info, conn, enum)
|
|
60
|
+
|
|
61
|
+
# Done
|
|
62
|
+
conn.commit()
|
|
63
|
+
return conn
|