closurizer 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
closurizer/cli.py CHANGED
@@ -4,7 +4,8 @@ from closurizer.closurizer import add_closure
4
4
 
5
5
 
6
6
  @click.command()
7
- @click.option('--kg', required=True, help='KGX tar.gz archive')
7
+ @click.option('--kg', required=False, help='KGX tar.gz archive to import (optional - if not provided, uses existing database)')
8
+ @click.option('--database', required=False, default='monarch-kg.duckdb', help='Database path - if kg is provided, data will be loaded into this database; if kg is not provided, this database must already exist with nodes and edges tables (default: monarch-kg.duckdb)')
8
9
  @click.option('--closure', required=True, help='TSV file of closure triples')
9
10
  @click.option('--nodes-output', required=True, help='file write nodes kgx file with closure fields added')
10
11
  @click.option('--edges-output', required=True, help='file write edges kgx file with closure fields added')
@@ -14,8 +15,12 @@ from closurizer.closurizer import add_closure
14
15
  @click.option('--edge-fields-to-label', multiple=True, help='edge fields to with category, label, etc but not full closure exansion')
15
16
  @click.option('--node-fields', multiple=True, help='node fields to expand with closure IDs, labels, etc')
16
17
  @click.option('--grouping-fields', multiple=True, help='fields to populate a single value grouping_key field')
18
+ @click.option('--multivalued-fields', multiple=True, help='fields containing pipe-delimited values to convert to varchar[] arrays in database')
17
19
  @click.option('--dry-run', is_flag=True, help='A dry run will not write the output file, but will print the SQL query')
20
+ @click.option('--export-edges', is_flag=True, help='Export denormalized edges to TSV file (default: database only)')
21
+ @click.option('--export-nodes', is_flag=True, help='Export denormalized nodes to TSV file (default: database only)')
18
22
  def main(kg: str,
23
+ database: str,
19
24
  closure: str,
20
25
  nodes_output: str,
21
26
  edges_output: str,
@@ -24,17 +29,25 @@ def main(kg: str,
24
29
  edge_fields: List[str] = None,
25
30
  edge_fields_to_label: List[str] = None,
26
31
  node_fields: List[str] = None,
27
- grouping_fields: List[str] = None):
28
- add_closure(kg_archive=kg,
29
- closure_file=closure,
30
- edge_fields=edge_fields,
31
- edge_fields_to_label=edge_fields_to_label,
32
- node_fields=node_fields,
33
- edges_output_file=edges_output,
32
+ grouping_fields: List[str] = None,
33
+ multivalued_fields: List[str] = None,
34
+ export_edges: bool = False,
35
+ export_nodes: bool = False):
36
+
37
+ add_closure(closure_file=closure,
34
38
  nodes_output_file=nodes_output,
39
+ edges_output_file=edges_output,
40
+ kg_archive=kg,
41
+ database_path=database,
42
+ edge_fields=edge_fields or ['subject', 'object'],
43
+ edge_fields_to_label=edge_fields_to_label or [],
44
+ node_fields=node_fields or [],
35
45
  additional_node_constraints=additional_node_constraints,
36
46
  dry_run=dry_run,
37
- grouping_fields=grouping_fields)
47
+ grouping_fields=grouping_fields or ['subject', 'negated', 'predicate', 'object'],
48
+ multivalued_fields=multivalued_fields or ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
49
+ export_edges=export_edges,
50
+ export_nodes=export_nodes)
38
51
 
39
52
  if __name__ == "__main__":
40
53
  main()
closurizer/closurizer.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  import tarfile
5
5
  import duckdb
6
6
 
7
- def edge_columns(field: str, include_closure_fields: bool =True):
7
+ def edge_columns(field: str, include_closure_fields: bool =True, node_column_names: list = None):
8
8
  column_text = f"""
9
9
  {field}.name as {field}_label,
10
10
  {field}.category as {field}_category,
@@ -16,23 +16,46 @@ def edge_columns(field: str, include_closure_fields: bool =True):
16
16
  {field}_closure_label.closure_label as {field}_closure_label,
17
17
  """
18
18
 
