jaclang 0.7.1__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (48) hide show
  1. jaclang/compiler/absyntree.py +51 -14
  2. jaclang/compiler/passes/main/def_impl_match_pass.py +9 -3
  3. jaclang/compiler/passes/main/fuse_typeinfo_pass.py +20 -1
  4. jaclang/compiler/passes/main/import_pass.py +4 -1
  5. jaclang/compiler/passes/main/pyast_gen_pass.py +14 -6
  6. jaclang/compiler/passes/main/pyast_load_pass.py +2 -1
  7. jaclang/compiler/passes/main/pyjac_ast_link_pass.py +6 -1
  8. jaclang/compiler/passes/main/pyout_pass.py +3 -1
  9. jaclang/compiler/passes/main/tests/test_import_pass.py +8 -0
  10. jaclang/compiler/passes/main/tests/test_type_check_pass.py +1 -1
  11. jaclang/compiler/passes/tool/jac_formatter_pass.py +14 -2
  12. jaclang/compiler/passes/tool/tests/fixtures/doc_string.jac +15 -0
  13. jaclang/compiler/passes/tool/tests/test_jac_format_pass.py +7 -5
  14. jaclang/compiler/passes/tool/tests/test_unparse_validate.py +1 -2
  15. jaclang/compiler/symtable.py +21 -1
  16. jaclang/core/aott.py +107 -11
  17. jaclang/core/construct.py +171 -5
  18. jaclang/core/llms/anthropic.py +31 -2
  19. jaclang/core/llms/base.py +3 -3
  20. jaclang/core/llms/groq.py +4 -1
  21. jaclang/core/llms/huggingface.py +4 -1
  22. jaclang/core/llms/ollama.py +4 -1
  23. jaclang/core/llms/openai.py +6 -2
  24. jaclang/core/llms/togetherai.py +4 -1
  25. jaclang/langserve/engine.py +99 -115
  26. jaclang/langserve/server.py +27 -5
  27. jaclang/langserve/tests/fixtures/circle_pure.impl.jac +8 -4
  28. jaclang/langserve/tests/fixtures/circle_pure.jac +2 -2
  29. jaclang/langserve/tests/test_server.py +123 -0
  30. jaclang/langserve/utils.py +100 -10
  31. jaclang/plugin/default.py +25 -83
  32. jaclang/plugin/feature.py +10 -12
  33. jaclang/plugin/tests/test_features.py +0 -33
  34. jaclang/settings.py +1 -0
  35. jaclang/tests/fixtures/byllmissue.jac +3 -0
  36. jaclang/tests/fixtures/hash_init_check.jac +17 -0
  37. jaclang/tests/fixtures/math_question.jpg +0 -0
  38. jaclang/tests/fixtures/nosigself.jac +19 -0
  39. jaclang/tests/fixtures/walker_override.jac +21 -0
  40. jaclang/tests/fixtures/with_llm_vision.jac +25 -0
  41. jaclang/tests/test_language.py +61 -11
  42. jaclang/utils/treeprinter.py +19 -2
  43. {jaclang-0.7.1.dist-info → jaclang-0.7.2.dist-info}/METADATA +3 -2
  44. {jaclang-0.7.1.dist-info → jaclang-0.7.2.dist-info}/RECORD +46 -41
  45. jaclang/core/memory.py +0 -48
  46. jaclang/core/shelve_storage.py +0 -55
  47. {jaclang-0.7.1.dist-info → jaclang-0.7.2.dist-info}/WHEEL +0 -0
  48. {jaclang-0.7.1.dist-info → jaclang-0.7.2.dist-info}/entry_points.txt +0 -0
jaclang/core/aott.py CHANGED
@@ -4,18 +4,30 @@ AOTT: Automated Operational Type Transformation.
4
4
  This has all the necessary functions to perform the AOTT operations.
