orchestrator-core 4.5.3__py3-none-any.whl → 4.6.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/agentic_app.py +1 -21
  3. orchestrator/api/api_v1/api.py +5 -0
  4. orchestrator/api/api_v1/endpoints/agent.py +50 -0
  5. orchestrator/api/api_v1/endpoints/search.py +120 -201
  6. orchestrator/cli/database.py +3 -0
  7. orchestrator/cli/generate.py +11 -4
  8. orchestrator/cli/generator/generator/migration.py +7 -3
  9. orchestrator/cli/search/resize_embedding.py +28 -22
  10. orchestrator/cli/search/speedtest.py +4 -6
  11. orchestrator/db/__init__.py +6 -0
  12. orchestrator/db/models.py +75 -0
  13. orchestrator/migrations/helpers.py +46 -38
  14. orchestrator/schedules/validate_products.py +1 -1
  15. orchestrator/schemas/search.py +8 -85
  16. orchestrator/search/agent/__init__.py +2 -2
  17. orchestrator/search/agent/agent.py +25 -29
  18. orchestrator/search/agent/json_patch.py +51 -0
  19. orchestrator/search/agent/prompts.py +35 -9
  20. orchestrator/search/agent/state.py +28 -2
  21. orchestrator/search/agent/tools.py +192 -53
  22. orchestrator/search/core/exceptions.py +6 -0
  23. orchestrator/search/core/types.py +1 -0
  24. orchestrator/search/export.py +199 -0
  25. orchestrator/search/indexing/indexer.py +13 -4
  26. orchestrator/search/indexing/registry.py +14 -1
  27. orchestrator/search/llm_migration.py +55 -0
  28. orchestrator/search/retrieval/__init__.py +3 -2
  29. orchestrator/search/retrieval/builder.py +5 -1
  30. orchestrator/search/retrieval/engine.py +66 -23
  31. orchestrator/search/retrieval/pagination.py +46 -56
  32. orchestrator/search/retrieval/query_state.py +61 -0
  33. orchestrator/search/retrieval/retrievers/base.py +26 -40
  34. orchestrator/search/retrieval/retrievers/fuzzy.py +10 -9
  35. orchestrator/search/retrieval/retrievers/hybrid.py +11 -8
  36. orchestrator/search/retrieval/retrievers/semantic.py +9 -8
  37. orchestrator/search/retrieval/retrievers/structured.py +6 -6
  38. orchestrator/search/schemas/parameters.py +17 -13
  39. orchestrator/search/schemas/results.py +4 -1
  40. orchestrator/settings.py +1 -0
  41. orchestrator/utils/auth.py +3 -2
  42. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0rc1.dist-info}/METADATA +3 -3
  43. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0rc1.dist-info}/RECORD +45 -41
  44. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0rc1.dist-info}/WHEEL +0 -0
  45. {orchestrator_core-4.5.3.dist-info → orchestrator_core-4.6.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ from sqlalchemy import text
4
4
  from sqlalchemy.exc import SQLAlchemyError
5
5
 
6
6
  from orchestrator.db import db
7
- from orchestrator.db.models import AiSearchIndex
7
+ from orchestrator.db.models import AiSearchIndex, SearchQueryTable
8
8
  from orchestrator.llm_settings import llm_settings
9
9
 
10
10
  logger = structlog.get_logger(__name__)
@@ -40,17 +40,20 @@ def get_current_embedding_dimension() -> int | None:
40
40
  return None
41
41
 
42
42
 
43
- def drop_all_embeddings() -> int:
44
- """Drop all records from the ai_search_index table.
43
+ def drop_all_embeddings() -> tuple[int, int]:
44
+ """Drop all records from ai_search_index and search_queries tables.
45
45
 
46
46
  Returns:
47
- Number of records deleted
47
+ Tuple of (ai_search_index records deleted, search_queries records deleted)
48
48
  """
49
49
  try:
50
- result = db.session.query(AiSearchIndex).delete()
50
+ index_deleted = db.session.query(AiSearchIndex).delete()
51
+ query_deleted = db.session.query(SearchQueryTable).delete()
51
52
  db.session.commit()
52
- logger.info(f"Deleted {result} records from ai_search_index")
53
- return result
53
+ logger.info(
54
+ f"Deleted {index_deleted} records from ai_search_index and {query_deleted} records from search_queries"
55
+ )
56
+ return index_deleted, query_deleted
54
57
 
55
58
  except SQLAlchemyError as e:
56
59
  db.session.rollback()
@@ -59,34 +62,34 @@ def drop_all_embeddings() -> int:
59
62
 
60
63
 
61
64
  def alter_embedding_column_dimension(new_dimension: int) -> None:
62
- """Alter the embedding column to use the new dimension size.
65
+ """Alter the embedding columns in both ai_search_index and search_queries tables.
63
66
 
64
67
  Args:
65
68
  new_dimension: New vector dimension size
66
69
  """
67
70
  try:
68
- drop_query = text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding")
69
- db.session.execute(drop_query)
71
+ db.session.execute(text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding"))
72
+ db.session.execute(text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})"))
70
73
 
71
- add_query = text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})")
72
- db.session.execute(add_query)
74
+ db.session.execute(text("ALTER TABLE search_queries DROP COLUMN IF EXISTS query_embedding"))
75
+ db.session.execute(text(f"ALTER TABLE search_queries ADD COLUMN query_embedding vector({new_dimension})"))
73
76
 
