orchestrator-core 4.5.3__py3-none-any.whl → 4.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. orchestrator/__init__.py +2 -2
  2. orchestrator/agentic_app.py +3 -23
  3. orchestrator/api/api_v1/api.py +5 -0
  4. orchestrator/api/api_v1/endpoints/agent.py +49 -0
  5. orchestrator/api/api_v1/endpoints/search.py +120 -201
  6. orchestrator/app.py +1 -1
  7. orchestrator/cli/database.py +3 -0
  8. orchestrator/cli/generate.py +11 -4
  9. orchestrator/cli/generator/generator/migration.py +7 -3
  10. orchestrator/cli/main.py +1 -1
  11. orchestrator/cli/scheduler.py +15 -22
  12. orchestrator/cli/search/resize_embedding.py +28 -22
  13. orchestrator/cli/search/speedtest.py +4 -6
  14. orchestrator/db/__init__.py +6 -0
  15. orchestrator/db/models.py +75 -0
  16. orchestrator/llm_settings.py +18 -1
  17. orchestrator/migrations/helpers.py +47 -39
  18. orchestrator/schedules/scheduler.py +32 -15
  19. orchestrator/schedules/validate_products.py +1 -1
  20. orchestrator/schemas/search.py +8 -85
  21. orchestrator/search/agent/__init__.py +2 -2
  22. orchestrator/search/agent/agent.py +26 -30
  23. orchestrator/search/agent/json_patch.py +51 -0
  24. orchestrator/search/agent/prompts.py +35 -9
  25. orchestrator/search/agent/state.py +28 -2
  26. orchestrator/search/agent/tools.py +192 -53
  27. orchestrator/search/core/embedding.py +2 -2
  28. orchestrator/search/core/exceptions.py +6 -0
  29. orchestrator/search/core/types.py +1 -0
  30. orchestrator/search/export.py +199 -0
  31. orchestrator/search/indexing/indexer.py +13 -4
  32. orchestrator/search/indexing/registry.py +14 -1
  33. orchestrator/search/llm_migration.py +55 -0
  34. orchestrator/search/retrieval/__init__.py +3 -2
  35. orchestrator/search/retrieval/builder.py +5 -1
  36. orchestrator/search/retrieval/engine.py +66 -23
  37. orchestrator/search/retrieval/pagination.py +46 -56
  38. orchestrator/search/retrieval/query_state.py +61 -0
  39. orchestrator/search/retrieval/retrievers/base.py +26 -40
  40. orchestrator/search/retrieval/retrievers/fuzzy.py +10 -9
  41. orchestrator/search/retrieval/retrievers/hybrid.py +11 -8
  42. orchestrator/search/retrieval/retrievers/semantic.py +9 -8
  43. orchestrator/search/retrieval/retrievers/structured.py +6 -6
  44. orchestrator/search/schemas/parameters.py +17 -13
  45. orchestrator/search/schemas/results.py +4 -1
  46. orchestrator/settings.py +1 -0
  47. orchestrator/utils/auth.py +3 -2
  48. orchestrator/workflow.py +23 -6
  49. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0.dist-info}/METADATA +16 -11
  50. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0.dist-info}/RECORD +52 -48
  51. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0.dist-info}/WHEEL +0 -0
  52. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,26 +12,23 @@
12
12
  # limitations under the License.
13
13
 
14
14
 
15
- import logging
16
15
  import time
17
16
 
18
17
  import typer
19
18
 
20
19
  from orchestrator.schedules.scheduler import (
21
- get_paused_scheduler,
20
+ get_all_scheduler_tasks,
21
+ get_scheduler,
22
+ get_scheduler_task,
22
23
  )
23
24
 
24
- log = logging.getLogger(__name__)
25
-
26
25
  app: typer.Typer = typer.Typer()
27
26
 
28
27
 
29
28
  @app.command()
30
29
  def run() -> None:
31
30
  """Start scheduler and loop eternally to keep thread alive."""
32
- with get_paused_scheduler() as scheduler:
33
- scheduler.resume()
34
-
31
+ with get_scheduler():
35
32
  while True:
