closurizer 0.7.2__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: closurizer
3
- Version: 0.7.2
3
+ Version: 0.8.0
4
4
  Summary: Add closure expansion fields to kgx files following the Golr pattern
5
5
  Author: Kevin Schaper
6
6
  Author-email: kevin@tislab.org
@@ -4,7 +4,8 @@ from closurizer.closurizer import add_closure
4
4
 
5
5
 
6
6
  @click.command()
7
- @click.option('--kg', required=True, help='KGX tar.gz archive')
7
+ @click.option('--kg', required=False, help='KGX tar.gz archive to import (optional - if not provided, uses existing database)')
8
+ @click.option('--database', required=False, default='monarch-kg.duckdb', help='Database path - if kg is provided, data will be loaded into this database; if kg is not provided, this database must already exist with nodes and edges tables (default: monarch-kg.duckdb)')
8
9
  @click.option('--closure', required=True, help='TSV file of closure triples')
9
10
  @click.option('--nodes-output', required=True, help='file write nodes kgx file with closure fields added')
10
11
  @click.option('--edges-output', required=True, help='file write edges kgx file with closure fields added')
@@ -14,8 +15,12 @@ from closurizer.closurizer import add_closure
14
15
  @click.option('--edge-fields-to-label', multiple=True, help='edge fields to with category, label, etc but not full closure exansion')
15
16
  @click.option('--node-fields', multiple=True, help='node fields to expand with closure IDs, labels, etc')
16
17
  @click.option('--grouping-fields', multiple=True, help='fields to populate a single value grouping_key field')
18
+ @click.option('--multivalued-fields', multiple=True, help='fields containing pipe-delimited values to convert to varchar[] arrays in database')
17
19
  @click.option('--dry-run', is_flag=True, help='A dry run will not write the output file, but will print the SQL query')
20
+ @click.option('--export-edges', is_flag=True, help='Export denormalized edges to TSV file (default: database only)')
21
+ @click.option('--export-nodes', is_flag=True, help='Export denormalized nodes to TSV file (default: database only)')
18
22
  def main(kg: str,
23
+ database: str,
19
24
  closure: str,
20
25
  nodes_output: str,
21
26
  edges_output: str,
@@ -24,17 +29,25 @@ def main(kg: str,
24
29
  edge_fields: List[str] = None,
25
30
  edge_fields_to_label: List[str] = None,
26
31
  node_fields: List[str] = None,
27
- grouping_fields: List[str] = None):
28
- add_closure(kg_archive=kg,
29
- closure_file=closure,
30
- edge_fields=edge_fields,
31
- edge_fields_to_label=edge_fields_to_label,
32
- node_fields=node_fields,
33
- edges_output_file=edges_output,
32
+ grouping_fields: List[str] = None,
33
+ multivalued_fields: List[str] = None,
34
+ export_edges: bool = False,
35
+ export_nodes: bool = False):
36
+
37
+ add_closure(closure_file=closure,
34
38
  nodes_output_file=nodes_output,
39
+ edges_output_file=edges_output,
40
+ kg_archive=kg,
41
+ database_path=database,
42
+ edge_fields=edge_fields or ['subject', 'object'],
43
+ edge_fields_to_label=edge_fields_to_label or [],
44
+ node_fields=node_fields or [],
35
45
  additional_node_constraints=additional_node_constraints,
36
46
  dry_run=dry_run,
37
- grouping_fields=grouping_fields)
47
+ grouping_fields=grouping_fields or ['subject', 'negated', 'predicate', 'object'],
48
+ multivalued_fields=multivalued_fields or ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
49
+ export_edges=export_edges,
50
+ export_nodes=export_nodes)
38
51
 
39
52
  if __name__ == "__main__":
40
53
  main()
