futurehouse-client 0.3.18.dev109__tar.gz → 0.3.18.dev184__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/PKG-INFO +1 -1
  2. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/clients/rest_client.py +89 -103
  3. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/models/__init__.py +10 -0
  4. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/models/app.py +93 -0
  5. futurehouse_client-0.3.18.dev184/futurehouse_client/utils/__init__.py +0 -0
  6. futurehouse_client-0.3.18.dev184/futurehouse_client/utils/auth.py +107 -0
  7. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client.egg-info/PKG-INFO +1 -1
  8. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client.egg-info/SOURCES.txt +1 -1
  9. futurehouse_client-0.3.18.dev184/tests/test_rest.py +260 -0
  10. futurehouse_client-0.3.18.dev109/futurehouse_client/utils/__init__.py +0 -3
  11. futurehouse_client-0.3.18.dev109/futurehouse_client/utils/context.py +0 -16
  12. futurehouse_client-0.3.18.dev109/tests/test_rest.py +0 -214
  13. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/LICENSE +0 -0
  14. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/README.md +0 -0
  15. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/docs/__init__.py +0 -0
  16. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/docs/client_notebook.ipynb +0 -0
  17. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/__init__.py +0 -0
  18. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/clients/__init__.py +0 -0
  19. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/clients/job_client.py +0 -0
  20. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/models/client.py +0 -0
  21. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/models/rest.py +0 -0
  22. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/utils/general.py +0 -0
  23. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/utils/module_utils.py +0 -0
  24. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client/utils/monitoring.py +0 -0
  25. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client.egg-info/dependency_links.txt +0 -0
  26. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client.egg-info/requires.txt +0 -0
  27. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/futurehouse_client.egg-info/top_level.txt +0 -0
  28. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/pyproject.toml +0 -0
  29. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/setup.cfg +0 -0
  30. {futurehouse_client-0.3.18.dev109 → futurehouse_client-0.3.18.dev184}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.18.dev109
3
+ Version: 0.3.18.dev184
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -12,8 +12,7 @@ import sys
12
12
  import tempfile
13
13
  import time
14
14
  import uuid
15
- from collections.abc import Collection, Mapping
16
- from datetime import datetime
15
+ from collections.abc import Collection
17
16
  from pathlib import Path
18
17
  from types import ModuleType
19
18
  from typing import Any, ClassVar, assert_never, cast
@@ -34,7 +33,6 @@ from httpx import (
34
33
  RemoteProtocolError,
35
34
  )
36
35
  from ldp.agent import AgentConfig
37
- from pydantic import BaseModel, ConfigDict, model_validator
38
36
  from requests.exceptions import RequestException, Timeout
39
37
  from tenacity import (
40
38
  retry,
@@ -50,10 +48,18 @@ from futurehouse_client.models.app import (
50
48
  APIKeyPayload,
51
49
  AuthType,
52
50
  JobDeploymentConfig,
51
+ PQATaskResponse,
53
52
  Stage,
54
53
  TaskRequest,
54
+ TaskResponse,
55
+ TaskResponseVerbose,
55
56
  )
56
57
  from futurehouse_client.models.rest import ExecutionStatus
58
+ from futurehouse_client.utils.auth import (
59
+ AUTH_ERRORS_TO_RETRY_ON,
60
+ AuthError,
61
+ refresh_token_on_auth_error,
62
+ )
57
63
  from futurehouse_client.utils.general import gather_with_concurrency
58
64
  from futurehouse_client.utils.module_utils import (
59
65
  OrganizationSelector,
@@ -65,7 +71,7 @@ from futurehouse_client.utils.monitoring import (
65
71
 
66
72
  logger = logging.getLogger(__name__)
67
73
  logging.basicConfig(
68
- level=logging.INFO,
74
+ level=logging.WARNING,
69
75
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
70
76
  stream=sys.stdout,
71
77
  )
@@ -122,103 +128,11 @@ retry_if_connection_error = retry_if_exception_type((
122
128
  FileUploadError,
123
129
  ))
124
130
 
125
-
126
- class SimpleOrganization(BaseModel):
127
- id: int
128
- name: str
129
- display_name: str
130
-
131
-
132
131
  # 5 minute default for JWTs
133
132
  JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
134
133
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
135
134
 
136
135
 
137
- class TaskResponse(BaseModel):
138
- """Base class for task responses. This holds attributes shared over all futurehouse jobs."""
139
-
140
- model_config = ConfigDict(extra="ignore")
141
-
142
- status: str
143
- query: str
144
- user: str | None = None
145
- created_at: datetime
146
- job_name: str
147
- public: bool
148
- shared_with: list[SimpleOrganization] | None = None
149
- build_owner: str | None = None
150
- environment_name: str | None = None
151
- agent_name: str | None = None
152
- task_id: UUID | None = None
153
-
154
- @model_validator(mode="before")
155
- @classmethod
156
- def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
157
- # Extract fields from environment frame state
158
- if not isinstance(data, dict):
159
- return data
160
- # TODO: We probably want to remove these two once we define the final names.
161
- data["job_name"] = data.get("crow")
162
- data["query"] = data.get("task")
163
- data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
164
- if not (metadata := data.get("metadata", {})):
165
- return data
166
- data["environment_name"] = metadata.get("environment_name")
167
- data["agent_name"] = metadata.get("agent_name")
168
- return data
169
-
170
-
171
- class PQATaskResponse(TaskResponse):
172
- model_config = ConfigDict(extra="ignore")
173
-
174
- answer: str | None = None
175
- formatted_answer: str | None = None
176
- answer_reasoning: str | None = None
177
- has_successful_answer: bool | None = None
178
- total_cost: float | None = None
179
- total_queries: int | None = None
180
-
181
- @model_validator(mode="before")
182
- @classmethod
183
- def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
184
- if not isinstance(data, dict):
185
- return data
186
- if not (env_frame := data.get("environment_frame", {})):
187
- return data
188
- state = env_frame.get("state", {}).get("state", {})
189
- response = state.get("response", {})
190
- answer = response.get("answer", {})
191
- usage = state.get("info", {}).get("usage", {})
192
-
193
- # Add additional PQA specific fields to data so that pydantic can validate the model
194
- data["answer"] = answer.get("answer")
195
- data["formatted_answer"] = answer.get("formatted_answer")
196
- data["answer_reasoning"] = answer.get("answer_reasoning")
197
- data["has_successful_answer"] = answer.get("has_successful_answer")
198
- data["total_cost"] = cast(float, usage.get("total_cost"))
199
- data["total_queries"] = cast(int, usage.get("total_queries"))
200
-
201
- return data
202
-
203
- def clean_verbose(self) -> "TaskResponse":
204
- """Clean the verbose response from the server."""
205
- self.request = None
206
- self.response = None
207
- return self
208
-
209
-
210
- class TaskResponseVerbose(TaskResponse):
211
- """Class for responses to include all the fields of a task response."""
212
-
213
- model_config = ConfigDict(extra="allow")
214
-
215
- public: bool
216
- agent_state: list[dict[str, Any]] | None = None
217
- environment_frame: dict[str, Any] | None = None
218
- metadata: dict[str, Any] | None = None
219
- shared_with: list[SimpleOrganization] | None = None
220
-
221
-
222
136
  class RestClient:
223
137
  REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
224
138
  MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
@@ -236,7 +150,13 @@ class RestClient:
236
150
  api_key: str | None = None,
237
151
  jwt: str | None = None,
238
152
  headers: dict[str, str] | None = None,
153
+ verbose_logging: bool = False,
239
154
  ):
155
+ if verbose_logging:
156
+ logger.setLevel(logging.INFO)
157
+ else:
158
+ logger.setLevel(logging.WARNING)
159
+
240
160
  self.base_url = service_uri or stage.value
241
161
  self.stage = stage
242
162
  self.auth_type = auth_type
@@ -360,6 +280,7 @@ class RestClient:
360
280
  except Exception as e:
361
281
  raise RestClientError(f"Error authenticating: {e!s}") from e
362
282
 
283
+ @refresh_token_on_auth_error()
363
284
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
364
285
  try:
365
286
  response = self.client.get(
@@ -367,9 +288,19 @@ class RestClient:
367
288
  )
368
289
  response.raise_for_status()
369
290
  return response.json()
291
+ except HTTPStatusError as e:
292
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
293
+ raise AuthError(
294
+ e.response.status_code,
295
+ f"Authentication failed: {e}",
296
+ request=e.request,
297
+ response=e.response,
298
+ ) from e
299
+ raise
370
300
  except Exception as e:
371
301
  raise JobFetchError(f"Error checking job: {e!s}") from e
372
302
 
303
+ @refresh_token_on_auth_error()
373
304
  def _fetch_my_orgs(self) -> list[str]:
374
305
  response = self.client.get(f"/v0.1/organizations?filter={True}")
375
306
  response.raise_for_status()
@@ -432,6 +363,7 @@ class RestClient:
432
363
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
433
364
  retry=retry_if_connection_error,
434
365
  )
366
+ @refresh_token_on_auth_error()
435
367
  def get_task(
436
368
  self, task_id: str | None = None, history: bool = False, verbose: bool = False
437
369
  ) -> "TaskResponse":
@@ -467,8 +399,15 @@ class RestClient:
467
399
  ):
468
400
  return PQATaskResponse(**data)
469
401
  return TaskResponse(**data)
470
- except ValueError as e:
471
- raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
402
+ except HTTPStatusError as e:
403
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
404
+ raise AuthError(
405
+ e.response.status_code,
406
+ f"Authentication failed: {e}",
407
+ request=e.request,
408
+ response=e.response,
409
+ ) from e
410
+ raise
472
411
  except Exception as e:
473
412
  raise TaskFetchError(f"Error getting task: {e!s}") from e
474
413
 
@@ -477,6 +416,7 @@ class RestClient:
477
416
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
478
417
  retry=retry_if_connection_error,
479
418
  )
419
+ @refresh_token_on_auth_error()
480
420
  async def aget_task(
481
421
  self, task_id: str | None = None, history: bool = False, verbose: bool = False
482
422
  ) -> "TaskResponse":
@@ -515,11 +455,19 @@ class RestClient:
515
455
  ):
516
456
  return PQATaskResponse(**data)
517
457
  return TaskResponse(**data)
518
- except ValueError as e:
519
- raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
458
+ except HTTPStatusError as e:
459
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
460
+ raise AuthError(
461
+ e.response.status_code,
462
+ f"Authentication failed: {e}",
463
+ request=e.request,
464
+ response=e.response,
465
+ ) from e
466
+ raise
520
467
  except Exception as e:
521
468
  raise TaskFetchError(f"Error getting task: {e!s}") from e
522
469
 
470
+ @refresh_token_on_auth_error()
523
471
  @retry(
524
472
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
525
473
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -543,6 +491,15 @@ class RestClient:
543
491
  response.raise_for_status()
544
492
  trajectory_id = response.json()["trajectory_id"]
545
493
  self.trajectory_id = trajectory_id
494
+ except HTTPStatusError as e:
495
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
496
+ raise AuthError(
497
+ e.response.status_code,
498
+ f"Authentication failed: {e}",
499
+ request=e.request,
500
+ response=e.response,
501
+ ) from e
502
+ raise
546
503
  except Exception as e:
547
504
  raise TaskFetchError(f"Error creating task: {e!s}") from e
548
505
  return trajectory_id
@@ -552,6 +509,7 @@ class RestClient:
552
509
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
553
510
  retry=retry_if_connection_error,
554
511
  )
512
+ @refresh_token_on_auth_error()
555
513
  async def acreate_task(self, task_data: TaskRequest | dict[str, Any]):
556
514
  """Create a new futurehouse task."""
557
515
  if isinstance(task_data, dict):
@@ -570,6 +528,15 @@ class RestClient:
570
528
  response.raise_for_status()
571
529
  trajectory_id = response.json()["trajectory_id"]
572
530
  self.trajectory_id = trajectory_id
531
+ except HTTPStatusError as e:
532
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
533
+ raise AuthError(
534
+ e.response.status_code,
535
+ f"Authentication failed: {e}",
536
+ request=e.request,
537
+ response=e.response,
538
+ ) from e
539
+ raise
573
540
  except Exception as e:
574
541
  raise TaskFetchError(f"Error creating task: {e!s}") from e
575
542
  return trajectory_id
@@ -720,11 +687,22 @@ class RestClient:
720
687
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
721
688
  retry=retry_if_connection_error,
722
689
  )
690
+ @refresh_token_on_auth_error()
723
691
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
724
692
  """Get the status of a build."""
725
- build_id = build_id or self.build_id
726
- response = self.client.get(f"/v0.1/builds/{build_id}")
727
- response.raise_for_status()
693
+ try:
694
+ build_id = build_id or self.build_id
695
+ response = self.client.get(f"/v0.1/builds/{build_id}")
696
+ response.raise_for_status()
697
+ except HTTPStatusError as e:
698
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
699
+ raise AuthError(
700
+ e.response.status_code,
701
+ f"Authentication failed: {e}",
702
+ request=e.request,
703
+ response=e.response,
704
+ ) from e
705
+ raise
728
706
  return response.json()
729
707
 
730
708
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -733,6 +711,7 @@ class RestClient:
733
711
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
734
712
  retry=retry_if_connection_error,
735
713
  )
714
+ @refresh_token_on_auth_error()
736
715
  def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]: # noqa: PLR0915
