symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. symai/__init__.py +269 -173
  2. symai/backend/base.py +123 -110
  3. symai/backend/engines/drawing/engine_bfl.py +45 -44
  4. symai/backend/engines/drawing/engine_gpt_image.py +112 -97
  5. symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
  6. symai/backend/engines/embedding/engine_openai.py +25 -21
  7. symai/backend/engines/execute/engine_python.py +19 -18
  8. symai/backend/engines/files/engine_io.py +104 -95
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
  11. symai/backend/engines/index/engine_pinecone.py +124 -97
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +84 -56
  14. symai/backend/engines/lean/engine_lean4.py +96 -52
  15. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
  21. symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
  26. symai/backend/engines/ocr/engine_apilayer.py +23 -27
  27. symai/backend/engines/output/engine_stdout.py +10 -13
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
  29. symai/backend/engines/search/engine_openai.py +100 -88
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +44 -45
  32. symai/backend/engines/search/engine_serpapi.py +37 -34
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
  35. symai/backend/engines/text_to_speech/engine_openai.py +20 -26
  36. symai/backend/engines/text_vision/engine_clip.py +39 -37
  37. symai/backend/engines/userinput/engine_console.py +5 -6
  38. symai/backend/mixin/__init__.py +13 -0
  39. symai/backend/mixin/anthropic.py +48 -38
  40. symai/backend/mixin/deepseek.py +6 -5
  41. symai/backend/mixin/google.py +7 -4
  42. symai/backend/mixin/groq.py +2 -4
  43. symai/backend/mixin/openai.py +140 -110
  44. symai/backend/settings.py +87 -20
  45. symai/chat.py +216 -123
  46. symai/collect/__init__.py +7 -1
  47. symai/collect/dynamic.py +80 -70
  48. symai/collect/pipeline.py +67 -51
  49. symai/collect/stats.py +161 -109
  50. symai/components.py +707 -360
  51. symai/constraints.py +24 -12
  52. symai/core.py +1857 -1233
  53. symai/core_ext.py +83 -80
  54. symai/endpoints/api.py +166 -104
  55. symai/extended/.DS_Store +0 -0
  56. symai/extended/__init__.py +46 -12
  57. symai/extended/api_builder.py +29 -21
  58. symai/extended/arxiv_pdf_parser.py +23 -14
  59. symai/extended/bibtex_parser.py +9 -6
  60. symai/extended/conversation.py +156 -126
  61. symai/extended/document.py +50 -30
  62. symai/extended/file_merger.py +57 -14
  63. symai/extended/graph.py +51 -32
  64. symai/extended/html_style_template.py +18 -14
  65. symai/extended/interfaces/blip_2.py +2 -3
  66. symai/extended/interfaces/clip.py +4 -3
  67. symai/extended/interfaces/console.py +9 -1
  68. symai/extended/interfaces/dall_e.py +4 -2
  69. symai/extended/interfaces/file.py +2 -0
  70. symai/extended/interfaces/flux.py +4 -2
  71. symai/extended/interfaces/gpt_image.py +16 -7
  72. symai/extended/interfaces/input.py +2 -1
  73. symai/extended/interfaces/llava.py +1 -2
  74. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
  75. symai/extended/interfaces/naive_vectordb.py +9 -10
  76. symai/extended/interfaces/ocr.py +5 -3
  77. symai/extended/interfaces/openai_search.py +2 -0
  78. symai/extended/interfaces/parallel.py +30 -0
  79. symai/extended/interfaces/perplexity.py +2 -0
  80. symai/extended/interfaces/pinecone.py +12 -9
  81. symai/extended/interfaces/python.py +2 -0
  82. symai/extended/interfaces/serpapi.py +3 -1
  83. symai/extended/interfaces/terminal.py +2 -4
  84. symai/extended/interfaces/tts.py +3 -2
  85. symai/extended/interfaces/whisper.py +3 -2
  86. symai/extended/interfaces/wolframalpha.py +2 -1
  87. symai/extended/metrics/__init__.py +11 -1
  88. symai/extended/metrics/similarity.py +14 -13
  89. symai/extended/os_command.py +39 -29
  90. symai/extended/packages/__init__.py +29 -3
  91. symai/extended/packages/symdev.py +51 -43
  92. symai/extended/packages/sympkg.py +41 -35
  93. symai/extended/packages/symrun.py +63 -50
  94. symai/extended/repo_cloner.py +14 -12
  95. symai/extended/seo_query_optimizer.py +15 -13
  96. symai/extended/solver.py +116 -91
  97. symai/extended/summarizer.py +12 -10
  98. symai/extended/taypan_interpreter.py +17 -18
  99. symai/extended/vectordb.py +122 -92
  100. symai/formatter/__init__.py +9 -1
  101. symai/formatter/formatter.py +51 -47
  102. symai/formatter/regex.py +70 -69
  103. symai/functional.py +325 -176
  104. symai/imports.py +190 -147
  105. symai/interfaces.py +57 -28
  106. symai/memory.py +45 -35
  107. symai/menu/screen.py +28 -19
  108. symai/misc/console.py +66 -56
  109. symai/misc/loader.py +8 -5
  110. symai/models/__init__.py +17 -1
  111. symai/models/base.py +395 -236
  112. symai/models/errors.py +1 -2
  113. symai/ops/__init__.py +32 -22
  114. symai/ops/measures.py +24 -25
  115. symai/ops/primitives.py +1149 -731
  116. symai/post_processors.py +58 -50
  117. symai/pre_processors.py +86 -82
  118. symai/processor.py +21 -13
  119. symai/prompts.py +764 -685
  120. symai/server/huggingface_server.py +135 -49
  121. symai/server/llama_cpp_server.py +21 -11
  122. symai/server/qdrant_server.py +206 -0
  123. symai/shell.py +100 -42
  124. symai/shellsv.py +700 -492
  125. symai/strategy.py +630 -346
  126. symai/symbol.py +368 -322
  127. symai/utils.py +100 -78
  128. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
  129. symbolicai-1.1.0.dist-info/RECORD +168 -0
  130. symbolicai-0.21.0.dist-info/RECORD +0 -162
  131. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  132. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  133. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  134. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/backend/base.py CHANGED