74
77
  db.session.commit()
75
- logger.info(f"Altered embedding column to dimension {new_dimension}")
78
+ logger.info(f"Altered embedding columns to dimension {new_dimension} in ai_search_index and search_queries")
76
79
 
77
80
  except SQLAlchemyError as e:
78
81
  db.session.rollback()
79
- logger.error("Failed to alter embedding column dimension", error=str(e))
82
+ logger.error("Failed to alter embedding column dimensions", error=str(e))
80
83
  raise
81
84
 
82
85
 
83
86
  @app.command("resize")
84
87
  def resize_embeddings_command() -> None:
85
- """Resize vector dimensions of the ai_search_index embedding column.
88
+ """Resize vector dimensions of embedding columns in ai_search_index and search_queries tables.
86
89
 
87
90
  Compares the current embedding dimension in the database with the configured
88
- dimension in llm_settings. If they differ, drops all records and alters the
89
- column to match the new dimension.
91
+ dimension in llm_settings. If they differ, drops all records and alters both
92
+ embedding columns to match the new dimension.
90
93
  """
91
94
  new_dimension = llm_settings.EMBEDDING_DIMENSION
92
95
 
@@ -107,22 +110,25 @@ def resize_embeddings_command() -> None:
107
110
 
108
111
  logger.info("Dimension mismatch detected", current_dimension=current_dimension, new_dimension=new_dimension)
109
112
 
110
- if not typer.confirm("This will DELETE ALL RECORDS from ai_search_index and alter the embedding column. Continue?"):
113
+ if not typer.confirm(
114
+ "This will DELETE ALL RECORDS from ai_search_index and search_queries tables and alter embedding columns. Continue?"
115
+ ):
111
116
  logger.info("Operation cancelled by user")
112
117
  return
113
118
 
114
119
  try:
115
120
  # Drop all records first.
116
121
  logger.info("Dropping all embedding records...")
117
- deleted_count = drop_all_embeddings()
122
+ index_deleted, query_deleted = drop_all_embeddings()
118
123
 
119
- # Then alter column dimension.
120
- logger.info(f"Altering embedding column to dimension {new_dimension}...")
124
+ # Then alter column dimensions.
125
+ logger.info(f"Altering embedding columns to dimension {new_dimension}...")
121
126
  alter_embedding_column_dimension(new_dimension)
122
127
 
123
128
  logger.info(
124
129
  "Embedding dimension resize completed successfully",
125
- records_deleted=deleted_count,
130
+ index_records_deleted=index_deleted,
131
+ query_records_deleted=query_deleted,
126
132
  new_dimension=new_dimension,
127
133
  )
128
134
 
@@ -13,7 +13,6 @@ from orchestrator.search.core.embedding import QueryEmbedder
13
13
  from orchestrator.search.core.types import EntityType
14
14
  from orchestrator.search.core.validators import is_uuid
15
15
  from orchestrator.search.retrieval.engine import execute_search
16
- from orchestrator.search.retrieval.pagination import PaginationParams
17
16
  from orchestrator.search.schemas.parameters import BaseSearchParameters
18
17
 
19
18
  logger = structlog.get_logger(__name__)
@@ -54,17 +53,16 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[
54
53
  async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]:
55
54
  search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30)
56
55
 
56
+ query_embedding = None
57
+
57
58
  if is_uuid(query):
58
- pagination_params = PaginationParams()
59
59
  logger.debug("Using fuzzy-only ranking for full UUID", query=query)
60
60
  else:
61
-
62
- cached_embedding = embedding_lookup[query]
63
- pagination_params = PaginationParams(q_vec_override=cached_embedding)
61
+ query_embedding = embedding_lookup[query]
64
62
 
65
63
  with db.session as session:
66
64
  start_time = time.perf_counter()
67
- response = await execute_search(search_params, session, pagination_params=pagination_params)
65
+ response = await execute_search(search_params, session, cursor=None, query_embedding=query_embedding)
68
66
  end_time = time.perf_counter()
69
67
 
70
68
  return {
@@ -17,6 +17,7 @@ from structlog import get_logger
17
17
  from orchestrator.db.database import BaseModel as DbBaseModel
18
18
  from orchestrator.db.database import Database, transactional
19
19
  from orchestrator.db.models import ( # noqa: F401
20
+ AgentRunTable,
20
21
  EngineSettingsTable,
21
22
  FixedInputTable,
22
23
  InputStateTable,
@@ -26,6 +27,7 @@ from orchestrator.db.models import ( # noqa: F401
26
27
  ProductBlockTable,
27
28
  ProductTable,
28
29
  ResourceTypeTable,
30
+ SearchQueryTable,
29
31
  SubscriptionCustomerDescriptionTable,
30
32
  SubscriptionInstanceRelationTable,
31
33
  SubscriptionInstanceTable,
@@ -74,6 +76,8 @@ def init_database(settings: AppSettings) -> Database:
74
76
 
75
77
  __all__ = [
76
78
  "transactional",
79
+ "SearchQueryTable",
80
+ "AgentRunTable",
77
81
  "SubscriptionTable",
78
82
  "ProcessSubscriptionTable",
79
83
  "ProcessTable",
@@ -97,6 +101,8 @@ __all__ = [
97
101
  ]
98
102
 
99
103
  ALL_DB_MODELS: list[type[DbBaseModel]] = [
104
+ SearchQueryTable,
105
+ AgentRunTable,
100
106
  FixedInputTable,
101
107
  ProcessStepTable,
102
108
  ProcessSubscriptionTable,
orchestrator/db/models.py CHANGED
@@ -15,6 +15,7 @@ from __future__ import annotations
15
15
 
16
16
  import enum
17
17
  from datetime import datetime, timezone
18
+ from typing import TYPE_CHECKING
18
19
  from uuid import UUID
19
20
 
20
21
  import sqlalchemy
@@ -58,6 +59,9 @@ from orchestrator.targets import Target
58
59
  from orchestrator.utils.datetime import nowtz
59
60
  from orchestrator.version import GIT_COMMIT_HASH
60
61
 
62
+ if TYPE_CHECKING:
63
+ from orchestrator.search.retrieval.query_state import SearchQueryState
64
+
61
65
  logger = structlog.get_logger(__name__)
62
66
 
63
67
  TAG_LENGTH = 20
@@ -674,6 +678,76 @@ class SubscriptionSearchView(BaseModel):
674
678
  subscription = relationship("SubscriptionTable", foreign_keys=[subscription_id])
675
679
 
676
680
 
681
+ class AgentRunTable(BaseModel):
682
+ """Agent conversation/session tracking."""
683
+
684
+ __tablename__ = "agent_runs"
685
+
686
+ run_id = mapped_column("run_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True)
687
+ agent_type = mapped_column(String(50), nullable=False)
688
+ created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False)
689
+
690
+ queries = relationship("SearchQueryTable", back_populates="run", cascade="delete", passive_deletes=True)
691
+
692
+ __table_args__ = (Index("ix_agent_runs_created_at", "created_at"),)
693
+
694
+
695
+ class SearchQueryTable(BaseModel):
696
+ """Search query execution - used by both agent runs and regular API searches.
697
+
698
+ When run_id is NULL: standalone API search query
699
+ When run_id is NOT NULL: query belongs to an agent conversation run
700
+ """
701
+
702
+ __tablename__ = "search_queries"
703
+
704
+ query_id = mapped_column("query_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True)
705
+ run_id = mapped_column(
706
+ "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=True, index=True
707
+ )
708
+ query_number = mapped_column(Integer, nullable=False)
709
+
710
+ # Search parameters as JSONB (maps to BaseSearchParameters subclasses)
711
+ parameters = mapped_column(pg.JSONB, nullable=False)
712
+
713
+ # Query embedding for semantic search (pgvector)
714
+ query_embedding = mapped_column(Vector(1536), nullable=True)
715
+
716
+ executed_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False)
717
+
718
+ run = relationship("AgentRunTable", back_populates="queries")
719
+
720
+ __table_args__ = (
721
+ Index("ix_search_queries_run_id", "run_id"),
722
+ Index("ix_search_queries_executed_at", "executed_at"),
723
+ Index("ix_search_queries_query_id", "query_id"),
724
+ )
725
+
726
+ @classmethod
727
+ def from_state(
728
+ cls,
729
+ state: "SearchQueryState",
730
+ run_id: "UUID | None" = None,
731
+ query_number: int = 1,
732
+ ) -> "SearchQueryTable":
733
+ """Create a SearchQueryTable instance from a SearchQueryState.
734
+
735
+ Args:
736
+ state: The search query state with parameters and embedding
737
+ run_id: Optional agent run ID (NULL for regular API searches)
738
+ query_number: Query number within the run (default=1)
739
+
740
+ Returns:
741
+ SearchQueryTable instance ready to be added to the database.
742
+ """
743
+ return cls(
744
+ run_id=run_id,
745
+ query_number=query_number,
746
+ parameters=state.parameters.model_dump(),
747
+ query_embedding=state.query_embedding,
748
+ )
749
+
750
+
677
751
  class EngineSettingsTable(BaseModel):
678
752
  __tablename__ = "engine_settings"
679
753
  global_lock = mapped_column(Boolean(), default=False, nullable=False, primary_key=True)
@@ -705,6 +779,7 @@ class AiSearchIndex(BaseModel):
705
779
  UUIDType,
706
780
  nullable=False,
707
781
  )
782
+ entity_title = mapped_column(TEXT, nullable=True)
708
783
 
709
784
  # Ltree path for hierarchical data
710
785
  path = mapped_column(LtreeType, nullable=False, index=True)
@@ -170,49 +170,57 @@ def create_workflow(conn: sa.engine.Connection, workflow: dict) -> None:
170
170
  }
171
171
  >>> create_workflow(conn, workflow)
172
172
  """