@@ -0,0 +1,447 @@
1
+ from typing import List, Optional
2
+
3
+ import os
4
+ import tarfile
5
+ import duckdb
6
+
7
+ def edge_columns(field: str, include_closure_fields: bool =True, node_column_names: list = None):
8
+ column_text = f"""
9
+ {field}.name as {field}_label,
10
+ {field}.category as {field}_category,
11
+ {field}.namespace as {field}_namespace,
12
+ """
13
+ if include_closure_fields:
14
+ column_text += f"""
15
+ {field}_closure.closure as {field}_closure,
16
+ {field}_closure_label.closure_label as {field}_closure_label,
17
+ """
18
+
19
+ # Only add taxon fields if they exist in the nodes table
20
+ if field in ['subject', 'object'] and node_column_names:
21
+ if 'in_taxon' in node_column_names:
22
+ column_text += f"""
23
+ {field}.in_taxon as {field}_taxon,"""
24
+ if 'in_taxon_label' in node_column_names:
25
+ column_text += f"""
26
+ {field}.in_taxon_label as {field}_taxon_label,"""
27
+ column_text += """
28
+ """
29
+ return column_text
30
+
31
+ def edge_joins(field: str, include_closure_joins: bool =True, is_multivalued: bool = False):
32
+ if is_multivalued:
33
+ # For VARCHAR[] fields, use array containment with list_contains
34
+ join_condition = f"list_contains(edges.{field}, {field}.id)"
35
+ else:
36
+ # For VARCHAR fields, use direct equality
37
+ join_condition = f"edges.{field} = {field}.id"
38
+
39
+ joins = f"""
40
+ left outer join nodes as {field} on {join_condition}"""
41
+
42
+ if include_closure_joins:
43
+ joins += f"""
44
+ left outer join closure_id as {field}_closure on {field}.id = {field}_closure.id
45
+ left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id"""
46
+
47
+ return joins + "\n "
48
+
49
+ def evidence_sum(evidence_fields: List[str], edges_column_names: list = None):
50
+ """ Sum together the length of each field - assumes fields are VARCHAR[] arrays """
51
+ evidence_count_parts = []
52
+ for field in evidence_fields:
53
+ # Only include fields that actually exist in the edges table
54
+ if not edges_column_names or field in edges_column_names:
55
+ # All evidence fields are expected to be VARCHAR[] arrays
56
+ evidence_count_parts.append(f"ifnull(array_length({field}),0)")
57
+
58
+ evidence_count_sum = "+".join(evidence_count_parts) if evidence_count_parts else "0"
59
+ return f"{evidence_count_sum} as evidence_count,"
60
+
61
+
62
+ def node_columns(predicate):
63
+ # strip the biolink predicate, if necessary to get the field name
64
+ field = predicate.replace('biolink:','')
65
+
66
+ return f"""
67
+ case when count(distinct {field}_edges.object) > 0 then array_agg(distinct {field}_edges.object) else null end as {field},
68
+ case when count(distinct {field}_edges.object_label) > 0 then array_agg(distinct {field}_edges.object_label) else null end as {field}_label,
69
+ count (distinct {field}_edges.object) as {field}_count,
70
+ case when count({field}_closure.closure) > 0 then list_distinct(flatten(array_agg({field}_closure.closure))) else null end as {field}_closure,
71
+ case when count({field}_closure_label.closure_label) > 0 then list_distinct(flatten(array_agg({field}_closure_label.closure_label))) else null end as {field}_closure_label,
72
+ """
73
+
74
+ def node_joins(predicate):
75
+ # strip the biolink predicate, if necessary to get the field name
76
+ field = predicate.replace('biolink:','')
77
+ return f"""
78
+ left outer join denormalized_edges as {field}_edges
79
+ on nodes.id = {field}_edges.subject
80
+ and {field}_edges.predicate = 'biolink:{field}'
81
+ left outer join closure_id as {field}_closure
82
+ on {field}_edges.object = {field}_closure.id
83
+ left outer join closure_label as {field}_closure_label
84
+ on {field}_edges.object = {field}_closure_label.id
85
+ """
86
+
87
+
88
+ def grouping_key(grouping_fields, edges_column_names: list = None):
89
+ if not grouping_fields:
90
+ return "null as grouping_key"
91
+ fragments = []
92
+ for field in grouping_fields:
93
+ # Only include fields that actually exist in the edges table
94
+ if not edges_column_names or field in edges_column_names:
95
+ if field == 'negated':
96
+ fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
97
+ else:
98
+ fragments.append(field)
99
+ if not fragments:
100
+ return "null as grouping_key"
101
+ grouping_key_fragments = ", ".join(fragments)
102
+ return f"concat_ws('|', {grouping_key_fragments}) as grouping_key"
103
+
104
+
105
+
106
+
107
+ def load_from_archive(kg_archive: str, db, multivalued_fields: List[str]):
108
+ """Load nodes and edges tables from tar.gz archive"""
109
+
110
+ tar = tarfile.open(f"{kg_archive}")
111
+
112
+ print("Loading node table...")
113
+ node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
114
+ tar.extract(node_file_name,)
115
+ node_file = f"{node_file_name}"
116
+ print(f"node_file: {node_file}")
117
+
118
+ db.sql(f"""
119
+ create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
120
+ """)
121
+
122
+ edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
123
+ tar.extract(edge_file_name)
124
+ edge_file = f"{edge_file_name}"
125
+ print(f"edge_file: {edge_file}")
126
+
127
+ db.sql(f"""
128
+ create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
129
+ """)
130
+
131
+ # Convert multivalued fields to arrays
132
+ prepare_multivalued_fields(db, multivalued_fields)
133
+
134
+ # Clean up extracted files
135
+ if os.path.exists(f"{node_file}"):
136
+ os.remove(f"{node_file}")
137
+ if os.path.exists(f"{edge_file}"):
138
+ os.remove(f"{edge_file}")
139
+
140
+
141
+ def prepare_multivalued_fields(db, multivalued_fields: List[str]):
142
+ """Convert specified fields to varchar[] arrays in both nodes and edges tables"""
143
+
144
+ # Convert multivalued fields in nodes table to varchar[] arrays
145
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
146
+ node_column_names = [col[0] for col in nodes_table_info]
147
+ node_column_types = {col[0]: col[1] for col in nodes_table_info}
148
+
149
+ for field in multivalued_fields:
150
+ if field in node_column_names:
151
+ # Check if field is already VARCHAR[] - if so, skip conversion
152
+ if 'VARCHAR[]' in node_column_types[field].upper():
153
+ print(f"Field '{field}' in nodes table is already VARCHAR[], skipping conversion")
154
+ continue
155
+
156
+ print(f"Converting field '{field}' in nodes table to VARCHAR[]")
157
+ # Create a new column with proper array type and replace the original
158
+ db.sql(f"""
159
+ alter table nodes add column {field}_array VARCHAR[]
160
+ """)
161
+ db.sql(f"""
162
+ update nodes set {field}_array =
163
+ case
164
+ when {field} is null or {field} = '' then null
165
+ else split({field}, '|')
166
+ end
167
+ """)
168
+ db.sql(f"""
169
+ alter table nodes drop column {field}
170
+ """)
171
+ db.sql(f"""
172
+ alter table nodes rename column {field}_array to {field}
173
+ """)
174
+
175
+ # Convert multivalued fields in edges table to varchar[] arrays
176
+ edges_table_info = db.sql("DESCRIBE edges").fetchall()
177
+ edge_column_names = [col[0] for col in edges_table_info]
178
+ edge_column_types = {col[0]: col[1] for col in edges_table_info}
179
+
180
+ for field in multivalued_fields:
181
+ if field in edge_column_names:
182
+ # Check if field is already VARCHAR[] - if so, skip conversion
183
+ if 'VARCHAR[]' in edge_column_types[field].upper():
184
+ print(f"Field '{field}' in edges table is already VARCHAR[], skipping conversion")
185
+ continue
186
+
187
+ print(f"Converting field '{field}' in edges table to VARCHAR[]")
188
+ # Create a new column with proper array type and replace the original
189
+ db.sql(f"""
190
+ alter table edges add column {field}_array VARCHAR[]
191
+ """)
192
+ db.sql(f"""
193
+ update edges set {field}_array =
194
+ case
195
+ when {field} is null or {field} = '' then null
196
+ else split({field}, '|')
197
+ end
198
+ """)
199
+ db.sql(f"""
200
+ alter table edges drop column {field}
201
+ """)
202
+ db.sql(f"""
203
+ alter table edges rename column {field}_array to {field}
204
+ """)
205
+
206
+
207
+ def add_closure(closure_file: str,
208
+ nodes_output_file: str,
209
+ edges_output_file: str,
210
+ kg_archive: Optional[str] = None,
211
+ database_path: str = 'monarch-kg.duckdb',
212
+ node_fields: List[str] = [],
213
+ edge_fields: List[str] = ['subject', 'object'],
214
+ edge_fields_to_label: List[str] = [],
215
+ additional_node_constraints: Optional[str] = None,
216
+ dry_run: bool = False,
217
+ evidence_fields: List[str] = ['has_evidence', 'publications'],
218
+ grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object'],
219
+ multivalued_fields: List[str] = ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
220
+ export_edges: bool = False,
221
+ export_nodes: bool = False
222
+ ):
223
+ # Validate input parameters
224
+ if not kg_archive and not os.path.exists(database_path):
225
+ raise ValueError("Either kg_archive must be specified or database_path must exist")
226
+
227
+ print("Generating closure KG...")
228
+ if kg_archive:
229
+ print(f"kg_archive: {kg_archive}")
230
+ print(f"database_path: {database_path}")
231
+ print(f"closure_file: {closure_file}")
232
+
233
+ # Connect to database
234
+ db = duckdb.connect(database=database_path)
235
+
236
+ if not dry_run:
237
+ print(f"fields: {','.join(edge_fields)}")
238
+ print(f"output_file: {edges_output_file}")
239
+
240
+ # Load data based on input method
241
+ if kg_archive:
242
+ load_from_archive(kg_archive, db, multivalued_fields)
243
+ else:
244
+ # Database already exists and contains data
245
+ # Check if namespace column exists, add it if needed
246
+ node_column_names = [col[0] for col in db.sql("DESCRIBE nodes").fetchall()]
247
+ if 'namespace' not in node_column_names:
248
+ print("Adding namespace column to nodes table...")
249
+ db.sql("ALTER TABLE nodes ADD COLUMN namespace VARCHAR")
250
+ db.sql("UPDATE nodes SET namespace = substr(id, 1, instr(id,':') -1)")
251
+
252
+ # Convert multivalued fields to arrays
253
+ prepare_multivalued_fields(db, multivalued_fields)
254
+
255
+ # Load the relation graph tsv in long format mapping a node to each of it's ancestors
256
+ db.sql(f"""
257
+ create or replace table closure as select * from read_csv('{closure_file}', sep='\t', names=['subject_id', 'predicate_id', 'object_id'], AUTO_DETECT=TRUE)
258
+ """)
259
+
260
+ db.sql("""
261
+ create or replace table closure_id as select subject_id as id, array_agg(object_id) as closure from closure group by subject_id
262
+ """)
263
+
264
+ db.sql("""
265
+ create or replace table closure_label as select subject_id as id, array_agg(name) as closure_label from closure join nodes on object_id = id
266
+ group by subject_id
267
+ """)
268
+
269
+ db.sql("""
270
+ create or replace table descendants_id as
271
+ select object_id as id, array_agg(subject_id) as descendants
272
+ from closure
273
+ group by object_id
274
+ """)
275
+
276
+ db.sql("""
277
+ create or replace table descendants_label as
278
+ select object_id as id, array_agg(name) as descendants_label
279
+ from closure
280
+ join nodes on subject_id = nodes.id
281
+ group by object_id
282
+ """)
283
+
284
+ # Get edges table schema to determine which fields are VARCHAR[]
285
+ edges_table_info = db.sql("DESCRIBE edges").fetchall()
286
+ edges_table_types = {col[0]: col[1] for col in edges_table_info}
287
+ edges_column_names = [col[0] for col in edges_table_info]
288
+
289
+ # Get nodes table schema to check for available columns
290
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
291
+ node_column_names = [col[0] for col in nodes_table_info]
292
+
293
+ # Build edge joins with proper multivalued field handling
294
+ edge_field_joins = []
295
+ for field in edge_fields:
296
+ is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
297
+ edge_field_joins.append(edge_joins(field, is_multivalued=is_multivalued))
298
+
299
+ edge_field_to_label_joins = []
300
+ for field in edge_fields_to_label:
301
+ is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
302
+ edge_field_to_label_joins.append(edge_joins(field, include_closure_joins=False, is_multivalued=is_multivalued))
303
+
304
+ edges_query = f"""
305
+ create or replace table denormalized_edges as
306
+ select edges.*,
307
+ {"".join([edge_columns(field, node_column_names=node_column_names) for field in edge_fields])}
308
+ {"".join([edge_columns(field, include_closure_fields=False, node_column_names=node_column_names) for field in edge_fields_to_label])}
309
+ {evidence_sum(evidence_fields, edges_column_names)}
310
+ {grouping_key(grouping_fields, edges_column_names)}
311
+ from edges
312
+ {"".join(edge_field_joins)}
313
+ {"".join(edge_field_to_label_joins)}
314
+ """
315
+
316
+ print(edges_query)
317
+
318
+ additional_node_constraints = f"where {additional_node_constraints}" if additional_node_constraints else ""
319
+
320
+ # Get nodes table info to handle multivalued fields in the query
321
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
322
+ nodes_table_column_names = [col[0] for col in nodes_table_info]
323
+ nodes_table_types = {col[0]: col[1] for col in nodes_table_info}
324
+
325
+ # Create field selections for nodes, converting VARCHAR[] back to pipe-delimited strings
326
+ nodes_field_selections = []
327
+ for field in nodes_table_column_names:
328
+ if field in multivalued_fields and 'VARCHAR[]' in nodes_table_types[field].upper():
329
+ # Convert VARCHAR[] back to pipe-delimited string
330
+ nodes_field_selections.append(f"list_aggregate({field}, 'string_agg', '|') as {field}")
331
+ else:
332
+ # Regular field, use as-is (but need to specify for GROUP BY)
333
+ nodes_field_selections.append(f"nodes.{field}")
334
+
335
+ nodes_base_fields = ",\n ".join(nodes_field_selections)
336
+
337
+ nodes_query = f"""
338
+ create or replace table denormalized_nodes as
339
+ select {nodes_base_fields},
340
+ {"".join([node_columns(node_field) for node_field in node_fields])}
341
+ from nodes
342
+ {node_joins('has_phenotype')}
343
+ {additional_node_constraints}
344
+ group by {", ".join([f"nodes.{field}" for field in nodes_table_column_names])}
345
+ """
346
+ print(nodes_query)
347
+
348
+
349
+ if not dry_run:
350
+
351
+ db.sql(edges_query)
352
+
353
+ # Export edges to TSV only if requested
354
+ if export_edges:
355
+ edge_closure_replacements = [
356
+ f"""
357
+ list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
358
+ list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
359
+ """
360
+ for field in edge_fields
361
+ ]
362
+
363
+ # Add conversions for original multivalued fields back to pipe-delimited strings
364
+ edge_table_info = db.sql("DESCRIBE denormalized_edges").fetchall()
365
+ edge_table_column_names = [col[0] for col in edge_table_info]
366
+ edge_table_types = {col[0]: col[1] for col in edge_table_info}
367
+
368
+ # Create set of closure fields already handled by edge_closure_replacements
369
+ closure_fields_handled = set()
370
+ for field in edge_fields:
371
+ closure_fields_handled.add(f"{field}_closure")
372
+ closure_fields_handled.add(f"{field}_closure_label")
373
+
374
+ multivalued_replacements = [
375
+ f"list_aggregate({field}, 'string_agg', '|') as {field}"
376
+ for field in multivalued_fields
377
+ if field in edge_table_column_names and 'VARCHAR[]' in edge_table_types[field].upper()
378
+ and field not in closure_fields_handled
379
+ ]
380
+
381
+ all_replacements = edge_closure_replacements + multivalued_replacements
382
+ edge_closure_replacements = "REPLACE (\n" + ",\n".join(all_replacements) + ")\n"
383
+
384
+ edges_export_query = f"""
385
+ -- write denormalized_edges as tsv
386
+ copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
387
+ """
388
+ print(edges_export_query)
389
+ db.sql(edges_export_query)
390
+
391
+ db.sql(nodes_query)
392
+
393
+ # Add descendant columns separately to avoid memory issues with large GROUP BY
394
+ print("Adding descendant columns to denormalized_nodes...")
395
+ db.sql("alter table denormalized_nodes add column has_descendant VARCHAR[]")
396
+ db.sql("alter table denormalized_nodes add column has_descendant_label VARCHAR[]")
397
+ db.sql("alter table denormalized_nodes add column has_descendant_count INTEGER")
398
+
399
+ db.sql("""
400
+ update denormalized_nodes
401
+ set has_descendant = descendants_id.descendants
402
+ from descendants_id
403
+ where denormalized_nodes.id = descendants_id.id
404
+ """)
405
+
406
+ db.sql("""
407
+ update denormalized_nodes
408
+ set has_descendant_label = descendants_label.descendants_label
409
+ from descendants_label
410
+ where denormalized_nodes.id = descendants_label.id
411
+ """)
412
+
413
+ db.sql("""
414
+ update denormalized_nodes
415
+ set has_descendant_count = coalesce(array_length(has_descendant), 0)
416
+ """)
417
+
418
+ # Export nodes to TSV only if requested
419
+ if export_nodes:
420
+ # Get denormalized_nodes table info to handle array fields in export
421
+ denorm_nodes_table_info = db.sql("DESCRIBE denormalized_nodes").fetchall()
422
+ denorm_nodes_column_names = [col[0] for col in denorm_nodes_table_info]
423
+ denorm_nodes_types = {col[0]: col[1] for col in denorm_nodes_table_info}
424
+
425
+ # Find all VARCHAR[] fields that need conversion to pipe-delimited strings
426
+ array_field_replacements = [
427
+ f"list_aggregate({field}, 'string_agg', '|') as {field}"
428
+ for field in denorm_nodes_column_names
429
+ if 'VARCHAR[]' in denorm_nodes_types[field].upper()
430
+ ]
431
+
432
+ # The descendants fields are already handled by the general VARCHAR[] logic above
433
+ # No need to add them separately
434
+
435
+ if array_field_replacements:
436
+ nodes_replacements = "REPLACE (\n" + ",\n".join(array_field_replacements) + ")\n"
437
+ nodes_export_query = f"""
438
+ -- write denormalized_nodes as tsv
439
+ copy (select * {nodes_replacements} from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
440
+ """
441
+ else:
442
+ nodes_export_query = f"""
443
+ -- write denormalized_nodes as tsv
444
+ copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
445
+ """
446
+ print(nodes_export_query)
447
+ db.sql(nodes_export_query)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "closurizer"
3
- version = "0.7.2"
3
+ version = "0.8.0"
4
4
  description = "Add closure expansion fields to kgx files following the Golr pattern"