@@ -1,115 +1,125 @@
1
1
  import logging
2
- import os
3
2
  import time
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
4
5
 
5
- from abc import ABC
6
- from typing import Any, List, Tuple
6
+ from ..collect import CollectionRepository, rec_serialize
7
+ from ..utils import UserMessage
7
8
  from .settings import HOME_PATH
8
9
 
9
- from ..collect import CollectionRepository, rec_serialize
10
+ ENGINE_UNREGISTERED = "<UNREGISTERED/>"
10
11
 
12
+ COLLECTION_LOGGING_ENGINES = {
13
+ "GPTXChatEngine",
14
+ "GPTXCompletionEngine",
15
+ "SerpApiEngine",
16
+ "WolframAlphaEngine",
17
+ "SeleniumEngine",
18
+ "OCREngine",
19
+ }
11
20
 
12
- ENGINE_UNREGISTERED = '<UNREGISTERED/>'
13
21
 
14
22
  class Engine(ABC):
15
23
  def __init__(self) -> None:
16
24
  super().__init__()
17
- self.verbose = False
18
- self.logging = False
19
- self.log_level = logging.DEBUG
25
+ self.verbose = False
26
+ self.logging = False
27
+ self.log_level = logging.DEBUG
20
28
  self.time_clock = False
21
29
  self.collection = CollectionRepository()
22
30
  self.collection.connect()
23
31
  # create formatter