173
- if not workflow.get("is_task", False):
174
- workflow["is_task"] = False
173
+ params = workflow.copy()
174
+ params.setdefault("is_task", False)
175
+ params.setdefault("product_tag", None)
175
176
 
176
- if not workflow.get("product_tag"):
177
- workflow["product_tag"] = None
177
+ query_parts = []
178
178
 
179
179
  if has_table_column(table_name="workflows", column_name="is_task", conn=conn):
180
- query = """
181
- WITH new_workflow AS (
182
- INSERT INTO workflows (name, target, is_task, description)
183
- VALUES (:name, :target, :is_task, :description)
184
- ON CONFLICT DO NOTHING
185
- RETURNING workflow_id)
186
- INSERT
187
- INTO products_workflows (product_id, workflow_id)
188
- SELECT p.product_id,
189
- nw.workflow_id
190
- FROM products AS p
191
- CROSS JOIN new_workflow AS nw
192
- WHERE p.product_type = :product_type
193
- AND (:product_tag IS NULL OR p.tag = :product_tag)
194
- ON CONFLICT DO NOTHING
195
- """
180
+ query_parts.append(
181
+ """
182
+ WITH new_workflow AS (
183
+ INSERT INTO workflows (name, target, is_task, description)
184
+ VALUES (:name, :target, :is_task, :description)
185
+ ON CONFLICT DO NOTHING
186
+ RETURNING workflow_id
187
+ )
188
+ """
189
+ )
196
190
  else:
