genxai-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. cli/__init__.py +3 -0
  2. cli/commands/__init__.py +6 -0
  3. cli/commands/approval.py +85 -0
  4. cli/commands/audit.py +127 -0
  5. cli/commands/metrics.py +25 -0
  6. cli/commands/tool.py +389 -0
  7. cli/main.py +32 -0
  8. genxai/__init__.py +81 -0
  9. genxai/api/__init__.py +5 -0
  10. genxai/api/app.py +21 -0
  11. genxai/config/__init__.py +5 -0
  12. genxai/config/settings.py +37 -0
  13. genxai/connectors/__init__.py +19 -0
  14. genxai/connectors/base.py +122 -0
  15. genxai/connectors/kafka.py +92 -0
  16. genxai/connectors/postgres_cdc.py +95 -0
  17. genxai/connectors/registry.py +44 -0
  18. genxai/connectors/sqs.py +94 -0
  19. genxai/connectors/webhook.py +73 -0
  20. genxai/core/__init__.py +37 -0
  21. genxai/core/agent/__init__.py +32 -0
  22. genxai/core/agent/base.py +206 -0
  23. genxai/core/agent/config_io.py +59 -0
  24. genxai/core/agent/registry.py +98 -0
  25. genxai/core/agent/runtime.py +970 -0
  26. genxai/core/communication/__init__.py +6 -0
  27. genxai/core/communication/collaboration.py +44 -0
  28. genxai/core/communication/message_bus.py +192 -0
  29. genxai/core/communication/protocols.py +35 -0
  30. genxai/core/execution/__init__.py +22 -0
  31. genxai/core/execution/metadata.py +181 -0
  32. genxai/core/execution/queue.py +201 -0
  33. genxai/core/graph/__init__.py +30 -0
  34. genxai/core/graph/checkpoints.py +77 -0
  35. genxai/core/graph/edges.py +131 -0
  36. genxai/core/graph/engine.py +813 -0
  37. genxai/core/graph/executor.py +516 -0
  38. genxai/core/graph/nodes.py +161 -0
  39. genxai/core/graph/trigger_runner.py +40 -0
  40. genxai/core/memory/__init__.py +19 -0
  41. genxai/core/memory/base.py +72 -0
  42. genxai/core/memory/embedding.py +327 -0
  43. genxai/core/memory/episodic.py +448 -0
  44. genxai/core/memory/long_term.py +467 -0
  45. genxai/core/memory/manager.py +543 -0
  46. genxai/core/memory/persistence.py +297 -0
  47. genxai/core/memory/procedural.py +461 -0
  48. genxai/core/memory/semantic.py +526 -0
  49. genxai/core/memory/shared.py +62 -0
  50. genxai/core/memory/short_term.py +303 -0
  51. genxai/core/memory/vector_store.py +508 -0
  52. genxai/core/memory/working.py +211 -0
  53. genxai/core/state/__init__.py +6 -0
  54. genxai/core/state/manager.py +293 -0
  55. genxai/core/state/schema.py +115 -0
  56. genxai/llm/__init__.py +14 -0
  57. genxai/llm/base.py +150 -0
  58. genxai/llm/factory.py +329 -0
  59. genxai/llm/providers/__init__.py +1 -0
  60. genxai/llm/providers/anthropic.py +249 -0
  61. genxai/llm/providers/cohere.py +274 -0
  62. genxai/llm/providers/google.py +334 -0
  63. genxai/llm/providers/ollama.py +147 -0
  64. genxai/llm/providers/openai.py +257 -0
  65. genxai/llm/routing.py +83 -0
  66. genxai/observability/__init__.py +6 -0
  67. genxai/observability/logging.py +327 -0
  68. genxai/observability/metrics.py +494 -0
  69. genxai/observability/tracing.py +372 -0
  70. genxai/performance/__init__.py +39 -0
  71. genxai/performance/cache.py +256 -0
  72. genxai/performance/pooling.py +289 -0
  73. genxai/security/audit.py +304 -0
  74. genxai/security/auth.py +315 -0
  75. genxai/security/cost_control.py +528 -0
  76. genxai/security/default_policies.py +44 -0
  77. genxai/security/jwt.py +142 -0
  78. genxai/security/oauth.py +226 -0
  79. genxai/security/pii.py +366 -0
  80. genxai/security/policy_engine.py +82 -0
  81. genxai/security/rate_limit.py +341 -0
  82. genxai/security/rbac.py +247 -0
  83. genxai/security/validation.py +218 -0
  84. genxai/tools/__init__.py +21 -0
  85. genxai/tools/base.py +383 -0
  86. genxai/tools/builtin/__init__.py +131 -0
  87. genxai/tools/builtin/communication/__init__.py +15 -0
  88. genxai/tools/builtin/communication/email_sender.py +159 -0
  89. genxai/tools/builtin/communication/notification_manager.py +167 -0
  90. genxai/tools/builtin/communication/slack_notifier.py +118 -0
  91. genxai/tools/builtin/communication/sms_sender.py +118 -0
  92. genxai/tools/builtin/communication/webhook_caller.py +136 -0
  93. genxai/tools/builtin/computation/__init__.py +15 -0
  94. genxai/tools/builtin/computation/calculator.py +101 -0
  95. genxai/tools/builtin/computation/code_executor.py +183 -0
  96. genxai/tools/builtin/computation/data_validator.py +259 -0
  97. genxai/tools/builtin/computation/hash_generator.py +129 -0
  98. genxai/tools/builtin/computation/regex_matcher.py +201 -0
  99. genxai/tools/builtin/data/__init__.py +15 -0
  100. genxai/tools/builtin/data/csv_processor.py +213 -0
  101. genxai/tools/builtin/data/data_transformer.py +299 -0
  102. genxai/tools/builtin/data/json_processor.py +233 -0
  103. genxai/tools/builtin/data/text_analyzer.py +288 -0
  104. genxai/tools/builtin/data/xml_processor.py +175 -0
  105. genxai/tools/builtin/database/__init__.py +15 -0
  106. genxai/tools/builtin/database/database_inspector.py +157 -0
  107. genxai/tools/builtin/database/mongodb_query.py +196 -0
  108. genxai/tools/builtin/database/redis_cache.py +167 -0
  109. genxai/tools/builtin/database/sql_query.py +145 -0
  110. genxai/tools/builtin/database/vector_search.py +163 -0
  111. genxai/tools/builtin/file/__init__.py +17 -0
  112. genxai/tools/builtin/file/directory_scanner.py +214 -0
  113. genxai/tools/builtin/file/file_compressor.py +237 -0
  114. genxai/tools/builtin/file/file_reader.py +102 -0
  115. genxai/tools/builtin/file/file_writer.py +122 -0
  116. genxai/tools/builtin/file/image_processor.py +186 -0
  117. genxai/tools/builtin/file/pdf_parser.py +144 -0
  118. genxai/tools/builtin/test/__init__.py +15 -0
  119. genxai/tools/builtin/test/async_simulator.py +62 -0
  120. genxai/tools/builtin/test/data_transformer.py +99 -0
  121. genxai/tools/builtin/test/error_generator.py +82 -0
  122. genxai/tools/builtin/test/simple_math.py +94 -0
  123. genxai/tools/builtin/test/string_processor.py +72 -0
  124. genxai/tools/builtin/web/__init__.py +15 -0
  125. genxai/tools/builtin/web/api_caller.py +161 -0
  126. genxai/tools/builtin/web/html_parser.py +330 -0
  127. genxai/tools/builtin/web/http_client.py +187 -0
  128. genxai/tools/builtin/web/url_validator.py +162 -0
  129. genxai/tools/builtin/web/web_scraper.py +170 -0
  130. genxai/tools/custom/my_test_tool_2.py +9 -0
  131. genxai/tools/dynamic.py +105 -0
  132. genxai/tools/mcp_server.py +167 -0
  133. genxai/tools/persistence/__init__.py +6 -0
  134. genxai/tools/persistence/models.py +55 -0
  135. genxai/tools/persistence/service.py +322 -0
  136. genxai/tools/registry.py +227 -0
  137. genxai/tools/security/__init__.py +11 -0
  138. genxai/tools/security/limits.py +214 -0
  139. genxai/tools/security/policy.py +20 -0
  140. genxai/tools/security/sandbox.py +248 -0
  141. genxai/tools/templates.py +435 -0
  142. genxai/triggers/__init__.py +19 -0
  143. genxai/triggers/base.py +104 -0
  144. genxai/triggers/file_watcher.py +75 -0
  145. genxai/triggers/queue.py +68 -0
  146. genxai/triggers/registry.py +82 -0
  147. genxai/triggers/schedule.py +66 -0
  148. genxai/triggers/webhook.py +68 -0
  149. genxai/utils/__init__.py +1 -0
  150. genxai/utils/tokens.py +295 -0
  151. genxai_framework-0.1.0.dist-info/METADATA +495 -0
  152. genxai_framework-0.1.0.dist-info/RECORD +156 -0
  153. genxai_framework-0.1.0.dist-info/WHEEL +5 -0
  154. genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
  155. genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
  156. genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,211 @@
