polycoding 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,346 @@
1
+ """PostgreSQL-based flow state persistence using SQLAlchemy with JSONB."""
2
+
3
+ import logging
4
+ from datetime import datetime, timezone
5
+ from typing import Any, Optional
6
+
7
+ from crewai.flow.async_feedback.types import PendingFeedbackContext
8
+ from crewai.flow.persistence import FlowPersistence, SQLiteFlowPersistence
9
+ from pydantic import BaseModel
10
+ from sqlalchemy import Index, create_engine
11
+ from sqlalchemy.dialects.postgresql import JSONB
12
+ from sqlalchemy.orm import (
13
+ DeclarativeBase,
14
+ Mapped,
15
+ mapped_column,
16
+ sessionmaker,
17
+ )
18
+ from sqlalchemy.sql.expression import text
19
+ from sqlalchemy.types import JSON, DateTime, Integer, String, TypeDecorator
20
+
21
+ from persistence.config import settings
22
+
23
+ DATABASE_URL = settings.DATABASE_URL
24
+
25
+ engine = create_engine(DATABASE_URL)
26
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Base(DeclarativeBase):
32
+ """SQLAlchemy declarative base."""
33
+
34
+
35
+ class JSONType(TypeDecorator):
36
+ """Platform-independent JSON type. Uses JSONB for PostgreSQL, JSON for others."""
37
+
38
+ impl = JSON
39
+ cache_ok = True
40
+
41
+ def load_dialect_impl(self, dialect):
42
+ if dialect.name == "postgresql":
43
+ return dialect.type_descriptor(JSONB())
44
+ return dialect.type_descriptor(JSON())
45
+
46
+
47
+ class Payments(Base):
48
+ __tablename__ = "payments"
49
+
50
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
51
+ issue_number: Mapped[int] = mapped_column()
52
+ payment_id: Mapped[str] = mapped_column()
53
+ amount: Mapped[int] = mapped_column()
54
+ currency: Mapped[str] = mapped_column()
55
+ payment_method: Mapped[str] = mapped_column()
56
+ status: Mapped[str] = mapped_column()
57
+ created_at: Mapped[datetime | None] = mapped_column(server_default=text("CURRENT_TIMESTAMP"))
58
+ verified_at: Mapped[datetime | None] = mapped_column(default=None)
59
+
60
+
61
+ class Requests(Base):
62
+ """Flow state table model."""
63
+
64
+ __tablename__ = "requests"
65
+
66
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
67
+ issue_number: Mapped[int] = mapped_column(Integer, nullable=False)
68
+ request_text: Mapped[str] = mapped_column(String, nullable=False)
69
+ status: Mapped[str] = mapped_column(String, nullable=False)
70
+ commit: Mapped[str] = mapped_column(String, nullable=True)
71
+ created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
72
+
73
+
74
+ def update_request_status(
75
+ session: sessionmaker,
76
+ issue_number: int,
77
+ status: str,
78
+ commit: Optional[str] = None,
79
+ ) -> bool:
80
+ """Update the status of a request by issue_number.
81
+
82
+ Args:
83
+ session: SQLAlchemy session factory
84
+ issue_number: The issue_number to update
85
+ status: The new status value
86
+
87
+ Returns:
88
+ True if a row was updated, False otherwise
89
+ """
90
+ with session() as sess:
91
+ result = sess.query(Requests).filter_by(issue_number=issue_number).update({"status": status, "commit": commit})
92
+ sess.commit()
93
+ return result > 0
94
+
95
+
96
+ def ensure_request_exists(
97
+ session: sessionmaker,
98
+ issue_number: int,
99
+ body: str,
100
+ status: str = "pending",
101
+ ) -> bool:
102
+ """Ensure a request exists for the given issue_number, inserting if needed.
103
+
104
+ Args:
105
+ session: SQLAlchemy session factory
106
+ issue_number: The issue_number to check/insert
107
+ body: The issue body text
108
+ status: The status for new requests (default: "pending")
109
+
110
+ Returns:
111
+ True if a new request was inserted, False if it already existed
112
+ """
113
+ with session() as sess:
114
+ existing = sess.query(Requests).filter_by(issue_number=issue_number).first()
115
+ if existing:
116
+ return False
117
+
118
+ new_payment = Payments(
119
+ issue_number=issue_number,
120
+ status="manual",
121
+ payment_id="none",
122
+ amount=0,
123
+ currency="USD",
124
+ payment_method="none",
125
+ )
126
+ new_request = Requests(issue_number=issue_number, request_text=body, status=status)
127
+ sess.add(new_payment)
128
+ sess.commit()
129
+ sess.add(new_request)
130
+ sess.commit()
131
+ return True
132
+
133
+
134
+ class FlowState(Base):
135
+ """Flow state table model."""
136
+
137
+ __tablename__ = "flow_states"
138
+
139
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
140
+ flow_uuid: Mapped[str] = mapped_column(String(255), nullable=False)
141
+ method_name: Mapped[str] = mapped_column(String(255), nullable=False)
142
+ timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
143
+ state_json: Mapped[dict[str, Any]] = mapped_column(JSONType, nullable=False)
144
+
145
+ __table_args__ = (Index("idx_flow_states_uuid", "flow_uuid"),)
146
+
147
+
148
+ class PendingFeedback(Base):
149
+ """Pending feedback table model for async HITL."""
150
+
151
+ __tablename__ = "pending_feedback"
152
+
153
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
154
+ flow_uuid: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
155
+ context_json: Mapped[dict[str, Any]] = mapped_column(JSONType, nullable=False)
156
+ state_json: Mapped[dict[str, Any]] = mapped_column(JSONType, nullable=False)
157
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
158
+
159
+ __table_args__ = (Index("idx_pending_feedback_uuid", "flow_uuid"),)
160
+
161
+
162
+ class PostgresFlowPersistence(FlowPersistence):
163
+ """PostgreSQL-based implementation of flow state persistence.
164
+
165
+ This class provides production-grade persistence using PostgreSQL with SQLAlchemy
166
+ and JSONB for efficient querying of flow states.
167
+
168
+ Example:
169
+ ```python
170
+ persistence = PostgresFlowPersistence(
171
+ connection_string="postgresql://user:pass@localhost/db"
172
+ )
173
+
174
+ # Start a flow with async feedback
175
+ try:
176
+ flow = MyFlow(persistence=persistence)
177
+ result = flow.kickoff()
178
+ except HumanFeedbackPending as e:
179
+ # Flow is paused, state is already persisted
180
+ print(f"Waiting for feedback: {e.context.flow_id}")
181
+
182
+ # Later, resume with feedback
183
+ flow = MyFlow.from_pending("abc-123", persistence)
184
+ result = flow.resume("looks good!")
185
+ ```
186
+ """
187
+
188
+ def __init__(self, connection_string: str) -> None:
189
+ """Initialize PostgreSQL persistence.
190
+
191
+ Args:
192
+ connection_string: PostgreSQL connection string.
193
+ Format: postgresql://user:password@host:port/database
194
+
195
+ Raises:
196
+ ValueError: If connection_string is invalid
197
+ """
198
+ if not connection_string:
199
+ raise ValueError("Connection string must be provided")
200
+
201
+ self.connection_string = connection_string
202
+ self.engine = create_engine(connection_string)
203
+ self.Session = sessionmaker(bind=self.engine)
204
+ self.init_db()
205
+
206
+ def init_db(self) -> None:
207
+ """Create the necessary tables if they don't exist."""
208
+ Base.metadata.create_all(self.engine)
209
+
210
+ def save_state(
211
+ self,
212
+ flow_uuid: str,
213
+ method_name: str,
214
+ state_data: dict[str, Any] | BaseModel,
215
+ ) -> None:
216
+ """Save the current flow state to PostgreSQL.
217
+
218
+ Args:
219
+ flow_uuid: Unique identifier for the flow instance
220
+ method_name: Name of the method that just completed
221
+ state_data: Current state data (either dict or Pydantic model)
222
+ """
223
+ state_dict = self._to_dict(state_data)
224
+
225
+ with self.Session() as session:
226
+ state = FlowState(
227
+ flow_uuid=flow_uuid,
228
+ method_name=method_name,
229
+ timestamp=datetime.now(timezone.utc),
230
+ state_json=state_dict,
231
+ )
232
+ session.add(state)
233
+ session.commit()
234
+
235
+ def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
236
+ """Load the most recent state for a given flow UUID.
237
+
238
+ Args:
239
+ flow_uuid: Unique identifier for the flow instance
240
+
241
+ Returns:
242
+ The most recent state as a dictionary, or None if no state exists
243
+ """
244
+ with self.Session() as session:
245
+ state = (
246
+ session.query(FlowState).filter(FlowState.flow_uuid == flow_uuid).order_by(FlowState.id.desc()).first()
247
+ )
248
+
249
+ if state:
250
+ return state.state_json
251
+ return None
252
+
253
+ def save_pending_feedback(
254
+ self,
255
+ flow_uuid: str,
256
+ context: PendingFeedbackContext,
257
+ state_data: dict[str, Any] | BaseModel,
258
+ ) -> None:
259
+ """Save state with a pending feedback marker.
260
+
261
+ This method stores both the flow state and the pending feedback context,
262
+ allowing the flow to be resumed later when feedback is received.
263
+
264
+ Args:
265
+ flow_uuid: Unique identifier for the flow instance
266
+ context: The pending feedback context with all resume information
267
+ state_data: Current state data
268
+ """
269
+ state_dict = self._to_dict(state_data)
270
+
271
+ self.save_state(flow_uuid, context.method_name, state_data)
272
+
273
+ with self.Session() as session:
274
+ existing = session.query(PendingFeedback).filter(PendingFeedback.flow_uuid == flow_uuid).first()
275
+
276
+ if existing:
277
+ existing.context_json = context.to_dict()
278
+ existing.state_json = state_dict
279
+ existing.created_at = datetime.now(timezone.utc)
280
+ else:
281
+ pending = PendingFeedback(
282
+ flow_uuid=flow_uuid,
283
+ context_json=context.to_dict(),
284
+ state_json=state_dict,
285
+ created_at=datetime.now(timezone.utc),
286
+ )
287
+ session.add(pending)
288
+
289
+ session.commit()
290
+
291
+ def load_pending_feedback(
292
+ self,
293
+ flow_uuid: str,
294
+ ) -> tuple[dict[str, Any], PendingFeedbackContext] | None:
295
+ """Load state and pending feedback context.
296
+
297
+ Args:
298
+ flow_uuid: Unique identifier for the flow instance
299
+
300
+ Returns:
301
+ Tuple of (state_data, pending_context) if pending feedback exists,
302
+ None otherwise.
303
+ """
304
+ with self.Session() as session:
305
+ pending = session.query(PendingFeedback).filter(PendingFeedback.flow_uuid == flow_uuid).first()
306
+
307
+ if pending:
308
+ context = PendingFeedbackContext.from_dict(pending.context_json)
309
+ return (pending.state_json, context)
310
+ return None
311
+
312
+ def clear_pending_feedback(self, flow_uuid: str) -> None:
313
+ """Clear the pending feedback marker after successful resume.
314
+
315
+ Args:
316
+ flow_uuid: Unique identifier for the flow instance
317
+ """
318
+ with self.Session() as session:
319
+ session.query(PendingFeedback).filter(PendingFeedback.flow_uuid == flow_uuid).delete()
320
+ session.commit()
321
+
322
+ def _to_dict(self, state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
323
+ """Convert state_data to dict.
324
+
325
+ Args:
326
+ state_data: Current state data (either dict or Pydantic model)
327
+
328
+ Returns:
329
+ Dictionary representation of state_data
330
+
331
+ Raises:
332
+ ValueError: If state_data is not a dict or Pydantic model
333
+ """
334
+ if isinstance(state_data, BaseModel):
335
+ return state_data.model_dump()
336
+ elif isinstance(state_data, dict):
337
+ return state_data
338
+ else:
339
+ raise ValueError(f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}")
340
+
341
+
342
+ if DATABASE_URL and DATABASE_URL.startswith("postgres"):
343
+ logger.info("📊 Connecting persistence with postgres")
344
+ persistence = PostgresFlowPersistence(connection_string=DATABASE_URL)
345
+ else:
346
+ persistence = SQLiteFlowPersistence()
@@ -0,0 +1,111 @@
1
+ """SQLAlchemy model registry with auto-registration via __init_subclass__."""
2
+
3
+ import logging
4
+ from typing import Type
5
+
6
+ from sqlalchemy import MetaData
7
+ from sqlalchemy.orm import DeclarativeBase
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ METADATA = MetaData(
13
+ naming_convention={
14
+ "ix": "ix_%(column_0_label)s",
15
+ "uq": "uq_%(table_name)s_%(column_0_name)s",
16
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
17
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
18
+ "pk": "pk_%(table_name)s",
19
+ }
20
+ )
21
+
22
+
23
+ class ModelRegistry:
24
+ """Central registry for ORM models from all modules."""
25
+
26
+ _models: dict[str, Type[DeclarativeBase]] = {}
27
+ _modules: set[str] = set()
28
+
29
+ @classmethod
30
+ def register_model(cls, model: Type[DeclarativeBase], module_name: str) -> None:
31
+ """Register a single model under its module name."""
32
+ key = f"{module_name}.{model.__tablename__}"
33
+ cls._models[key] = model
34
+
35
+ @classmethod
36
+ def register_module(cls, module_name: str) -> None:
37
+ """Mark a module as having been processed."""
38
+ cls._modules.add(module_name)
39
+
40
+ @classmethod
41
+ def is_registered(cls, module_name: str) -> bool:
42
+ return module_name in cls._modules
43
+
44
+ @classmethod
45
+ def create_all(cls, engine) -> None:
46
+ """Create all registered tables in one pass."""
47
+ METADATA.create_all(bind=engine)
48
+ log.info(f"📊 Created {len(cls._models)} tables from {len(cls._modules)} modules")
49
+
50
+ @classmethod
51
+ def get_models_for_module(cls, module_name: str) -> list[Type[DeclarativeBase]]:
52
+ """Return all models belonging to a module."""
53
+ prefix = f"{module_name}."
54
+ return [m for key, m in cls._models.items() if key.startswith(prefix)]
55
+
56
+ @classmethod
57
+ def all_models(cls) -> dict[str, Type[DeclarativeBase]]:
58
+ """Return all registered models as {module.table: model}."""
59
+ return dict(cls._models)
60
+
61
+ @classmethod
62
+ def reset(cls) -> None:
63
+ """Clear registry (for testing)."""
64
+ cls._models.clear()
65
+ cls._modules.clear()
66
+
67
+
68
+ class RegisteredBase(DeclarativeBase):
69
+ """Base class for ORM models with auto-registration.
70
+
71
+ All models across all modules inherit from this. Each model must
72
+ declare __module_name__ to identify its owning module.
73
+
74
+ Usage:
75
+
76
+ class MyModel(RegisteredBase):
77
+ __module_name__ = "my_module"
78
+ __tablename__ = "my_table"
79
+
80
+ id: Mapped[int] = mapped_column(primary_key=True)
81
+
82
+ The __init_subclass__ hook automatically registers the model with
83
+ ModelRegistry when the class is defined (at import time).
84
+
85
+ If __module_name__ is omitted, the registry attempts to infer it from
86
+ the class's __module__ attribute (e.g., 'src.retro.persistence' -> 'retro').
87
+ """
88
+
89
+ metadata = METADATA
90
+ __module_name__: str
91
+
92
+ def __init_subclass__(cls, **kwargs) -> None:
93
+ super().__init_subclass__(**kwargs)
94
+
95
+ if getattr(cls, "__abstract__", False):
96
+ return
97
+
98
+ module_name = getattr(cls, "__module_name__", None)
99
+ if not module_name:
100
+ parts = cls.__module__.split(".")
101
+ if len(parts) >= 2 and parts[0] == "src":
102
+ module_name = parts[1]
103
+
104
+ if module_name:
105
+ ModelRegistry.register_model(cls, module_name)
106
+ ModelRegistry.register_module(module_name)
107
+ log.debug(f"📊 Auto-registered: {module_name}.{cls.__tablename__}")
108
+ else:
109
+ log.warning(
110
+ f"⚠️ {cls.__name__} has no __module_name__ and cannot be inferred from __module__={cls.__module__!r}"
111
+ )
persistence/tasks.py ADDED
@@ -0,0 +1,178 @@
1
+ """Celery task tracking in PostgreSQL."""
2
+
3
+ from datetime import datetime, timedelta, timezone
4
+
5
+ from sqlalchemy import Index
6
+ from sqlalchemy.orm import Mapped, mapped_column
7
+ from sqlalchemy.types import DateTime, Integer, String, Text
8
+
9
+ from .postgres import Base
10
+
11
+
12
+ class CeleryTask(Base):
13
+ """Celery task tracking model.
14
+
15
+ Tracks all Celery tasks for monitoring and debugging purposes.
16
+ """
17
+
18
+ __tablename__ = "tasks"
19
+
20
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
21
+ task_id: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
22
+ flow_id: Mapped[str] = mapped_column(String(255), nullable=False)
23
+ task_type: Mapped[str] = mapped_column(String(100), nullable=False)
24
+ status: Mapped[str] = mapped_column(String(50), nullable=False)
25
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
26
+ started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
27
+ completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
28
+ result: Mapped[str | None] = mapped_column(Text, nullable=True)
29
+ error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
30
+ retry_count: Mapped[int] = mapped_column(Integer, default=0)
31
+ issue_number: Mapped[int | None] = mapped_column(Integer, nullable=True)
32
+
33
+ __table_args__ = (
34
+ Index("idx_tasks_flow_id", "flow_id"),
35
+ Index("idx_tasks_status", "status"),
36
+ Index("idx_tasks_task_type", "task_type"),
37
+ )
38
+
39
+
40
+ class CeleryTaskTracker:
41
+ """Track Celery tasks in PostgreSQL."""
42
+
43
+ def __init__(self, session_factory):
44
+ """Initialize task tracker.
45
+
46
+ Args:
47
+ session_factory: SQLAlchemy session factory
48
+ """
49
+ self.Session = session_factory
50
+
51
+ def create_task(
52
+ self,
53
+ task_id: str,
54
+ flow_id: str,
55
+ task_type: str,
56
+ issue_number: int | None = None,
57
+ ) -> None:
58
+ """Create a new task record.
59
+
60
+ Args:
61
+ task_id: Celery task ID
62
+ flow_id: Flow ID this task belongs to
63
+ task_type: Type of task (e.g., 'implement_story', 'test_story')
64
+ issue_number: Optional GitHub issue number
65
+ """
66
+ with self.Session() as session:
67
+ task = CeleryTask(
68
+ task_id=task_id,
69
+ flow_id=flow_id,
70
+ task_type=task_type,
71
+ status="pending",
72
+ created_at=datetime.now(timezone.utc),
73
+ issue_number=issue_number,
74
+ )
75
+ session.add(task)
76
+ session.commit()
77
+
78
+ def update_task_started(self, task_id: str) -> None:
79
+ """Mark task as started.
80
+
81
+ Args:
82
+ task_id: Celery task ID
83
+ """
84
+ with self.Session() as session:
85
+ task = session.query(CeleryTask).filter(CeleryTask.task_id == task_id).first()
86
+ if task:
87
+ task.status = "running"
88
+ task.started_at = datetime.now(timezone.utc)
89
+ session.commit()
90
+
91
+ def update_task_completed(self, task_id: str, result: str | None = None) -> None:
92
+ """Mark task as completed.
93
+
94
+ Args:
95
+ task_id: Celery task ID
96
+ result: Optional result data
97
+ """
98
+ with self.Session() as session:
99
+ task = session.query(CeleryTask).filter(CeleryTask.task_id == task_id).first()
100
+ if task:
101
+ task.status = "completed"
102
+ task.completed_at = datetime.now(timezone.utc)
103
+ task.result = result
104
+ session.commit()
105
+
106
+ def update_task_failed(self, task_id: str, error_message: str) -> None:
107
+ """Mark task as failed.
108
+
109
+ Args:
110
+ task_id: Celery task ID
111
+ error_message: Error message
112
+ """
113
+ with self.Session() as session:
114
+ task = session.query(CeleryTask).filter(CeleryTask.task_id == task_id).first()
115
+ if task:
116
+ task.status = "failed"
117
+ task.completed_at = datetime.now(timezone.utc)
118
+ task.error_message = error_message
119
+ session.commit()
120
+
121
+ def increment_retry(self, task_id: str) -> None:
122
+ """Increment retry count.
123
+
124
+ Args:
125
+ task_id: Celery task ID
126
+ """
127
+ with self.Session() as session:
128
+ task = session.query(CeleryTask).filter(CeleryTask.task_id == task_id).first()
129
+ if task:
130
+ task.retry_count += 1
131
+ session.commit()
132
+
133
+ def get_task(self, task_id: str) -> CeleryTask | None:
134
+ """Get task by ID.
135
+
136
+ Args:
137
+ task_id: Celery task ID
138
+
139
+ Returns:
140
+ CeleryTask if found, None otherwise
141
+ """
142
+ with self.Session() as session:
143
+ return session.query(CeleryTask).filter(CeleryTask.task_id == task_id).first()
144
+
145
+ def get_flow_tasks(self, flow_id: str) -> list[CeleryTask]:
146
+ """Get all tasks for a flow.
147
+
148
+ Args:
149
+ flow_id: Flow ID
150
+
151
+ Returns:
152
+ List of CeleryTask objects
153
+ """
154
+ with self.Session() as session:
155
+ return session.query(CeleryTask).filter(CeleryTask.flow_id == flow_id).order_by(CeleryTask.created_at).all()
156
+
157
+ def cleanup_completed_tasks(self, days_old: int = 7) -> int:
158
+ """Delete completed/failed tasks older than specified days.
159
+
160
+ Args:
161
+ days_old: Number of days to keep completed tasks
162
+
163
+ Returns:
164
+ Number of tasks deleted
165
+ """
166
+ cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
167
+
168
+ with self.Session() as session:
169
+ deleted = (
170
+ session.query(CeleryTask)
171
+ .filter(
172
+ CeleryTask.status.in_(["completed", "failed"]),
173
+ CeleryTask.completed_at < cutoff_date,
174
+ )
175
+ .delete()
176
+ )
177
+ session.commit()
178
+ return deleted