197
- # Remove is_task from workflow dict and insert SQL
198
- workflow = {k: v for k, v in workflow.items() if k != "is_task"}
199
- query = """
200
- WITH new_workflow AS (
201
- INSERT INTO workflows (name, target, description)
202
- VALUES (:name, :target, :description)
203
- ON CONFLICT DO NOTHING
204
- RETURNING workflow_id)
205
- INSERT
206
- INTO products_workflows (product_id, workflow_id)
207
- SELECT p.product_id, nw.workflow_id
208
- FROM products AS p
209
- CROSS JOIN new_workflow AS nw
210
- WHERE p.product_type = :product_type
211
- AND (:product_tag IS NULL OR p.tag = :product_tag)
212
- ON CONFLICT DO NOTHING
213
- """
191
+ params.pop("is_task", None)
192
+ query_parts.append(
193
+ """
194
+ WITH new_workflow AS (
195
+ INSERT INTO workflows (name, target, description)
196
+ VALUES (:name, :target, :description)
197
+ ON CONFLICT DO NOTHING
198
+ RETURNING workflow_id
199
+ )
200
+ """
201
+ )
202
+
203
+ query_parts.append(
204
+ """
205
+ INSERT INTO products_workflows (product_id, workflow_id)
206
+ SELECT p.product_id, nw.workflow_id
207
+ FROM products AS p
208
+ CROSS JOIN new_workflow AS nw
209
+ """
210
+ )
211
+
212
+ query_parts.append("WHERE p.product_type = :product_type")
213
+
214
+ if params.get("product_tag") is not None:
215
+ query_parts.append("AND p.tag = :product_tag")
216
+ else:
217
+ params.pop("product_tag", None)
218
+
219
+ query_parts.append("ON CONFLICT DO NOTHING")
220
+
221
+ query = "\n".join(query_parts)
214
222
 