24
- __root_dir__ = HOME_PATH
25
- os.makedirs(__root_dir__, exist_ok=True)
32
+ __root_dir__ = HOME_PATH
33
+ __root_dir__.mkdir(parents=True, exist_ok=True)
26
34
  __file_path__ = __root_dir__ / "engine.log"
27
- logging.basicConfig(filename=__file_path__, filemode="a", format='%(asctime)s %(name)s %(levelname)s %(message)s')
28
- self.logger = logging.getLogger()
35
+ logging.basicConfig(
36
+ filename=__file_path__,
37
+ filemode="a",
38
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
39
+ )
40
+ self.logger = logging.getLogger()
29
41
  self.logger.setLevel(logging.DEBUG)
30
42
  # logging to console
31
- stream = logging.StreamHandler()
32
- streamformat = logging.Formatter("%(asctime)s %(message)s")
43
+ stream = logging.StreamHandler()
44
+ streamformat = logging.Formatter("%(asctime)s %(message)s")
33
45
  stream.setLevel(logging.INFO)
34
46
  stream.setFormatter(streamformat)
35
47
  self.logger.addHandler(stream)
36
48
 
37
- def __call__(self, argument: Any) -> Tuple[List[str], dict]:
38
- log = {
39
- 'Input': {
40
- 'self': self,
41
- 'args': argument.args,
42
- **argument.kwargs
43
- }
44
- }
49
+ def __call__(self, argument: Any) -> tuple[list[str], dict]:
50
+ log = {"Input": {"self": self, "args": argument.args, **argument.kwargs}}
45
51
  start_time = time.time()
46
52
 
47
- # check for global object based input handler
48
- if hasattr(argument.prop.instance, '_metadata') and hasattr(argument.prop.instance._metadata, 'input_handler'):
49
- input_handler = argument.prop.instance._metadata.input_handler if hasattr(argument.prop.instance._metadata, 'input_handler') else None
50
- if input_handler is not None:
51
- input_handler((argument.prop.processed_input, argument))
52
- # check for kwargs based input handler
53
- if argument.prop.input_handler is not None:
54
- argument.prop.input_handler((argument.prop.processed_input, argument))
53
+ self._trigger_input_handlers(argument)
55
54
 
56
- # execute the engine
57
55
  res, metadata = self.forward(argument)
58
56
 
59
- # compute time
60
57
  req_time = time.time() - start_time
61
- metadata['time'] = req_time
58
+ metadata["time"] = req_time
62
59
  if self.time_clock:
63
- print(f"{argument.prop.func}: {req_time} sec")
64
- log['Output'] = res
60
+ UserMessage(f"{argument.prop.func}: {req_time} sec")
61
+ log["Output"] = res
65
62
  if self.verbose:
66
- view = {k: v for k, v in list(log['Input'].items()) if k != 'self'}
67
- input_ = f"{str(log['Input']['self'])[:50]}, {str(argument.prop.func)}, {str(view)}"
68
- print(input_[:150], str(log['Output'])[:100])
63
+ view = {k: v for k, v in list(log["Input"].items()) if k != "self"}
64
+ input_ = f"{str(log['Input']['self'])[:50]}, {argument.prop.func!s}, {view!s}"
65
+ UserMessage(f"{input_[:150]} {str(log['Output'])[:100]}")
69
66
  if self.logging:
70
67
  self.logger.log(self.log_level, log)
71
68
 