19
- if field in ['subject', 'object']:
20
- column_text += f"""
21
- {field}.in_taxon as {field}_taxon,
22
- {field}.in_taxon_label as {field}_taxon_label,
19
+ # Only add taxon fields if they exist in the nodes table
20
+ if field in ['subject', 'object'] and node_column_names:
21
+ if 'in_taxon' in node_column_names:
22
+ column_text += f"""
23
+ {field}.in_taxon as {field}_taxon,"""
24
+ if 'in_taxon_label' in node_column_names:
25
+ column_text += f"""
26
+ {field}.in_taxon_label as {field}_taxon_label,"""
27
+ column_text += """
23
28
  """
24
29
  return column_text
25
30
 
26
- def edge_joins(field: str, include_closure_joins: bool =True):
27
- return f"""
28
- left outer join nodes as {field} on edges.{field} = {field}.id
31
+ def edge_joins(field: str, include_closure_joins: bool =True, is_multivalued: bool = False):
32
+ if is_multivalued:
33
+ # For VARCHAR[] fields, use array containment with list_contains
34
+ join_condition = f"list_contains(edges.{field}, {field}.id)"
35
+ else:
36
+ # For VARCHAR fields, use direct equality
37
+ join_condition = f"edges.{field} = {field}.id"
38
+
39
+ joins = f"""
40
+ left outer join nodes as {field} on {join_condition}"""
41
+
42
+ if include_closure_joins:
43
+ joins += f"""
29
44
  left outer join closure_id as {field}_closure on {field}.id = {field}_closure.id
30
- left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id
31
- """
32
-
33
- def evidence_sum(evidence_fields: List[str]):
34
- """ Sum together the length of each field after splitting on | """
35
- evidence_count_sum = "+".join([f"ifnull(len(split({field}, '|')),0)" for field in evidence_fields])
45
+ left outer join closure_label as {field}_closure_label on {field}.id = {field}_closure_label.id"""
46
+
47
+ return joins + "\n "
48
+
49
+ def evidence_sum(evidence_fields: List[str], edges_column_names: list = None):
50
+ """ Sum together the length of each field - assumes fields are VARCHAR[] arrays """
51
+ evidence_count_parts = []
52
+ for field in evidence_fields:
53
+ # Only include fields that actually exist in the edges table
54
+ if not edges_column_names or field in edges_column_names:
55
+ # All evidence fields are expected to be VARCHAR[] arrays
56
+ evidence_count_parts.append(f"ifnull(array_length({field}),0)")
57
+
58
+ evidence_count_sum = "+".join(evidence_count_parts) if evidence_count_parts else "0"
36
59
  return f"{evidence_count_sum} as evidence_count,"
37
60
 
38
61
 
@@ -41,11 +64,11 @@ def node_columns(predicate):
41
64
  field = predicate.replace('biolink:','')
42
65
 
43
66
  return f"""
44
- string_agg({field}_edges.object, '|') as {field},
45
- string_agg({field}_edges.object_label, '|') as {field}_label,
67
+ case when count(distinct {field}_edges.object) > 0 then array_agg(distinct {field}_edges.object) else null end as {field},
68
+ case when count(distinct {field}_edges.object_label) > 0 then array_agg(distinct {field}_edges.object_label) else null end as {field}_label,
46
69
  count (distinct {field}_edges.object) as {field}_count,
47
- list_aggregate(list_distinct(flatten(array_agg({field}_closure.closure))), 'string_agg', '|') as {field}_closure,
48
- list_aggregate(list_distinct(flatten(array_agg({field}_closure_label.closure_label))), 'string_agg', '|') as {field}_closure_label,
70
+ case when count({field}_closure.closure) > 0 then list_distinct(flatten(array_agg({field}_closure.closure))) else null end as {field}_closure,
71
+ case when count({field}_closure_label.closure_label) > 0 then list_distinct(flatten(array_agg({field}_closure_label.closure_label))) else null end as {field}_closure_label,
49
72
  """
50
73
 
51
74
  def node_joins(predicate):
@@ -62,59 +85,172 @@ def node_joins(predicate):
62
85
  """
