jaclang 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (114) hide show
  1. jaclang/cli/cli.py +194 -10
  2. jaclang/cli/cmdreg.py +144 -8
  3. jaclang/compiler/__init__.py +6 -1
  4. jaclang/compiler/codeinfo.py +16 -1
  5. jaclang/compiler/constant.py +33 -8
  6. jaclang/compiler/jac.lark +154 -62
  7. jaclang/compiler/larkparse/jac_parser.py +2 -2
  8. jaclang/compiler/parser.py +656 -149
  9. jaclang/compiler/passes/__init__.py +2 -1
  10. jaclang/compiler/passes/ast_gen/__init__.py +5 -0
  11. jaclang/compiler/passes/ast_gen/base_ast_gen_pass.py +54 -0
  12. jaclang/compiler/passes/ast_gen/jsx_processor.py +344 -0
  13. jaclang/compiler/passes/ecmascript/__init__.py +25 -0
  14. jaclang/compiler/passes/ecmascript/es_unparse.py +576 -0
  15. jaclang/compiler/passes/ecmascript/esast_gen_pass.py +2068 -0
  16. jaclang/compiler/passes/ecmascript/estree.py +972 -0
  17. jaclang/compiler/passes/ecmascript/tests/__init__.py +1 -0
  18. jaclang/compiler/passes/ecmascript/tests/fixtures/advanced_language_features.jac +170 -0
  19. jaclang/compiler/passes/ecmascript/tests/fixtures/class_separate_impl.impl.jac +30 -0
  20. jaclang/compiler/passes/ecmascript/tests/fixtures/class_separate_impl.jac +14 -0
  21. jaclang/compiler/passes/ecmascript/tests/fixtures/client_jsx.jac +89 -0
  22. jaclang/compiler/passes/ecmascript/tests/fixtures/core_language_features.jac +195 -0
  23. jaclang/compiler/passes/ecmascript/tests/test_esast_gen_pass.py +167 -0
  24. jaclang/compiler/passes/ecmascript/tests/test_js_generation.py +239 -0
  25. jaclang/compiler/passes/main/__init__.py +0 -3
  26. jaclang/compiler/passes/main/annex_pass.py +23 -1
  27. jaclang/compiler/passes/main/def_use_pass.py +1 -0
  28. jaclang/compiler/passes/main/pyast_gen_pass.py +413 -255
  29. jaclang/compiler/passes/main/pyast_load_pass.py +48 -11
  30. jaclang/compiler/passes/main/pyjac_ast_link_pass.py +2 -0
  31. jaclang/compiler/passes/main/sym_tab_build_pass.py +18 -1
  32. jaclang/compiler/passes/main/tests/fixtures/autoimpl.cl.jac +7 -0
  33. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +3 -0
  34. jaclang/compiler/passes/main/tests/fixtures/checker_class_construct.jac +33 -0
  35. jaclang/compiler/passes/main/tests/fixtures/defuse_modpath.jac +7 -0
  36. jaclang/compiler/passes/main/tests/fixtures/member_access_type_resolve.jac +2 -1
  37. jaclang/compiler/passes/main/tests/test_checker_pass.py +31 -3
  38. jaclang/compiler/passes/main/tests/test_def_use_pass.py +12 -0
  39. jaclang/compiler/passes/main/tests/test_import_pass.py +23 -4
  40. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +13 -14
  41. jaclang/compiler/passes/main/tests/test_pyast_gen_pass.py +25 -0
  42. jaclang/compiler/passes/main/type_checker_pass.py +7 -0
  43. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +219 -20
  44. jaclang/compiler/passes/tool/fuse_comments_pass.py +1 -10
  45. jaclang/compiler/passes/tool/jac_formatter_pass.py +2 -2
  46. jaclang/compiler/passes/tool/tests/fixtures/import_fmt.jac +7 -1
  47. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +135 -29
  48. jaclang/compiler/passes/tool/tests/test_jac_format_pass.py +4 -1
  49. jaclang/compiler/passes/transform.py +9 -1
  50. jaclang/compiler/passes/uni_pass.py +5 -7
  51. jaclang/compiler/program.py +27 -26
  52. jaclang/compiler/tests/test_client_codegen.py +113 -0
  53. jaclang/compiler/tests/test_importer.py +12 -10
  54. jaclang/compiler/tests/test_parser.py +249 -3
  55. jaclang/compiler/type_system/type_evaluator.jac +1078 -0
  56. jaclang/compiler/type_system/type_utils.py +1 -1
  57. jaclang/compiler/type_system/types.py +6 -0
  58. jaclang/compiler/unitree.py +438 -82
  59. jaclang/langserve/engine.jac +224 -288
  60. jaclang/langserve/sem_manager.jac +12 -8
  61. jaclang/langserve/server.jac +48 -48
  62. jaclang/langserve/tests/fixtures/greet.py +17 -0
  63. jaclang/langserve/tests/fixtures/md_path.jac +22 -0
  64. jaclang/langserve/tests/fixtures/user.jac +15 -0
  65. jaclang/langserve/tests/test_server.py +66 -371
  66. jaclang/lib.py +17 -0
  67. jaclang/runtimelib/archetype.py +25 -25
  68. jaclang/runtimelib/client_bundle.py +169 -0
  69. jaclang/runtimelib/client_runtime.jac +586 -0
  70. jaclang/runtimelib/constructs.py +4 -2
  71. jaclang/runtimelib/machine.py +308 -139
  72. jaclang/runtimelib/meta_importer.py +111 -22
  73. jaclang/runtimelib/mtp.py +15 -0
  74. jaclang/runtimelib/server.py +1089 -0
  75. jaclang/runtimelib/tests/fixtures/client_app.jac +18 -0
  76. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  77. jaclang/runtimelib/tests/fixtures/savable_object.jac +4 -5
  78. jaclang/runtimelib/tests/fixtures/serve_api.jac +75 -0
  79. jaclang/runtimelib/tests/test_client_bundle.py +55 -0
  80. jaclang/runtimelib/tests/test_client_render.py +63 -0
  81. jaclang/runtimelib/tests/test_serve.py +1069 -0
  82. jaclang/settings.py +0 -3
  83. jaclang/tests/fixtures/attr_pattern_case.jac +18 -0
  84. jaclang/tests/fixtures/funccall_genexpr.jac +7 -0
  85. jaclang/tests/fixtures/funccall_genexpr.py +5 -0
  86. jaclang/tests/fixtures/iife_functions.jac +142 -0
  87. jaclang/tests/fixtures/iife_functions_client.jac +143 -0
  88. jaclang/tests/fixtures/multistatement_lambda.jac +116 -0
  89. jaclang/tests/fixtures/multistatement_lambda_client.jac +113 -0
  90. jaclang/tests/fixtures/needs_import_dup.jac +6 -4
  91. jaclang/tests/fixtures/py2jac_empty.py +0 -0
  92. jaclang/tests/fixtures/py_run.py +7 -5
  93. jaclang/tests/fixtures/pyfunc_fstr.py +2 -2
  94. jaclang/tests/fixtures/simple_lambda_test.jac +12 -0
  95. jaclang/tests/test_cli.py +134 -18
  96. jaclang/tests/test_language.py +120 -32
  97. jaclang/tests/test_reference.py +20 -3
  98. jaclang/utils/NonGPT.py +375 -0
  99. jaclang/utils/helpers.py +64 -20
  100. jaclang/utils/lang_tools.py +31 -4
  101. jaclang/utils/tests/test_lang_tools.py +5 -16
  102. jaclang/utils/treeprinter.py +8 -3
  103. {jaclang-0.8.8.dist-info → jaclang-0.8.10.dist-info}/METADATA +3 -3
  104. {jaclang-0.8.8.dist-info → jaclang-0.8.10.dist-info}/RECORD +106 -71
  105. jaclang/compiler/passes/main/binder_pass.py +0 -594
  106. jaclang/compiler/passes/main/tests/fixtures/sym_binder.jac +0 -47
  107. jaclang/compiler/passes/main/tests/test_binder_pass.py +0 -111
  108. jaclang/compiler/type_system/type_evaluator.py +0 -844
  109. jaclang/langserve/tests/session.jac +0 -294
  110. jaclang/langserve/tests/test_dev_server.py +0 -80
  111. jaclang/runtimelib/importer.py +0 -351
  112. jaclang/tests/test_typecheck.py +0 -542
  113. {jaclang-0.8.8.dist-info → jaclang-0.8.10.dist-info}/WHEEL +0 -0
  114. {jaclang-0.8.8.dist-info → jaclang-0.8.10.dist-info}/entry_points.txt +0 -0
