orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3942 -0
- orca_sdk/classification_model.py +218 -20
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +899 -712
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +64 -12
- orca_sdk/datasource_test.py +144 -18
- orca_sdk/embedding_model.py +54 -37
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +27 -21
- orca_sdk/memoryset.py +823 -205
- orca_sdk/memoryset_test.py +315 -33
- orca_sdk/regression_model.py +59 -15
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +76 -26
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/METADATA +1 -1
- orca_sdk-0.1.4.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/WHEEL +0 -0
orca_sdk/conftest.py
CHANGED
|
@@ -4,11 +4,13 @@ from typing import Generator
|
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
6
|
import pytest
|
|
7
|
+
import pytest_asyncio
|
|
7
8
|
from datasets import ClassLabel, Dataset, Features, Value
|
|
8
9
|
|
|
9
10
|
from ._utils.auth import _create_api_key, _delete_org
|
|
11
|
+
from .async_client import OrcaAsyncClient
|
|
10
12
|
from .classification_model import ClassificationModel
|
|
11
|
-
from .client import
|
|
13
|
+
from .client import OrcaClient
|
|
12
14
|
from .credentials import OrcaCredentials
|
|
13
15
|
from .datasource import Datasource
|
|
14
16
|
from .embedding_model import PretrainedEmbeddingModel
|
|
@@ -44,51 +46,51 @@ def _create_org_id():
|
|
|
44
46
|
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
@pytest.fixture()
|
|
48
|
-
def
|
|
49
|
-
|
|
50
|
-
yield
|
|
51
|
-
orca_api.base_url = original_base_url
|
|
49
|
+
@pytest.fixture(scope="session")
|
|
50
|
+
def org_id():
|
|
51
|
+
return _create_org_id()
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
@pytest.fixture(scope="session")
|
|
55
|
-
def
|
|
55
|
+
def other_org_id():
|
|
56
56
|
return _create_org_id()
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@pytest.fixture(autouse=True, scope="session")
|
|
60
60
|
def api_key(org_id) -> Generator[str, None, None]:
|
|
61
61
|
api_key = _create_api_key(org_id=org_id, name="orca_sdk_test")
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
with OrcaClient(api_key=api_key).use():
|
|
63
|
+
yield api_key
|
|
64
64
|
_delete_org(org_id)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
+
# We cannot use a session scoped fixture because async pytest tears down the client after each test
|
|
67
68
|
@pytest.fixture(autouse=True)
|
|
68
|
-
def
|
|
69
|
-
|
|
69
|
+
def authenticate_async_client(api_key) -> Generator[None, None, None]:
|
|
70
|
+
with OrcaAsyncClient(api_key=api_key).use():
|
|
71
|
+
yield
|
|
70
72
|
|
|
71
73
|
|
|
72
|
-
@pytest.fixture()
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
yield
|
|
76
|
-
# Need to reset the api key to the original api key so following tests don't fail
|
|
77
|
-
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
74
|
+
@pytest.fixture(scope="session")
|
|
75
|
+
def unauthenticated_client() -> OrcaClient:
|
|
76
|
+
return OrcaClient(api_key=str(uuid4()))
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
@
|
|
81
|
-
def
|
|
82
|
-
return
|
|
79
|
+
@pytest_asyncio.fixture()
|
|
80
|
+
def unauthenticated_async_client() -> OrcaAsyncClient:
|
|
81
|
+
return OrcaAsyncClient(api_key=str(uuid4()))
|
|
83
82
|
|
|
84
83
|
|
|
85
|
-
@pytest.fixture()
|
|
86
|
-
def
|
|
84
|
+
@pytest.fixture(scope="session")
|
|
85
|
+
def unauthorized_client(other_org_id):
|
|
87
86
|
different_api_key = _create_api_key(org_id=other_org_id, name="orca_sdk_test_other_org")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
87
|
+
return OrcaClient(api_key=different_api_key)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.fixture(scope="session")
|
|
91
|
+
def predict_only_client() -> OrcaClient:
|
|
92
|
+
predict_api_key = OrcaCredentials.create_api_key("orca_sdk_test_predict", scopes={"PREDICT"})
|
|
93
|
+
return OrcaClient(api_key=predict_api_key)
|
|
92
94
|
|
|
93
95
|
|
|
94
96
|
@pytest.fixture(scope="session")
|
|
@@ -209,18 +211,17 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
209
211
|
yield memoryset
|
|
210
212
|
finally:
|
|
211
213
|
# Restore the memoryset to a clean state for the next test.
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
memoryset.refresh()
|
|
214
|
+
with OrcaClient(api_key=api_key).use():
|
|
215
|
+
if LabeledMemoryset.exists("test_writable_memoryset"):
|
|
216
|
+
memoryset.refresh()
|
|
216
217
|
|
|
217
|
-
|
|
218
|
+
memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
|
|
218
219
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
220
|
+
if memory_ids:
|
|
221
|
+
memoryset.delete(memory_ids)
|
|
222
|
+
memoryset.refresh()
|
|
223
|
+
assert len(memoryset) == 0
|
|
224
|
+
memoryset.insert(SAMPLE_DATA)
|
|
224
225
|
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
225
226
|
|
|
226
227
|
|
orca_sdk/credentials.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from datetime import datetime
|
|
2
3
|
from typing import Literal, NamedTuple
|
|
3
4
|
|
|
4
5
|
import httpx
|
|
5
|
-
from httpx import ConnectError, Headers
|
|
6
|
+
from httpx import ConnectError, Headers, HTTPTransport
|
|
7
|
+
from typing_extensions import deprecated
|
|
6
8
|
|
|
7
|
-
from .
|
|
9
|
+
from .async_client import OrcaAsyncClient
|
|
10
|
+
from .client import OrcaClient
|
|
8
11
|
|
|
9
12
|
Scope = Literal["ADMINISTER", "PREDICT"]
|
|
10
13
|
"""
|
|
@@ -15,19 +18,31 @@ The scopes of an API key.
|
|
|
15
18
|
"""
|
|
16
19
|
|
|
17
20
|
|
|
18
|
-
class ApiKeyInfo
|
|
21
|
+
class ApiKeyInfo:
|
|
19
22
|
"""
|
|
20
|
-
|
|
23
|
+
Information about an API key
|
|
24
|
+
|
|
25
|
+
Note:
|
|
26
|
+
The value of the API key is only available at creation time.
|
|
21
27
|
|
|
22
28
|
Attributes:
|
|
23
29
|
name: Unique name of the API key
|
|
24
30
|
created_at: When the API key was created
|
|
31
|
+
scopes: The scopes of the API key
|
|
25
32
|
"""
|
|
26
33
|
|
|
27
34
|
name: str
|
|
28
35
|
created_at: datetime
|
|
29
36
|
scopes: set[Scope]
|
|
30
37
|
|
|
38
|
+
def __init__(self, name: str, created_at: datetime, scopes: set[Scope]):
|
|
39
|
+
self.name = name
|
|
40
|
+
self.created_at = created_at
|
|
41
|
+
self.scopes = scopes
|
|
42
|
+
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
return "ApiKey({ " + f"name: '{self.name}', scopes: <{'|'.join(self.scopes)}>" + "})"
|
|
45
|
+
|
|
31
46
|
|
|
32
47
|
class OrcaCredentials:
|
|
33
48
|
"""
|
|
@@ -42,8 +57,9 @@ class OrcaCredentials:
|
|
|
42
57
|
Returns:
|
|
43
58
|
True if you are authenticated, False otherwise
|
|
44
59
|
"""
|
|
60
|
+
client = OrcaClient._resolve_client()
|
|
45
61
|
try:
|
|
46
|
-
return
|
|
62
|
+
return client.GET("/auth")
|
|
47
63
|
except ValueError as e:
|
|
48
64
|
if "Invalid API key" in str(e):
|
|
49
65
|
return False
|
|
@@ -57,8 +73,10 @@ class OrcaCredentials:
|
|
|
57
73
|
Returns:
|
|
58
74
|
True if the API is healthy, False otherwise
|
|
59
75
|
"""
|
|
76
|
+
client = OrcaClient._resolve_client()
|
|
60
77
|
try:
|
|
61
|
-
|
|
78
|
+
# we don't want a retry transport here, so we use httpx directly
|
|
79
|
+
httpx.get(f"{client.base_url}/check/healthy")
|
|
62
80
|
except Exception:
|
|
63
81
|
return False
|
|
64
82
|
return True
|
|
@@ -71,13 +89,14 @@ class OrcaCredentials:
|
|
|
71
89
|
Returns:
|
|
72
90
|
A list of named tuples, with the name and creation date time of the API key
|
|
73
91
|
"""
|
|
92
|
+
client = OrcaClient._resolve_client()
|
|
74
93
|
return [
|
|
75
94
|
ApiKeyInfo(
|
|
76
95
|
name=api_key["name"],
|
|
77
96
|
created_at=datetime.fromisoformat(api_key["created_at"]),
|
|
78
97
|
scopes=set(api_key["scope"]),
|
|
79
98
|
)
|
|
80
|
-
for api_key in
|
|
99
|
+
for api_key in client.GET("/auth/api_key")
|
|
81
100
|
]
|
|
82
101
|
|
|
83
102
|
@staticmethod
|
|
@@ -92,7 +111,8 @@ class OrcaCredentials:
|
|
|
92
111
|
Returns:
|
|
93
112
|
The secret value of the API key. Make sure to save this value as it will not be shown again.
|
|
94
113
|
"""
|
|
95
|
-
|
|
114
|
+
client = OrcaClient._resolve_client()
|
|
115
|
+
res = client.POST(
|
|
96
116
|
"/auth/api_key",
|
|
97
117
|
json={"name": name, "scope": list(scopes)},
|
|
98
118
|
)
|
|
@@ -109,8 +129,12 @@ class OrcaCredentials:
|
|
|
109
129
|
Raises:
|
|
110
130
|
ValueError: if the API key is not found
|
|
111
131
|
"""
|
|
112
|
-
|
|
132
|
+
client = OrcaClient._resolve_client()
|
|
133
|
+
client.DELETE("/auth/api_key/{name_or_id}", params={"name_or_id": name})
|
|
134
|
+
|
|
135
|
+
# TODO: remove deprecated methods after 2026-01-01
|
|
113
136
|
|
|
137
|
+
@deprecated("Use `OrcaClient.api_key` instead")
|
|
114
138
|
@staticmethod
|
|
115
139
|
def set_api_key(api_key: str, check_validity: bool = True):
|
|
116
140
|
"""
|
|
@@ -126,17 +150,24 @@ class OrcaCredentials:
|
|
|
126
150
|
Raises:
|
|
127
151
|
ValueError: if the API key is invalid and `check_validity` is True
|
|
128
152
|
"""
|
|
129
|
-
|
|
153
|
+
sync_client = OrcaClient._resolve_client()
|
|
154
|
+
sync_client.api_key = api_key
|
|
130
155
|
if check_validity:
|
|
131
|
-
|
|
156
|
+
sync_client.GET("/auth")
|
|
132
157
|
|
|
158
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
159
|
+
async_client.api_key = api_key
|
|
160
|
+
|
|
161
|
+
@deprecated("Use `OrcaClient.base_url` instead")
|
|
133
162
|
@staticmethod
|
|
134
163
|
def get_api_url() -> str:
|
|
135
164
|
"""
|
|
136
165
|
Get the base URL of the Orca API that is currently being used
|
|
137
166
|
"""
|
|
138
|
-
|
|
167
|
+
client = OrcaClient._resolve_client()
|
|
168
|
+
return str(client.base_url)
|
|
139
169
|
|
|
170
|
+
@deprecated("Use `OrcaClient.base_url` instead")
|
|
140
171
|
@staticmethod
|
|
141
172
|
def set_api_url(url: str, check_validity: bool = True):
|
|
142
173
|
"""
|
|
@@ -156,12 +187,17 @@ class OrcaCredentials:
|
|
|
156
187
|
except ConnectError as e:
|
|
157
188
|
raise ValueError(f"No API found at {url}") from e
|
|
158
189
|
|
|
159
|
-
|
|
190
|
+
sync_client = OrcaClient._resolve_client()
|
|
191
|
+
sync_client.base_url = url
|
|
192
|
+
|
|
193
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
194
|
+
async_client.base_url = url
|
|
160
195
|
|
|
161
196
|
# check if the api passes the health check
|
|
162
197
|
if check_validity:
|
|
163
198
|
OrcaCredentials.is_healthy()
|
|
164
199
|
|
|
200
|
+
@deprecated("Use `OrcaClient.headers` instead")
|
|
165
201
|
@staticmethod
|
|
166
202
|
def set_api_headers(headers: dict[str, str]):
|
|
167
203
|
"""
|
|
@@ -174,4 +210,8 @@ class OrcaCredentials:
|
|
|
174
210
|
New keys are merged into the existing headers, this will overwrite headers with the
|
|
175
211
|
same name, but leave other headers untouched.
|
|
176
212
|
"""
|
|
177
|
-
|
|
213
|
+
sync_client = OrcaClient._resolve_client()
|
|
214
|
+
sync_client.headers.update(Headers(headers))
|
|
215
|
+
|
|
216
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
217
|
+
async_client.headers.update(Headers(headers))
|
orca_sdk/credentials_test.py
CHANGED
|
@@ -2,56 +2,120 @@ from uuid import uuid4
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
-
from .client import
|
|
5
|
+
from .client import OrcaClient
|
|
6
6
|
from .credentials import OrcaCredentials
|
|
7
7
|
|
|
8
8
|
|
|
9
|
+
def test_is_authenticated():
|
|
10
|
+
assert OrcaCredentials.is_authenticated()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_is_authenticated_false(unauthenticated_client):
|
|
14
|
+
with unauthenticated_client.use():
|
|
15
|
+
assert not OrcaCredentials.is_authenticated()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_is_healthy():
|
|
19
|
+
assert OrcaCredentials.is_healthy()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_is_healthy_false(api_key):
|
|
23
|
+
with OrcaClient(api_key=api_key, base_url="http://localhost:1582").use():
|
|
24
|
+
assert not OrcaCredentials.is_healthy()
|
|
25
|
+
|
|
26
|
+
|
|
9
27
|
def test_list_api_keys():
|
|
10
28
|
api_keys = OrcaCredentials.list_api_keys()
|
|
11
29
|
assert len(api_keys) >= 1
|
|
12
30
|
assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
|
|
13
31
|
|
|
14
32
|
|
|
15
|
-
def test_list_api_keys_unauthenticated(
|
|
16
|
-
with
|
|
17
|
-
|
|
33
|
+
def test_list_api_keys_unauthenticated(unauthenticated_client):
|
|
34
|
+
with unauthenticated_client.use():
|
|
35
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
36
|
+
OrcaCredentials.list_api_keys()
|
|
18
37
|
|
|
19
38
|
|
|
20
|
-
def
|
|
21
|
-
|
|
39
|
+
def test_manage_api_key():
|
|
40
|
+
api_key_name = f"orca_sdk_test_{uuid4().hex[:8]}"
|
|
41
|
+
api_key = OrcaCredentials.create_api_key(api_key_name)
|
|
42
|
+
assert api_key is not None
|
|
43
|
+
assert len(api_key) > 0
|
|
44
|
+
assert api_key_name in [aki.name for aki in OrcaCredentials.list_api_keys()]
|
|
45
|
+
OrcaCredentials.revoke_api_key(api_key_name)
|
|
46
|
+
assert api_key_name not in [aki.name for aki in OrcaCredentials.list_api_keys()]
|
|
22
47
|
|
|
23
48
|
|
|
24
|
-
def
|
|
25
|
-
|
|
49
|
+
def test_create_api_key_unauthenticated(unauthenticated_client):
|
|
50
|
+
with unauthenticated_client.use():
|
|
51
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
52
|
+
OrcaCredentials.create_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
26
53
|
|
|
27
54
|
|
|
28
|
-
def
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
55
|
+
def test_create_api_key_unauthorized(predict_only_client):
|
|
56
|
+
with predict_only_client.use():
|
|
57
|
+
with pytest.raises(PermissionError):
|
|
58
|
+
OrcaCredentials.create_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
32
59
|
|
|
33
60
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
assert not OrcaCredentials.is_authenticated()
|
|
61
|
+
def test_revoke_api_key_unauthenticated(unauthenticated_client):
|
|
62
|
+
with unauthenticated_client.use():
|
|
63
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
64
|
+
OrcaCredentials.revoke_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
39
65
|
|
|
40
66
|
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
def test_revoke_api_key_unauthorized(predict_only_client):
|
|
68
|
+
with predict_only_client.use():
|
|
69
|
+
with pytest.raises(PermissionError):
|
|
70
|
+
OrcaCredentials.revoke_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
44
71
|
|
|
45
72
|
|
|
46
|
-
def
|
|
47
|
-
with pytest.raises(ValueError, match="
|
|
48
|
-
OrcaCredentials.
|
|
73
|
+
def test_create_api_key_already_exists():
|
|
74
|
+
with pytest.raises(ValueError, match="API key with this name already exists"):
|
|
75
|
+
OrcaCredentials.create_api_key("orca_sdk_test")
|
|
49
76
|
|
|
50
77
|
|
|
51
|
-
def
|
|
52
|
-
|
|
78
|
+
def test_set_api_key(api_key):
|
|
79
|
+
client = OrcaClient(api_key=str(uuid4()))
|
|
80
|
+
with client.use():
|
|
81
|
+
assert not OrcaCredentials.is_authenticated()
|
|
82
|
+
client.api_key = api_key
|
|
83
|
+
assert client.api_key == api_key
|
|
84
|
+
assert OrcaCredentials.is_authenticated()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_set_base_url(api_key):
|
|
88
|
+
client = OrcaClient(base_url="http://localhost:1582")
|
|
89
|
+
assert client.base_url == "http://localhost:1582"
|
|
90
|
+
client.base_url = "http://localhost:1583"
|
|
91
|
+
assert client.base_url == "http://localhost:1583"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# deprecated methods:
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_deprecated_set_api_key(api_key):
|
|
98
|
+
with OrcaClient(api_key=str(uuid4())).use():
|
|
99
|
+
assert not OrcaCredentials.is_authenticated()
|
|
100
|
+
OrcaCredentials.set_api_key(api_key)
|
|
101
|
+
assert OrcaCredentials.is_authenticated()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_deprecated_set_invalid_api_key(api_key):
|
|
105
|
+
with OrcaClient(api_key=api_key).use():
|
|
106
|
+
assert OrcaCredentials.is_authenticated()
|
|
107
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
+
OrcaCredentials.set_api_key(str(uuid4()))
|
|
109
|
+
assert not OrcaCredentials.is_authenticated()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_deprecated_set_api_url(api_key):
|
|
113
|
+
with OrcaClient(api_key=api_key).use():
|
|
114
|
+
OrcaCredentials.set_api_url("http://api.orcadb.ai")
|
|
115
|
+
assert str(OrcaClient._resolve_client().base_url) == "http://api.orcadb.ai"
|
|
53
116
|
|
|
54
117
|
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
118
|
+
def test_deprecated_set_invalid_api_url(api_key):
|
|
119
|
+
with OrcaClient(api_key=api_key).use():
|
|
120
|
+
with pytest.raises(ValueError, match="No API found at http://localhost:1582"):
|
|
121
|
+
OrcaCredentials.set_api_url("http://localhost:1582")
|
orca_sdk/datasource.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
5
4
|
import tempfile
|
|
6
5
|
import zipfile
|
|
7
6
|
from datetime import datetime
|
|
8
7
|
from io import BytesIO
|
|
9
8
|
from os import PathLike
|
|
10
9
|
from pathlib import Path
|
|
11
|
-
from typing import Literal, Union, cast
|
|
10
|
+
from typing import Any, Literal, Union, cast
|
|
12
11
|
|
|
13
12
|
import pandas as pd
|
|
14
13
|
import pyarrow as pa
|
|
@@ -22,7 +21,7 @@ from tqdm.auto import tqdm
|
|
|
22
21
|
from ._utils.common import CreateMode, DropMode
|
|
23
22
|
from ._utils.data_parsing import hf_dataset_from_torch
|
|
24
23
|
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
25
|
-
from .client import DatasourceMetadata,
|
|
24
|
+
from .client import DatasourceMetadata, OrcaClient
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
def _upload_files_to_datasource(
|
|
@@ -55,7 +54,8 @@ def _upload_files_to_datasource(
|
|
|
55
54
|
files.append(("files", (file_path.name, cast(bytes, tqdm_reader))))
|
|
56
55
|
|
|
57
56
|
# Use manual HTTP request for file uploads
|
|
58
|
-
|
|
57
|
+
client = OrcaClient._resolve_client()
|
|
58
|
+
metadata = client.POST(
|
|
59
59
|
"/datasource/upload",
|
|
60
60
|
files=files,
|
|
61
61
|
data={"name": name, "description": description},
|
|
@@ -268,7 +268,8 @@ class Datasource:
|
|
|
268
268
|
if existing is not None:
|
|
269
269
|
return existing
|
|
270
270
|
|
|
271
|
-
|
|
271
|
+
client = OrcaClient._resolve_client()
|
|
272
|
+
metadata = client.POST(
|
|
272
273
|
"/datasource",
|
|
273
274
|
json={"name": name, "description": description, "content": data},
|
|
274
275
|
)
|
|
@@ -302,7 +303,8 @@ class Datasource:
|
|
|
302
303
|
if existing is not None:
|
|
303
304
|
return existing
|
|
304
305
|
|
|
305
|
-
|
|
306
|
+
client = OrcaClient._resolve_client()
|
|
307
|
+
metadata = client.POST(
|
|
306
308
|
"/datasource",
|
|
307
309
|
json={"name": name, "description": description, "content": data},
|
|
308
310
|
)
|
|
@@ -361,7 +363,8 @@ class Datasource:
|
|
|
361
363
|
parquet.write_table(pyarrow_table, buffer)
|
|
362
364
|
parquet_bytes = buffer.getvalue()
|
|
363
365
|
|
|
364
|
-
|
|
366
|
+
client = OrcaClient._resolve_client()
|
|
367
|
+
metadata = client.POST(
|
|
365
368
|
"/datasource/upload",
|
|
366
369
|
files=[("files", ("data.parquet", parquet_bytes))],
|
|
367
370
|
data={"name": name, "description": description},
|
|
@@ -429,7 +432,8 @@ class Datasource:
|
|
|
429
432
|
Raises:
|
|
430
433
|
LookupError: If the datasource does not exist
|
|
431
434
|
"""
|
|
432
|
-
|
|
435
|
+
client = OrcaClient._resolve_client()
|
|
436
|
+
return cls(client.GET("/datasource/{name_or_id}", params={"name_or_id": name_or_id}))
|
|
433
437
|
|
|
434
438
|
@classmethod
|
|
435
439
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -456,7 +460,8 @@ class Datasource:
|
|
|
456
460
|
Returns:
|
|
457
461
|
A list of all datasource handles in the OrcaCloud
|
|
458
462
|
"""
|
|
459
|
-
|
|
463
|
+
client = OrcaClient._resolve_client()
|
|
464
|
+
return [cls(metadata) for metadata in client.GET("/datasource")]
|
|
460
465
|
|
|
461
466
|
@classmethod
|
|
462
467
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error") -> None:
|
|
@@ -472,7 +477,8 @@ class Datasource:
|
|
|
472
477
|
LookupError: If the datasource does not exist and if_not_exists is `"error"`
|
|
473
478
|
"""
|
|
474
479
|
try:
|
|
475
|
-
|
|
480
|
+
client = OrcaClient._resolve_client()
|
|
481
|
+
client.DELETE("/datasource/{name_or_id}", params={"name_or_id": name_or_id})
|
|
476
482
|
logging.info(f"Deleted datasource {name_or_id}")
|
|
477
483
|
except LookupError:
|
|
478
484
|
if if_not_exists == "error":
|
|
@@ -481,6 +487,50 @@ class Datasource:
|
|
|
481
487
|
def __len__(self) -> int:
|
|
482
488
|
return self.length
|
|
483
489
|
|
|
490
|
+
def query(
|
|
491
|
+
self,
|
|
492
|
+
offset: int = 0,
|
|
493
|
+
limit: int = 100,
|
|
494
|
+
shuffle: bool = False,
|
|
495
|
+
shuffle_seed: int | None = None,
|
|
496
|
+
filters: list[tuple[str, Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "like"], Any]] = [],
|
|
497
|
+
) -> list[dict[str, Any]]:
|
|
498
|
+
"""
|
|
499
|
+
Query the datasource for rows with pagination and filtering support.
|
|
500
|
+
|
|
501
|
+
Params:
|
|
502
|
+
offset: Number of rows to skip
|
|
503
|
+
limit: Maximum number of rows to return
|
|
504
|
+
shuffle: Whether to shuffle the dataset before pagination
|
|
505
|
+
shuffle_seed: Seed for shuffling (for reproducible results)
|
|
506
|
+
filters: List of filter tuples. Each tuple contains:
|
|
507
|
+
- field (str): Column name to filter on
|
|
508
|
+
- op (str): Operator ("==", "!=", ">", ">=", "<", "<=", "in", "not in", "like")
|
|
509
|
+
- value: Value to compare against
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
List of rows from the datasource
|
|
513
|
+
|
|
514
|
+
Examples:
|
|
515
|
+
>>> datasource.query(filters=[("age", ">", 25)])
|
|
516
|
+
>>> datasource.query(filters=[("city", "in", ["NYC", "LA"])])
|
|
517
|
+
>>> datasource.query(filters=[("name", "like", "John")])
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
client = OrcaClient._resolve_client()
|
|
521
|
+
response = client.POST(
|
|
522
|
+
"/datasource/{name_or_id}/rows",
|
|
523
|
+
params={"name_or_id": self.id},
|
|
524
|
+
json={
|
|
525
|
+
"limit": limit,
|
|
526
|
+
"offset": offset,
|
|
527
|
+
"shuffle": shuffle,
|
|
528
|
+
"shuffle_seed": shuffle_seed,
|
|
529
|
+
"filters": [{"field": field, "op": op, "value": value} for field, op, value in filters],
|
|
530
|
+
},
|
|
531
|
+
)
|
|
532
|
+
return response
|
|
533
|
+
|
|
484
534
|
def download(
|
|
485
535
|
self, output_dir: str | PathLike, file_type: Literal["hf_dataset", "json", "csv"] = "hf_dataset"
|
|
486
536
|
) -> None:
|
|
@@ -497,7 +547,8 @@ class Datasource:
|
|
|
497
547
|
extension = "zip" if file_type == "hf_dataset" else file_type
|
|
498
548
|
output_path = Path(output_dir) / f"{self.name}.{extension}"
|
|
499
549
|
with open(output_path, "wb") as download_file:
|
|
500
|
-
|
|
550
|
+
client = OrcaClient._resolve_client()
|
|
551
|
+
with client.stream("GET", f"/datasource/{self.id}/download", params={"file_type": file_type}) as response:
|
|
501
552
|
total_chunks = int(response.headers["X-Total-Chunks"]) if "X-Total-Chunks" in response.headers else None
|
|
502
553
|
with tqdm(desc="Downloading", total=total_chunks, disable=total_chunks is None) as progress:
|
|
503
554
|
for chunk in response.iter_bytes():
|
|
@@ -521,4 +572,5 @@ class Datasource:
|
|
|
521
572
|
Returns:
|
|
522
573
|
A list of dictionaries representation of the datasource.
|
|
523
574
|
"""
|
|
524
|
-
|
|
575
|
+
client = OrcaClient._resolve_client()
|
|
576
|
+
return client.GET("/datasource/{name_or_id}/download", params={"name_or_id": self.id, "file_type": "json"})
|