setta 0.0.3.dev3__py3-none-any.whl → 0.0.3.dev5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
setta/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.3.dev3"
1
+ __version__ = "0.0.3.dev5"
setta/cli/__init__.py CHANGED
@@ -1 +1 @@
1
- from .logger import Setta, SettaList, SettaImg
1
+ from .logger import Setta, SettaImg, SettaList
@@ -402,8 +402,46 @@ def prune_and_topological_sort(code_dict, to_keep):
402
402
  return topological_sort(section_dependencies), section_dependencies
403
403
 
404
404
 
405
+ # TODO: eliminate redundancy between this and prune_and_topological_sort
406
+ def prune_and_find_top_nodes(code_dict, to_keep):
407
+ section_dependencies = {k: v["section_dependencies"] for k, v in code_dict.items()}
408
+ section_dependencies = prune_dict(section_dependencies, to_keep)
409
+ return find_top_nodes(section_dependencies), section_dependencies
410
+
411
+
412
+ def get_import_order_for_top_node(top_node, dependency_dict):
413
+ # Build a subgraph consisting of top_node and all its descendants.
414
+ subgraph = get_subgraph(top_node, dependency_dict)
415
+ # Get a topologically sorted list where each dependency comes before the node that depends on it.
416
+ sorted_nodes = topological_sort(subgraph)
417
+ # Because you plan to import starting from the end of the list and work backwards,
418
+ # reverse the topological order so that the deepest dependency is imported first.
419
+ import_order = sorted_nodes[::-1]
420
+ return import_order
421
+
422
+
423
+ def find_top_nodes(dependency_dict):
424
+ all_nodes = set(dependency_dict.keys())
425
+ all_deps = {dep for deps in dependency_dict.values() for dep in deps}
426
+ return all_nodes - all_deps
427
+
428
+
429
+ def get_subgraph(top_node, dependency_dict):
430
+ visited = set()
431
+
432
+ def dfs(node):
433
+ if node not in visited:
434
+ visited.add(node)
435
+ for dep in dependency_dict.get(node, []):
436
+ dfs(dep)
437
+
438
+ dfs(top_node)
439
+ return {node: dependency_dict.get(node, []) for node in visited}
440
+
441
+
442
+ # TODO: is this function actually necessary anymore?
405
443
  def topological_sort(objects):
406
- graph = {id: refs for id, refs in objects.items()}
444
+ graph = {node: refs for node, refs in objects.items()}
407
445
  visiting, visited = set(), set()
408
446
  order = []
409
447
 
@@ -412,16 +450,15 @@ def topological_sort(objects):
412
450
  raise ValueError("Circular reference detected")
413
451
  if node not in visited:
414
452
  visiting.add(node)
415
- for neighbour in graph[node]:
416
- dfs(neighbour)
453
+ for neighbor in graph.get(node, []):
454
+ dfs(neighbor)
417
455
  visiting.remove(node)
418
456
  visited.add(node)
419
457
  order.append(node)
420
458
 
421
- for id in objects.keys():
422
- if id not in visited:
423
- dfs(id)
424
-
459
+ for node in graph.keys():
460
+ if node not in visited:
461
+ dfs(node)
425
462
  return order[::-1]
426
463
 
427
464
 
@@ -4,7 +4,8 @@ from pydantic import BaseModel
4
4
 
5
5
  from setta.code_gen.create_runnable_scripts import (
6
6
  generate_final_code_for_sections,
7
- prune_and_topological_sort,
7
+ get_import_order_for_top_node,
8
+ prune_and_find_top_nodes,
8
9
  sanitize_section_path_full_name,
9
10
  )