1
+ """Working memory implementation for active processing."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+ from datetime import datetime
5
+ import logging
6
+ from collections import deque
7
+
8
+ from genxai.core.memory.base import Memory, MemoryType
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class WorkingMemory:
14
+ """Working memory for active processing and temporary storage.
15
+
16
+ Working memory holds information that is currently being processed:
17
+ - Current task context
18
+ - Intermediate results
19
+ - Active goals
20
+ - Temporary computations
21
+
22
+ Has limited capacity and items are automatically evicted when full.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ capacity: int = 5,
28
+ ) -> None:
29
+ """Initialize working memory.
30
+
31
+ Args:
32
+ capacity: Maximum number of items to hold
33
+ """
34
+ self._capacity = capacity
35
+ self._items: deque = deque(maxlen=capacity)
36
+ self._item_map: Dict[str, Any] = {} # id -> item for fast lookup
37
+
38
+ logger.info(f"Initialized working memory with capacity {capacity}")
39
+
40
+ def add(
41
+ self,
42
+ key: str,
43
+ value: Any,
44
+ metadata: Optional[Dict[str, Any]] = None,
45
+ ) -> None:
46
+ """Add an item to working memory.
47
+
48
+ Args:
49
+ key: Item key/identifier
50
+ value: Item value
51
+ metadata: Optional metadata
52
+ """
53
+ item = {
54
+ "key": key,
55
+ "value": value,
56
+ "metadata": metadata or {},
57
+ "timestamp": datetime.now(),
58
+ }
59
+
60
+ # Remove old item with same key if exists
61
+ if key in self._item_map:
62
+ self.remove(key)
63
+
64
+ # Add new item
65
+ self._items.append(item)
66
+ self._item_map[key] = item
67
+
68
+ # If capacity exceeded, oldest item was automatically removed by deque
69
+ # Update item_map accordingly
70
+ if len(self._items) < len(self._item_map):
71
+ # Find and remove evicted items from map
72
+ current_keys = {item["key"] for item in self._items}
73
+ evicted_keys = set(self._item_map.keys()) - current_keys
74
+ for evicted_key in evicted_keys:
75
+ del self._item_map[evicted_key]
76
+ logger.debug(f"Evicted item from working memory: {evicted_key}")
77
+
78
+ logger.debug(f"Added to working memory: {key}")
79
+
80
+ def get(self, key: str) -> Optional[Any]:
81
+ """Get an item from working memory.
82
+
83
+ Args:
84
+ key: Item key
85
+
86
+ Returns:
87
+ Item value if found, None otherwise
88
+ """
89
+ item = self._item_map.get(key)
90
+ if item:
91
+ return item["value"]
92
+ return None
93
+
94
+ def get_all(self) -> List[Dict[str, Any]]:
95
+ """Get all items in working memory.
96
+
97
+ Returns:
98
+ List of all items
99
+ """
100
+ return list(self._items)
101
+
102
+ def get_recent(self, n: int = 3) -> List[Dict[str, Any]]:
103
+ """Get n most recent items.
104
+
105
+ Args:
106
+ n: Number of items to retrieve
107
+
108
+ Returns:
109
+ List of recent items
110
+ """
111
+ return list(self._items)[-n:] if len(self._items) >= n else list(self._items)
112
+
113
+ def remove(self, key: str) -> bool:
114
+ """Remove an item from working memory.
115
+
116
+ Args:
117
+ key: Item key
118
+
119
+ Returns:
120
+ True if removed, False if not found
121
+ """
122
+ if key not in self._item_map:
123
+ return False
124
+
125
+ # Remove from deque
126
+ self._items = deque(
127
+ (item for item in self._items if item["key"] != key),
128
+ maxlen=self._capacity
129
+ )
130
+
131
+ # Remove from map
132
+ del self._item_map[key]
133
+
134
+ logger.debug(f"Removed from working memory: {key}")
135
+ return True
136
+
137
+ def clear(self) -> None:
138
+ """Clear all items from working memory."""
139
+ count = len(self._items)
140
+ self._items.clear()
141
+ self._item_map.clear()
142
+ logger.info(f"Cleared {count} items from working memory")
143
+
144
+ def contains(self, key: str) -> bool:
145
+ """Check if key exists in working memory.
146
+
147
+ Args:
148
+ key: Item key
149
+
150
+ Returns:
151
+ True if exists, False otherwise
152
+ """
153
+ return key in self._item_map
154
+
155
+ def get_size(self) -> int:
156
+ """Get current number of items.
157
+
158
+ Returns:
159
+ Number of items
160
+ """
161
+ return len(self._items)
162
+
163
+ def get_capacity(self) -> int:
164
+ """Get maximum capacity.
165
+
166
+ Returns:
167
+ Capacity
168
+ """
169
+ return self._capacity
170
+
171
+ def is_full(self) -> bool:
172
+ """Check if working memory is full.
173
+
174
+ Returns:
175
+ True if full, False otherwise
176
+ """
177
+ return len(self._items) >= self._capacity
178
+
179
+ def get_stats(self) -> Dict[str, Any]:
180
+ """Get working memory statistics.
181
+
182
+ Returns:
183
+ Statistics dictionary
184
+ """
185
+ if not self._items:
186
+ return {
187
+ "size": 0,
188
+ "capacity": self._capacity,
189
+ "utilization": 0.0,
190
+ }
191
+
192
+ return {
193
+ "size": len(self._items),
194
+ "capacity": self._capacity,
195
+ "utilization": len(self._items) / self._capacity,
196
+ "oldest_item": self._items[0]["timestamp"].isoformat(),
197
+ "newest_item": self._items[-1]["timestamp"].isoformat(),
198
+ "keys": [item["key"] for item in self._items],
199
+ }
200
+
201
+ def __len__(self) -> int:
202
+ """Get number of items."""
203
+ return len(self._items)
204
+
205
+ def __contains__(self, key: str) -> bool:
206
+ """Check if key exists."""
207
+ return key in self._item_map
208
+
209
+ def __repr__(self) -> str:
210
+ """String representation."""
211
+ return f"WorkingMemory(size={len(self._items)}/{self._capacity})"
@@ -0,0 +1,6 @@
1
+ """State management for GenXAI workflows."""
2
+
3
+ from genxai.core.state.manager import StateManager
4
+ from genxai.core.state.schema import StateSchema
5
+
6
+ __all__ = ["StateManager", "StateSchema"]
@@ -0,0 +1,293 @@
1
+ """State manager for workflow execution."""
2
+
3
+ import json
4
+ from typing import Any, Dict, Optional
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import logging
8
+
9
+ from genxai.core.state.schema import StateSchema
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class StateManager:
15
+ """Manages workflow state with persistence and versioning."""
16
+
17
+ def __init__(
18
+ self,
19
+ schema: Optional[StateSchema] = None,
20
+ enable_persistence: bool = False,
21
+ persistence_path: Optional[Path] = None,
22
+ ) -> None:
23
+ """Initialize state manager.
24
+
25
+ Args:
26
+ schema: State schema for validation
27
+ enable_persistence: Whether to enable state persistence
28
+ persistence_path: Path for state persistence
29
+ """
30
+ self.schema = schema
31
+ self.enable_persistence = enable_persistence
32
+ self.persistence_path = persistence_path or Path(".genxai/state")
33
+ self._state: Dict[str, Any] = {}
34
+ self._history: list[Dict[str, Any]] = []
35
+ self._version = 0
36
+
37
+ def get(self, key: str, default: Any = None) -> Any:
38
+ """Get value from state.
39
+
40
+ Args:
41
+ key: State key
42
+ default: Default value if key not found
43
+
44
+ Returns:
45
+ State value or default
46
+ """
47
+ return self._state.get(key, default)
48
+
49
+ def set(self, key: str, value: Any) -> None:
50
+ """Set value in state.
51
+
52
+ Args:
53
+ key: State key
54
+ value: Value to set
55
+ """
56
+ old_value = self._state.get(key)
57
+ self._state[key] = value
58
+ self._version += 1
59
+
60
+ # Record in history
61
+ self._history.append(
62
+ {
63
+ "version": self._version,
64
+ "timestamp": datetime.now().isoformat(),
65
+ "action": "set",
66
+ "key": key,
67
+ "old_value": old_value,
68
+ "new_value": value,
69
+ }
70
+ )
71
+
72
+ logger.debug(f"State updated: {key} = {value}")
73
+
74
+ # Persist if enabled
75
+ if self.enable_persistence:
76
+ self._persist()
77
+
78
+ def update(self, updates: Dict[str, Any]) -> None:
79
+ """Update multiple state values.
80
+
81
+ Args:
82
+ updates: Dictionary of updates
83
+ """
84
+ for key, value in updates.items():
85
+ self.set(key, value)
86
+
87
+ def delete(self, key: str) -> None:
88
+ """Delete key from state.
89
+
90
+ Args:
91
+ key: State key to delete
92
+ """
93
+ if key in self._state:
94
+ old_value = self._state[key]
95
+ del self._state[key]
96
+ self._version += 1
97
+
98
+ self._history.append(
99
+ {
100
+ "version": self._version,
101
+ "timestamp": datetime.now().isoformat(),
102
+ "action": "delete",
103
+ "key": key,
104
+ "old_value": old_value,
105
+ }
106
+ )
107
+
108
+ logger.debug(f"State key deleted: {key}")
109
+
110
+ if self.enable_persistence:
111
+ self._persist()
112
+
113
+ def get_all(self) -> Dict[str, Any]:
114
+ """Get all state values.
115
+
116
+ Returns:
117
+ Complete state dictionary
118
+ """
119
+ return self._state.copy()
120
+
121
+ def clear(self) -> None:
122
+ """Clear all state."""
123
+ self._state.clear()
124
+ self._version += 1
125
+ self._history.append(
126
+ {
127
+ "version": self._version,
128
+ "timestamp": datetime.now().isoformat(),
129
+ "action": "clear",
130
+ }
131
+ )
132
+ logger.info("State cleared")
133
+
134
+ if self.enable_persistence:
135
+ self._persist()
136
+
137
+ def validate(self) -> bool:
138
+ """Validate current state against schema.
139
+
140
+ Returns:
141
+ True if valid
142
+
143
+ Raises:
144
+ ValueError: If validation fails
145
+ """
146
+ if self.schema is None:
147
+ return True
148
+
149
+ return self.schema.validate_state(self._state)
150
+
151
+ def checkpoint(self, name: str) -> None:
152
+ """Create a named checkpoint of current state.
153
+
154
+ Args:
155
+ name: Checkpoint name
156
+ """
157
+ checkpoint = {
158
+ "name": name,
159
+ "version": self._version,
160
+ "timestamp": datetime.now().isoformat(),
161
+ "state": self._state.copy(),
162
+ }
163
+
164
+ self._history.append(
165
+ {
166
+ "version": self._version,
167
+ "timestamp": datetime.now().isoformat(),
168
+ "action": "checkpoint",
169
+ "checkpoint": checkpoint,
170
+ }
171
+ )
172
+
173
+ logger.info(f"Checkpoint created: {name}")
174
+
175
+ if self.enable_persistence:
176
+ self._persist_checkpoint(name, checkpoint)
177
+
178
+ def rollback(self, version: Optional[int] = None) -> None:
179
+ """Rollback state to a previous version.
180
+
181
+ Args:
182
+ version: Version to rollback to (default: previous version)
183
+ """
184
+ if not self._history:
185
+ logger.warning("No history to rollback to")
186
+ return
187
+
188
+ target_version = version or (self._version - 1)
189
+
190
+ # Find state at target version
191
+ for entry in reversed(self._history):
192
+ if entry.get("version") == target_version:
193
+ if entry.get("action") == "checkpoint":
194
+ self._state = entry["checkpoint"]["state"].copy()
195
+ self._version = target_version
196
+ logger.info(f"Rolled back to version {target_version}")
197
+ return
198
+
199
+ logger.warning(f"Version {target_version} not found in history")
200
+
201
+ def get_history(self, limit: Optional[int] = None) -> list[Dict[str, Any]]:
202
+ """Get state change history.
203
+
204
+ Args:
205
+ limit: Maximum number of entries to return
206
+
207
+ Returns:
208
+ List of history entries
209
+ """
210
+ if limit:
211
+ return self._history[-limit:]
212
+ return self._history.copy()
213
+
214
+ def _persist(self) -> None:
215
+ """Persist current state to disk."""
216
+ if not self.persistence_path:
217
+ return
218
+
219
+ self.persistence_path.mkdir(parents=True, exist_ok=True)
220
+ state_file = self.persistence_path / "current_state.json"
221
+
222
+ try:
223
+ with open(state_file, "w") as f:
224
+ json.dump(
225
+ {
226
+ "version": self._version,
227
+ "timestamp": datetime.now().isoformat(),
228
+ "state": self._state,
229
+ },
230
+ f,
231
+ indent=2,
232
+ default=str,
233
+ )
234
+ logger.debug(f"State persisted to {state_file}")
235
+ except Exception as e:
236
+ logger.error(f"Failed to persist state: {e}")
237
+
238
+ def _persist_checkpoint(self, name: str, checkpoint: Dict[str, Any]) -> None:
239
+ """Persist a checkpoint to disk.
240
+
241
+ Args:
242
+ name: Checkpoint name
243
+ checkpoint: Checkpoint data
244
+ """
245
+ if not self.persistence_path:
246
+ return
247
+
248
+ self.persistence_path.mkdir(parents=True, exist_ok=True)
249
+ checkpoint_file = self.persistence_path / f"checkpoint_{name}.json"
250
+
251
+ try:
252
+ with open(checkpoint_file, "w") as f:
253
+ json.dump(checkpoint, f, indent=2, default=str)
254
+ logger.debug(f"Checkpoint persisted to {checkpoint_file}")
255
+ except Exception as e:
256
+ logger.error(f"Failed to persist checkpoint: {e}")
257
+
258
+ def load(self, path: Optional[Path] = None) -> None:
259
+ """Load state from disk.
260
+
261
+ Args:
262
+ path: Path to load from (default: persistence_path)
263
+ """
264
+ load_path = path or (self.persistence_path / "current_state.json")
265
+
266
+ if not load_path.exists():
267
+ logger.warning(f"State file not found: {load_path}")
268
+ return
269
+
270
+ try:
271
+ with open(load_path, "r") as f:
272
+ data = json.load(f)
273
+ self._state = data.get("state", {})
274
+ self._version = data.get("version", 0)
275
+ logger.info(f"State loaded from {load_path}")
276
+ except Exception as e:
277
+ logger.error(f"Failed to load state: {e}")
278
+
279
+ def to_dict(self) -> Dict[str, Any]:
280
+ """Convert state manager to dictionary.
281
+
282
+ Returns:
283
+ Dictionary representation
284
+ """
285
+ return {
286
+ "version": self._version,
287
+ "state": self._state,
288
+ "history_length": len(self._history),
289
+ }
290
+
291
+ def __repr__(self) -> str:
292
+ """String representation."""
293
+ return f"StateManager(version={self._version}, keys={len(self._state)})"
@@ -0,0 +1,115 @@
1
+ """State schema definition and validation."""
2
+
3
+ from typing import Any, Dict, Optional, Type
4
+ from pydantic import BaseModel, Field, create_model, ConfigDict
5
+
6
+
7
+ class StateSchema(BaseModel):
8
+ """Schema for workflow state."""
9
+
10
+ model_config = ConfigDict(arbitrary_types_allowed=True)
11
+
12
+ fields: Dict[str, Type[Any]] = Field(default_factory=dict)
13
+ required_fields: set[str] = Field(default_factory=set)
14
+ metadata: Dict[str, Any] = Field(default_factory=dict)
15
+
16
+
17
+ def validate_state(self, state: Dict[str, Any]) -> bool:
18
+ """Validate state against schema.
19
+
20
+ Args:
21
+ state: State dictionary to validate
22
+
23
+ Returns:
24
+ True if valid
25
+
26
+ Raises:
27
+ ValueError: If validation fails
28
+ """
29
+ # Check required fields
30
+ missing_fields = self.required_fields - set(state.keys())
31
+ if missing_fields:
32
+ raise ValueError(f"Missing required fields: {missing_fields}")
33
+
34
+ # Type validation
35
+ for field_name, field_type in self.fields.items():
36
+ if field_name in state:
37
+ value = state[field_name]
38
+ if not isinstance(value, field_type):
39
+ raise ValueError(
40
+ f"Field '{field_name}' has wrong type. "
41
+ f"Expected {field_type}, got {type(value)}"
42
+ )
43
+
44
+ return True
45
+
46
+ def create_pydantic_model(self, model_name: str = "DynamicState") -> Type[BaseModel]:
47
+ """Create a Pydantic model from the schema.
48
+
49
+ Args:
50
+ model_name: Name for the generated model
51
+
52
+ Returns:
53
+ Pydantic model class
54
+ """
55
+ field_definitions = {}
56
+ for field_name, field_type in self.fields.items():
57
+ if field_name in self.required_fields:
58
+ field_definitions[field_name] = (field_type, ...)
59
+ else:
60
+ field_definitions[field_name] = (Optional[field_type], None)
61
+
62
+ return create_model(model_name, **field_definitions)
63
+
64
+ def add_field(
65
+ self, name: str, field_type: Type[Any], required: bool = False
66
+ ) -> None:
67
+ """Add a field to the schema.
68
+
69
+ Args:
70
+ name: Field name
71
+ field_type: Field type
72
+ required: Whether field is required
73
+ """
74
+ self.fields[name] = field_type
75
+ if required:
76
+ self.required_fields.add(name)
77
+
78
+ def remove_field(self, name: str) -> None:
79
+ """Remove a field from the schema.
80
+
81
+ Args:
82
+ name: Field name to remove
83
+ """
84
+ if name in self.fields:
85
+ del self.fields[name]
86
+ if name in self.required_fields:
87
+ self.required_fields.remove(name)
88
+
89
+ def to_dict(self) -> Dict[str, Any]:
90
+ """Convert schema to dictionary.
91
+
92
+ Returns:
93
+ Dictionary representation
94
+ """
95
+ return {
96
+ "fields": {name: str(type_) for name, type_ in self.fields.items()},
97
+ "required_fields": list(self.required_fields),
98
+ "metadata": self.metadata,
99
+ }
100
+
101
+ @classmethod
102
+ def from_dict(cls, data: Dict[str, Any]) -> "StateSchema":
103
+ """Create schema from dictionary.
104
+
105
+ Args:
106
+ data: Dictionary representation
107
+
108
+ Returns:
109
+ StateSchema instance
110
+ """
111
+ schema = cls()
112
+ schema.required_fields = set(data.get("required_fields", []))
113
+ schema.metadata = data.get("metadata", {})
114
+ # Note: Type reconstruction from string would need eval or mapping
115
+ return schema
genxai/llm/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ """LLM provider integrations for GenXAI."""
2
+
3
+ from genxai.llm.base import LLMProvider, LLMResponse
4
+ from genxai.llm.providers.openai import OpenAIProvider
5
+ from genxai.llm.providers.ollama import OllamaProvider
6
+ from genxai.llm.factory import LLMProviderFactory
7
+
8
+ __all__ = [
9
+ "LLMProvider",
10
+ "LLMResponse",
11
+ "OpenAIProvider",
12
+ "OllamaProvider",
13
+ "LLMProviderFactory",
14
+ ]