@@ -16,6 +16,8 @@ from __future__ import annotations
16
16
 
17
17
  import ast as py_ast
18
18
  import os
19
+ import re
20
+ from threading import Event
19
21
  from typing import Optional, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
20
22
 
21
23
  import jaclang.compiler.unitree as uni
@@ -32,11 +34,16 @@ T = TypeVar("T", bound=uni.UniNode)
32
34
  class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
33
35
  """Jac Parser."""
34
36
 
35
- def __init__(self, ir_in: uni.PythonModuleAst, prog: JacProgram) -> None:
37
+ def __init__(
38
+ self,
39
+ ir_in: uni.PythonModuleAst,
40
+ prog: JacProgram,
41
+ cancel_token: Event | None = None,
42
+ ) -> None:
36
43
  """Initialize parser."""
37
44
  self.mod_path = ir_in.loc.mod_path
38
45
  self.orig_src = ir_in.loc.orig_src
39
- Transform.__init__(self, ir_in=ir_in, prog=prog)
46
+ Transform.__init__(self, ir_in=ir_in, prog=prog, cancel_token=cancel_token)
40
47
 
41
48
  def nu(self, node: T) -> T:
42
49
  """Update node."""
@@ -45,6 +52,8 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
45
52
 
46
53
  def convert(self, node: py_ast.AST) -> uni.UniNode:
47
54
  """Get python node type."""
55
+ if self.is_canceled():
56
+ raise StopIteration
48
57
  if hasattr(self, f"proc_{pascal_to_snake(type(node).__name__)}"):
49
58
  ret = getattr(self, f"proc_{pascal_to_snake(type(node).__name__)}")(node)
50
59
  else:
@@ -96,6 +105,8 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
96
105
  body: list[stmt]
97
106
  type_ignores: list[TypeIgnore]
98
107
  """
108
+ if not node.body:
109
+ return uni.Module.make_stub(inject_src=self.ir_in)
99
110
  elements: list[uni.UniNode] = [self.convert(i) for i in node.body]
100
111
  elements[0] = (
101
112
  elements[0].expr
@@ -1158,7 +1169,7 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
1158
1169
  else:
1159
1170
  raise self.ice()
1160
1171
 
1161
- def proc_formatted_value(self, node: py_ast.FormattedValue) -> uni.ExprStmt:
1172
+ def proc_formatted_value(self, node: py_ast.FormattedValue) -> uni.FormattedValue:
1162
1173
  """Process python node.
1163
1174
 
1164
1175
  class FormattedValue(expr):
@@ -1169,10 +1180,15 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
1169
1180
  format_spec: expr | None
1170
1181
  """
1171
1182
  value = self.convert(node.value)
1183
+ if node.format_spec:
1184
+ fmt_spec = cast(uni.Expr, self.convert(node.format_spec))
1185
+ else:
1186
+ fmt_spec = None
1172
1187
  if isinstance(value, uni.Expr):
1173
- ret = uni.ExprStmt(
1174
- expr=value,
1175
- in_fstring=True,
1188
+ ret = uni.FormattedValue(
1189
+ format_part=value,
1190
+ conversion=node.conversion,
1191
+ format_spec=fmt_spec,
1176
1192
  kid=[value],
1177
1193
  )
1178
1194
  else:
@@ -1367,13 +1383,34 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
1367
1383
  __match_args__ = ("values",)
1368
1384
  values: list[expr]
1369
1385
  """
