symbolicai 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. symai/__init__.py +198 -134
  2. symai/backend/base.py +51 -51
  3. symai/backend/engines/drawing/engine_bfl.py +33 -33
  4. symai/backend/engines/drawing/engine_gpt_image.py +4 -10
  5. symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
  6. symai/backend/engines/embedding/engine_openai.py +22 -16
  7. symai/backend/engines/execute/engine_python.py +16 -16
  8. symai/backend/engines/files/engine_io.py +51 -49
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
  11. symai/backend/engines/index/engine_pinecone.py +116 -88
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +78 -52
  14. symai/backend/engines/lean/engine_lean4.py +65 -25
  15. symai/backend/engines/neurosymbolic/__init__.py +35 -28
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
  21. symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
  26. symai/backend/engines/ocr/engine_apilayer.py +18 -20
  27. symai/backend/engines/output/engine_stdout.py +9 -9
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
  29. symai/backend/engines/search/engine_openai.py +95 -83
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +40 -41
  32. symai/backend/engines/search/engine_serpapi.py +33 -28
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
  35. symai/backend/engines/text_to_speech/engine_openai.py +15 -19
  36. symai/backend/engines/text_vision/engine_clip.py +34 -28
  37. symai/backend/engines/userinput/engine_console.py +3 -4
  38. symai/backend/mixin/__init__.py +4 -0
  39. symai/backend/mixin/anthropic.py +48 -40
  40. symai/backend/mixin/cerebras.py +9 -0
  41. symai/backend/mixin/deepseek.py +4 -5
  42. symai/backend/mixin/google.py +5 -4
  43. symai/backend/mixin/groq.py +2 -4
  44. symai/backend/mixin/openai.py +132 -110
  45. symai/backend/settings.py +14 -14
  46. symai/chat.py +164 -94
  47. symai/collect/dynamic.py +13 -11
  48. symai/collect/pipeline.py +39 -31
  49. symai/collect/stats.py +109 -69
  50. symai/components.py +578 -238
  51. symai/constraints.py +14 -5
  52. symai/core.py +1495 -1210
  53. symai/core_ext.py +55 -50
  54. symai/endpoints/api.py +113 -58
  55. symai/extended/api_builder.py +22 -17
  56. symai/extended/arxiv_pdf_parser.py +13 -5
  57. symai/extended/bibtex_parser.py +8 -4
  58. symai/extended/conversation.py +88 -69
  59. symai/extended/document.py +40 -27
  60. symai/extended/file_merger.py +45 -7
  61. symai/extended/graph.py +38 -24
  62. symai/extended/html_style_template.py +17 -11
  63. symai/extended/interfaces/blip_2.py +1 -1
  64. symai/extended/interfaces/clip.py +4 -2
  65. symai/extended/interfaces/console.py +5 -3
  66. symai/extended/interfaces/dall_e.py +3 -1
  67. symai/extended/interfaces/file.py +2 -0
  68. symai/extended/interfaces/flux.py +3 -1
  69. symai/extended/interfaces/gpt_image.py +15 -6
  70. symai/extended/interfaces/input.py +2 -1
  71. symai/extended/interfaces/llava.py +1 -1
  72. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
  73. symai/extended/interfaces/naive_vectordb.py +2 -2
  74. symai/extended/interfaces/ocr.py +4 -2
  75. symai/extended/interfaces/openai_search.py +2 -0
  76. symai/extended/interfaces/parallel.py +30 -0
  77. symai/extended/interfaces/perplexity.py +2 -0
  78. symai/extended/interfaces/pinecone.py +6 -4
  79. symai/extended/interfaces/python.py +2 -0
  80. symai/extended/interfaces/serpapi.py +2 -0
  81. symai/extended/interfaces/terminal.py +0 -1
  82. symai/extended/interfaces/tts.py +2 -1
  83. symai/extended/interfaces/whisper.py +2 -1
  84. symai/extended/interfaces/wolframalpha.py +1 -0
  85. symai/extended/metrics/__init__.py +1 -1
  86. symai/extended/metrics/similarity.py +5 -2
  87. symai/extended/os_command.py +31 -22
  88. symai/extended/packages/symdev.py +39 -34
  89. symai/extended/packages/sympkg.py +30 -27
  90. symai/extended/packages/symrun.py +46 -35
  91. symai/extended/repo_cloner.py +10 -9
  92. symai/extended/seo_query_optimizer.py +15 -12
  93. symai/extended/solver.py +104 -76
  94. symai/extended/summarizer.py +8 -7
  95. symai/extended/taypan_interpreter.py +10 -9
  96. symai/extended/vectordb.py +28 -15
  97. symai/formatter/formatter.py +39 -31
  98. symai/formatter/regex.py +46 -44
  99. symai/functional.py +184 -86
  100. symai/imports.py +85 -51
  101. symai/interfaces.py +1 -1
  102. symai/memory.py +33 -24
  103. symai/menu/screen.py +28 -19
  104. symai/misc/console.py +27 -27
  105. symai/misc/loader.py +4 -3
  106. symai/models/base.py +147 -76
  107. symai/models/errors.py +1 -1
  108. symai/ops/__init__.py +1 -1
  109. symai/ops/measures.py +17 -14
  110. symai/ops/primitives.py +933 -635
  111. symai/post_processors.py +28 -24
  112. symai/pre_processors.py +58 -52
  113. symai/processor.py +15 -9
  114. symai/prompts.py +714 -649
  115. symai/server/huggingface_server.py +115 -32
  116. symai/server/llama_cpp_server.py +14 -6
  117. symai/server/qdrant_server.py +206 -0
  118. symai/shell.py +98 -39
  119. symai/shellsv.py +307 -223
  120. symai/strategy.py +135 -81
  121. symai/symbol.py +276 -225
  122. symai/utils.py +62 -46
  123. {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/METADATA +19 -9
  124. symbolicai-1.1.1.dist-info/RECORD +169 -0
  125. symbolicai-1.0.0.dist-info/RECORD +0 -163
  126. {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/WHEEL +0 -0
  127. {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/entry_points.txt +0 -0
  128. {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/licenses/LICENSE +0 -0
  129. {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/top_level.txt +0 -0
symai/misc/loader.py CHANGED
@@ -8,7 +8,8 @@ from prompt_toolkit import print_formatted_text
8
8
 
9
9
  from .console import ConsoleStyle
10
10
 
11
- print = print_formatted_text # noqa
11
+ print = print_formatted_text # noqa
12
+
12
13
 
13
14
  class Loader:
14
15
  def __init__(self, desc="Loading...", end="\n", timeout=0.1):
@@ -31,7 +32,7 @@ class Loader:
31
32
  for c in cycle(self.steps):
32
33
  if self.done.is_set():
33
34
  break
34
- with ConsoleStyle('debug'):
35
+ with ConsoleStyle("debug"):
35
36
  sys.stdout.write(f"\r{self.desc} {c} ")
36
37
  sys.stdout.flush()
37
38
  sys.stdout.write(f"\r{self.end}")
@@ -46,7 +47,7 @@ class Loader:
46
47
  self.done.set()
47
48
  self._thread.join()
48
49
  cols = get_terminal_size((80, 20)).columns
49
- with ConsoleStyle('debug'):
50
+ with ConsoleStyle("debug"):
50
51
  sys.stdout.write("\r" + " " * cols)
51
52
  sys.stdout.flush()
52
53
  sys.stdout.write(f"\r{self.end}")
symai/models/base.py CHANGED
@@ -24,7 +24,7 @@ class CustomConstraint:
24
24
 
25
25
 
26
26
  def Const(value: str):
27
- return Field(default=value, json_schema_extra={'const_value': value})
27
+ return Field(default=value, json_schema_extra={"const_value": value})
28
28
 
29
29
 
30
30
  class LLMDataModel(BaseModel):
@@ -35,11 +35,9 @@ class LLMDataModel(BaseModel):
35
35
 
36
36
  _MAX_RECURSION_DEPTH = 50
37
37
 
38
- section_header: str = Field(
39
- default=None, exclude=True, frozen=True
40
- )
38
+ section_header: str = Field(default=None, exclude=True, frozen=True)
41
39
 
42
- @model_validator(mode='before')
40
+ @model_validator(mode="before")
43
41
  @classmethod
44
42
  def validate_const_fields(cls, values):
45
43
  """Validate that const fields have their expected values."""
@@ -47,7 +45,7 @@ class LLMDataModel(BaseModel):
47
45
  if cls._is_const_field(field_info):
48
46
  const_value = cls._get_const_value(field_info)
49
47
  if field_name in values and values[field_name] != const_value:
50
- UserMessage(f'{field_name} must be {const_value!r}', raise_with=ValueError)
48
+ UserMessage(f"{field_name} must be {const_value!r}", raise_with=ValueError)
51
49
  return values
52
50
 
53
51
  @staticmethod
@@ -74,27 +72,32 @@ class LLMDataModel(BaseModel):
74
72
  def _is_collection_type(field_type: Any) -> bool:
75
73
  """Check if a type is a collection (list, set, tuple, dict, etc.)."""
76
74
  origin = get_origin(field_type)
77
- return origin in (list, set, frozenset, tuple, dict) or field_type in (list, set, frozenset, tuple, dict)
75
+ return origin in (list, set, frozenset, tuple, dict) or field_type in (
76
+ list,
77
+ set,
78
+ frozenset,
79
+ tuple,
80
+ dict,
81
+ )
78
82
 
79
83
  @staticmethod
80
84
  def _is_const_field(field_info) -> bool:
81
85
  """Check if a field is a const field."""
82
- return (
83
- field_info.json_schema_extra and
84
- 'const_value' in field_info.json_schema_extra
85
- )
86
+ return field_info.json_schema_extra and "const_value" in field_info.json_schema_extra
86
87
 
87
88
  @staticmethod
88
89
  def _get_const_value(field_info):
89
90
  """Get the const value from a field."""
90
- return field_info.json_schema_extra.get('const_value')
91
+ return field_info.json_schema_extra.get("const_value")
91
92
 
92
93
  @staticmethod
93
94
  def _has_default_value(field_info) -> bool:
94
95
  """Check if a field has a default value."""
95
96
  return field_info.default != ... and field_info.default != PydanticUndefined
96
97
 
97
- def format_field(self, key: str, value: Any, indent: int = 0, visited: set | None = None, depth: int = 0) -> str:
98
+ def format_field(
99
+ self, key: str, value: Any, indent: int = 0, visited: set | None = None, depth: int = 0
100
+ ) -> str:
98
101
  """Formats a field value for string representation, handling nested structures."""
99
102
  visited = visited or set()
100
103
  formatter = self._get_formatter_for_value(value)
@@ -118,15 +121,21 @@ class LLMDataModel(BaseModel):
118
121
 
119
122
  return self._format_primitive_field
120
123
 
121
- def _format_none_field(self, key: str, _value: Any, indent: int, _visited: set, _depth: int) -> str:
124
+ def _format_none_field(
125
+ self, key: str, _value: Any, indent: int, _visited: set, _depth: int
126
+ ) -> str:
122
127
  """Format a None value."""
123
128
  return f"{' ' * indent}{key}: None"
124
129
 
125
- def _format_enum_field(self, key: str, value: Enum, indent: int, _visited: set, _depth: int) -> str:
130
+ def _format_enum_field(
131
+ self, key: str, value: Enum, indent: int, _visited: set, _depth: int
132
+ ) -> str:
126
133
  """Format an Enum value."""
127
134
  return f"{' ' * indent}{key}: {value.value}"
128
135
 
129
- def _format_model_field(self, key: str, value: "LLMDataModel", indent: int, visited: set, depth: int) -> str:
136
+ def _format_model_field(
137
+ self, key: str, value: "LLMDataModel", indent: int, visited: set, depth: int
138
+ ) -> str:
130
139
  """Format a nested model field."""
131
140
  obj_id = id(value)
132
141
  indent_str = " " * indent
@@ -139,7 +148,9 @@ class LLMDataModel(BaseModel):
139
148
  visited.discard(obj_id)
140
149
  return f"{indent_str}{key}:\n{indent_str} {nested_str}"
141
150
 
142
- def _format_list_field(self, key: str, value: list, indent: int, visited: set, depth: int) -> str:
151
+ def _format_list_field(
152
+ self, key: str, value: list, indent: int, visited: set, depth: int
153
+ ) -> str:
143
154
  """Format a list field."""
144
155
  indent_str = " " * indent
145
156
  if not value:
@@ -161,12 +172,16 @@ class LLMDataModel(BaseModel):
161
172
  visited.add(obj_id)
162
173
  item_str = item.__str__(indent + 2, visited, depth + 1).strip()
163
174
  visited.discard(obj_id)
164
- items.append(f"{indent_str} - : {item_str}" if item_str else f"{indent_str} - :")
175
+ items.append(
176
+ f"{indent_str} - : {item_str}" if item_str else f"{indent_str} - :"
177
+ )
165
178
  else:
166
179
  items.append(f"{indent_str} - : {item}" if item != "" else f"{indent_str} - :")
167
180
  return f"{indent_str}{key}:\n" + "\n".join(items)
168
181
 
169
- def _format_dict_field(self, key: str, value: dict, indent: int, visited: set, depth: int) -> str:
182
+ def _format_dict_field(
183
+ self, key: str, value: dict, indent: int, visited: set, depth: int
184
+ ) -> str:
170
185
  """Format a dictionary field."""
171
186
  indent_str = " " * indent
172
187
  if not value:
@@ -192,7 +207,9 @@ class LLMDataModel(BaseModel):
192
207
  visited.discard(obj_id)
193
208
  return f"{indent_str}{key}:\n" + "\n".join(items) if key else "\n".join(items)
194
209
 
195
- def _format_primitive_field(self, key: str, value: Any, indent: int, _visited: set, _depth: int) -> str:
210
+ def _format_primitive_field(
211
+ self, key: str, value: Any, indent: int, _visited: set, _depth: int
212
+ ) -> str:
196
213
  """Format a primitive field."""
197
214
  return f"{' ' * indent}{key}: {value}"
198
215
 
@@ -207,10 +224,7 @@ class LLMDataModel(BaseModel):
207
224
  field_list = [
208
225
  self.format_field(name, getattr(self, name), indent + 2, visited, depth)
209
226
  for name, field in type(self).model_fields.items()
210
- if (
211
- not getattr(field, "exclude", False)
212
- and name != "section_header"
213
- )
227
+ if (not getattr(field, "exclude", False) and name != "section_header")
214
228
  ]
215
229
 
216
230
  fields = "\n".join(field_list) + "\n" if field_list else ""
@@ -280,8 +294,15 @@ class LLMDataModel(BaseModel):
280
294
  return schema.get("$defs", schema.get("definitions", {}))
281
295
 
282
296
  @classmethod
283
- def _format_schema_field(cls, name: str, field_schema: dict, required: bool,
284
- definitions: dict, indent_level: int, visited: set | None = None) -> str:
297
+ def _format_schema_field(
298
+ cls,
299
+ name: str,
300
+ field_schema: dict,
301
+ required: bool,
302
+ definitions: dict,
303
+ indent_level: int,
304
+ visited: set | None = None,
305
+ ) -> str:
285
306
  """Format a single schema field without descriptions (kept for definitions)."""
286
307
  visited = visited or set()
287
308
 
@@ -311,27 +332,38 @@ class LLMDataModel(BaseModel):
311
332
  return result
312
333
 
313
334
  @classmethod
314
- def _format_referenced_object_fields(cls, ref_name: str, definitions: dict,
315
- indent_level: int, visited: set) -> str:
335
+ def _format_referenced_object_fields(
336
+ cls, ref_name: str, definitions: dict, indent_level: int, visited: set
337
+ ) -> str:
316
338
  """Format nested fields for a referenced object definition by name."""
317
339
  if ref_name in definitions and ref_name not in visited:
318
340
  visited.add(ref_name)
319
341
  return cls._format_schema_fields(
320
342
  definitions[ref_name].get("properties", {}),
321
- definitions[ref_name], definitions, indent_level + 1, visited.copy()
343
+ definitions[ref_name],
344
+ definitions,
345
+ indent_level + 1,
346
+ visited.copy(),
322
347
  )
323
348
  return ""
324
349
 
325
350
  @classmethod
326
- def _format_array_referenced_object_fields(cls, field_schema: dict, definitions: dict,
327
- indent_level: int, visited: set) -> str:
351
+ def _format_array_referenced_object_fields(
352
+ cls, field_schema: dict, definitions: dict, indent_level: int, visited: set
353
+ ) -> str:
328
354
  """Format nested fields for arrays referencing object definitions."""
329
355
  ref_name = field_schema.get("items", {}).get("$ref", "").split("/")[-1]
330
356
  return cls._format_referenced_object_fields(ref_name, definitions, indent_level, visited)
331
357
 
332
358
  @classmethod
333
- def _format_schema_fields(cls, properties: dict, schema: dict, definitions: dict,
334
- indent_level: int, visited: set | None = None) -> str:
359
+ def _format_schema_fields(
360
+ cls,
361
+ properties: dict,
362
+ schema: dict,
363
+ definitions: dict,
364
+ indent_level: int,
365
+ visited: set | None = None,
366
+ ) -> str:
335
367
  """Format multiple schema fields."""
336
368
  visited = visited or set()
337
369
  required_fields = set(schema.get("required", []))
@@ -342,8 +374,12 @@ class LLMDataModel(BaseModel):
342
374
  continue
343
375
  lines.append(
344
376
  cls._format_schema_field(
345
- name, field_schema, name in required_fields,
346
- definitions, indent_level, visited.copy()
377
+ name,
378
+ field_schema,
379
+ name in required_fields,
380
+ definitions,
381
+ indent_level,
382
+ visited.copy(),
347
383
  )
348
384
  )
349
385
 
@@ -397,10 +433,7 @@ class LLMDataModel(BaseModel):
397
433
  @classmethod
398
434
  def _resolve_union_type(cls, schemas: list, definitions: dict, separator: str) -> str:
399
435
  """Resolve union types (anyOf/oneOf)."""
400
- subtypes = [
401
- cls._resolve_field_type(subschema, definitions)
402
- for subschema in schemas
403
- ]
436
+ subtypes = [cls._resolve_field_type(subschema, definitions) for subschema in schemas]
404
437
  return separator.join(subtypes)
405
438
 
406
439
  @classmethod
@@ -584,7 +617,7 @@ class LLMDataModel(BaseModel):
584
617
  }
585
618
  for prefix, handler in handlers.items():
586
619
  if type_desc.startswith(prefix):
587
- item_type = type_desc[len(prefix):]
620
+ item_type = type_desc[len(prefix) :]
588
621
  return handler(item_type)
589
622
  return None
590
623
 
@@ -630,7 +663,7 @@ class LLMDataModel(BaseModel):
630
663
  }
631
664
  for prefix, template in nested_mappings.items():
632
665
  if item_type.startswith(prefix):
633
- inner = item_type[len(prefix):]
666
+ inner = item_type[len(prefix) :]
634
667
  return template.format(inner)
635
668
  return None
636
669
 
@@ -671,12 +704,15 @@ class LLMDataModel(BaseModel):
671
704
 
672
705
  if LLMDataModel._has_default_value(model_field):
673
706
  default_val = model_field.default
674
- desc = getattr(model_field, 'description', None)
675
- ann = getattr(model_field, 'annotation', None)
707
+ desc = getattr(model_field, "description", None)
708
+ ann = getattr(model_field, "annotation", None)
676
709
  is_desc_like = isinstance(default_val, str) and (
677
- (desc and default_val.strip() == str(desc).strip()) or
678
- len(default_val) >= 30 or
679
- any(kw in default_val for kw in ["represents", "should", "Always use", "This is", "This represents"])
710
+ (desc and default_val.strip() == str(desc).strip())
711
+ or len(default_val) >= 30
712
+ or any(
713
+ kw in default_val
714
+ for kw in ["represents", "should", "Always use", "This is", "This represents"]
715
+ )
680
716
  )
681
717
  if is_desc_like and (ann is str or ann is Any or ann is None):
682
718
  return "example_string"
@@ -687,18 +723,16 @@ class LLMDataModel(BaseModel):
687
723
  default_val = model_field.default_factory()
688
724
  if isinstance(default_val, (list, dict, set, tuple)) and len(default_val) == 0:
689
725
  # Generate example data instead of using empty default
690
- return LLMDataModel._generate_value_for_type(
691
- model_field.annotation, visited_models
692
- )
726
+ return LLMDataModel._generate_value_for_type(model_field.annotation, visited_models)
693
727
  return default_val
694
- return LLMDataModel._generate_value_for_type(
695
- model_field.annotation, visited_models
696
- )
728
+ return LLMDataModel._generate_value_for_type(model_field.annotation, visited_models)
697
729
 
698
730
  @staticmethod
699
731
  def _generate_value_for_type(field_type: Any, visited_models: set) -> Any:
700
732
  """Generate a value for a specific type (standard behavior)."""
701
- return LLMDataModel._generate_value_for_type_generic(field_type, visited_models, prefer_non_null=False)
733
+ return LLMDataModel._generate_value_for_type_generic(
734
+ field_type, visited_models, prefer_non_null=False
735
+ )
702
736
 
703
737
  @staticmethod
704
738
  def _generate_union_value(field_type: Any, visited_models: set) -> Any:
@@ -706,7 +740,9 @@ class LLMDataModel(BaseModel):
706
740
  subtypes = LLMDataModel._get_union_types(field_type, exclude_none=True)
707
741
  if not subtypes:
708
742
  return None
709
- return LLMDataModel._generate_value_for_type_generic(subtypes[0], visited_models, prefer_non_null=False)
743
+ return LLMDataModel._generate_value_for_type_generic(
744
+ subtypes[0], visited_models, prefer_non_null=False
745
+ )
710
746
 
711
747
  @staticmethod
712
748
  def _generate_collection_value(field_type: Any, visited_models: set) -> Any:
@@ -731,7 +767,10 @@ class LLMDataModel(BaseModel):
731
767
 
732
768
  if LLMDataModel._is_union_type(item_type):
733
769
  subtypes = LLMDataModel._get_union_types(item_type)
734
- return [LLMDataModel._generate_value_for_type_generic(subtype, visited_models, False) for subtype in subtypes[:2]]
770
+ return [
771
+ LLMDataModel._generate_value_for_type_generic(subtype, visited_models, False)
772
+ for subtype in subtypes[:2]
773
+ ]
735
774
 
736
775
  return [LLMDataModel._generate_value_for_type_generic(item_type, visited_models, False)]
737
776
 
@@ -741,7 +780,11 @@ class LLMDataModel(BaseModel):
741
780
  key_type, value_type = get_args(field_type) if get_args(field_type) else (Any, Any)
742
781
 
743
782
  example_key = LLMDataModel._example_key_for_type(key_type, visited_models)
744
- return {example_key: LLMDataModel._generate_value_for_type_generic(value_type, visited_models, False)}
783
+ return {
784
+ example_key: LLMDataModel._generate_value_for_type_generic(
785
+ value_type, visited_models, False
786
+ )
787
+ }
745
788
 
746
789
  @staticmethod
747
790
  def _generate_set_value(field_type: Any, visited_models: set) -> list:
@@ -754,7 +797,10 @@ class LLMDataModel(BaseModel):
754
797
  """Generate a value for a tuple type."""
755
798
  types = get_args(field_type)
756
799
  if types:
757
- return tuple(LLMDataModel._generate_value_for_type_generic(t, visited_models, False) for t in types)
800
+ return tuple(
801
+ LLMDataModel._generate_value_for_type_generic(t, visited_models, False)
802
+ for t in types
803
+ )
758
804
  return ("item1", "item2")
759
805
 
760
806
  @staticmethod
@@ -818,10 +864,7 @@ class LLMDataModel(BaseModel):
818
864
  @classmethod
819
865
  def _find_non_header_fields(cls) -> dict:
820
866
  """Find all fields except section_header."""
821
- return {
822
- name: f for name, f in cls.model_fields.items()
823
- if name != "section_header"
824
- }
867
+ return {name: f for name, f in cls.model_fields.items() if name != "section_header"}
825
868
 
826
869
  @classmethod
827
870
  def _is_single_value_model(cls, fields: dict) -> bool:
@@ -878,7 +921,9 @@ class LLMDataModel(BaseModel):
878
921
  return submodel.generate_example_json()
879
922
 
880
923
  @classmethod
881
- def _generate_non_null_example_for_model(cls, model: type[BaseModel], visited_models: set | None = None) -> dict:
924
+ def _generate_non_null_example_for_model(
925
+ cls, model: type[BaseModel], visited_models: set | None = None
926
+ ) -> dict:
882
927
  """Generate an example for a model, preferring non-null for Optional fields (recursive)."""
883
928
  if visited_models is None:
884
929
  visited_models = set()
@@ -897,13 +942,17 @@ class LLMDataModel(BaseModel):
897
942
  chosen = non_none_types[0] if non_none_types else Any
898
943
  example[field_name] = cls._generate_value_for_type_non_null(chosen, visited_models)
899
944
  else:
900
- example[field_name] = cls._generate_value_for_type_non_null(model_field.annotation, visited_models)
945
+ example[field_name] = cls._generate_value_for_type_non_null(
946
+ model_field.annotation, visited_models
947
+ )
901
948
  return example
902
949
 
903
950
  @classmethod
904
951
  def _generate_value_for_type_non_null(cls, field_type: Any, visited_models: set) -> Any:
905
952
  """Generate a value ensuring non-null choices for unions/optionals."""
906
- return cls._generate_value_for_type_generic(field_type, visited_models, prefer_non_null=True)
953
+ return cls._generate_value_for_type_generic(
954
+ field_type, visited_models, prefer_non_null=True
955
+ )
907
956
 
908
957
  @classmethod
909
958
  def _example_key_for_type(cls, key_type: Any, visited_models: set) -> Any:
@@ -916,14 +965,20 @@ class LLMDataModel(BaseModel):
916
965
  return True
917
966
  if key_type is tuple or get_origin(key_type) is tuple:
918
967
  tuple_args = get_args(key_type) if get_args(key_type) else (str, int)
919
- return tuple(cls._generate_value_for_type_generic(t, visited_models, True) for t in tuple_args)
968
+ return tuple(
969
+ cls._generate_value_for_type_generic(t, visited_models, True) for t in tuple_args
970
+ )
920
971
  if key_type is frozenset or get_origin(key_type) is frozenset:
921
972
  item_type = get_args(key_type)[0] if get_args(key_type) else str
922
- return frozenset([cls._generate_value_for_type_generic(item_type, visited_models, True)])
973
+ return frozenset(
974
+ [cls._generate_value_for_type_generic(item_type, visited_models, True)]
975
+ )
923
976
  return "example_string"
924
977
 
925
978
  @classmethod
926
- def _generate_value_for_type_generic(cls, field_type: Any, visited_models: set, prefer_non_null: bool) -> Any:
979
+ def _generate_value_for_type_generic(
980
+ cls, field_type: Any, visited_models: set, prefer_non_null: bool
981
+ ) -> Any:
927
982
  """Unified generator for example values; prefer_non_null to avoid None variants."""
928
983
  origin = get_origin(field_type) or field_type
929
984
 
@@ -967,7 +1022,9 @@ class LLMDataModel(BaseModel):
967
1022
  return False, None
968
1023
 
969
1024
  @classmethod
970
- def _handle_model_type(cls, origin: Any, field_type: Any, visited_models: set, prefer_non_null: bool) -> tuple[bool, Any]:
1025
+ def _handle_model_type(
1026
+ cls, origin: Any, field_type: Any, visited_models: set, prefer_non_null: bool
1027
+ ) -> tuple[bool, Any]:
971
1028
  """Handle Pydantic BaseModel subclasses."""
972
1029
  if not (isinstance(origin, type) and issubclass(origin, BaseModel)):
973
1030
  return False, None
@@ -975,11 +1032,17 @@ class LLMDataModel(BaseModel):
975
1032
  if model_name in visited_models:
976
1033
  return True, {}
977
1034
  visited_models.add(model_name)
978
- generator = cls._generate_non_null_example_for_model if prefer_non_null else LLMDataModel._generate_example_for_model
1035
+ generator = (
1036
+ cls._generate_non_null_example_for_model
1037
+ if prefer_non_null
1038
+ else LLMDataModel._generate_example_for_model
1039
+ )
979
1040
  return True, generator(field_type, visited_models.copy())
980
1041
 
981
1042
  @classmethod
982
- def _handle_union_type(cls, field_type: Any, visited_models: set, prefer_non_null: bool) -> tuple[bool, Any]:
1043
+ def _handle_union_type(
1044
+ cls, field_type: Any, visited_models: set, prefer_non_null: bool
1045
+ ) -> tuple[bool, Any]:
983
1046
  """Handle Optional/Union annotations."""
984
1047
  if not LLMDataModel._is_union_type(field_type):
985
1048
  return False, None
@@ -991,7 +1054,9 @@ class LLMDataModel(BaseModel):
991
1054
  return True, value
992
1055
 
993
1056
  @classmethod
994
- def _handle_collection_type(cls, field_type: Any, visited_models: set, prefer_non_null: bool) -> tuple[bool, Any]:
1057
+ def _handle_collection_type(
1058
+ cls, field_type: Any, visited_models: set, prefer_non_null: bool
1059
+ ) -> tuple[bool, Any]:
995
1060
  """Handle list/dict/set/tuple-like annotations."""
996
1061
  if not LLMDataModel._is_collection_type(field_type):
997
1062
  return False, None
@@ -1001,16 +1066,22 @@ class LLMDataModel(BaseModel):
1001
1066
 
1002
1067
  if origin is list:
1003
1068
  item_type = args[0] if args else Any
1004
- value = [cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)]
1069
+ value = [
1070
+ cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)
1071
+ ]
1005
1072
  return True, value
1006
1073
  if origin is dict:
1007
1074
  key_type, value_type = args if args else (Any, Any)
1008
1075
  example_key = cls._example_key_for_type(key_type, visited_models)
1009
- example_value = cls._generate_value_for_type_generic(value_type, visited_models, prefer_non_null)
1076
+ example_value = cls._generate_value_for_type_generic(
1077
+ value_type, visited_models, prefer_non_null
1078
+ )
1010
1079
  return True, {example_key: example_value}
1011
1080
  if origin in (set, frozenset):
1012
1081
  item_type = args[0] if args else Any
1013
- value = [cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)]
1082
+ value = [
1083
+ cls._generate_value_for_type_generic(item_type, visited_models, prefer_non_null)
1084
+ ]
1014
1085
  return True, value