36
33
  time.sleep(1)
37
34
 
@@ -42,27 +39,23 @@ def show_schedule() -> None:
42
39
 
43
40
  in cli underscore is replaced by a dash `show-schedule`
44
41
  """
45
- with get_paused_scheduler() as scheduler:
46
- jobs = scheduler.get_jobs()
47
-
48
- for job in jobs:
49
- typer.echo(f"[{job.id}] Next run: {job.next_run_time} | Trigger: {job.trigger}")
42
+ for task in get_all_scheduler_tasks():
43
+ typer.echo(f"[{task.id}] Next run: {task.next_run_time} | Trigger: {task.trigger}")
50
44
 
51
45
 
52
46
  @app.command()
53
- def force(job_id: str) -> None:
54
- """Force the execution of (a) scheduler(s) based on a job_id."""
55
- with get_paused_scheduler() as scheduler:
56
- job = scheduler.get_job(job_id)
47
+ def force(task_id: str) -> None:
48
+ """Force the execution of (a) scheduler(s) based on a task_id."""
49
+ task = get_scheduler_task(task_id)
57
50
 
58
- if not job:
59
- typer.echo(f"Job '{job_id}' not found.")
51
+ if not task:
52
+ typer.echo(f"Task '{task_id}' not found.")
60
53
  raise typer.Exit(code=1)
61
54
 
62
- typer.echo(f"Running job [{job.id}] now...")
55
+ typer.echo(f"Running Task [{task.id}] now...")
63
56
  try:
64
- job.func(*job.args or (), **job.kwargs or {})
65
- typer.echo("Job executed successfully.")
57
+ task.func(*task.args or (), **task.kwargs or {})
58
+ typer.echo("Task executed successfully.")
66
59
  except Exception as e:
67
- typer.echo(f"Job execution failed: {e}")
60
+ typer.echo(f"Task execution failed: {e}")
68
61
  raise typer.Exit(code=1)
@@ -4,7 +4,7 @@ from sqlalchemy import text
4
4
  from sqlalchemy.exc import SQLAlchemyError
5
5
 
6
6
  from orchestrator.db import db
7
- from orchestrator.db.models import AiSearchIndex
7
+ from orchestrator.db.models import AiSearchIndex, SearchQueryTable
8
8
  from orchestrator.llm_settings import llm_settings
9
9
 
10
10
  logger = structlog.get_logger(__name__)
@@ -40,17 +40,20 @@ def get_current_embedding_dimension() -> int | None:
40
40
  return None
41
41
 
42
42
 
43
- def drop_all_embeddings() -> int:
44
- """Drop all records from the ai_search_index table.
43
+ def drop_all_embeddings() -> tuple[int, int]:
44
+ """Drop all records from ai_search_index and search_queries tables.
45
45
 
46
46
  Returns:
47
- Number of records deleted
47
+ Tuple of (ai_search_index records deleted, search_queries records deleted)
48
48
  """
49
49
  try:
50
- result = db.session.query(AiSearchIndex).delete()
50
+ index_deleted = db.session.query(AiSearchIndex).delete()
51
+ query_deleted = db.session.query(SearchQueryTable).delete()
51
52
  db.session.commit()
52
- logger.info(f"Deleted {result} records from ai_search_index")
53
- return result
53
+ logger.info(
54
+ f"Deleted {index_deleted} records from ai_search_index and {query_deleted} records from search_queries"
55
+ )
56
+ return index_deleted, query_deleted
54
57
 
55
58
  except SQLAlchemyError as e:
56
59
  db.session.rollback()
@@ -59,34 +62,34 @@ def drop_all_embeddings() -> int:
59
62
 
60
63
 
61
64
  def alter_embedding_column_dimension(new_dimension: int) -> None:
62
- """Alter the embedding column to use the new dimension size.
65
+ """Alter the embedding columns in both ai_search_index and search_queries tables.
63
66
 
64
67
  Args:
65
68
  new_dimension: New vector dimension size