72
- # share data statistics with the collection repository
73
- if str(self) == 'GPTXChatEngine' \
74
- or str(self) == 'GPTXCompletionEngine' \
75
- or str(self) == 'SerpApiEngine' \
76
- or str(self) == 'WolframAlphaEngine' \
77
- or str(self) == 'SeleniumEngine' \
78
- or str(self) == 'OCREngine':
79
- self.collection.add(
80
- forward={'args': rec_serialize(argument.args), 'kwds': rec_serialize(argument.kwargs)},
81
- engine=str(self),
82
- metadata={
83
- 'time': req_time,
84
- 'data': rec_serialize(metadata),
85
- 'argument': rec_serialize(argument)
86
- }
87
- )
88
-
89
- # check for global object based output handler
90
- if hasattr(argument.prop.instance, '_metadata') and hasattr(argument.prop.instance._metadata, 'output_handler'):
91
- output_handler = argument.prop.instance._metadata.output_handler if hasattr(argument.prop.instance._metadata, 'output_handler') else None
92
- if output_handler:
93
- output_handler(res)
94
- # check for kwargs based output handler
95
- if argument.prop.output_handler:
96
- argument.prop.output_handler((res, metadata))
69
+ if str(self) in COLLECTION_LOGGING_ENGINES:
70
+ self._record_collection_entry(argument, metadata, req_time)
71
+
72
+ self._trigger_output_handlers(argument, res, metadata)
97
73
  return res, metadata
98
74
 
75
+ def _trigger_input_handlers(self, argument: Any) -> None:
76
+ instance_metadata = getattr(argument.prop.instance, "_metadata", None)
77
+ if instance_metadata is not None:
78
+ input_handler = getattr(instance_metadata, "input_handler", None)
79
+ if input_handler is not None:
80
+ input_handler((argument.prop.processed_input, argument))
81
+ argument_handler = argument.prop.input_handler
82
+ if argument_handler is not None:
83
+ argument_handler((argument.prop.processed_input, argument))
84
+
85
+ def _trigger_output_handlers(self, argument: Any, result: Any, metadata: dict | None) -> None:
86
+ instance_metadata = getattr(argument.prop.instance, "_metadata", None)
87
+ if instance_metadata is not None:
88
+ output_handler = getattr(instance_metadata, "output_handler", None)
89
+ if output_handler:
90
+ output_handler(result)
91
+ argument_handler = argument.prop.output_handler
92
+ if argument_handler:
93
+ argument_handler((result, metadata))
94
+
95
+ def _record_collection_entry(self, argument: Any, metadata: dict, req_time: float) -> None:
96
+ self.collection.add(
97
+ forward={"args": rec_serialize(argument.args), "kwds": rec_serialize(argument.kwargs)},
98
+ engine=str(self),
99
+ metadata={
100
+ "time": req_time,
101
+ "data": rec_serialize(metadata),
102
+ "argument": rec_serialize(argument),
103
+ },
104
+ )
105
+
99
106
  def id(self) -> str:
100
107
  return ENGINE_UNREGISTERED
101
108
 
102
109
  def preview(self, argument):
103
- # used here to avoid circular import
104
- from ..symbol import Symbol
110
+ # Used here to avoid backend.base <-> symbol circular import.
111
+ from ..symbol import ( # noqa
112
+ Symbol,
113
+ )
114
+
105
115
  class Preview(Symbol):
106
116
  def __repr__(self) -> str:
107
- '''
117
+ """
108
118
  Get the representation of the Symbol object as a string.
109
119
 
110
120
  Returns:
111
121
  str: The representation of the Symbol object.
112
- '''
122
+ """
113
123
  return str(self.value.prop.prepared_input)
114
124
 
115
125
  def prepared_input(self):
@@ -117,36 +127,38 @@ class Engine(ABC):
117
127
 
118
128
  return Preview(argument), {}
119
129
 
120
- def forward(self, *args: Any, **kwds: Any) -> List[str]:
121
- raise NotADirectoryError()
130
+ @abstractmethod
131
+ def forward(self, *args: Any, **kwds: Any) -> list[str]:
132
+ raise NotADirectoryError
122
133
 
134
+ @abstractmethod
123
135
  def prepare(self, argument):
124
- raise NotImplementedError()
125
-
126
- def command(self, *args, **kwargs):
127
- if kwargs.get('verbose', None):
128
- self.verbose = kwargs['verbose']
129
- if kwargs.get('logging', None):
130
- self.logging = kwargs['logging']
131
- if kwargs.get('log_level', None):
132
- self.log_level = kwargs['log_level']
133
- if kwargs.get('time_clock', None):
134
- self.time_clock = kwargs['time_clock']
136
+ raise NotImplementedError
137
+
138
+ def command(self, *_args, **kwargs):
139
+ if kwargs.get("verbose"):
140
+ self.verbose = kwargs["verbose"]
141
+ if kwargs.get("logging"):
142
+ self.logging = kwargs["logging"]
143
+ if kwargs.get("log_level"):
144
+ self.log_level = kwargs["log_level"]
145
+ if kwargs.get("time_clock"):
146
+ self.time_clock = kwargs["time_clock"]
135
147
 
136
148
  def __str__(self) -> str:
137
149
  return self.__class__.__name__
138
150
 
139
151
  def __repr__(self) -> str:
140
- '''
152
+ """
141
153
  Get the representation of the Symbol object as a string.
