closurizer 0.7.3__tar.gz → 0.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,8 @@ from closurizer.closurizer import add_closure
|
|
4
4
|
|
5
5
|
|
6
6
|
@click.command()
|
7
|
-
@click.option('--kg', required=
|
7
|
+
@click.option('--kg', required=False, help='KGX tar.gz archive to import (optional - if not provided, uses existing database)')
|
8
|
+
@click.option('--database', required=False, default='monarch-kg.duckdb', help='Database path - if kg is provided, data will be loaded into this database; if kg is not provided, this database must already exist with nodes and edges tables (default: monarch-kg.duckdb)')
|
8
9
|
@click.option('--closure', required=True, help='TSV file of closure triples')
|
9
10
|
@click.option('--nodes-output', required=True, help='file write nodes kgx file with closure fields added')
|
10
11
|
@click.option('--edges-output', required=True, help='file write edges kgx file with closure fields added')
|
@@ -14,8 +15,12 @@ from closurizer.closurizer import add_closure
|
|
14
15
|
@click.option('--edge-fields-to-label', multiple=True, help='edge fields to with category, label, etc but not full closure exansion')
|
15
16
|
@click.option('--node-fields', multiple=True, help='node fields to expand with closure IDs, labels, etc')
|
16
17
|
@click.option('--grouping-fields', multiple=True, help='fields to populate a single value grouping_key field')
|
18
|
+
@click.option('--multivalued-fields', multiple=True, help='fields containing pipe-delimited values to convert to varchar[] arrays in database')
|
17
19
|
@click.option('--dry-run', is_flag=True, help='A dry run will not write the output file, but will print the SQL query')
|
20
|
+
@click.option('--export-edges', is_flag=True, help='Export denormalized edges to TSV file (default: database only)')
|
21
|
+
@click.option('--export-nodes', is_flag=True, help='Export denormalized nodes to TSV file (default: database only)')
|
18
22
|
def main(kg: str,
|
23
|
+
database: str,
|
19
24
|
closure: str,
|
20
25
|
nodes_output: str,
|
21
26
|
edges_output: str,
|
@@ -24,17 +29,25 @@ def main(kg: str,
|
|
24
29
|
edge_fields: List[str] = None,
|
25
30
|
edge_fields_to_label: List[str] = None,
|
26
31
|
node_fields: List[str] = None,
|
27
|
-
grouping_fields: List[str] = None
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
edges_output_file=edges_output,
|
32
|
+
grouping_fields: List[str] = None,
|
33
|
+
multivalued_fields: List[str] = None,
|
34
|
+
export_edges: bool = False,
|
35
|
+
export_nodes: bool = False):
|
36
|
+
|
37
|
+
add_closure(closure_file=closure,
|
34
38
|
nodes_output_file=nodes_output,
|
39
|
+
edges_output_file=edges_output,
|
40
|
+
kg_archive=kg,
|
41
|
+
database_path=database,
|
42
|
+
edge_fields=edge_fields or ['subject', 'object'],
|
43
|
+
edge_fields_to_label=edge_fields_to_label or [],
|
44
|
+
node_fields=node_fields or [],
|
35
45
|
additional_node_constraints=additional_node_constraints,
|
36
46
|
dry_run=dry_run,
|
37
|
-
grouping_fields=grouping_fields
|
47
|
+
grouping_fields=grouping_fields or ['subject', 'negated', 'predicate', 'object'],
|
48
|
+
multivalued_fields=multivalued_fields or ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
|
49
|
+
export_edges=export_edges,
|
50
|
+
export_nodes=export_nodes)
|
38
51
|
|
39
52
|
if __name__ == "__main__":
|
40
53
|
main()
|
@@ -0,0 +1,447 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import os
|
4
|
+
import tarfile
|
5
|
+
import duckdb
|
6
|
+
|
7
|
+
def edge_columns(field: str, include_closure_fields: bool =True, node_column_names: list = None):
|
8
|
+
column_text = f"""
|
9
|
+
{field}.name as {field}_label,
|
10
|
+
{field}.category as {field}_category,
|
11
|
+
{field}.namespace as {field}_namespace,
|
12
|
+
"""
|
13
|
+
if include_closure_fields:
|
14
|
+
column_text += f"""
|
15
|
+
{field}_closure.closure as {field}_closure,
|
16
|
+
{field}_closure_label.closure_label as {field}_closure_label,
|
17
|
+
"""
|
18
|
+
|
19
|
+
# Only add taxon fields if they exist in the nodes table
|
20
|
+
if field in ['subject', 'object'] and node_column_names:
|
21
|
+
if 'in_taxon' in node_column_names:
|
22
|
+
column_text += f"""
|
23
|
+
{field}.in_taxon as {field}_taxon,"""
|
24
|
+
if 'in_taxon_label' in node_column_names:
|
25
|
+
column_text += f"""
|
26
|
+
{field}.in_taxon_label as {field}_taxon_label,"""
|
27
|
+
column_text += """
|
28
|
+
"""
|
29
|
+
return column_text
|
30
|
+
|
31
|
+
def edge_joins(field: str, include_closure_joins: bool =True, is_multivalued: bool = False):
|
32
|
+
if is_multivalued:
|
33
|
+
# For VARCHAR[] fields, use array containment with list_contains
|
34
|
+
join_condition = f"list_contains(edges.{field}, {field}.id)"
|
35
|
+
else:
|
36
|
+
# For VARCHAR fields, use direct equality
|
37
|
+
join_condition = f"edges.{field} = {field}.id"
|
38
|
+
|
39
|
+
joins = f"""
|
40
|
+
left outer join nodes as {field} on {join_condition}"""
|
41
|
+
|
42
|
+
if include_closure_joins:
|
43
|
+
joins += f"""
|
44
|
+
left outer join closure_id as {field}_closure on {field}.id = {field}_closure.id
|
45
|
+
left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id"""
|
46
|
+
|
47
|
+
return joins + "\n "
|
48
|
+
|
49
|
+
def evidence_sum(evidence_fields: List[str], edges_column_names: list = None):
|
50
|
+
""" Sum together the length of each field - assumes fields are VARCHAR[] arrays """
|
51
|
+
evidence_count_parts = []
|
52
|
+
for field in evidence_fields:
|
53
|
+
# Only include fields that actually exist in the edges table
|
54
|
+
if not edges_column_names or field in edges_column_names:
|
55
|
+
# All evidence fields are expected to be VARCHAR[] arrays
|
56
|
+
evidence_count_parts.append(f"ifnull(array_length({field}),0)")
|
57
|
+
|
58
|
+
evidence_count_sum = "+".join(evidence_count_parts) if evidence_count_parts else "0"
|
59
|
+
return f"{evidence_count_sum} as evidence_count,"
|
60
|
+
|
61
|
+
|
62
|
+
def node_columns(predicate):
|
63
|
+
# strip the biolink predicate, if necessary to get the field name
|
64
|
+
field = predicate.replace('biolink:','')
|
65
|
+
|
66
|
+
return f"""
|
67
|
+
case when count(distinct {field}_edges.object) > 0 then array_agg(distinct {field}_edges.object) else null end as {field},
|
68
|
+
case when count(distinct {field}_edges.object_label) > 0 then array_agg(distinct {field}_edges.object_label) else null end as {field}_label,
|
69
|
+
count (distinct {field}_edges.object) as {field}_count,
|
70
|
+
case when count({field}_closure.closure) > 0 then list_distinct(flatten(array_agg({field}_closure.closure))) else null end as {field}_closure,
|
71
|
+
case when count({field}_closure_label.closure_label) > 0 then list_distinct(flatten(array_agg({field}_closure_label.closure_label))) else null end as {field}_closure_label,
|
72
|
+
"""
|
73
|
+
|
74
|
+
def node_joins(predicate):
|
75
|
+
# strip the biolink predicate, if necessary to get the field name
|
76
|
+
field = predicate.replace('biolink:','')
|
77
|
+
return f"""
|
78
|
+
left outer join denormalized_edges as {field}_edges
|
79
|
+
on nodes.id = {field}_edges.subject
|
80
|
+
and {field}_edges.predicate = 'biolink:{field}'
|
81
|
+
left outer join closure_id as {field}_closure
|
82
|
+
on {field}_edges.object = {field}_closure.id
|
83
|
+
left outer join closure_label as {field}_closure_label
|
84
|
+
on {field}_edges.object = {field}_closure_label.id
|
85
|
+
"""
|
86
|
+
|
87
|
+
|
88
|
+
def grouping_key(grouping_fields, edges_column_names: list = None):
|
89
|
+
if not grouping_fields:
|
90
|
+
return "null as grouping_key"
|
91
|
+
fragments = []
|
92
|
+
for field in grouping_fields:
|
93
|
+
# Only include fields that actually exist in the edges table
|
94
|
+
if not edges_column_names or field in edges_column_names:
|
95
|
+
if field == 'negated':
|
96
|
+
fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
|
97
|
+
else:
|
98
|
+
fragments.append(field)
|
99
|
+
if not fragments:
|
100
|
+
return "null as grouping_key"
|
101
|
+
grouping_key_fragments = ", ".join(fragments)
|
102
|
+
return f"concat_ws('|', {grouping_key_fragments}) as grouping_key"
|
103
|
+
|
104
|
+
|
105
|
+
|
106
|
+
|
107
|
+
def load_from_archive(kg_archive: str, db, multivalued_fields: List[str]):
|
108
|
+
"""Load nodes and edges tables from tar.gz archive"""
|
109
|
+
|
110
|
+
tar = tarfile.open(f"{kg_archive}")
|
111
|
+
|
112
|
+
print("Loading node table...")
|
113
|
+
node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
|
114
|
+
tar.extract(node_file_name,)
|
115
|
+
node_file = f"{node_file_name}"
|
116
|
+
print(f"node_file: {node_file}")
|
117
|
+
|
118
|
+
db.sql(f"""
|
119
|
+
create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
|
120
|
+
""")
|
121
|
+
|
122
|
+
edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
|
123
|
+
tar.extract(edge_file_name)
|
124
|
+
edge_file = f"{edge_file_name}"
|
125
|
+
print(f"edge_file: {edge_file}")
|
126
|
+
|
127
|
+
db.sql(f"""
|
128
|
+
create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
|
129
|
+
""")
|
130
|
+
|
131
|
+
# Convert multivalued fields to arrays
|
132
|
+
prepare_multivalued_fields(db, multivalued_fields)
|
133
|
+
|
134
|
+
# Clean up extracted files
|
135
|
+
if os.path.exists(f"{node_file}"):
|
136
|
+
os.remove(f"{node_file}")
|
137
|
+
if os.path.exists(f"{edge_file}"):
|
138
|
+
os.remove(f"{edge_file}")
|
139
|
+
|
140
|
+
|
141
|
+
def prepare_multivalued_fields(db, multivalued_fields: List[str]):
|
142
|
+
"""Convert specified fields to varchar[] arrays in both nodes and edges tables"""
|
143
|
+
|
144
|
+
# Convert multivalued fields in nodes table to varchar[] arrays
|
145
|
+
nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
|
146
|
+
node_column_names = [col[0] for col in nodes_table_info]
|
147
|
+
node_column_types = {col[0]: col[1] for col in nodes_table_info}
|
148
|
+
|
149
|
+
for field in multivalued_fields:
|
150
|
+
if field in node_column_names:
|
151
|
+
# Check if field is already VARCHAR[] - if so, skip conversion
|
152
|
+
if 'VARCHAR[]' in node_column_types[field].upper():
|
153
|
+
print(f"Field '{field}' in nodes table is already VARCHAR[], skipping conversion")
|
154
|
+
continue
|
155
|
+
|
156
|
+
print(f"Converting field '{field}' in nodes table to VARCHAR[]")
|
157
|
+
# Create a new column with proper array type and replace the original
|
158
|
+
db.sql(f"""
|
159
|
+
alter table nodes add column {field}_array VARCHAR[]
|
160
|
+
""")
|
161
|
+
db.sql(f"""
|
162
|
+
update nodes set {field}_array =
|
163
|
+
case
|
164
|
+
when {field} is null or {field} = '' then null
|
165
|
+
else split({field}, '|')
|
166
|
+
end
|
167
|
+
""")
|
168
|
+
db.sql(f"""
|
169
|
+
alter table nodes drop column {field}
|
170
|
+
""")
|
171
|
+
db.sql(f"""
|
172
|
+
alter table nodes rename column {field}_array to {field}
|
173
|
+
""")
|
174
|
+
|
175
|
+
# Convert multivalued fields in edges table to varchar[] arrays
|
176
|
+
edges_table_info = db.sql("DESCRIBE edges").fetchall()
|
177
|
+
edge_column_names = [col[0] for col in edges_table_info]
|
178
|
+
edge_column_types = {col[0]: col[1] for col in edges_table_info}
|
179
|
+
|
180
|
+
for field in multivalued_fields:
|
181
|
+
if field in edge_column_names:
|
182
|
+
# Check if field is already VARCHAR[] - if so, skip conversion
|
183
|
+
if 'VARCHAR[]' in edge_column_types[field].upper():
|
184
|
+
print(f"Field '{field}' in edges table is already VARCHAR[], skipping conversion")
|
185
|
+
continue
|
186
|
+
|
187
|
+
print(f"Converting field '{field}' in edges table to VARCHAR[]")
|
188
|
+
# Create a new column with proper array type and replace the original
|
189
|
+
db.sql(f"""
|
190
|
+
alter table edges add column {field}_array VARCHAR[]
|
191
|
+
""")
|
192
|
+
db.sql(f"""
|
193
|
+
update edges set {field}_array =
|
194
|
+
case
|
195
|
+
when {field} is null or {field} = '' then null
|
196
|
+
else split({field}, '|')
|
197
|
+
end
|
198
|
+
""")
|
199
|
+
db.sql(f"""
|
200
|
+
alter table edges drop column {field}
|
201
|
+
""")
|
202
|
+
db.sql(f"""
|
203
|
+
alter table edges rename column {field}_array to {field}
|
204
|
+
""")
|
205
|
+
|
206
|
+
|
207
|
+
def add_closure(closure_file: str,
|
208
|
+
nodes_output_file: str,
|
209
|
+
edges_output_file: str,
|
210
|
+
kg_archive: Optional[str] = None,
|
211
|
+
database_path: str = 'monarch-kg.duckdb',
|
212
|
+
node_fields: List[str] = [],
|
213
|
+
edge_fields: List[str] = ['subject', 'object'],
|
214
|
+
edge_fields_to_label: List[str] = [],
|
215
|
+
additional_node_constraints: Optional[str] = None,
|
216
|
+
dry_run: bool = False,
|
217
|
+
evidence_fields: List[str] = ['has_evidence', 'publications'],
|
218
|
+
grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object'],
|
219
|
+
multivalued_fields: List[str] = ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
|
220
|
+
export_edges: bool = False,
|
221
|
+
export_nodes: bool = False
|
222
|
+
):
|
223
|
+
# Validate input parameters
|
224
|
+
if not kg_archive and not os.path.exists(database_path):
|
225
|
+
raise ValueError("Either kg_archive must be specified or database_path must exist")
|
226
|
+
|
227
|
+
print("Generating closure KG...")
|
228
|
+
if kg_archive:
|
229
|
+
print(f"kg_archive: {kg_archive}")
|
230
|
+
print(f"database_path: {database_path}")
|
231
|
+
print(f"closure_file: {closure_file}")
|
232
|
+
|
233
|
+
# Connect to database
|
234
|
+
db = duckdb.connect(database=database_path)
|
235
|
+
|
236
|
+
if not dry_run:
|
237
|
+
print(f"fields: {','.join(edge_fields)}")
|
238
|
+
print(f"output_file: {edges_output_file}")
|
239
|
+
|
240
|
+
# Load data based on input method
|
241
|
+
if kg_archive:
|
242
|
+
load_from_archive(kg_archive, db, multivalued_fields)
|
243
|
+
else:
|
244
|
+
# Database already exists and contains data
|
245
|
+
# Check if namespace column exists, add it if needed
|
246
|
+
node_column_names = [col[0] for col in db.sql("DESCRIBE nodes").fetchall()]
|
247
|
+
if 'namespace' not in node_column_names:
|
248
|
+
print("Adding namespace column to nodes table...")
|
249
|
+
db.sql("ALTER TABLE nodes ADD COLUMN namespace VARCHAR")
|
250
|
+
db.sql("UPDATE nodes SET namespace = substr(id, 1, instr(id,':') -1)")
|
251
|
+
|
252
|
+
# Convert multivalued fields to arrays
|
253
|
+
prepare_multivalued_fields(db, multivalued_fields)
|
254
|
+
|
255
|
+
# Load the relation graph tsv in long format mapping a node to each of it's ancestors
|
256
|
+
db.sql(f"""
|
257
|
+
create or replace table closure as select * from read_csv('{closure_file}', sep='\t', names=['subject_id', 'predicate_id', 'object_id'], AUTO_DETECT=TRUE)
|
258
|
+
""")
|
259
|
+
|
260
|
+
db.sql("""
|
261
|
+
create or replace table closure_id as select subject_id as id, array_agg(object_id) as closure from closure group by subject_id
|
262
|
+
""")
|
263
|
+
|
264
|
+
db.sql("""
|
265
|
+
create or replace table closure_label as select subject_id as id, array_agg(name) as closure_label from closure join nodes on object_id = id
|
266
|
+
group by subject_id
|
267
|
+
""")
|
268
|
+
|
269
|
+
db.sql("""
|
270
|
+
create or replace table descendants_id as
|
271
|
+
select object_id as id, array_agg(subject_id) as descendants
|
272
|
+
from closure
|
273
|
+
group by object_id
|
274
|
+
""")
|
275
|
+
|
276
|
+
db.sql("""
|
277
|
+
create or replace table descendants_label as
|
278
|
+
select object_id as id, array_agg(name) as descendants_label
|
279
|
+
from closure
|
280
|
+
join nodes on subject_id = nodes.id
|
281
|
+
group by object_id
|
282
|
+
""")
|
283
|
+
|
284
|
+
# Get edges table schema to determine which fields are VARCHAR[]
|
285
|
+
edges_table_info = db.sql("DESCRIBE edges").fetchall()
|
286
|
+
edges_table_types = {col[0]: col[1] for col in edges_table_info}
|
287
|
+
edges_column_names = [col[0] for col in edges_table_info]
|
288
|
+
|
289
|
+
# Get nodes table schema to check for available columns
|
290
|
+
nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
|
291
|
+
node_column_names = [col[0] for col in nodes_table_info]
|
292
|
+
|
293
|
+
# Build edge joins with proper multivalued field handling
|
294
|
+
edge_field_joins = []
|
295
|
+
for field in edge_fields:
|
296
|
+
is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
|
297
|
+
edge_field_joins.append(edge_joins(field, is_multivalued=is_multivalued))
|
298
|
+
|
299
|
+
edge_field_to_label_joins = []
|
300
|
+
for field in edge_fields_to_label:
|
301
|
+
is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
|
302
|
+
edge_field_to_label_joins.append(edge_joins(field, include_closure_joins=False, is_multivalued=is_multivalued))
|
303
|
+
|
304
|
+
edges_query = f"""
|
305
|
+
create or replace table denormalized_edges as
|
306
|
+
select edges.*,
|
307
|
+
{"".join([edge_columns(field, node_column_names=node_column_names) for field in edge_fields])}
|
308
|
+
{"".join([edge_columns(field, include_closure_fields=False, node_column_names=node_column_names) for field in edge_fields_to_label])}
|
309
|
+
{evidence_sum(evidence_fields, edges_column_names)}
|
310
|
+
{grouping_key(grouping_fields, edges_column_names)}
|
311
|
+
from edges
|
312
|
+
{"".join(edge_field_joins)}
|
313
|
+
{"".join(edge_field_to_label_joins)}
|
314
|
+
"""
|
315
|
+
|
316
|
+
print(edges_query)
|
317
|
+
|
318
|
+
additional_node_constraints = f"where {additional_node_constraints}" if additional_node_constraints else ""
|
319
|
+
|
320
|
+
# Get nodes table info to handle multivalued fields in the query
|
321
|
+
nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
|
322
|
+
nodes_table_column_names = [col[0] for col in nodes_table_info]
|
323
|
+
nodes_table_types = {col[0]: col[1] for col in nodes_table_info}
|
324
|
+
|
325
|
+
# Create field selections for nodes, converting VARCHAR[] back to pipe-delimited strings
|
326
|
+
nodes_field_selections = []
|
327
|
+
for field in nodes_table_column_names:
|
328
|
+
if field in multivalued_fields and 'VARCHAR[]' in nodes_table_types[field].upper():
|
329
|
+
# Convert VARCHAR[] back to pipe-delimited string
|
330
|
+
nodes_field_selections.append(f"list_aggregate({field}, 'string_agg', '|') as {field}")
|
331
|
+
else:
|
332
|
+
# Regular field, use as-is (but need to specify for GROUP BY)
|
333
|
+
nodes_field_selections.append(f"nodes.{field}")
|
334
|
+
|
335
|
+
nodes_base_fields = ",\n ".join(nodes_field_selections)
|
336
|
+
|
337
|
+
nodes_query = f"""
|
338
|
+
create or replace table denormalized_nodes as
|
339
|
+
select {nodes_base_fields},
|
340
|
+
{"".join([node_columns(node_field) for node_field in node_fields])}
|
341
|
+
from nodes
|
342
|
+
{node_joins('has_phenotype')}
|
343
|
+
{additional_node_constraints}
|
344
|
+
group by {", ".join([f"nodes.{field}" for field in nodes_table_column_names])}
|
345
|
+
"""
|
346
|
+
print(nodes_query)
|
347
|
+
|
348
|
+
|
349
|
+
if not dry_run:
|
350
|
+
|
351
|
+
db.sql(edges_query)
|
352
|
+
|
353
|
+
# Export edges to TSV only if requested
|
354
|
+
if export_edges:
|
355
|
+
edge_closure_replacements = [
|
356
|
+
f"""
|
357
|
+
list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
|
358
|
+
list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
|
359
|
+
"""
|
360
|
+
for field in edge_fields
|
361
|
+
]
|
362
|
+
|
363
|
+
# Add conversions for original multivalued fields back to pipe-delimited strings
|
364
|
+
edge_table_info = db.sql("DESCRIBE denormalized_edges").fetchall()
|
365
|
+
edge_table_column_names = [col[0] for col in edge_table_info]
|
366
|
+
edge_table_types = {col[0]: col[1] for col in edge_table_info}
|
367
|
+
|
368
|
+
# Create set of closure fields already handled by edge_closure_replacements
|
369
|
+
closure_fields_handled = set()
|
370
|
+
for field in edge_fields:
|
371
|
+
closure_fields_handled.add(f"{field}_closure")
|
372
|
+
closure_fields_handled.add(f"{field}_closure_label")
|
373
|
+
|
374
|
+
multivalued_replacements = [
|
375
|
+
f"list_aggregate({field}, 'string_agg', '|') as {field}"
|
376
|
+
for field in multivalued_fields
|
377
|
+
if field in edge_table_column_names and 'VARCHAR[]' in edge_table_types[field].upper()
|
378
|
+
and field not in closure_fields_handled
|
379
|
+
]
|
380
|
+
|
381
|
+
all_replacements = edge_closure_replacements + multivalued_replacements
|
382
|
+
edge_closure_replacements = "REPLACE (\n" + ",\n".join(all_replacements) + ")\n"
|
383
|
+
|
384
|
+
edges_export_query = f"""
|
385
|
+
-- write denormalized_edges as tsv
|
386
|
+
copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
|
387
|
+
"""
|
388
|
+
print(edges_export_query)
|
389
|
+
db.sql(edges_export_query)
|
390
|
+
|
391
|
+
db.sql(nodes_query)
|
392
|
+
|
393
|
+
# Add descendant columns separately to avoid memory issues with large GROUP BY
|
394
|
+
print("Adding descendant columns to denormalized_nodes...")
|
395
|
+
db.sql("alter table denormalized_nodes add column has_descendant VARCHAR[]")
|
396
|
+
db.sql("alter table denormalized_nodes add column has_descendant_label VARCHAR[]")
|
397
|
+
db.sql("alter table denormalized_nodes add column has_descendant_count INTEGER")
|
398
|
+
|
399
|
+
db.sql("""
|
400
|
+
update denormalized_nodes
|
401
|
+
set has_descendant = descendants_id.descendants
|
402
|
+
from descendants_id
|
403
|
+
where denormalized_nodes.id = descendants_id.id
|
404
|
+
""")
|
405
|
+
|
406
|
+
db.sql("""
|
407
|
+
update denormalized_nodes
|
408
|
+
set has_descendant_label = descendants_label.descendants_label
|
409
|
+
from descendants_label
|
410
|
+
where denormalized_nodes.id = descendants_label.id
|
411
|
+
""")
|
412
|
+
|
413
|
+
db.sql("""
|
414
|
+
update denormalized_nodes
|
415
|
+
set has_descendant_count = coalesce(array_length(has_descendant), 0)
|
416
|
+
""")
|
417
|
+
|
418
|
+
# Export nodes to TSV only if requested
|
419
|
+
if export_nodes:
|
420
|
+
# Get denormalized_nodes table info to handle array fields in export
|
421
|
+
denorm_nodes_table_info = db.sql("DESCRIBE denormalized_nodes").fetchall()
|
422
|
+
denorm_nodes_column_names = [col[0] for col in denorm_nodes_table_info]
|
423
|
+
denorm_nodes_types = {col[0]: col[1] for col in denorm_nodes_table_info}
|
424
|
+
|
425
|
+
# Find all VARCHAR[] fields that need conversion to pipe-delimited strings
|
426
|
+
array_field_replacements = [
|
427
|
+
f"list_aggregate({field}, 'string_agg', '|') as {field}"
|
428
|
+
for field in denorm_nodes_column_names
|
429
|
+
if 'VARCHAR[]' in denorm_nodes_types[field].upper()
|
430
|
+
]
|
431
|
+
|
432
|
+
# The descendants fields are already handled by the general VARCHAR[] logic above
|
433
|
+
# No need to add them separately
|
434
|
+
|
435
|
+
if array_field_replacements:
|
436
|
+
nodes_replacements = "REPLACE (\n" + ",\n".join(array_field_replacements) + ")\n"
|
437
|
+
nodes_export_query = f"""
|
438
|
+
-- write denormalized_nodes as tsv
|
439
|
+
copy (select * {nodes_replacements} from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
|
440
|
+
"""
|
441
|
+
else:
|
442
|
+
nodes_export_query = f"""
|
443
|
+
-- write denormalized_nodes as tsv
|
444
|
+
copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
|
445
|
+
"""
|
446
|
+
print(nodes_export_query)
|
447
|
+
db.sql(nodes_export_query)
|
@@ -1,193 +0,0 @@
|
|
1
|
-
from typing import List, Optional
|
2
|
-
|
3
|
-
import os
|
4
|
-
import tarfile
|
5
|
-
import duckdb
|
6
|
-
|
7
|
-
def edge_columns(field: str, include_closure_fields: bool =True):
|
8
|
-
column_text = f"""
|
9
|
-
{field}.name as {field}_label,
|
10
|
-
{field}.category as {field}_category,
|
11
|
-
{field}.namespace as {field}_namespace,
|
12
|
-
"""
|
13
|
-
if include_closure_fields:
|
14
|
-
column_text += f"""
|
15
|
-
{field}_closure.closure as {field}_closure,
|
16
|
-
{field}_closure_label.closure_label as {field}_closure_label,
|
17
|
-
"""
|
18
|
-
|
19
|
-
if field in ['subject', 'object']:
|
20
|
-
column_text += f"""
|
21
|
-
{field}.in_taxon as {field}_taxon,
|
22
|
-
{field}.in_taxon_label as {field}_taxon_label,
|
23
|
-
"""
|
24
|
-
return column_text
|
25
|
-
|
26
|
-
def edge_joins(field: str, include_closure_joins: bool =True):
|
27
|
-
return f"""
|
28
|
-
left outer join nodes as {field} on edges.{field} = {field}.id
|
29
|
-
left outer join closure_id as {field}_closure on {field}.id = {field}_closure.id
|
30
|
-
left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id
|
31
|
-
"""
|
32
|
-
|
33
|
-
def evidence_sum(evidence_fields: List[str]):
|
34
|
-
""" Sum together the length of each field after splitting on | """
|
35
|
-
evidence_count_sum = "+".join([f"ifnull(len(split({field}, '|')),0)" for field in evidence_fields])
|
36
|
-
return f"{evidence_count_sum} as evidence_count,"
|
37
|
-
|
38
|
-
|
39
|
-
def node_columns(predicate):
|
40
|
-
# strip the biolink predicate, if necessary to get the field name
|
41
|
-
field = predicate.replace('biolink:','')
|
42
|
-
|
43
|
-
return f"""
|
44
|
-
string_agg({field}_edges.object, '|') as {field},
|
45
|
-
string_agg({field}_edges.object_label, '|') as {field}_label,
|
46
|
-
count (distinct {field}_edges.object) as {field}_count,
|
47
|
-
list_aggregate(list_distinct(flatten(array_agg({field}_closure.closure))), 'string_agg', '|') as {field}_closure,
|
48
|
-
list_aggregate(list_distinct(flatten(array_agg({field}_closure_label.closure_label))), 'string_agg', '|') as {field}_closure_label,
|
49
|
-
"""
|
50
|
-
|
51
|
-
def node_joins(predicate):
|
52
|
-
# strip the biolink predicate, if necessary to get the field name
|
53
|
-
field = predicate.replace('biolink:','')
|
54
|
-
return f"""
|
55
|
-
left outer join denormalized_edges as {field}_edges
|
56
|
-
on nodes.id = {field}_edges.subject
|
57
|
-
and {field}_edges.predicate = 'biolink:{field}'
|
58
|
-
left outer join closure_id as {field}_closure
|
59
|
-
on {field}_edges.object = {field}_closure.id
|
60
|
-
left outer join closure_label as {field}_closure_label
|
61
|
-
on {field}_edges.object = {field}_closure_label.id
|
62
|
-
"""
|
63
|
-
|
64
|
-
|
65
|
-
def grouping_key(grouping_fields):
|
66
|
-
fragments = []
|
67
|
-
for field in grouping_fields:
|
68
|
-
if field == 'negated':
|
69
|
-
fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
|
70
|
-
else:
|
71
|
-
fragments.append(field)
|
72
|
-
grouping_key_fragments = ", ".join(fragments)
|
73
|
-
return f"concat_ws('|', {grouping_key_fragments}) as grouping_key"
|
74
|
-
|
75
|
-
|
76
|
-
def add_closure(kg_archive: str,
|
77
|
-
closure_file: str,
|
78
|
-
nodes_output_file: str,
|
79
|
-
edges_output_file: str,
|
80
|
-
node_fields: List[str] = [],
|
81
|
-
edge_fields: List[str] = ['subject', 'object'],
|
82
|
-
edge_fields_to_label: List[str] = [],
|
83
|
-
additional_node_constraints: Optional[str] = None,
|
84
|
-
dry_run: bool = False,
|
85
|
-
evidence_fields: List[str] = ['has_evidence', 'publications'],
|
86
|
-
grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object']
|
87
|
-
):
|
88
|
-
print("Generating closure KG...")
|
89
|
-
print(f"kg_archive: {kg_archive}")
|
90
|
-
print(f"closure_file: {closure_file}")
|
91
|
-
|
92
|
-
db = duckdb.connect(database='monarch-kg.duckdb')
|
93
|
-
|
94
|
-
if not dry_run:
|
95
|
-
print(f"fields: {','.join(edge_fields)}")
|
96
|
-
print(f"output_file: {edges_output_file}")
|
97
|
-
|
98
|
-
tar = tarfile.open(f"{kg_archive}")
|
99
|
-
|
100
|
-
print("Loading node table...")
|
101
|
-
node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
|
102
|
-
tar.extract(node_file_name,)
|
103
|
-
node_file = f"{node_file_name}"
|
104
|
-
print(f"node_file: {node_file}")
|
105
|
-
|
106
|
-
db.sql(f"""
|
107
|
-
create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
|
108
|
-
""")
|
109
|
-
|
110
|
-
edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
|
111
|
-
tar.extract(edge_file_name)
|
112
|
-
edge_file = f"{edge_file_name}"
|
113
|
-
print(f"edge_file: {edge_file}")
|
114
|
-
|
115
|
-
db.sql(f"""
|
116
|
-
create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
|
117
|
-
""")
|
118
|
-
|
119
|
-
# Load the relation graph tsv in long format mapping a node to each of it's ancestors
|
120
|
-
db.sql(f"""
|
121
|
-
create or replace table closure as select * from read_csv('{closure_file}', sep='\t', names=['subject_id', 'predicate_id', 'object_id'], AUTO_DETECT=TRUE)
|
122
|
-
""")
|
123
|
-
|
124
|
-
db.sql("""
|
125
|
-
create or replace table closure_id as select subject_id as id, array_agg(object_id) as closure from closure group by subject_id
|
126
|
-
""")
|
127
|
-
|
128
|
-
db.sql("""
|
129
|
-
create or replace table closure_label as select subject_id as id, array_agg(name) as closure_label from closure join nodes on object_id = id
|
130
|
-
group by subject_id
|
131
|
-
""")
|
132
|
-
|
133
|
-
edges_query = f"""
|
134
|
-
create or replace table denormalized_edges as
|
135
|
-
select edges.*,
|
136
|
-
{"".join([edge_columns(field) for field in edge_fields])}
|
137
|
-
{"".join([edge_columns(field, include_closure_fields=False) for field in edge_fields_to_label])}
|
138
|
-
{evidence_sum(evidence_fields)}
|
139
|
-
{grouping_key(grouping_fields)}
|
140
|
-
from edges
|
141
|
-
{"".join([edge_joins(field) for field in edge_fields])}
|
142
|
-
{"".join([edge_joins(field, include_closure_joins=False) for field in edge_fields_to_label])}
|
143
|
-
"""
|
144
|
-
|
145
|
-
print(edges_query)
|
146
|
-
|
147
|
-
additional_node_constraints = f"where {additional_node_constraints}" if additional_node_constraints else ""
|
148
|
-
nodes_query = f"""
|
149
|
-
create or replace table denormalized_nodes as
|
150
|
-
select nodes.*,
|
151
|
-
{"".join([node_columns(node_field) for node_field in node_fields])}
|
152
|
-
from nodes
|
153
|
-
{node_joins('has_phenotype')}
|
154
|
-
{additional_node_constraints}
|
155
|
-
group by nodes.*
|
156
|
-
"""
|
157
|
-
print(nodes_query)
|
158
|
-
|
159
|
-
|
160
|
-
if not dry_run:
|
161
|
-
|
162
|
-
db.sql(edges_query)
|
163
|
-
|
164
|
-
edge_closure_replacements = [
|
165
|
-
f"""
|
166
|
-
list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
|
167
|
-
list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
|
168
|
-
"""
|
169
|
-
for field in edge_fields
|
170
|
-
]
|
171
|
-
|
172
|
-
edge_closure_replacements = "REPLACE (\n" + ",\n".join(edge_closure_replacements) + ")\n"
|
173
|
-
|
174
|
-
edges_export_query = f"""
|
175
|
-
-- write denormalized_edges as tsv
|
176
|
-
copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
|
177
|
-
"""
|
178
|
-
print(edges_export_query)
|
179
|
-
db.sql(edges_export_query)
|
180
|
-
|
181
|
-
db.sql(nodes_query)
|
182
|
-
nodes_export_query = f"""
|
183
|
-
-- write denormalized_nodes as tsv
|
184
|
-
copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
|
185
|
-
"""
|
186
|
-
print(nodes_export_query)
|
187
|
-
db.sql(nodes_export_query)
|
188
|
-
|
189
|
-
# Clean up extracted node & edge files
|
190
|
-
if os.path.exists(f"{node_file}"):
|
191
|
-
os.remove(f"{node_file}")
|
192
|
-
if os.path.exists(f"{edge_file}"):
|
193
|
-
os.remove(f"{edge_file}")
|