66
69
  """
67
70
  try:
68
- drop_query = text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding")
69
- db.session.execute(drop_query)
71
+ db.session.execute(text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding"))
72
+ db.session.execute(text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})"))
70
73
 
71
- add_query = text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})")
72
- db.session.execute(add_query)
74
+ db.session.execute(text("ALTER TABLE search_queries DROP COLUMN IF EXISTS query_embedding"))
75
+ db.session.execute(text(f"ALTER TABLE search_queries ADD COLUMN query_embedding vector({new_dimension})"))
73
76
 
74
77
  db.session.commit()
75
- logger.info(f"Altered embedding column to dimension {new_dimension}")
78
+ logger.info(f"Altered embedding columns to dimension {new_dimension} in ai_search_index and search_queries")
76
79
 
77
80
  except SQLAlchemyError as e:
78
81
  db.session.rollback()
79
- logger.error("Failed to alter embedding column dimension", error=str(e))
82
+ logger.error("Failed to alter embedding column dimensions", error=str(e))
80
83
  raise
81
84
 
82
85
 
83
86
  @app.command("resize")
84
87
  def resize_embeddings_command() -> None:
85
- """Resize vector dimensions of the ai_search_index embedding column.
88
+ """Resize vector dimensions of embedding columns in ai_search_index and search_queries tables.
86
89
 
87
90
  Compares the current embedding dimension in the database with the configured
88
- dimension in llm_settings. If they differ, drops all records and alters the
89
- column to match the new dimension.
91
+ dimension in llm_settings. If they differ, drops all records and alters both
92
+ embedding columns to match the new dimension.
90
93
  """
91
94
  new_dimension = llm_settings.EMBEDDING_DIMENSION
92
95
 
@@ -107,22 +110,25 @@ def resize_embeddings_command() -> None:
107
110
 
108
111
  logger.info("Dimension mismatch detected", current_dimension=current_dimension, new_dimension=new_dimension)
109
112
 
110
- if not typer.confirm("This will DELETE ALL RECORDS from ai_search_index and alter the embedding column. Continue?"):
113
+ if not typer.confirm(
114
+ "This will DELETE ALL RECORDS from ai_search_index and search_queries tables and alter embedding columns. Continue?"
115
+ ):
111
116
  logger.info("Operation cancelled by user")
112
117
  return
113
118
 
114
119
  try:
115
120
  # Drop all records first.
116
121
  logger.info("Dropping all embedding records...")
117
- deleted_count = drop_all_embeddings()
122
+ index_deleted, query_deleted = drop_all_embeddings()
118
123
 
119
- # Then alter column dimension.
120
- logger.info(f"Altering embedding column to dimension {new_dimension}...")
124
+ # Then alter column dimensions.
125
+ logger.info(f"Altering embedding columns to dimension {new_dimension}...")
121
126
  alter_embedding_column_dimension(new_dimension)
122
127
 
123
128
  logger.info(
124
129
  "Embedding dimension resize completed successfully",
125
- records_deleted=deleted_count,
130
+ index_records_deleted=index_deleted,
131
+ query_records_deleted=query_deleted,
126
132
  new_dimension=new_dimension,
127
133
  )
128
134
 
@@ -13,7 +13,6 @@ from orchestrator.search.core.embedding import QueryEmbedder
13
13
  from orchestrator.search.core.types import EntityType
14
14
  from orchestrator.search.core.validators import is_uuid
15
15
  from orchestrator.search.retrieval.engine import execute_search
16
- from orchestrator.search.retrieval.pagination import PaginationParams
17
16
  from orchestrator.search.schemas.parameters import BaseSearchParameters
18
17
 
19
18
  logger = structlog.get_logger(__name__)
@@ -54,17 +53,16 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[
54
53
  async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]:
55
54
  search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30)
56
55
 
56
+ query_embedding = None
57
+
57
58
  if is_uuid(query):
58
- pagination_params = PaginationParams()
59
59
  logger.debug("Using fuzzy-only ranking for full UUID", query=query)
60
60
  else:
61
-
62
- cached_embedding = embedding_lookup[query]
63
- pagination_params = PaginationParams(q_vec_override=cached_embedding)
61
+ query_embedding = embedding_lookup[query]
64
62
 
