chemrecon 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. chemrecon/__init__.py +73 -0
  2. chemrecon/chem/__init__.py +0 -0
  3. chemrecon/chem/chemreaction.py +223 -0
  4. chemrecon/chem/constant_compounds.py +3 -0
  5. chemrecon/chem/create_mol.py +91 -0
  6. chemrecon/chem/elements.py +141 -0
  7. chemrecon/chem/gml/__init__.py +0 -0
  8. chemrecon/chem/gml/gml.py +324 -0
  9. chemrecon/chem/gml/gml_reactant_matching.py +130 -0
  10. chemrecon/chem/gml/gml_to_rdk.py +217 -0
  11. chemrecon/chem/mol.py +483 -0
  12. chemrecon/chem/sumformula.py +120 -0
  13. chemrecon/connection.py +97 -0
  14. chemrecon/core/__init__.py +0 -0
  15. chemrecon/core/id_types.py +687 -0
  16. chemrecon/core/ontology.py +209 -0
  17. chemrecon/core/populate_query_handler.py +336 -0
  18. chemrecon/core/query_handler.py +587 -0
  19. chemrecon/database/__init__.py +1 -0
  20. chemrecon/database/connect.py +63 -0
  21. chemrecon/database/connection_params/chemrecon_pub.dbinfo +5 -0
  22. chemrecon/database/connection_params/local_docker_dev.dbinfo +5 -0
  23. chemrecon/database/connection_params/local_docker_init.dbinfo +5 -0
  24. chemrecon/database/connection_params/local_docker_pub.dbinfo +5 -0
  25. chemrecon/database/params.py +88 -0
  26. chemrecon/entrygraph/draw.py +119 -0
  27. chemrecon/entrygraph/entrygraph.py +301 -0
  28. chemrecon/entrygraph/explorationprotocol.py +199 -0
  29. chemrecon/entrygraph/explore.py +421 -0
  30. chemrecon/entrygraph/explore_procedure.py +183 -0
  31. chemrecon/entrygraph/filter.py +88 -0
  32. chemrecon/entrygraph/scoring.py +141 -0
  33. chemrecon/query/__init__.py +26 -0
  34. chemrecon/query/create_entry.py +86 -0
  35. chemrecon/query/default_protocols.py +57 -0
  36. chemrecon/query/find_entry.py +84 -0
  37. chemrecon/query/get_relations.py +143 -0
  38. chemrecon/query/get_structures_from_compound.py +65 -0
  39. chemrecon/schema/__init__.py +86 -0
  40. chemrecon/schema/db_object.py +363 -0
  41. chemrecon/schema/direction.py +10 -0
  42. chemrecon/schema/entry_types/__init__.py +0 -0
  43. chemrecon/schema/entry_types/aam.py +34 -0
  44. chemrecon/schema/entry_types/aam_repr.py +37 -0
  45. chemrecon/schema/entry_types/compound.py +52 -0
  46. chemrecon/schema/entry_types/enzyme.py +49 -0
  47. chemrecon/schema/entry_types/molstructure.py +64 -0
  48. chemrecon/schema/entry_types/molstructure_repr.py +41 -0
  49. chemrecon/schema/entry_types/reaction.py +57 -0
  50. chemrecon/schema/enums.py +154 -0
  51. chemrecon/schema/procedural_relation_entrygraph.py +66 -0
  52. chemrecon/schema/relation_types_composed/__init__.py +0 -0
  53. chemrecon/schema/relation_types_composed/compound_has_molstructure_relation.py +59 -0
  54. chemrecon/schema/relation_types_composed/reaction_has_aam_relation.py +50 -0
  55. chemrecon/schema/relation_types_procedural/__init__.py +0 -0
  56. chemrecon/schema/relation_types_procedural/aam_convert_relation.py +69 -0
  57. chemrecon/schema/relation_types_procedural/compound_select_structure_proceduralrelation.py +36 -0
  58. chemrecon/schema/relation_types_procedural/compound_similarlity_proceduralrelation.py +1 -0
  59. chemrecon/schema/relation_types_procedural/molstructure_convert_relation.py +49 -0
  60. chemrecon/schema/relation_types_procedural/reaction_select_aam_proceduralrelation.py +38 -0
  61. chemrecon/schema/relation_types_procedural/reaction_similarity_proceduralrelation.py +1 -0
  62. chemrecon/schema/relation_types_source/__init__.py +0 -0
  63. chemrecon/schema/relation_types_source/aam_involves_molstructure_relation.py +77 -0
  64. chemrecon/schema/relation_types_source/aam_repr_involves_molstructure_repr_relation.py +79 -0
  65. chemrecon/schema/relation_types_source/compound_has_structure_representation_relation.py +33 -0
  66. chemrecon/schema/relation_types_source/compound_reference_relation.py +34 -0
  67. chemrecon/schema/relation_types_source/molstructure_standardisation_relation.py +71 -0
  68. chemrecon/schema/relation_types_source/ontology/__init__.py +0 -0
  69. chemrecon/schema/relation_types_source/ontology/compound_ontology.py +369 -0
  70. chemrecon/schema/relation_types_source/ontology/enzyme_ontology.py +142 -0
  71. chemrecon/schema/relation_types_source/ontology/reaction_ontology.py +140 -0
  72. chemrecon/schema/relation_types_source/reaction_has_aam_representation_relation.py +34 -0
  73. chemrecon/schema/relation_types_source/reaction_has_enzyme_relation.py +71 -0
  74. chemrecon/schema/relation_types_source/reaction_involves_compound_relation.py +69 -0
  75. chemrecon/schema/relation_types_source/reaction_reference_relation.py +33 -0
  76. chemrecon/scripts/initialize_database.py +494 -0
  77. chemrecon/utils/copy_signature.py +10 -0
  78. chemrecon/utils/encodeable_list.py +11 -0
  79. chemrecon/utils/get_id_type.py +70 -0
  80. chemrecon/utils/hungarian.py +31 -0
  81. chemrecon/utils/reactant_matching.py +168 -0
  82. chemrecon/utils/rxnutils.py +44 -0
  83. chemrecon/utils/set_cwd.py +12 -0
  84. chemrecon-0.1.1.dist-info/METADATA +143 -0
  85. chemrecon-0.1.1.dist-info/RECORD +86 -0
  86. chemrecon-0.1.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,587 @@
