symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +269 -173
- symai/backend/base.py +123 -110
- symai/backend/engines/drawing/engine_bfl.py +45 -44
- symai/backend/engines/drawing/engine_gpt_image.py +112 -97
- symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
- symai/backend/engines/embedding/engine_openai.py +25 -21
- symai/backend/engines/execute/engine_python.py +19 -18
- symai/backend/engines/files/engine_io.py +104 -95
- symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
- symai/backend/engines/index/engine_pinecone.py +124 -97
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +84 -56
- symai/backend/engines/lean/engine_lean4.py +96 -52
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
- symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
- symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
- symai/backend/engines/ocr/engine_apilayer.py +23 -27
- symai/backend/engines/output/engine_stdout.py +10 -13
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
- symai/backend/engines/search/engine_openai.py +100 -88
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +44 -45
- symai/backend/engines/search/engine_serpapi.py +37 -34
- symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
- symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
- symai/backend/engines/text_to_speech/engine_openai.py +20 -26
- symai/backend/engines/text_vision/engine_clip.py +39 -37
- symai/backend/engines/userinput/engine_console.py +5 -6
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +48 -38
- symai/backend/mixin/deepseek.py +6 -5
- symai/backend/mixin/google.py +7 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +140 -110
- symai/backend/settings.py +87 -20
- symai/chat.py +216 -123
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +80 -70
- symai/collect/pipeline.py +67 -51
- symai/collect/stats.py +161 -109
- symai/components.py +707 -360
- symai/constraints.py +24 -12
- symai/core.py +1857 -1233
- symai/core_ext.py +83 -80
- symai/endpoints/api.py +166 -104
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +29 -21
- symai/extended/arxiv_pdf_parser.py +23 -14
- symai/extended/bibtex_parser.py +9 -6
- symai/extended/conversation.py +156 -126
- symai/extended/document.py +50 -30
- symai/extended/file_merger.py +57 -14
- symai/extended/graph.py +51 -32
- symai/extended/html_style_template.py +18 -14
- symai/extended/interfaces/blip_2.py +2 -3
- symai/extended/interfaces/clip.py +4 -3
- symai/extended/interfaces/console.py +9 -1
- symai/extended/interfaces/dall_e.py +4 -2
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +4 -2
- symai/extended/interfaces/gpt_image.py +16 -7
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -2
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
- symai/extended/interfaces/naive_vectordb.py +9 -10
- symai/extended/interfaces/ocr.py +5 -3
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +12 -9
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +3 -1
- symai/extended/interfaces/terminal.py +2 -4
- symai/extended/interfaces/tts.py +3 -2
- symai/extended/interfaces/whisper.py +3 -2
- symai/extended/interfaces/wolframalpha.py +2 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +14 -13
- symai/extended/os_command.py +39 -29
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +51 -43
- symai/extended/packages/sympkg.py +41 -35
- symai/extended/packages/symrun.py +63 -50
- symai/extended/repo_cloner.py +14 -12
- symai/extended/seo_query_optimizer.py +15 -13
- symai/extended/solver.py +116 -91
- symai/extended/summarizer.py +12 -10
- symai/extended/taypan_interpreter.py +17 -18
- symai/extended/vectordb.py +122 -92
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +51 -47
- symai/formatter/regex.py +70 -69
- symai/functional.py +325 -176
- symai/imports.py +190 -147
- symai/interfaces.py +57 -28
- symai/memory.py +45 -35
- symai/menu/screen.py +28 -19
- symai/misc/console.py +66 -56
- symai/misc/loader.py +8 -5
- symai/models/__init__.py +17 -1
- symai/models/base.py +395 -236
- symai/models/errors.py +1 -2
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +24 -25
- symai/ops/primitives.py +1149 -731
- symai/post_processors.py +58 -50
- symai/pre_processors.py +86 -82
- symai/processor.py +21 -13
- symai/prompts.py +764 -685
- symai/server/huggingface_server.py +135 -49
- symai/server/llama_cpp_server.py +21 -11
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +100 -42
- symai/shellsv.py +700 -492
- symai/strategy.py +630 -346
- symai/symbol.py +368 -322
- symai/utils.py +100 -78
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-0.21.0.dist-info/RECORD +0 -162
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/backend/base.py
CHANGED
|
@@ -1,115 +1,125 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import os
|
|
3
2
|
import time
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
6
|
+
from ..collect import CollectionRepository, rec_serialize
|
|
7
|
+
from ..utils import UserMessage
|
|
7
8
|
from .settings import HOME_PATH
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
ENGINE_UNREGISTERED = "<UNREGISTERED/>"
|
|
10
11
|
|
|
12
|
+
COLLECTION_LOGGING_ENGINES = {
|
|
13
|
+
"GPTXChatEngine",
|
|
14
|
+
"GPTXCompletionEngine",
|
|
15
|
+
"SerpApiEngine",
|
|
16
|
+
"WolframAlphaEngine",
|
|
17
|
+
"SeleniumEngine",
|
|
18
|
+
"OCREngine",
|
|
19
|
+
}
|
|
11
20
|
|
|
12
|
-
ENGINE_UNREGISTERED = '<UNREGISTERED/>'
|
|
13
21
|
|
|
14
22
|
class Engine(ABC):
|
|
15
23
|
def __init__(self) -> None:
|
|
16
24
|
super().__init__()
|
|
17
|
-
self.verbose
|
|
18
|
-
self.logging
|
|
19
|
-
self.log_level
|
|
25
|
+
self.verbose = False
|
|
26
|
+
self.logging = False
|
|
27
|
+
self.log_level = logging.DEBUG
|
|
20
28
|
self.time_clock = False
|
|
21
29
|
self.collection = CollectionRepository()
|
|
22
30
|
self.collection.connect()
|
|
23
31
|
# create formatter
|
|
24
|
-
__root_dir__
|
|
25
|
-
|
|
32
|
+
__root_dir__ = HOME_PATH
|
|
33
|
+
__root_dir__.mkdir(parents=True, exist_ok=True)
|
|
26
34
|
__file_path__ = __root_dir__ / "engine.log"
|
|
27
|
-
logging.basicConfig(
|
|
28
|
-
|
|
35
|
+
logging.basicConfig(
|
|
36
|
+
filename=__file_path__,
|
|
37
|
+
filemode="a",
|
|
38
|
+
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
|
39
|
+
)
|
|
40
|
+
self.logger = logging.getLogger()
|
|
29
41
|
self.logger.setLevel(logging.DEBUG)
|
|
30
42
|
# logging to console
|
|
31
|
-
stream
|
|
32
|
-
streamformat
|
|
43
|
+
stream = logging.StreamHandler()
|
|
44
|
+
streamformat = logging.Formatter("%(asctime)s %(message)s")
|
|
33
45
|
stream.setLevel(logging.INFO)
|
|
34
46
|
stream.setFormatter(streamformat)
|
|
35
47
|
self.logger.addHandler(stream)
|
|
36
48
|
|
|
37
|
-
def __call__(self, argument: Any) ->
|
|
38
|
-
log = {
|
|
39
|
-
'Input': {
|
|
40
|
-
'self': self,
|
|
41
|
-
'args': argument.args,
|
|
42
|
-
**argument.kwargs
|
|
43
|
-
}
|
|
44
|
-
}
|
|
49
|
+
def __call__(self, argument: Any) -> tuple[list[str], dict]:
|
|
50
|
+
log = {"Input": {"self": self, "args": argument.args, **argument.kwargs}}
|
|
45
51
|
start_time = time.time()
|
|
46
52
|
|
|
47
|
-
|
|
48
|
-
if hasattr(argument.prop.instance, '_metadata') and hasattr(argument.prop.instance._metadata, 'input_handler'):
|
|
49
|
-
input_handler = argument.prop.instance._metadata.input_handler if hasattr(argument.prop.instance._metadata, 'input_handler') else None
|
|
50
|
-
if input_handler is not None:
|
|
51
|
-
input_handler((argument.prop.processed_input, argument))
|
|
52
|
-
# check for kwargs based input handler
|
|
53
|
-
if argument.prop.input_handler is not None:
|
|
54
|
-
argument.prop.input_handler((argument.prop.processed_input, argument))
|
|
53
|
+
self._trigger_input_handlers(argument)
|
|
55
54
|
|
|
56
|
-
# execute the engine
|
|
57
55
|
res, metadata = self.forward(argument)
|
|
58
56
|
|
|
59
|
-
# compute time
|
|
60
57
|
req_time = time.time() - start_time
|
|
61
|
-
metadata[
|
|
58
|
+
metadata["time"] = req_time
|
|
62
59
|
if self.time_clock:
|
|
63
|
-
|
|
64
|
-
log[
|
|
60
|
+
UserMessage(f"{argument.prop.func}: {req_time} sec")
|
|
61
|
+
log["Output"] = res
|
|
65
62
|
if self.verbose:
|
|
66
|
-
view
|
|
67
|
-
input_ = f"{str(log['Input']['self'])[:50]}, {
|
|
68
|
-
|
|
63
|
+
view = {k: v for k, v in list(log["Input"].items()) if k != "self"}
|
|
64
|
+
input_ = f"{str(log['Input']['self'])[:50]}, {argument.prop.func!s}, {view!s}"
|
|
65
|
+
UserMessage(f"{input_[:150]} {str(log['Output'])[:100]}")
|
|
69
66
|
if self.logging:
|
|
70
67
|
self.logger.log(self.log_level, log)
|
|
71
68
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
or str(self) == 'WolframAlphaEngine' \
|
|
77
|
-
or str(self) == 'SeleniumEngine' \
|
|
78
|
-
or str(self) == 'OCREngine':
|
|
79
|
-
self.collection.add(
|
|
80
|
-
forward={'args': rec_serialize(argument.args), 'kwds': rec_serialize(argument.kwargs)},
|
|
81
|
-
engine=str(self),
|
|
82
|
-
metadata={
|
|
83
|
-
'time': req_time,
|
|
84
|
-
'data': rec_serialize(metadata),
|
|
85
|
-
'argument': rec_serialize(argument)
|
|
86
|
-
}
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# check for global object based output handler
|
|
90
|
-
if hasattr(argument.prop.instance, '_metadata') and hasattr(argument.prop.instance._metadata, 'output_handler'):
|
|
91
|
-
output_handler = argument.prop.instance._metadata.output_handler if hasattr(argument.prop.instance._metadata, 'output_handler') else None
|
|
92
|
-
if output_handler:
|
|
93
|
-
output_handler(res)
|
|
94
|
-
# check for kwargs based output handler
|
|
95
|
-
if argument.prop.output_handler:
|
|
96
|
-
argument.prop.output_handler((res, metadata))
|
|
69
|
+
if str(self) in COLLECTION_LOGGING_ENGINES:
|
|
70
|
+
self._record_collection_entry(argument, metadata, req_time)
|
|
71
|
+
|
|
72
|
+
self._trigger_output_handlers(argument, res, metadata)
|
|
97
73
|
return res, metadata
|
|
98
74
|
|
|
75
|
+
def _trigger_input_handlers(self, argument: Any) -> None:
|
|
76
|
+
instance_metadata = getattr(argument.prop.instance, "_metadata", None)
|
|
77
|
+
if instance_metadata is not None:
|
|
78
|
+
input_handler = getattr(instance_metadata, "input_handler", None)
|
|
79
|
+
if input_handler is not None:
|
|
80
|
+
input_handler((argument.prop.processed_input, argument))
|
|
81
|
+
argument_handler = argument.prop.input_handler
|
|
82
|
+
if argument_handler is not None:
|
|
83
|
+
argument_handler((argument.prop.processed_input, argument))
|
|
84
|
+
|
|
85
|
+
def _trigger_output_handlers(self, argument: Any, result: Any, metadata: dict | None) -> None:
|
|
86
|
+
instance_metadata = getattr(argument.prop.instance, "_metadata", None)
|
|
87
|
+
if instance_metadata is not None:
|
|
88
|
+
output_handler = getattr(instance_metadata, "output_handler", None)
|
|
89
|
+
if output_handler:
|
|
90
|
+
output_handler(result)
|
|
91
|
+
argument_handler = argument.prop.output_handler
|
|
92
|
+
if argument_handler:
|
|
93
|
+
argument_handler((result, metadata))
|
|
94
|
+
|
|
95
|
+
def _record_collection_entry(self, argument: Any, metadata: dict, req_time: float) -> None:
|
|
96
|
+
self.collection.add(
|
|
97
|
+
forward={"args": rec_serialize(argument.args), "kwds": rec_serialize(argument.kwargs)},
|
|
98
|
+
engine=str(self),
|
|
99
|
+
metadata={
|
|
100
|
+
"time": req_time,
|
|
101
|
+
"data": rec_serialize(metadata),
|
|
102
|
+
"argument": rec_serialize(argument),
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
|
|
99
106
|
def id(self) -> str:
|
|
100
107
|
return ENGINE_UNREGISTERED
|
|
101
108
|
|
|
102
109
|
def preview(self, argument):
|
|
103
|
-
#
|
|
104
|
-
from ..symbol import
|
|
110
|
+
# Used here to avoid backend.base <-> symbol circular import.
|
|
111
|
+
from ..symbol import ( # noqa
|
|
112
|
+
Symbol,
|
|
113
|
+
)
|
|
114
|
+
|
|
105
115
|
class Preview(Symbol):
|
|
106
116
|
def __repr__(self) -> str:
|
|
107
|
-
|
|
117
|
+
"""
|
|
108
118
|
Get the representation of the Symbol object as a string.
|
|
109
119
|
|
|
110
120
|
Returns:
|
|
111
121
|
str: The representation of the Symbol object.
|
|
112
|
-
|
|
122
|
+
"""
|
|
113
123
|
return str(self.value.prop.prepared_input)
|
|
114
124
|
|
|
115
125
|
def prepared_input(self):
|
|
@@ -117,36 +127,38 @@ class Engine(ABC):
|
|
|
117
127
|
|
|
118
128
|
return Preview(argument), {}
|
|
119
129
|
|
|
120
|
-
|
|
121
|
-
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def forward(self, *args: Any, **kwds: Any) -> list[str]:
|
|
132
|
+
raise NotADirectoryError
|
|
122
133
|
|
|
134
|
+
@abstractmethod
|
|
123
135
|
def prepare(self, argument):
|
|
124
|
-
raise NotImplementedError
|
|
125
|
-
|
|
126
|
-
def command(self, *
|
|
127
|
-
if kwargs.get(
|
|
128
|
-
self.verbose = kwargs[
|
|
129
|
-
if kwargs.get(
|
|
130
|
-
self.logging = kwargs[
|
|
131
|
-
if kwargs.get(
|
|
132
|
-
self.log_level = kwargs[
|
|
133
|
-
if kwargs.get(
|
|
134
|
-
self.time_clock = kwargs[
|
|
136
|
+
raise NotImplementedError
|
|
137
|
+
|
|
138
|
+
def command(self, *_args, **kwargs):
|
|
139
|
+
if kwargs.get("verbose"):
|
|
140
|
+
self.verbose = kwargs["verbose"]
|
|
141
|
+
if kwargs.get("logging"):
|
|
142
|
+
self.logging = kwargs["logging"]
|
|
143
|
+
if kwargs.get("log_level"):
|
|
144
|
+
self.log_level = kwargs["log_level"]
|
|
145
|
+
if kwargs.get("time_clock"):
|
|
146
|
+
self.time_clock = kwargs["time_clock"]
|
|
135
147
|
|
|
136
148
|
def __str__(self) -> str:
|
|
137
149
|
return self.__class__.__name__
|
|
138
150
|
|
|
139
151
|
def __repr__(self) -> str:
|
|
140
|
-
|
|
152
|
+
"""
|
|
141
153
|
Get the representation of the Symbol object as a string.
|
|
142
154
|
|
|
143
155
|
Returns:
|
|
144
156
|
str: The representation of the Symbol object.
|
|
145
|
-
|
|
157
|
+
"""
|
|
146
158
|
# class with full path
|
|
147
|
-
class_ = self.__class__.__module__ +
|
|
148
|
-
hex_
|
|
149
|
-
return f
|
|
159
|
+
class_ = self.__class__.__module__ + "." + self.__class__.__name__
|
|
160
|
+
hex_ = hex(id(self))
|
|
161
|
+
return f"<class {class_} at {hex_}>"
|
|
150
162
|
|
|
151
163
|
|
|
152
164
|
class BatchEngine(Engine):
|
|
@@ -155,41 +167,42 @@ class BatchEngine(Engine):
|
|
|
155
167
|
self.time_clock = True
|
|
156
168
|
self.allows_batching = True
|
|
157
169
|
|
|
158
|
-
def __call__(self, arguments:
|
|
170
|
+
def __call__(self, arguments: list[Any]) -> list[tuple[Any, dict]]:
|
|
159
171
|
start_time = time.time()
|
|
160
172
|
for arg in arguments:
|
|
161
|
-
|
|
162
|
-
input_handler = getattr(arg.prop.instance._metadata, 'input_handler', None)
|
|
163
|
-
if input_handler is not None:
|
|
164
|
-
input_handler((arg.prop.processed_input, arg))
|
|
165
|
-
if arg.prop.input_handler is not None:
|
|
166
|
-
arg.prop.input_handler((arg.prop.processed_input, arg))
|
|
173
|
+
self._trigger_input_handlers(arg)
|
|
167
174
|
|
|
168
|
-
|
|
169
|
-
results, metadata_list = self.forward(arguments)
|
|
170
|
-
except Exception as e:
|
|
171
|
-
results = [e] * len(arguments)
|
|
172
|
-
metadata_list = [None] * len(arguments)
|
|
175
|
+
results, metadata_list = self._execute_batch(arguments)
|
|
173
176
|
|
|
174
177
|
total_time = time.time() - start_time
|
|
175
178
|
if self.time_clock:
|
|
176
|
-
|
|
179
|
+
UserMessage(f"Total execution time: {total_time} sec")
|
|
177
180
|
|
|
178
|
-
|
|
181
|
+
return self._prepare_batch_results(arguments, results, metadata_list, total_time)
|
|
179
182
|
|
|
180
|
-
|
|
183
|
+
def _execute_batch(self, arguments: list[Any]) -> tuple[list[Any], list[dict | None]]:
|
|
184
|
+
try:
|
|
185
|
+
return self.forward(arguments)
|
|
186
|
+
except Exception as error:
|
|
187
|
+
return [error] * len(arguments), [None] * len(arguments)
|
|
188
|
+
|
|
189
|
+
def _prepare_batch_results(
|
|
190
|
+
self,
|
|
191
|
+
arguments: list[Any],
|
|
192
|
+
results: list[Any],
|
|
193
|
+
metadata_list: list[dict | None],
|
|
194
|
+
total_time: float,
|
|
195
|
+
) -> list[tuple[Any, dict | None]]:
|
|
196
|
+
return_list = []
|
|
197
|
+
for arg, result, metadata in zip(arguments, results, metadata_list, strict=False):
|
|
181
198
|
if metadata is not None:
|
|
182
|
-
metadata[
|
|
183
|
-
|
|
184
|
-
if hasattr(arg.prop.instance, '_metadata') and hasattr(arg.prop.instance._metadata, 'output_handler'):
|
|
185
|
-
output_handler = getattr(arg.prop.instance._metadata, 'output_handler', None)
|
|
186
|
-
if output_handler:
|
|
187
|
-
output_handler(result)
|
|
188
|
-
if arg.prop.output_handler:
|
|
189
|
-
arg.prop.output_handler((result, metadata))
|
|
199
|
+
metadata["time"] = total_time / len(arguments)
|
|
190
200
|
|
|
201
|
+
self._trigger_output_handlers(arg, result, metadata)
|
|
191
202
|
return_list.append((result, metadata))
|
|
192
203
|
return return_list
|
|
193
204
|
|
|
194
|
-
def forward(self,
|
|
195
|
-
|
|
205
|
+
def forward(self, _arguments: list[Any]) -> tuple[list[Any], list[dict]]:
|
|
206
|
+
msg = "Subclasses must implement forward method"
|
|
207
|
+
UserMessage(msg)
|
|
208
|
+
raise NotImplementedError(msg)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import tempfile
|
|
3
3
|
import time
|
|
4
|
-
from
|
|
4
|
+
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import requests
|
|
7
7
|
|
|
8
8
|
from ....symbol import Result
|
|
9
|
+
from ....utils import UserMessage
|
|
9
10
|
from ...base import Engine
|
|
10
11
|
from ...settings import SYMAI_CONFIG
|
|
11
12
|
|
|
@@ -19,89 +20,89 @@ class FluxResult(Result):
|
|
|
19
20
|
def __init__(self, value, **kwargs):
|
|
20
21
|
super().__init__(value, **kwargs)
|
|
21
22
|
# unpack the result
|
|
22
|
-
|
|
23
|
-
|
|
23
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
|
24
|
+
path = tmp_file.name
|
|
25
|
+
url = value.get("result").get("sample")
|
|
24
26
|
request = requests.get(url, allow_redirects=True)
|
|
25
27
|
request.raise_for_status()
|
|
26
|
-
with open(
|
|
28
|
+
with Path(path).open("wb") as f:
|
|
27
29
|
f.write(request.content)
|
|
28
30
|
self._value = [path]
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class DrawingEngine(Engine):
|
|
32
|
-
def __init__(self, api_key:
|
|
34
|
+
def __init__(self, api_key: str | None = None, model: str | None = None):
|
|
33
35
|
super().__init__()
|
|
34
36
|
self.config = SYMAI_CONFIG
|
|
35
|
-
self.api_key = self.config[
|
|
36
|
-
self.model = self.config[
|
|
37
|
+
self.api_key = self.config["DRAWING_ENGINE_API_KEY"] if api_key is None else api_key
|
|
38
|
+
self.model = self.config["DRAWING_ENGINE_MODEL"] if model is None else model
|
|
37
39
|
self.name = self.__class__.__name__
|
|
38
40
|
|
|
39
41
|
def id(self) -> str:
|
|
40
|
-
if
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
if self.config["DRAWING_ENGINE_API_KEY"] and self.config["DRAWING_ENGINE_MODEL"].startswith(
|
|
43
|
+
"flux"
|
|
44
|
+
):
|
|
45
|
+
return "drawing"
|
|
46
|
+
return super().id() # default to unregistered
|
|
43
47
|
|
|
44
48
|
def command(self, *args, **kwargs):
|
|
45
49
|
super().command(*args, **kwargs)
|
|
46
|
-
if
|
|
47
|
-
self.api_key = kwargs[
|
|
48
|
-
if
|
|
49
|
-
self.model = kwargs[
|
|
50
|
+
if "DRAWING_ENGINE_API_KEY" in kwargs:
|
|
51
|
+
self.api_key = kwargs["DRAWING_ENGINE_API_KEY"]
|
|
52
|
+
if "DRAWING_ENGINE_MODEL" in kwargs:
|
|
53
|
+
self.model = kwargs["DRAWING_ENGINE_MODEL"]
|
|
50
54
|
|
|
51
55
|
def forward(self, argument):
|
|
52
56
|
prompt = argument.prop.prepared_input
|
|
53
57
|
kwargs = argument.kwargs
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
safety_tolerance = kwargs.get('safety_tolerance', 2)
|
|
62
|
-
interval = kwargs.get('interval', None)
|
|
63
|
-
output_format = kwargs.get('output_format', 'png')
|
|
64
|
-
except_remedy = kwargs.get('except_remedy', None)
|
|
58
|
+
width = kwargs.get("width", 1024)
|
|
59
|
+
height = kwargs.get("height", 768)
|
|
60
|
+
steps = kwargs.get("steps", 40)
|
|
61
|
+
seed = kwargs.get("seed", None)
|
|
62
|
+
guidance = kwargs.get("guidance", None)
|
|
63
|
+
safety_tolerance = kwargs.get("safety_tolerance", 2)
|
|
64
|
+
except_remedy = kwargs.get("except_remedy", None)
|
|
65
65
|
|
|
66
66
|
headers = {
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
67
|
+
"accept": "application/json",
|
|
68
|
+
"x-key": self.api_key,
|
|
69
|
+
"Content-Type": "application/json",
|
|
70
70
|
}
|
|
71
71
|
|
|
72
72
|
payload = {
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
73
|
+
"prompt": prompt,
|
|
74
|
+
"width": width,
|
|
75
|
+
"height": height,
|
|
76
|
+
"num_inference_steps": steps,
|
|
77
|
+
"guidance_scale": guidance,
|
|
78
|
+
"seed": seed,
|
|
79
|
+
"safety_tolerance": safety_tolerance,
|
|
80
80
|
}
|
|
81
81
|
# drop any None values so Flux API won't return 500
|
|
82
82
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
83
83
|
|
|
84
|
-
if kwargs.get(
|
|
84
|
+
if kwargs.get("operation") == "create":
|
|
85
85
|
try:
|
|
86
86
|
response = requests.post(
|
|
87
|
-
f
|
|
88
|
-
headers=headers,
|
|
89
|
-
json=payload
|
|
87
|
+
f"https://api.us1.bfl.ai/v1/{self.model}", headers=headers, json=payload
|
|
90
88
|
)
|
|
91
89
|
# fail early on HTTP errors
|
|
92
90
|
response.raise_for_status()
|
|
93
91
|
data = response.json()
|
|
94
92
|
request_id = data.get("id")
|
|
95
93
|
if not request_id:
|
|
96
|
-
|
|
94
|
+
UserMessage(
|
|
95
|
+
f"Failed to get request ID! Response payload: {data}",
|
|
96
|
+
raise_with=Exception,
|
|
97
|
+
)
|
|
97
98
|
|
|
98
99
|
while True:
|
|
99
100
|
time.sleep(5)
|
|
100
101
|
|
|
101
102
|
result = requests.get(
|
|
102
|
-
|
|
103
|
+
"https://api.us1.bfl.ai/v1/get_result",
|
|
103
104
|
headers=headers,
|
|
104
|
-
params={
|
|
105
|
+
params={"id": request_id},
|
|
105
106
|
)
|
|
106
107
|
|
|
107
108
|
result.raise_for_status()
|
|
@@ -118,8 +119,8 @@ class DrawingEngine(Engine):
|
|
|
118
119
|
|
|
119
120
|
metadata = {}
|
|
120
121
|
return [rsp], metadata
|
|
121
|
-
|
|
122
|
-
|
|
122
|
+
UserMessage(f"Unknown operation: {kwargs['operation']}", raise_with=Exception)
|
|
123
|
+
return [], {}
|
|
123
124
|
|
|
124
125
|
def prepare(self, argument):
|
|
125
126
|
argument.prop.prepared_input = str(argument.prop.processed_input)
|