1370
- values = [self.convert(value) for value in node.values]
1371
- valid = [
1372
- value for value in values if isinstance(value, (uni.String, uni.ExprStmt))
1373
- ]
1386
+ valid: list[uni.Token | uni.FormattedValue] = []
1387
+ for i in node.values:
1388
+ if isinstance(i, py_ast.Constant) and isinstance(i.value, str):
1389
+ valid.append(self.operator(Tok.STRING, i.value))
1390
+ elif isinstance(i, py_ast.FormattedValue):
1391
+ converted = self.convert(i)
1392
+ if isinstance(converted, uni.FormattedValue):
1393
+ valid.append(converted)
1394
+ else:
1395
+ raise self.ice("Invalid node in joined str")
1396
+ ast_seg = py_ast.get_source_segment(self.orig_src.code, node)
1397
+ if ast_seg is None:
1398
+ ast_seg = 'f""'
1399
+ match = re.match(r"(?i)(fr|rf|f)('{3}|\"{3}|'|\")", ast_seg)
1400
+ if match:
1401
+ prefix, quote = match.groups()
1402
+ start = match.group(0)
1403
+ end = quote * (3 if len(quote) == 3 else 1)
1404
+ else:
1405
+ start = "f'"
1406
+ end = "'"
1407
+ tok_start = self.operator(Tok.STRING, start)
1408
+ tok_end = self.operator(Tok.STRING, end)
1374
1409
  fstr = uni.FString(
1410
+ start=tok_start,
1375
1411
  parts=valid,
1376
- kid=[*valid] if valid else [uni.EmptyToken()],
1412
+ end=tok_end,
1413
+ kid=[tok_start, *valid, tok_end] if valid else [uni.EmptyToken()],
1377
1414
  )
1378
1415
  return uni.MultiString(strings=[fstr], kid=[fstr])
1379
1416
 
@@ -27,6 +27,8 @@ class PyJacAstLinkPass(UniPass):
27
27
  self, jac_node: uni.UniNode, py_nodes: list[ast3.AST]
28
28
  ) -> None:
29
29
  """Link jac name ast to py ast nodes."""
30
+ if isinstance(jac_node, uni.ClientFacingNode) and jac_node.is_client_decl:
31
+ return
30
32
  jac_node.gen.py_ast = py_nodes
31
33
  for i in py_nodes:
32
34
  if isinstance(i.jac_link, list): # type: ignore
