symbolicai 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +198 -134
- symai/backend/base.py +51 -51
- symai/backend/engines/drawing/engine_bfl.py +33 -33
- symai/backend/engines/drawing/engine_gpt_image.py +4 -10
- symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
- symai/backend/engines/embedding/engine_openai.py +22 -16
- symai/backend/engines/execute/engine_python.py +16 -16
- symai/backend/engines/files/engine_io.py +51 -49
- symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
- symai/backend/engines/index/engine_pinecone.py +116 -88
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +78 -52
- symai/backend/engines/lean/engine_lean4.py +65 -25
- symai/backend/engines/neurosymbolic/__init__.py +35 -28
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
- symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
- symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
- symai/backend/engines/ocr/engine_apilayer.py +18 -20
- symai/backend/engines/output/engine_stdout.py +9 -9
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
- symai/backend/engines/search/engine_openai.py +95 -83
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +40 -41
- symai/backend/engines/search/engine_serpapi.py +33 -28
- symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
- symai/backend/engines/text_to_speech/engine_openai.py +15 -19
- symai/backend/engines/text_vision/engine_clip.py +34 -28
- symai/backend/engines/userinput/engine_console.py +3 -4
- symai/backend/mixin/__init__.py +4 -0
- symai/backend/mixin/anthropic.py +48 -40
- symai/backend/mixin/cerebras.py +9 -0
- symai/backend/mixin/deepseek.py +4 -5
- symai/backend/mixin/google.py +5 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +132 -110
- symai/backend/settings.py +14 -14
- symai/chat.py +164 -94
- symai/collect/dynamic.py +13 -11
- symai/collect/pipeline.py +39 -31
- symai/collect/stats.py +109 -69
- symai/components.py +578 -238
- symai/constraints.py +14 -5
- symai/core.py +1495 -1210
- symai/core_ext.py +55 -50
- symai/endpoints/api.py +113 -58
- symai/extended/api_builder.py +22 -17
- symai/extended/arxiv_pdf_parser.py +13 -5
- symai/extended/bibtex_parser.py +8 -4
- symai/extended/conversation.py +88 -69
- symai/extended/document.py +40 -27
- symai/extended/file_merger.py +45 -7
- symai/extended/graph.py +38 -24
- symai/extended/html_style_template.py +17 -11
- symai/extended/interfaces/blip_2.py +1 -1
- symai/extended/interfaces/clip.py +4 -2
- symai/extended/interfaces/console.py +5 -3
- symai/extended/interfaces/dall_e.py +3 -1
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +3 -1
- symai/extended/interfaces/gpt_image.py +15 -6
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -1
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
- symai/extended/interfaces/naive_vectordb.py +2 -2
- symai/extended/interfaces/ocr.py +4 -2
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +6 -4
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +2 -0
- symai/extended/interfaces/terminal.py +0 -1
- symai/extended/interfaces/tts.py +2 -1
- symai/extended/interfaces/whisper.py +2 -1
- symai/extended/interfaces/wolframalpha.py +1 -0
- symai/extended/metrics/__init__.py +1 -1
- symai/extended/metrics/similarity.py +5 -2
- symai/extended/os_command.py +31 -22
- symai/extended/packages/symdev.py +39 -34
- symai/extended/packages/sympkg.py +30 -27
- symai/extended/packages/symrun.py +46 -35
- symai/extended/repo_cloner.py +10 -9
- symai/extended/seo_query_optimizer.py +15 -12
- symai/extended/solver.py +104 -76
- symai/extended/summarizer.py +8 -7
- symai/extended/taypan_interpreter.py +10 -9
- symai/extended/vectordb.py +28 -15
- symai/formatter/formatter.py +39 -31
- symai/formatter/regex.py +46 -44
- symai/functional.py +184 -86
- symai/imports.py +85 -51
- symai/interfaces.py +1 -1
- symai/memory.py +33 -24
- symai/menu/screen.py +28 -19
- symai/misc/console.py +27 -27
- symai/misc/loader.py +4 -3
- symai/models/base.py +147 -76
- symai/models/errors.py +1 -1
- symai/ops/__init__.py +1 -1
- symai/ops/measures.py +17 -14
- symai/ops/primitives.py +933 -635
- symai/post_processors.py +28 -24
- symai/pre_processors.py +58 -52
- symai/processor.py +15 -9
- symai/prompts.py +714 -649
- symai/server/huggingface_server.py +115 -32
- symai/server/llama_cpp_server.py +14 -6
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +98 -39
- symai/shellsv.py +307 -223
- symai/strategy.py +135 -81
- symai/symbol.py +276 -225
- symai/utils.py +62 -46
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/METADATA +19 -9
- symbolicai-1.1.1.dist-info/RECORD +169 -0
- symbolicai-1.0.0.dist-info/RECORD +0 -163
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/WHEEL +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/entry_points.txt +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/top_level.txt +0 -0
symai/misc/loader.py
CHANGED
|
@@ -8,7 +8,8 @@ from prompt_toolkit import print_formatted_text
|
|
|
8
8
|
|
|
9
9
|
from .console import ConsoleStyle
|
|
10
10
|
|
|
11
|
-
print = print_formatted_text
|
|
11
|
+
print = print_formatted_text # noqa
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
class Loader:
|
|
14
15
|
def __init__(self, desc="Loading...", end="\n", timeout=0.1):
|
|
@@ -31,7 +32,7 @@ class Loader:
|
|
|
31
32
|
for c in cycle(self.steps):
|
|
32
33
|
if self.done.is_set():
|
|
33
34
|
break
|
|
34
|
-
with ConsoleStyle(
|
|
35
|
+
with ConsoleStyle("debug"):
|
|
35
36
|
sys.stdout.write(f"\r{self.desc} {c} ")
|
|
36
37
|
sys.stdout.flush()
|
|
37
38
|
sys.stdout.write(f"\r{self.end}")
|
|
@@ -46,7 +47,7 @@ class Loader:
|
|
|
46
47
|
self.done.set()
|
|
47
48
|
self._thread.join()
|
|
48
49
|
cols = get_terminal_size((80, 20)).columns
|
|
49
|
-
with ConsoleStyle(
|
|
50
|
+
with ConsoleStyle("debug"):
|
|
50
51
|
sys.stdout.write("\r" + " " * cols)
|
|
51
52
|
sys.stdout.flush()
|
|
52
53
|
sys.stdout.write(f"\r{self.end}")
|
symai/models/base.py
CHANGED
|
@@ -24,7 +24,7 @@ class CustomConstraint:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def Const(value: str):
|
|
27
|
-
return Field(default=value, json_schema_extra={
|
|
27
|
+
return Field(default=value, json_schema_extra={"const_value": value})
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class LLMDataModel(BaseModel):
|
|
@@ -35,11 +35,9 @@ class LLMDataModel(BaseModel):
|
|
|
35
35
|
|
|
36
36
|
_MAX_RECURSION_DEPTH = 50
|
|
37
37
|
|
|
38
|
-
section_header: str = Field(
|
|
39
|
-
default=None, exclude=True, frozen=True
|
|
40
|
-
)
|
|
38
|
+
section_header: str = Field(default=None, exclude=True, frozen=True)
|
|
41
39
|
|
|
42
|
-
@model_validator(mode=
|
|
40
|
+
@model_validator(mode="before")
|
|
43
41
|
@classmethod
|
|
44
42
|
def validate_const_fields(cls, values):
|
|
45
43
|
"""Validate that const fields have their expected values."""
|
|
@@ -47,7 +45,7 @@ class LLMDataModel(BaseModel):
|
|
|
47
45
|
if cls._is_const_field(field_info):
|
|
48
46
|
const_value = cls._get_const_value(field_info)
|
|
49
47
|
if field_name in values and values[field_name] != const_value:
|
|
50
|
-
UserMessage(f
|
|
48
|
+
UserMessage(f"{field_name} must be {const_value!r}", raise_with=ValueError)
|
|
51
49
|
return values
|
|
52
50
|
|
|
53
51
|
@staticmethod
|
|
@@ -74,27 +72,32 @@ class LLMDataModel(BaseModel):
|
|
|
74
72
|
def _is_collection_type(field_type: Any) -> bool:
|
|
75
73
|
"""Check if a type is a collection (list, set, tuple, dict, etc.)."""
|
|
76
74
|
origin = get_origin(field_type)
|
|
77
|
-
return origin in (list, set, frozenset, tuple, dict) or field_type in (
|
|
75
|
+
return origin in (list, set, frozenset, tuple, dict) or field_type in (
|
|
76
|
+
list,
|
|
77
|
+
set,
|
|
78
|
+
frozenset,
|
|
79
|
+
tuple,
|
|
80
|
+
dict,
|
|
81
|
+
)
|
|
78
82
|
|
|
79
83
|
@staticmethod
|
|
80
84
|
def _is_const_field(field_info) -> bool:
|
|
81
85
|
"""Check if a field is a const field."""
|
|
82
|
-
return
|
|
83
|
-
field_info.json_schema_extra and
|
|
84
|
-
'const_value' in field_info.json_schema_extra
|
|
85
|
-
)
|
|
86
|
+
return field_info.json_schema_extra and "const_value" in field_info.json_schema_extra
|
|
86
87
|
|
|
87
88
|
@staticmethod
|
|
88
89
|
def _get_const_value(field_info):
|
|
89
90
|
"""Get the const value from a field."""
|
|
90
|
-
return field_info.json_schema_extra.get(
|
|
91
|
+
return field_info.json_schema_extra.get("const_value")
|
|
91
92
|
|
|
92
93
|
@staticmethod
|
|
93
94
|
def _has_default_value(field_info) -> bool:
|
|
94
95
|
"""Check if a field has a default value."""
|
|
95
96
|
return field_info.default != ... and field_info.default != PydanticUndefined
|
|
96
97
|
|
|
97
|
-
def format_field(
|
|
98
|
+
def format_field(
|
|
99
|
+
self, key: str, value: Any, indent: int = 0, visited: set | None = None, depth: int = 0
|
|
100
|
+
) -> str:
|
|
98
101
|
"""Formats a field value for string representation, handling nested structures."""
|
|
99
102
|
visited = visited or set()
|
|
100
103
|
formatter = self._get_formatter_for_value(value)
|
|
@@ -118,15 +121,21 @@ class LLMDataModel(BaseModel):
|
|
|
118
121
|
|
|
119
122
|
return self._format_primitive_field
|
|
120
123
|
|
|
121
|
-
def _format_none_field(
|
|
124
|
+
def _format_none_field(
|
|
125
|
+
self, key: str, _value: Any, indent: int, _visited: set, _depth: int
|
|
126
|
+
) -> str:
|
|
122
127
|
"""Format a None value."""
|
|
123
128
|
return f"{' ' * indent}{key}: None"
|
|
124
129
|
|
|
125
|
-
def _format_enum_field(
|
|
130
|
+
def _format_enum_field(
|
|
131
|
+
self, key: str, value: Enum, indent: int, _visited: set, _depth: int
|
|
132
|
+
) -> str:
|
|
126
133
|
"""Format an Enum value."""
|
|
127
134
|
return f"{' ' * indent}{key}: {value.value}"
|
|
128
135
|
|
|
129
|
-
def _format_model_field(
|
|
136
|
+
def _format_model_field(
|
|
137
|
+
self, key: str, value: "LLMDataModel", indent: int, visited: set, depth: int
|
|
138
|
+
) -> str:
|
|
130
139
|
"""Format a nested model field."""
|
|
131
140
|
obj_id = id(value)
|
|
132
141
|
indent_str = " " * indent
|
|
@@ -139,7 +148,9 @@ class LLMDataModel(BaseModel):
|
|
|
139
148
|
visited.discard(obj_id)
|
|
140
149
|
return f"{indent_str}{key}:\n{indent_str} {nested_str}"
|
|
141
150
|
|
|
142
|
-
def _format_list_field(
|
|
151
|
+
def _format_list_field(
|
|
152
|
+
self, key: str, value: list, indent: int, visited: set, depth: int
|
|
153
|
+
) -> str:
|
|
143
154
|
"""Format a list field."""
|
|
144
155
|
indent_str = " " * indent
|
|
145
156
|
if not value:
|
|
@@ -161,12 +172,16 @@ class LLMDataModel(BaseModel):
|
|
|
161
172
|
visited.add(obj_id)
|
|
162
173
|
item_str = item.__str__(indent + 2, visited, depth + 1).strip()
|
|
163
174
|
visited.discard(obj_id)
|
|
164
|
-
items.append(
|
|
175
|
+
items.append(
|
|
176
|
+
f"{indent_str} - : {item_str}" if item_str else f"{indent_str} - :"
|
|
177
|
+
)
|
|
165
178
|
else:
|
|
166
179
|
items.append(f"{indent_str} - : {item}" if item != "" else f"{indent_str} - :")
|
|
167
180
|
return f"{indent_str}{key}:\n" + "\n".join(items)
|
|
168
181
|
|
|
169
|
-
def _format_dict_field(
|
|
182
|
+
def _format_dict_field(
|
|
183
|
+
self, key: str, value: dict, indent: int, visited: set, depth: int
|
|
184
|
+
) -> str:
|
|
170
185
|
"""Format a dictionary field."""
|
|
171
186
|
indent_str = " " * indent
|
|
172
187
|
if not value:
|
|
@@ -192,7 +207,9 @@ class LLMDataModel(BaseModel):
|
|
|
192
207
|
visited.discard(obj_id)
|
|
193
208
|
return f"{indent_str}{key}:\n" + "\n".join(items) if key else "\n".join(items)
|
|
194
209
|
|
|
195
|
-
def _format_primitive_field(
|
|
210
|
+
def _format_primitive_field(
|
|
211
|
+
self, key: str, value: Any, indent: int, _visited: set, _depth: int
|
|
212
|
+
) -> str:
|
|
196
213
|
"""Format a primitive field."""
|
|
197
214
|
return f"{' ' * indent}{key}: {value}"
|
|
198
215
|
|
|
@@ -207,10 +224,7 @@ class LLMDataModel(BaseModel):
|
|
|
207
224
|
field_list = [
|
|
208
225
|
self.format_field(name, getattr(self, name), indent + 2, visited, depth)
|
|
209
226
|
for name, field in type(self).model_fields.items()
|
|
210
|
-
if (
|
|
211
|
-
not getattr(field, "exclude", False)
|
|
212
|
-
and name != "section_header"
|
|
213
|
-
)
|
|
227
|
+
if (not getattr(field, "exclude", False) and name != "section_header")
|
|
214
228
|
]
|
|
215
229
|
|
|
216
230
|
fields = "\n".join(field_list) + "\n" if field_list else ""
|
|
@@ -280,8 +294,15 @@ class LLMDataModel(BaseModel):
|
|
|
280
294
|
return schema.get("$defs", schema.get("definitions", {}))
|
|
281
295
|
|
|
282
296
|
@classmethod
|
|
283
|
-
def _format_schema_field(
|
|
284
|
-
|
|
297
|
+
def _format_schema_field(
|
|
298
|
+
cls,
|
|
299
|
+
name: str,
|
|
300
|
+
field_schema: dict,
|
|
301
|
+
required: bool,
|
|
302
|
+
definitions: dict,
|
|
303
|
+
indent_level: int,
|
|
304
|
+
visited: set | None = None,
|
|
305
|
+
) -> str:
|
|
285
306
|
"""Format a single schema field without descriptions (kept for definitions)."""
|
|
286
307
|
visited = visited or set()
|
|
287
308
|
|
|
@@ -311,27 +332,38 @@ class LLMDataModel(BaseModel):
|
|
|
311
332
|
return result
|
|
312
333
|
|
|
313
334
|
@classmethod
|
|
314
|
-
def _format_referenced_object_fields(
|
|
315
|
-
|
|
335
|
+
def _format_referenced_object_fields(
|
|
336
|
+
cls, ref_name: str, definitions: dict, indent_level: int, visited: set
|
|
337
|
+
) -> str:
|
|
316
338
|
"""Format nested fields for a referenced object definition by name."""
|
|
317
339
|
if ref_name in definitions and ref_name not in visited:
|
|
318
340
|
visited.add(ref_name)
|
|
319
341
|
return cls._format_schema_fields(
|
|
320
342
|
definitions[ref_name].get("properties", {}),
|
|
321
|
-
definitions[ref_name],
|
|
343
|
+
definitions[ref_name],
|
|
344
|
+
definitions,
|
|
345
|
+
indent_level + 1,
|
|
346
|
+
visited.copy(),
|
|
322
347
|
)
|
|
323
348
|
return ""
|
|
324
349
|
|
|
325
350
|
@classmethod
|
|
326
|
-
def _format_array_referenced_object_fields(
|
|
327
|
-
|
|
351
|
+
def _format_array_referenced_object_fields(
|
|
352
|
+
cls, field_schema: dict, definitions: dict, indent_level: int, visited: set
|
|
353
|
+
) -> str:
|
|
328
354
|
"""Format nested fields for arrays referencing object definitions."""
|
|
329
355
|
ref_name = field_schema.get("items", {}).get("$ref", "").split("/")[-1]
|
|
330
356
|
return cls._format_referenced_object_fields(ref_name, definitions, indent_level, visited)
|
|
331
357
|
|
|
332
358
|
@classmethod
|
|
333
|
-
def _format_schema_fields(
|
|
334
|
-
|
|
359
|
+
def _format_schema_fields(
|
|
360
|
+
cls,
|
|
361
|
+
properties: dict,
|
|
362
|
+
schema: dict,
|
|
363
|
+
definitions: dict,
|
|
364
|
+
indent_level: int,
|
|
365
|
+
visited: set | None = None,
|
|
366
|
+
) -> str:
|
|
335
367
|
"""Format multiple schema fields."""
|
|
336
368
|
visited = visited or set()
|
|
337
369
|
required_fields = set(schema.get("required", []))
|
|
@@ -342,8 +374,12 @@ class LLMDataModel(BaseModel):
|
|
|
342
374
|
continue
|
|
343
375
|
lines.append(
|
|
344
376
|
cls._format_schema_field(
|
|
345
|
-
name,
|
|
346
|
-
|
|
377
|
+
name,
|
|
378
|
+
field_schema,
|
|
379
|
+
name in required_fields,
|
|
380
|
+
definitions,
|
|
381
|
+
indent_level,
|
|
382
|
+
visited.copy(),
|
|
347
383
|
)
|
|
348
384
|
)
|
|
349
385
|
|
|
@@ -397,10 +433,7 @@ class LLMDataModel(BaseModel):
|
|
|
397
433
|
@classmethod
|
|
398
434
|
def _resolve_union_type(cls, schemas: list, definitions: dict, separator: str) -> str:
|
|
399
435
|
"""Resolve union types (anyOf/oneOf)."""
|
|
400
|
-
subtypes = [
|
|
401
|
-
cls._resolve_field_type(subschema, definitions)
|
|
402
|
-
for subschema in schemas
|
|
403
|
-
]
|
|
436
|
+
subtypes = [cls._resolve_field_type(subschema, definitions) for subschema in schemas]
|
|
404
437
|
return separator.join(subtypes)
|
|
405
438
|
|
|
406
439
|
@classmethod
|
|
@@ -584,7 +617,7 @@ class LLMDataModel(BaseModel):
|
|
|
584
617
|
}
|
|
585
618
|
for prefix, handler in handlers.items():
|
|
586
619
|
if type_desc.startswith(prefix):
|
|
587
|
-
item_type = type_desc[len(prefix):]
|
|
620
|
+
item_type = type_desc[len(prefix) :]
|
|
588
621
|
return handler(item_type)
|
|
589
622
|
return None
|
|
590
623
|
|
|
@@ -630,7 +663,7 @@ class LLMDataModel(BaseModel):
|
|
|
630
663
|
}
|
|
631
664
|
for prefix, template in nested_mappings.items():
|
|
632
665
|
if item_type.startswith(prefix):
|
|
633
|
-
inner = item_type[len(prefix):]
|
|
666
|
+
inner = item_type[len(prefix) :]
|
|
634
667
|
return template.format(inner)
|
|
635
668
|
return None
|
|
636
669
|
|
|
@@ -671,12 +704,15 @@ class LLMDataModel(BaseModel):
|
|
|
671
704
|
|
|
672
705
|
if LLMDataModel._has_default_value(model_field):
|
|
673
706
|
default_val = model_field.default
|
|
674
|
-
desc = getattr(model_field,
|
|
675
|
-
ann = getattr(model_field,
|
|
707
|
+
desc = getattr(model_field, "description", None)
|
|
708
|
+
ann = getattr(model_field, "annotation", None)
|
|
676
709
|
is_desc_like = isinstance(default_val, str) and (
|
|
677
|
-
(desc and default_val.strip() == str(desc).strip())
|
|
678
|
-
len(default_val) >= 30
|
|
679
|
-
any(
|
|
710
|
+
(desc and default_val.strip() == str(desc).strip())
|
|
711
|
+
or len(default_val) >= 30
|
|
712
|
+
or any(
|
|
713
|
+
kw in default_val
|
|
714
|
+
for kw in ["represents", "should", "Always use", "This is", "This represents"]
|
|
715
|
+
)
|
|
680
716
|
)
|
|
681
717
|
if is_desc_like and (ann is str or ann is Any or ann is None):
|
|
682
718
|
return "example_string"
|
|
@@ -687,18 +723,16 @@ class LLMDataModel(BaseModel):
|
|
|
687
723
|
default_val = model_field.default_factory()
|
|
688
724
|
if isinstance(default_val, (list, dict, set, tuple)) and len(default_val) == 0:
|
|
689
725
|
# Generate example data instead of using empty default
|
|
690
|
-
return LLMDataModel._generate_value_for_type(
|
|
691
|
-
model_field.annotation, visited_models
|
|
692
|
-
)
|
|
726
|
+
return LLMDataModel._generate_value_for_type(model_field.annotation, visited_models)
|
|
693
727
|
return default_val
|
|
694
|
-
return LLMDataModel._generate_value_for_type(
|
|
695
|
-
model_field.annotation, visited_models
|
|
696
|
-
)
|
|
728
|
+
return LLMDataModel._generate_value_for_type(model_field.annotation, visited_models)
|
|
697
729
|
|
|
698
730
|
@staticmethod
|
|
699
731
|
def _generate_value_for_type(field_type: Any, visited_models: set) -> Any:
|
|
700
732
|
"""Generate a value for a specific type (standard behavior)."""
|
|
701
|
-
return LLMDataModel._generate_value_for_type_generic(
|
|
733
|
+
return LLMDataModel._generate_value_for_type_generic(
|
|
734
|
+
field_type, visited_models, prefer_non_null=False
|
|
735
|
+
)
|
|
702
736
|
|
|
703
737
|
@staticmethod
|
|
704
738
|
def _generate_union_value(field_type: Any, visited_models: set) -> Any:
|
|
@@ -706,7 +740,9 @@ class LLMDataModel(BaseModel):
|
|
|
706
740
|
subtypes = LLMDataModel._get_union_types(field_type, exclude_none=True)
|
|
707
741
|
if not subtypes:
|
|
708
742
|
return None
|
|
709
|
-
return LLMDataModel._generate_value_for_type_generic(
|
|
743
|
+
return LLMDataModel._generate_value_for_type_generic(
|
|
744
|
+
subtypes[0], visited_models, prefer_non_null=False
|
|
745
|
+
)
|
|
710
746
|
|
|
711
747
|
@staticmethod
|
|
712
748
|
def _generate_collection_value(field_type: Any, visited_models: set) -> Any:
|
|
@@ -731,7 +767,10 @@ class LLMDataModel(BaseModel):
|
|
|
731
767
|
|
|
732
768
|
if LLMDataModel._is_union_type(item_type):
|
|
733
769
|
subtypes = LLMDataModel._get_union_types(item_type)
|
|
734
|
-
return [
|
|
770
|
+
return [
|
|
771
|
+
LLMDataModel._generate_value_for_type_generic(subtype, visited_models, False)
|
|
772
|
+
for subtype in subtypes[:2]
|
|
773
|
+
]
|
|
735
774
|
|
|
736
775
|
return [LLMDataModel._generate_value_for_type_generic(item_type, visited_models, False)]
|
|
737
776
|
|
|
@@ -741,7 +780,11 @@ class LLMDataModel(BaseModel):
|
|
|
741
780
|
key_type, value_type = get_args(field_type) if get_args(field_type) else (Any, Any)
|
|
742
781
|
|
|
743
782
|
example_key = LLMDataModel._example_key_for_type(key_type, visited_models)
|
|
744
|
-
return {
|
|
783
|
+
return {
|
|
784
|
+
example_key: LLMDataModel._generate_value_for_type_generic(
|
|
785
|
+
value_type, visited_models, False
|
|
786
|
+
)
|
|
787
|
+
}
|
|
745
788
|
|
|
746
789
|
@staticmethod
|
|
747
790
|
def _generate_set_value(field_type: Any, visited_models: set) -> list:
|
|
@@ -754,7 +797,10 @@ class LLMDataModel(BaseModel):
|
|
|
754
797
|
"""Generate a value for a tuple type."""
|
|
755
798
|
types = get_args(field_type)
|
|
756
799
|
if types:
|
|
757
|
-
return tuple(
|
|
800
|
+
return tuple(
|
|
801
|
+
LLMDataModel._generate_value_for_type_generic(t, visited_models, False)
|
|
802
|
+
for t in types
|
|
803
|
+
)
|
|
758
804
|
return ("item1", "item2")
|
|
759
805
|
|
|
760
806
|
@staticmethod
|
|
@@ -818,10 +864,7 @@ class LLMDataModel(BaseModel):
|
|
|
818
864
|
@classmethod
|
|
819
865
|
def _find_non_header_fields(cls) -> dict:
|
|
820
866
|
"""Find all fields except section_header."""
|
|
821
|
-
return {
|
|
822
|
-
name: f for name, f in cls.model_fields.items()
|
|
823
|
-
if name != "section_header"
|
|
824
|
-
}
|
|
867
|
+
return {name: f for name, f in cls.model_fields.items() if name != "section_header"}
|
|
825
868
|
|
|
826
869
|
@classmethod
|
|
827
870
|
def _is_single_value_model(cls, fields: dict) -> bool:
|
|
@@ -878,7 +921,9 @@ class LLMDataModel(BaseModel):
|
|
|
878
921
|
return submodel.generate_example_json()
|
|
879
922
|
|
|
880
923
|
@classmethod
|
|
881
|
-
def _generate_non_null_example_for_model(
|
|
924
|
+
def _generate_non_null_example_for_model(
|
|
925
|
+
cls, model: type[BaseModel], visited_models: set | None = None
|
|
926
|
+
) -> dict:
|
|
882
927
|
"""Generate an example for a model, preferring non-null for Optional fields (recursive)."""
|
|
883
928
|
if visited_models is None:
|
|
884
929
|
visited_models = set()
|
|
@@ -897,13 +942,17 @@ class LLMDataModel(BaseModel):
|
|
|
897
942
|
chosen = non_none_types[0] if non_none_types else Any
|
|
898
943
|
example[field_name] = cls._generate_value_for_type_non_null(chosen, visited_models)
|
|
899
944
|
else:
|
|
900
|
-
example[field_name] = cls._generate_value_for_type_non_null(
|
|
945
|
+
example[field_name] = cls._generate_value_for_type_non_null(
|
|
946
|
+
model_field.annotation, visited_models
|
|
947
|
+
)
|
|
901
948
|
return example
|
|
902
949
|
|
|
903
950
|
@classmethod
|
|
904
951
|
def _generate_value_for_type_non_null(cls, field_type: Any, visited_models: set) -> Any:
|
|
905
952
|
"""Generate a value ensuring non-null choices for unions/optionals."""
|
|
906
|
-
return cls._generate_value_for_type_generic(
|
|
953
|
+
return cls._generate_value_for_type_generic(
|
|
954
|
+
field_type, visited_models, prefer_non_null=True
|
|
955
|
+
)
|
|
907
956
|
|
|
908
957
|
@classmethod
|
|
909
958
|
def _example_key_for_type(cls, key_type: Any, visited_models: set) -> Any:
|
|
@@ -916,14 +965,20 @@ class LLMDataModel(BaseModel):
|
|
|
916
965
|
return True
|
|
917
966
|
if key_type is tuple or get_origin(key_type) is tuple:
|
|
918
967
|
tuple_args = get_args(key_type) if get_args(key_type) else (str, int)
|
|
919
|
-
return tuple(
|
|
968
|
+
return tuple(
|
|
969
|
+
cls._generate_value_for_type_generic(t, visited_models, True) for t in tuple_args
|
|
970
|
+
)
|
|
920
971
|
if key_type is frozenset or get_origin(key_type) is frozenset:
|
|
921
972
|
item_type = get_args(key_type)[0] if get_args(key_type) else str
|
|
922
|
-
return frozenset(
|
|
973
|
+
return frozenset(
|
|
974
|
+
[cls._generate_value_for_type_generic(item_type, visited_models, True)]
|
|
975
|
+
)
|
|
923
976
|
return "example_string"
|
|
924
977
|
|
|
925
978
|
@classmethod
|
|
926
|
-
def _generate_value_for_type_generic(
|
|
979
|
+
def _generate_value_for_type_generic(
|
|
980
|
+
cls, field_type: Any, visited_models: set, prefer_non_null: bool
|
|
981
|
+
) -> Any:
|
|
927
982
|
"""Unified generator for example values; prefer_non_null to avoid None variants."""
|
|
928
983
|
origin = get_origin(field_type) or field_type
|
|
929
984
|
|
|
@@ -967,7 +1022,9 @@ class LLMDataModel(BaseModel):
|
|
|
967
1022
|
return False, None
|
|
968
1023
|
|
|
969
1024
|
@classmethod
|
|
970
|
-
def _handle_model_type(
|
|
1025
|
+
def _handle_model_type(
|
|
1026
|
+
cls, origin: Any, field_type: Any, visited_models: set, prefer_non_null: bool
|
|
1027
|
+
) -> tuple[bool, Any]:
|
|
971
1028
|
"""Handle Pydantic BaseModel subclasses."""
|
|
972
1029
|
if not (isinstance(origin, type) and issubclass(origin, BaseModel)):
|
|
973
1030
|
return False, None
|
|
@@ -975,11 +1032,17 @@ class LLMDataModel(BaseModel):
|
|
|
975
1032
|
if model_name in visited_models:
|
|
976
1033
|
return True, {}
|
|
977
1034
|
visited_models.add(model_name)
|
|
978
|
-
generator =
|
|
1035
|
+
generator = (
|
|
1036
|
+
cls._generate_non_null_example_for_model
|
|
1037
|
+
if prefer_non_null
|
|
1038
|
+
else LLMDataModel._generate_example_for_model
|
|
1039
|
+
)
|
|
979
1040
|
return True, generator(field_type, visited_models.copy())
|
|
980
1041
|
|
|
981
1042
|
@classmethod
|
|
982
|
-
def _handle_union_type(
|
|
1043
|
+
def _handle_union_type(
|
|
1044
|
+
cls, field_type: Any, visited_models: set, prefer_non_null: bool
|
|
1045
|
+
) -> tuple[bool, Any]:
|
|
983
1046
|
"""Handle Optional/Union annotations."""
|
|
984
1047
|
if not LLMDataModel._is_union_type(field_type):
|
|
985
1048
|
return False, None
|
|
@@ -991,7 +1054,9 @@ class LLMDataModel(BaseModel):
|
|
|
991
1054
|
return True, value
|
|
992
1055
|
|
|
993
1056
|
@classmethod
|
|
994
|
-
def _handle_collection_type(
|
|
1057
|
+
def _handle_collection_type(
|
|
1058
|
+
cls, field_type: Any, visited_models: set, prefer_non_null: bool
|
|
1059
|
+
) -> tuple[bool, Any]:
|
|
995
1060
|
"""Handle list/dict/set/tuple-like annotations."""
|
|
996
1061
|
if not LLMDataModel._is_collection_type(field_type):
|
|
997
1062
|
return False, None
|
|
@@ -1001,16 +1066,22 @@ class LLMDataModel(BaseModel):
|
|
|
1001
1066
|
|
|
1002
1067
|
if origin is list:
|
|
1003
1068
|
item_type = args[0] if args else Any
|
|
1004
|
-
value = [
|
|
1069
|
+
value = [
|
|
1070
|
+
cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)
|
|
1071
|
+
]
|
|
1005
1072
|
return True, value
|
|
1006
1073
|
if origin is dict:
|
|
1007
1074
|
key_type, value_type = args if args else (Any, Any)
|
|
1008
1075
|
example_key = cls._example_key_for_type(key_type, visited_models)
|
|
1009
|
-
example_value = cls._generate_value_for_type_generic(
|
|
1076
|
+
example_value = cls._generate_value_for_type_generic(
|
|
1077
|
+
value_type, visited_models, prefer_non_null
|
|
1078
|
+
)
|
|
1010
1079
|
return True, {example_key: example_value}
|
|
1011
1080
|
if origin in (set, frozenset):
|
|
1012
1081
|
item_type = args[0] if args else Any
|
|
1013
|
-
value = [
|
|
1082
|
+
value = [
|
|
1083
|
+
cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)
|
|
1084
|
+
]
|
|
1014
1085
|
return True, value
|
|
1015
1086
|
if origin is tuple:
|
|
1016
1087
|
if args:
|
|
@@ -1054,8 +1125,8 @@ def build_dynamic_llm_datamodel(py_type: Any) -> type[LLMDataModel]:
|
|
|
1054
1125
|
description="This is a dynamically generated data model. This description is general. "
|
|
1055
1126
|
"If you're dealing with a complex type, or nested types in combination with unions, make sure you "
|
|
1056
1127
|
"understand the instructions provided in the prompt, and select the appropriate data model based on the "
|
|
1057
|
-
"type at hand, as described in the schema section."
|
|
1058
|
-
)
|
|
1128
|
+
"type at hand, as described in the schema section.",
|
|
1129
|
+
),
|
|
1059
1130
|
),
|
|
1060
1131
|
)
|
|
1061
1132
|
|
symai/models/errors.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
|
|
2
1
|
class ExceptionWithUsage(Exception):
|
|
3
2
|
def __init__(self, message, usage):
|
|
4
3
|
super().__init__(message)
|
|
5
4
|
self.usage = usage
|
|
6
5
|
|
|
6
|
+
|
|
7
7
|
class TypeValidationError(Exception):
|
|
8
8
|
def __init__(self, prompt: str, result: str, violations: list[str], *args):
|
|
9
9
|
super().__init__(*args)
|
symai/ops/__init__.py
CHANGED
symai/ops/measures.py
CHANGED
|
@@ -32,18 +32,17 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|
|
32
32
|
sigma1 = np.atleast_2d(sigma1)
|
|
33
33
|
sigma2 = np.atleast_2d(sigma2)
|
|
34
34
|
|
|
35
|
-
assert mu1.shape == mu2.shape,
|
|
36
|
-
|
|
37
|
-
assert sigma1.shape == sigma2.shape, \
|
|
38
|
-
'Training and test covariances have different dimensions'
|
|
35
|
+
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
|
36
|
+
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
|
39
37
|
|
|
40
38
|
diff = mu1 - mu2
|
|
41
39
|
|
|
42
40
|
# Product might be almost singular
|
|
43
41
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
44
42
|
if not np.isfinite(covmean).all():
|
|
45
|
-
msg = (
|
|
46
|
-
f
|
|
43
|
+
msg = (
|
|
44
|
+
f"fid calculation produces singular product; adding {eps} to diagonal of cov estimates"
|
|
45
|
+
)
|
|
47
46
|
UserMessage(msg)
|
|
48
47
|
offset = np.eye(sigma1.shape[0]) * eps
|
|
49
48
|
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
|
@@ -52,14 +51,14 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|
|
52
51
|
if np.iscomplexobj(covmean):
|
|
53
52
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
|
54
53
|
m = np.max(np.abs(covmean.imag))
|
|
55
|
-
UserMessage(f
|
|
54
|
+
UserMessage(f"Imaginary component {m}", raise_with=ValueError)
|
|
56
55
|
covmean = covmean.real
|
|
57
56
|
|
|
58
57
|
tr_covmean = np.trace(covmean)
|
|
59
58
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
|
60
59
|
|
|
61
60
|
|
|
62
|
-
def calculate_mmd(x, y, kernel=
|
|
61
|
+
def calculate_mmd(x, y, kernel="rbf", kernel_mul=2.0, kernel_num=5, fix_sigma=None, eps=1e-9):
|
|
63
62
|
def gaussian_kernel(source, target, kernel_mul, kernel_num, fix_sigma):
|
|
64
63
|
n_samples = source.shape[0] + target.shape[0]
|
|
65
64
|
total = np.concatenate([source, target], axis=0)
|
|
@@ -67,21 +66,25 @@ def calculate_mmd(x, y, kernel='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=No
|
|
|
67
66
|
total1 = np.expand_dims(total, 1)
|
|
68
67
|
L2_distance = np.sum((total0 - total1) ** 2, axis=2)
|
|
69
68
|
|
|
70
|
-
bandwidth = fix_sigma or np.sum(L2_distance) / (n_samples
|
|
69
|
+
bandwidth = fix_sigma or np.sum(L2_distance) / (n_samples**2 - n_samples + eps)
|
|
71
70
|
bandwidth /= kernel_mul ** (kernel_num // 2)
|
|
72
|
-
bandwidth_list = [bandwidth * (kernel_mul
|
|
73
|
-
kernel_val = [
|
|
71
|
+
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
|
72
|
+
kernel_val = [
|
|
73
|
+
np.exp(-L2_distance / (bandwidth_temp + eps)) for bandwidth_temp in bandwidth_list
|
|
74
|
+
]
|
|
74
75
|
return np.sum(kernel_val, axis=0)
|
|
75
76
|
|
|
76
77
|
def linear_mmd2(f_of_X, f_of_Y):
|
|
77
78
|
delta = f_of_X.mean(axis=0) - f_of_Y.mean(axis=0)
|
|
78
79
|
return np.dot(delta, delta.T)
|
|
79
80
|
|
|
80
|
-
if kernel ==
|
|
81
|
+
if kernel == "linear":
|
|
81
82
|
return linear_mmd2(x, y)
|
|
82
|
-
if kernel ==
|
|
83
|
+
if kernel == "rbf":
|
|
83
84
|
batch_size = x.shape[0]
|
|
84
|
-
kernels = gaussian_kernel(
|
|
85
|
+
kernels = gaussian_kernel(
|
|
86
|
+
x, y, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma
|
|
87
|
+
)
|
|
85
88
|
xx = np.mean(kernels[:batch_size, :batch_size])
|
|
86
89
|
yy = np.mean(kernels[batch_size:, batch_size:])
|
|
87
90
|
xy = np.mean(kernels[:batch_size, batch_size:])
|