1015
1086
  if origin is tuple:
1016
1087
  if args:
@@ -1054,8 +1125,8 @@ def build_dynamic_llm_datamodel(py_type: Any) -> type[LLMDataModel]:
1054
1125
  description="This is a dynamically generated data model. This description is general. "
1055
1126
  "If you're dealing with a complex type, or nested types in combination with unions, make sure you "
1056
1127
  "understand the instructions provided in the prompt, and select the appropriate data model based on the "
1057
- "type at hand, as described in the schema section."
1058
- )
1128
+ "type at hand, as described in the schema section.",
1129
+ ),
1059
1130
  ),
1060
1131
  )
1061
1132
 
symai/models/errors.py CHANGED
@@ -1,9 +1,9 @@
1
-
2
1
  class ExceptionWithUsage(Exception):
3
2
  def __init__(self, message, usage):
4
3
  super().__init__(message)
5
4
  self.usage = usage
6
5
 
6
+
7
7
  class TypeValidationError(Exception):
8
8
  def __init__(self, prompt: str, result: str, violations: list[str], *args):
9
9
  super().__init__(*args)
symai/ops/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from . import primitives as _primitives
2
2
 
3
- __all__ = getattr(_primitives, "__all__", None) # noqa
3
+ __all__ = getattr(_primitives, "__all__", None) # noqa
4
4
  if __all__ is None:
5
5
  __all__ = [name for name in dir(_primitives) if not name.startswith("_")]
6
6
 
symai/ops/measures.py CHANGED
@@ -32,18 +32,17 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
32
32
  sigma1 = np.atleast_2d(sigma1)
33
33
  sigma2 = np.atleast_2d(sigma2)
34
34
 
35
- assert mu1.shape == mu2.shape, \
36
- 'Training and test mean vectors have different lengths'
37
- assert sigma1.shape == sigma2.shape, \
38
- 'Training and test covariances have different dimensions'
35
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
36
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
39
37
 
40
38
  diff = mu1 - mu2
41
39
 
42
40
  # Product might be almost singular
43
41
  covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
44
42
  if not np.isfinite(covmean).all():
45
- msg = ('fid calculation produces singular product; '
46
- f'adding {eps} to diagonal of cov estimates')
43
+ msg = (
44
+ f"fid calculation produces singular product; adding {eps} to diagonal of cov estimates"
45
+ )
47
46
  UserMessage(msg)
48
47
  offset = np.eye(sigma1.shape[0]) * eps
49
48
  covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
@@ -52,14 +51,14 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
52
51
  if np.iscomplexobj(covmean):
53
52
  if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
54
53
  m = np.max(np.abs(covmean.imag))