142
154
 
143
155
  Returns:
144
156
  str: The representation of the Symbol object.
145
- '''
157
+ """
146
158
  # class with full path
147
- class_ = self.__class__.__module__ + '.' + self.__class__.__name__
148
- hex_ = hex(id(self))
149
- return f'<class {class_} at {hex_}>'
159
+ class_ = self.__class__.__module__ + "." + self.__class__.__name__
160
+ hex_ = hex(id(self))
161
+ return f"<class {class_} at {hex_}>"
150
162
 
151
163
 
152
164
  class BatchEngine(Engine):
@@ -155,41 +167,42 @@ class BatchEngine(Engine):
155
167
  self.time_clock = True
156
168
  self.allows_batching = True
157
169
 
158
- def __call__(self, arguments: List[Any]) -> List[Tuple[Any, dict]]:
170
+ def __call__(self, arguments: list[Any]) -> list[tuple[Any, dict]]:
159
171
  start_time = time.time()
160
172
  for arg in arguments:
161
- if hasattr(arg.prop.instance, '_metadata') and hasattr(arg.prop.instance._metadata, 'input_handler'):
162
- input_handler = getattr(arg.prop.instance._metadata, 'input_handler', None)
163
- if input_handler is not None:
164
- input_handler((arg.prop.processed_input, arg))
165
- if arg.prop.input_handler is not None:
166
- arg.prop.input_handler((arg.prop.processed_input, arg))
173
+ self._trigger_input_handlers(arg)
167
174
 
168
- try:
169
- results, metadata_list = self.forward(arguments)
170
- except Exception as e:
171
- results = [e] * len(arguments)
172
- metadata_list = [None] * len(arguments)
175
+ results, metadata_list = self._execute_batch(arguments)
173
176
 
174
177
  total_time = time.time() - start_time
175
178
  if self.time_clock:
176
- print(f"Total execution time: {total_time} sec")
179
+ UserMessage(f"Total execution time: {total_time} sec")
177
180
 
178
- return_list = []
181
+ return self._prepare_batch_results(arguments, results, metadata_list, total_time)
179
182
 
180
- for arg, result, metadata in zip(arguments, results, metadata_list):
183
+ def _execute_batch(self, arguments: list[Any]) -> tuple[list[Any], list[dict | None]]:
184
+ try:
185
+ return self.forward(arguments)
186
+ except Exception as error:
187
+ return [error] * len(arguments), [None] * len(arguments)
188
+
189
+ def _prepare_batch_results(
190
+ self,
191
+ arguments: list[Any],
192
+ results: list[Any],
193
+ metadata_list: list[dict | None],
194
+ total_time: float,
195
+ ) -> list[tuple[Any, dict | None]]:
196
+ return_list = []
197
+ for arg, result, metadata in zip(arguments, results, metadata_list, strict=False):
181
198
  if metadata is not None:
182
- metadata['time'] = total_time / len(arguments)
183
-
184
- if hasattr(arg.prop.instance, '_metadata') and hasattr(arg.prop.instance._metadata, 'output_handler'):
185
- output_handler = getattr(arg.prop.instance._metadata, 'output_handler', None)
186
- if output_handler:
187
- output_handler(result)
188
- if arg.prop.output_handler:
189
- arg.prop.output_handler((result, metadata))
199
+ metadata["time"] = total_time / len(arguments)
190
200
 
201
+ self._trigger_output_handlers(arg, result, metadata)
191
202
  return_list.append((result, metadata))
192
203
  return return_list
193
204
 
194
- def forward(self, arguments: List[Any]) -> Tuple[List[Any], List[dict]]:
195
- raise NotImplementedError("Subclasses must implement forward method")
205
+ def forward(self, _arguments: list[Any]) -> tuple[list[Any], list[dict]]:
206
+ msg = "Subclasses must implement forward method"
207
+ UserMessage(msg)
208
+ raise NotImplementedError(msg)
@@ -1,11 +1,12 @@
1
1
  import logging
2
2
  import tempfile
3
3
  import time
4
- from typing import Optional
4
+ from pathlib import Path
5
5
 
6
6
  import requests
7
7
 
8
8
  from ....symbol import Result
9
+ from ....utils import UserMessage
9
10
  from ...base import Engine
10
11
  from ...settings import SYMAI_CONFIG
11
12
 
@@ -19,89 +20,89 @@ class FluxResult(Result):
19
20
  def __init__(self, value, **kwargs):
20
21
  super().__init__(value, **kwargs)
21
22
  # unpack the result
22
- path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
23
- url = value.get('result').get('sample')
23
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
24
+ path = tmp_file.name
25
+ url = value.get("result").get("sample")
24
26
  request = requests.get(url, allow_redirects=True)
25
27
  request.raise_for_status()
26
- with open(path, "wb") as f:
28
+ with Path(path).open("wb") as f:
27
29
  f.write(request.content)
28
30
  self._value = [path]
29
31
 
30
32
 
31
33
  class DrawingEngine(Engine):
32
- def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
34
+ def __init__(self, api_key: str | None = None, model: str | None = None):
33
35
  super().__init__()
34
36
  self.config = SYMAI_CONFIG
35
- self.api_key = self.config['DRAWING_ENGINE_API_KEY'] if api_key is None else api_key
36
- self.model = self.config['DRAWING_ENGINE_MODEL'] if model is None else model
37
+ self.api_key = self.config["DRAWING_ENGINE_API_KEY"] if api_key is None else api_key
38
+ self.model = self.config["DRAWING_ENGINE_MODEL"] if model is None else model
37
39
  self.name = self.__class__.__name__
38
40
 
39
41
  def id(self) -> str:
40
- if self.config['DRAWING_ENGINE_API_KEY'] and self.config['DRAWING_ENGINE_MODEL'].startswith("flux"):
41
- return 'drawing'
42
- return super().id() # default to unregistered
42
+ if self.config["DRAWING_ENGINE_API_KEY"] and self.config["DRAWING_ENGINE_MODEL"].startswith(
43
+ "flux"
44
+ ):
45
+ return "drawing"
46
+ return super().id() # default to unregistered
43
47
 
44
48
  def command(self, *args, **kwargs):
45
49
  super().command(*args, **kwargs)
46
- if 'DRAWING_ENGINE_API_KEY' in kwargs:
47
- self.api_key = kwargs['DRAWING_ENGINE_API_KEY']
48
- if 'DRAWING_ENGINE_MODEL' in kwargs:
49
- self.model = kwargs['DRAWING_ENGINE_MODEL']
50
+ if "DRAWING_ENGINE_API_KEY" in kwargs:
51
+ self.api_key = kwargs["DRAWING_ENGINE_API_KEY"]
52
+ if "DRAWING_ENGINE_MODEL" in kwargs:
53
+ self.model = kwargs["DRAWING_ENGINE_MODEL"]
50
54
 
51
55
  def forward(self, argument):
52
56
  prompt = argument.prop.prepared_input
53
57
  kwargs = argument.kwargs
54
- model = kwargs.get('model', self.model)
55
- width = kwargs.get('width', 1024)
56
- height = kwargs.get('height', 768)
57
- steps = kwargs.get('steps', 40)
58
- prompt_upsampling = kwargs.get('prompt_upsampling', False)
59
- seed = kwargs.get('seed', None)
60
- guidance = kwargs.get('guidance', None)
61
- safety_tolerance = kwargs.get('safety_tolerance', 2)
62
- interval = kwargs.get('interval', None)
63
- output_format = kwargs.get('output_format', 'png')
64
- except_remedy = kwargs.get('except_remedy', None)
58
+ width = kwargs.get("width", 1024)
59
+ height = kwargs.get("height", 768)
60
+ steps = kwargs.get("steps", 40)
61
+ seed = kwargs.get("seed", None)
62
+ guidance = kwargs.get("guidance", None)
63
+ safety_tolerance = kwargs.get("safety_tolerance", 2)
64
+ except_remedy = kwargs.get("except_remedy", None)
65
65
 
66
66
  headers = {
67
- 'accept': 'application/json',
68
- 'x-key': self.api_key,
69
- 'Content-Type': 'application/json',
67
+ "accept": "application/json",
68
+ "x-key": self.api_key,
69
+ "Content-Type": "application/json",
70
70
  }
71
71
 
72
72
  payload = {
73
- 'prompt': prompt,
74
- 'width': width,
75
- 'height': height,
76
- 'num_inference_steps': steps,
77
- 'guidance_scale': guidance,
78
- 'seed': seed,
79
- 'safety_tolerance': safety_tolerance,
73
+ "prompt": prompt,
74
+ "width": width,
75
+ "height": height,
76
+ "num_inference_steps": steps,
77
+ "guidance_scale": guidance,
78
+ "seed": seed,
79
+ "safety_tolerance": safety_tolerance,
80
80
  }
81
81
  # drop any None values so Flux API won't return 500
82
82
  payload = {k: v for k, v in payload.items() if v is not None}
83
83
 
84
- if kwargs.get('operation') == 'create':
84
+ if kwargs.get("operation") == "create":
85
85
  try:
86
86
  response = requests.post(
87
- f'https://api.us1.bfl.ai/v1/{self.model}',
88
- headers=headers,
89
- json=payload
87
+ f"https://api.us1.bfl.ai/v1/{self.model}", headers=headers, json=payload
90
88
  )
91
89
  # fail early on HTTP errors
92
90
  response.raise_for_status()
93
91
  data = response.json()
94
92
  request_id = data.get("id")
95
93
  if not request_id:
96
- raise Exception(f"Failed to get request ID! Response payload: {data}")
94
+ UserMessage(
95
+ f"Failed to get request ID! Response payload: {data}",
96
+ raise_with=Exception,
97
+ )
97
98
 
98
99
  while True:
99
100
  time.sleep(5)
100
101
 
101
102
  result = requests.get(
102
- 'https://api.us1.bfl.ai/v1/get_result',
103
+ "https://api.us1.bfl.ai/v1/get_result",
103
104
  headers=headers,
104
- params={'id': request_id}
105
+ params={"id": request_id},
105
106
  )
106
107
 
107
108
  result.raise_for_status()
@@ -118,8 +119,8 @@ class DrawingEngine(Engine):
118
119
 
119
120
  metadata = {}
120
121
  return [rsp], metadata
121
- else:
122
- raise Exception(f"Unknown operation: {kwargs['operation']}")
122
+ UserMessage(f"Unknown operation: {kwargs['operation']}", raise_with=Exception)
123
+ return [], {}
123
124
 
124
125
  def prepare(self, argument):
125
126
  argument.prop.prepared_input = str(argument.prop.processed_input)