dao-ai 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/cli.py CHANGED
@@ -47,6 +47,57 @@ def get_default_user_id() -> str:
47
47
  return local_user
48
48
 
49
49
 
50
+ def detect_cloud_provider(profile: Optional[str] = None) -> Optional[str]:
51
+ """
52
+ Detect the cloud provider from the Databricks workspace URL.
53
+
54
+ The cloud provider is determined by the workspace URL pattern:
55
+ - Azure: *.azuredatabricks.net
56
+ - AWS: *.cloud.databricks.com (without gcp subdomain)
57
+ - GCP: *.gcp.databricks.com
58
+
59
+ Args:
60
+ profile: Optional Databricks CLI profile name
61
+
62
+ Returns:
63
+ Cloud provider string ('azure', 'aws', 'gcp') or None if detection fails
64
+ """
65
+ try:
66
+ from databricks.sdk import WorkspaceClient
67
+
68
+ # Create workspace client with optional profile
69
+ if profile:
70
+ w = WorkspaceClient(profile=profile)
71
+ else:
72
+ w = WorkspaceClient()
73
+
74
+ # Get the workspace URL from config
75
+ host = w.config.host
76
+ if not host:
77
+ logger.warning("Could not determine workspace URL for cloud detection")
78
+ return None
79
+
80
+ host_lower = host.lower()
81
+
82
+ if "azuredatabricks.net" in host_lower:
83
+ logger.debug(f"Detected Azure cloud from workspace URL: {host}")
84
+ return "azure"
85
+ elif ".gcp.databricks.com" in host_lower:
86
+ logger.debug(f"Detected GCP cloud from workspace URL: {host}")
87
+ return "gcp"
88
+ elif ".cloud.databricks.com" in host_lower or "databricks.com" in host_lower:
89
+ # AWS uses *.cloud.databricks.com or regional patterns
90
+ logger.debug(f"Detected AWS cloud from workspace URL: {host}")
91
+ return "aws"
92
+ else:
93
+ logger.warning(f"Could not determine cloud provider from URL: {host}")
94
+ return None
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Could not detect cloud provider: {e}")
98
+ return None
99
+
100
+
50
101
  env_path: str = find_dotenv()
51
102
  if env_path:
52
103
  logger.info(f"Loading environment variables from: {env_path}")
@@ -220,6 +271,13 @@ Examples:
220
271
  "-t",
221
272
  "--target",
222
273
  type=str,
274
+ help="Bundle target name (default: auto-generated from app name and cloud)",
275
+ )
276
+ bundle_parser.add_argument(
277
+ "--cloud",
278
+ type=str,
279
+ choices=["azure", "aws", "gcp"],
280
+ help="Cloud provider (auto-detected from workspace URL if not specified)",
223
281
  )