63
86
 
64
87
 
65
- def grouping_key(grouping_fields):
88
+ def grouping_key(grouping_fields, edges_column_names: list = None):
89
+ if not grouping_fields:
90
+ return "null as grouping_key"
66
91
  fragments = []
67
92
  for field in grouping_fields:
68
- if field == 'negated':
69
- fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
70
- else:
71
- fragments.append(field)
93
+ # Only include fields that actually exist in the edges table
94
+ if not edges_column_names or field in edges_column_names:
95
+ if field == 'negated':
96
+ fragments.append(f"coalesce(cast({field} as varchar).replace('true','NOT'), '')")
97
+ else:
98
+ fragments.append(field)
99
+ if not fragments:
100
+ return "null as grouping_key"
72
101
  grouping_key_fragments = ", ".join(fragments)
73
102
  return f"concat_ws('|', {grouping_key_fragments}) as grouping_key"
74
103
 
75
104
 
76
- def add_closure(kg_archive: str,
77
- closure_file: str,
105
+
106
+
107
+ def load_from_archive(kg_archive: str, db, multivalued_fields: List[str]):
108
+ """Load nodes and edges tables from tar.gz archive"""
109
+
110
+ tar = tarfile.open(f"{kg_archive}")
111
+
112
+ print("Loading node table...")
113
+ node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
114
+ tar.extract(node_file_name,)
115
+ node_file = f"{node_file_name}"
116
+ print(f"node_file: {node_file}")
117
+
118
+ db.sql(f"""
119
+ create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
120
+ """)
121
+
122
+ edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
123
+ tar.extract(edge_file_name)
124
+ edge_file = f"{edge_file_name}"
125
+ print(f"edge_file: {edge_file}")
126
+
127
+ db.sql(f"""
128
+ create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
129
+ """)
130
+
131
+ # Convert multivalued fields to arrays
132
+ prepare_multivalued_fields(db, multivalued_fields)
133
+
134
+ # Clean up extracted files
135
+ if os.path.exists(f"{node_file}"):
136
+ os.remove(f"{node_file}")
137
+ if os.path.exists(f"{edge_file}"):
138
+ os.remove(f"{edge_file}")
139
+
140
+
141
+ def prepare_multivalued_fields(db, multivalued_fields: List[str]):
142
+ """Convert specified fields to varchar[] arrays in both nodes and edges tables"""
143
+
144
+ # Convert multivalued fields in nodes table to varchar[] arrays
145
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
146
+ node_column_names = [col[0] for col in nodes_table_info]
147
+ node_column_types = {col[0]: col[1] for col in nodes_table_info}
148
+
149
+ for field in multivalued_fields:
150
+ if field in node_column_names:
151
+ # Check if field is already VARCHAR[] - if so, skip conversion
152
+ if 'VARCHAR[]' in node_column_types[field].upper():
153
+ print(f"Field '{field}' in nodes table is already VARCHAR[], skipping conversion")
154
+ continue
155
+
156
+ print(f"Converting field '{field}' in nodes table to VARCHAR[]")
157
+ # Create a new column with proper array type and replace the original
158
+ db.sql(f"""
159
+ alter table nodes add column {field}_array VARCHAR[]
160
+ """)
161
+ db.sql(f"""
162
+ update nodes set {field}_array =
163
+ case
164
+ when {field} is null or {field} = '' then null
165
+ else split({field}, '|')
166
+ end
167
+ """)
168
+ db.sql(f"""
169
+ alter table nodes drop column {field}
170
+ """)
171
+ db.sql(f"""
172
+ alter table nodes rename column {field}_array to {field}
173
+ """)
174
+
175
+ # Convert multivalued fields in edges table to varchar[] arrays
176
+ edges_table_info = db.sql("DESCRIBE edges").fetchall()
177
+ edge_column_names = [col[0] for col in edges_table_info]
178
+ edge_column_types = {col[0]: col[1] for col in edges_table_info}
179
+
180
+ for field in multivalued_fields:
181
+ if field in edge_column_names:
182
+ # Check if field is already VARCHAR[] - if so, skip conversion
183
+ if 'VARCHAR[]' in edge_column_types[field].upper():
184
+ print(f"Field '{field}' in edges table is already VARCHAR[], skipping conversion")
185
+ continue
186
+
187
+ print(f"Converting field '{field}' in edges table to VARCHAR[]")
188
+ # Create a new column with proper array type and replace the original
189
+ db.sql(f"""
190
+ alter table edges add column {field}_array VARCHAR[]
191
+ """)
192
+ db.sql(f"""
193
+ update edges set {field}_array =
194
+ case
195
+ when {field} is null or {field} = '' then null
196
+ else split({field}, '|')
197
+ end
198
+ """)
199
+ db.sql(f"""
200
+ alter table edges drop column {field}
201
+ """)
202
+ db.sql(f"""
203
+ alter table edges rename column {field}_array to {field}
204
+ """)
205
+
206
+
207
+ def add_closure(closure_file: str,
78
208
  nodes_output_file: str,
79
209
  edges_output_file: str,
210
+ kg_archive: Optional[str] = None,
211
+ database_path: str = 'monarch-kg.duckdb',
80
212
  node_fields: List[str] = [],
81
213
  edge_fields: List[str] = ['subject', 'object'],
82
214
  edge_fields_to_label: List[str] = [],
83
215
  additional_node_constraints: Optional[str] = None,
84
216
  dry_run: bool = False,
85
217
  evidence_fields: List[str] = ['has_evidence', 'publications'],
86
- grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object']
218
+ grouping_fields: List[str] = ['subject', 'negated', 'predicate', 'object'],
219
+ multivalued_fields: List[str] = ['has_evidence', 'publications', 'in_taxon', 'in_taxon_label'],
220
+ export_edges: bool = False,
221
+ export_nodes: bool = False
87
222
  ):
223
+ # Validate input parameters
224
+ if not kg_archive and not os.path.exists(database_path):
225
+ raise ValueError("Either kg_archive must be specified or database_path must exist")
226
+
88
227
  print("Generating closure KG...")
89
- print(f"kg_archive: {kg_archive}")
228
+ if kg_archive:
229
+ print(f"kg_archive: {kg_archive}")
230
+ print(f"database_path: {database_path}")
90
231
  print(f"closure_file: {closure_file}")
91
232
 
92
- db = duckdb.connect(database='monarch-kg.duckdb')
233
+ # Connect to database
234
+ db = duckdb.connect(database=database_path)
93
235
 
94
236
  if not dry_run:
95
237
  print(f"fields: {','.join(edge_fields)}")
96
238
  print(f"output_file: {edges_output_file}")
97
239
 
98
- tar = tarfile.open(f"{kg_archive}")
99
-
100
- print("Loading node table...")
101
- node_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_nodes.tsv') ][0]
102
- tar.extract(node_file_name,)
103
- node_file = f"{node_file_name}"
104
- print(f"node_file: {node_file}")
105
-
106
- db.sql(f"""
107
- create or replace table nodes as select *, substr(id, 1, instr(id,':') -1) as namespace from read_csv('{node_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
108
- """)
109
-
110
- edge_file_name = [member.name for member in tar.getmembers() if member.name.endswith('_edges.tsv') ][0]
111
- tar.extract(edge_file_name)
112
- edge_file = f"{edge_file_name}"
113
- print(f"edge_file: {edge_file}")
114
-
115
- db.sql(f"""
116
- create or replace table edges as select * from read_csv('{edge_file_name}', header=True, sep='\t', AUTO_DETECT=TRUE)
117
- """)
240
+ # Load data based on input method
241
+ if kg_archive:
242
+ load_from_archive(kg_archive, db, multivalued_fields)
243
+ else:
244
+ # Database already exists and contains data
245
+ # Check if namespace column exists, add it if needed
246
+ node_column_names = [col[0] for col in db.sql("DESCRIBE nodes").fetchall()]
247
+ if 'namespace' not in node_column_names:
248
+ print("Adding namespace column to nodes table...")
249
+ db.sql("ALTER TABLE nodes ADD COLUMN namespace VARCHAR")
250
+ db.sql("UPDATE nodes SET namespace = substr(id, 1, instr(id,':') -1)")
251
+
252
+ # Convert multivalued fields to arrays
253
+ prepare_multivalued_fields(db, multivalued_fields)
118
254
 
119
255
  # Load the relation graph tsv in long format mapping a node to each of it's ancestors
120
256
  db.sql(f"""
@@ -130,62 +266,182 @@ def add_closure(kg_archive: str,
130
266
  group by subject_id
131
267
  """)
132
268
 
269
+ db.sql("""
270
+ create or replace table descendants_id as
271
+ select object_id as id, array_agg(subject_id) as descendants
272
+ from closure
273
+ group by object_id
274
+ """)
275
+
276
+ db.sql("""
277
+ create or replace table descendants_label as
278
+ select object_id as id, array_agg(name) as descendants_label
279
+ from closure
280
+ join nodes on subject_id = nodes.id
281
+ group by object_id
282
+ """)
283
+
284
+ # Get edges table schema to determine which fields are VARCHAR[]
285
+ edges_table_info = db.sql("DESCRIBE edges").fetchall()
286
+ edges_table_types = {col[0]: col[1] for col in edges_table_info}
287
+ edges_column_names = [col[0] for col in edges_table_info]
288
+
289
+ # Get nodes table schema to check for available columns
290
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
291
+ node_column_names = [col[0] for col in nodes_table_info]
292
+
293
+ # Build edge joins with proper multivalued field handling
294
+ edge_field_joins = []
295
+ for field in edge_fields:
296
+ is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
297
+ edge_field_joins.append(edge_joins(field, is_multivalued=is_multivalued))
298
+
299
+ edge_field_to_label_joins = []
300
+ for field in edge_fields_to_label:
301
+ is_multivalued = field in multivalued_fields and 'VARCHAR[]' in edges_table_types.get(field, '').upper()
302
+ edge_field_to_label_joins.append(edge_joins(field, include_closure_joins=False, is_multivalued=is_multivalued))
303
+
133
304
  edges_query = f"""
134
305
  create or replace table denormalized_edges as
135
306
  select edges.*,
136
- {"".join([edge_columns(field) for field in edge_fields])}
137
- {"".join([edge_columns(field, include_closure_fields=False) for field in edge_fields_to_label])}
138
- {evidence_sum(evidence_fields)}
139
- {grouping_key(grouping_fields)}
307
+ {"".join([edge_columns(field, node_column_names=node_column_names) for field in edge_fields])}
308
+ {"".join([edge_columns(field, include_closure_fields=False, node_column_names=node_column_names) for field in edge_fields_to_label])}
309
+ {evidence_sum(evidence_fields, edges_column_names)}
310
+ {grouping_key(grouping_fields, edges_column_names)}
140
311
  from edges
141
- {"".join([edge_joins(field) for field in edge_fields])}
142
- {"".join([edge_joins(field, include_closure_joins=False) for field in edge_fields_to_label])}
312
+ {"".join(edge_field_joins)}
313
+ {"".join(edge_field_to_label_joins)}
143
314
  """
144
315
 
145
316
  print(edges_query)
146
317
 
147
318
  additional_node_constraints = f"where {additional_node_constraints}" if additional_node_constraints else ""
319
+
320
+ # Get nodes table info to handle multivalued fields in the query
321
+ nodes_table_info = db.sql("DESCRIBE nodes").fetchall()
322
+ nodes_table_column_names = [col[0] for col in nodes_table_info]
323
+ nodes_table_types = {col[0]: col[1] for col in nodes_table_info}
324
+
325
+ # Create field selections for nodes, converting VARCHAR[] back to pipe-delimited strings
326
+ nodes_field_selections = []
327
+ for field in nodes_table_column_names:
328
+ if field in multivalued_fields and 'VARCHAR[]' in nodes_table_types[field].upper():
329
+ # Convert VARCHAR[] back to pipe-delimited string
330
+ nodes_field_selections.append(f"list_aggregate({field}, 'string_agg', '|') as {field}")
331
+ else:
332
+ # Regular field, use as-is (but need to specify for GROUP BY)
333
+ nodes_field_selections.append(f"nodes.{field}")
334
+
335
+ nodes_base_fields = ",\n ".join(nodes_field_selections)
336
+
148
337
  nodes_query = f"""
149
338
  create or replace table denormalized_nodes as
150
- select nodes.*,
339
+ select {nodes_base_fields},
151
340
  {"".join([node_columns(node_field) for node_field in node_fields])}
152
341
  from nodes
153
342
  {node_joins('has_phenotype')}
154
343
  {additional_node_constraints}
155
- group by nodes.*
344
+ group by {", ".join([f"nodes.{field}" for field in nodes_table_column_names])}
156
345
  """
157
346
  print(nodes_query)
158
347
 
159
348
 
160
349
  if not dry_run:
161
350
 
162
-
163
- edge_closure_replacements = [
164
- f"""
165
- list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
166
- list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
351
+ db.sql(edges_query)
352
+
353
+ # Export edges to TSV only if requested
354
+ if export_edges:
355
+ edge_closure_replacements = [
356
+ f"""
357
+ list_aggregate({field}_closure, 'string_agg', '|') as {field}_closure,
358
+ list_aggregate({field}_closure_label, 'string_agg', '|') as {field}_closure_label
359
+ """
360
+ for field in edge_fields
361
+ ]
362
+
363
+ # Add conversions for original multivalued fields back to pipe-delimited strings
364
+ edge_table_info = db.sql("DESCRIBE denormalized_edges").fetchall()
365
+ edge_table_column_names = [col[0] for col in edge_table_info]
366
+ edge_table_types = {col[0]: col[1] for col in edge_table_info}
367
+
368
+ # Create set of closure fields already handled by edge_closure_replacements
369
+ closure_fields_handled = set()
370
+ for field in edge_fields:
371
+ closure_fields_handled.add(f"{field}_closure")
372
+ closure_fields_handled.add(f"{field}_closure_label")
373
+
374
+ multivalued_replacements = [
375
+ f"list_aggregate({field}, 'string_agg', '|') as {field}"
376
+ for field in multivalued_fields
377
+ if field in edge_table_column_names and 'VARCHAR[]' in edge_table_types[field].upper()
378
+ and field not in closure_fields_handled
379
+ ]
380
+
381
+ all_replacements = edge_closure_replacements + multivalued_replacements
382
+ edge_closure_replacements = "REPLACE (\n" + ",\n".join(all_replacements) + ")\n"
383
+
384
+ edges_export_query = f"""
385
+ -- write denormalized_edges as tsv
386
+ copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
167
387
  """
168
- for field in edge_fields
169
- ]
170
-
171
- edge_closure_replacements = "REPLACE (\n" + ",\n".join(edge_closure_replacements) + ")\n"
172
-
173
- edges_export_query = f"""
174
- -- write denormalized_edges as tsv
175
- copy (select * {edge_closure_replacements} from denormalized_edges) to '{edges_output_file}' (header, delimiter '\t')
176
- """
177
- print(edges_export_query)
178
- db.query(edges_export_query)
179
-
180
- nodes_export_query = f"""
181
- -- write denormalized_nodes as tsv
182
- copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
183
- """
184
- print(nodes_export_query)
185
-
186
-
187
- # Clean up extracted node & edge files
188
- if os.path.exists(f"{node_file}"):
189
- os.remove(f"{node_file}")
190
- if os.path.exists(f"{edge_file}"):
191
- os.remove(f"{edge_file}")
388
+ print(edges_export_query)
389
+ db.sql(edges_export_query)
390
+
391
+ db.sql(nodes_query)
392
+
393
+ # Add descendant columns separately to avoid memory issues with large GROUP BY
394
+ print("Adding descendant columns to denormalized_nodes...")
395
+ db.sql("alter table denormalized_nodes add column has_descendant VARCHAR[]")
396
+ db.sql("alter table denormalized_nodes add column has_descendant_label VARCHAR[]")
397
+ db.sql("alter table denormalized_nodes add column has_descendant_count INTEGER")
398
+
399
+ db.sql("""
400
+ update denormalized_nodes
401
+ set has_descendant = descendants_id.descendants
402
+ from descendants_id
403
+ where denormalized_nodes.id = descendants_id.id
404
+ """)
405
+
406
+ db.sql("""
407
+ update denormalized_nodes
408
+ set has_descendant_label = descendants_label.descendants_label
409
+ from descendants_label
410
+ where denormalized_nodes.id = descendants_label.id
411
+ """)
412
+
413
+ db.sql("""
414
+ update denormalized_nodes
415
+ set has_descendant_count = coalesce(array_length(has_descendant), 0)
416
+ """)
417
+
418
+ # Export nodes to TSV only if requested
419
+ if export_nodes:
420
+ # Get denormalized_nodes table info to handle array fields in export
421
+ denorm_nodes_table_info = db.sql("DESCRIBE denormalized_nodes").fetchall()
422
+ denorm_nodes_column_names = [col[0] for col in denorm_nodes_table_info]
423
+ denorm_nodes_types = {col[0]: col[1] for col in denorm_nodes_table_info}
424
+
425
+ # Find all VARCHAR[] fields that need conversion to pipe-delimited strings
426
+ array_field_replacements = [
427
+ f"list_aggregate({field}, 'string_agg', '|') as {field}"
428
+ for field in denorm_nodes_column_names
429
+ if 'VARCHAR[]' in denorm_nodes_types[field].upper()
430
+ ]
431
+
432
+ # The descendants fields are already handled by the general VARCHAR[] logic above
433
+ # No need to add them separately
434
+
435
+ if array_field_replacements:
436
+ nodes_replacements = "REPLACE (\n" + ",\n".join(array_field_replacements) + ")\n"
437
+ nodes_export_query = f"""
438
+ -- write denormalized_nodes as tsv
439
+ copy (select * {nodes_replacements} from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
440
+ """
441
+ else:
442
+ nodes_export_query = f"""
443
+ -- write denormalized_nodes as tsv
444
+ copy (select * from denormalized_nodes) to '{nodes_output_file}' (header, delimiter '\t')
445
+ """
446
+ print(nodes_export_query)
447
+ db.sql(nodes_export_query)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: closurizer
3
- Version: 0.7.2
3
+ Version: 0.8.0
4
4
  Summary: Add closure expansion fields to kgx files following the Golr pattern
5
5
  Author: Kevin Schaper
6
6
  Author-email: kevin@tislab.org
@@ -0,0 +1,6 @@
1
+ closurizer/cli.py,sha256=o5Cjm6iTcKPD0StW20iqiEa2ZIw_S_NreT4eKT1gE14,3263
2
+ closurizer/closurizer.py,sha256=RIjYvXrHDUzvhbhzk-pfKDU7wgXSZYNJpZwR-aXNv_o,19630
3
+ closurizer-0.8.0.dist-info/METADATA,sha256=N-jbflUDV3z3av12TzmKFxsugwHxeJfr1vSkaZIE88A,661
4
+ closurizer-0.8.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
5
+ closurizer-0.8.0.dist-info/entry_points.txt,sha256=MnAVu1lgP6DqDb3BZGNzVs2AnDMsp4sThi3ccWbONFo,50
6
+ closurizer-0.8.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.0.1
2
+ Generator: poetry-core 2.1.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- closurizer/cli.py,sha256=xTFscsxGDnaKoTNhn1FefRPPeldI5tZvvp3DygNai7Y,2069
2
- closurizer/closurizer.py,sha256=RpP4XrIMLwLfflgYtO8j0n6AWCRNnQKkNOHGjl4dQqg,7494
3
- closurizer-0.7.2.dist-info/METADATA,sha256=IDV6B4WPiu2EIJ-MPFTUwUIefAhgudbMiJx9rewfk2E,661
4
- closurizer-0.7.2.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
5
- closurizer-0.7.2.dist-info/entry_points.txt,sha256=MnAVu1lgP6DqDb3BZGNzVs2AnDMsp4sThi3ccWbONFo,50
6
- closurizer-0.7.2.dist-info/RECORD,,