@@ -19,6 +19,7 @@ type checking, and semantic analysis throughout the compilation process.
19
19
  """
20
20
 
21
21
  import jaclang.compiler.unitree as uni
22
+ from jaclang.compiler.constant import SymbolAccess
22
23
  from jaclang.compiler.passes import UniPass
23
24
  from jaclang.compiler.unitree import UniScopeNode
24
25
 
@@ -75,7 +76,7 @@ class SymTabBuildPass(UniPass):
75
76
  def exit_module_path(self, node: uni.ModulePath) -> None:
76
77
  if node.alias:
77
78
  node.alias.sym_tab.def_insert(node.alias, single_decl="import")
78
- elif node.path:
79
+ elif node.path and not node.is_import_from:
79
80
  if node.parent_of_type(uni.Import) and not (
80
81
  node.parent_of_type(uni.Import).from_loc
81
82
  and node.parent_of_type(uni.Import).is_jac
@@ -84,9 +85,25 @@ class SymTabBuildPass(UniPass):
84
85
  else:
85
86
  pass # Need to support pythonic import symbols with dots in it
86
87
 
88
+ # There will be symbols for
89
+ # import from math {sqrt} <- math will have a symbol but no symtab entry
90
+ # import math as m <- m will have a symbol and symtab entry
91
+ if node.path and (node.is_import_from or (node.alias)):
92
+ for n in node.path:
93
+ n.sym = n.create_symbol(
94
+ access=SymbolAccess.PUBLIC,
95
+ imported=True,
96
+ )
97
+
87
98
  def exit_module_item(self, node: uni.ModuleItem) -> None:
88
99
  sym_node = node.alias or node.name
89
100
  sym_node.sym_tab.def_insert(sym_node, single_decl="import")
101
+ if node.alias:
102
+ # create symbol for module item
103
+ node.name.sym = node.name.create_symbol(
104
+ access=SymbolAccess.PUBLIC,
105
+ imported=True,
106
+ )
90
107
 
91
108
  def enter_archetype(self, node: uni.Archetype) -> None:
92
109
  self.push_scope_and_link(node)
@@ -0,0 +1,7 @@
1
+ node Button {
2
+ has title: str;
3
+ }
4
+
5
+ def component() {
6
+ print("client component");
7
+ }
@@ -15,4 +15,7 @@ with entry {
15
15
  f.with_default_args(1, 2); # <-- Ok
16
16
  f.with_default_args(1, 2, 3); # <-- Error
17
17
  f.with_default_args(); # <-- Error
18
+
19
+ # Argument unpacking
20
+ f.with_default_args(*[1, 2]); # <-- Ok
18
21
  }
@@ -0,0 +1,33 @@
1
+ glob RAD = 5.0;
2
+
3
+ obj Circle1{
4
+ has radius:float ,age:int;
5
+ def init(color:str){
6
+ self.color = color;
7
+ }
8
+ }
9
+ with entry {
10
+ c1 = Circle1(RAD);
11
+ }
12
+ # ---------------------------------------------------------------
13
+ glob length = 5.0;
14
+
15
+ obj Square{
16
+ has side_length:float ,age:int;
17
+ }
18
+
19
+ with entry {
20
+ c2 = Square(length);
21
+ }
22
+ # ---------------------------------------------------------------
23
+
24
+ glob name = "John";
25
+
26
+ obj Person{
27
+ has name:str ,age:int=90;
28
+ }
29
+
30
+ with entry {
31
+ c = Person(name=name, age=25);
32
+ c = Person();
33
+ }
@@ -0,0 +1,7 @@
1
+ import from math { sqrt as square_root }
2
+ import from concurrent.futures { ThreadPoolExecutor }
3
+
4
+ with entry {
5
+ square_root(16);
6
+ # math;
7
+ }
@@ -5,7 +5,8 @@ node Foo {
5
5
  has bar: Bar;
6
6
  }
7
7
  with entry {
8
- f: Foo = Foo();
8
+ bar_obj: Bar = Bar(23);
9
+ f: Foo = Foo(bar=bar_obj);
9
10
  i: int = f.bar.baz; # <-- Ok
10
11
  s: str = f.bar.baz; # <-- Error
11
12
  }
@@ -1,8 +1,6 @@
1
1
 
2
2
  """Tests for typechecker pass (the pyright implementation)."""
3
3
 
4
- from tempfile import NamedTemporaryFile
5
-
6
4
  from jaclang.utils.test import TestCase
7
5
  from jaclang.compiler.passes.main import TypeCheckPass
8
6
  from jaclang.compiler.program import JacProgram
@@ -175,7 +173,6 @@ class TypeCheckerPassTests(TestCase):
175
173
  """, program.errors_had[0].pretty_print())
176
174
 
177
175
  def test_param_arg_match(self) -> None:
178
- path = self.fixture_abs_path("checker_param_types.jac")
179
176
  program = JacProgram()
180
177
  path = self.fixture_abs_path("checker_arg_param_match.jac")
181
178
  mod = program.compile(path)
@@ -271,6 +268,37 @@ class TypeCheckerPassTests(TestCase):
271
268
  for i, expected in enumerate(expected_errors):
272
269
  self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
273
270
 
271
+ def test_class_construct(self) -> None:
272
+ program = JacProgram()
273
+ path = self.fixture_abs_path("checker_class_construct.jac")
274
+ mod = program.compile(path)
275
+ TypeCheckPass(ir_in=mod, prog=program)
276
+ self.assertEqual(len(program.errors_had), 3)
277
+
278
+ expected_errors = [
279
+ """
280
+ Cannot assign <class float> to parameter 'color' of type <class str>
281
+ with entry {
282
+ c1 = Circle1(RAD);
283
+ ^^^
284
+ """,
285
+ """
286
+ Not all required parameters were provided in the function call: 'age'
287
+ with entry {
288
+ c2 = Square(length);
289
+ ^^^^^^^^^^^^^^
290
+ """,
291
+ """
292
+ Not all required parameters were provided in the function call: 'name'
293
+ c = Person(name=name, age=25);
294
+ c = Person();
295
+ ^^^^^^^^
296
+ """,
297
+ ]
298
+
299
+ for i, expected in enumerate(expected_errors):
300
+ self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
301
+
274
302
  def test_self_type_inference(self) -> None:
275
303
  path = self.fixture_abs_path("checker_self_type.jac")
276
304
  program = JacProgram()
@@ -21,3 +21,15 @@ class DefUsePassTests(TestCase):
21
21
  self.assertEqual(len(uses[1]), 1)
22
22
  self.assertIn("output", [uses[0][0].sym_name, uses[1][0].sym_name])
23
23
  self.assertIn("message", [uses[0][0].sym_name, uses[1][0].sym_name])
24
+
25
+ def test_def_use_modpath(self) -> None:
26
+ """Basic test for pass."""
27
+ state = JacProgram().compile(
28
+ file_path=self.fixture_abs_path("defuse_modpath.jac")
29
+ )
30
+ all_symbols = list(
31
+ state.sym_tab.names_in_scope.values()
32
+ )
33
+ self.assertEqual(len(all_symbols), 2)
34
+ self.assertEqual(all_symbols[0].sym_name, "square_root")
35
+ self.assertEqual(all_symbols[1].sym_name, "ThreadPoolExecutor")
@@ -30,29 +30,31 @@ class ImportPassPassTests(TestCase):
30
30
  (prog := JacProgram()).compile(self.fixture_abs_path("autoimpl.jac"))
31
31
  num_modules = len(list(prog.mod.hub.values())[0].impl_mod)
32
32
  mod_names = [i.name for i in list(prog.mod.hub.values())[0].impl_mod]
33
- self.assertEqual(num_modules, 4)
33
+ self.assertEqual(num_modules, 5)
34
34
  self.assertIn("getme.impl", mod_names)
35
35
  self.assertIn("autoimpl.impl", mod_names)
36
36
  self.assertIn("autoimpl.something.else.impl", mod_names)
37
+ self.assertIn("autoimpl.cl", mod_names)
37
38
 
38
39
  def test_import_include_auto_impl(self) -> None:
39
40
  """Basic test for pass."""
40
41
  (prog := JacProgram()).build(self.fixture_abs_path("incautoimpl.jac"))
41
42
  num_modules = len(list(prog.mod.hub.values())[1].impl_mod) + 1
42
43
  mod_names = [i.name for i in list(prog.mod.hub.values())[1].impl_mod]
43
- self.assertEqual(num_modules, 5)
44
+ self.assertEqual(num_modules, 6)
44
45
  self.assertEqual("incautoimpl", list(prog.mod.hub.values())[0].name)