224
282
  bundle_parser.add_argument(
225
283
  "--dry-run",
@@ -549,13 +607,6 @@ def handle_chat_command(options: Namespace) -> None:
549
607
  # Find the last AI message
550
608
  for msg in reversed(latest_messages):
551
609
  if isinstance(msg, AIMessage):
552
- logger.debug(f"AI message content: {msg.content}")
553
- logger.debug(
554
- f"AI message has tool_calls: {hasattr(msg, 'tool_calls')}"
555
- )
556
- if hasattr(msg, "tool_calls"):
557
- logger.debug(f"Tool calls: {msg.tool_calls}")
558
-
559
610
  if hasattr(msg, "content") and msg.content:
560
611
  response_content = msg.content
561
612
  print(response_content, end="", flush=True)
@@ -676,7 +727,7 @@ def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
676
727
  4. Returns the path to the generated file
677
728
 
678
729
  The generated databricks.yaml is overwritten on each deployment and is not tracked in git.
679
- Schema reference remains pointing to ./schemas/bundle_config_schema.json.
730
+ The template contains cloud-specific targets (azure, aws, gcp) with appropriate node types.
680
731
 
681
732
  Args:
682
733
  config_path: Path to the app config file
@@ -713,39 +764,59 @@ def run_databricks_command(
713
764
  profile: Optional[str] = None,
714
765
  config: Optional[str] = None,
715
766
  target: Optional[str] = None,
767
+ cloud: Optional[str] = None,
716
768
  dry_run: bool = False,
717
769
  ) -> None:
718
- """Execute a databricks CLI command with optional profile and target."""
770
+ """Execute a databricks CLI command with optional profile, target, and cloud.
771
+
772
+ Args:
773
+ command: The databricks CLI command to execute (e.g., ["bundle", "deploy"])
774
+ profile: Optional Databricks CLI profile name
775
+ config: Optional path to the configuration file
776
+ target: Optional bundle target name (if not provided, auto-generated from app name and cloud)
777
+ cloud: Optional cloud provider ('azure', 'aws', 'gcp'). Auto-detected if not specified.
778
+ dry_run: If True, print the command without executing
779
+ """
719
780
  config_path = Path(config) if config else None
720
781
 
721
782
  if config_path and not config_path.exists():
722
783
  logger.error(f"Configuration file {config_path} does not exist.")
723
784
  sys.exit(1)
724
785
 
725
- # Load app config and generate bundle from template
786
+ # Load app config
726
787
  app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
727
788
  normalized_name: str = normalize_name(app_config.app.name) if app_config else None
728
789
 
790
+ # Auto-detect cloud provider if not specified
791
+ if not cloud:
792
+ cloud = detect_cloud_provider(profile)
793
+ if cloud:
794
+ logger.info(f"Auto-detected cloud provider: {cloud}")
795
+ else:
796
+ logger.warning("Could not detect cloud provider. Defaulting to 'azure'.")
797
+ cloud = "azure"
798
+
729
799
  # Generate app-specific bundle from template (overwrites databricks.yaml temporarily)
730
800
  if config_path and app_config:
731
801
  generate_bundle_from_template(config_path, normalized_name)
732
802
 
733
- # Use app name as target if not explicitly provided
734
- # This ensures each app gets its own Terraform state in .databricks/bundle/<app-name>/
735
- if not target and normalized_name:
736
- target = normalized_name
737
- logger.debug(f"Using app-specific target: {target}")
803
+ # Use cloud as target (azure, aws, gcp) - can be overridden with explicit --target
804
+ if not target:
805
+ target = cloud
806
+ logger.debug(f"Using cloud-based target: {target}")
738
807
 
739
- # Build databricks command (no -c flag needed, uses databricks.yaml in current dir)
808
+ # Build databricks command
809
+ # --profile is a global flag, --target is a subcommand flag for 'bundle'
740
810
  cmd = ["databricks"]
741
811
  if profile:
742
812
  cmd.extend(["--profile", profile])
743
813
 
814
+ cmd.extend(command)
815
+
816
+ # --target must come after the bundle subcommand (it's a subcommand-specific flag)
744
817
  if target:
745
818
  cmd.extend(["--target", target])
746
819
 
747
- cmd.extend(command)
748
-
749
820
  # Add config_path variable for notebooks
750
821
  if config_path and app_config:
751
822
  # Calculate relative path from notebooks directory to config file
@@ -800,30 +871,38 @@ def handle_bundle_command(options: Namespace) -> None:
800
871
  profile: Optional[str] = options.profile
801
872
  config: Optional[str] = options.config
802
873
  target: Optional[str] = options.target
874
+ cloud: Optional[str] = options.cloud
803
875
  dry_run: bool = options.dry_run
804
876
 
805
877
  if options.deploy:
806
878
  logger.info("Deploying DAO AI asset bundle...")
807
879
  run_databricks_command(
808
- ["bundle", "deploy"], profile, config, target, dry_run=dry_run
880
+ ["bundle", "deploy"],
881
+ profile=profile,
882
+ config=config,
883
+ target=target,
884
+ cloud=cloud,
885
+ dry_run=dry_run,
809
886
  )
810
887
  if options.run:
811
888
  logger.info("Running DAO AI system with current configuration...")
812
889
  # Use static job resource key that matches databricks.yaml (resources.jobs.deploy_job)
813
890
  run_databricks_command(
814
891
  ["bundle", "run", "deploy_job"],
815
- profile,
816
- config,
817
- target,
892
+ profile=profile,
893
+ config=config,
894
+ target=target,
895
+ cloud=cloud,
818
896
  dry_run=dry_run,
819
897
  )
820
898
  if options.destroy:
821
899
  logger.info("Destroying DAO AI system with current configuration...")
822
900
  run_databricks_command(
823
901
  ["bundle", "destroy", "--auto-approve"],
824
- profile,
825
- config,
826
- target,
902
+ profile=profile,
903
+ config=config,
904
+ target=target,
905
+ cloud=cloud,
827
906
  dry_run=dry_run,
828
907
  )
829
908
  else:
dao_ai/config.py CHANGED
@@ -601,6 +601,8 @@ class VectorSearchEndpoint(BaseModel):
601
601
 
602
602
 
603
603
  class IndexModel(IsDatabricksResource, HasFullName):
604
+ """Model representing a Databricks Vector Search index."""
605
+
604
606
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
605
607
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
606
608
  name: str
@@ -624,6 +626,22 @@ class IndexModel(IsDatabricksResource, HasFullName):
624
626
  )
