jaclang 0.7.11__py3-none-any.whl → 0.7.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (31) hide show
  1. jaclang/cli/cli.py +10 -2
  2. jaclang/compiler/absyntree.py +19 -6
  3. jaclang/compiler/parser.py +6 -1
  4. jaclang/compiler/passes/main/import_pass.py +1 -0
  5. jaclang/compiler/passes/main/pyast_gen_pass.py +239 -40
  6. jaclang/compiler/passes/main/pyast_load_pass.py +4 -1
  7. jaclang/compiler/passes/main/tests/test_import_pass.py +5 -1
  8. jaclang/compiler/passes/main/type_check_pass.py +0 -17
  9. jaclang/compiler/passes/tool/fuse_comments_pass.py +14 -2
  10. jaclang/compiler/passes/tool/jac_formatter_pass.py +22 -10
  11. jaclang/compiler/tests/test_importer.py +1 -1
  12. jaclang/core/importer.py +126 -89
  13. jaclang/langserve/engine.py +173 -169
  14. jaclang/langserve/server.py +19 -7
  15. jaclang/langserve/tests/fixtures/base_module_structure.jac +28 -2
  16. jaclang/langserve/tests/fixtures/import_include_statements.jac +1 -1
  17. jaclang/langserve/tests/test_server.py +77 -64
  18. jaclang/langserve/utils.py +266 -0
  19. jaclang/plugin/default.py +4 -2
  20. jaclang/plugin/feature.py +2 -2
  21. jaclang/plugin/spec.py +2 -2
  22. jaclang/tests/fixtures/blankwithentry.jac +3 -0
  23. jaclang/tests/fixtures/deep/one_lev.jac +3 -0
  24. jaclang/tests/fixtures/needs_import.jac +1 -1
  25. jaclang/tests/test_cli.py +6 -6
  26. jaclang/tests/test_language.py +9 -0
  27. jaclang/tests/test_man_code.py +17 -0
  28. {jaclang-0.7.11.dist-info → jaclang-0.7.14.dist-info}/METADATA +1 -1
  29. {jaclang-0.7.11.dist-info → jaclang-0.7.14.dist-info}/RECORD +31 -30
  30. {jaclang-0.7.11.dist-info → jaclang-0.7.14.dist-info}/WHEEL +0 -0
  31. {jaclang-0.7.11.dist-info → jaclang-0.7.14.dist-info}/entry_points.txt +0 -0
