optexity 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. optexity/examples/__init__.py +0 -0
  2. optexity/examples/add_example.py +88 -0
  3. optexity/examples/download_pdf_url.py +29 -0
  4. optexity/examples/extract_price_stockanalysis.py +44 -0
  5. optexity/examples/file_upload.py +59 -0
  6. optexity/examples/i94.py +126 -0
  7. optexity/examples/i94_travel_history.py +126 -0
  8. optexity/examples/peachstate_medicaid.py +201 -0
  9. optexity/examples/supabase_login.py +75 -0
  10. optexity/inference/__init__.py +0 -0
  11. optexity/inference/agents/__init__.py +0 -0
  12. optexity/inference/agents/error_handler/__init__.py +0 -0
  13. optexity/inference/agents/error_handler/error_handler.py +39 -0
  14. optexity/inference/agents/error_handler/prompt.py +60 -0
  15. optexity/inference/agents/index_prediction/__init__.py +0 -0
  16. optexity/inference/agents/index_prediction/action_prediction_locator_axtree.py +45 -0
  17. optexity/inference/agents/index_prediction/prompt.py +14 -0
  18. optexity/inference/agents/select_value_prediction/__init__.py +0 -0
  19. optexity/inference/agents/select_value_prediction/prompt.py +20 -0
  20. optexity/inference/agents/select_value_prediction/select_value_prediction.py +39 -0
  21. optexity/inference/agents/two_fa_extraction/__init__.py +0 -0
  22. optexity/inference/agents/two_fa_extraction/prompt.py +23 -0
  23. optexity/inference/agents/two_fa_extraction/two_fa_extraction.py +47 -0
  24. optexity/inference/child_process.py +251 -0
  25. optexity/inference/core/__init__.py +0 -0
  26. optexity/inference/core/interaction/__init__.py +0 -0
  27. optexity/inference/core/interaction/handle_agentic_task.py +79 -0
  28. optexity/inference/core/interaction/handle_check.py +57 -0
  29. optexity/inference/core/interaction/handle_click.py +79 -0
  30. optexity/inference/core/interaction/handle_command.py +261 -0
  31. optexity/inference/core/interaction/handle_input.py +76 -0
  32. optexity/inference/core/interaction/handle_keypress.py +16 -0
  33. optexity/inference/core/interaction/handle_select.py +109 -0
  34. optexity/inference/core/interaction/handle_select_utils.py +132 -0
  35. optexity/inference/core/interaction/handle_upload.py +59 -0
  36. optexity/inference/core/interaction/utils.py +81 -0
  37. optexity/inference/core/logging.py +406 -0
  38. optexity/inference/core/run_assertion.py +55 -0
  39. optexity/inference/core/run_automation.py +463 -0
  40. optexity/inference/core/run_extraction.py +240 -0
  41. optexity/inference/core/run_interaction.py +254 -0
  42. optexity/inference/core/run_python_script.py +20 -0
  43. optexity/inference/core/run_two_fa.py +120 -0
  44. optexity/inference/core/two_factor_auth/__init__.py +0 -0
  45. optexity/inference/infra/__init__.py +0 -0
  46. optexity/inference/infra/browser.py +455 -0
  47. optexity/inference/infra/browser_extension.py +20 -0
  48. optexity/inference/models/__init__.py +22 -0
  49. optexity/inference/models/gemini.py +113 -0
  50. optexity/inference/models/human.py +20 -0
  51. optexity/inference/models/llm_model.py +210 -0
  52. optexity/inference/run_local.py +200 -0
  53. optexity/schema/__init__.py +0 -0
  54. optexity/schema/actions/__init__.py +0 -0
  55. optexity/schema/actions/assertion_action.py +66 -0
  56. optexity/schema/actions/extraction_action.py +143 -0
  57. optexity/schema/actions/interaction_action.py +330 -0
  58. optexity/schema/actions/misc_action.py +18 -0
  59. optexity/schema/actions/prompts.py +27 -0
  60. optexity/schema/actions/two_fa_action.py +24 -0
  61. optexity/schema/automation.py +432 -0
  62. optexity/schema/callback.py +16 -0
  63. optexity/schema/inference.py +87 -0
  64. optexity/schema/memory.py +100 -0
  65. optexity/schema/task.py +212 -0
  66. optexity/schema/token_usage.py +48 -0
  67. optexity/utils/__init__.py +0 -0
  68. optexity/utils/settings.py +54 -0
  69. optexity/utils/utils.py +76 -0
  70. {optexity-0.1.2.dist-info → optexity-0.1.3.dist-info}/METADATA +1 -1
  71. optexity-0.1.3.dist-info/RECORD +80 -0
  72. optexity-0.1.2.dist-info/RECORD +0 -11
  73. {optexity-0.1.2.dist-info → optexity-0.1.3.dist-info}/WHEEL +0 -0
  74. {optexity-0.1.2.dist-info → optexity-0.1.3.dist-info}/entry_points.txt +0 -0
  75. {optexity-0.1.2.dist-info → optexity-0.1.3.dist-info}/licenses/LICENSE +0 -0
  76. {optexity-0.1.2.dist-info → optexity-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ import ast
2
+ import logging
3
+ import re
4
+ import time
5
+ from enum import Enum, unique
6
+ from typing import Optional
7
+
8
+ import tokencost.costs
9
+ from pydantic import BaseModel, ValidationError
10
+
11
+ from optexity.schema.token_usage import TokenUsage
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @unique
17
+ class HumanModels(Enum):
18
+ TERMINAL_INPUT = "terminal-input"
19
+
20
+
21
+ @unique
22
+ class GeminiModels(Enum):
23
+ GEMINI_1_5_FLASH = "gemini-1.5-flash"
24
+ GEMINI_2_0_FLASH = "gemini-2.0-flash"
25
+ GEMINI_2_5_FLASH = "gemini-2.5-flash"
26
+ GEMINI_2_5_FLASH_LITE = "gemini-2.5-flash-lite-preview-06-17"
27
+ GEMINI_2_5_PRO = "gemini-2.5-pro"
28
+
29
+
30
+ @unique
31
+ class OpenAIModels(Enum):
32
+ GPT_4O = "gpt-4o"
33
+ GPT_4O_MINI = "gpt-4o-mini"
34
+ GPT_4_1 = "gpt-4.1"
35
+ GPT_4_1_MINI = "gpt-4.1-mini"
36
+
37
+
38
+ class LLMModel:
39
+ def __init__(
40
+ self,
41
+ model_name: GeminiModels | HumanModels | OpenAIModels,
42
+ use_structured_output: bool,
43
+ ):
44
+
45
+ self.model_name = model_name
46
+ self.use_structured_output = use_structured_output
47
+
48
+ def _get_model_response(
49
+ self, prompt: str, system_instruction: Optional[str] = None
50
+ ) -> tuple[str, TokenUsage]:
51
+ raise NotImplementedError("This method should be implemented by subclasses.")
52
+
53
+ def _get_model_response_with_structured_output(
54
+ self,
55
+ prompt: str,
56
+ response_schema: BaseModel,
57
+ screenshot: Optional[str] = None,
58
+ pdf_url: Optional[str] = None,
59
+ system_instruction: Optional[str] = None,
60
+ ) -> tuple[BaseModel, TokenUsage]:
61
+ raise NotImplementedError("This method should be implemented by subclasses.")
62
+
63
+ def get_model_response(
64
+ self, prompt: str, system_instruction: Optional[str] = None
65
+ ) -> tuple[str, TokenUsage]:
66
+
67
+ max_retries = 3
68
+ for i in range(max_retries):
69
+ try:
70
+ return self._get_model_response(prompt, system_instruction)
71
+ except Exception as e:
72
+ logger.error(f"LLM Error during inference: {e}")
73
+ if i < max_retries - 1:
74
+ logger.info(f"Retrying... {i + 1}/{max_retries}")
75
+ time.sleep(5)
76
+ continue
77
+ raise Exception("Max retries exceeded for LLM")
78
+
79
+ def get_model_response_with_structured_output(
80
+ self,
81
+ prompt: str,
82
+ response_schema: BaseModel,
83
+ screenshot: Optional[str] = None,
84
+ pdf_url: Optional[str] = None,
85
+ system_instruction: Optional[str] = None,
86
+ ) -> tuple[BaseModel, TokenUsage]:
87
+
88
+ total_token_usage = TokenUsage()
89
+ max_retries = 3
90
+ last_exception = ""
91
+ for i in range(max_retries):
92
+ try:
93
+ # raise Exception("Test error")
94
+ parsed_response, token_usage = (
95
+ self._get_model_response_with_structured_output(
96
+ prompt=prompt,
97
+ response_schema=response_schema,
98
+ screenshot=screenshot,
99
+ pdf_url=pdf_url,
100
+ system_instruction=system_instruction,
101
+ )
102
+ )
103
+ total_token_usage += token_usage
104
+ if parsed_response is not None:
105
+ return parsed_response, total_token_usage
106
+ except Exception as e:
107
+ logger.error(f"LLM with structured output Error during inference: {e}")
108
+ if i < max_retries - 1:
109
+ logger.info(f"Retrying... {i + 1}/{max_retries}")
110
+ time.sleep(20)
111
+ last_exception = str(e)
112
+
113
+ raise Exception(
114
+ "Max retries exceeded for LLM with structured output"
115
+ + "\n"
116
+ + last_exception
117
+ )
118
+
119
+ def extract_json_objects(self, text):
120
+ stack = [] # Stack to track `{` positions
121
+ json_candidates = [] # Potential JSON substrings
122
+
123
+ # Iterate through the text to find balanced { }
124
+ for i, char in enumerate(text):
125
+ if char == "{":
126
+ stack.append(i) # Store index of '{'
127
+ elif char == "}" and stack:
128
+ start = stack.pop() # Get the last unmatched '{'
129
+ json_candidates.append(text[start : i + 1]) # Extract substring
130
+
131
+ return json_candidates
132
+
133
+ def parse_from_completion(
134
+ self, content: str, response_schema: BaseModel
135
+ ) -> BaseModel:
136
+ patterns = [r"```json\n(.*?)\n```"]
137
+ json_blocks = []
138
+ for pattern in patterns:
139
+ json_blocks += re.findall(pattern, content, re.DOTALL)
140
+ json_blocks += self.extract_json_objects(content)
141
+ for block in json_blocks:
142
+ block = block.strip()
143
+ try:
144
+ response = response_schema.model_validate_json(block)
145
+ return response
146
+ except Exception as e:
147
+ try:
148
+ block_dict = ast.literal_eval(block)
149
+ response = response_schema.model_validate(block_dict)
150
+ return response
151
+ except Exception as e:
152
+ continue
153
+
154
+ raise ValidationError("Could not parse response from completion.")
155
+
156
+ def get_token_usage(
157
+ self,
158
+ input_tokens: int | None = None,
159
+ output_tokens: int | None = None,
160
+ tool_use_tokens: int | None = None,
161
+ thoughts_tokens: int | None = None,
162
+ total_tokens: Optional[int] = None,
163
+ ) -> TokenUsage:
164
+ if input_tokens is None:
165
+ input_tokens = 0
166
+ if output_tokens is None:
167
+ output_tokens = 0
168
+ if tool_use_tokens is None:
169
+ tool_use_tokens = 0
170
+ if thoughts_tokens is None:
171
+ thoughts_tokens = 0
172
+ if total_tokens is None:
173
+ total_tokens = 0
174
+ input_cost = tokencost.costs.calculate_cost_by_tokens(
175
+ model=self.model_name.value,
176
+ num_tokens=input_tokens,
177
+ token_type="input",
178
+ )
179
+ output_cost = tokencost.costs.calculate_cost_by_tokens(
180
+ model=self.model_name.value,
181
+ num_tokens=output_tokens,
182
+ token_type="output",
183
+ )
184
+ tool_use_cost = tokencost.costs.calculate_cost_by_tokens(
185
+ model=self.model_name.value,
186
+ num_tokens=tool_use_tokens,
187
+ token_type="output",
188
+ )
189
+ thoughts_cost = tokencost.costs.calculate_cost_by_tokens(
190
+ model=self.model_name.value,
191
+ num_tokens=thoughts_tokens,
192
+ token_type="output",
193
+ )
194
+ calculated_total_tokens = (
195
+ input_tokens + output_tokens + tool_use_tokens + thoughts_tokens
196
+ )
197
+ total_cost = input_cost + output_cost + tool_use_cost + thoughts_cost
198
+ return TokenUsage(
199
+ input_tokens=input_tokens,
200
+ output_tokens=output_tokens,
201
+ tool_use_tokens=tool_use_tokens,
202
+ thoughts_tokens=thoughts_tokens,
203
+ total_tokens=total_tokens,
204
+ calculated_total_tokens=calculated_total_tokens,
205
+ input_cost=input_cost,
206
+ output_cost=output_cost,
207
+ tool_use_cost=tool_use_cost,
208
+ thoughts_cost=thoughts_cost,
209
+ total_cost=total_cost,
210
+ )
@@ -0,0 +1,200 @@
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from datetime import datetime, timezone
6
+
7
+ from dotenv import load_dotenv
8
+
9
+ from optexity.examples.fadv import fadv_test
10
+ from optexity.examples.i94 import automation
11
+ from optexity.examples.pshpgeorgia_medicaid import (
12
+ pshpgeorgia_login_test,
13
+ pshpgeorgia_medicaid_test,
14
+ )
15
+ from optexity.examples.shein import shein_test
16
+ from optexity.examples.supabase_login import supabase_login_test
17
+ from optexity.inference.core.run_automation import run_automation
18
+ from optexity.inference.infra.browser import Browser
19
+ from optexity.schema.memory import Memory, Variables
20
+ from optexity.schema.task import Task
21
+
22
+ load_dotenv()
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+ logging.getLogger(__name__).setLevel(logging.DEBUG)
27
+
28
+
29
+ async def run_supabase_login_test():
30
+ logger.debug("Starting Supabase login test")
31
+ browser = Browser()
32
+ memory = Memory(
33
+ variables=Variables(
34
+ input_variables={
35
+ "username": ["test@test.com"],
36
+ "password": ["password"],
37
+ }
38
+ )
39
+ )
40
+
41
+ await browser.start()
42
+ logger.info("Browser started")
43
+ logger.info("Navigating to Supabase")
44
+ await browser.go_to_url("https://supabase.com")
45
+ logger.info("Navigated to Supabase")
46
+ logger.info("Sleeping for 5 seconds")
47
+ await asyncio.sleep(2)
48
+
49
+ logger.info("Running automation")
50
+ await run_automation(supabase_login_test, memory, browser)
51
+ logger.info("Automation finished")
52
+ await asyncio.sleep(5)
53
+
54
+ await browser.stop()
55
+
56
+
57
+ async def run_pshpgeorgia_test():
58
+ try:
59
+ logger.debug("Starting PSHP Georgia test")
60
+ browser = Browser()
61
+ memory = Memory(
62
+ variables=Variables(
63
+ input_variables={
64
+ "username": [os.environ.get("USERNAME")],
65
+ "password": [os.environ.get("PASSWORD")],
66
+ "plan_type": [os.environ.get("PLAN_TYPE")],
67
+ "member_id": [os.environ.get("MEMBER_ID")],
68
+ "dob": [os.environ.get("DOB")],
69
+ }
70
+ )
71
+ )
72
+
73
+ await browser.start()
74
+ logger.debug("Browser started")
75
+ logger.debug("Navigating to PSHP Georgia")
76
+ await browser.go_to_url(
77
+ "https://sso.entrykeyid.com/as/authorization.oauth2?response_type=code&client_id=f6a6219c-be42-421b-b86c-e4fc509e2e87&scope=openid%20profile&state=_igWklSsnrkO5DQfjBMMuN41ksMJePZQ_SM_61wTJlA%3D&redirect_uri=https://provider.pshpgeorgia.com/careconnect/login/oauth2/code/pingcloud&code_challenge_method=S256&nonce=xG41TJjco_x7Vs_MQgcS3bw5njLiJsXCqvO-V8THmY0&code_challenge=ZTaVHaZCNFTejXNJo51RlJ3Kv9dH0tMODPTqO7hiP3A&app_origin=https://provider.pshpgeorgia.com/careconnect/login/oauth2/code/pingcloud&brand=pshpgeorgia"
78
+ )
79
+ logger.debug("Navigated to PSHP Georgia")
80
+
81
+ logger.debug("Running login test")
82
+ await run_automation(pshpgeorgia_login_test, memory, browser)
83
+ logger.debug("Login test finished")
84
+
85
+ logger.debug("Running Medicaid test")
86
+ await run_automation(pshpgeorgia_medicaid_test, memory, browser)
87
+ logger.debug("Medicaid test finished")
88
+
89
+ await asyncio.sleep(5)
90
+ await browser.stop()
91
+ except Exception as e:
92
+ logger.error(f"Error running PSHP Georgia test: {e}")
93
+ raise e
94
+ finally:
95
+ await browser.stop()
96
+
97
+
98
+ async def run_i94_test():
99
+ try:
100
+ logger.debug("Starting I-94 test")
101
+ browser = Browser(stealth=True)
102
+ memory = Memory(
103
+ variables=Variables(
104
+ input_variables={
105
+ "last_name": [os.environ.get("LAST_NAME")],
106
+ "first_name": [os.environ.get("FIRST_NAME")],
107
+ "nationality": [os.environ.get("NATIONALITY")],
108
+ "date_of_birth": [os.environ.get("DATE_OF_BIRTH")],
109
+ "document_number": [os.environ.get("DOCUMENT_NUMBER")],
110
+ }
111
+ )
112
+ )
113
+
114
+ await browser.start()
115
+ logger.debug("Browser started")
116
+ logger.debug("Navigating to I-94")
117
+ await browser.go_to_url(automation.url)
118
+ logger.debug("Navigated to I-94")
119
+
120
+ logger.debug("Running I-94 test")
121
+ await asyncio.sleep(5)
122
+ await run_automation(automation, memory, browser)
123
+ logger.debug("I-94 test finished")
124
+
125
+ await asyncio.sleep(5)
126
+ await browser.stop()
127
+ except Exception as e:
128
+ logger.error(f"Error running I-94 test: {e}")
129
+ raise e
130
+ finally:
131
+ await browser.stop()
132
+
133
+
134
+ async def run_shein_test():
135
+
136
+ try:
137
+ logger.debug("Starting Shein test")
138
+ task = Task(
139
+ task_id=str(uuid.uuid4()),
140
+ user_id=str(uuid.uuid4()),
141
+ recording_id=str(uuid.uuid4()),
142
+ automation=shein_test,
143
+ input_parameters={},
144
+ unique_parameter_names=[],
145
+ created_at=datetime.now(timezone.utc),
146
+ status="queued",
147
+ )
148
+ await run_automation(task, 0)
149
+ except Exception as e:
150
+ logger.error(f"Error running Shein test: {e}")
151
+ raise e
152
+ finally:
153
+
154
+ logger.debug("Remaining tasks:")
155
+ for task in asyncio.all_tasks():
156
+ if task is not asyncio.current_task():
157
+ logger.debug(f"Remaining task: {task.get_coro()}")
158
+
159
+ logger.debug("Shein test finished")
160
+
161
+
162
+ async def run_fadv_test():
163
+ try:
164
+ logger.debug("Starting FADV test task")
165
+ task = Task(
166
+ task_id=str(uuid.uuid4()),
167
+ user_id=str(uuid.uuid4()),
168
+ recording_id=str(uuid.uuid4()),
169
+ automation=fadv_test,
170
+ input_parameters={
171
+ "client_id": [os.environ.get("client_id")],
172
+ "user_id": [os.environ.get("user_id")],
173
+ "password": [os.environ.get("password")],
174
+ "secret_answer": [os.environ.get("secret_answer")],
175
+ "start_date": [os.environ.get("start_date")],
176
+ },
177
+ unique_parameter_names=[],
178
+ created_at=datetime.now(timezone.utc),
179
+ status="queued",
180
+ )
181
+ await run_automation(task, 0)
182
+ await asyncio.sleep(5)
183
+ except Exception as e:
184
+ logger.error(f"Error running FADV test: {e}")
185
+ raise e
186
+ finally:
187
+ logger.debug("Remaining tasks:")
188
+ for task in asyncio.all_tasks():
189
+ if task is not asyncio.current_task():
190
+ logger.debug(f"Remaining task: {task.get_coro()}")
191
+ logger.debug("FADV test finished")
192
+
193
+
194
+ if __name__ == "__main__":
195
+
196
+ # asyncio.run(run_supabase_login_test())
197
+ # asyncio.run(run_pshpgeorgia_test())
198
+ # asyncio.run(run_i94_test())
199
+ asyncio.run(run_fadv_test())
200
+ # asyncio.run(run_shein_test())
File without changes
File without changes
@@ -0,0 +1,66 @@
1
+ from typing import Literal, Optional
2
+
3
+ from pydantic import BaseModel, field_validator, model_validator
4
+
5
+ from optexity.schema.actions.extraction_action import LLMExtraction
6
+
7
+
8
+ class LLMAssertion(LLMExtraction):
9
+ source: list[Literal["axtree", "screenshot"]] = ["screenshot"]
10
+ extraction_format: dict = {"assertion_result": "bool", "assertion_reason": "str"}
11
+
12
+ @model_validator(mode="after")
13
+ def validate_output_var_in_format(self):
14
+ if "screenshot" not in self.source:
15
+ self.source.append("screenshot")
16
+
17
+ return self
18
+
19
+
20
+ class NetworkCallAssertion(BaseModel):
21
+ url_pattern: Optional[str] = None
22
+ header_filter: Optional[dict[str, str]] = None
23
+
24
+
25
+ class PythonScriptAssertion(BaseModel):
26
+ script: str
27
+ ## TODO: add output to memory variables
28
+
29
+ @field_validator("script")
30
+ @classmethod
31
+ def validate_script(cls, v: str):
32
+ if not v.strip():
33
+ raise ValueError("Script cannot be empty")
34
+ return v
35
+
36
+
37
+ class AssertionAction(BaseModel):
38
+ network_call: Optional[NetworkCallAssertion] = None
39
+ llm: Optional[LLMAssertion] = None
40
+ python_script: Optional[PythonScriptAssertion] = None
41
+
42
+ @model_validator(mode="after")
43
+ def validate_one_assertion(cls, model: "AssertionAction"):
44
+ """Ensure exactly one of the extraction types is set and matches the type."""
45
+ provided = {
46
+ "llm": model.llm,
47
+ "network_call": model.network_call,
48
+ "python_script": model.python_script,
49
+ }
50
+ non_null = [k for k, v in provided.items() if v is not None]
51
+
52
+ if len(non_null) != 1:
53
+ raise ValueError(
54
+ "Exactly one of llm, networkcall, or python must be provided"
55
+ )
56
+
57
+ return model
58
+
59
+ def replace(self, pattern: str, replacement: str):
60
+ if self.network_call:
61
+ pass
62
+ if self.llm:
63
+ self.llm.replace(pattern, replacement)
64
+ if self.python_script:
65
+ pass
66
+ return self
@@ -0,0 +1,143 @@
1
+ from typing import Any, List, Literal, Optional
2
+ from uuid import uuid4
3
+
4
+ from pydantic import BaseModel, Field, field_validator, model_validator
5
+
6
+ from optexity.utils.utils import build_model
7
+
8
+
9
+ class LLMExtraction(BaseModel):
10
+ source: list[Literal["axtree", "screenshot"]] = ["axtree"]
11
+ extraction_format: dict
12
+ extraction_instructions: str
13
+ output_variable_names: list[str] | None = None
14
+ llm_provider: Literal["gemini"] = "gemini"
15
+ llm_model_name: str = "gemini-2.5-flash"
16
+
17
+ def build_model(self):
18
+ return build_model(self.extraction_format)
19
+
20
+ @field_validator("extraction_format")
21
+ def validate_extraction_format(cls, v):
22
+ if isinstance(v, dict):
23
+ try:
24
+ build_model(v)
25
+ except Exception as e:
26
+ raise ValueError(f"Invalid extraction_format dict: {e}")
27
+ return v
28
+ raise ValueError("extraction_format must be either a string or a dict")
29
+
30
+ @model_validator(mode="after")
31
+ def validate_output_var_in_format(self):
32
+
33
+ if self.output_variable_names is not None:
34
+ for key in self.output_variable_names:
35
+ if key not in self.extraction_format:
36
+ raise ValueError(
37
+ f"Output variable {key} not found in extraction_format"
38
+ )
39
+ ## TODO: fix this
40
+ # if eval(self.extraction_format[key]) not in [
41
+ # int,
42
+ # float,
43
+ # bool,
44
+ # str,
45
+ # None,
46
+ # list[str | int | float | bool | None],
47
+ # List[str | int | float | bool | None],
48
+ # ]:
49
+ # raise ValueError(
50
+ # f"Output variable {key} must be a string, int, float, bool, or a list of strings, ints, floats, or bools"
51
+ # )
52
+
53
+ return self
54
+
55
+ def replace(self, pattern: str, replacement: str):
56
+ return self
57
+
58
+
59
+ class NetworkCallExtraction(BaseModel):
60
+ url_pattern: Optional[str] = None
61
+ extract_from: None | Literal["request", "response"] = None
62
+ download_from: None | Literal["request", "response"] = None
63
+ download_filename: str | None = None
64
+
65
+ @model_validator(mode="before")
66
+ def download_filename_if_download_from_is_set(cls, data: dict[str, Any]):
67
+ if (
68
+ "downlowd_from" in data
69
+ and data["download_from"] is not None
70
+ and ("download_filename" not in data or data["download_filename"] is None)
71
+ ):
72
+ data["download_filename"] = str(uuid4())
73
+
74
+ return data
75
+
76
+ def replace(self, pattern: str, replacement: str):
77
+ return self
78
+
79
+
80
+ class PythonScriptExtraction(BaseModel):
81
+ script: str
82
+ ## TODO: add output to memory variables
83
+
84
+ @field_validator("script")
85
+ @classmethod
86
+ def validate_script(cls, v: str):
87
+ if not v.strip():
88
+ raise ValueError("Script cannot be empty")
89
+ return v
90
+
91
+ def replace(self, pattern: str, replacement: str):
92
+ self.script = self.script.replace(pattern, replacement)
93
+ return self
94
+
95
+
96
+ class ScreenshotExtraction(BaseModel):
97
+ filename: str
98
+ full_page: bool = True
99
+
100
+
101
+ class StateExtraction(BaseModel):
102
+ pass
103
+
104
+
105
+ class ExtractionAction(BaseModel):
106
+ unique_identifier: str | None = None
107
+ network_call: Optional[NetworkCallExtraction] = None
108
+ llm: Optional[LLMExtraction] = None
109
+ python_script: Optional[PythonScriptExtraction] = None
110
+ screenshot: Optional[ScreenshotExtraction] = None
111
+ state: Optional[StateExtraction] = None
112
+
113
+ @model_validator(mode="after")
114
+ def validate_one_extraction(cls, model: "ExtractionAction"):
115
+ """Ensure exactly one of the extraction types is set and matches the type."""
116
+ provided = {
117
+ "llm": model.llm,
118
+ "network_call": model.network_call,
119
+ "python_script": model.python_script,
120
+ "screenshot": model.screenshot,
121
+ "state": model.state,
122
+ }
123
+ non_null = [k for k, v in provided.items() if v is not None]
124
+
125
+ if len(non_null) != 1:
126
+ raise ValueError(
127
+ "Exactly one of llm, networkcall, python_script, or screenshot must be provided"
128
+ )
129
+
130
+ return model
131
+
132
+ def replace(self, pattern: str, replacement: str):
133
+ if self.network_call:
134
+ self.network_call.replace(pattern, replacement)
135
+ if self.llm:
136
+ self.llm.replace(pattern, replacement)
137
+ if self.python_script:
138
+ self.python_script.replace(pattern, replacement)
139
+ if self.unique_identifier:
140
+ self.unique_identifier = self.unique_identifier.replace(
141
+ pattern, replacement
142
+ )
143
+ return self