625
627
  ]
626
628
 
629
+ def exists(self) -> bool:
630
+ """Check if this vector search index exists.
631
+
632
+ Returns:
633
+ True if the index exists, False otherwise.
634
+ """
635
+ try:
636
+ self.workspace_client.vector_search_indexes.get_index(self.full_name)
637
+ return True
638
+ except NotFound:
639
+ logger.debug(f"Index not found: {self.full_name}")
640
+ return False
641
+ except Exception as e:
642
+ logger.warning(f"Error checking index existence for {self.full_name}: {e}")
643
+ return False
644
+
627
645
 
628
646
  class FunctionModel(IsDatabricksResource, HasFullName):
629
647
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1009,27 +1027,92 @@ class VolumePathModel(BaseModel, HasFullName):
1009
1027
 
1010
1028
 
1011
1029
  class VectorStoreModel(IsDatabricksResource):
1030
+ """
1031
+ Configuration model for a Databricks Vector Search store.
1032
+
1033
+ Supports two modes:
1034
+ 1. **Use Existing Index**: Provide only `index` (fully qualified name).
1035
+ Used for querying an existing vector search index at runtime.
1036
+ 2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
1037
+ Used for creating a new vector search index.
1038
+
1039
+ Examples:
1040
+ Minimal configuration (use existing index):
1041
+ ```yaml
1042
+ vector_stores:
1043
+ products_search:
1044
+ index:
1045
+ name: catalog.schema.my_index
1046
+ ```
1047
+
1048
+ Full provisioning configuration:
1049
+ ```yaml
1050
+ vector_stores:
1051
+ products_search:
1052
+ source_table:
1053
+ schema: *my_schema
1054
+ name: products
1055
+ embedding_source_column: description
1056
+ endpoint:
1057
+ name: my_endpoint
1058
+ ```
1059
+ """
1060
+
1012
1061
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1013
- embedding_model: Optional[LLMModel] = None
1062
+
1063
+ # RUNTIME: Only index is truly required for querying existing indexes
1014
1064
  index: Optional[IndexModel] = None
1065
+
1066
+ # PROVISIONING ONLY: Required when creating a new index
1067
+ source_table: Optional[TableModel] = None
1068
+ embedding_source_column: Optional[str] = None
1069
+ embedding_model: Optional[LLMModel] = None
1015
1070
  endpoint: Optional[VectorSearchEndpoint] = None
1016
- source_table: TableModel
1071
+
1072
+ # OPTIONAL: For both modes
1017
1073
  source_path: Optional[VolumePathModel] = None
1018
1074
  checkpoint_path: Optional[VolumePathModel] = None
1019
1075
  primary_key: Optional[str] = None
1020
1076
  columns: Optional[list[str]] = Field(default_factory=list)