jaclang/cli/cli.py CHANGED
@@ -92,22 +92,30 @@ def run(
92
92
  base = base if base else "./"
93
93
  mod = mod[:-4]
94
94
  if filename.endswith(".jac"):
95
- loaded_mod = jac_import(
95
+ ret_module = jac_import(
96
96
  target=mod,
97
97
  base_path=base,
98
98
  cachable=cache,
99
99
  override_name="__main__" if main else None,
100
100
  )
101
+ if ret_module is None:
102
+ loaded_mod = None
103
+ else:
104
+ (loaded_mod,) = ret_module
101
105
  elif filename.endswith(".jir"):
102
106
  with open(filename, "rb") as f:
103
107
  ir = pickle.load(f)
104
- loaded_mod = jac_import(
108
+ ret_module = jac_import(
105
109
  target=mod,
106
110
  base_path=base,
107
111
  cachable=cache,
108
112
  override_name="__main__" if main else None,
109
113
  mod_bundle=ir,
110
114
  )
115
+ if ret_module is None:
116
+ loaded_mod = None
117
+ else:
118
+ (loaded_mod,) = ret_module
111
119
  else:
112
120
  print("Not a .jac file.")
113
121
  return
@@ -621,6 +621,7 @@ class Module(AstDocNode):
621
621
  doc: Optional[String],
622
622
  body: Sequence[ElementStmt | String | EmptyToken],
623
623
  is_imported: bool,
624
+ terminals: list[Token],
624
625
  kid: Sequence[AstNode],
625
626
  stub_only: bool = False,
626
627
  registry: Optional[SemRegistry] = None,
@@ -635,6 +636,7 @@ class Module(AstDocNode):
635
636
  self.test_mod: list[Module] = []
636
637
  self.mod_deps: dict[str, Module] = {}
637
638
  self.registry = registry
639
+ self.terminals: list[Token] = terminals
638
640
  AstNode.__init__(self, kid=kid)
639
641
  AstDocNode.__init__(self, doc=doc)
640
642
 
@@ -642,15 +644,15 @@ class Module(AstDocNode):
642
644
  def annexable_by(self) -> Optional[str]:
643
645
  """Get annexable by."""
644
646
  if not self.stub_only and (
645
- self.loc.mod_path.endswith("impl.jac")
646
- or self.loc.mod_path.endswith("test.jac")
647
+ self.loc.mod_path.endswith(".impl.jac")
648
+ or self.loc.mod_path.endswith(".test.jac")
647
649
  ):
648
650
  head_mod_name = self.name.split(".")[0]
649
651
  potential_path = os.path.join(
650
652
  os.path.dirname(self.loc.mod_path),
651
653
  f"{head_mod_name}.jac",
652
654
  )
653
- if os.path.exists(potential_path):
655
+ if os.path.exists(potential_path) and potential_path != self.loc.mod_path:
654
656
  return potential_path
655
657
  annex_dir = os.path.split(os.path.dirname(self.loc.mod_path))[-1]
656
658
  if annex_dir.endswith(".impl") or annex_dir.endswith(".test"):
@@ -661,7 +663,10 @@ class Module(AstDocNode):
661
663
  os.path.dirname(os.path.dirname(self.loc.mod_path)),
662
664
  f"{head_mod_name}.jac",
663
665
  )
664
- if os.path.exists(potential_path):
666
+ if (
667
+ os.path.exists(potential_path)
668
+ and potential_path != self.loc.mod_path
669
+ ):
665
670
  return potential_path
666
671
  return None
667
672
 
@@ -916,9 +921,17 @@ class ModulePath(AstSymbolNode):
916
921
  self.sub_module: Optional[Module] = None
917
922
 
918
923
  name_spec = alias if alias else path[0] if path else None
924
+
925
+ AstNode.__init__(self, kid=kid)
926
+ if not name_spec:
927
+ pkg_name = self.loc.mod_path
928
+ for _ in range(self.level):
929
+ pkg_name = os.path.dirname(pkg_name)
930
+ pkg_name = pkg_name.split(os.sep)[-1]
931
+ name_spec = Name.gen_stub_from_node(self, pkg_name)
932
+ self.level += 1
919
933
  if not isinstance(name_spec, Name):
920
934
  raise ValueError("ModulePath should have a name spec. Impossible.")
921
- AstNode.__init__(self, kid=kid)
922
935
  AstSymbolNode.__init__(
923
936
  self,
924
937
  sym_name=name_spec.sym_name,
@@ -930,7 +943,7 @@ class ModulePath(AstSymbolNode):
930
943
  def path_str(self) -> str:
931
944
  """Get path string."""
932
945
  return ("." * self.level) + ".".join(
933
- [p.value for p in self.path] if self.path else ""
946
+ [p.value for p in self.path] if self.path else [self.name_spec.sym_name]
934
947
  )
935
948
 
936
949
  def normalize(self, deep: bool = False) -> bool:
@@ -56,6 +56,7 @@ class JacParser(Pass):
56
56
  source=self.source,
57
57
  doc=None,
58
58
  body=[],
59
+ terminals=[],
59
60
  is_imported=False,
60
61
  kid=[ast.EmptyToken()],
61
62
  )
@@ -120,6 +121,7 @@ class JacParser(Pass):
120
121
  """Initialize transformer."""
121
122
  super().__init__(*args, **kwargs)
122
123
  self.parse_ref = parser
124
+ self.terminals: list[ast.Token] = []
123
125
 
124
126
  def ice(self) -> Exception:
125
127
  """Raise internal compiler error."""
@@ -131,7 +133,8 @@ class JacParser(Pass):
131
133
  def nu(self, node: ast.T) -> ast.T:
132
134
  """Update node."""
133
135
  self.parse_ref.cur_node = node
134
- self.parse_ref.node_list.append(node)
136
+ if node not in self.parse_ref.node_list:
137
+ self.parse_ref.node_list.append(node)
135
138
  return node
136
139
 
137
140
  def start(self, kid: list[ast.Module]) -> ast.Module:
@@ -159,6 +162,7 @@ class JacParser(Pass):
159
162
  doc=doc,
160
163
  body=body,
161
164
  is_imported=False,
165
+ terminals=self.terminals,
162
166
  kid=kid if len(kid) else [ast.EmptyToken()],
163
167
  )
164
168
  return self.nu(mod)
@@ -3982,4 +3986,5 @@ class JacParser(Pass):
3982
3986
  err.line = ret.loc.first_line
3983
3987
  err.column = ret.loc.col_start
3984
3988
  raise err
3989
+ self.terminals.append(ret)
3985
3990
  return self.nu(ret)
@@ -182,6 +182,7 @@ class JacImportPass(Pass):
182
182
  source=ast.JacSource("", mod_path=target),
183
183
  doc=None,
184
184
  body=[],
185
+ terminals=[],
185
186
  is_imported=False,
186
187
  stub_only=True,
187
188
  kid=[ast.EmptyToken()],
@@ -50,7 +50,15 @@ class PyastGenPass(Pass):
50
50
  level=0,
51
51
  ),