5
5
  authors = ["Kevin Schaper <kevin@tislab.org>"]
6
6
 
@@ -1,191 +0,0 @@
1
- from typing import List, Optional
2
-
3
- import os
4
- import tarfile
5
- import duckdb
6
-
7
- def edge_columns(field: str, include_closure_fields: bool =True):
8
- column_text = f"""
9
- {field}.name as {field}_label,
10
- {field}.category as {field}_category,
11
- {field}.namespace as {field}_namespace,
12
- """
13
- if include_closure_fields:
14
- column_text += f"""
15
- {field}_closure.closure as {field}_closure,
16
- {field}_closure_label.closure_label as {field}_closure_label,
17
- """
18
-
19
- if field in ['subject', 'object']:
20
- column_text += f"""
21
- {field}.in_taxon as {field}_taxon,
22
- {field}.in_taxon_label as {field}_taxon_label,
23
- """
24
- return column_text
25
-
26
- def edge_joins(field: str, include_closure_joins: bool =True):
27
- return f"""
28
- left outer join nodes as {field} on edges.{field} = {field}.id
29
- left outer join closure_id as {field}_closure on {field}.id = {field}_closure.id
30
- left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id
31
- """
32
-
33
- def evidence_sum(evidence_fields: List[str]):
34
- """ Sum together the length of each field after splitting on | """
35
- evidence_count_sum = "+".join([f"ifnull(len(split({field}, '|')),0)" for field in evidence_fields])
36
- return f"{evidence_count_sum} as evidence_count,"
37
-
38
-
39
- def node_columns(predicate):
40
- # strip the biolink predicate, if necessary to get the field name
41
- field = predicate.replace('biolink:','')
42
-
43
- return f"""
44
- string_agg({field}_edges.object, '|') as {field},
45
- string_agg({field}_edges.object_label, '|') as {field}_label,
46
- count (distinct {field}_edges.object) as {field}_count,
47
- list_aggregate(list_distinct(flatten(array_agg({field}_closure.closure))), 'string_agg', '|') as {field}_closure,
48
- list_aggregate(list_distinct(flatten(array_agg({field}_closure_label.closure_label))), 'string_agg', '|') as {field}_closure_label,
49
- """
50
-
51
- def node_joins(predicate):
52
- # strip the biolink predicate, if necessary to get the field name
53
- field = predicate.replace('biolink:','')
54
- return f"""
55
- left outer join denormalized_edges as {field}_edges
56
- on nodes.id = {field}_edges.subject
57
- and {field}_edges.predicate = 'biolink:{field}'
58
- left outer join closure_id as {field}_closure
59
- on {field}_edges.object = {field}_closure.id
60
- left outer join closure_label as {field}_closure_label
61
- on {field}_edges.object = {field}_closure_label.id
62
- """
63
-
64
-
65
- def grouping_key(grouping_fields):
66
- fragments = []
67
- for field in grouping_fields:
68
- if field == 'negated':
69
- fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
70
- else:
71
- fragments.append(field)
72
- grouping_key_fragments = ", ".join(fragments)
73
- return f"concat_ws('|', {grouping_key_fragments}) as grouping_key"
74
-
75
-
76
- def add_closure(kg_archive: str,
77
- closure_file: str,
78
- nodes_output_file: str,
79
- edges_output_file: str,
80
- node_fields: List[str] = [],
81
- edge_fields: List[str] = ['subject', 'object'],
82
- edge_fields_to_label: List[str] = [],
83
- additional_node_constraints: Optional[str] = None,
84
- dry_run: bool = False,
85
- evidence_fields: List[str] = ['has_evidence', 'publications'],
86
- grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object']
87
- ):
88
- print("Generating closure KG...")
89
- print(f"kg_archive: {kg_archive}")
90
- print(f"closure_file: {closure_file}")
91
-
92
- db = duckdb.connect(database='monarch-kg.duckdb')
93
-
94
- if not dry_run:
95
- print(f"fields: {','.join(edge_fields)}")
96
- print(f"output_file: {edges_output_file}")
97
-
98
- tar = tarfile.open(f"{kg_archive}")
99
-
100
- print("Loading node table...")
101
- node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
102
- tar.extract(node_file_name,)
103
- node_file = f"{node_file_name}"
104
- print(f"node_file: {node_file}")
105
-
106
- db.sql(f"""
107
- create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
108
- """)
109
-
110
- edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
111
- tar.extract(edge_file_name)
112
- edge_file = f"{edge_file_name}"
113
- print(f"edge_file: {edge_file}")
114
-
115
- db.sql(f"""
116
- create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
117
- """)
118
-
119
- # Load the relation graph tsv in long format mapping a node to each of it's ancestors
120
- db.sql(f"""
121
- create or replace table closure as select * from read_csv('{closure_file}', sep='\t', names=['subject_id', 'predicate_id', 'object_id'], AUTO_DETECT=TRUE)
122
- """)
123
-
124
- db.sql("""
125
- create or replace table closure_id as select subject_id as id, array_agg(object_id) as closure from closure group by subject_id
126
- """)
127
-
128
- db.sql("""
129
- create or replace table closure_label as select subject_id as id, array_agg(name) as closure_label from closure join nodes on object_id = id
130
- group by subject_id
131
- """)
132
-
133
- edges_query = f"""
134
- create or replace table denormalized_edges as
135
- select edges.*,
136
- {"".join([edge_columns(field) for field in edge_fields])}
137
- {"".join([edge_columns(field, include_closure_fields=False) for field in edge_fields_to_label])}
138
- {evidence_sum(evidence_fields)}
139
- {grouping_key(grouping_fields)}
140
- from edges
141
- {"".join([edge_joins(field) for field in edge_fields])}
142
- {"".join([edge_joins(field, include_closure_joins=False) for field in edge_fields_to_label])}
143
- """
144
-
145
- print(edges_query)
146
-
147
- additional_node_constraints = f"where {additional_node_constraints}" if additional_node_constraints else ""
148
- nodes_query = f"""
149
- create or replace table denormalized_nodes as
150
- select nodes.*,
151
- {"".join([node_columns(node_field) for node_field in node_fields])}
152
- from nodes
153
- {node_joins('has_phenotype')}
154
- {additional_node_constraints}
155
- group by nodes.*
156
- """
157
- print(nodes_query)
158
-
159
-
160
- if not dry_run:
161
-
162
-
163
- edge_closure_replacements = [
164
- f"""
165
- list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
166
- list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
167
- """
168
- for field in edge_fields
169
- ]
170
-
171
- edge_closure_replacements = "REPLACE (\n" + ",\n".join(edge_closure_replacements) + ")\n"
172
-
173
- edges_export_query = f"""
174
- -- write denormalized_edges as tsv
175
- copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
176
- """
177
- print(edges_export_query)
178
- db.query(edges_export_query)
179
-
180
- nodes_export_query = f"""
181
- -- write denormalized_nodes as tsv
182
- copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
183
- """
184
- print(nodes_export_query)
185
-
186
-
187
- # Clean up extracted node & edge files
188
- if os.path.exists(f"{node_file}"):
189
- os.remove(f"{node_file}")
190
- if os.path.exists(f"{edge_file}"):
191
- os.remove(f"{edge_file}")