1021
1077
  doc_uri: Optional[str] = None
1022
- embedding_source_column: str
1078
+
1079
+ @model_validator(mode="after")
1080
+ def validate_configuration_mode(self) -> Self:
1081
+ """
1082
+ Validate that configuration is valid for either:
1083
+ - Use existing mode: index is provided
1084
+ - Provisioning mode: source_table + embedding_source_column provided
1085
+ """
1086
+ has_index = self.index is not None
1087
+ has_source_table = self.source_table is not None
1088
+ has_embedding_col = self.embedding_source_column is not None
1089
+
1090
+ # Must have at least index OR source_table
1091
+ if not has_index and not has_source_table:
1092
+ raise ValueError(
1093
+ "Either 'index' (for existing indexes) or 'source_table' "
1094
+ "(for provisioning) must be provided"
1095
+ )
1096
+
1097
+ # If provisioning mode, need embedding_source_column
1098
+ if has_source_table and not has_embedding_col:
1099
+ raise ValueError(
1100
+ "embedding_source_column is required when source_table is provided (provisioning mode)"
1101
+ )
1102
+
1103
+ return self
1023
1104
 
1024
1105
  @model_validator(mode="after")
1025
1106
  def set_default_embedding_model(self) -> Self:
1026
- if not self.embedding_model:
1107
+ # Only set default embedding model in provisioning mode
1108
+ if self.source_table is not None and not self.embedding_model:
1027
1109
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
1028
1110
  return self
1029
1111
 
1030
1112
  @model_validator(mode="after")
1031
1113
  def set_default_primary_key(self) -> Self:
1032
- if self.primary_key is None:
1114
+ # Only auto-discover primary key in provisioning mode
1115
+ if self.primary_key is None and self.source_table is not None:
1033
1116
  from dao_ai.providers.databricks import DatabricksProvider
1034
1117
 
1035
1118
  provider: DatabricksProvider = DatabricksProvider()
@@ -1050,14 +1133,16 @@ class VectorStoreModel(IsDatabricksResource):
1050
1133
 
1051
1134
  @model_validator(mode="after")
1052
1135
  def set_default_index(self) -> Self:
1053
- if self.index is None:
1136
+ # Only generate index from source_table in provisioning mode
1137
+ if self.index is None and self.source_table is not None:
1054
1138
  name: str = f"{self.source_table.name}_index"
1055
1139
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
1056
1140
  return self
1057
1141
 
1058
1142
  @model_validator(mode="after")
1059
1143
  def set_default_endpoint(self) -> Self:
1060
- if self.endpoint is None:
1144
+ # Only find/create endpoint in provisioning mode
1145
+ if self.endpoint is None and self.source_table is not None:
1061
1146
  from dao_ai.providers.databricks import (
1062
1147
  DatabricksProvider,
1063
1148
  with_available_indexes,
@@ -1092,18 +1177,60 @@ class VectorStoreModel(IsDatabricksResource):
1092
1177
  return self.index.as_resources()
1093
1178
 
1094
1179
  def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
1095
- from dao_ai.providers.base import ServiceProvider
1096
1180
  from dao_ai.providers.databricks import DatabricksProvider
1097
1181
 
1098
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
1182
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
1099
1183
  index: VectorSearchIndex = provider.get_vector_index(self)
1100
1184
  return index
1101
1185
 
1102
1186
  def create(self, vsc: VectorSearchClient | None = None) -> None:
1103
- from dao_ai.providers.base import ServiceProvider
1187
+ """
1188
+ Create or validate the vector search index.
1189
+
1190
+ Behavior depends on configuration mode:
1191
+ - **Provisioning Mode** (source_table provided): Creates the index
1192
+ - **Use Existing Mode** (only index provided): Validates the index exists
1193
+
1194
+ Args:
1195
+ vsc: Optional VectorSearchClient instance
1196
+
1197
+ Raises:
1198
+ ValueError: If configuration is invalid or index doesn't exist
1199
+ """
1104
1200
  from dao_ai.providers.databricks import DatabricksProvider
1105
1201
 
1106
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
1202
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
1203
+
1204
+ if self.source_table is not None:
1205
+ self._create_new_index(provider)
1206
+ else:
1207
+ self._validate_existing_index(provider)
1208
+
1209
+ def _validate_existing_index(self, provider: Any) -> None:
1210
+ """Validate that an existing index is accessible."""
1211
+ if self.index is None:
1212
+ raise ValueError("index is required for 'use existing' mode")
1213
+
1214
+ if self.index.exists():
1215
+ logger.info(
1216
+ "Vector search index exists and ready",
1217
+ index_name=self.index.full_name,
1218
+ )
1219
+ else:
1220
+ raise ValueError(
1221
+ f"Index '{self.index.full_name}' does not exist. "
1222
+ "Provide 'source_table' to provision it."
1223
+ )
1224
+
1225
+ def _create_new_index(self, provider: Any) -> None:
1226
+ """Create a new vector search index from source table."""
1227
+ if self.embedding_source_column is None:
1228
+ raise ValueError("embedding_source_column is required for provisioning")
1229
+ if self.endpoint is None:
1230
+ raise ValueError("endpoint is required for provisioning")
1231
+ if self.index is None:
1232
+ raise ValueError("index is required for provisioning")
1233
+
1107
1234
  provider.create_vector_store(self)
1108
1235
 
1109
1236
 
@@ -1266,32 +1393,12 @@ class DatabaseModel(IsDatabricksResource):
1266
1393
 
1267
1394
  @model_validator(mode="after")
1268
1395
  def update_host(self) -> Self:
1269
- if self.host is not None:
1396
+ # Lakebase uses instance_name directly via databricks_langchain - host not needed
1397
+ if self.is_lakebase:
1270
1398
  return self
1271
1399
 
1272
- # If instance_name is provided (Lakebase), try to fetch host from existing instance
1273
- # This may fail for OBO/ambient auth during model logging (before deployment)
1274
- if self.is_lakebase:
1275
- try:
1276
- existing_instance: DatabaseInstance = (
1277
- self.workspace_client.database.get_database_instance(
1278
- name=self.instance_name
1279
- )
1280
- )
1281
- self.host = existing_instance.read_write_dns
1282
- except Exception as e:
1283
- # For Lakebase with OBO/ambient auth, we can't fetch at config time
1284
- # The host will need to be provided explicitly or fetched at runtime
1285
- if self.on_behalf_of_user:
1286
- logger.debug(
1287
- f"Could not fetch host for database {self.instance_name} "
1288
- f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1289
- )
1290
- else:
1291
- raise ValueError(
1292
- f"Could not fetch host for database {self.instance_name}. "
1293
- f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1294
- )
1400
+ # For standard PostgreSQL, host must be provided by the user
1401
+ # (enforced by validate_connection_type)
1295
1402
  return self
1296
1403
 
1297
1404
  @model_validator(mode="after")
@@ -1549,11 +1656,13 @@ class RerankParametersModel(BaseModel):
1549
1656
  top_n: 5 # Return top 5 after reranking
1550
1657
  ```
1551
1658
 
1552
- Available models (from fastest to most accurate):
1553
- - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
1554
- - "ms-marco-MiniLM-L-6-v2"
1555
- - "ms-marco-MiniLM-L-12-v2" (default, good balance)
1556
- - "rank-T5-flan" (most accurate, slower)
1659
+ Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
1660
+ - "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
1661
+ - "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
1662
+ - "rank-T5-flan" (~110MB, best non cross-encoder)
1663
+ - "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
1664
+ - "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
1665
+ - "miniReranker_arabic_v1" (Arabic language)
1557
1666
  """
1558
1667
 
1559
1668
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -3,8 +3,15 @@
3
3
 
4
4
  # Re-export LangChain built-in middleware
5
5
  from langchain.agents.middleware import (
6
+ ClearToolUsesEdit,
7
+ ContextEditingMiddleware,
6
8
  HumanInTheLoopMiddleware,
9
+ ModelCallLimitMiddleware,
10
+ ModelRetryMiddleware,
11
+ PIIMiddleware,
7
12
  SummarizationMiddleware,
13
+ ToolCallLimitMiddleware,
14
+ ToolRetryMiddleware,
8
15
  after_agent,
9
16
  after_model,
10
17
  before_agent,
@@ -37,6 +44,10 @@ from dao_ai.middleware.base import (
37
44
  ModelRequest,
38
45
  ModelResponse,
39
46
  )
47
+ from dao_ai.middleware.context_editing import (
48
+ create_clear_tool_uses_edit,
49
+ create_context_editing_middleware,
50
+ )
40
51
  from dao_ai.middleware.core import create_factory_middleware
41
52
  from dao_ai.middleware.guardrails import (
42
53
  ContentFilterMiddleware,
@@ -62,10 +73,15 @@ from dao_ai.middleware.message_validation import (
62
73
  create_thread_id_validation_middleware,
63
74
  create_user_id_validation_middleware,
64
75
  )
76
+ from dao_ai.middleware.model_call_limit import create_model_call_limit_middleware
77
+ from dao_ai.middleware.model_retry import create_model_retry_middleware
78
+ from dao_ai.middleware.pii import create_pii_middleware
65
79
  from dao_ai.middleware.summarization import (
66
80
  LoggingSummarizationMiddleware,
67
81
  create_summarization_middleware,
68
82
  )
83
+ from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
84
+ from dao_ai.middleware.tool_retry import create_tool_retry_middleware
69
85
 
70
86
  __all__ = [
71
87
  # Base class (from LangChain)
@@ -85,6 +101,13 @@ __all__ = [
85
101
  "SummarizationMiddleware",
86
102
  "LoggingSummarizationMiddleware",
87
103
  "HumanInTheLoopMiddleware",
104
+ "ToolCallLimitMiddleware",
105
+ "ModelCallLimitMiddleware",
106
+ "ToolRetryMiddleware",
107
+ "ModelRetryMiddleware",
108
+ "ContextEditingMiddleware",
109
+ "ClearToolUsesEdit",
110
+ "PIIMiddleware",
88
111
  # Core factory function
89
112
  "create_factory_middleware",
90
113
  # DAO AI middleware implementations
@@ -122,4 +145,14 @@ __all__ = [
122
145
  "create_assert_middleware",
123
146
  "create_suggest_middleware",
124
147
  "create_refine_middleware",
148
+ # Limit and retry middleware factory functions
149
+ "create_tool_call_limit_middleware",
150
+ "create_model_call_limit_middleware",
151
+ "create_tool_retry_middleware",
152
+ "create_model_retry_middleware",
153
+ # Context editing middleware factory functions
154
+ "create_context_editing_middleware",
155
+ "create_clear_tool_uses_edit",
156
+ # PII middleware factory functions
157
+ "create_pii_middleware",
125
158
  ]
@@ -688,7 +688,7 @@ def create_assert_middleware(
688
688
  name: Name for function constraints
689
689
 
690
690
  Returns:
691
- AssertMiddleware configured with the constraint
691
+ List containing AssertMiddleware configured with the constraint
692
692
 
693
693
  Example:
694
694
  # Using a Constraint class
@@ -737,7 +737,7 @@ def create_suggest_middleware(
737
737
  name: Name for function constraints
738
738
 
739
739
  Returns:
740
- SuggestMiddleware configured with the constraint
740
+ List containing SuggestMiddleware configured with the constraint
741
741
 
742
742
  Example:
743
743
  def is_professional(response: str, ctx: dict) -> ConstraintResult:
@@ -783,7 +783,7 @@ def create_refine_middleware(
783
783
  select_best: Track and return best response across iterations
784
784
 
785
785
  Returns:
786
- RefineMiddleware configured with the reward function
786
+ List containing RefineMiddleware configured with the reward function
787
787
 
788
788
  Example:
789
789
  def evaluate_completeness(response: str, ctx: dict) -> float: