futurehouse-client 0.4.0__py3-none-any.whl → 0.4.1.dev95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/__init__.py +8 -0
- futurehouse_client/clients/data_storage_methods.py +1876 -0
- futurehouse_client/clients/rest_client.py +110 -29
- futurehouse_client/models/data_storage_methods.py +333 -0
- futurehouse_client/utils/general.py +34 -0
- futurehouse_client/utils/world_model_tools.py +69 -0
- futurehouse_client/version.py +3 -16
- {futurehouse_client-0.4.0.dist-info → futurehouse_client-0.4.1.dev95.dist-info}/METADATA +6 -1
- futurehouse_client-0.4.1.dev95.dist-info/RECORD +23 -0
- futurehouse_client-0.4.0.dist-info/RECORD +0 -20
- {futurehouse_client-0.4.0.dist-info → futurehouse_client-0.4.1.dev95.dist-info}/WHEEL +0 -0
- {futurehouse_client-0.4.0.dist-info → futurehouse_client-0.4.1.dev95.dist-info}/licenses/LICENSE +0 -0
- {futurehouse_client-0.4.0.dist-info → futurehouse_client-0.4.1.dev95.dist-info}/top_level.txt +0 -0
@@ -26,21 +26,14 @@ from httpx import (
|
|
26
26
|
AsyncClient,
|
27
27
|
Client,
|
28
28
|
CloseError,
|
29
|
-
ConnectError,
|
30
|
-
ConnectTimeout,
|
31
29
|
HTTPStatusError,
|
32
|
-
NetworkError,
|
33
|
-
ReadError,
|
34
|
-
ReadTimeout,
|
35
30
|
RemoteProtocolError,
|
36
31
|
codes,
|
37
32
|
)
|
38
33
|
from ldp.agent import AgentConfig
|
39
|
-
from requests.exceptions import RequestException, Timeout
|
40
34
|
from tenacity import (
|
41
35
|
before_sleep_log,
|
42
36
|
retry,
|
43
|
-
retry_if_exception_type,
|
44
37
|
stop_after_attempt,
|
45
38
|
wait_exponential,
|
46
39
|
)
|
@@ -48,6 +41,7 @@ from tqdm import tqdm as sync_tqdm
|
|
48
41
|
from tqdm.asyncio import tqdm
|
49
42
|
|
50
43
|
from futurehouse_client.clients import JobNames
|
44
|
+
from futurehouse_client.clients.data_storage_methods import DataStorageMethods
|
51
45
|
from futurehouse_client.models.app import (
|
52
46
|
AuthType,
|
53
47
|
JobDeploymentConfig,
|
@@ -68,7 +62,10 @@ from futurehouse_client.models.rest import (
|
|
68
62
|
WorldModelResponse,
|
69
63
|
)
|
70
64
|
from futurehouse_client.utils.auth import RefreshingJWT
|
71
|
-
from futurehouse_client.utils.general import
|
65
|
+
from futurehouse_client.utils.general import (
|
66
|
+
create_retry_if_connection_error,
|
67
|
+
gather_with_concurrency,
|
68
|
+
)
|
72
69
|
from futurehouse_client.utils.module_utils import (
|
73
70
|
OrganizationSelector,
|
74
71
|
fetch_environment_function_docstring,
|
@@ -136,6 +133,10 @@ class WorldModelCreationError(RestClientError):
|
|
136
133
|
"""Raised when there's an error creating a world model."""
|
137
134
|
|
138
135
|
|
136
|
+
class WorldModelDeletionError(RestClientError):
|
137
|
+
"""Raised when there's an error deleting a world model."""
|
138
|
+
|
139
|
+
|
139
140
|
class ProjectError(RestClientError):
|
140
141
|
"""Raised when there's an error with trajectory group operations."""
|
141
142
|
|
@@ -156,28 +157,15 @@ class FileUploadError(RestClientError):
|
|
156
157
|
"""Raised when there's an error uploading a file."""
|
157
158
|
|
158
159
|
|
159
|
-
retry_if_connection_error =
|
160
|
-
# From requests
|
161
|
-
Timeout,
|
162
|
-
ConnectionError,
|
163
|
-
RequestException,
|
164
|
-
# From httpx
|
165
|
-
ConnectError,
|
166
|
-
ConnectTimeout,
|
167
|
-
ReadTimeout,
|
168
|
-
ReadError,
|
169
|
-
NetworkError,
|
170
|
-
RemoteProtocolError,
|
171
|
-
CloseError,
|
172
|
-
FileUploadError,
|
173
|
-
))
|
160
|
+
retry_if_connection_error = create_retry_if_connection_error(FileUploadError)
|
174
161
|
|
175
162
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
176
163
|
|
177
164
|
|
178
165
|
# pylint: disable=too-many-public-methods
|
179
|
-
class RestClient:
|
180
|
-
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
166
|
+
class RestClient(DataStorageMethods):
|
167
|
+
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec - for general API calls
|
168
|
+
FILE_UPLOAD_TIMEOUT: ClassVar[float] = 600.0 # 10 minutes - for file uploads
|
181
169
|
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
182
170
|
RETRY_MULTIPLIER: ClassVar[int] = 1
|
183
171
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
@@ -235,11 +223,35 @@ class RestClient:
|
|
235
223
|
"""Authenticated HTTP client for multipart uploads."""
|
236
224
|
return cast(Client, self.get_client(None, authenticated=True))
|
237
225
|
|
226
|
+
@property
|
227
|
+
def file_upload_client(self) -> Client:
|
228
|
+
"""Authenticated HTTP client with extended timeout for file uploads."""
|
229
|
+
return cast(
|
230
|
+
Client,
|
231
|
+
self.get_client(
|
232
|
+
"application/json", authenticated=True, timeout=self.FILE_UPLOAD_TIMEOUT
|
233
|
+
),
|
234
|
+
)
|
235
|
+
|
236
|
+
@property
|
237
|
+
def async_file_upload_client(self) -> AsyncClient:
|
238
|
+
"""Authenticated async HTTP client with extended timeout for file uploads."""
|
239
|
+
return cast(
|
240
|
+
AsyncClient,
|
241
|
+
self.get_client(
|
242
|
+
"application/json",
|
243
|
+
authenticated=True,
|
244
|
+
async_client=True,
|
245
|
+
timeout=self.FILE_UPLOAD_TIMEOUT,
|
246
|
+
),
|
247
|
+
)
|
248
|
+
|
238
249
|
def get_client(
|
239
250
|
self,
|
240
251
|
content_type: str | None = "application/json",
|
241
252
|
authenticated: bool = True,
|
242
253
|
async_client: bool = False,
|
254
|
+
timeout: float | None = None,
|
243
255
|
) -> Client | AsyncClient:
|
244
256
|
"""Return a cached HTTP client or create one if needed.
|
245
257
|
|
@@ -247,12 +259,13 @@ class RestClient:
|
|
247
259
|
content_type: The desired content type header. Use None for multipart uploads.
|
248
260
|
authenticated: Whether the client should include authentication.
|
249
261
|
async_client: Whether to use an async client.
|
262
|
+
timeout: Custom timeout in seconds. Uses REQUEST_TIMEOUT if not provided.
|
250
263
|
|
251
264
|
Returns:
|
252
265
|
An HTTP client configured with the appropriate headers.
|
253
266
|
"""
|
254
|
-
|
255
|
-
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
267
|
+
client_timeout = timeout or self.REQUEST_TIMEOUT
|
268
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}_{client_timeout}"
|
256
269
|
|
257
270
|
if key not in self._clients:
|
258
271
|
headers = copy.deepcopy(self.headers)
|
@@ -278,14 +291,14 @@ class RestClient:
|
|
278
291
|
AsyncClient(
|
279
292
|
base_url=self.base_url,
|
280
293
|
headers=headers,
|
281
|
-
timeout=
|
294
|
+
timeout=client_timeout,
|
282
295
|
auth=auth,
|
283
296
|
)
|
284
297
|
if async_client
|
285
298
|
else Client(
|
286
299
|
base_url=self.base_url,
|
287
300
|
headers=headers,
|
288
|
-
timeout=
|
301
|
+
timeout=client_timeout,
|
289
302
|
auth=auth,
|
290
303
|
)
|
291
304
|
)
|
@@ -1593,6 +1606,48 @@ class RestClient:
|
|
1593
1606
|
except Exception as e:
|
1594
1607
|
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1595
1608
|
|
1609
|
+
@retry(
|
1610
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1611
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1612
|
+
retry=retry_if_connection_error,
|
1613
|
+
)
|
1614
|
+
def search_world_models(
|
1615
|
+
self,
|
1616
|
+
query: str,
|
1617
|
+
size: int = 10,
|
1618
|
+
total_search_size: int = 50,
|
1619
|
+
search_all_versions: bool = False,
|
1620
|
+
) -> list[str]:
|
1621
|
+
"""Search for world models.
|
1622
|
+
|
1623
|
+
Args:
|
1624
|
+
query: The search query.
|
1625
|
+
size: The number of results to return.
|
1626
|
+
total_search_size: The number of results to search for.
|
1627
|
+
search_all_versions: Whether to search all versions of the world model or just the latest one.
|
1628
|
+
|
1629
|
+
Returns:
|
1630
|
+
A list of world model names.
|
1631
|
+
"""
|
1632
|
+
try:
|
1633
|
+
response = self.client.get(
|
1634
|
+
"/v0.1/world-models/search/",
|
1635
|
+
params={
|
1636
|
+
"query": query,
|
1637
|
+
"size": size,
|
1638
|
+
"total_search_size": total_search_size,
|
1639
|
+
"search_all_versions": search_all_versions,
|
1640
|
+
},
|
1641
|
+
)
|
1642
|
+
response.raise_for_status()
|
1643
|
+
return response.json()
|
1644
|
+
except HTTPStatusError as e:
|
1645
|
+
raise WorldModelFetchError(
|
1646
|
+
f"Error searching world models: {e.response.status_code} - {e.response.text}"
|
1647
|
+
) from e
|
1648
|
+
except Exception as e:
|
1649
|
+
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1650
|
+
|
1596
1651
|
@retry(
|
1597
1652
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1598
1653
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1668,6 +1723,32 @@ class RestClient:
|
|
1668
1723
|
f"An unexpected error occurred during world model creation: {e!r}."
|
1669
1724
|
) from e
|
1670
1725
|
|
1726
|
+
@retry(
|
1727
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1728
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1729
|
+
retry=retry_if_connection_error,
|
1730
|
+
)
|
1731
|
+
async def delete_world_model(self, world_model_id: UUID) -> None:
|
1732
|
+
"""Delete a world model snapshot by its ID.
|
1733
|
+
|
1734
|
+
Args:
|
1735
|
+
world_model_id: The unique ID of the world model snapshot to delete.
|
1736
|
+
|
1737
|
+
Raises:
|
1738
|
+
WorldModelDeletionError: If the API call fails.
|
1739
|
+
"""
|
1740
|
+
try:
|
1741
|
+
response = await self.async_client.delete(
|
1742
|
+
f"/v0.1/world-models/{world_model_id}"
|
1743
|
+
)
|
1744
|
+
response.raise_for_status()
|
1745
|
+
except HTTPStatusError as e:
|
1746
|
+
raise WorldModelDeletionError(
|
1747
|
+
f"Error deleting world model: {e.response.status_code} - {e.response.text}"
|
1748
|
+
) from e
|
1749
|
+
except Exception as e:
|
1750
|
+
raise WorldModelDeletionError(f"An unexpected error occurred: {e}") from e
|
1751
|
+
|
1671
1752
|
@retry(
|
1672
1753
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1673
1754
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -0,0 +1,333 @@
|
|
1
|
+
import contextlib
|
2
|
+
from datetime import datetime
|
3
|
+
from enum import StrEnum, auto
|
4
|
+
from os import PathLike
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any
|
7
|
+
from uuid import UUID
|
8
|
+
|
9
|
+
from pydantic import BaseModel, Field, JsonValue
|
10
|
+
|
11
|
+
|
12
|
+
class DataStorageEntry(BaseModel):
|
13
|
+
"""Model representing a data storage entry."""
|
14
|
+
|
15
|
+
id: UUID = Field(description="Unique identifier for the data storage entry")
|
16
|
+
name: str = Field(description="Name of the data storage entry")
|
17
|
+
description: str | None = Field(
|
18
|
+
default=None, description="Description of the data storage entry"
|
19
|
+
)
|
20
|
+
content: str | None = Field(
|
21
|
+
default=None, description="Content of the data storage entry"
|
22
|
+
)
|
23
|
+
embedding: list[float] | None = Field(
|
24
|
+
default=None, description="Embedding vector for the content"
|
25
|
+
)
|
26
|
+
is_collection: bool = Field(
|
27
|
+
default=False, description="Whether this entry is a collection"
|
28
|
+
)
|
29
|
+
tags: list[str] | None = Field(
|
30
|
+
default=None,
|
31
|
+
description="List of tags associated with the data storage entry",
|
32
|
+
)
|
33
|
+
parent_id: UUID | None = Field(
|
34
|
+
default=None,
|
35
|
+
description="ID of the parent entry if this is a sub-entry for hierarchical storage",
|
36
|
+
)
|
37
|
+
dataset_id: UUID | None = Field(
|
38
|
+
default=None,
|
39
|
+
description="ID of the dataset this entry belongs to",
|
40
|
+
)
|
41
|
+
path: str | None = Field(
|
42
|
+
default=None,
|
43
|
+
description="Path in the storage system where this entry is located, if a file.",
|
44
|
+
)
|
45
|
+
bigquery_schema: Any | None = Field(
|
46
|
+
default=None, description="Target BigQuery schema for the data storage entry"
|
47
|
+
)
|
48
|
+
user_id: str = Field(description="ID of the user who created this entry")
|
49
|
+
created_at: datetime = Field(description="Timestamp when the entry was created")
|
50
|
+
modified_at: datetime = Field(
|
51
|
+
description="Timestamp when the entry was last updated"
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
class DataStorageType(StrEnum):
|
56
|
+
BIGQUERY = auto()
|
57
|
+
GCS = auto()
|
58
|
+
PG_TABLE = auto()
|
59
|
+
RAW_CONTENT = auto()
|
60
|
+
ELASTIC_SEARCH = auto()
|
61
|
+
|
62
|
+
|
63
|
+
class DataContentType(StrEnum):
|
64
|
+
BQ_DATASET = auto()
|
65
|
+
BQ_TABLE = auto()
|
66
|
+
TEXT = auto()
|
67
|
+
TEXT_W_EMBEDDINGS = auto()
|
68
|
+
DIRECTORY = auto()
|
69
|
+
FILE = auto()
|
70
|
+
INDEX = auto()
|
71
|
+
INDEX_W_EMBEDDINGS = auto()
|
72
|
+
|
73
|
+
|
74
|
+
class DataStorageLocationPayload(BaseModel):
|
75
|
+
storage_type: DataStorageType
|
76
|
+
content_type: DataContentType
|
77
|
+
content_schema: JsonValue | None = None
|
78
|
+
metadata: JsonValue | None = None
|
79
|
+
location: str | None = None
|
80
|
+
|
81
|
+
|
82
|
+
class DataStorageLocationDetails(BaseModel):
|
83
|
+
"""Model representing the location details within a DataStorageLocations object."""
|
84
|
+
|
85
|
+
storage_type: str = Field(description="Type of storage (e.g., 'gcs', 'pg_table')")
|
86
|
+
content_type: str = Field(description="Type of content stored")
|
87
|
+
content_schema: JsonValue | None = Field(default=None, description="Content schema")
|
88
|
+
metadata: JsonValue | None = Field(default=None, description="Location metadata")
|
89
|
+
location: str | None = Field(
|
90
|
+
default=None, description="Location path or identifier"
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
class DataStorageLocations(BaseModel):
|
95
|
+
"""Model representing storage locations for a data storage entry."""
|
96
|
+
|
97
|
+
id: UUID = Field(description="Unique identifier for the storage locations")
|
98
|
+
data_storage_id: UUID = Field(description="ID of the associated data storage entry")
|
99
|
+
storage_config: DataStorageLocationDetails = Field(
|
100
|
+
description="Storage configuration details"
|
101
|
+
)
|
102
|
+
created_at: datetime = Field(description="Timestamp when the location was created")
|
103
|
+
|
104
|
+
|
105
|
+
class DataStorageResponse(BaseModel):
|
106
|
+
"""Response model for data storage operations."""
|
107
|
+
|
108
|
+
data_storage: DataStorageEntry = Field(description="The created data storage entry")
|
109
|
+
storage_location: DataStorageLocations = Field(
|
110
|
+
description="Storage location for this data entry"
|
111
|
+
)
|
112
|
+
signed_url: str | None = Field(
|
113
|
+
default=None,
|
114
|
+
description="Signed URL for uploading/downloading the file to/from GCS",
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
class DataStorageRequestPayload(BaseModel):
|
119
|
+
"""Payload for creating a data storage entry."""
|
120
|
+
|
121
|
+
name: str = Field(description="Name of the data storage entry")
|
122
|
+
description: str | None = Field(
|
123
|
+
default=None, description="Description of the data storage entry"
|
124
|
+
)
|
125
|
+
content: str | None = Field(
|
126
|
+
default=None, description="Content of the data storage entry"
|
127
|
+
)
|
128
|
+
is_collection: bool = Field(
|
129
|
+
default=False, description="Whether this entry is a collection"
|
130
|
+
)
|
131
|
+
parent_id: UUID | None = Field(
|
132
|
+
default=None, description="ID of the parent entry for hierarchical storage"
|
133
|
+
)
|
134
|
+
dataset_id: UUID | None = Field(
|
135
|
+
default=None,
|
136
|
+
description="ID of existing dataset to add entry to, or None to create new dataset",
|
137
|
+
)
|
138
|
+
path: PathLike | str | None = Field(
|
139
|
+
default=None,
|
140
|
+
description="Path to store in the GCS bucket.",
|
141
|
+
)
|
142
|
+
existing_location: DataStorageLocationPayload | None = Field(
|
143
|
+
default=None, description="Target storage metadata"
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
class ManifestEntry(BaseModel):
|
148
|
+
"""Model representing a single entry in a manifest file."""
|
149
|
+
|
150
|
+
description: str | None = Field(
|
151
|
+
default=None, description="Description of the file or directory"
|
152
|
+
)
|
153
|
+
metadata: dict[str, Any] | None = Field(
|
154
|
+
default=None, description="Additional metadata for the entry"
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
class DirectoryManifest(BaseModel):
|
159
|
+
"""Model representing the structure of a manifest file."""
|
160
|
+
|
161
|
+
entries: dict[str, "ManifestEntry | DirectoryManifest"] = Field(
|
162
|
+
default_factory=dict,
|
163
|
+
description="Map of file/directory names to their manifest entries",
|
164
|
+
)
|
165
|
+
|
166
|
+
def get_entry_description(self, name: str) -> str | None:
|
167
|
+
"""Get description for a specific entry."""
|
168
|
+
entry = self.entries.get(name)
|
169
|
+
if isinstance(entry, ManifestEntry):
|
170
|
+
return entry.description
|
171
|
+
if isinstance(entry, DirectoryManifest):
|
172
|
+
# For nested directories, could derive description from contents
|
173
|
+
return None
|
174
|
+
return None
|
175
|
+
|
176
|
+
def get_entry_metadata(self, name: str) -> dict[str, Any] | None:
|
177
|
+
"""Get metadata for a specific entry."""
|
178
|
+
entry = self.entries.get(name)
|
179
|
+
if isinstance(entry, ManifestEntry):
|
180
|
+
return entry.metadata
|
181
|
+
return None
|
182
|
+
|
183
|
+
@classmethod
|
184
|
+
def from_dict(cls, data: dict[str, Any]) -> "DirectoryManifest":
|
185
|
+
"""Create DirectoryManifest from a dictionary (loaded from JSON/YAML)."""
|
186
|
+
entries: dict[str, ManifestEntry | DirectoryManifest] = {}
|
187
|
+
for name, value in data.items():
|
188
|
+
if isinstance(value, dict):
|
189
|
+
if "description" in value or "metadata" in value:
|
190
|
+
# This looks like a ManifestEntry
|
191
|
+
entries[name] = ManifestEntry(**value)
|
192
|
+
else:
|
193
|
+
# This looks like a nested directory
|
194
|
+
entries[name] = cls.from_dict(value)
|
195
|
+
else:
|
196
|
+
# Simple string description
|
197
|
+
entries[name] = ManifestEntry(description=str(value))
|
198
|
+
|
199
|
+
return cls(entries=entries)
|
200
|
+
|
201
|
+
def to_dict(self) -> dict[str, Any]:
|
202
|
+
"""Convert back to dictionary format."""
|
203
|
+
result = {}
|
204
|
+
for name, entry in self.entries.items():
|
205
|
+
if isinstance(entry, ManifestEntry):
|
206
|
+
if entry.description is not None or entry.metadata is not None:
|
207
|
+
entry_dict = {}
|
208
|
+
if entry.description is not None:
|
209
|
+
entry_dict["description"] = entry.description
|
210
|
+
if entry.metadata is not None:
|
211
|
+
entry_dict.update(entry.metadata)
|
212
|
+
result[name] = entry_dict
|
213
|
+
elif isinstance(entry, DirectoryManifest):
|
214
|
+
result[name] = entry.to_dict()
|
215
|
+
return result
|
216
|
+
|
217
|
+
|
218
|
+
class FileMetadata(BaseModel):
|
219
|
+
"""Model representing metadata for a file being processed."""
|
220
|
+
|
221
|
+
path: Path = Field(description="Path to the file")
|
222
|
+
name: str = Field(description="Name of the file")
|
223
|
+
size: int | None = Field(default=None, description="Size of the file in bytes")
|
224
|
+
description: str | None = Field(
|
225
|
+
default=None, description="Description from manifest or generated"
|
226
|
+
)
|
227
|
+
is_directory: bool = Field(default=False, description="Whether this is a directory")
|
228
|
+
parent_id: UUID | None = Field(
|
229
|
+
default=None, description="Parent directory ID in the storage system"
|
230
|
+
)
|
231
|
+
dataset_id: UUID | None = Field(
|
232
|
+
default=None, description="Dataset ID this file belongs to"
|
233
|
+
)
|
234
|
+
|
235
|
+
@classmethod
|
236
|
+
def from_path(
|
237
|
+
cls,
|
238
|
+
path: Path,
|
239
|
+
description: str | None = None,
|
240
|
+
parent_id: UUID | None = None,
|
241
|
+
dataset_id: UUID | None = None,
|
242
|
+
) -> "FileMetadata":
|
243
|
+
"""Create FileMetadata from a Path object."""
|
244
|
+
size = None
|
245
|
+
is_directory = path.is_dir()
|
246
|
+
|
247
|
+
if not is_directory:
|
248
|
+
with contextlib.suppress(OSError):
|
249
|
+
size = path.stat().st_size
|
250
|
+
|
251
|
+
return cls(
|
252
|
+
path=path,
|
253
|
+
name=path.name,
|
254
|
+
size=size,
|
255
|
+
description=description,
|
256
|
+
is_directory=is_directory,
|
257
|
+
parent_id=parent_id,
|
258
|
+
dataset_id=dataset_id,
|
259
|
+
)
|
260
|
+
|
261
|
+
|
262
|
+
class UploadProgress(BaseModel):
|
263
|
+
"""Model for tracking upload progress."""
|
264
|
+
|
265
|
+
total_files: int = Field(description="Total number of files to upload")
|
266
|
+
uploaded_files: int = Field(default=0, description="Number of files uploaded")
|
267
|
+
total_bytes: int | None = Field(default=None, description="Total bytes to upload")
|
268
|
+
uploaded_bytes: int = Field(default=0, description="Number of bytes uploaded")
|
269
|
+
current_file: str | None = Field(
|
270
|
+
default=None, description="Currently uploading file"
|
271
|
+
)
|
272
|
+
errors: list[str] = Field(
|
273
|
+
default_factory=list, description="List of error messages"
|
274
|
+
)
|
275
|
+
|
276
|
+
@property
|
277
|
+
def progress_percentage(self) -> float:
|
278
|
+
"""Calculate progress percentage based on files."""
|
279
|
+
if self.total_files == 0:
|
280
|
+
return 0.0
|
281
|
+
return (self.uploaded_files / self.total_files) * 100.0
|
282
|
+
|
283
|
+
@property
|
284
|
+
def bytes_percentage(self) -> float | None:
|
285
|
+
"""Calculate progress percentage based on bytes."""
|
286
|
+
if not self.total_bytes or self.total_bytes == 0:
|
287
|
+
return None
|
288
|
+
return (self.uploaded_bytes / self.total_bytes) * 100.0
|
289
|
+
|
290
|
+
def add_error(self, error: str) -> None:
|
291
|
+
"""Add an error message."""
|
292
|
+
self.errors.append(error)
|
293
|
+
|
294
|
+
def increment_files(self, bytes_uploaded: int = 0) -> None:
|
295
|
+
"""Increment the uploaded files counter."""
|
296
|
+
self.uploaded_files += 1
|
297
|
+
self.uploaded_bytes += bytes_uploaded
|
298
|
+
|
299
|
+
|
300
|
+
class DirectoryUploadConfig(BaseModel):
|
301
|
+
"""Configuration for directory uploads."""
|
302
|
+
|
303
|
+
name: str = Field(description="Name for the directory upload")
|
304
|
+
description: str | None = Field(
|
305
|
+
default=None, description="Description for the directory"
|
306
|
+
)
|
307
|
+
as_collection: bool = Field(
|
308
|
+
default=False, description="Upload as single collection or hierarchically"
|
309
|
+
)
|
310
|
+
manifest_filename: str | None = Field(
|
311
|
+
default=None, description="Name of manifest file to use"
|
312
|
+
)
|
313
|
+
ignore_patterns: list[str] = Field(
|
314
|
+
default_factory=list, description="Patterns to ignore"
|
315
|
+
)
|
316
|
+
ignore_filename: str = Field(
|
317
|
+
default=".gitignore", description="Name of ignore file to read"
|
318
|
+
)
|
319
|
+
base_path: str | None = Field(default=None, description="Base path for storage")
|
320
|
+
parent_id: UUID | None = Field(default=None, description="Parent directory ID")
|
321
|
+
dataset_id: UUID | None = Field(default=None, description="Dataset ID to use")
|
322
|
+
|
323
|
+
def with_parent(
|
324
|
+
self, parent_id: UUID, dataset_id: UUID | None = None
|
325
|
+
) -> "DirectoryUploadConfig":
|
326
|
+
"""Create a new config with parent and dataset IDs set."""
|
327
|
+
return self.model_copy(
|
328
|
+
update={"parent_id": parent_id, "dataset_id": dataset_id or self.dataset_id}
|
329
|
+
)
|
330
|
+
|
331
|
+
|
332
|
+
# Forward reference resolution for DirectoryManifest
|
333
|
+
DirectoryManifest.model_rebuild()
|
@@ -2,11 +2,45 @@ import asyncio
|
|
2
2
|
from collections.abc import Awaitable, Iterable
|
3
3
|
from typing import TypeVar
|
4
4
|
|
5
|
+
from httpx import (
|
6
|
+
CloseError,
|
7
|
+
ConnectError,
|
8
|
+
ConnectTimeout,
|
9
|
+
NetworkError,
|
10
|
+
ReadError,
|
11
|
+
ReadTimeout,
|
12
|
+
RemoteProtocolError,
|
13
|
+
)
|
14
|
+
from requests.exceptions import RequestException, Timeout
|
15
|
+
from tenacity import retry_if_exception_type
|
5
16
|
from tqdm.asyncio import tqdm
|
6
17
|
|
7
18
|
T = TypeVar("T")
|
8
19
|
|
9
20
|
|
21
|
+
_BASE_CONNECTION_ERRORS = (
|
22
|
+
# From requests
|
23
|
+
Timeout,
|
24
|
+
ConnectionError,
|
25
|
+
RequestException,
|
26
|
+
# From httpx
|
27
|
+
ConnectError,
|
28
|
+
ConnectTimeout,
|
29
|
+
ReadTimeout,
|
30
|
+
ReadError,
|
31
|
+
NetworkError,
|
32
|
+
RemoteProtocolError,
|
33
|
+
CloseError,
|
34
|
+
)
|
35
|
+
|
36
|
+
retry_if_connection_error = retry_if_exception_type(_BASE_CONNECTION_ERRORS)
|
37
|
+
|
38
|
+
|
39
|
+
def create_retry_if_connection_error(*additional_exceptions):
|
40
|
+
"""Create a retry condition with base connection errors plus additional exceptions."""
|
41
|
+
return retry_if_exception_type(_BASE_CONNECTION_ERRORS + additional_exceptions)
|
42
|
+
|
43
|
+
|
10
44
|
async def gather_with_concurrency(
|
11
45
|
n: int | asyncio.Semaphore, coros: Iterable[Awaitable[T]], progress: bool = False
|
12
46
|
) -> list[T]:
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import os
|
2
|
+
from uuid import UUID
|
3
|
+
|
4
|
+
from aviary.core import Tool
|
5
|
+
|
6
|
+
from futurehouse_client.clients.rest_client import RestClient
|
7
|
+
from futurehouse_client.models.app import Stage
|
8
|
+
from futurehouse_client.models.rest import WorldModel
|
9
|
+
|
10
|
+
|
11
|
+
class WorldModelTools:
|
12
|
+
_client: RestClient | None = None
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def _get_client(cls) -> RestClient:
|
16
|
+
"""Lazy initialization of the RestClient to avoid validation errors during import."""
|
17
|
+
if cls._client is None:
|
18
|
+
api_key = os.getenv("FH_PLATFORM_API_KEY")
|
19
|
+
if not api_key:
|
20
|
+
raise ValueError(
|
21
|
+
"FH_PLATFORM_API_KEY environment variable is required for WorldModelTools"
|
22
|
+
)
|
23
|
+
cls._client = RestClient(
|
24
|
+
stage=Stage.from_string(os.getenv("CROW_ENV", "dev")),
|
25
|
+
api_key=api_key,
|
26
|
+
)
|
27
|
+
return cls._client
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def create_world_model(name: str, description: str, content: str) -> UUID:
|
31
|
+
"""Create a new world model.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
name: The name of the world model.
|
35
|
+
description: A description of the world model.
|
36
|
+
content: The content/data of the world model.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
UUID: The ID of the newly created world model.
|
40
|
+
"""
|
41
|
+
world_model = WorldModel(
|
42
|
+
name=name,
|
43
|
+
description=description,
|
44
|
+
content=content,
|
45
|
+
)
|
46
|
+
return WorldModelTools._get_client().create_world_model(world_model)
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def search_world_models(query: str) -> list[str]:
|
50
|
+
"""Search for world models using a text query.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
query: The search query string to match against world model content.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
list[str]: A list of world model IDs that match the search query.
|
57
|
+
"""
|
58
|
+
return WorldModelTools._get_client().search_world_models(query, size=1)
|
59
|
+
|
60
|
+
|
61
|
+
create_world_model_tool = Tool.from_function(WorldModelTools.create_world_model)
|
62
|
+
search_world_model_tool = Tool.from_function(WorldModelTools.search_world_models)
|
63
|
+
|
64
|
+
|
65
|
+
def make_world_model_tools() -> list[Tool]:
|
66
|
+
return [
|
67
|
+
search_world_model_tool,
|
68
|
+
create_world_model_tool,
|
69
|
+
]
|