215
- conn.execute(sa.text(query), workflow)
223
+ conn.execute(sa.text(query), params)
216
224
 
217
225
 
218
226
  def create_task(conn: sa.engine.Connection, task: dict) -> None:
@@ -29,7 +29,7 @@ def validate_products() -> None:
29
29
  uncompleted_products = db.session.scalar(
30
30
  select(func.count())
31
31
  .select_from(ProcessTable)
32
- .filter(ProcessTable.workflow.name == "validate_products", ProcessTable.last_status != "completed")
32
+ .filter(ProcessTable.workflow.has(name="validate_products"), ProcessTable.last_status != "completed")
33
33
  )
34
34
  if not uncompleted_products:
35
35
  start_process("task_validate_products")
@@ -11,14 +11,12 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
 
14
- from datetime import datetime
15
- from typing import Any, Generic, TypeVar
16
- from uuid import UUID
14
+ from typing import Generic, TypeVar
17
15
 
18
16
  from pydantic import BaseModel, ConfigDict, Field
19
17
 
20
18
  from orchestrator.search.core.types import SearchMetadata
21
- from orchestrator.search.schemas.results import ComponentInfo, LeafInfo, MatchingField
19
+ from orchestrator.search.schemas.results import ComponentInfo, LeafInfo
22
20
 
23
21
  T = TypeVar("T")
24
22
 
@@ -36,95 +34,20 @@ class ProductSchema(BaseModel):
36
34
  product_type: str
37
35
 
38
36
 