5
5
  """
6
6
 
7
+ import base64
8
+ import logging
7
9
  import re
8
10
  from enum import Enum
11
+ from io import BytesIO
9
12
  from typing import Any
10
13
 
14
+
15
+ try:
16
+ from PIL import Image
17
+ except ImportError:
18
+ Image = None
19
+
11
20
  from jaclang.core.llms.base import BaseLLM
12
21
  from jaclang.core.registry import SemInfo, SemRegistry, SemScope
13
22
 
14
23
 
24
+ IMG_FORMATS = ["PngImageFile", "JpegImageFile"]
25
+
26
+
15
27
  def aott_raise(
16
28
  model: BaseLLM,
17
29
  information: str,
18
- inputs_information: str,
30
+ inputs_information: str | list[dict],
19
31
  output_information: str,
20
32
  type_explanations: str,
21
33
  action: str,
@@ -25,18 +37,43 @@ def aott_raise(
25
37
  model_params: dict,
26
38
  ) -> str:
27
39
  """AOTT Raise uses the information (Meanings types values) provided to generate a prompt(meaning in)."""
40
+ system_prompt = model.MTLLM_SYSTEM_PROMPT
41
+ meaning_in: str | list[dict]
28
42
  if method != "ReAct":
29
- system_prompt = model.MTLLM_SYSTEM_PROMPT
30
- mtllm_prompt = model.MTLLM_PROMPT.format(
31
- information=information,
32
- inputs_information=inputs_information,
33
- output_information=output_information,
34
- type_explanations=type_explanations,
35
- action=action,
36
- context=context,
37
- )
38
43
  method_prompt = model.MTLLM_METHOD_PROMPTS[method]
39
- meaning_in = f"{system_prompt}\n{mtllm_prompt}\n{method_prompt}"
44
+ if isinstance(inputs_information, str):
45
+ mtllm_prompt = model.MTLLM_PROMPT.format(
46
+ information=information,
47
+ inputs_information=inputs_information,
48
+ output_information=output_information,
49
+ type_explanations=type_explanations,
50
+ action=action,
51
+ context=context,
52
+ ).strip()
53
+ meaning_in = f"{system_prompt}\n{mtllm_prompt}\n{method_prompt}".strip()
54
+ else:
55
+ upper_half = model.MTLLM_PROMPT.split("{inputs_information}")[0]
56
+ lower_half = model.MTLLM_PROMPT.split("{inputs_information}")[1]
57
+ upper_half = upper_half.format(
58
+ information=information,
59
+ context=context,
60
+ )
61
+ lower_half = lower_half.format(
62
+ output_information=output_information,
63
+ type_explanations=type_explanations,
64
+ action=action,
65
+ )
66
+ meaning_in = (
67
+ [
68
+ {"type": "text", "text": system_prompt},
69
+ {"type": "text", "text": upper_half},
70
+ ]
71
+ + inputs_information
72
+ + [
73
+ {"type": "text", "text": lower_half},
74
+ {"type": "text", "text": method_prompt},
75
+ ]
76
+ )
40
77
  return model(meaning_in, **model_params)
41
78
  else:
42
79
  assert tools, "Tools must be provided for the ReAct method."
@@ -212,3 +249,62 @@ class Tool:
212
249
  """Initialize the Tool class."""
213
250
  # TODO: Implement the Tool class
214
251
  pass
252
+
253
+
254
+ def get_input_information(
255
+ inputs: list[tuple[str, str, str, Any]], type_collector: list
256
+ ) -> str | list[dict]:
257
+ """
258
+ Get the input information for the AOTT operation.
259
+
260
+ Returns:
261
+ str | list[dict]: If the input does not contain images, returns a string with the input information.
262
+ If the input contains images, returns a list of dictionaries representing the input information,
263
+ where each dictionary contains either text or image_url.
264
+
265
+ """
266
+ contains_imgs = any(get_type_annotation(i[3]) in IMG_FORMATS for i in inputs)
267
+ if not contains_imgs:
268
+ inputs_information_list = []
269
+ for i in inputs:
270
+ typ_anno = get_type_annotation(i[3])
271
+ type_collector.extend(extract_non_primary_type(typ_anno))
272
+ inputs_information_list.append(
273
+ f"{i[0]} ({i[2]}) ({typ_anno}) = {get_object_string(i[3])}"
274
+ )
275
+ return "\n".join(inputs_information_list)
276
+ else:
277
+ inputs_information_dict_list: list[dict] = []
278
+ for i in inputs:
279
+ if get_type_annotation(i[3]) in IMG_FORMATS:
280
+ img_base64 = image_to_base64(i[3])
281
+ image_repr: list[dict] = [
282
+ {
283
+ "type": "text",
284
+ "text": f"{i[0]} ({i[2]}) (Image) = ",
285
+ },
286
+ {"type": "image_url", "image_url": {"url": img_base64}},
287
+ ]
288
+ inputs_information_dict_list.extend(image_repr)
289
+ continue
290
+ typ_anno = get_type_annotation(i[3])
291
+ type_collector.extend(extract_non_primary_type(typ_anno))
292
+ inputs_information_dict_list.append(
293
+ {
294
+ "type": "text",
295
+ "text": f"{i[0]} ({i[2]}) ({typ_anno}) = {get_object_string(i[3])}",
296
+ }
297
+ )
298
+ return inputs_information_dict_list
299
+
300
+
301
+ def image_to_base64(image: Image) -> str:
302
+ """Convert an image to base64 expected by OpenAI."""
303
+ if not Image:
304
+ log = logging.getLogger(__name__)
305
+ log.error("Pillow is not installed. Please install Pillow to use images.")
306
+ return ""
307
+ img_format = image.format
308
+ with BytesIO() as buffer:
309
+ image.save(buffer, format=img_format, quality=100)
310
+ return f"data:image/{img_format.lower()};base64,{base64.b64encode(buffer.getvalue()).decode()}"
jaclang/core/construct.py CHANGED
@@ -2,14 +2,15 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import shelve
5
6
  import unittest
7
+ from contextvars import ContextVar
6
8
  from dataclasses import dataclass, field
7
9
  from typing import Callable, Optional
8
10
  from uuid import UUID, uuid4
9
11
 
10
12
  from jaclang.compiler.constant import EdgeDir
11
13
  from jaclang.core.utils import collect_node_connections
12
- from jaclang.plugin.feature import JacFeature as Jac
13
14
  from jaclang.plugin.spec import DSFunc
14
15
 
15
16
 
@@ -57,6 +58,8 @@ class NodeAnchor(ObjectAnchor):
57
58
 
58
59
  def populate_edges(self) -> None:
59
60
  """Populate edges from edge ids."""
61
+ from jaclang.plugin.feature import JacFeature as Jac
62
+
60
63
  if len(self.edges) == 0 and len(self.edge_ids) > 0:
61
64
  for e_id in self.edge_ids:
62
65
  edge = Jac.context().get_obj(e_id)
@@ -354,11 +357,15 @@ class NodeArchitype(Architype):
354
357
 
355
358
  def __init__(self) -> None:
356
359
  """Create node architype."""
360
+ from jaclang.plugin.feature import JacFeature as Jac
361
+
357
362
  self._jac_: NodeAnchor = NodeAnchor(obj=self)
358
363
  Jac.context().save_obj(self, persistent=self._jac_.persistent)
359
364
 
360
365
  def save(self) -> None:
361
366
  """Save the node to the memory/storage hierarchy."""
367
+ from jaclang.plugin.feature import JacFeature as Jac
368
+
362
369
  self._jac_.persistent = True
363
370
  Jac.context().save_obj(self, persistent=True)
364
371
 
@@ -383,11 +390,15 @@ class EdgeArchitype(Architype):
383
390
 
384
391
  def __init__(self) -> None:
385
392
  """Create edge architype."""
393
+ from jaclang.plugin.feature import JacFeature as Jac
394
+
386
395
  self._jac_: EdgeAnchor = EdgeAnchor(obj=self)
387
396
  Jac.context().save_obj(self, persistent=self.persistent)
388
397
 
389
398
  def save(self) -> None:
390
399
  """Save the edge to the memory/storage hierarchy."""
400
+ from jaclang.plugin.feature import JacFeature as Jac
401
+
391
402
  self.persistent = True
392
403
  Jac.context().save_obj(self, persistent=True)
393
404
 
@@ -405,6 +416,8 @@ class EdgeArchitype(Architype):
405
416
 
406
417
  def populate_nodes(self) -> None:
407
418
  """Populate nodes for the edges from node ids."""
419
+ from jaclang.plugin.feature import JacFeature as Jac
420
+
408
421
  if self._jac_.source_id:
409
422
  obj = Jac.context().get_obj(self._jac_.source_id)
410
423
  if obj is None:
@@ -439,6 +452,13 @@ class WalkerArchitype(Architype):
439
452
  self._jac_: WalkerAnchor = WalkerAnchor(obj=self)
440
453
 
441
454
 
455
+ class GenericEdge(EdgeArchitype):
456
+ """Generic Root Node."""
457
+
458
+ _jac_entry_funcs_ = []
459
+ _jac_exit_funcs_ = []
460
+
461
+
442
462
  class Root(NodeArchitype):
443
463
  """Generic Root Node."""
444
464
 
@@ -460,11 +480,157 @@ class Root(NodeArchitype):
460
480
  self._jac_.edges = []
461
481
 
462
482
 
463
- class GenericEdge(EdgeArchitype):
464
- """Generic Root Node."""
483
+ class Memory:
484
+ """Memory module interface."""
465
485
 
466
- _jac_entry_funcs_ = []
467
- _jac_exit_funcs_ = []
486
+ mem: dict[UUID, Architype] = {}
487
+ save_obj_list: dict[UUID, Architype] = {}
488
+
489
+ def __init__(self) -> None:
490
+ """init."""
491
+ pass
492
+
493
+ def get_obj(self, obj_id: UUID) -> Architype | None:
494
+ """Get object from memory."""
495
+ return self.get_obj_from_store(obj_id)
496
+
497
+ def get_obj_from_store(self, obj_id: UUID) -> Architype | None:
498
+ """Get object from the underlying store."""
499
+ ret = self.mem.get(obj_id)
500
+ return ret
501
+
502
+ def has_obj(self, obj_id: UUID) -> bool:
503
+ """Check if the object exists."""
504
+ return self.has_obj_in_store(obj_id)
505
+
506
+ def has_obj_in_store(self, obj_id: UUID) -> bool:
507
+ """Check if the object exists in the underlying store."""
508
+ return obj_id in self.mem
509
+
510
+ def save_obj(self, item: Architype, persistent: bool) -> None:
511
+ """Save object."""
512
+ self.mem[item._jac_.id] = item
513
+ if persistent:
514
+ # TODO: check if it needs to be saved, i.e. dirty or not
515
+ self.save_obj_list[item._jac_.id] = item
516
+
517
+ def commit(self) -> None:
518
+ """Commit changes to persistent storage, if applicable."""
519
+ pass
520
+
521
+ def close(self) -> None:
522
+ """Close any connection, if applicable."""
523
+ self.mem.clear()
524
+
525
+
526
+ class ShelveStorage(Memory):
527
+ """Shelve storage for jaclang runtime object."""
528
+
529
+ storage: shelve.Shelf | None = None
530
+
531
+ def __init__(self, session: str = "") -> None:
532
+ """Init shelve storage."""
533
+ super().__init__()
534
+ if session:
535
+ self.connect(session)
536
+
537
+ def get_obj_from_store(self, obj_id: UUID) -> Architype | None:
538
+ """Get object from the underlying store."""
539
+ obj = super().get_obj_from_store(obj_id)
540
+ if obj is None and self.storage:
541
+ obj = self.storage.get(str(obj_id))
542
+ if obj is not None:
543
+ self.mem[obj_id] = obj
544
+
545
+ return obj
546
+
547
+ def has_obj_in_store(self, obj_id: UUID | str) -> bool:
548
+ """Check if the object exists in the underlying store."""
549
+ return obj_id in self.mem or (
550
+ str(obj_id) in self.storage if self.storage else False
551
+ )
552
+
553
+ def commit(self) -> None:
554
+ """Commit changes to persistent storage."""
555
+ if self.storage is not None:
556
+ for obj_id, obj in self.save_obj_list.items():
557
+ self.storage[str(obj_id)] = obj
558
+ self.save_obj_list.clear()
559
+
560
+ def connect(self, session: str) -> None:
561
+ """Connect to storage."""
562
+ self.session = session
563
+ self.storage = shelve.open(session)
564
+
565
+ def close(self) -> None:
566
+ """Close the storage."""
567
+ super().close()
568
+ self.commit()
569
+ if self.storage:
570
+ self.storage.close()
571
+ self.storage = None
572
+
573
+
574
+ class ExecutionContext:
575
+ """Default Execution Context implementation."""
576
+
577
+ mem: Optional[Memory]
578
+ root: Optional[Root]
579
+
580
+ def __init__(self) -> None:
581
+ """Create execution context."""
582
+ super().__init__()
583
+ self.mem = ShelveStorage()
584
+ self.root = None
585
+
586
+ def init_memory(self, session: str = "") -> None:
587
+ """Initialize memory."""
588
+ if session:
589
+ self.mem = ShelveStorage(session)
590
+ else:
591
+ self.mem = Memory()
592
+
593
+ def get_root(self) -> Root:
594
+ """Get the root object."""
595
+ if self.mem is None:
596
+ raise ValueError("Memory not initialized")
597
+
598
+ if not self.root:
599
+ root = self.mem.get_obj(UUID(int=0))
600
+ if root is None:
601
+ self.root = Root()
602
+ self.mem.save_obj(self.root, persistent=self.root._jac_.persistent)
603
+ elif not isinstance(root, Root):
604
+ raise ValueError(f"Invalid root object: {root}")
605
+ else:
606
+ self.root = root
607
+ return self.root
608
+
609
+ def get_obj(self, obj_id: UUID) -> Architype | None:
610
+ """Get object from memory."""
611
+ if self.mem is None:
612
+ raise ValueError("Memory not initialized")
613
+
614
+ return self.mem.get_obj(obj_id)
615
+
616
+ def save_obj(self, item: Architype, persistent: bool) -> None:
617
+ """Save object to memory."""
618
+ if self.mem is None:
619
+ raise ValueError("Memory not initialized")
620
+
621
+ self.mem.save_obj(item, persistent)
622
+
623
+ def reset(self) -> None:
624
+ """Reset the execution context."""
625
+ if self.mem:
626
+ self.mem.close()
627
+ self.mem = None
628
+ self.root = None
629
+
630
+
631
+ exec_context: ContextVar[ExecutionContext | None] = ContextVar(
632
+ "ExecutionContext", default=None
633
+ )
468
634
 
469
635
 
470
636
  class JacTestResult(unittest.TextTestResult):
@@ -45,12 +45,41 @@ class Anthropic(BaseLLM):
45
45
  self.client = anthropic.Anthropic()
46
46
  self.verbose = verbose
47
47
  self.max_tries = max_tries
48
- self.model_name = kwargs.get("model_name", "claude-3-sonnet-20240229")
48
+ self.model_name = str(kwargs.get("model_name", "claude-3-sonnet-20240229"))
49
49
  self.temperature = kwargs.get("temperature", 0.7)
50
50
  self.max_tokens = kwargs.get("max_tokens", 1024)
51
51
 
52
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
52
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
53
53
  """Infer a response from the input meaning."""
54
+ if not isinstance(meaning_in, str):
55
+ assert self.model_name.startswith(
56
+ ("claude-3-opus", "claude-3-sonnet", "claude-3-haiku")
57
+ ), f"Model {self.model_name} is not multimodal, use a multimodal model instead."
58
+
59
+ import re
60
+
61
+ formatted_meaning_in = []
62
+ for item in meaning_in:
63
+ if item["type"] == "image_url":
64
+ # "_string"
65
+ img_match = re.match(
66
+ r"data:(image/[a-zA-Z]*);base64,(.*)", item["source"]
67
+ )
68
+ if img_match:
69
+ media_type, base64_string = img_match.groups()
70
+ formatted_meaning_in.append(
71
+ {
72
+ "type": "image",
73
+ "source": {
74
+ "type": "base64",
75
+ "media_type": media_type,
76
+ "data": base64_string,
77
+ },
78
+ }
79
+ )
80
+ continue
81
+ formatted_meaning_in.append(item)
82
+ meaning_in = formatted_meaning_in
54
83
  messages = [{"role": "user", "content": meaning_in}]
55
84
  output = self.client.messages.create(
56
85
  model=kwargs.get("model_name", self.model_name),
jaclang/core/llms/base.py CHANGED
@@ -112,11 +112,11 @@ class BaseLLM:
112
112
  self.max_tries = max_tries
113
113
  raise NotImplementedError
114
114
 
115
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
115
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
116
116
  """Infer a response from the input meaning."""
117
117
  raise NotImplementedError
118
118
 
119
- def __call__(self, input_text: str, **kwargs: dict) -> str:
119
+ def __call__(self, input_text: str | list[dict], **kwargs: dict) -> str:
120
120
  """Infer a response from the input text."""
121
121
  if self.verbose:
122
122
  logger.info(f"Meaning In\n{input_text}")
@@ -131,7 +131,7 @@ class BaseLLM:
131
131
  ) -> str:
132
132
  """Resolve the output string to return the reasoning and output."""
133
133
  if self.verbose:
134
- logger.opt(colors=True).info(f"Meaning Out\n<green>{meaning_out}</green>")
134
+ logger.info(f"Meaning Out\n{meaning_out}")
135
135
  output_match = re.search(r"\[Output\](.*)", meaning_out)
136
136
  output = output_match.group(1).strip() if output_match else None
137
137
  if not output_match:
jaclang/core/llms/groq.py CHANGED
@@ -49,8 +49,11 @@ class Groq(BaseLLM):
49
49
  self.temperature = kwargs.get("temperature", 0.7)
50
50
  self.max_tokens = kwargs.get("max_tokens", 1024)
51
51
 
52
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
52
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
53
53
  """Infer a response from the input meaning."""
54
+ assert isinstance(
55
+ meaning_in, str
56
+ ), "Currently Multimodal models are not supported. Please provide a string input."
54
57
  messages = [{"role": "user", "content": meaning_in}]
55
58
  model_params = {
56
59
  k: v
@@ -61,8 +61,11 @@ class Huggingface(BaseLLM):
61
61
  self.temperature = kwargs.get("temperature", 0.7)
62
62
  self.max_tokens = kwargs.get("max_new_tokens", 1024)
63
63
 
64
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
64
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
65
65
  """Infer a response from the input meaning."""
66
+ assert isinstance(
67
+ meaning_in, str
68
+ ), "Currently Multimodal models are not supported. Please provide a string input."
66
69
  messages = [{"role": "user", "content": meaning_in}]
67
70
  output = self.pipe(
68
71
  messages,
@@ -51,8 +51,11 @@ class Ollama(BaseLLM):
51
51
  k: v for k, v in kwargs.items() if k not in ["model_name", "host"]
52
52
  }
53
53
 
54
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
54
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
55
55
  """Infer a response from the input meaning."""
56
+ assert isinstance(
57
+ meaning_in, str
58
+ ), "Currently Multimodal models are not supported. Please provide a string input."
56
59
  model = str(kwargs.get("model_name", self.model_name))
57
60
  if not self.check_model(model):
58
61
  self.download_model(model)
@@ -45,12 +45,16 @@ class OpenAI(BaseLLM):
45
45
  self.client = openai.OpenAI()
46
46
  self.verbose = verbose
47
47
  self.max_tries = max_tries
48
- self.model_name = kwargs.get("model_name", "gpt-3.5-turbo")
48
+ self.model_name = str(kwargs.get("model_name", "gpt-3.5-turbo"))
49
49
  self.temperature = kwargs.get("temperature", 0.7)
50
50
  self.max_tokens = kwargs.get("max_tokens", 1024)
51
51
 
52
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
52
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
53
53
  """Infer a response from the input meaning."""
54
+ if not isinstance(meaning_in, str):
55
+ assert self.model_name.startswith(
56
+ ("gpt-4o", "gpt-4-turbo")
57
+ ), f"Model {self.model_name} is not multimodal, use a multimodal model instead."
54
58
  messages = [{"role": "user", "content": meaning_in}]
55
59
  output = self.client.chat.completions.create(
56
60
  model=kwargs.get("model_name", self.model_name),
@@ -48,8 +48,11 @@ class TogetherAI(BaseLLM):
48
48
  self.temperature = kwargs.get("temperature", 0.7)
49
49
  self.max_tokens = kwargs.get("max_tokens", 1024)
50
50
 
51
- def __infer__(self, meaning_in: str, **kwargs: dict) -> str:
51
+ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
52
52
  """Infer a response from the input meaning."""
53
+ assert isinstance(
54
+ meaning_in, str
55
+ ), "Currently Multimodal models are not supported. Please provide a string input."
53
56
  messages = [{"role": "user", "content": meaning_in}]
54
57
  output = self.client.chat.completions.create(
55
58
  model=kwargs.get("model_name", self.model_name),