52
52
  jac_node=self.ir,
53
- )
53
+ ),
54
+ self.sync(
55
+ ast3.ImportFrom(
56
+ module="typing",
57
+ names=[self.sync(ast3.alias(name="TYPE_CHECKING", asname=None))],
58
+ level=0,
59
+ ),
60
+ jac_node=self.ir,
61
+ ),
54
62
  ]
55
63
 
56
64
  def enter_node(self, node: ast.AstNode) -> None:
@@ -426,14 +434,7 @@ class PyastGenPass(Pass):
426
434
  body: SubNodeList[CodeBlockStmt],
427
435
  doc: Optional[String],
428
436
  """
429
- if node.doc:
430
- doc = self.sync(ast3.Expr(value=node.doc.gen.py_ast[0]), jac_node=node.doc)
431
- if isinstance(node.body.gen.py_ast, list):
432
- node.gen.py_ast = [doc] + node.body.gen.py_ast
433
- else:
434
- raise self.ice()
435
- else:
436
- node.gen.py_ast = node.body.gen.py_ast
437
+ node.gen.py_ast = self.resolve_stmt_block(node.body, doc=node.doc)
437
438
  if node.name:
438
439
  node.gen.py_ast = [
439
440
  self.sync(
@@ -480,7 +481,7 @@ class PyastGenPass(Pass):
480
481
  def exit_import(self, node: ast.Import) -> None:
481
482
  """Sub objects.
482
483
 
483
- lang: SubTag[Name],
484
+ hint: SubTag[Name],
484
485
  paths: list[ModulePath],
485
486
  alias: Optional[Name],
486
487
  items: Optional[SubNodeList[ModuleItem]],
@@ -488,12 +489,6 @@ class PyastGenPass(Pass):
488
489
  doc: Optional[String],
489
490
  sub_module: Optional[Module],