39
- class SubscriptionSearchResult(BaseModel):
40
- score: float
41
- perfect_match: int
42
- matching_field: MatchingField | None = None
43
- subscription: dict[str, Any]
44
-
45
-
46
37
  class SearchResultsSchema(BaseModel, Generic[T]):
47
38
  data: list[T] = Field(default_factory=list)
48
39
  page_info: PageInfoSchema = Field(default_factory=PageInfoSchema)
49
40
  search_metadata: SearchMetadata | None = None
50
41
 
51
42
 
52
- class WorkflowProductSchema(BaseModel):
53
- """Product associated with a workflow."""
54
-
55
- model_config = ConfigDict(from_attributes=True)
56
-
57
- product_type: str
58
- product_id: UUID
59
- name: str
60
-
61
-
62
- class WorkflowSearchSchema(BaseModel):
63
- """Schema for workflow search results."""
64
-
65
- model_config = ConfigDict(from_attributes=True)
66
-
67
- name: str
68
- products: list[WorkflowProductSchema]
69
- description: str | None = None
70
- created_at: datetime | None = None
71
-
72
-
73
- class ProductSearchSchema(BaseModel):
74
- """Schema for product search results."""
75
-
76
- model_config = ConfigDict(from_attributes=True)
77
-
78
- product_id: UUID
79
- name: str
80
- product_type: str
81
- tag: str | None = None
82
- description: str | None = None
83
- status: str | None = None
84
- created_at: datetime | None = None
85
-
86
-
87
- class ProcessSearchSchema(BaseModel):
88
- """Schema for process search results."""
89
-
90
- model_config = ConfigDict(from_attributes=True)
91
-
92
- process_id: UUID
93
- workflow_name: str
94
- workflow_id: UUID
95
- last_status: str
96
- is_task: bool
97
- created_by: str | None = None
98
- started_at: datetime
99
- last_modified_at: datetime
100
- last_step: str | None = None
101
- failed_reason: str | None = None
102
- subscription_ids: list[UUID] | None = None
103
-
104
-
105
- class WorkflowSearchResult(BaseModel):
106
- score: float
107
- perfect_match: int
108
- matching_field: MatchingField | None = None
109
- workflow: WorkflowSearchSchema
110
-
111
-
112
- class ProductSearchResult(BaseModel):
113
- score: float
114
- perfect_match: int
115
- matching_field: MatchingField | None = None
116
- product: ProductSearchSchema
117
-
118
-
119
- class ProcessSearchResult(BaseModel):
120
- score: float
121
- perfect_match: int
122
- matching_field: MatchingField | None = None
123
- process: ProcessSearchSchema
124
-
125
-
126
43
  class PathsResponse(BaseModel):
127
44
  leaves: list[LeafInfo]
128
45
  components: list[ComponentInfo]
129
46
 
130
47
  model_config = ConfigDict(extra="forbid", use_enum_values=True)
48
+
49
+
50
+ class ExportResponse(BaseModel):
51
+ page: list[dict]
52
+
53
+ model_config = ConfigDict(extra="forbid")
@@ -14,8 +14,8 @@
14
14
  # This module requires: pydantic-ai==0.7.0, ag-ui-protocol>=0.1.8
15
15
 
16
16
 
17
- from orchestrator.search.agent.agent import build_agent_router
17
+ from orchestrator.search.agent.agent import build_agent_instance
18
18
 
19
19
  __all__ = [
20
- "build_agent_router",
20
+ "build_agent_instance",
21
21
  ]
@@ -14,13 +14,11 @@
14
14
  from typing import Any
15
15
 
16
16
  import structlog
17
- from fastapi import APIRouter, HTTPException, Request
18
- from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request
17
+ from pydantic_ai.ag_ui import StateDeps
19
18
  from pydantic_ai.agent import Agent
20
19
  from pydantic_ai.models.openai import OpenAIModel
21
20
  from pydantic_ai.settings import ModelSettings
22
21
  from pydantic_ai.toolsets import FunctionToolset
23
- from starlette.responses import Response
24
22
 
25
23
  from orchestrator.search.agent.prompts import get_base_instructions, get_dynamic_instructions
26
24
  from orchestrator.search.agent.state import SearchState