10
11
  from setta.code_gen.export_selected import (
@@ -13,6 +14,7 @@ from setta.code_gen.export_selected import (
13
14
  get_section_code,
14
15
  get_section_type,
15
16
  )
17
+ from setta.code_gen.find_placeholders import parse_template_var
16
18
  from setta.tasks.fns.utils import replace_template_vars_with_random_names
17
19
  from setta.utils.constants import C
18
20
  from setta.utils.utils import multireplace
@@ -23,7 +25,7 @@ router = APIRouter()
23
25
 
24
26
 
25
27
  class UpdateInteractiveCodeRequest(BaseModel):
26
- project: dict
28
+ projects: list
27
29
 
28
30
 
29
31
  class FormatCodeRequest(BaseModel):
@@ -37,7 +39,19 @@ async def route_update_interactive_code(
37
39
  tasks=Depends(get_tasks),
38
40
  lsp_writers=Depends(get_lsp_writers),
39
41
  ):
40
- p = x.project
42
+ idx = 0
43
+ dependencies = set()
44
+ content = []
45
+ for p in x.projects:
46
+ result = await update_interactive_code(p, tasks, lsp_writers, idx)
47
+ dependencies.update(result["dependencies"])
48
+ content.extend(result["content"])
49
+ idx += 1
50
+ dependencies = list(dependencies)
51
+ return {"dependencies": dependencies, "content": content}
52
+
53
+
54
+ async def update_interactive_code(p, tasks, lsp_writers, idx):
41
55
  exporter_obj = export_selected(
42
56
  p, always_export_args_objs=False, force_include_template_var=True
43
57
  )
@@ -46,11 +60,11 @@ async def route_update_interactive_code(
46
60
  template_var_replacement_values = {}
47
61
  for variant in p["sectionVariants"].values():
48
62
  for t in variant["templateVars"]:
49
- if not t["sectionId"]:
50
- continue
51
- template_var_replacement_values[
52
- t["keyword"]
53
- ] = create_in_memory_module_name(p, t["sectionId"])
63
+ _, keyword_type = parse_template_var(t["keyword"])
64
+ if t["sectionId"] and keyword_type == C.TEMPLATE_VAR_IMPORT_PATH_SUFFIX:
65
+ template_var_replacement_values[
66
+ t["keyword"]
67
+ ] = create_in_memory_module_name(p, t["sectionId"])
54
68
 
55
69
  code_dict = await generate_final_code_for_sections(
56
70
  p,
@@ -62,24 +76,29 @@ async def route_update_interactive_code(
62
76
  template_var_replacement_values=template_var_replacement_values,
63
77
  )
64
78
 
65
- to_import, _ = prune_and_topological_sort(code_dict, p["importCodeBlocks"])
66
- to_import = to_import[::-1] # we want to import the dependencies first
67
- code_list = []
68
- for section_id in to_import:
69
- v = code_dict[section_id]
70
- task_name = create_in_memory_module_name(p, section_id)
71
- code_list.append(
79
+ top_node_ids, section_dependencies = prune_and_find_top_nodes(
80
+ code_dict, p["runCodeBlocks"]
81
+ )
82
+ code_graph = []
83
+ project_config_id = p["projectConfig"]["id"]
84
+ for section_id in top_node_ids:
85
+ code_graph.append(
72
86
  {
73
- "code": v["code"],
74
- "module_name": task_name,
87
+ "subprocess_key": f"{project_config_id}-{section_id}-{idx}",
88
+ "code": code_dict[section_id]["code"],
89
+ "imports": get_import_order_for_top_node(
90
+ section_id, section_dependencies
91
+ ),
92
+ "module_name": create_in_memory_module_name(p, section_id),
75
93
  }
76
94
  )
77
95
 
78
- metadata, error_msgs, content = await tasks.add_custom_fns(
79
- code_list,
96
+ dependencies, content = await tasks.add_custom_fns(
97
+ code_graph,
80
98
  to_cache=exporter_obj_in_memory,
81
99
  )
82
- return {"metadata": metadata, "errorMsgs": error_msgs, "content": content}
100
+
101
+ return {"dependencies": dependencies, "content": content}
83
102
 
84
103
 
85
104
  @router.post(C.ROUTE_FORMAT_CODE)