65
63
  with db.session as session:
66
64
  start_time = time.perf_counter()
67
- response = await execute_search(search_params, session, pagination_params=pagination_params)
65
+ response = await execute_search(search_params, session, cursor=None, query_embedding=query_embedding)
68
66
  end_time = time.perf_counter()
69
67
 
70
68
  return {
@@ -17,6 +17,7 @@ from structlog import get_logger
17
17
  from orchestrator.db.database import BaseModel as DbBaseModel
18
18
  from orchestrator.db.database import Database, transactional
19
19
  from orchestrator.db.models import ( # noqa: F401
20
+ AgentRunTable,
20
21
  EngineSettingsTable,
21
22
  FixedInputTable,
22
23
  InputStateTable,
@@ -26,6 +27,7 @@ from orchestrator.db.models import ( # noqa: F401
26
27
  ProductBlockTable,
27
28
  ProductTable,
28
29
  ResourceTypeTable,
30
+ SearchQueryTable,
29
31
  SubscriptionCustomerDescriptionTable,
30
32
  SubscriptionInstanceRelationTable,
31
33
  SubscriptionInstanceTable,
@@ -74,6 +76,8 @@ def init_database(settings: AppSettings) -> Database:
74
76
 
75
77
  __all__ = [
76
78
  "transactional",
79
+ "SearchQueryTable",
80
+ "AgentRunTable",
77
81
  "SubscriptionTable",
78
82
  "ProcessSubscriptionTable",
79
83
  "ProcessTable",
@@ -97,6 +101,8 @@ __all__ = [
97
101
  ]
98
102
 
99
103
  ALL_DB_MODELS: list[type[DbBaseModel]] = [
104
+ SearchQueryTable,
105
+ AgentRunTable,
100
106
  FixedInputTable,
101
107
  ProcessStepTable,
102
108
  ProcessSubscriptionTable,
orchestrator/db/models.py CHANGED
@@ -15,6 +15,7 @@ from __future__ import annotations
15
15
 
16
16
  import enum
17
17
  from datetime import datetime, timezone
18
+ from typing import TYPE_CHECKING
18
19
  from uuid import UUID
19
20
 
20
21
  import sqlalchemy
@@ -58,6 +59,9 @@ from orchestrator.targets import Target
58
59
  from orchestrator.utils.datetime import nowtz
59
60
  from orchestrator.version import GIT_COMMIT_HASH
60
61
 
62
+ if TYPE_CHECKING:
63
+ from orchestrator.search.retrieval.query_state import SearchQueryState
64
+
61
65
  logger = structlog.get_logger(__name__)
62
66
 
63
67
  TAG_LENGTH = 20
@@ -674,6 +678,76 @@ class SubscriptionSearchView(BaseModel):
674
678
  subscription = relationship("SubscriptionTable", foreign_keys=[subscription_id])
675
679
 
676
680
 
681
+ class AgentRunTable(BaseModel):
682
+ """Agent conversation/session tracking."""
683
+
684
+ __tablename__ = "agent_runs"
685
+
686
+ run_id = mapped_column("run_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True)
687
+ agent_type = mapped_column(String(50), nullable=False)
688
+ created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False)
689
+
690
+ queries = relationship("SearchQueryTable", back_populates="run", cascade="delete", passive_deletes=True)
691
+
692
+ __table_args__ = (Index("ix_agent_runs_created_at", "created_at"),)
693
+
694
+
695
+ class SearchQueryTable(BaseModel):
696
+ """Search query execution - used by both agent runs and regular API searches.
697
+
698
+ When run_id is NULL: standalone API search query
699
+ When run_id is NOT NULL: query belongs to an agent conversation run
700
+ """
701
+
702
+ __tablename__ = "search_queries"
703
+
704
+ query_id = mapped_column("query_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True)
705
+ run_id = mapped_column(
706
+ "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=True, index=True
707
+ )
708
+ query_number = mapped_column(Integer, nullable=False)
709
+
710
+ # Search parameters as JSONB (maps to BaseSearchParameters subclasses)
711
+ parameters = mapped_column(pg.JSONB, nullable=False)
712
+
713
+ # Query embedding for semantic search (pgvector)
714
+ query_embedding = mapped_column(Vector(llm_settings.EMBEDDING_DIMENSION), nullable=True)
715
+
716
+ executed_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False)
717
+
718
+ run = relationship("AgentRunTable", back_populates="queries")
719
+
720
+ __table_args__ = (
721
+ Index("ix_search_queries_run_id", "run_id"),
722
+ Index("ix_search_queries_executed_at", "executed_at"),
723
+ Index("ix_search_queries_query_id", "query_id"),
724
+ )
725
+
726
+ @classmethod
727
+ def from_state(
728
+ cls,
729
+ state: "SearchQueryState",
730
+ run_id: "UUID | None" = None,
731
+ query_number: int = 1,
732
+ ) -> "SearchQueryTable":
733
+ """Create a SearchQueryTable instance from a SearchQueryState.
734
+
735
+ Args:
736
+ state: The search query state with parameters and embedding
737
+ run_id: Optional agent run ID (NULL for regular API searches)
738
+ query_number: Query number within the run (default=1)
739
+
740
+ Returns:
741
+ SearchQueryTable instance ready to be added to the database.
742
+ """
743
+ return cls(
744
+ run_id=run_id,
745
+ query_number=query_number,
746
+ parameters=state.parameters.model_dump(),
747
+ query_embedding=state.query_embedding,
748
+ )
749
+
750
+
677
751
  class EngineSettingsTable(BaseModel):
678
752
  __tablename__ = "engine_settings"
679
753
  global_lock = mapped_column(Boolean(), default=False, nullable=False, primary_key=True)
@@ -705,6 +779,7 @@ class AiSearchIndex(BaseModel):
705
779
  UUIDType,
706
780
  nullable=False,
707
781
  )
782
+ entity_title = mapped_column(TEXT, nullable=True)
708
783
 
709
784
  # Ltree path for hierarchical data
710
785
  path = mapped_column(LtreeType, nullable=False, index=True)
@@ -10,14 +10,31 @@
10
10
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
+ from typing import Annotated
14
+
13
15
  from pydantic import Field, field_validator
14
16
  from pydantic_settings import BaseSettings
15
17
  from structlog import get_logger
16
18
 
17
19
  logger = get_logger(__name__)
18
20
 
21
+ EMBEDDING_DIMENSION_MIN = 100
22
+ EMBEDDING_DIMENSION_MAX = 2000
23
+ EMBEDDING_DIMENSION_DEFAULT = 1536
24
+
25
+ EMBEDDING_DIMENSION_FIELD = Annotated[
26
+ int,
27
+ Field(
28
+ ge=EMBEDDING_DIMENSION_MIN,
29
+ le=EMBEDDING_DIMENSION_MAX,
30
+ default=EMBEDDING_DIMENSION_DEFAULT,
31
+ description="Embedding dimension: when embeddings are generated at a higher resolution than this setting, the least significant numbers will be truncated",
32
+ ),
33
+ ]
34
+
19
35
 
20
36
  class LLMSettings(BaseSettings):
37
+
21
38
  # Feature flags for LLM functionality
22
39
  SEARCH_ENABLED: bool = False # Enable search/indexing with embeddings
23
40
  AGENT_ENABLED: bool = False # Enable agentic functionality
@@ -27,7 +44,7 @@ class LLMSettings(BaseSettings):
27
44
  AGENT_MODEL_VERSION: str = "2025-01-01-preview"
28
45
  OPENAI_API_KEY: str = "" # Change per provider (Azure, etc).
29
46
  # Embedding settings
30
- EMBEDDING_DIMENSION: int = 1536
47
+ EMBEDDING_DIMENSION: EMBEDDING_DIMENSION_FIELD = 1536
31
48
  EMBEDDING_MODEL: str = "openai/text-embedding-3-small" # See litellm docs for supported models.
32
49
  EMBEDDING_SAFE_MARGIN_PERCENT: float = Field(
33
50
  0.1, description="Safety margin as a percentage (e.g., 0.1 for 10%) for token budgeting.", ge=0, le=1
@@ -170,49 +170,57 @@ def create_workflow(conn: sa.engine.Connection, workflow: dict) -> None:
170
170
  }
171
171
  >>> create_workflow(conn, workflow)
172
172
  """
173
- if not workflow.get("is_task", False):
174
- workflow["is_task"] = False
173
+ params = workflow.copy()
174
+ params.setdefault("is_task", False)
175
+ params.setdefault("product_tag", None)
175
176
 
176
- if not workflow.get("product_tag"):
177
- workflow["product_tag"] = None
177
+ query_parts = []
178
178
 
179
179
  if has_table_column(table_name="workflows", column_name="is_task", conn=conn):
180
- query = """
181
- WITH new_workflow AS (
182
- INSERT INTO workflows (name, target, is_task, description)
183
- VALUES (:name, :target, :is_task, :description)
184
- ON CONFLICT DO NOTHING
185
- RETURNING workflow_id)
186
- INSERT
187
- INTO products_workflows (product_id, workflow_id)
188
- SELECT p.product_id,
189
- nw.workflow_id
190
- FROM products AS p
191
- CROSS JOIN new_workflow AS nw
192
- WHERE p.product_type = :product_type
193
- AND (:product_tag IS NULL OR p.tag = :product_tag)
194
- ON CONFLICT DO NOTHING
195
- """
180
+ query_parts.append(
181
+ """
182
+ WITH new_workflow AS (
183
+ INSERT INTO workflows (name, target, is_task, description)
184
+ VALUES (:name, :target, :is_task, :description)
185
+ ON CONFLICT DO NOTHING
186
+ RETURNING workflow_id
187
+ )
188
+ """
189
+ )
196
190
  else:
197
- # Remove is_task from workflow dict and insert SQL
198
- workflow = {k: v for k, v in workflow.items() if k != "is_task"}
199
- query = """
200
- WITH new_workflow AS (
201
- INSERT INTO workflows (name, target, description)
202
- VALUES (:name, :target, :description)
203
- ON CONFLICT DO NOTHING
204
- RETURNING workflow_id)
205
- INSERT
206
- INTO products_workflows (product_id, workflow_id)
207
- SELECT p.product_id, nw.workflow_id
208
- FROM products AS p
209
- CROSS JOIN new_workflow AS nw
210
- WHERE p.product_type = :product_type
211
- AND (:product_tag IS NULL OR p.tag = :product_tag)
212
- ON CONFLICT DO NOTHING
213
- """
191
+ params.pop("is_task", None)
192
+ query_parts.append(
193
+ """
194
+ WITH new_workflow AS (
195
+ INSERT INTO workflows (name, target, description)
196
+ VALUES (:name, :target, :description)
197
+ ON CONFLICT DO NOTHING
198
+ RETURNING workflow_id
199
+ )
200
+ """
201
+ )
202
+
203
+ query_parts.append(
204
+ """
205
+ INSERT INTO products_workflows (product_id, workflow_id)
206
+ SELECT p.product_id, nw.workflow_id
207
+ FROM products AS p
208
+ CROSS JOIN new_workflow AS nw
209
+ """
210
+ )
211
+
212
+ query_parts.append("WHERE p.product_type = :product_type")
213
+
214
+ if params.get("product_tag") is not None:
215
+ query_parts.append("AND p.tag = :product_tag")
216
+ else:
217
+ params.pop("product_tag", None)
218
+
219
+ query_parts.append("ON CONFLICT DO NOTHING")
220
+
221
+ query = "\n".join(query_parts)
214
222
 
215
- conn.execute(sa.text(query), workflow)
223
+ conn.execute(sa.text(query), params)
216
224
 
217
225
 
218
226
  def create_task(conn: sa.engine.Connection, task: dict) -> None:
@@ -229,7 +237,7 @@ def create_task(conn: sa.engine.Connection, task: dict) -> None:
229
237
  "name": "task_name",
230
238
  "description": "task description",
231
239
  }
232
- >>> create_workflow(conn, task)
240
+ >>> create_task(conn, task)
233
241
  """
234
242
  if has_table_column(table_name="workflows", column_name="is_task", conn=conn):
235
243
  query = """
@@ -17,16 +17,16 @@ from datetime import datetime
17
17
  from typing import Any, Generator
18
18
 
19
19
  from apscheduler.executors.pool import ThreadPoolExecutor
20
- from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
20
+ from apscheduler.jobstores.sqlalchemy import Job, SQLAlchemyJobStore
21
21
  from apscheduler.schedulers.background import BackgroundScheduler
22
22
  from more_itertools import partition
23
23
  from pydantic import BaseModel
24
24
 
25
+ from orchestrator.db import db
25
26
  from orchestrator.db.filters import Filter
26
27
  from orchestrator.db.filters.filters import CallableErrorHandler
27
28
  from orchestrator.db.sorting import Sort
28
29
  from orchestrator.db.sorting.sorting import SortOrder
29
- from orchestrator.settings import app_settings
30
30
  from orchestrator.utils.helpers import camel_to_snake, to_camel
31
31
 
32
32
  executors = {
@@ -40,18 +40,37 @@ scheduler = BackgroundScheduler(executors=executors, job_defaults=job_defaults)
40
40
 
41
41
 
42
42
  @contextmanager
43
- def get_paused_scheduler() -> Generator[BackgroundScheduler, Any, None]:
43
+ def get_scheduler_store() -> Generator[SQLAlchemyJobStore, Any, None]:
44
+ store = SQLAlchemyJobStore(engine=db.engine)
44
45
  try:
45
- scheduler.add_jobstore(SQLAlchemyJobStore(url=str(app_settings.DATABASE_URI)))
46
- except ValueError:
47
- pass
48
- scheduler.start(paused=True)
49
-
50
- try:
51
- yield scheduler
46
+ yield store
52
47
  finally:
53
- scheduler.shutdown()
54
- scheduler._jobstores["default"].engine.dispose()
48
+ store.shutdown()
49
+
50
+
51
+ def get_all_scheduler_tasks() -> list[Job]:
52
+ with get_scheduler_store() as scheduler_store:
53
+ return scheduler_store.get_all_jobs()
54
+
55
+
56
+ def get_scheduler_task(job_id: str) -> Job | None:
57
+ with get_scheduler_store() as scheduler_store:
58
+ return scheduler_store.lookup_job(job_id)
59
+
60
+
61
+ @contextmanager
62
+ def get_scheduler(paused: bool = False) -> Generator[BackgroundScheduler, Any, None]:
63
+ with get_scheduler_store() as store:
64
+ try:
65
+ scheduler.add_jobstore(store)
66
+ except ValueError:
67
+ pass
68
+ scheduler.start(paused=paused)
69
+
70
+ try:
71
+ yield scheduler
72
+ finally:
73
+ scheduler.shutdown()
55
74
 
56
75
 
57
76
  class ScheduledTask(BaseModel):
@@ -149,9 +168,7 @@ def get_scheduler_tasks(
149
168
  sort_by: list[Sort] | None = None,
150
169
  error_handler: CallableErrorHandler = default_error_handler,
151
170
  ) -> tuple[list[ScheduledTask], int]:
152
- with get_paused_scheduler() as pauzed_scheduler:
153
- scheduled_tasks = pauzed_scheduler.get_jobs()
154
-
171
+ scheduled_tasks = get_all_scheduler_tasks()
155
172
  scheduled_tasks = filter_scheduled_tasks(scheduled_tasks, error_handler, filter_by)
156
173
  scheduled_tasks = sort_scheduled_tasks(scheduled_tasks, error_handler, sort_by)
157
174
 
@@ -29,7 +29,7 @@ def validate_products() -> None:
29
29
  uncompleted_products = db.session.scalar(
30
30
  select(func.count())
31
31
  .select_from(ProcessTable)
32
- .filter(ProcessTable.workflow.name == "validate_products", ProcessTable.last_status != "completed")
32
+ .filter(ProcessTable.workflow.has(name="validate_products"), ProcessTable.last_status != "completed")
33
33
  )
34
34
  if not uncompleted_products:
35
35
  start_process("task_validate_products")