jaclang 0.5.17__py3-none-any.whl → 0.5.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

jaclang/compiler/jac.lark CHANGED
@@ -317,9 +317,9 @@ cmp_op: KW_ISN
317
317
 
318
318
  // Arithmetic expressions
319
319
  arithmetic: (arithmetic (MINUS | PLUS))? term
320
- term: (term (MOD | DIV | FLOOR_DIV | STAR_MUL | DECOR_OP))? factor
321
- factor: (BW_NOT | MINUS | PLUS) factor | power
322
- power: (power STAR_POW)? connect
320
+ term: (term (MOD | DIV | FLOOR_DIV | STAR_MUL | DECOR_OP))? power
321
+ power: (power STAR_POW)? factor
322
+ factor: (BW_NOT | MINUS | PLUS) factor | connect
323
323
 
324
324
  // Connect expressions
325
325
  connect: (connect (connect_op | disconnect_op))? atomic_pipe
@@ -539,16 +539,16 @@ KW_POST_INIT: "postinit"
539
539
  KW_SUPER: "super"
540
540
  KW_ROOT: "root"
541
541
 
542
- FLOAT: /(\d+(\.\d*)|\.\d+)([eE][+-]?\d+)?/
542
+ FLOAT: /(\d+(\.\d*)|\.\d+)([eE][+-]?\d+)?|\d+([eE][-+]?\d+)/
543
543
  DOC_STRING.1: /"""(.|\n|\r)*?"""|'''(.|\n|\r)*?'''/
544
544
  PYNLINE: /::py::(.|\n|\r)*?::py::/
545
545
  STRING: /(r?b?|b?r?)"[^"\r\n]*"|(r?b?|b?r?)'[^'\r\n]*'/
546
546
  BOOL.1: /True|False/
547
547
  KW_NIN.1: /\bnot\s+in\b/
548
548
  KW_ISN.1: /\bis\s+not\b/
549
- HEX: /0[xX][0-9a-fA-F_]+/
550
- BIN: /0[bB][01_]+/
551
- OCT: /0[oO][0-7_]+/
549
+ HEX.1: /0[xX][0-9a-fA-F_]+/
550
+ BIN.1: /0[bB][01_]+/
551
+ OCT.1: /0[oO][0-7_]+/
552
552
  INT: /[0-9][0-9_]*/
553
553
  NULL.1: /None/
554
554
  KWESC_NAME: /<>[a-zA-Z_][a-zA-Z0-9_]*/
@@ -287,17 +287,22 @@ class JacParser(Pass):
287
287
  if len(kid) == 1 and isinstance(kid[0], ast.Import):
288
288
  return self.nu(kid[0])
289
289
  lang = kid[1]
290
- paths = [i for i in kid if isinstance(i, ast.ModulePath)]
290
+ from_path = kid[3] if isinstance(kid[3], ast.ModulePath) else None
291
+ if from_path:
292
+ items = kid[-2] if isinstance(kid[-2], ast.SubNodeList) else None
293
+ else:
294
+ paths = [i for i in kid if isinstance(i, ast.ModulePath)]
295
+ items = ast.SubNodeList[ast.ModulePath](
296
+ items=paths, delim=Tok.COMMA, kid=kid[2:-1]
297
+ )
298
+ kid = kid[:2] + [items] + kid[-1:]
291
299
 
292
- items = kid[-2] if isinstance(kid[-2], ast.SubNodeList) else None
293
300
  is_absorb = False
294
- if isinstance(lang, ast.SubTag) and (
295
- isinstance(items, ast.SubNodeList) or items is None
296
- ):
301
+ if isinstance(lang, ast.SubTag) and (isinstance(items, ast.SubNodeList)):
297
302
  return self.nu(
298
303
  ast.Import(
299
304
  hint=lang,
300
- paths=paths,
305
+ from_loc=from_path,
301
306
  items=items,
302
307
  is_absorb=is_absorb,
303
308
  kid=kid,
@@ -342,14 +347,20 @@ class JacParser(Pass):
342
347
  include_stmt: KW_INCLUDE sub_name import_path SEMI
343
348
  """
344
349
  lang = kid[1]
345
- paths = [i for i in kid if isinstance(i, ast.ModulePath)]
350
+ from_path = kid[2]
351
+ if not isinstance(from_path, ast.ModulePath):
352
+ raise self.ice()
353
+ items = ast.SubNodeList[ast.ModulePath](
354
+ items=[from_path], delim=Tok.COMMA, kid=[from_path]
355
+ )
356
+ kid = kid[:2] + [items] + kid[3:]
346
357
  is_absorb = True
347
358
  if isinstance(lang, ast.SubTag):
348
359
  return self.nu(
349
360
  ast.Import(
350
361
  hint=lang,
351
- paths=paths,
352
- items=None,
362
+ from_loc=None,
363
+ items=items,
353
364
  is_absorb=is_absorb,
354
365
  kid=kid,
355
366
  )
@@ -1919,13 +1930,17 @@ class JacParser(Pass):
1919
1930
  sig_kid.append(params)
1920
1931
  if return_type:
1921
1932
  sig_kid.append(return_type)
1922
- signature = ast.FuncSignature(
1923
- params=params,
1924
- return_type=return_type,
1925
- kid=sig_kid,
1933
+ signature = (
1934
+ ast.FuncSignature(
1935
+ params=params,
1936
+ return_type=return_type,
1937
+ kid=sig_kid,
1938
+ )
1939
+ if params or return_type
1940
+ else None
1926
1941
  )
1927
1942
  new_kid = [i for i in kid if i != params and i != return_type]
1928
- new_kid.insert(1, signature)
1943
+ new_kid.insert(1, signature) if signature else None
1929
1944
  if isinstance(chomp[0], ast.Expr):
1930
1945
  return self.nu(
1931
1946
  ast.LambdaExpr(
@@ -1,16 +1,18 @@
1
1
  """Abstract class for IR Passes for Jac."""
2
2
 
3
- from typing import Optional, Type
3
+ from typing import Optional, Type, TypeVar
4
4
 
5
5
  import jaclang.compiler.absyntree as ast
6
6
  from jaclang.compiler.passes.transform import Transform
7
7
  from jaclang.utils.helpers import pascal_to_snake
8
8
 
9
+ T = TypeVar("T", bound=ast.AstNode)
9
10
 
10
- class Pass(Transform[ast.T]):
11
+
12
+ class Pass(Transform[T]):
11
13
  """Abstract class for IR passes."""
12
14
 
13
- def __init__(self, input_ir: ast.T, prior: Optional[Transform]) -> None:
15
+ def __init__(self, input_ir: T, prior: Optional[Transform]) -> None:
14
16
  """Initialize parser."""
15
17
  self.term_signal = False
16
18
  self.prune_signal = False
@@ -45,10 +47,10 @@ class Pass(Transform[ast.T]):
45
47
 
46
48
  @staticmethod
47
49
  def get_all_sub_nodes(
48
- node: ast.AstNode, typ: Type[ast.T], brute_force: bool = False
49
- ) -> list[ast.T]:
50
+ node: ast.AstNode, typ: Type[T], brute_force: bool = False
51
+ ) -> list[T]:
50
52
  """Get all sub nodes of type."""
51
- result: list[ast.T] = []
53
+ result: list[T] = []
52
54
  # Assumes pass built the sub node table
53
55
  if not node:
54
56
  return result
@@ -69,7 +71,7 @@ class Pass(Transform[ast.T]):
69
71
  return result
70
72
 
71
73
  @staticmethod
72
- def has_parent_of_type(node: ast.AstNode, typ: Type[ast.T]) -> Optional[ast.T]:
74
+ def has_parent_of_type(node: ast.AstNode, typ: Type[T]) -> Optional[T]:
73
75
  """Check if node has parent of type."""
74
76
  while node.parent:
75
77
  if isinstance(node.parent, typ):
@@ -97,7 +99,7 @@ class Pass(Transform[ast.T]):
97
99
 
98
100
  # Transform Implementations
99
101
  # -------------------------
100
- def transform(self, ir: ast.T) -> ast.AstNode:
102
+ def transform(self, ir: T) -> ast.AstNode:
101
103
  """Run pass."""
102
104
  # Only performs passes on proper ASTs
103
105
  if not isinstance(ir, ast.AstNode):
@@ -6,15 +6,16 @@ mypy apis into Jac and use jac py ast in it.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
- import os
10
9
  import traceback
11
10
  from typing import Callable, TypeVar
12
11
 
13
12
  import jaclang.compiler.absyntree as ast
14
13
  from jaclang.compiler.passes import Pass
14
+ from jaclang.settings import settings
15
15
  from jaclang.utils.helpers import pascal_to_snake
16
16
  from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack
17
17
 
18
+
18
19
  import mypy.nodes as MypyNodes # noqa N812
19
20
  import mypy.types as MypyTypes # noqa N812
20
21
  from mypy.checkexpr import Type as MyType
@@ -29,7 +30,7 @@ class FuseTypeInfoPass(Pass):
29
30
  node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}
30
31
 
31
32
  def __debug_print(self, *argv: object) -> None:
32
- if "FuseTypeInfoDebug" in os.environ:
33
+ if settings.fuse_type_info_debug:
33
34
  print("FuseTypeInfo::", *argv)
34
35
 
35
36
  def __call_type_handler(
@@ -7,7 +7,6 @@ symbols are available for matching.
7
7
 
8
8
  import ast as py_ast
9
9
  import importlib.util
10
- import os
11
10
  import sys
12
11
  from os import path
13
12
  from typing import Optional
@@ -16,6 +15,7 @@ from typing import Optional
16
15
  import jaclang.compiler.absyntree as ast
17
16
  from jaclang.compiler.passes import Pass
18
17
  from jaclang.compiler.passes.main import SubNodeTabPass
18
+ from jaclang.settings import settings
19
19
  from jaclang.utils.helpers import import_target_to_relative_path
20
20
 
21
21
 
@@ -36,7 +36,8 @@ class ImportPass(Pass):
36
36
  self.run_again = False
37
37
  all_imports = self.get_all_sub_nodes(node, ast.ModulePath)
38
38
  for i in all_imports:
39
- if i.parent.hint.tag.value == "jac" and not i.sub_module:
39
+ lang = i.parent_of_type(ast.Import).hint.tag.value
40
+ if lang == "jac" and not i.sub_module:
40
41
  self.run_again = True
41
42
  mod = self.import_module(
42
43
  node=i,
@@ -48,9 +49,7 @@ class ImportPass(Pass):
48
49
  self.annex_impl(mod)
49
50
  i.sub_module = mod
50
51
  i.add_kids_right([mod], pos_update=False)
51
- elif i.parent.hint.tag.value == "py" and os.environ.get(
52
- "JAC_PROC_DEBUG", False
53
- ):
52
+ elif lang == "py" and settings.jac_proc_debug:
54
53
  mod = self.import_py_module(node=i, mod_path=node.loc.mod_path)
55
54
  i.sub_module = mod
56
55
  i.add_kids_right([mod], pos_update=False)
@@ -492,17 +492,20 @@ class PyastGenPass(Pass):
492
492
  py_nodes.append(
493
493
  self.sync(ast3.Expr(value=node.doc.gen.py_ast[0]), jac_node=node.doc)
494
494
  )
495
- py_compat_path_str = []
496
- path_alias = {}
497
- for path in node.paths:
498
- py_compat_path_str.append(path.path_str.lstrip("."))
499
- path_alias[path.path_str] = path.alias.sym_name if path.alias else None
495
+ path_alias: dict[str, Optional[str]] = (
496
+ {node.from_loc.path_str: None} if node.from_loc else {}
497
+ )
500
498
  imp_from = {}
501
499
  if node.items:
502
500
  for item in node.items.items:
503
- imp_from[item.name.sym_name] = (
504
- item.alias.sym_name if item.alias else False
505
- )
501
+ if isinstance(item, ast.ModuleItem):
502
+ imp_from[item.name.sym_name] = (
503
+ item.alias.sym_name if item.alias else False
504
+ )
505
+ elif isinstance(item, ast.ModulePath):
506
+ path_alias[item.path_str] = (
507
+ item.alias.sym_name if item.alias else None
508
+ )
506
509
 
507
510
  keys = []
508
511
  values = []
@@ -594,10 +597,13 @@ class PyastGenPass(Pass):
594
597
  )
595
598
  )
596
599
  if node.is_absorb:
600
+ source = node.items.items[0]
601
+ if not isinstance(source, ast.ModulePath):
602
+ raise self.ice()
597
603
  py_nodes.append(
598
604
  self.sync(
599
605
  py_node=ast3.ImportFrom(
600
- module=py_compat_path_str[0] if py_compat_path_str[0] else None,
606
+ module=(source.path_str.lstrip(".") if source else None),
601
607
  names=[self.sync(ast3.alias(name="*"), node)],
602
608
  level=0,
603
609
  ),
@@ -608,15 +614,17 @@ class PyastGenPass(Pass):
608
614
  self.warning(
609
615
  "Includes import * in target module into current namespace."
610
616
  )
611
- if not node.items:
612
- py_nodes.append(
613
- self.sync(ast3.Import(names=[i.gen.py_ast[0] for i in node.paths]))
614
- )
617
+ if not node.from_loc:
618
+ py_nodes.append(self.sync(ast3.Import(names=node.items.gen.py_ast)))
615
619
  else:
616
620
  py_nodes.append(
617
621
  self.sync(
618
622
  ast3.ImportFrom(
619
- module=py_compat_path_str[0] if py_compat_path_str[0] else None,
623
+ module=(
624
+ node.from_loc.path_str.lstrip(".")
625
+ if node.from_loc
626
+ else None
627
+ ),
620
628
  names=node.items.gen.py_ast,
621
629
  level=0,
622
630
  )
@@ -1674,7 +1682,7 @@ class PyastGenPass(Pass):
1674
1682
  self.sync(
1675
1683
  ast3.Try(
1676
1684
  body=self.resolve_stmt_block(node.body),
1677
- handlers=node.excepts.gen.py_ast if node.excepts else None,
1685
+ handlers=node.excepts.gen.py_ast if node.excepts else [],
1678
1686
  orelse=node.else_body.gen.py_ast if node.else_body else [],
1679
1687
  finalbody=node.finally_body.gen.py_ast if node.finally_body else [],
1680
1688
  )
@@ -2376,7 +2384,19 @@ class PyastGenPass(Pass):
2376
2384
  node.gen.py_ast = [
2377
2385
  self.sync(
2378
2386
  ast3.Lambda(
2379
- args=node.signature.gen.py_ast[0],
2387
+ args=(
2388
+ node.signature.gen.py_ast[0]
2389
+ if node.signature
2390
+ else self.sync(
2391
+ ast3.arguments(
2392
+ posonlyargs=[],
2393
+ args=[],
2394
+ kwonlyargs=[],
2395
+ kw_defaults=[],
2396
+ defaults=[],
2397
+ )
2398
+ )
2399
+ ),
2380
2400
  body=node.body.gen.py_ast[0],
2381
2401
  )
2382
2402
  )
@@ -2435,11 +2455,16 @@ class PyastGenPass(Pass):
2435
2455
  )
2436
2456
  ]
2437
2457
  elif node.op.name in [Tok.STAR_MUL]:
2458
+ ctx_val = (
2459
+ node.operand.py_ctx_func()
2460
+ if isinstance(node.operand, ast.AstSymbolNode)
2461
+ else ast3.Load()
2462
+ )
2438
2463
  node.gen.py_ast = [
2439
2464
  self.sync(
2440
2465
  ast3.Starred(
2441
2466
  value=node.operand.gen.py_ast[0],
2442
- ctx=ast3.Load(),
2467
+ ctx=ctx_val,
2443
2468
  )
2444
2469
  )
2445
2470
  ]
@@ -3648,7 +3673,22 @@ class PyastGenPass(Pass):
3648
3673
  pos_start: int,
3649
3674
  pos_end: int,
3650
3675
  """
3651
- node.gen.py_ast = [self.sync(ast3.Constant(value=int(node.value)))]
3676
+
3677
+ def handle_node_value(value: str) -> int:
3678
+ if value.startswith(("0x", "0X")):
3679
+ return int(value, 16)
3680
+ elif value.startswith(("0b", "0B")):
3681
+ return int(value, 2)
3682
+ elif value.startswith(("0o", "0O")):
3683
+ return int(value, 8)
3684
+ else:
3685
+ return int(value)
3686
+
3687
+ node.gen.py_ast = [
3688
+ self.sync(
3689
+ ast3.Constant(value=handle_node_value(str(node.value)), kind=None)
3690
+ )
3691
+ ]
3652
3692
 
3653
3693
  def exit_string(self, node: ast.String) -> None:
3654
3694
  """Sub objects.
@@ -1372,12 +1372,13 @@ class PyastBuildPass(Pass[ast.PythonModuleAst]):
1372
1372
  pos_end=0,
1373
1373
  )
1374
1374
  pytag = ast.SubTag[ast.Name](tag=lang, kid=[lang])
1375
+ items = ast.SubNodeList[ast.ModulePath](items=paths, delim=Tok.COMMA, kid=paths)
1375
1376
  ret = ast.Import(
1376
1377
  hint=pytag,
1377
- paths=paths,
1378
- items=None,
1378
+ from_loc=None,
1379
+ items=items,
1379
1380
  is_absorb=False,
1380
- kid=[pytag, *paths],
1381
+ kid=[pytag, items],
1381
1382
  )
1382
1383
  return ret
1383
1384
 
@@ -1449,7 +1450,7 @@ class PyastBuildPass(Pass[ast.PythonModuleAst]):
1449
1450
  pytag = ast.SubTag[ast.Name](tag=lang, kid=[lang])
1450
1451
  ret = ast.Import(
1451
1452
  hint=pytag,
1452
- paths=[path],
1453
+ from_loc=path,
1453
1454
  items=items,
1454
1455
  is_absorb=False,
1455
1456
  kid=[pytag, path, items],
@@ -61,14 +61,21 @@ class SymTabPass(Pass):
61
61
  node.sym_name_node.py_ctx_func = ast3.Store
62
62
  if isinstance(node, (ast.TupleVal, ast.ListVal)) and node.values:
63
63
 
64
- def fix(item: ast.TupleVal | ast.ListVal) -> None:
65
- for i in item.values.items if item.values else []:
66
- if isinstance(i, ast.AstSymbolNode):
67
- i.py_ctx_func = ast3.Store
68
- elif isinstance(i, ast.AtomTrailer):
69
- self.chain_def_insert(self.unwind_atom_trailer(i))
70
- if isinstance(i, (ast.TupleVal, ast.ListVal)):
71
- fix(i)
64
+ # Handling of UnaryExpr case for item is only necessary for
65
+ # the generation of Starred nodes in the AST for examples
66
+ # like `(a, *b) = (1, 2, 3, 4)`.
67
+ def fix(item: ast.TupleVal | ast.ListVal | ast.UnaryExpr) -> None:
68
+ if isinstance(item, ast.UnaryExpr):
69
+ if isinstance(item.operand, ast.AstSymbolNode):
70
+ item.operand.py_ctx_func = ast3.Store
71
+ elif isinstance(item, (ast.TupleVal, ast.ListVal)):
72
+ for i in item.values.items if item.values else []:
73
+ if isinstance(i, ast.AstSymbolNode):
74
+ i.py_ctx_func = ast3.Store
75
+ elif isinstance(i, ast.AtomTrailer):
76
+ self.chain_def_insert(self.unwind_atom_trailer(i))
77
+ if isinstance(i, (ast.TupleVal, ast.ListVal, ast.UnaryExpr)):
78
+ fix(i)
72
79
 
73
80
  fix(node)
74
81
  self.handle_hit_outcome(node)
@@ -381,16 +388,22 @@ class SymTabBuildPass(SymTabPass):
381
388
  is_absorb: bool,
382
389
  sub_module: Optional[Module],
383
390
  """
384
- if node.items:
391
+ if not node.is_absorb:
385
392
  for i in node.items.items:
386
393
  self.def_insert(i, single_decl="import item")
387
394
  elif node.is_absorb and node.hint.tag.value == "jac":
388
- if not node.paths[0].sub_module or not node.paths[0].sub_module.sym_tab:
395
+ source = node.items.items[0]
396
+ if (
397
+ not isinstance(source, ast.ModulePath)
398
+ or not source.sub_module
399
+ or not source.sub_module.sym_tab
400
+ ):
389
401
  self.error(
390
- f"Module {node.paths[0].path_str} not found to include *, or ICE occurred!"
402
+ f"Module {node.from_loc.path_str if node.from_loc else 'from location'}"
403
+ f" not found to include *, or ICE occurred!"
391
404
  )
392
405
  else:
393
- for v in node.paths[0].sub_module.sym_tab.tab.values():
406
+ for v in source.sub_module.sym_tab.tab.values():
394
407
  self.def_insert(v.decl, table_override=self.cur_scope())
395
408
 
396
409
  def enter_module_path(self, node: ast.ModulePath) -> None:
@@ -1443,10 +1443,10 @@ class JacFormatPass(Pass):
1443
1443
  body: Expr,
1444
1444
  """
1445
1445
  out = ""
1446
- if node.signature.params:
1446
+ if node.signature and node.signature.params:
1447
1447
  self.comma_sep_node_list(node.signature.params)
1448
1448
  out += node.signature.params.gen.jac
1449
- if node.signature.return_type:
1449
+ if node.signature and node.signature.return_type:
1450
1450
  out += f" -> {node.signature.return_type.gen.jac}"
1451
1451
  self.emit(node, f"with {out} can {node.body.gen.jac}")
1452
1452
 
@@ -171,9 +171,7 @@ class Workspace:
171
171
  [
172
172
  i
173
173
  for i in mod_ir.get_all_sub_nodes(ast.ModulePath)
174
- if i.parent
175
- and isinstance(i.parent, ast.Import)
176
- and i.parent.hint.tag.value == "jac"
174
+ if i.parent_of_type(ast.Import).hint.tag.value == "jac"
177
175
  ]
178
176
  if mod_ir
179
177
  else []
@@ -184,9 +182,7 @@ class Workspace:
184
182
  i
185
183
  for i in mod_ir.get_all_sub_nodes(ast.ModulePath)
186
184
  if i.loc.mod_path == file_path
187
- and i.parent
188
- and isinstance(i.parent, ast.Import)
189
- and i.parent.hint.tag.value == "jac"
185
+ and i.parent_of_type(ast.Import).hint.tag.value == "jac"
190
186
  ]
191
187
  if mod_ir
192
188
  else []
jaclang/core/llms.py CHANGED
@@ -1,7 +1,5 @@
1
1
  """LLMs (Large Language Models) module for Jaclang."""
2
2
 
3
- import anthropic
4
-
5
3
 
6
4
  class Anthropic:
7
5
  """Anthropic API client for Large Language Models (LLMs)."""
@@ -12,6 +10,8 @@ class Anthropic:
12
10
 
13
11
  def __init__(self, **kwargs: dict) -> None:
14
12
  """Initialize the Anthropic API client."""
13
+ import anthropic
14
+
15
15
  self.client = anthropic.Anthropic()
16
16
  self.model_name = kwargs.get("model_name", "claude-3-sonnet-20240229")
17
17
  self.temperature = kwargs.get("temperature", 0.7)
@@ -27,3 +27,85 @@ class Anthropic:
27
27
  messages=messages,
28
28
  )
29
29
  return output.content[0].text
30
+
31
+
32
+ class Huggingface:
33
+ """Huggingface API client for Large Language Models (LLMs)."""
34
+
35
+ MTLLM_PROMPT: str = ""
36
+ MTLLM_REASON_SUFFIX: str = ""
37
+ MTLLM_WO_REASON_SUFFIX: str = ""
38
+
39
+ def __init__(self, **kwargs: dict) -> None:
40
+ """Initialize the Huggingface API client."""
41
+ import torch
42
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
43
+
44
+ torch.random.manual_seed(0)
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ kwargs.get("model_name", "microsoft/Phi-3-mini-128k-instruct"),
47
+ device_map=kwargs.get("device_map", "cuda"),
48
+ torch_dtype="auto",
49
+ trust_remote_code=True,
50
+ )
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ kwargs.get("model_name", "microsoft/Phi-3-mini-128k-instruct")
53
+ )
54
+ self.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
55
+ self.temperature = kwargs.get("temperature", 0.7)
56
+ self.max_tokens = kwargs.get("max_new_tokens", 1024)
57
+
58
+ def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
59
+ """Infer a response from the input meaning."""
60
+ messages = [{"role": "user", "content": meaning_in}]
61
+ output = self.pipe(
62
+ messages,
63
+ temperature=kwargs.get("temperature", self.temperature),
64
+ max_length=kwargs.get("max_new_tokens", self.max_tokens),
65
+ **kwargs
66
+ )
67
+ return output[0]["generated_text"][-1]["content"]
68
+
69
+
70
+ class Ollama:
71
+ """Ollama API client for Large Language Models (LLMs)."""
72
+
73
+ MTLLM_PROMPT: str = ""
74
+ MTLLM_REASON_SUFFIX: str = ""
75
+ MTLLM_WO_REASON_SUFFIX: str = ""
76
+
77
+ def __init__(self, **kwargs: dict) -> None:
78
+ """Initialize the Ollama API client."""
79
+ import ollama
80
+
81
+ self.client = ollama.Client(host=kwargs.get("host", "http://localhost:11434"))
82
+ self.model_name = kwargs.get("model_name", "phi3")
83
+ self.default_model_params = {
84
+ k: v for k, v in kwargs.items() if k not in ["model_name", "host"]
85
+ }
86
+
87
+ def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
88
+ """Infer a response from the input meaning."""
89
+ model = str(kwargs.get("model_name", self.model_name))
90
+ if not self.check_model(model):
91
+ self.download_model(model)
92
+ model_params = {k: v for k, v in kwargs.items() if k not in ["model_name"]}
93
+ messages = [{"role": "user", "content": meaning_in}]
94
+ output = self.client.chat(
95
+ model=model,
96
+ messages=messages,
97
+ options={**self.default_model_params, **model_params},
98
+ )
99
+ return output["message"]["content"]
100
+
101
+ def check_model(self, model_name: str) -> bool:
102
+ """Check if the model is available."""
103
+ try:
104
+ self.client.show(model_name)
105
+ return True
106
+ except Exception:
107
+ return False
108
+
109
+ def download_model(self, model_name: str) -> None:
110
+ """Download the model."""
111
+ self.client.pull(model_name)
jaclang/plugin/default.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import fnmatch
6
+ import html
6
7
  import os
7
8
  import pickle
8
9
  import types
@@ -85,7 +86,11 @@ class JacFeatureDefaults:
85
86
  for i in on_entry + on_exit:
86
87
  i.resolve(cls)
87
88
  if not issubclass(cls, arch_base):
89
+ # Saving the module path and reassign it after creating cls
90
+ # So the jac modules are part of the correct module
91
+ cur_module = cls.__module__
88
92
  cls = type(cls.__name__, (cls, arch_base), {})
93
+ cls.__module__ = cur_module
89
94
  cls._jac_entry_funcs_ = on_entry # type: ignore
90
95
  cls._jac_exit_funcs_ = on_exit # type: ignore
91
96
  else:
@@ -691,13 +696,16 @@ class JacBuiltin:
691
696
  for source, target, edge in connections:
692
697
  dot_content += (
693
698
  f"{visited_nodes.index(source)} -> {visited_nodes.index(target)} "
694
- f' [label="{edge._jac_.obj.__class__.__name__} "];\n'
699
+ f' [label="{html.escape(str(edge._jac_.obj.__class__.__name__))} "];\n'
695
700
  )
696
701
  for node_ in visited_nodes:
697
702
  color = (
698
703
  colors[node_depths[node_]] if node_depths[node_] < 25 else colors[24]
699
704
  )
700
- dot_content += f'{visited_nodes.index(node_)} [label="{node_._jac_.obj}" fillcolor="{color}"];\n'
705
+ dot_content += (
706
+ f'{visited_nodes.index(node_)} [label="{html.escape(str(node_._jac_.obj))}"'
707
+ f'fillcolor="{color}"];\n'
708
+ )
701
709
  if dot_file:
702
710
  with open(dot_file, "w") as f:
703
711
  f.write(dot_content + "}")