737
716
  """Creates a futurehouse job deployment from the environment and environment files.
738
717
 
@@ -907,6 +886,13 @@ class RestClient:
907
886
  build_context = response.json()
908
887
  self.build_id = build_context["build_id"]
909
888
  except HTTPStatusError as e:
889
+ if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
890
+ raise AuthError(
891
+ e.response.status_code,
892
+ f"Authentication failed: {e}",
893
+ request=e.request,
894
+ response=e.response,
895
+ ) from e
910
896
  error_detail = response.json()
911
897
  error_message = error_detail.get("detail", str(e))
912
898
  raise JobCreationError(
@@ -3,10 +3,15 @@ from .app import (
3
3
  DockerContainerConfiguration,
4
4
  FramePath,
5
5
  JobDeploymentConfig,
6
+ PQATaskResponse,
6
7
  RuntimeConfig,
7
8
  Stage,
8
9
  Step,
10
+ TaskQueue,
11
+ TaskQueuesConfig,
9
12
  TaskRequest,
13
+ TaskResponse,
14
+ TaskResponseVerbose,
10
15
  )
11
16
 
12
17
  __all__ = [
@@ -14,8 +19,13 @@ __all__ = [
14
19
  "DockerContainerConfiguration",
15
20
  "FramePath",
16
21
  "JobDeploymentConfig",
22
+ "PQATaskResponse",
17
23
  "RuntimeConfig",
18
24
  "Stage",
19
25
  "Step",
26
+ "TaskQueue",
27
+ "TaskQueuesConfig",
20
28
  "TaskRequest",
29
+ "TaskResponse",
30
+ "TaskResponseVerbose",
21
31
  ]
@@ -1,6 +1,8 @@
1
1
  import json
2
2
  import os
3
3
  import re
4
+ from collections.abc import Mapping
5
+ from datetime import datetime
4
6
  from enum import StrEnum, auto
5
7
  from pathlib import Path
6
8
  from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
@@ -646,3 +648,94 @@ class TaskRequest(BaseModel):
646
648
  runtime_config: RuntimeConfig | None = Field(
647
649
  default=None, description="All optional runtime parameters for the job"
648
650
  )
651
+
652
+
653
+ class SimpleOrganization(BaseModel):
654
+ id: int
655
+ name: str
656
+ display_name: str
657
+
658
+
659
+ class TaskResponse(BaseModel):
660
+ """Base class for task responses. This holds attributes shared over all futurehouse jobs."""
661
+
662
+ model_config = ConfigDict(extra="ignore")
663
+
664
+ status: str
665
+ query: str
666
+ user: str | None = None
667
+ created_at: datetime
668
+ job_name: str
669
+ public: bool
670
+ shared_with: list[SimpleOrganization] | None = None
671
+ build_owner: str | None = None
672
+ environment_name: str | None = None
673
+ agent_name: str | None = None
674
+ task_id: UUID | None = None
675
+
676
+ @model_validator(mode="before")
677
+ @classmethod
678
+ def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
679
+ # Extract fields from environment frame state
680
+ if not isinstance(data, dict):
681
+ return data
682
+ # TODO: We probably want to remove these two once we define the final names.
683
+ data["job_name"] = data.get("crow")
684
+ data["query"] = data.get("task")
685
+ data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
686
+ if not (metadata := data.get("metadata", {})):
687
+ return data
688
+ data["environment_name"] = metadata.get("environment_name")
689
+ data["agent_name"] = metadata.get("agent_name")
690
+ return data
691
+
692
+
693
+ class PQATaskResponse(TaskResponse):
694
+ model_config = ConfigDict(extra="ignore")
695
+
696
+ answer: str | None = None
697
+ formatted_answer: str | None = None
698
+ answer_reasoning: str | None = None
699
+ has_successful_answer: bool | None = None
700
+ total_cost: float | None = None
701
+ total_queries: int | None = None
702
+
703
+ @model_validator(mode="before")
704
+ @classmethod
705
+ def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
706
+ if not isinstance(data, dict):
707
+ return data
708
+ if not (env_frame := data.get("environment_frame", {})):
709
+ return data
710
+ state = env_frame.get("state", {}).get("state", {})
711
+ response = state.get("response", {})
712
+ answer = response.get("answer", {})
713
+ usage = state.get("info", {}).get("usage", {})
714
+
715
+ # Add additional PQA specific fields to data so that pydantic can validate the model
716
+ data["answer"] = answer.get("answer")
717
+ data["formatted_answer"] = answer.get("formatted_answer")
718
+ data["answer_reasoning"] = answer.get("answer_reasoning")
719
+ data["has_successful_answer"] = answer.get("has_successful_answer")
720
+ data["total_cost"] = cast(float, usage.get("total_cost"))
721
+ data["total_queries"] = cast(int, usage.get("total_queries"))
722
+
723
+ return data
724
+
725
+ def clean_verbose(self) -> "TaskResponse":
726
+ """Clean the verbose response from the server."""
727
+ self.request = None
728
+ self.response = None
729
+ return self
730
+
731
+
732
+ class TaskResponseVerbose(TaskResponse):
733
+ """Class for responses to include all the fields of a task response."""
734
+
735
+ model_config = ConfigDict(extra="allow")
736
+
737
+ public: bool
738
+ agent_state: list[dict[str, Any]] | None = None
739
+ environment_frame: dict[str, Any] | None = None
740
+ metadata: dict[str, Any] | None = None
741
+ shared_with: list[SimpleOrganization] | None = None
@@ -0,0 +1,107 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import Callable, Coroutine
4
+ from functools import wraps
5
+ from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
6
+
7
+ import httpx
8
+ from httpx import HTTPStatusError
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ T = TypeVar("T")
13
+ P = ParamSpec("P")
14
+
15
+ AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
16
+ httpx.codes.UNAUTHORIZED,
17
+ httpx.codes.FORBIDDEN,
18
+ }
19
+
20
+
21
+ class AuthError(Exception):
22
+ """Raised when authentication fails with 401/403 status."""
23
+
24
+ def __init__(self, status_code: int, message: str, request=None, response=None):
25
+ self.status_code = status_code
26
+ self.request = request
27
+ self.response = response
28
+ super().__init__(message)
29
+
30
+
31
+ def is_auth_error(e: Exception) -> bool:
32
+ if isinstance(e, AuthError):
33
+ return True
34
+ if isinstance(e, HTTPStatusError):
35
+ return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
36
+ return False
37
+
38
+
39
+ def get_status_code(e: Exception) -> Optional[int]:
40
+ if isinstance(e, AuthError):
41
+ return e.status_code
42
+ if isinstance(e, HTTPStatusError):
43
+ return e.response.status_code
44
+ return None
45
+
46
+
47
+ @overload
48
+ def refresh_token_on_auth_error(
49
+ func: Callable[P, Coroutine[Any, Any, T]],
50
+ ) -> Callable[P, Coroutine[Any, Any, T]]: ...
51
+
52
+
53
+ @overload
54
+ def refresh_token_on_auth_error(
55
+ func: None = None, *, max_retries: int = ...
56
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
57
+
58
+
59
+ def refresh_token_on_auth_error(func=None, max_retries=1):
60
+ """Decorator that refreshes JWT token on 401/403 auth errors."""
61
+
62
+ def decorator(fn):
63
+ @wraps(fn)
64
+ def sync_wrapper(self, *args, **kwargs):
65
+ retries = 0
66
+ while True:
67
+ try:
68
+ return fn(self, *args, **kwargs)
69
+ except Exception as e:
70
+ if is_auth_error(e) and retries < max_retries:
71
+ retries += 1
72
+ status = get_status_code(e) or "Unknown"
73
+ logger.info(
74
+ f"Received auth error {status}, "
75
+ f"refreshing token and retrying (attempt {retries}/{max_retries})..."
76
+ )
77
+ self.auth_jwt = self._run_auth()
78
+ self._clients = {}
79
+ continue
80
+ raise
81
+
82
+ @wraps(fn)
83
+ async def async_wrapper(self, *args, **kwargs):
84
+ retries = 0
85
+ while True:
86
+ try:
87
+ return await fn(self, *args, **kwargs)
88
+ except Exception as e:
89
+ if is_auth_error(e) and retries < max_retries:
90
+ retries += 1
91
+ status = get_status_code(e) or "Unknown"
92
+ logger.info(
93
+ f"Received auth error {status}, "
94
+ f"refreshing token and retrying (attempt {retries}/{max_retries})..."
95
+ )
96
+ self.auth_jwt = self._run_auth()
97
+ self._clients = {}
98
+ continue
99
+ raise
100
+
101
+ if asyncio.iscoroutinefunction(fn):
102
+ return async_wrapper
103
+ return sync_wrapper
104
+
105
+ if callable(func):
106
+ return decorator(func)
107
+ return decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.18.dev109
3
+ Version: 0.3.18.dev184
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -18,7 +18,7 @@ futurehouse_client/models/app.py
18
18
  futurehouse_client/models/client.py
19
19
  futurehouse_client/models/rest.py
20
20
  futurehouse_client/utils/__init__.py
21
- futurehouse_client/utils/context.py
21
+ futurehouse_client/utils/auth.py
22
22
  futurehouse_client/utils/general.py
23
23
  futurehouse_client/utils/module_utils.py
24
24
  futurehouse_client/utils/monitoring.py
@@ -0,0 +1,260 @@
1
+ # ruff: noqa: ARG001
2
+ import asyncio
3
+ import os
4
+ import time
5
+ from unittest.mock import patch
6
+
7
+ import pytest
8
+ from futurehouse_client.clients import (
9
+ JobNames,
10
+ PQATaskResponse,
11
+ TaskResponseVerbose,
12
+ )
13
+ from futurehouse_client.clients.rest_client import RestClient
14
+ from futurehouse_client.models.app import Stage, TaskRequest
15
+ from futurehouse_client.models.rest import ExecutionStatus
16
+ from futurehouse_client.utils.auth import AuthError, refresh_token_on_auth_error
17
+ from pytest_subtests import SubTests
18
+
19
+ ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
20
+ PUBLIC_API_KEY = os.environ["PLAYWRIGHT_PUBLIC_API_KEY"]
21
+ TEST_MAX_POLLS = 100
22
+
23
+
24
+ @pytest.fixture
25
+ def admin_client():
26
+ """Create a RestClient for testing."""
27
+ return RestClient(
28
+ stage=Stage.DEV,
29
+ api_key=ADMIN_API_KEY,
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def pub_client():
35
+ """Create a RestClient for testing."""
36
+ return RestClient(
37
+ stage=Stage.DEV,
38
+ api_key=PUBLIC_API_KEY,
39
+ )
40
+
41
+
42
+ @pytest.fixture
43
+ def task_data():
44
+ """Create a sample task request."""
45
+ return TaskRequest(
46
+ name=JobNames.from_string("dummy"),
47
+ query="How many moons does earth have?",
48
+ )
49
+
50
+
51
+ @pytest.fixture
52
+ def pqa_task_data():
53
+ return TaskRequest(
54
+ name=JobNames.from_string("crow"),
55
+ query="How many moons does earth have?",
56
+ )
57
+
58
+
59
+ @pytest.mark.timeout(300)
60
+ @pytest.mark.flaky(reruns=3)
61
+ def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_data: TaskRequest):
62
+ admin_client.create_task(task_data)
63
+ while (task_status := admin_client.get_task().status) in {"queued", "in progress"}:
64
+ time.sleep(5)
65
+ assert task_status == "success"
66
+
67
+
68
+ def test_insufficient_permissions_request(
69
+ pub_client: RestClient, task_data: TaskRequest
70
+ ):
71
+ # Create a new instance so that cached credentials aren't reused
72
+ with pytest.raises(AuthError) as exc_info:
73
+ pub_client.create_task(task_data)
74
+
75
+ assert "403 Forbidden" in str(exc_info.value)
76
+
77
+
78
+ @pytest.mark.timeout(300)
79
+ @pytest.mark.asyncio
80
+ async def test_job_response( # noqa: PLR0915
81
+ subtests: SubTests, admin_client: RestClient, pqa_task_data: TaskRequest
82
+ ):
83
+ task_id = admin_client.create_task(pqa_task_data)
84
+ atask_id = await admin_client.acreate_task(pqa_task_data)
85
+
86
+ with subtests.test("Test TaskResponse with queued task"):
87
+ task_response = admin_client.get_task(task_id)
88
+ assert task_response.status in {"queued", "in progress"}
89
+ assert task_response.job_name == pqa_task_data.name
90
+ assert task_response.query == pqa_task_data.query
91
+ task_response = await admin_client.aget_task(atask_id)
92
+ assert task_response.status in {"queued", "in progress"}
93
+ assert task_response.job_name == pqa_task_data.name
94
+ assert task_response.query == pqa_task_data.query
95
+
96
+ for _ in range(TEST_MAX_POLLS):
97
+ task_response = admin_client.get_task(task_id)
98
+ if task_response.status in ExecutionStatus.terminal_states():
99
+ break
100
+ await asyncio.sleep(5)
101
+
102
+ for _ in range(TEST_MAX_POLLS):
103
+ task_response = await admin_client.aget_task(atask_id)
104
+ if task_response.status in ExecutionStatus.terminal_states():
105
+ break
106
+ await asyncio.sleep(5)
107
+
108
+ with subtests.test("Test PQA job response"):
109
+ task_response = admin_client.get_task(task_id)
110
+ assert isinstance(task_response, PQATaskResponse)
111
+ # assert it has general fields
112
+ assert task_response.status == "success"
113
+ assert task_response.task_id is not None
114
+ assert pqa_task_data.name in task_response.job_name
115
+ assert pqa_task_data.query in task_response.query
116
+ # assert it has PQA specific fields
117
+ assert task_response.answer is not None
118
+ # assert it's not verbose
119
+ assert not hasattr(task_response, "environment_frame")
120
+ assert not hasattr(task_response, "agent_state")
121
+
122
+ with subtests.test("Test async PQA job response"):
123
+ task_response = await admin_client.aget_task(atask_id)
124
+ assert isinstance(task_response, PQATaskResponse)
125
+ # assert it has general fields
126
+ assert task_response.status == "success"
127
+ assert task_response.task_id is not None
128
+ assert pqa_task_data.name in task_response.job_name
129
+ assert pqa_task_data.query in task_response.query
130
+ # assert it has PQA specific fields
131
+ assert task_response.answer is not None
132
+ # assert it's not verbose
133
+ assert not hasattr(task_response, "environment_frame")
134
+ assert not hasattr(task_response, "agent_state")
135
+
136
+ with subtests.test("Test task response with verbose"):
137
+ task_response = admin_client.get_task(task_id, verbose=True)
138
+ assert isinstance(task_response, TaskResponseVerbose)
139
+ assert task_response.status == "success"
140
+ assert task_response.environment_frame is not None
141
+ assert task_response.agent_state is not None
142
+
143
+ with subtests.test("Test task async response with verbose"):
144
+ task_response = await admin_client.aget_task(atask_id, verbose=True)
145
+ assert isinstance(task_response, TaskResponseVerbose)
146
+ assert task_response.status == "success"
147
+ assert task_response.environment_frame is not None
148
+ assert task_response.agent_state is not None
149
+
150
+
151
+ @pytest.mark.timeout(300)
152
+ @pytest.mark.flaky(reruns=3)
153
+ def test_run_until_done_futurehouse_dummy_env_crow(
154
+ admin_client: RestClient, task_data: TaskRequest
155
+ ):
156
+ tasks_to_do = [task_data, task_data]
157
+
158
+ results = admin_client.run_tasks_until_done(tasks_to_do)
159
+
160
+ assert len(results) == len(tasks_to_do), "Should return 2 tasks."
161
+ assert all(task.status == "success" for task in results)
162
+
163
+
164
+ @pytest.mark.timeout(300)
165
+ @pytest.mark.flaky(reruns=3)
166
+ @pytest.mark.asyncio
167
+ async def test_arun_until_done_futurehouse_dummy_env_crow(
168
+ admin_client: RestClient, task_data: TaskRequest
169
+ ):
170
+ tasks_to_do = [task_data, task_data]
171
+
172
+ results = await admin_client.arun_tasks_until_done(tasks_to_do)
173
+
174
+ assert len(results) == len(tasks_to_do), "Should return 2 tasks."
175
+ assert all(task.status == "success" for task in results)
176
+
177
+
178
+ @pytest.mark.timeout(300)
179
+ @pytest.mark.flaky(reruns=3)
180
+ @pytest.mark.asyncio
181
+ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
182
+ admin_client: RestClient, task_data: TaskRequest
183
+ ):
184
+ tasks_to_do = [task_data, task_data]
185
+
186
+ results = await admin_client.arun_tasks_until_done(
187
+ tasks_to_do, verbose=True, timeout=5, progress_bar=True
188
+ )
189
+
190
+ assert len(results) == len(tasks_to_do), "Should return 2 tasks."
191
+ assert all(task.status != "success" for task in results), "Should not be success."
192
+ assert all(not isinstance(task, PQATaskResponse) for task in results), (
193
+ "Should be verbose."
194
+ )
195
+
196
+ results = admin_client.run_tasks_until_done(
197
+ tasks_to_do, verbose=True, timeout=5, progress_bar=True
198
+ )
199
+
200
+ assert len(results) == len(tasks_to_do), "Should return 2 tasks."
201
+ assert all(task.status != "success" for task in results), "Should not be success."
202
+ assert all(not isinstance(task, PQATaskResponse) for task in results), (
203
+ "Should be verbose."
204
+ )
205
+
206
+
207
+ def test_auth_refresh_flow(admin_client: RestClient):
208
+ refresh_calls = 0
209
+ func_calls = 0
210
+
211
+ def mock_run_auth(*args, **kwargs):
212
+ nonlocal refresh_calls
213
+ refresh_calls += 1
214
+ return f"fresh-token-{refresh_calls}"
215
+
216
+ @refresh_token_on_auth_error()
217
+ def test_func(self, *args):
218
+ nonlocal func_calls
219
+ func_calls += 1
220
+
221
+ if func_calls == 1:
222
+ raise AuthError(401, "Auth failed", None, None)
223
+ return "success"
224
+
225
+ with patch.object(admin_client, "_run_auth", mock_run_auth):
226
+ result = test_func(admin_client)
227
+
228
+ assert result == "success"
229
+ assert func_calls == 2, "Function should be called twice"
230
+ assert refresh_calls == 1, "Auth should be refreshed once"
231
+ assert admin_client.auth_jwt == "fresh-token-1"
232
+
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_async_auth_refresh_flow(admin_client: RestClient):
236
+ refresh_calls = 0
237
+ func_calls = 0
238
+
239
+ def mock_run_auth(*args, **kwargs):
240
+ nonlocal refresh_calls
241
+ refresh_calls += 1
242
+ return f"fresh-token-{refresh_calls}"
243
+
244
+ @refresh_token_on_auth_error()
245
+ async def test_async_func(self, *args):
246
+ nonlocal func_calls
247
+ func_calls += 1
248
+
249
+ if func_calls == 1:
250
+ raise AuthError(401, "Auth failed", None, None)
251
+ await asyncio.sleep(1)
252
+ return "success"
253
+
254
+ with patch.object(admin_client, "_run_auth", mock_run_auth):
255
+ result = await test_async_func(admin_client)
256
+
257
+ assert result == "success"
258
+ assert func_calls == 2, "Function should be called twice"
259
+ assert refresh_calls == 1, "Auth should be refreshed once"
260
+ assert admin_client.auth_jwt == "fresh-token-1"
@@ -1,3 +0,0 @@
1
- from .context import UserContext
2
-
3
- __all__ = ["UserContext"]
@@ -1,16 +0,0 @@
1
- class UserContext:
2
- """A context manager for storing user information from the initial request."""
3
-
4
- _user_jwt = None
5
-
6
- @classmethod
7
- def set_user_jwt(cls, jwt: str) -> None:
8
- cls._user_jwt = jwt
9
-
10
- @classmethod
11
- def get_user_jwt(cls) -> str | None:
12
- return cls._user_jwt
13
-
14
- @classmethod
15
- def clear_user_jwt(cls) -> None:
16
- cls._user_jwt = None
@@ -1,214 +0,0 @@
1
- import asyncio
2
- import os
3
- import time
4
-
5
- import pytest
6
- from futurehouse_client.clients import (
7
- JobNames,
8
- PQATaskResponse,
9
- TaskResponseVerbose,
10
- )
11
- from futurehouse_client.clients.rest_client import RestClient, TaskFetchError
12
- from futurehouse_client.models.app import Stage, TaskRequest
13
- from futurehouse_client.models.rest import ExecutionStatus
14
- from pytest_subtests import SubTests
15
-
16
- ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
17
- PUBLIC_API_KEY = os.environ["PLAYWRIGHT_PUBLIC_API_KEY"]
18
- TEST_MAX_POLLS = 100
19
-
20
-
21
- @pytest.mark.timeout(300)
22
- @pytest.mark.flaky(reruns=3)
23
- def test_futurehouse_dummy_env_crow():
24
- client = RestClient(
25
- stage=Stage.DEV,
26
- api_key=ADMIN_API_KEY,
27
- )
28
-
29
- task_data = TaskRequest(
30
- name=JobNames.from_string("dummy"),
31
- query="How many moons does earth have?",
32
- )
33
- client.create_task(task_data)
34
-
35
- while (task_status := client.get_task().status) in {"queued", "in progress"}:
36
- time.sleep(5)
37
-
38
- assert task_status == "success"
39
-
40
-
41
- def test_insufficient_permissions_request():
42
- # Create a new instance so that cached credentials aren't reused
43
- client = RestClient(
44
- stage=Stage.DEV,
45
- api_key=PUBLIC_API_KEY,
46
- )
47
- task_data = TaskRequest(
48
- name=JobNames.from_string("dummy"),
49
- query="How many moons does earth have?",
50
- )
51
-
52
- with pytest.raises(TaskFetchError) as exc_info:
53
- client.create_task(task_data)
54
-
55
- assert "Error creating task" in str(exc_info.value)
56
-
57
-
58
- @pytest.mark.timeout(300)
59
- @pytest.mark.asyncio
60
- async def test_job_response(subtests: SubTests): # noqa: PLR0915
61
- client = RestClient(
62
- stage=Stage.DEV,
63
- api_key=ADMIN_API_KEY,
64
- )
65
- task_data = TaskRequest(
66
- name=JobNames.from_string("crow"),
67
- query="How many moons does earth have?",
68
- )
69
- task_id = client.create_task(task_data)
70
- atask_id = await client.acreate_task(task_data)
71
-
72
- with subtests.test("Test TaskResponse with queued task"):
73
- task_response = client.get_task(task_id)
74
- assert task_response.status in {"queued", "in progress"}
75
- assert task_response.job_name == task_data.name
76
- assert task_response.query == task_data.query
77
- task_response = await client.aget_task(atask_id)
78
- assert task_response.status in {"queued", "in progress"}
79
- assert task_response.job_name == task_data.name
80
- assert task_response.query == task_data.query
81
-
82
- for _ in range(TEST_MAX_POLLS):
83
- task_response = client.get_task(task_id)
84
- if task_response.status in ExecutionStatus.terminal_states():
85
- break
86
- await asyncio.sleep(5)
87
-
88
- for _ in range(TEST_MAX_POLLS):
89
- task_response = await client.aget_task(atask_id)
90
- if task_response.status in ExecutionStatus.terminal_states():
91
- break
92
- await asyncio.sleep(5)
93
-
94
- with subtests.test("Test PQA job response"):
95
- task_response = client.get_task(task_id)
96
- assert isinstance(task_response, PQATaskResponse)
97
- # assert it has general fields
98
- assert task_response.status == "success"
99
- assert task_response.task_id is not None
100
- assert task_data.name in task_response.job_name
101
- assert task_data.query in task_response.query
102
- # assert it has PQA specific fields
103
- assert task_response.answer is not None
104
- # assert it's not verbose
105
- assert not hasattr(task_response, "environment_frame")
106
- assert not hasattr(task_response, "agent_state")
107
-
108
- with subtests.test("Test async PQA job response"):
109
- task_response = await client.aget_task(atask_id)
110
- assert isinstance(task_response, PQATaskResponse)
111
- # assert it has general fields
112
- assert task_response.status == "success"
113
- assert task_response.task_id is not None
114
- assert task_data.name in task_response.job_name
115
- assert task_data.query in task_response.query
116
- # assert it has PQA specific fields
117
- assert task_response.answer is not None
118
- # assert it's not verbose
119
- assert not hasattr(task_response, "environment_frame")
120
- assert not hasattr(task_response, "agent_state")
121
-
122
- with subtests.test("Test task response with verbose"):
123
- task_response = client.get_task(task_id, verbose=True)
124
- assert isinstance(task_response, TaskResponseVerbose)
125
- assert task_response.status == "success"
126
- assert task_response.environment_frame is not None
127
- assert task_response.agent_state is not None
128
-
129
- with subtests.test("Test task async response with verbose"):
130
- task_response = await client.aget_task(atask_id, verbose=True)
131
- assert isinstance(task_response, TaskResponseVerbose)
132
- assert task_response.status == "success"
133
- assert task_response.environment_frame is not None
134
- assert task_response.agent_state is not None
135
-
136
-
137
- @pytest.mark.timeout(300)
138
- @pytest.mark.flaky(reruns=3)
139
- def test_run_until_done_futurehouse_dummy_env_crow():
140
- client = RestClient(
141
- stage=Stage.DEV,
142
- api_key=ADMIN_API_KEY,
143
- )
144
-
145
- task_data = TaskRequest(
146
- name=JobNames.from_string("dummy"),
147
- query="How many moons does earth have?",
148
- )
149
-
150
- tasks_to_do = [task_data, task_data]
151
-
152
- results = client.run_tasks_until_done(tasks_to_do)
153
-
154
- assert len(results) == len(tasks_to_do), "Should return 2 tasks."
155
- assert all(task.status == "success" for task in results)
156
-
157
-
158
- @pytest.mark.timeout(300)
159
- @pytest.mark.flaky(reruns=3)
160
- @pytest.mark.asyncio
161
- async def test_arun_until_done_futurehouse_dummy_env_crow():
162
- client = RestClient(
163
- stage=Stage.DEV,
164
- api_key=ADMIN_API_KEY,
165
- )
166
-
167
- task_data = TaskRequest(
168
- name=JobNames.from_string("dummy"),
169
- query="How many moons does earth have?",
170
- )
171
-
172
- tasks_to_do = [task_data, task_data]
173
-
174
- results = await client.arun_tasks_until_done(tasks_to_do)
175
-
176
- assert len(results) == len(tasks_to_do), "Should return 2 tasks."
177
- assert all(task.status == "success" for task in results)
178
-
179
-
180
- @pytest.mark.timeout(300)
181
- @pytest.mark.flaky(reruns=3)
182
- @pytest.mark.asyncio
183
- async def test_timeout_run_until_done_futurehouse_dummy_env_crow():
184
- client = RestClient(
185
- stage=Stage.DEV,
186
- api_key=ADMIN_API_KEY,
187
- )
188
-
189
- task_data = TaskRequest(
190
- name=JobNames.from_string("dummy"),
191
- query="How many moons does earth have?",
192
- )
193
-
194
- tasks_to_do = [task_data, task_data]
195
-
196
- results = await client.arun_tasks_until_done(
197
- tasks_to_do, verbose=True, timeout=5, progress_bar=True
198
- )
199
-
200
- assert len(results) == len(tasks_to_do), "Should return 2 tasks."
201
- assert all(task.status != "success" for task in results), "Should not be success."
202
- assert all(not isinstance(task, PQATaskResponse) for task in results), (
203
- "Should be verbose."
204
- )
205
-
206
- results = client.run_tasks_until_done(
207
- tasks_to_do, verbose=True, timeout=5, progress_bar=True
208
- )
209
-
210
- assert len(results) == len(tasks_to_do), "Should return 2 tasks."
211
- assert all(task.status != "success" for task in results), "Should not be success."
212
- assert all(not isinstance(task, PQATaskResponse) for task in results), (
213
- "Should be verbose."
214
- )