1
+ from collections import defaultdict
2
+ from typing import Any, DefaultDict, Generator, Optional, Sequence
3
+
4
+ import psycopg as pg
5
+ import psycopg.sql as sql
6
+ from psycopg.rows import class_row
7
+
8
+ from chemrecon.schema.db_object import Column, DatabaseObject, Entry, Relation
9
+ import chemrecon.schema as schema
10
+
11
+
12
+ class QueryHandler():
13
+ """ Context manager to wrap database queries.
14
+ """
15
+ conn: pg.Connection
16
+ c: pg.Cursor
17
+
18
+ # The handler also keeps track of a parallel 'database' of procedurally generated entries
19
+ procedural_entries: dict[
20
+ type[Entry], set[Entry]
21
+ ]
22
+ procedural_entries_index: dict[
23
+ tuple[type[Entry], tuple],
24
+ Entry
25
+ ]
26
+ procedural_entry_recon_id: dict[
27
+ tuple[type[Entry], int],
28
+ Entry
29
+ ]
30
+ procedural_id_table: dict[type[Entry], int] # For assigning procedural ids to entries created locally
31
+
32
+
33
+ # Context manager
34
+ # ----------------------------------------------------------------------------------------------------------
35
+ def __init__(self, connection: pg.Connection):
36
+ self.conn = connection
37
+ self.c = self.conn.cursor()
38
+
39
+ self.procedural_entries = DefaultDict(set)
40
+ self.procedural_entries_index = dict()
41
+ self.procedural_entry_recon_id = dict()
42
+ self.procedural_id_table = {
43
+ e_type: -1 # Start at -1 and decrement
44
+ for e_type in schema.entrytypes
45
+ }
46
+
47
+ def __enter__(self):
48
+ # Enter, returning the handler as the interface
49
+ return self
50
+
51
+ def __exit__(self, exc_type, exc_val, exc_tb):
52
+ # On leaving the context, execute the commands.
53
+ if exc_type is not None:
54
+ # Exception was raised in context
55
+ raise exc_val
56
+ self.conn.commit()
57
+ self.c.close()
58
+
59
+ # Procedural entries
60
+ # ------------------------------------------------------------------------------------------------------------------
61
+ def add_procedural_entry[T: Entry](self, entry: T) -> int:
62
+ """ Adds a procedural entry, and returns its virtual recon id. If already exists,
63
+ return the virtual recon id.
64
+ """
65
+ # TODO does the fact that index is not an ordereddict cause problems?
66
+ index = tuple(val for val in entry.get_index_columns_with_values().values())
67
+ try:
68
+ # Try to return existing
69
+ existing_entry = self.procedural_entries_index[
70
+ type(entry), index
71
+ ]
72
+ entry.recon_id = existing_entry.recon_id
73
+ return existing_entry.recon_id
74
+ except KeyError:
75
+ # Not found, create and insert
76
+ new_id = self._assign_procedural_id(entry)
77
+ self.procedural_entries[type(entry)].add(entry)
78
+ self.procedural_entries_index[(type(entry), index)] = entry
79
+ self.procedural_entry_recon_id[(type(entry), new_id)] = entry
80
+ return new_id
81
+
82
+ def _assign_procedural_id(self, entry: Entry) -> int:
83
+ """ Assign a local id to an object, returning the id.
84
+ """
85
+ new_id = self.procedural_id_table[type(entry)]
86
+ entry.recon_id = new_id
87
+ self.procedural_id_table[type(entry)] -= 1
88
+ return new_id
89
+
90
+ # Row factories for Psycopg3
91
+ # ------------------------------------------------------------------------------------------------------------------
92
+ # Row factory for relation views
93
+ def make_relation_entry_view_row_factory[T1: Entry, T2: Entry](
94
+ self,
95
+ relationtype: type[Relation[T1, T2]]
96
+ ) -> pg.rows.RowFactory[tuple[Relation[T1, T2], T1, T2]]:
97
+
98
+ class RelationIteratorRowFactory[T1_: Entry, T2_: Entry](pg.rows.RowFactory):
99
+ """ Creates a row factory which constructs rows based on the relation view with attached entries.
100
+ """
101
+
102
+ def __init__(self, cursor: pg.Cursor[Any]):
103
+ self.fields = [c.name for c in cursor.description]
104
+
105
+ # Create fields dict (for fields of each table to the fields of the view)
106
+ self.rel_fields: dict[str, str] = dict()
107
+ self.t1_fields = dict()
108
+ self.t2_fields = dict()
109
+
110
+ for field in self.fields:
111
+ if field == 'recon_id_1':
112
+ self.rel_fields[field] = 'recon_id_1'
113
+ self.t1_fields[field] = 'recon_id'
114
+ elif field == 'recon_id_2':
115
+ self.rel_fields[field] = 'recon_id_2'
116
+ self.t2_fields[field] = 'recon_id'
117
+ elif field.startswith('rel_'):
118
+ self.rel_fields[field] = field.removeprefix('rel_')
119
+ elif field.startswith('t1_'):
120
+ self.t1_fields[field] = field.removeprefix('t1_')
121
+ elif field.startswith('t2_'):
122
+ self.t2_fields[field] = field.removeprefix('t2_')
123
+
124
+ def __call__(self, values: Sequence[Any]) -> tuple[Relation[T1_, T2_], T1_, T2_]:
125
+ # Get fields of each object
126
+ rel_args: dict[str, Any] = dict()
127
+ t1_args: dict[str, Any] = dict()
128
+ t2_args: dict[str, Any] = dict()
129
+ for fieldname, val in zip(self.fields, values):
130
+ # Assign values to fields
131
+ if fieldname in self.rel_fields:
132
+ rel_args[self.rel_fields[fieldname]] = val
133
+ if fieldname in self.t1_fields:
134
+ t1_args[self.t1_fields[fieldname]] = val
135
+ if fieldname in self.t2_fields:
136
+ t2_args[self.t2_fields[fieldname]] = val
137
+
138
+ # Call constructors with the the given args
139
+ return (
140
+ relationtype(**rel_args),
141
+ relationtype.source_entrytype(**t1_args),
142
+ relationtype.target_entrytype(**t2_args)
143
+ )
144
+
145
+ return RelationIteratorRowFactory
146
+
147
+ # Getters
148
+ # ----------------------------------------------------------------------------------------------------------
149
+ def get_entry_by_recon_id[T: Entry](self, table: type[T], recon_id: int) -> T:
150
+ """ Raises KeyError on failure. """
151
+ res = self.get_entries_by_recon_ids(table, [recon_id])
152
+ return res[recon_id]
153
+
154
+
155
+ def get_entries_by_recon_ids[T: Entry](self, table: type[T], recon_ids: list[int]) -> dict[int, T]:
156
+ """ Get the entries associated with the given recon ids. Raises KeyError on failure. """
157
+ c = self.conn.cursor(row_factory = class_row(table))
158
+
159
+ q = sql.SQL("""
160
+ SELECT *
161
+ FROM {table}
162
+ WHERE recon_id = %s;
163
+ """).format(
164
+ table = sql.Identifier(table.get_table_name())
165
+ )
166
+ c.executemany(
167
+ q,
168
+ params_seq = [[i] for i in recon_ids],
169
+ returning = True
170
+ )
171
+ cursor_results = self._extract_results_from_cursor(c)
172
+ return {
173
+ i: e
174
+ for i, e in zip(recon_ids, cursor_results)
175
+ }
176
+
177
+
178
+ def get_entry_by_index[T: Entry](self, entry_index: T) -> T:
179
+ """ Given an entry with a valid index, get the corresponding full entry from the database.
180
+ Raise KeyError in failure.
181
+ """
182
+ return self.get_entries_by_indices([entry_index])[0]
183
+
184
+
185
+ def get_entries_by_indices[T: Entry](self, entry_indices: list[T]) -> list[T]:
186
+ """ Given an entry with a valid index, get the corresponding full entry from the database.
187
+ Raise KeyError in failure.
188
+ """
189
+ table = type(entry_indices[0])
190
+ index_cols: list[Column] = table.get_index_columns()
191
+ c = self.conn.cursor(row_factory = class_row(table))
192
+
193
+ q = sql.SQL("""
194
+ SELECT *
195
+ FROM {table}
196
+ WHERE ({index_cols}) = ({index_placeholders})
197
+ """).format(
198
+ table = sql.Identifier(table.get_table_name()),
199
+ index_cols = sql.SQL(', ').join(
200
+ sql.Identifier(c.name) for c in index_cols
201
+ ),
202
+ index_placeholders = sql.SQL(', ').join(sql.SQL('%s') for _ in index_cols)
203
+ )
204
+
205
+ # Execute and fetch
206
+ c.executemany(
207
+ q,
208
+ params_seq = [
209
+ list(entry_index.get_index_columns_with_values().values())
210
+ for entry_index in entry_indices
211
+ ],
212
+ returning = True
213
+ )
214
+ cursor_results = self._extract_results_from_cursor(c)
215
+ return cursor_results
216
+
217
+
218
+ # Relation getters
219
+ # ----------------------------------------------------------------------------------------------------------
220
+ def get_relations_by_recon_ids[T1: Entry, T2: Entry](
221
+ self,
222
+ entry_type: type[T1],
223
+ recon_ids: list[int],
224
+ relation_type: type[Relation[T1, T2]] | type[Relation[T2, T1]]
225
+ ) -> list[list[Relation[T1,T2]]] | list[list[Relation[T2, T1]]]:
226
+ """ Given a list of entries, get all connected relations of a given type.
227
+ """
228
+ # TODO dispatch to the three sub-functions.
229
+ raise NotImplementedError()
230
+
231
+ def get_relations_by_recon_ids_symmetric[T: Entry](
232
+ self,
233
+ recon_ids: list[int],
234
+ relation_type: type[Relation[T, T]]
235
+ ) -> list[list[Relation[T, T]]]:
236
+ # Symmetric - either recon_id 1 or 2 can be found
237
+ c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
238
+ q = sql.SQL("""
239
+ SELECT *
240
+ FROM {table}
241
+ WHERE (recon_id_1 = {recon_id_placeholder}) OR (recon_id_2 = {recon_id_placeholder});
242
+ """).format(
243
+ table = sql.Identifier(relation_type.get_table_name()),
244
+ recon_id_placeholder = sql.SQL('%s')
245
+ )
246
+ c.executemany(q, params_seq = [[recon_id] * 2 for recon_id in recon_ids], returning = True)
247
+
248
+ # Each input recon_id corresponds to a result set
249
+ relations: list[list[Relation[T, T]]] = list()
250
+ for _ in recon_ids:
251
+ relations.append(c.fetchall())
252
+ c.nextset()
253
+
254
+ return relations
255
+
256
+ def get_relations_by_recon_ids_of_t1[T1: Entry, T2: Entry](
257
+ self,
258
+ recon_ids: list[int],
259
+ relation_type: type[Relation[T1, T2]]
260
+ ) -> list[list[Relation[T1, T2]]]:
261
+ c = self.conn.cursor(row_factory = class_row(relation_type))
262
+ q = sql.SQL("""
263
+ SELECT *
264
+ FROM {table}
265
+ WHERE (recon_id_1 = {recon_id_placeholder});
266
+ """).format(
267
+ table = sql.Identifier(relation_type.get_table_name()),
268
+ recon_id_placeholder = sql.SQL('%s')
269
+ )
270
+ c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
271
+
272
+ # Each input recon_id corresponds to a result set
273
+ relations: list[list[Relation[T1, T2]]] = list()
274
+ for _ in recon_ids:
275
+ relations.append(c.fetchall())
276
+ c.nextset()
277
+
278
+ return relations
279
+
280
+ def get_relations_by_recon_ids_of_t2[T1: Entry, T2: Entry](
281
+ self,
282
+ recon_ids: list[int],
283
+ relation_type: type[Relation[T2, T1]]
284
+ ) -> list[list[Relation[T2, T1]]]:
285
+ c = self.conn.cursor(row_factory = class_row(relation_type))
286
+ q = sql.SQL("""
287
+ SELECT *
288
+ FROM {table}
289
+ WHERE (recon_id_2 = {recon_id_placeholder});
290
+ """).format(
291
+ table = sql.Identifier(relation_type.get_table_name()),
292
+ recon_id_placeholder = sql.SQL('%s')
293
+ )
294
+ c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
295
+
296
+ # Each input recon_id corresponds to a result set
297
+ relations: list[list[Relation[T1, T2]]] = list()
298
+ for _ in recon_ids:
299
+ relations.append(c.fetchall())
300
+ c.nextset()
301
+
302
+ return relations
303
+
304
+
305
+ # Relation and entry getters
306
+ # ----------------------------------------------------------------------------------------------------------
307
+ def get_relations_with_entries_by_recon_ids[T1: Entry, T2: Entry](
308
+ self,
309
+ entry_type: type[T1],
310
+ recon_ids: list[int],
311
+ relation_type: type[Relation[T1, T2]] | type[Relation[T2, T1]],
312
+ ) -> list[dict[T2, list[Relation[T1, T2]] | list[Relation[T2, T1]]]]:
313
+ """ Given a list of entries, get connected relations of the given type.
314
+ Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
315
+ to the relation(s) with which they are connected.
316
+ """
317
+ # First, establish the types
318
+ t2_entrytype = (
319
+ relation_type.source_entrytype if relation_type.target_entrytype is entry_type
320
+ else relation_type.target_entrytype
321
+ )
322
+ is_symmetric = (entry_type is t2_entrytype)
323
+
324
+ # Dispatch to sub-functions depending on types
325
+ if is_symmetric:
326
+ # Case 1: Symmetric - either recon_id 1 or 2 can be found
327
+ return self.get_relations_with_entries_by_recon_ids_symmetric(entry_type, recon_ids, relation_type)
328
+ else:
329
+ if relation_type.source_entrytype is entry_type:
330
+ # Case 2: Relation is [T1, T2]
331
+ return self.get_relations_with_entries_by_recon_ids_of_t1(entry_type, recon_ids, relation_type)
332
+ elif relation_type.target_entrytype is entry_type:
333
+ # Case 3: Relation is [T2, T1]
334
+ return self.get_relations_with_entries_by_recon_ids_of_t2(entry_type, recon_ids, relation_type)
335
+ else:
336
+ # Invalid shape of entrytype and relation?
337
+ raise ValueError()
338
+
339
+ # Specialized getters for different entry/relation cases
340
+ def get_relations_with_entries_by_recon_ids_symmetric[T: Entry](
341
+ self,
342
+ entry_type: type[T],
343
+ recon_ids: list[int],
344
+ relation_type: type[Relation[T, T]],
345
+ ) -> list[dict[T, list[Relation[T, T]]]]:
346
+ """ Specialization for symmetric relations.
347
+ """
348
+ c = self.conn.cursor(row_factory = class_row(relation_type))
349
+
350
+ # TODO REDO, call above functions
351
+ # TODO get by views!
352
+
353
+ # Symmetric - either recon_id 1 or 2 can be found
354
+ q = sql.SQL("""
355
+ SELECT *
356
+ FROM {table}
357
+ WHERE (recon_id_1 = {recon_id_placeholder}) OR (recon_id_2 = {recon_id_placeholder});
358
+ """).format(
359
+ table = sql.Identifier(relation_type.get_table_name()),
360
+ recon_id_placeholder = sql.SQL('%s')
361
+ )
362
+ c.executemany(q, params_seq = [[recon_id] * 2 for recon_id in recon_ids], returning = True)
363
+
364
+ # Each input recon_id corresponds to a result set
365
+ entry_relations: dict[int, list[Relation]] = dict()
366
+ for recon_id in recon_ids:
367
+ entry_relations[recon_id] = c.fetchall()
368
+ c.nextset()
369
+
370
+ # Get entries
371
+ results: list[dict[T, list[Relation[T, T]]]] = list()
372
+
373
+ for input_recon_id, rels in entry_relations.items():
374
+ # For each source recon id, run a query to get relation targets
375
+ subresults: dict[T, list[Relation[T, T]]] = defaultdict(list)
376
+ target_recon_ids = [ # List corresponding to the list of relations rels
377
+ rel.recon_id_1 if rel.recon_id_2 == input_recon_id else rel.recon_id_2
378
+ for rel in rels
379
+ ]
380
+ q = sql.SQL("""
381
+ SELECT *
382
+ FROM {t2_table}
383
+ WHERE recon_id = %s;
384
+ """).format(
385
+ t2_table = sql.Identifier(entry_type.get_table_name())
386
+ )
387
+ c = self.conn.cursor(row_factory = class_row(entry_type))
388
+ c.executemany(q, params_seq = [[tid] for tid in target_recon_ids], returning = True)
389
+ target_entries: list[T] = self._extract_results_from_cursor(c)
390
+
391
+ # Add to results
392
+ for rel, target_entry in zip(rels, target_entries):
393
+ subresults[target_entry].append(rel)
394
+ results.append(subresults)
395
+
396
+ # Return
397
+ return results
398
+
399
+ def get_relations_with_entries_by_recon_ids_of_t1[T1: Entry, T2: Entry](
400
+ self,
401
+ entry_type: type[T1],
402
+ recon_ids: list[int],
403
+ relation_type: type[Relation[T1, T2]],
404
+ ) -> list[dict[T2, list[Relation[T1, T2]]]]:
405
+ """ Given a list of entries, get connected relations of the given type.
406
+ Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
407
+ to the relation(s) with which they are connected.
408
+ """
409
+ c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
410
+ q = sql.SQL("""
411
+ SELECT *
412
+ FROM {view}
413
+ WHERE (recon_id_1 = {recon_id_placeholder});
414
+ """).format(
415
+ view = sql.Identifier(f'{relation_type.get_table_name()}_v'),
416
+ recon_id_placeholder = sql.SQL('%s')
417
+ )
418
+ c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
419
+
420
+ # Process output
421
+ result: list[dict[T2, list[Relation[T1, T2]]]] = list()
422
+ cursor_result = self._extract_results_from_cursor_by_sets(c)
423
+ for recon_id, relation_list in zip(recon_ids, cursor_result):
424
+ subdict: dict[T2, list[Relation[T1, T2]]] = defaultdict(list)
425
+ for relation, entry_1, entry_2 in relation_list:
426
+ assert entry_1.recon_id == recon_id
427
+ subdict[entry_2].append(relation)
428
+ result.append(subdict)
429
+ return result
430
+
431
+ def get_relations_with_entries_by_recon_ids_of_t2[T1: Entry, T2: Entry](
432
+ self,
433
+ entry_type: type[T1],
434
+ recon_ids: list[int],
435
+ relation_type: type[Relation[T2, T1]],
436
+ ) -> list[dict[T2, list[Relation[T2, T1]]]]:
437
+ """ Given a list of entries, get connected relations of the given type.
438
+ Returns a list, containing, for each input recon_id, a dictionary mapping connected vertices
439
+ to the relation(s) with which they are connected.
440
+ """
441
+ c = self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relation_type))
442
+ q = sql.SQL("""
443
+ SELECT *
444
+ FROM {view}
445
+ WHERE (recon_id_2 = {recon_id_placeholder});
446
+ """).format(
447
+ view = sql.Identifier(f'{relation_type.get_table_name()}_v'),
448
+ recon_id_placeholder = sql.SQL('%s')
449
+ )
450
+ c.executemany(q, params_seq = [[recon_id] for recon_id in recon_ids], returning = True)
451
+
452
+ # Process output
453
+ result: list[dict[T2, list[Relation[T2, T1]]]] = list()
454
+ cursor_result = self._extract_results_from_cursor_by_sets(c)
455
+ for recon_id, relation_list in zip(recon_ids, cursor_result):
456
+ subdict: dict[T2, list[Relation[T2, T1]]] = defaultdict(list)
457
+ for relation, entry_2, entry_1 in relation_list:
458
+ assert entry_1.recon_id == recon_id
459
+ subdict[entry_2].append(relation)
460
+ result.append(subdict)
461
+ return result
462
+
463
+ # Graph exploration getters
464
+ # ----------------------------------------------------------------------------------------------------------
465
+ # These are the main queries used to efficiently explore the graph.
466
+ def explore_vertex[T: Entry](
467
+ self,
468
+ entry: T,
469
+ relation_types: type[Relation[T, Any]]
470
+ ) -> list[tuple[Entry, list[Relation]]]:
471
+ """ Explores the neighbouring entries to a given entry.
472
+ Results are returned as a list of tuples, each containing a neighbouring entry and the relations
473
+ which point to it.
474
+ Note that the input entry may be given in the output in the case of reflexive relations.
475
+ """
476
+
477
+ # TODO Run get_relations_with_entries_from_source() for each allowed relation, and collate the results
478
+
479
+ # TODO
480
+ raise NotImplementedError()
481
+
482
+ # Misc
483
+ # ----------------------------------------------------------------------------------------------------------
484
+ def get_number_of_entries[T: DatabaseObject](self, table: type[T]) -> int:
485
+ """ Return hte number of entries in the table"""
486
+ table_name = table.get_table_name()
487
+ q = sql.SQL("""SELECT count(*) FROM {table}""").format(
488
+ table = sql.Identifier(table_name)
489
+ )
490
+ self.c.execute(q)
491
+ result: int = self.c.fetchone()[0]
492
+ return result
493
+
494
+ def iterate_table[T: DatabaseObject](
495
+ self,
496
+ table: type[T],
497
+ batch_size: int = 1000,
498
+ up_to: int = 0
499
+ ) -> Generator[T, None, None]:
500
+ """ Creates an iterator which queries all entries in a table.
501
+ Entries are requested from the database in batches, but returned as a continuous iterator.
502
+ """
503
+ i: int = 0
504
+ batch: list[T] = list()
505
+
506
+ with self.conn.cursor(row_factory = class_row(table)) as cursor:
507
+ q = sql.SQL("""
508
+ SELECT *
509
+ FROM {table}
510
+ ;
511
+ """).format(
512
+ table = sql.Identifier(table.get_table_name())
513
+ )
514
+ try:
515
+ cursor.execute(q)
516
+ except Exception as e:
517
+ raise RuntimeError('!: {e}')
518
+
519
+ while True:
520
+ # Main loop, fetch batches and yield from the batches
521
+ batch = cursor.fetchmany(size = batch_size)
522
+ if not batch:
523
+ break
524
+ item: T
525
+ for item in batch:
526
+ i += 1
527
+ yield item
528
+ if up_to and i > up_to:
529
+ return
530
+
531
+
532
+ def iterate_relation_with_entries[T1: Entry, T2: Entry](
533
+ self,
534
+ relationtype: type[Relation[T1, T2]],
535
+ batch_size: int = 1000,
536
+ up_to: int = 0
537
+ ) -> Generator[tuple[Relation[T1, T2], T1, T2], None, None]:
538
+ """ Iterates the view over a relation table, getting entries adjacent to each relation.
539
+ """
540
+ i: int = 0
541
+ batch: list = list()
542
+
543
+ # TODO make a relationIteratorRowFactory and use
544
+
545
+ # With the factory created, iterate the rows
546
+ with self.conn.cursor(row_factory = self.make_relation_entry_view_row_factory(relationtype)) as cursor:
547
+ q = sql.SQL("""
548
+ SELECT *
549
+ FROM {view}
550
+ ;
551
+ """).format(
552
+ view = sql.Identifier(f'{relationtype.get_table_name()}_v')
553
+ )
554
+ try:
555
+ cursor.execute(q)
556
+ except Exception as e:
557
+ raise RuntimeError('!: {e}')
558
+
559
+ while True:
560
+ # Main loop, fetch batches and yield from the batches
561
+ batch = cursor.fetchmany(size = batch_size)
562
+ if not batch:
563
+ break
564
+ item: tuple[Relation[T1, T2], T1, T2]
565
+ for item in batch:
566
+ i += 1
567
+ yield item
568
+ if up_to and i > up_to:
569
+ return
570
+
571
+ def _extract_results_from_cursor(self, cursor: Optional[pg.Cursor] = None) -> list[Any]:
572
+ """ Iterates over all result sets of the Psycopg3 cursor and returns.
573
+ """
574
+ c = cursor or self.c
575
+ result: list[Any] = [c.fetchone()]
576
+ while c.nextset():
577
+ result.append(c.fetchone())
578
+ return result
579
+
580
+ def _extract_results_from_cursor_by_sets(self, cursor: Optional[pg.Cursor] = None) -> list[list[Any]]:
581
+ """ Iterates over all result sets of the Psycopg3 cursor and returns.
582
+ """
583
+ c = cursor or self.c
584
+ result: list[list[Any]] = [c.fetchall()]
585
+ while c.nextset():
586
+ result.append(c.fetchall())
587
+ return result
@@ -0,0 +1 @@
1
+ from chemrecon.database.params import parameter_sets
@@ -0,0 +1,63 @@
1
+ """ Implements database connection, cursor, etc. functionality.
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import psycopg as pg
6
+ import psycopg.sql as sql
7
+ from psycopg.types.enum import EnumInfo, register_enum
8
+
9
+ import chemrecon.schema as schema
10
+ from chemrecon.database.params import Params
11
+
12
+ # Initialisation functions
13
+ # ----------------------------------------------------------------------------------------------------------------------
14
+ def postgres_connect(
15
+ params: Params,
16
+ initialise_enums: bool = True
17
+ ) -> pg.Connection:
18
+ """ Connect to the database server.
19
+ Connection string given relative to the working directory (chemrecon/).
20
+ If find_directory is set, will automatically try to locate the chemrecon directory
21
+ """
22
+ # Load parameters
23
+ print(f'[ChemRecon] Attempting to connect with string: {params.connection_string()}')
24
+ try:
25
+ conn: pg.Connection = pg.connect(
26
+ params.connection_string()
27
+ )
28
+ except pg.OperationalError as e:
29
+ print(f'[ChemRecon] Could not connect: {e}')
30
+ raise ConnectionError from e
31
+
32
+ print(f'[ChemRecon] Connection succesful.')
33
+
34
+ # Set correct schemata to search
35
+ schemata = [
36
+ 'chemrecon',
37
+ 'pg_catalog',
38
+ 'meta',
39
+ 'public'
40
+ ]
41
+ conn.execute(
42
+ sql.SQL("SET {search_path} TO {schemata}").format(
43
+ search_path = sql.Identifier('search_path'),
44
+ schemata = sql.SQL(', ').join(
45
+ [sql.Identifier(s) for s in schemata]
46
+ )
47
+ )
48
+ )
49
+ conn.commit()
50
+
51
+ # Register enum mapping
52
+ if initialise_enums:
53
+ conn.commit()
54
+ for enum in schema.enums.enum_register:
55
+ # enum_info = EnumInfo.fetch(conn, f'chemrecon.{enum.__name__}')
56
+ enum_info = EnumInfo.fetch(conn, sql.Identifier(enum.__name__))
57
+ if enum_info is None:
58
+ raise RuntimeError(f'Enum info none for enum {enum}.')
59
+ register_enum(enum_info, conn, enum)
60
+
61
+ # Done
62
+ conn.commit()
63
+ return conn
@@ -0,0 +1,5 @@
1
+ chemrecon_db
2
+ public_user
3
+ chemrecon_public_password
4
+ 130.225.164.152
5
+ 5432
@@ -0,0 +1,5 @@
1
+ chemrecon_db
2
+ dev
3
+ testpassword_dev
4
+ localhost
5
+ 54320
@@ -0,0 +1,5 @@
1
+ chemrecon_db
2
+ postgres
3
+ testpassword
4
+ localhost
5
+ 54320
@@ -0,0 +1,5 @@
1
+ chemrecon_db
2
+ public_user
3
+ chemrecon_public_password
4
+ localhost
5
+ 54320