490
491
  """
491
- py_nodes: list[ast3.AST] = []
492
-
493
- if node.doc:
494
- py_nodes.append(
495
- self.sync(ast3.Expr(value=node.doc.gen.py_ast[0]), jac_node=node.doc)
496
- )
497
492
  path_alias: dict[str, Optional[str]] = (
498
493
  {node.from_loc.path_str: None} if node.from_loc else {}
499
494
  )
@@ -502,25 +497,75 @@ class PyastGenPass(Pass):
502
497
  for item in node.items.items:
503
498
  if isinstance(item, ast.ModuleItem):
504
499
  imp_from[item.name.sym_name] = (
505
- item.alias.sym_name if item.alias else False
500
+ item.alias.sym_name if item.alias else None
506
501
  )
507
502
  elif isinstance(item, ast.ModulePath):
508
503
  path_alias[item.path_str] = (
509
504
  item.alias.sym_name if item.alias else None
510
505
  )
511
506
 
512
- keys = []
513
- values = []
514
- for k in imp_from.keys():
515
- keys.append(self.sync(ast3.Constant(value=k)))
516
- for v in imp_from.values():
517
- values.append(self.sync(ast3.Constant(value=v)))
518
-
507
+ item_keys = []
508
+ item_values = []
509
+ for k, v in imp_from.items():
510
+ item_keys.append(self.sync(ast3.Constant(value=k)))
511
+ item_values.append(self.sync(ast3.Constant(value=v)))
519
512
  self.needs_jac_import()
520
- for p, a in path_alias.items():
513
+ path_named_value: str
514
+ py_nodes: list[ast3.AST] = []
515
+ typecheck_nodes: list[ast3.AST] = []
516
+ runtime_nodes: list[ast3.AST] = []
517
+
518
+ if node.doc:
521
519
  py_nodes.append(
520
+ self.sync(ast3.Expr(value=node.doc.gen.py_ast[0]), jac_node=node.doc)
521
+ )
522
+
523
+ for path, alias in path_alias.items():
524
+ path_named_value = ("_jac_inc_" if node.is_absorb else "") + (
525
+ alias if alias else path
526
+ ).lstrip(".").split(".")[0]
527
+ # target_named_value = ""
528
+ # for i in path.split("."):
529
+ # target_named_value += i if i else "."
530
+ # if i:
531
+ # break
532
+ runtime_nodes.append(
522
533
  self.sync(
523
- ast3.Expr(
534
+ ast3.Assign(
535
+ targets=(
536
+ [
537
+ self.sync(
538
+ ast3.Tuple(
539
+ elts=(
540
+ [
541
+ self.sync(
542
+ ast3.Name(
543
+ id=path_named_value,
544
+ ctx=ast3.Store(),
545
+ )
546
+ )
547
+ ]
548
+ if not len(item_keys)
549
+ else []
550
+ + [
551
+ self.sync(
552
+ ast3.Name(
553
+ id=(
554
+ v.value
555
+ if v.value
556
+ else k.value
557
+ ),
558
+ ctx=ast3.Store(),
559
+ )
560
+ )
561
+ for k, v in zip(item_keys, item_values)
562
+ ]
563
+ ),
564
+ ctx=ast3.Store(),
565
+ )
566
+ )
567
+ ]
568
+ ),
524
569
  value=self.sync(
525
570
  ast3.Call(
526
571
  func=self.sync(
@@ -532,7 +577,7 @@ class PyastGenPass(Pass):
532
577
  ast3.keyword(
533
578
  arg="target",
534
579
  value=self.sync(
535
- ast3.Constant(value=p),
580
+ ast3.Constant(value=path),
536
581
  ),
537
582
  )
538
583
  ),
@@ -541,7 +586,8 @@ class PyastGenPass(Pass):
541
586
  arg="base_path",
542
587
  value=self.sync(
543
588
  ast3.Name(
544
- id="__file__", ctx=ast3.Load()
589
+ id="__file__",
590
+ ctx=ast3.Load(),
545
591
  )
546
592
  ),
547
593
  )
@@ -580,7 +626,7 @@ class PyastGenPass(Pass):
580
626
  ast3.keyword(
581
627
  arg="mdl_alias",
582
628
  value=self.sync(
583
- ast3.Constant(value=a),
629
+ ast3.Constant(value=alias),
584
630
  ),
585
631
  )
586
632
  ),
@@ -588,21 +634,170 @@ class PyastGenPass(Pass):
588
634
  ast3.keyword(
589
635
  arg="items",
590
636
  value=self.sync(
591
- ast3.Dict(keys=keys, values=values),
637
+ ast3.Dict(
638
+ keys=item_keys, values=item_values
639
+ ),
592
640
  ),
593
641
  )
594
642
  ),
595
643
  ],
596
644
  )
597
- )
645
+ ),
598
646
  ),
647
+ ),
648
+ )
649
+ if node.is_absorb:
650
+ absorb_exec = f"={path_named_value}.__dict__['"
651
+ runtime_nodes.append(
652
+ self.sync(
653
+ ast3.For(
654
+ target=self.sync(ast3.Name(id="i", ctx=ast3.Store())),
655
+ iter=self.sync(
656
+ ast3.IfExp(
657
+ test=self.sync(
658
+ ast3.Compare(
659
+ left=self.sync(ast3.Constant(value="__all__")),
660
+ ops=[self.sync(ast3.In())],
661
+ comparators=[
662
+ self.sync(
663
+ ast3.Attribute(
664
+ value=self.sync(
665
+ ast3.Name(
666
+ id=path_named_value,
667
+ ctx=ast3.Load(),
668
+ )
669
+ ),
670
+ attr="__dict__",
671
+ ctx=ast3.Load(),
672
+ )
673
+ )
674
+ ],
675
+ )
676
+ ),
677
+ body=self.sync(
678
+ ast3.Attribute(
679
+ value=self.sync(
680
+ ast3.Name(
681
+ id=path_named_value, ctx=ast3.Load()
682
+ )
683
+ ),
684
+ attr="__all__",
685
+ ctx=ast3.Load(),
686
+ )
687
+ ),
688
+ orelse=self.sync(
689
+ ast3.Attribute(
690
+ value=self.sync(
691
+ ast3.Name(
692
+ id=path_named_value, ctx=ast3.Load()
693
+ )
694
+ ),
695
+ attr="__dict__",
696
+ ctx=ast3.Load(),
697
+ )
698
+ ),
699
+ )
700
+ ),
701
+ body=[
702
+ self.sync(
703
+ ast3.If(
704
+ test=self.sync(
705
+ ast3.UnaryOp(
706
+ op=self.sync(ast3.Not()),
707
+ operand=self.sync(
708
+ ast3.Call(
709
+ func=self.sync(
710
+ ast3.Attribute(
711
+ value=self.sync(
712
+ ast3.Name(
713
+ id="i",
714
+ ctx=ast3.Load(),
715
+ )
716
+ ),
717
+ attr="startswith",
718
+ ctx=ast3.Load(),
719
+ )
720
+ ),
721
+ args=[
722
+ self.sync(
723
+ ast3.Constant(value="_")
724
+ )
725
+ ],
726
+ keywords=[],
727
+ )
728
+ ),
729
+ )
730
+ ),
731
+ body=[
732
+ self.sync(
733
+ ast3.Expr(
734
+ value=self.sync(
735
+ ast3.Call(
736
+ func=self.sync(
737
+ ast3.Name(
738
+ id="exec",
739
+ ctx=ast3.Load(),
740
+ )
741
+ ),
742
+ args=[
743
+ self.sync(
744
+ ast3.JoinedStr(
745
+ values=[
746
+ self.sync(
747
+ ast3.FormattedValue(
748
+ value=self.sync(
749
+ ast3.Name(
750
+ id="i",
751
+ ctx=ast3.Load(),
752
+ )
753
+ ),
754
+ conversion=-1,
755
+ )
756
+ ),
757
+ self.sync(
758
+ ast3.Constant(
759
+ value=absorb_exec
760
+ )
761
+ ),
762
+ self.sync(
763
+ ast3.FormattedValue(
764
+ value=self.sync(
765
+ ast3.Name(
766
+ id="i",
767
+ ctx=ast3.Load(),
768
+ )
769
+ ),
770
+ conversion=-1,
771
+ )
772
+ ),
773
+ self.sync(
774
+ ast3.Constant(
775
+ value="']"
776
+ )
777
+ ),
778
+ ]
779
+ )
780
+ )
781
+ ],
782
+ keywords=[],
783
+ )
784
+ )
785
+ )
786
+ )
787
+ ],
788
+ orelse=[],
789
+ )
790
+ )
791
+ ],
792
+ orelse=[],
793
+ )
599
794
  )
600
795
  )
601
796
  if node.is_absorb:
602
797
  source = node.items.items[0]
603
798
  if not isinstance(source, ast.ModulePath):
604
799
  raise self.ice()
605
- py_nodes.append(
800
+ typecheck_nodes.append(
606
801
  self.sync(
607
802
  py_node=ast3.ImportFrom(
608
803
  module=(source.path_str.lstrip(".") if source else None),
@@ -612,15 +807,10 @@ class PyastGenPass(Pass):
612
807
  jac_node=node,
613
808
  )
614
809
  )
615
- if node.items:
616
- pass
617
- # self.warning(
618
- # "Includes import * in target module into current namespace."
619
- # )
620
- if not node.from_loc:
621
- py_nodes.append(self.sync(ast3.Import(names=node.items.gen.py_ast)))
810
+ elif not node.from_loc:
811
+ typecheck_nodes.append(self.sync(ast3.Import(names=node.items.gen.py_ast)))
622
812
  else:
623
- py_nodes.append(
813
+ typecheck_nodes.append(
624
814
  self.sync(
625
815
  ast3.ImportFrom(
626
816
  module=(
@@ -633,6 +823,15 @@ class PyastGenPass(Pass):
633
823
  )
634
824
  )
635
825
  )
826
+ py_nodes.append(
827
+ self.sync(
828
+ ast3.If(
829
+ test=self.sync(ast3.Name(id="TYPE_CHECKING", ctx=ast3.Load())),
830
+ body=typecheck_nodes,
831
+ orelse=runtime_nodes,
832
+ )
833
+ )
834
+ )
636
835
  node.gen.py_ast = py_nodes
637
836
 
638
837
  def exit_module_path(self, node: ast.ModulePath) -> None:
@@ -124,6 +124,7 @@ class PyastBuildPass(Pass[ast.PythonModuleAst]):
124
124
  source=ast.JacSource("", mod_path=self.mod_path),
125
125
  doc=doc_str,
126
126
  body=valid[1:] if valid and isinstance(valid[0], ast.String) else valid,
127
+ terminals=[],
127
128
  is_imported=False,
128
129
  kid=valid,
129
130
  )
@@ -1507,11 +1508,13 @@ class PyastBuildPass(Pass[ast.PythonModuleAst]):
1507
1508
  pos_end=0,
1508
1509
  )
1509
1510
  )
1511
+ moddots = [self.operator(Tok.DOT, ".") for _ in range(node.level)]
1512
+ modparts = moddots + modpaths
1510
1513
  path = ast.ModulePath(
1511
1514
  path=modpaths,
1512
1515
  level=node.level,
1513
1516
  alias=None,
1514
- kid=modpaths,
1517
+ kid=modparts,
1515
1518
  )
1516
1519
  names = [self.convert(name) for name in node.names]
1517
1520
  valid_names = []
@@ -47,8 +47,12 @@ class ImportPassPassTests(TestCase):
47
47
  state = jac_file_to_pass(
48
48
  self.fixture_abs_path("incautoimpl.jac"), JacImportPass
49
49
  )
50
+ count = 0
50
51
  for i in state.ir.get_all_sub_nodes(ast.Module):
51
- self.assertEqual(i.annexable_by, self.fixture_abs_path("autoimpl.jac"))
52
+ if i.name != "autoimpl":
53
+ count += 1
54
+ self.assertEqual(i.annexable_by, self.fixture_abs_path("autoimpl.jac"))
55
+ self.assertEqual(count, 3)
52
56
 
53
57
  def test_py_resolve_list(self) -> None:
54
58
  """Basic test for pass."""
@@ -94,29 +94,12 @@ class JacTypeCheckPass(Pass):
94
94
  mypy_graph[module.name] = st
95
95
  new_modules.append(st)
96
96
 
97
- # def get_stub(mod: str) -> myab.BuildSource:
98
- # """Get stub file path."""
99
- # return myab.BuildSource(
100
- # path=str(
101
- # pathlib.Path(os.path.dirname(jaclang.__file__)).parent
102
- # / "stubs"
103
- # / "jaclang"
104
- # / "plugin"
105
- # / f"{mod}.pyi"
106
- # ),
107
- # module=f"jaclang.plugin.{mod}",
108
- # )
109
-
110
97
  graph = myab.load_graph(
111
98
  [
112
99
  myab.BuildSource(
113
100
  path=str(self.__path / "typeshed" / "stdlib" / "builtins.pyi"),
114
101
  module="builtins",
115
102
  ),
116
- # get_stub("default"),
117
- # get_stub("feature"),
118
- # get_stub("spec"),
119
- # get_stub("builtin"),
120
103
  ],
121
104
  manager,
122
105
  old_graph=mypy_graph,
@@ -27,7 +27,12 @@ class FuseCommentsPass(Pass):
27
27
  """Insert comment tokens into all_tokens."""
28
28
  comment_stream = iter(self.comments) # Iterator for comments
29
29
  code_stream = iter(self.all_tokens) # Iterator for code tokens
30
- new_stream: list[ast.AstNode] = [] # New stream to hold ordered tokens
30
+ new_stream: list[ast.Token] = [] # New stream to hold ordered tokens
31
+
32
+ if not isinstance(self.ir, ast.Module):
33
+ raise self.ice(
34
+ f"FuseCommentsPass can only be run on a Module, not a {type(self.ir)}"
35
+ )
31
36
 
32
37
  try:
33
38
  next_comment = next(comment_stream) # Get the first comment
@@ -39,12 +44,20 @@ class FuseCommentsPass(Pass):
39
44
  except StopIteration:
40
45
  next_code = None
41
46
 
47
+ if next_comment and (not next_code or is_comment_next(next_comment, next_code)):
48
+ self.ir.terminals.insert(0, next_comment)
49
+
42
50
  while next_comment or next_code:
43
51
  if next_comment and (
44
52
  not next_code or is_comment_next(next_comment, next_code)
45
53
  ):
46
54
  # Add the comment to the new stream
55
+ last_tok = new_stream[-1] if len(new_stream) else None
47
56
  new_stream.append(next_comment)
57
+ if last_tok:
58
+ self.ir.terminals.insert(
59
+ self.ir.terminals.index(last_tok) + 1, next_comment
60
+ )
48
61
  try:
49
62
  next_comment = next(comment_stream)
50
63
  except StopIteration:
@@ -70,7 +83,6 @@ class FuseCommentsPass(Pass):
70
83
  parent_kids.insert(insert_index, token)
71
84
  prev_token.parent.set_kids(parent_kids)
72
85
  else:
73
- prev_token.pp()
74
86
  raise self.ice(
75
87
  "Token without parent in AST should be impossible"
76
88
  )