45
46
  self.assertEqual("autoimpl", list(prog.mod.hub.values())[1].name)
46
47
  self.assertIn("getme.impl", mod_names)
47
48
  self.assertIn("autoimpl.impl", mod_names)
48
49
  self.assertIn("autoimpl.something.else.impl", mod_names)
50
+ self.assertIn("autoimpl.cl", mod_names)
49
51
 
50
52
  def test_annexalbe_by_discovery(self) -> None:
51
53
  """Basic test for pass."""
52
54
  (prog := JacProgram()).build(self.fixture_abs_path("incautoimpl.jac"))
53
55
  count = 0
54
56
  all_mods = prog.mod.hub.values()
55
- self.assertEqual(len(all_mods), 6)
57
+ self.assertEqual(len(all_mods), 7)
56
58
  for main_mod in all_mods:
57
59
  for i in main_mod.impl_mod:
58
60
  if i.name not in ["autoimpl", "incautoimpl"]:
@@ -60,7 +62,24 @@ class ImportPassPassTests(TestCase):
60
62
  self.assertEqual(
61
63
  i.annexable_by, self.fixture_abs_path("autoimpl.jac")
62
64
  )
63
- self.assertEqual(count, 4)
65
+ self.assertEqual(count, 5)
66
+
67
+ def test_cl_annex_marked_client(self) -> None:
68
+ """Ensure .cl.jac annex files are autoloaded and marked client."""
69
+
70
+ (prog := JacProgram()).compile(self.fixture_abs_path("autoimpl.jac"))
71
+ main_mod = list(prog.mod.hub.values())[0]
72
+ cl_mod = next(
73
+ (mod for mod in main_mod.impl_mod if mod.name.endswith(".cl")), None
74
+ )
75
+ self.assertIsNotNone(cl_mod, "Expected .cl annex module to be loaded")
76
+ abilities = cl_mod.get_all_sub_nodes(uni.Ability)
77
+ self.assertTrue(abilities, "Expected abilities in .cl annex module")
78
+ for ability in abilities:
79
+ self.assertTrue(
80
+ ability.is_client_decl,
81
+ "All client annex abilities should be marked as client declarations",
82
+ )
64
83
 
65
84
  @unittest.skip("TODO: Fix when we have the type checker")
66
85
  def test_py_raise_map(self) -> None:
@@ -4,7 +4,7 @@ import io
4
4
  import os
5
5
  import sys
6
6
 
7
- from jaclang.compiler.program import JacProgram
7
+ from jaclang.compiler.program import JacProgram, py_code_gen
8
8
  from jaclang.utils.test import TestCase
9
9
  from jaclang.settings import settings
10
10
  from jaclang.compiler.passes.main import PreDynamoPass
@@ -16,42 +16,41 @@ class PreDynamoPassTests(TestCase):
16
16
 
17
17
  def setUp(self) -> None:
18
18
  """Set up test."""
19
+ settings.predynamo_pass = True
19
20
  return super().setUp()
20
-
21
+
22
+ def tearDown(self) -> None:
23
+ """Tear down test."""
24
+ settings.predynamo_pass = False
25
+ # Remove PreDynamoPass from global py_code_gen list if it was added
26
+ if PreDynamoPass in py_code_gen:
27
+ py_code_gen.remove(PreDynamoPass)
28
+ return super().tearDown()
29
+
21
30
  def test_predynamo_where_assign(self) -> None:
22
31
  """Test torch.where transformation."""
23
32
  captured_output = io.StringIO()
24
33
  sys.stdout = captured_output
25
- os.environ["JAC_PREDYNAMO_PASS"] = "True"
26
- settings.load_env_vars()
34
+ settings.predynamo_pass = True
27
35
  code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_assign.jac"))
28
36
  sys.stdout = sys.__stdout__
29
37
  self.assertIn("torch.where", code_gen.unparse())