55
- UserMessage(f'Imaginary component {m}', raise_with=ValueError)
54
+ UserMessage(f"Imaginary component {m}", raise_with=ValueError)
56
55
  covmean = covmean.real
57
56
 
58
57
  tr_covmean = np.trace(covmean)
59
58
  return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
60
59
 
61
60
 
62
- def calculate_mmd(x, y, kernel='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, eps=1e-9):
61
+ def calculate_mmd(x, y, kernel="rbf", kernel_mul=2.0, kernel_num=5, fix_sigma=None, eps=1e-9):
63
62
  def gaussian_kernel(source, target, kernel_mul, kernel_num, fix_sigma):
64
63
  n_samples = source.shape[0] + target.shape[0]
65
64
  total = np.concatenate([source, target], axis=0)
@@ -67,21 +66,25 @@ def calculate_mmd(x, y, kernel='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=No
67
66
  total1 = np.expand_dims(total, 1)
68
67
  L2_distance = np.sum((total0 - total1) ** 2, axis=2)
69
68
 
70
- bandwidth = fix_sigma or np.sum(L2_distance) / (n_samples ** 2 - n_samples + eps)
69
+ bandwidth = fix_sigma or np.sum(L2_distance) / (n_samples**2 - n_samples + eps)
71
70
  bandwidth /= kernel_mul ** (kernel_num // 2)
72
- bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
73
- kernel_val = [np.exp(-L2_distance / (bandwidth_temp + eps)) for bandwidth_temp in bandwidth_list]
71
+ bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
72
+ kernel_val = [
73
+ np.exp(-L2_distance / (bandwidth_temp + eps)) for bandwidth_temp in bandwidth_list
74
+ ]
74
75
  return np.sum(kernel_val, axis=0)
75
76
 
76
77
  def linear_mmd2(f_of_X, f_of_Y):
77
78
  delta = f_of_X.mean(axis=0) - f_of_Y.mean(axis=0)
78
79
  return np.dot(delta, delta.T)
79
80
 
80
- if kernel == 'linear':
81
+ if kernel == "linear":
81
82
  return linear_mmd2(x, y)
82
- if kernel == 'rbf':
83
+ if kernel == "rbf":
83
84
  batch_size = x.shape[0]
84
- kernels = gaussian_kernel(x, y, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
85
+ kernels = gaussian_kernel(
86
+ x, y, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma
87
+ )
85
88
  xx = np.mean(kernels[:batch_size, :batch_size])
86
89
  yy = np.mean(kernels[batch_size:, batch_size:])
87
90
  xy = np.mean(kernels[:batch_size, batch_size:])