orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +176 -14
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +515 -475
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +19 -10
- orca_sdk/datasource_test.py +24 -18
- orca_sdk/embedding_model.py +22 -13
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +14 -8
- orca_sdk/memoryset.py +513 -183
- orca_sdk/memoryset_test.py +130 -32
- orca_sdk/regression_model.py +21 -11
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +24 -13
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +1 -1
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +0 -0
orca_sdk/conftest.py
CHANGED
|
@@ -4,11 +4,13 @@ from typing import Generator
|
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
6
|
import pytest
|
|
7
|
+
import pytest_asyncio
|
|
7
8
|
from datasets import ClassLabel, Dataset, Features, Value
|
|
8
9
|
|
|
9
10
|
from ._utils.auth import _create_api_key, _delete_org
|
|
11
|
+
from .async_client import OrcaAsyncClient
|
|
10
12
|
from .classification_model import ClassificationModel
|
|
11
|
-
from .client import
|
|
13
|
+
from .client import OrcaClient
|
|
12
14
|
from .credentials import OrcaCredentials
|
|
13
15
|
from .datasource import Datasource
|
|
14
16
|
from .embedding_model import PretrainedEmbeddingModel
|
|
@@ -44,51 +46,51 @@ def _create_org_id():
|
|
|
44
46
|
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
@pytest.fixture()
|
|
48
|
-
def
|
|
49
|
-
|
|
50
|
-
yield
|
|
51
|
-
orca_api.base_url = original_base_url
|
|
49
|
+
@pytest.fixture(scope="session")
|
|
50
|
+
def org_id():
|
|
51
|
+
return _create_org_id()
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
@pytest.fixture(scope="session")
|
|
55
|
-
def
|
|
55
|
+
def other_org_id():
|
|
56
56
|
return _create_org_id()
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@pytest.fixture(autouse=True, scope="session")
|
|
60
60
|
def api_key(org_id) -> Generator[str, None, None]:
|
|
61
61
|
api_key = _create_api_key(org_id=org_id, name="orca_sdk_test")
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
with OrcaClient(api_key=api_key).use():
|
|
63
|
+
yield api_key
|
|
64
64
|
_delete_org(org_id)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
+
# We cannot use a session scoped fixture because async pytest tears down the client after each test
|
|
67
68
|
@pytest.fixture(autouse=True)
|
|
68
|
-
def
|
|
69
|
-
|
|
69
|
+
def authenticate_async_client(api_key) -> Generator[None, None, None]:
|
|
70
|
+
with OrcaAsyncClient(api_key=api_key).use():
|
|
71
|
+
yield
|
|
70
72
|
|
|
71
73
|
|
|
72
|
-
@pytest.fixture()
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
yield
|
|
76
|
-
# Need to reset the api key to the original api key so following tests don't fail
|
|
77
|
-
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
74
|
+
@pytest.fixture(scope="session")
|
|
75
|
+
def unauthenticated_client() -> OrcaClient:
|
|
76
|
+
return OrcaClient(api_key=str(uuid4()))
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
@
|
|
81
|
-
def
|
|
82
|
-
return
|
|
79
|
+
@pytest_asyncio.fixture()
|
|
80
|
+
def unauthenticated_async_client() -> OrcaAsyncClient:
|
|
81
|
+
return OrcaAsyncClient(api_key=str(uuid4()))
|
|
83
82
|
|
|
84
83
|
|
|
85
|
-
@pytest.fixture()
|
|
86
|
-
def
|
|
84
|
+
@pytest.fixture(scope="session")
|
|
85
|
+
def unauthorized_client(other_org_id):
|
|
87
86
|
different_api_key = _create_api_key(org_id=other_org_id, name="orca_sdk_test_other_org")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
87
|
+
return OrcaClient(api_key=different_api_key)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.fixture(scope="session")
|
|
91
|
+
def predict_only_client() -> OrcaClient:
|
|
92
|
+
predict_api_key = OrcaCredentials.create_api_key("orca_sdk_test_predict", scopes={"PREDICT"})
|
|
93
|
+
return OrcaClient(api_key=predict_api_key)
|
|
92
94
|
|
|
93
95
|
|
|
94
96
|
@pytest.fixture(scope="session")
|
|
@@ -209,18 +211,17 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
209
211
|
yield memoryset
|
|
210
212
|
finally:
|
|
211
213
|
# Restore the memoryset to a clean state for the next test.
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
memoryset.refresh()
|
|
214
|
+
with OrcaClient(api_key=api_key).use():
|
|
215
|
+
if LabeledMemoryset.exists("test_writable_memoryset"):
|
|
216
|
+
memoryset.refresh()
|
|
216
217
|
|
|
217
|
-
|
|
218
|
+
memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
|
|
218
219
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
220
|
+
if memory_ids:
|
|
221
|
+
memoryset.delete(memory_ids)
|
|
222
|
+
memoryset.refresh()
|
|
223
|
+
assert len(memoryset) == 0
|
|
224
|
+
memoryset.insert(SAMPLE_DATA)
|
|
224
225
|
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
225
226
|
|
|
226
227
|
|
orca_sdk/credentials.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from datetime import datetime
|
|
2
3
|
from typing import Literal, NamedTuple
|
|
3
4
|
|
|
4
5
|
import httpx
|
|
5
|
-
from httpx import ConnectError, Headers
|
|
6
|
+
from httpx import ConnectError, Headers, HTTPTransport
|
|
7
|
+
from typing_extensions import deprecated
|
|
6
8
|
|
|
7
|
-
from .
|
|
9
|
+
from .async_client import OrcaAsyncClient
|
|
10
|
+
from .client import OrcaClient
|
|
8
11
|
|
|
9
12
|
Scope = Literal["ADMINISTER", "PREDICT"]
|
|
10
13
|
"""
|
|
@@ -15,19 +18,31 @@ The scopes of an API key.
|
|
|
15
18
|
"""
|
|
16
19
|
|
|
17
20
|
|
|
18
|
-
class ApiKeyInfo
|
|
21
|
+
class ApiKeyInfo:
|
|
19
22
|
"""
|
|
20
|
-
|
|
23
|
+
Information about an API key
|
|
24
|
+
|
|
25
|
+
Note:
|
|
26
|
+
The value of the API key is only available at creation time.
|
|
21
27
|
|
|
22
28
|
Attributes:
|
|
23
29
|
name: Unique name of the API key
|
|
24
30
|
created_at: When the API key was created
|
|
31
|
+
scopes: The scopes of the API key
|
|
25
32
|
"""
|
|
26
33
|
|
|
27
34
|
name: str
|
|
28
35
|
created_at: datetime
|
|
29
36
|
scopes: set[Scope]
|
|
30
37
|
|
|
38
|
+
def __init__(self, name: str, created_at: datetime, scopes: set[Scope]):
|
|
39
|
+
self.name = name
|
|
40
|
+
self.created_at = created_at
|
|
41
|
+
self.scopes = scopes
|
|
42
|
+
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
return "ApiKey({ " + f"name: '{self.name}', scopes: <{'|'.join(self.scopes)}>" + "})"
|
|
45
|
+
|
|
31
46
|
|
|
32
47
|
class OrcaCredentials:
|
|
33
48
|
"""
|
|
@@ -42,8 +57,9 @@ class OrcaCredentials:
|
|
|
42
57
|
Returns:
|
|
43
58
|
True if you are authenticated, False otherwise
|
|
44
59
|
"""
|
|
60
|
+
client = OrcaClient._resolve_client()
|
|
45
61
|
try:
|
|
46
|
-
return
|
|
62
|
+
return client.GET("/auth")
|
|
47
63
|
except ValueError as e:
|
|
48
64
|
if "Invalid API key" in str(e):
|
|
49
65
|
return False
|
|
@@ -57,8 +73,10 @@ class OrcaCredentials:
|
|
|
57
73
|
Returns:
|
|
58
74
|
True if the API is healthy, False otherwise
|
|
59
75
|
"""
|
|
76
|
+
client = OrcaClient._resolve_client()
|
|
60
77
|
try:
|
|
61
|
-
|
|
78
|
+
# we don't want a retry transport here, so we use httpx directly
|
|
79
|
+
httpx.get(f"{client.base_url}/check/healthy")
|
|
62
80
|
except Exception:
|
|
63
81
|
return False
|
|
64
82
|
return True
|
|
@@ -71,13 +89,14 @@ class OrcaCredentials:
|
|
|
71
89
|
Returns:
|
|
72
90
|
A list of named tuples, with the name and creation date time of the API key
|
|
73
91
|
"""
|
|
92
|
+
client = OrcaClient._resolve_client()
|
|
74
93
|
return [
|
|
75
94
|
ApiKeyInfo(
|
|
76
95
|
name=api_key["name"],
|
|
77
96
|
created_at=datetime.fromisoformat(api_key["created_at"]),
|
|
78
97
|
scopes=set(api_key["scope"]),
|
|
79
98
|
)
|
|
80
|
-
for api_key in
|
|
99
|
+
for api_key in client.GET("/auth/api_key")
|
|
81
100
|
]
|
|
82
101
|
|
|
83
102
|
@staticmethod
|
|
@@ -92,7 +111,8 @@ class OrcaCredentials:
|
|
|
92
111
|
Returns:
|
|
93
112
|
The secret value of the API key. Make sure to save this value as it will not be shown again.
|
|
94
113
|
"""
|
|
95
|
-
|
|
114
|
+
client = OrcaClient._resolve_client()
|
|
115
|
+
res = client.POST(
|
|
96
116
|
"/auth/api_key",
|
|
97
117
|
json={"name": name, "scope": list(scopes)},
|
|
98
118
|
)
|
|
@@ -109,8 +129,12 @@ class OrcaCredentials:
|
|
|
109
129
|
Raises:
|
|
110
130
|
ValueError: if the API key is not found
|
|
111
131
|
"""
|
|
112
|
-
|
|
132
|
+
client = OrcaClient._resolve_client()
|
|
133
|
+
client.DELETE("/auth/api_key/{name_or_id}", params={"name_or_id": name})
|
|
134
|
+
|
|
135
|
+
# TODO: remove deprecated methods after 2026-01-01
|
|
113
136
|
|
|
137
|
+
@deprecated("Use `OrcaClient.api_key` instead")
|
|
114
138
|
@staticmethod
|
|
115
139
|
def set_api_key(api_key: str, check_validity: bool = True):
|
|
116
140
|
"""
|
|
@@ -126,17 +150,24 @@ class OrcaCredentials:
|
|
|
126
150
|
Raises:
|
|
127
151
|
ValueError: if the API key is invalid and `check_validity` is True
|
|
128
152
|
"""
|
|
129
|
-
|
|
153
|
+
sync_client = OrcaClient._resolve_client()
|
|
154
|
+
sync_client.api_key = api_key
|
|
130
155
|
if check_validity:
|
|
131
|
-
|
|
156
|
+
sync_client.GET("/auth")
|
|
132
157
|
|
|
158
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
159
|
+
async_client.api_key = api_key
|
|
160
|
+
|
|
161
|
+
@deprecated("Use `OrcaClient.base_url` instead")
|
|
133
162
|
@staticmethod
|
|
134
163
|
def get_api_url() -> str:
|
|
135
164
|
"""
|
|
136
165
|
Get the base URL of the Orca API that is currently being used
|
|
137
166
|
"""
|
|
138
|
-
|
|
167
|
+
client = OrcaClient._resolve_client()
|
|
168
|
+
return str(client.base_url)
|
|
139
169
|
|
|
170
|
+
@deprecated("Use `OrcaClient.base_url` instead")
|
|
140
171
|
@staticmethod
|
|
141
172
|
def set_api_url(url: str, check_validity: bool = True):
|
|
142
173
|
"""
|
|
@@ -156,12 +187,17 @@ class OrcaCredentials:
|
|
|
156
187
|
except ConnectError as e:
|
|
157
188
|
raise ValueError(f"No API found at {url}") from e
|
|
158
189
|
|
|
159
|
-
|
|
190
|
+
sync_client = OrcaClient._resolve_client()
|
|
191
|
+
sync_client.base_url = url
|
|
192
|
+
|
|
193
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
194
|
+
async_client.base_url = url
|
|
160
195
|
|
|
161
196
|
# check if the api passes the health check
|
|
162
197
|
if check_validity:
|
|
163
198
|
OrcaCredentials.is_healthy()
|
|
164
199
|
|
|
200
|
+
@deprecated("Use `OrcaClient.headers` instead")
|
|
165
201
|
@staticmethod
|
|
166
202
|
def set_api_headers(headers: dict[str, str]):
|
|
167
203
|
"""
|
|
@@ -174,4 +210,8 @@ class OrcaCredentials:
|
|
|
174
210
|
New keys are merged into the existing headers, this will overwrite headers with the
|
|
175
211
|
same name, but leave other headers untouched.
|
|
176
212
|
"""
|
|
177
|
-
|
|
213
|
+
sync_client = OrcaClient._resolve_client()
|
|
214
|
+
sync_client.headers.update(Headers(headers))
|
|
215
|
+
|
|
216
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
217
|
+
async_client.headers.update(Headers(headers))
|
orca_sdk/credentials_test.py
CHANGED
|
@@ -2,56 +2,120 @@ from uuid import uuid4
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
-
from .client import
|
|
5
|
+
from .client import OrcaClient
|
|
6
6
|
from .credentials import OrcaCredentials
|
|
7
7
|
|
|
8
8
|
|
|
9
|
+
def test_is_authenticated():
|
|
10
|
+
assert OrcaCredentials.is_authenticated()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_is_authenticated_false(unauthenticated_client):
|
|
14
|
+
with unauthenticated_client.use():
|
|
15
|
+
assert not OrcaCredentials.is_authenticated()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_is_healthy():
|
|
19
|
+
assert OrcaCredentials.is_healthy()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_is_healthy_false(api_key):
|
|
23
|
+
with OrcaClient(api_key=api_key, base_url="http://localhost:1582").use():
|
|
24
|
+
assert not OrcaCredentials.is_healthy()
|
|
25
|
+
|
|
26
|
+
|
|
9
27
|
def test_list_api_keys():
|
|
10
28
|
api_keys = OrcaCredentials.list_api_keys()
|
|
11
29
|
assert len(api_keys) >= 1
|
|
12
30
|
assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
|
|
13
31
|
|
|
14
32
|
|
|
15
|
-
def test_list_api_keys_unauthenticated(
|
|
16
|
-
with
|
|
17
|
-
|
|
33
|
+
def test_list_api_keys_unauthenticated(unauthenticated_client):
|
|
34
|
+
with unauthenticated_client.use():
|
|
35
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
36
|
+
OrcaCredentials.list_api_keys()
|
|
18
37
|
|
|
19
38
|
|
|
20
|
-
def
|
|
21
|
-
|
|
39
|
+
def test_manage_api_key():
|
|
40
|
+
api_key_name = f"orca_sdk_test_{uuid4().hex[:8]}"
|
|
41
|
+
api_key = OrcaCredentials.create_api_key(api_key_name)
|
|
42
|
+
assert api_key is not None
|
|
43
|
+
assert len(api_key) > 0
|
|
44
|
+
assert api_key_name in [aki.name for aki in OrcaCredentials.list_api_keys()]
|
|
45
|
+
OrcaCredentials.revoke_api_key(api_key_name)
|
|
46
|
+
assert api_key_name not in [aki.name for aki in OrcaCredentials.list_api_keys()]
|
|
22
47
|
|
|
23
48
|
|
|
24
|
-
def
|
|
25
|
-
|
|
49
|
+
def test_create_api_key_unauthenticated(unauthenticated_client):
|
|
50
|
+
with unauthenticated_client.use():
|
|
51
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
52
|
+
OrcaCredentials.create_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
26
53
|
|
|
27
54
|
|
|
28
|
-
def
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
55
|
+
def test_create_api_key_unauthorized(predict_only_client):
|
|
56
|
+
with predict_only_client.use():
|
|
57
|
+
with pytest.raises(PermissionError):
|
|
58
|
+
OrcaCredentials.create_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
32
59
|
|
|
33
60
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
assert not OrcaCredentials.is_authenticated()
|
|
61
|
+
def test_revoke_api_key_unauthenticated(unauthenticated_client):
|
|
62
|
+
with unauthenticated_client.use():
|
|
63
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
64
|
+
OrcaCredentials.revoke_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
39
65
|
|
|
40
66
|
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
def test_revoke_api_key_unauthorized(predict_only_client):
|
|
68
|
+
with predict_only_client.use():
|
|
69
|
+
with pytest.raises(PermissionError):
|
|
70
|
+
OrcaCredentials.revoke_api_key(f"orca_sdk_test_{uuid4().hex[:8]}")
|
|
44
71
|
|
|
45
72
|
|
|
46
|
-
def
|
|
47
|
-
with pytest.raises(ValueError, match="
|
|
48
|
-
OrcaCredentials.
|
|
73
|
+
def test_create_api_key_already_exists():
|
|
74
|
+
with pytest.raises(ValueError, match="API key with this name already exists"):
|
|
75
|
+
OrcaCredentials.create_api_key("orca_sdk_test")
|
|
49
76
|
|
|
50
77
|
|
|
51
|
-
def
|
|
52
|
-
|
|
78
|
+
def test_set_api_key(api_key):
|
|
79
|
+
client = OrcaClient(api_key=str(uuid4()))
|
|
80
|
+
with client.use():
|
|
81
|
+
assert not OrcaCredentials.is_authenticated()
|
|
82
|
+
client.api_key = api_key
|
|
83
|
+
assert client.api_key == api_key
|
|
84
|
+
assert OrcaCredentials.is_authenticated()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_set_base_url(api_key):
|
|
88
|
+
client = OrcaClient(base_url="http://localhost:1582")
|
|
89
|
+
assert client.base_url == "http://localhost:1582"
|
|
90
|
+
client.base_url = "http://localhost:1583"
|
|
91
|
+
assert client.base_url == "http://localhost:1583"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# deprecated methods:
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_deprecated_set_api_key(api_key):
|
|
98
|
+
with OrcaClient(api_key=str(uuid4())).use():
|
|
99
|
+
assert not OrcaCredentials.is_authenticated()
|
|
100
|
+
OrcaCredentials.set_api_key(api_key)
|
|
101
|
+
assert OrcaCredentials.is_authenticated()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_deprecated_set_invalid_api_key(api_key):
|
|
105
|
+
with OrcaClient(api_key=api_key).use():
|
|
106
|
+
assert OrcaCredentials.is_authenticated()
|
|
107
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
+
OrcaCredentials.set_api_key(str(uuid4()))
|
|
109
|
+
assert not OrcaCredentials.is_authenticated()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_deprecated_set_api_url(api_key):
|
|
113
|
+
with OrcaClient(api_key=api_key).use():
|
|
114
|
+
OrcaCredentials.set_api_url("http://api.orcadb.ai")
|
|
115
|
+
assert str(OrcaClient._resolve_client().base_url) == "http://api.orcadb.ai"
|
|
53
116
|
|
|
54
117
|
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
118
|
+
def test_deprecated_set_invalid_api_url(api_key):
|
|
119
|
+
with OrcaClient(api_key=api_key).use():
|
|
120
|
+
with pytest.raises(ValueError, match="No API found at http://localhost:1582"):
|
|
121
|
+
OrcaCredentials.set_api_url("http://localhost:1582")
|
orca_sdk/datasource.py
CHANGED
|
@@ -22,7 +22,7 @@ from tqdm.auto import tqdm
|
|
|
22
22
|
from ._utils.common import CreateMode, DropMode
|
|
23
23
|
from ._utils.data_parsing import hf_dataset_from_torch
|
|
24
24
|
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
25
|
-
from .client import DatasourceMetadata,
|
|
25
|
+
from .client import DatasourceMetadata, OrcaClient
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _upload_files_to_datasource(
|
|
@@ -55,7 +55,8 @@ def _upload_files_to_datasource(
|
|
|
55
55
|
files.append(("files", (file_path.name, cast(bytes, tqdm_reader))))
|
|
56
56
|
|
|
57
57
|
# Use manual HTTP request for file uploads
|
|
58
|
-
|
|
58
|
+
client = OrcaClient._resolve_client()
|
|
59
|
+
metadata = client.POST(
|
|
59
60
|
"/datasource/upload",
|
|
60
61
|
files=files,
|
|
61
62
|
data={"name": name, "description": description},
|
|
@@ -268,7 +269,8 @@ class Datasource:
|
|
|
268
269
|
if existing is not None:
|
|
269
270
|
return existing
|
|
270
271
|
|
|
271
|
-
|
|
272
|
+
client = OrcaClient._resolve_client()
|
|
273
|
+
metadata = client.POST(
|
|
272
274
|
"/datasource",
|
|
273
275
|
json={"name": name, "description": description, "content": data},
|
|
274
276
|
)
|
|
@@ -302,7 +304,8 @@ class Datasource:
|
|
|
302
304
|
if existing is not None:
|
|
303
305
|
return existing
|
|
304
306
|
|
|
305
|
-
|
|
307
|
+
client = OrcaClient._resolve_client()
|
|
308
|
+
metadata = client.POST(
|
|
306
309
|
"/datasource",
|
|
307
310
|
json={"name": name, "description": description, "content": data},
|
|
308
311
|
)
|
|
@@ -361,7 +364,8 @@ class Datasource:
|
|
|
361
364
|
parquet.write_table(pyarrow_table, buffer)
|
|
362
365
|
parquet_bytes = buffer.getvalue()
|
|
363
366
|
|
|
364
|
-
|
|
367
|
+
client = OrcaClient._resolve_client()
|
|
368
|
+
metadata = client.POST(
|
|
365
369
|
"/datasource/upload",
|
|
366
370
|
files=[("files", ("data.parquet", parquet_bytes))],
|
|
367
371
|
data={"name": name, "description": description},
|
|
@@ -429,7 +433,8 @@ class Datasource:
|
|
|
429
433
|
Raises:
|
|
430
434
|
LookupError: If the datasource does not exist
|
|
431
435
|
"""
|
|
432
|
-
|
|
436
|
+
client = OrcaClient._resolve_client()
|
|
437
|
+
return cls(client.GET("/datasource/{name_or_id}", params={"name_or_id": name_or_id}))
|
|
433
438
|
|
|
434
439
|
@classmethod
|
|
435
440
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -456,7 +461,8 @@ class Datasource:
|
|
|
456
461
|
Returns:
|
|
457
462
|
A list of all datasource handles in the OrcaCloud
|
|
458
463
|
"""
|
|
459
|
-
|
|
464
|
+
client = OrcaClient._resolve_client()
|
|
465
|
+
return [cls(metadata) for metadata in client.GET("/datasource")]
|
|
460
466
|
|
|
461
467
|
@classmethod
|
|
462
468
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error") -> None:
|
|
@@ -472,7 +478,8 @@ class Datasource:
|
|
|
472
478
|
LookupError: If the datasource does not exist and if_not_exists is `"error"`
|
|
473
479
|
"""
|
|
474
480
|
try:
|
|
475
|
-
|
|
481
|
+
client = OrcaClient._resolve_client()
|
|
482
|
+
client.DELETE("/datasource/{name_or_id}", params={"name_or_id": name_or_id})
|
|
476
483
|
logging.info(f"Deleted datasource {name_or_id}")
|
|
477
484
|
except LookupError:
|
|
478
485
|
if if_not_exists == "error":
|
|
@@ -497,7 +504,8 @@ class Datasource:
|
|
|
497
504
|
extension = "zip" if file_type == "hf_dataset" else file_type
|
|
498
505
|
output_path = Path(output_dir) / f"{self.name}.{extension}"
|
|
499
506
|
with open(output_path, "wb") as download_file:
|
|
500
|
-
|
|
507
|
+
client = OrcaClient._resolve_client()
|
|
508
|
+
with client.stream("GET", f"/datasource/{self.id}/download", params={"file_type": file_type}) as response:
|
|
501
509
|
total_chunks = int(response.headers["X-Total-Chunks"]) if "X-Total-Chunks" in response.headers else None
|
|
502
510
|
with tqdm(desc="Downloading", total=total_chunks, disable=total_chunks is None) as progress:
|
|
503
511
|
for chunk in response.iter_bytes():
|
|
@@ -521,4 +529,5 @@ class Datasource:
|
|
|
521
529
|
Returns:
|
|
522
530
|
A list of dictionaries representation of the datasource.
|
|
523
531
|
"""
|
|
524
|
-
|
|
532
|
+
client = OrcaClient._resolve_client()
|
|
533
|
+
return client.GET("/datasource/{name_or_id}/download", params={"name_or_id": self.id, "file_type": "json"})
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -19,9 +19,10 @@ def test_create_datasource(datasource, hf_dataset):
|
|
|
19
19
|
assert datasource.length == len(hf_dataset)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def test_create_datasource_unauthenticated(
|
|
23
|
-
with
|
|
24
|
-
|
|
22
|
+
def test_create_datasource_unauthenticated(unauthenticated_client, hf_dataset):
|
|
23
|
+
with unauthenticated_client.use():
|
|
24
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
25
|
+
Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def test_create_datasource_already_exists_error(hf_dataset, datasource):
|
|
@@ -43,9 +44,10 @@ def test_open_datasource(datasource):
|
|
|
43
44
|
assert fetched_datasource.length == len(datasource)
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
def test_open_datasource_unauthenticated(
|
|
47
|
-
with
|
|
48
|
-
|
|
47
|
+
def test_open_datasource_unauthenticated(unauthenticated_client, datasource):
|
|
48
|
+
with unauthenticated_client.use():
|
|
49
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
+
Datasource.open("test_datasource")
|
|
49
51
|
|
|
50
52
|
|
|
51
53
|
def test_open_datasource_invalid_input():
|
|
@@ -58,9 +60,10 @@ def test_open_datasource_not_found():
|
|
|
58
60
|
Datasource.open(str(uuid4()))
|
|
59
61
|
|
|
60
62
|
|
|
61
|
-
def test_open_datasource_unauthorized(
|
|
62
|
-
with
|
|
63
|
-
|
|
63
|
+
def test_open_datasource_unauthorized(unauthorized_client, datasource):
|
|
64
|
+
with unauthorized_client.use():
|
|
65
|
+
with pytest.raises(LookupError):
|
|
66
|
+
Datasource.open(datasource.id)
|
|
64
67
|
|
|
65
68
|
|
|
66
69
|
def test_all_datasources(datasource):
|
|
@@ -69,9 +72,10 @@ def test_all_datasources(datasource):
|
|
|
69
72
|
assert any(datasource.name == datasource.name for datasource in datasources)
|
|
70
73
|
|
|
71
74
|
|
|
72
|
-
def test_all_datasources_unauthenticated(
|
|
73
|
-
with
|
|
74
|
-
|
|
75
|
+
def test_all_datasources_unauthenticated(unauthenticated_client):
|
|
76
|
+
with unauthenticated_client.use():
|
|
77
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
78
|
+
Datasource.all()
|
|
75
79
|
|
|
76
80
|
|
|
77
81
|
def test_drop_datasource(hf_dataset):
|
|
@@ -81,9 +85,10 @@ def test_drop_datasource(hf_dataset):
|
|
|
81
85
|
assert not Datasource.exists("datasource_to_delete")
|
|
82
86
|
|
|
83
87
|
|
|
84
|
-
def test_drop_datasource_unauthenticated(datasource,
|
|
85
|
-
with
|
|
86
|
-
|
|
88
|
+
def test_drop_datasource_unauthenticated(datasource, unauthenticated_client):
|
|
89
|
+
with unauthenticated_client.use():
|
|
90
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
91
|
+
Datasource.drop(datasource.id)
|
|
87
92
|
|
|
88
93
|
|
|
89
94
|
def test_drop_datasource_not_found():
|
|
@@ -93,9 +98,10 @@ def test_drop_datasource_not_found():
|
|
|
93
98
|
Datasource.drop(str(uuid4()), if_not_exists="ignore")
|
|
94
99
|
|
|
95
100
|
|
|
96
|
-
def test_drop_datasource_unauthorized(datasource,
|
|
97
|
-
with
|
|
98
|
-
|
|
101
|
+
def test_drop_datasource_unauthorized(datasource, unauthorized_client):
|
|
102
|
+
with unauthorized_client.use():
|
|
103
|
+
with pytest.raises(LookupError):
|
|
104
|
+
Datasource.drop(datasource.id)
|
|
99
105
|
|
|
100
106
|
|
|
101
107
|
def test_drop_datasource_invalid_input():
|