30
- os.environ["JAC_PREDYNAMO_PASS"] = "false"
31
- settings.load_env_vars()
32
38
 
33
39
  def test_predynamo_where_return(self) -> None:
34
40
  """Test torch.where transformation."""
35
41
  captured_output = io.StringIO()
36
42
  sys.stdout = captured_output
37
- os.environ["JAC_PREDYNAMO_PASS"] = "True"
38
- settings.load_env_vars()
39
43
  code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_return.jac"))
40
44
  sys.stdout = sys.__stdout__
41
45
  self.assertIn("torch.where", code_gen.unparse())
42
- os.environ["JAC_PREDYNAMO_PASS"] = "false"
43
- settings.load_env_vars()
46
+
44
47
 
45
48
  def test_predynamo_fix3(self) -> None:
46
49
  """Test torch.where transformation."""
47
50
  captured_output = io.StringIO()
48
51
  sys.stdout = captured_output
49
- os.environ["JAC_PREDYNAMO_PASS"] = "True"
50
- settings.load_env_vars()
51
52
  code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_fix3.jac"))
52
53
  sys.stdout = sys.__stdout__
53
54
  unparsed_code = code_gen.unparse()
54
55
  self.assertIn("__inv_freq = torch.where(", unparsed_code)
55
56
  self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)
56
- os.environ["JAC_PREDYNAMO_PASS"] = "false"
57
- settings.load_env_vars()
@@ -153,6 +153,31 @@ class PyastGenPassTests(TestCaseMicroSuite, AstSyncTestMixin):
153
153
 
154
154
  self.assertFalse(out.errors_had)
155
155
 
156
+ def test_iife_fixture_executes(self) -> None:
157
+ """Ensure IIFE and block lambdas lower to executable Python."""
158
+ fixture_path = self.lang_fixture_abs_path("iife_functions.jac")
159
+ code_gen = (prog := JacProgram()).compile(fixture_path)
160
+ self.assertFalse(prog.errors_had)
161
+ if code_gen.gen.py_ast and isinstance(code_gen.gen.py_ast[0], ast3.Module):
162
+ module_ast = code_gen.gen.py_ast[0]
163
+ compiled = compile(module_ast, filename="<ast>", mode="exec")
164
+ captured = io.StringIO()
165
+ original_stdout = sys.stdout
166
+ try:
167
+ sys.stdout = captured
168
+ module = types.ModuleType("__main__")
169
+ module.__dict__["__file__"] = code_gen.loc.mod_path
170
+ exec(compiled, module.__dict__)
171
+ finally:
172
+ sys.stdout = original_stdout
173
+ output = captured.getvalue()
174
+ self.assertIn("Test 1 - Basic IIFE: 42", output)
175
+ self.assertIn(
176
+ "Test 6 - IIFE returning function, adder(5): 15",
177
+ output,
178
+ )
179
+ self.assertIn("All IIFE tests completed!", output)
180
+
156
181
  def parent_scrub(self, node: uni.UniNode) -> bool:
157
182
  """Validate every node has parent."""
158
183
  success = True
@@ -69,10 +69,17 @@ class TypeCheckPass(UniPass):
69
69
 
70
70
  def exit_import(self, node: uni.Import) -> None:
71
71
  """Exit an import node."""
72
+ # import from math {sqrt, sin as s}
72
73
  if node.from_loc:
74
+ self.evaluator.get_type_of_module(node.from_loc)
73
75
  for item in node.items:
74
76
  if isinstance(item, uni.ModuleItem):
75
77
  self.evaluator.get_type_of_module_item(item)
78
+ else:
79
+ # import math as m, os, sys;
80
+ for item in node.items:
81
+ if isinstance(item, uni.ModulePath):
82
+ self.evaluator.get_type_of_module(item)
76
83
 
77
84
  def exit_assignment(self, node: uni.Assignment) -> None:
78
85
  """Pyright: Checker.visitAssignment(node: AssignmentNode): boolean."""