@@ -29,34 +27,32 @@ from orchestrator.search.agent.tools import search_toolset
29
27
  logger = structlog.get_logger(__name__)
30
28
 
31
29
 
32
- def build_agent_router(model: str | OpenAIModel, toolsets: list[FunctionToolset[Any]] | None = None) -> APIRouter:
33
- router = APIRouter()
30
+ def build_agent_instance(
31
+ model: str | OpenAIModel, agent_tools: list[FunctionToolset[Any]] | None = None
32
+ ) -> Agent[StateDeps[SearchState], str]:
33
+ """Build and configure the search agent instance.
34
34
 
35
- try:
36
- toolsets = toolsets + [search_toolset] if toolsets else [search_toolset]
35
+ Args:
36
+ model: The LLM model to use (string or OpenAIModel instance)
37
+ agent_tools: Optional list of additional toolsets to include
37
38
 
38
- agent = Agent(
39
- model=model,
40
- deps_type=StateDeps[SearchState],
41
- model_settings=ModelSettings(
42
- parallel_tool_calls=False,
43
- ), # https://github.com/pydantic/pydantic-ai/issues/562
44
- toolsets=toolsets,
45
- )
46
- agent.instructions(get_base_instructions)
47
- agent.instructions(get_dynamic_instructions)
39
+ Returns:
40
+ Configured Agent instance with StateDeps[SearchState] dependencies
48
41
 
49
- @router.post("/")
50
- async def agent_endpoint(request: Request) -> Response:
51
- return await handle_ag_ui_request(agent, request, deps=StateDeps(SearchState()))
42
+ Raises:
43
+ Exception: If agent initialization fails
44
+ """
45
+ toolsets = agent_tools + [search_toolset] if agent_tools else [search_toolset]
52
46
 
53
- return router
54
- except Exception as e:
55
- logger.error("Agent init failed; serving disabled stub.", error=str(e))
56
- error_msg = f"Agent disabled: {str(e)}"
47
+ agent = Agent(
48
+ model=model,
49
+ deps_type=StateDeps[SearchState],
50
+ model_settings=ModelSettings(
51
+ parallel_tool_calls=False,
52
+ ), # https://github.com/pydantic/pydantic-ai/issues/562
53
+ toolsets=toolsets,
54
+ )
55
+ agent.instructions(get_base_instructions)
56
+ agent.instructions(get_dynamic_instructions)
57
57
 
58
- @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"])
59
- async def _disabled(path: str) -> None:
60
- raise HTTPException(status_code=503, detail=error_msg)
61
-
62
- return router
58
+ return agent
@@ -0,0 +1,51 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from typing import Any, Literal
15
+
16
+ from pydantic import BaseModel, Field
17
+
18
+
19
+ class JSONPatchOp(BaseModel):
20
+ """A JSON Patch operation (RFC 6902).
21
+
22
+ Docs reference: https://docs.ag-ui.com/concepts/state
23
+ """
24
+
25
+ op: Literal["add", "remove", "replace", "move", "copy", "test"] = Field(
26
+ description="The operation to perform: add, remove, replace, move, copy, or test"
27
+ )
28
+ path: str = Field(description="JSON Pointer (RFC 6901) to the target location")
29
+ value: Any | None = Field(
30
+ default=None,
31
+ description="The value to apply (for add, replace operations)",
32
+ )
33
+ from_: str | None = Field(
34
+ default=None,
35
+ alias="from",
36
+ description="Source path (for move, copy operations)",
37
+ )
38
+
39
+ @classmethod
40
+ def upsert(cls, path: str, value: Any, existed: bool) -> "JSONPatchOp":
41
+ """Create an add or replace operation depending on whether the path existed.
42
+
43
+ Args:
44
+ path: JSON Pointer path to the target location
45
+ value: The value to set
46
+ existed: True if the path already exists (use replace), False otherwise (use add)
47
+
48
+ Returns:
49
+ JSONPatchOp with 'replace' if existed is True, 'add' otherwise
50
+ """
51
+ return cls(op="replace" if existed else "add", path=path, value=value)