oxenai 0.39.1__cp313-cp313-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oxenai might be problematic. Click here for more details.
- oxen/__init__.py +62 -0
- oxen/auth.py +40 -0
- oxen/clone.py +58 -0
- oxen/config.py +16 -0
- oxen/data_frame.py +462 -0
- oxen/datasets.py +106 -0
- oxen/df_utils.py +54 -0
- oxen/diff/__init__.py +0 -0
- oxen/diff/change_type.py +12 -0
- oxen/diff/diff.py +143 -0
- oxen/diff/line_diff.py +41 -0
- oxen/diff/tabular_diff.py +22 -0
- oxen/diff/text_diff.py +48 -0
- oxen/features.py +58 -0
- oxen/fs.py +57 -0
- oxen/init.py +19 -0
- oxen/notebooks.py +97 -0
- oxen/oxen.cpython-313-x86_64-linux-gnu.so +0 -0
- oxen/oxen_fs.py +351 -0
- oxen/providers/__init__.py +0 -0
- oxen/providers/dataset_path_provider.py +26 -0
- oxen/providers/mock_provider.py +73 -0
- oxen/providers/oxen_data_frame_provider.py +61 -0
- oxen/remote_repo.py +626 -0
- oxen/repo.py +239 -0
- oxen/streaming_dataset.py +242 -0
- oxen/user.py +40 -0
- oxen/util/__init__.py +0 -0
- oxen/workspace.py +210 -0
- oxen.libs/libcrypto-0787ff19.so.3 +0 -0
- oxen.libs/libssl-ec2edb95.so.3 +0 -0
- oxenai-0.39.1.dist-info/METADATA +92 -0
- oxenai-0.39.1.dist-info/RECORD +35 -0
- oxenai-0.39.1.dist-info/WHEEL +4 -0
- oxenai-0.39.1.dist-info/entry_points.txt +2 -0
oxen/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Core Oxen Functionality"""
|
|
2
|
+
|
|
3
|
+
# Rust wrappers
|
|
4
|
+
from .oxen import (
|
|
5
|
+
PyRepo,
|
|
6
|
+
PyStagedData,
|
|
7
|
+
PyCommit,
|
|
8
|
+
PyRemoteRepo,
|
|
9
|
+
PyDataset,
|
|
10
|
+
PyWorkspace,
|
|
11
|
+
PyWorkspaceDataFrame,
|
|
12
|
+
PyColumn,
|
|
13
|
+
__version__,
|
|
14
|
+
)
|
|
15
|
+
from .oxen import util
|
|
16
|
+
from .oxen import py_notebooks
|
|
17
|
+
|
|
18
|
+
# Python classes
|
|
19
|
+
from oxen.repo import Repo
|
|
20
|
+
from oxen.remote_repo import RemoteRepo
|
|
21
|
+
from oxen.workspace import Workspace
|
|
22
|
+
from oxen.data_frame import DataFrame
|
|
23
|
+
from oxen import auth
|
|
24
|
+
from oxen import datasets
|
|
25
|
+
from oxen.notebooks import start as start_notebook
|
|
26
|
+
from oxen.notebooks import stop as stop_notebook
|
|
27
|
+
from oxen.clone import clone
|
|
28
|
+
from oxen.diff.diff import diff
|
|
29
|
+
from oxen.init import init
|
|
30
|
+
from oxen.config import is_configured
|
|
31
|
+
from oxen.oxen_fs import OxenFS
|
|
32
|
+
|
|
33
|
+
# Names of public modules we want to expose
|
|
34
|
+
__all__ = [
|
|
35
|
+
"auth",
|
|
36
|
+
"DataFrame",
|
|
37
|
+
"Dataset",
|
|
38
|
+
"diff",
|
|
39
|
+
"init",
|
|
40
|
+
"is_configured",
|
|
41
|
+
"notebooks",
|
|
42
|
+
"start_notebook",
|
|
43
|
+
"stop_notebook",
|
|
44
|
+
"OxenFS",
|
|
45
|
+
"PyColumn",
|
|
46
|
+
"PyCommit",
|
|
47
|
+
"PyDataset",
|
|
48
|
+
"PyRemoteRepo",
|
|
49
|
+
"PyRepo",
|
|
50
|
+
"PyStagedData",
|
|
51
|
+
"PyWorkspace",
|
|
52
|
+
"PyWorkspaceDataFrame",
|
|
53
|
+
"RemoteRepo",
|
|
54
|
+
"Repo",
|
|
55
|
+
"util",
|
|
56
|
+
"py_notebooks",
|
|
57
|
+
"clone",
|
|
58
|
+
"datasets",
|
|
59
|
+
"Workspace",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
__version__ = __version__
|
oxen/auth.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .oxen import auth, util
|
|
2
|
+
from oxen.user import config_user
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import os
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def config_auth(token: str, host: str = "hub.oxen.ai", path: Optional[str] = None):
|
|
9
|
+
"""
|
|
10
|
+
Configures authentication for a host.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
token: `str`
|
|
14
|
+
The token to use for authentication.
|
|
15
|
+
host: `str`
|
|
16
|
+
The host to configure authentication for. Default: 'hub.oxen.ai'
|
|
17
|
+
path: `Optional[str]`
|
|
18
|
+
The path to save the authentication config to.
|
|
19
|
+
Defaults to $HOME/.config/oxen/auth_config.toml
|
|
20
|
+
"""
|
|
21
|
+
if path is None:
|
|
22
|
+
path = os.path.join(util.get_oxen_config_dir(), "auth_config.toml")
|
|
23
|
+
if not path.endswith(".toml"):
|
|
24
|
+
raise ValueError("Path must end with .toml")
|
|
25
|
+
auth.config_auth(host, token, path)
|
|
26
|
+
|
|
27
|
+
# Only fetch user if the host is the hub
|
|
28
|
+
if "hub.oxen.ai" == host:
|
|
29
|
+
# Fetch the user from the hub and save it to the config
|
|
30
|
+
url = f"https://{host}/api/authorize"
|
|
31
|
+
# make request with token
|
|
32
|
+
headers = {"Authorization": f"Bearer {token}"}
|
|
33
|
+
r = requests.get(url, headers=headers)
|
|
34
|
+
if r.status_code != 200:
|
|
35
|
+
raise Exception(f"Failed to fetch user from {host}.")
|
|
36
|
+
user = r.json()["user"]
|
|
37
|
+
name = user["name"]
|
|
38
|
+
email = user["email"]
|
|
39
|
+
# save user to config
|
|
40
|
+
config_user(name, email)
|
oxen/clone.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from oxen.repo import Repo
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def clone(
|
|
6
|
+
repo_id: str,
|
|
7
|
+
path: Optional[str] = None,
|
|
8
|
+
host: str = "hub.oxen.ai",
|
|
9
|
+
branch: str = "main",
|
|
10
|
+
scheme: str = "https",
|
|
11
|
+
filters: Optional[str | list[str]] = None,
|
|
12
|
+
all=False,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Clone a repository
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
repo_id: `str`
|
|
19
|
+
Name of the repository in the format 'namespace/repo_name'.
|
|
20
|
+
For example 'ox/chatbot'
|
|
21
|
+
path: `Optional[str]`
|
|
22
|
+
The path to clone the repo to. Defaults to the name of the repository.
|
|
23
|
+
host: `str`
|
|
24
|
+
The host to connect to. Defaults to 'hub.oxen.ai'
|
|
25
|
+
branch: `str`
|
|
26
|
+
The branch name id to clone. Defaults to 'main'
|
|
27
|
+
scheme: `str`
|
|
28
|
+
The scheme to use. Defaults to 'https'
|
|
29
|
+
all: `bool`
|
|
30
|
+
Whether to clone the full commit history or not. Default: False
|
|
31
|
+
filters: `str | list[str] | None`
|
|
32
|
+
Filter down the set of directories you want to clone. Useful if
|
|
33
|
+
you have a large repository and only want to make changes to a
|
|
34
|
+
specific subset of files. Default: None
|
|
35
|
+
Returns:
|
|
36
|
+
[Repo](/python-api/repo)
|
|
37
|
+
A Repo object that can be used to interact with the cloned repo.
|
|
38
|
+
"""
|
|
39
|
+
# Get path from repo_name if not provided
|
|
40
|
+
# Get repo name from repo_id
|
|
41
|
+
repo_name = repo_id.split("/")[-1]
|
|
42
|
+
if path is None:
|
|
43
|
+
path = repo_name
|
|
44
|
+
|
|
45
|
+
if repo_id.startswith("http://") or repo_id.startswith("https://"):
|
|
46
|
+
# Clone repo
|
|
47
|
+
repo = Repo(path)
|
|
48
|
+
repo.clone(repo_id, branch=branch, all=all, filters=filters)
|
|
49
|
+
else:
|
|
50
|
+
# Verify repo_id format
|
|
51
|
+
if "/" not in repo_id:
|
|
52
|
+
raise ValueError(f"Invalid repo_id format: {repo_id}")
|
|
53
|
+
# Get repo url
|
|
54
|
+
repo_url = f"{scheme}://{host}/{repo_id}"
|
|
55
|
+
# Clone repo
|
|
56
|
+
repo = Repo(path)
|
|
57
|
+
repo.clone(repo_url, branch=branch, all=all, filters=filters)
|
|
58
|
+
return repo
|
oxen/config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .oxen import util
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def is_configured():
|
|
6
|
+
"""
|
|
7
|
+
Checks if the user and auth is configured.
|
|
8
|
+
|
|
9
|
+
Returns:
|
|
10
|
+
`bool`: True if the user and auth is configured, False otherwise.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
auth_path = os.path.join(util.get_oxen_config_dir(), "auth_config.toml")
|
|
14
|
+
user_path = os.path.join(util.get_oxen_config_dir(), "user_config.toml")
|
|
15
|
+
|
|
16
|
+
return os.path.exists(auth_path) and os.path.exists(user_path)
|
oxen/data_frame.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
from oxen.workspace import Workspace
|
|
2
|
+
from oxen.remote_repo import RemoteRepo
|
|
3
|
+
from .oxen import PyWorkspaceDataFrame, PyColumn
|
|
4
|
+
import json
|
|
5
|
+
from typing import List, Union, Optional
|
|
6
|
+
import os
|
|
7
|
+
import polars as pl
|
|
8
|
+
from oxen import df_utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DataFrame:
|
|
12
|
+
"""
|
|
13
|
+
The DataFrame class allows you to perform CRUD operations on a remote data frame.
|
|
14
|
+
|
|
15
|
+
If you pass in a [Workspace](/concepts/workspaces) or a [RemoteRepo](/concepts/remote-repos) the data is indexed into DuckDB on an oxen-server without downloading the data locally.
|
|
16
|
+
|
|
17
|
+
## Examples
|
|
18
|
+
|
|
19
|
+
### CRUD Operations
|
|
20
|
+
|
|
21
|
+
Index a data frame in a workspace.
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from oxen import DataFrame
|
|
25
|
+
|
|
26
|
+
# Connect to and index the data frame
|
|
27
|
+
# Note: This must be an existing file committed to the repo
|
|
28
|
+
# indexing may take a while for large files
|
|
29
|
+
data_frame = DataFrame("datasets/SpamOrHam", "data.tsv")
|
|
30
|
+
|
|
31
|
+
# Add a row
|
|
32
|
+
row_id = data_frame.insert_row({"category": "spam", "message": "Hello, do I have an offer for you!"})
|
|
33
|
+
|
|
34
|
+
# Get a row by id
|
|
35
|
+
row = data_frame.get_row_by_id(row_id)
|
|
36
|
+
print(row)
|
|
37
|
+
|
|
38
|
+
# Update a row
|
|
39
|
+
row = data_frame.update_row(row_id, {"category": "ham"})
|
|
40
|
+
print(row)
|
|
41
|
+
|
|
42
|
+
# Delete a row
|
|
43
|
+
data_frame.delete_row(row_id)
|
|
44
|
+
|
|
45
|
+
# Get the current changes to the data frame
|
|
46
|
+
status = data_frame.diff()
|
|
47
|
+
print(status.added_files())
|
|
48
|
+
|
|
49
|
+
# Commit the changes
|
|
50
|
+
data_frame.commit("Updating data.csv")
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
remote: Union[str, RemoteRepo, Workspace],
|
|
57
|
+
path: str,
|
|
58
|
+
host: str = "hub.oxen.ai",
|
|
59
|
+
branch: Optional[str] = None,
|
|
60
|
+
scheme: str = "https",
|
|
61
|
+
workspace_name: Optional[str] = None,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Initialize the DataFrame class. Will index the data frame
|
|
65
|
+
into duckdb on init.
|
|
66
|
+
|
|
67
|
+
Will throw an error if the data frame does not exist.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
remote: `str`, `RemoteRepo`, or `Workspace`
|
|
71
|
+
The workspace or remote repo the data frame is in.
|
|
72
|
+
path: `str`
|
|
73
|
+
The path of the data frame file in the repository.
|
|
74
|
+
host: `str`
|
|
75
|
+
The host of the oxen-server. Defaults to "hub.oxen.ai".
|
|
76
|
+
branch: `Optional[str]`
|
|
77
|
+
The branch of the remote repo. Defaults to None.
|
|
78
|
+
scheme: `str`
|
|
79
|
+
The scheme of the remote repo. Defaults to "https".
|
|
80
|
+
"""
|
|
81
|
+
if isinstance(remote, str):
|
|
82
|
+
remote_repo = RemoteRepo(remote, host=host, scheme=scheme)
|
|
83
|
+
if branch is None:
|
|
84
|
+
branch = remote_repo.branch().name
|
|
85
|
+
self._workspace = Workspace(
|
|
86
|
+
remote_repo, branch, path=path, workspace_name=workspace_name
|
|
87
|
+
)
|
|
88
|
+
elif isinstance(remote, RemoteRepo):
|
|
89
|
+
if branch is None:
|
|
90
|
+
branch = remote.branch().name
|
|
91
|
+
self._workspace = Workspace(
|
|
92
|
+
remote, branch, path=path, workspace_name=workspace_name
|
|
93
|
+
)
|
|
94
|
+
elif isinstance(remote, Workspace):
|
|
95
|
+
self._workspace = remote
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"Invalid remote type. Must be a string, RemoteRepo, or Workspace"
|
|
99
|
+
)
|
|
100
|
+
self._path = path
|
|
101
|
+
# this will return an error if the data frame file does not exist
|
|
102
|
+
try:
|
|
103
|
+
self.data_frame = PyWorkspaceDataFrame(self._workspace._workspace, path)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(e)
|
|
106
|
+
self.data_frame = None
|
|
107
|
+
self.filter_keys = ["_oxen_diff_hash", "_oxen_diff_status", "_oxen_row_id"]
|
|
108
|
+
|
|
109
|
+
def __repr__(self):
|
|
110
|
+
name = f"{self._workspace._repo.namespace}/{self._workspace._repo.name}"
|
|
111
|
+
return f"DataFrame(repo={name}, path={self._path})"
|
|
112
|
+
|
|
113
|
+
def workspace_url(self, host: str = "oxen.ai", scheme: str = "https") -> str:
|
|
114
|
+
"""
|
|
115
|
+
Get the url of the data frame.
|
|
116
|
+
"""
|
|
117
|
+
return f"{scheme}://{host}/{self._workspace._repo.namespace}/{self._workspace._repo.name}/workspaces/{self._workspace.id}/file/{self._path}"
|
|
118
|
+
|
|
119
|
+
def size(self) -> tuple[int, int]:
|
|
120
|
+
"""
|
|
121
|
+
Get the size of the data frame. Returns a tuple of (rows, columns)
|
|
122
|
+
"""
|
|
123
|
+
return self.data_frame.size()
|
|
124
|
+
|
|
125
|
+
def page_size(self) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Get the page size of the data frame for pagination in list() command.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The page size of the data frame.
|
|
131
|
+
"""
|
|
132
|
+
return self.data_frame.page_size()
|
|
133
|
+
|
|
134
|
+
def total_pages(self) -> int:
|
|
135
|
+
"""
|
|
136
|
+
Get the total number of pages in the data frame for pagination in list() command.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
The total number of pages in the data frame.
|
|
140
|
+
"""
|
|
141
|
+
return self.data_frame.total_pages()
|
|
142
|
+
|
|
143
|
+
def list_page(self, page_num: int = 1) -> List[dict]:
|
|
144
|
+
"""
|
|
145
|
+
List the rows within the data frame.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
page_num: `int`
|
|
149
|
+
The page number of the data frame to list. We default to page size of 100 for now.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
A list of rows from the data frame.
|
|
153
|
+
"""
|
|
154
|
+
results = self.data_frame.list(page_num)
|
|
155
|
+
# convert string to dict
|
|
156
|
+
# this is not the most efficient but gets it working
|
|
157
|
+
data = json.loads(results)
|
|
158
|
+
data = self._filter_keys_arr(data)
|
|
159
|
+
return data
|
|
160
|
+
|
|
161
|
+
def insert_row(self, data: dict, workspace: Optional[Workspace] = None):
|
|
162
|
+
"""
|
|
163
|
+
Insert a single row of data into the data frame.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
data: `dict`
|
|
167
|
+
A dictionary representing a single row of data.
|
|
168
|
+
The keys must match a subset of the columns in the data frame.
|
|
169
|
+
If a column is not present in the dictionary,
|
|
170
|
+
it will be set to an empty value.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
The id of the row that was inserted.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
repo = self._workspace.repo
|
|
177
|
+
if not repo.file_exists(self._path):
|
|
178
|
+
tmp_file_path = self._write_first_row(data)
|
|
179
|
+
# Add the file to the repo
|
|
180
|
+
dirname = os.path.dirname(self._path)
|
|
181
|
+
repo.add(tmp_file_path, dst=dirname)
|
|
182
|
+
repo.commit("Adding data frame at " + self._path)
|
|
183
|
+
# This is a temporary hack that allows us to reference the resulting workspace by the
|
|
184
|
+
# same name as the original workspace. Ideally, we should just be able to create a df
|
|
185
|
+
# inside a workspace without a commit
|
|
186
|
+
if workspace is None:
|
|
187
|
+
self._workspace = Workspace(
|
|
188
|
+
repo, self._workspace.branch, path=self._path
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
if workspace.status().is_clean():
|
|
192
|
+
workspace.delete()
|
|
193
|
+
else:
|
|
194
|
+
workspace.commit("commit data to open new workspace")
|
|
195
|
+
self._workspace = Workspace(
|
|
196
|
+
repo,
|
|
197
|
+
workspace.branch,
|
|
198
|
+
path=self._path,
|
|
199
|
+
workspace_name=workspace.name,
|
|
200
|
+
)
|
|
201
|
+
self.data_frame = PyWorkspaceDataFrame(
|
|
202
|
+
self._workspace._workspace, self._path
|
|
203
|
+
)
|
|
204
|
+
results = self.data_frame.list(1)
|
|
205
|
+
results = json.loads(results)
|
|
206
|
+
print(results)
|
|
207
|
+
return results[0]["_oxen_id"]
|
|
208
|
+
else:
|
|
209
|
+
# convert dict to json string
|
|
210
|
+
# this is not the most efficient but gets it working
|
|
211
|
+
data = json.dumps(data)
|
|
212
|
+
return self.data_frame.insert_row(data)
|
|
213
|
+
|
|
214
|
+
def get_columns(self) -> List[PyColumn]:
|
|
215
|
+
"""
|
|
216
|
+
Get the columns of the data frame.
|
|
217
|
+
"""
|
|
218
|
+
# filter out the columns that are in the filter_keys list
|
|
219
|
+
columns = [
|
|
220
|
+
c for c in self.data_frame.get_columns() if c.name not in self.filter_keys
|
|
221
|
+
]
|
|
222
|
+
return columns
|
|
223
|
+
|
|
224
|
+
def add_column(self, name: str, data_type: str):
|
|
225
|
+
"""
|
|
226
|
+
Add a column to the data frame.
|
|
227
|
+
"""
|
|
228
|
+
return self.data_frame.add_column(name, data_type)
|
|
229
|
+
|
|
230
|
+
def _write_first_row(self, data: dict):
|
|
231
|
+
"""
|
|
232
|
+
Write the first row of the data frame to disk, based on the file extension and the input data.
|
|
233
|
+
"""
|
|
234
|
+
# get the filename from the path logs/data_frame_name.csv -> data_frame_name.csv
|
|
235
|
+
basename = os.path.basename(self._path)
|
|
236
|
+
# write the data to a temp file that we will add to the repo
|
|
237
|
+
tmp_file_path = os.path.join("/tmp", basename)
|
|
238
|
+
# Create a polars data frame from the input data
|
|
239
|
+
df = pl.DataFrame(data)
|
|
240
|
+
# Save the data frame to disk
|
|
241
|
+
df_utils.save(df, tmp_file_path)
|
|
242
|
+
# Return the path to the file
|
|
243
|
+
return tmp_file_path
|
|
244
|
+
|
|
245
|
+
# TODO: Allow `where_from_str` to be passed in so user could write their own where clause
|
|
246
|
+
def where_sql_from_dict(self, attributes: dict, operator: str = "AND") -> str:
|
|
247
|
+
"""
|
|
248
|
+
Generate the SQL from the attributes.
|
|
249
|
+
"""
|
|
250
|
+
# df is the name of the data frame
|
|
251
|
+
sql = ""
|
|
252
|
+
i = 0
|
|
253
|
+
for key, value in attributes.items():
|
|
254
|
+
# only accept string and numeric values
|
|
255
|
+
if not isinstance(value, (str, int, float, bool)):
|
|
256
|
+
raise ValueError(f"Invalid value type for {key}: {type(value)}")
|
|
257
|
+
|
|
258
|
+
# if the value is a str put it in quotes
|
|
259
|
+
if isinstance(value, str):
|
|
260
|
+
value = f"'{value}'"
|
|
261
|
+
sql += f"{key} = {value}"
|
|
262
|
+
if i < len(attributes) - 1:
|
|
263
|
+
sql += f" {operator} "
|
|
264
|
+
i += 1
|
|
265
|
+
return sql
|
|
266
|
+
|
|
267
|
+
def select_sql_from_dict(
|
|
268
|
+
self, attributes: dict, columns: Optional[List[str]] = None
|
|
269
|
+
) -> str:
|
|
270
|
+
"""
|
|
271
|
+
Generate the SQL from the attributes.
|
|
272
|
+
"""
|
|
273
|
+
# df is the name of the data frame
|
|
274
|
+
sql = "SELECT "
|
|
275
|
+
if columns is not None:
|
|
276
|
+
sql += ", ".join(columns)
|
|
277
|
+
else:
|
|
278
|
+
sql += "*"
|
|
279
|
+
sql += " FROM df WHERE "
|
|
280
|
+
sql += self.where_sql_from_dict(attributes)
|
|
281
|
+
return sql
|
|
282
|
+
|
|
283
|
+
def get_embeddings(
|
|
284
|
+
self, attributes: dict, column: str = "embedding"
|
|
285
|
+
) -> List[float]:
|
|
286
|
+
"""
|
|
287
|
+
Get the embedding from the data frame.
|
|
288
|
+
"""
|
|
289
|
+
sql = self.select_sql_from_dict(attributes, columns=[column])
|
|
290
|
+
result = self.data_frame.sql_query(sql)
|
|
291
|
+
result = json.loads(result)
|
|
292
|
+
embeddings = [r[column] for r in result]
|
|
293
|
+
return embeddings
|
|
294
|
+
|
|
295
|
+
def is_nearest_neighbors_enabled(self, column="embedding"):
|
|
296
|
+
"""
|
|
297
|
+
Check if the embeddings column is indexed in the data frame.
|
|
298
|
+
"""
|
|
299
|
+
return self.data_frame.is_nearest_neighbors_enabled(column)
|
|
300
|
+
|
|
301
|
+
def enable_nearest_neighbors(self, column: str = "embedding"):
|
|
302
|
+
"""
|
|
303
|
+
Index the embeddings in the data frame.
|
|
304
|
+
"""
|
|
305
|
+
self.data_frame.enable_nearest_neighbors(column)
|
|
306
|
+
|
|
307
|
+
def query(
|
|
308
|
+
self,
|
|
309
|
+
sql: Optional[str] = None,
|
|
310
|
+
find_embedding_where: Optional[dict] = None,
|
|
311
|
+
embedding: Optional[list[float]] = None,
|
|
312
|
+
sort_by_similarity_to: Optional[str] = None,
|
|
313
|
+
page_num: int = 1,
|
|
314
|
+
page_size: int = 10,
|
|
315
|
+
):
|
|
316
|
+
"""
|
|
317
|
+
Sort the data frame by the embedding.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
if sql is not None:
|
|
321
|
+
result = self.data_frame.sql_query(sql)
|
|
322
|
+
elif find_embedding_where is not None and sort_by_similarity_to is not None:
|
|
323
|
+
find_embedding_where = self.where_sql_from_dict(find_embedding_where)
|
|
324
|
+
result = self.data_frame.nearest_neighbors_search(
|
|
325
|
+
find_embedding_where, sort_by_similarity_to, page_num, page_size
|
|
326
|
+
)
|
|
327
|
+
elif embedding is not None and sort_by_similarity_to is not None:
|
|
328
|
+
result = self.data_frame.sort_by_embedding(
|
|
329
|
+
sort_by_similarity_to, embedding, page_num, page_size
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError(
|
|
333
|
+
"Must provide either sql or find_embedding_where as well as sort_by_similarity_to"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
return json.loads(result)
|
|
337
|
+
|
|
338
|
+
def nearest_neighbors_search(
|
|
339
|
+
self, find_embedding_where: dict, sort_by_similarity_to: str = "embedding"
|
|
340
|
+
):
|
|
341
|
+
"""
|
|
342
|
+
Get the nearest neighbors to the embedding.
|
|
343
|
+
"""
|
|
344
|
+
result = self.data_frame.nearest_neighbors_search(
|
|
345
|
+
find_embedding_where, sort_by_similarity_to
|
|
346
|
+
)
|
|
347
|
+
result = json.loads(result)
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
def get_by(self, attributes: dict):
|
|
351
|
+
"""
|
|
352
|
+
Get a single row of data by attributes.
|
|
353
|
+
"""
|
|
354
|
+
# Write the SQL from the attributes
|
|
355
|
+
sql = self.select_sql_from_dict(attributes)
|
|
356
|
+
|
|
357
|
+
# convert dict to json string
|
|
358
|
+
data = self.data_frame.sql_query(sql)
|
|
359
|
+
data = json.loads(data)
|
|
360
|
+
return data
|
|
361
|
+
|
|
362
|
+
def get_row(self, idx: int):
|
|
363
|
+
"""
|
|
364
|
+
Get a single row of data by index.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
idx: `int`
|
|
368
|
+
The index of the row to get.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
A dictionary representing the row.
|
|
372
|
+
"""
|
|
373
|
+
result = self.data_frame.get_row_by_idx(idx)
|
|
374
|
+
result = json.loads(result)
|
|
375
|
+
return result
|
|
376
|
+
|
|
377
|
+
def get_row_by_id(self, id: str):
|
|
378
|
+
"""
|
|
379
|
+
Get a single row of data by id.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
id: `str`
|
|
383
|
+
The id of the row to get.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
A dictionary representing the row.
|
|
387
|
+
"""
|
|
388
|
+
data = self.data_frame.get_row_by_id(id)
|
|
389
|
+
# convert string to dict
|
|
390
|
+
# this is not the most efficient but gets it working
|
|
391
|
+
data = json.loads(data)
|
|
392
|
+
# filter out .oxen.diff.hash and .oxen.diff.status and _oxen_row_id
|
|
393
|
+
data = self._filter_keys_arr(data)
|
|
394
|
+
|
|
395
|
+
if len(data) == 0:
|
|
396
|
+
return None
|
|
397
|
+
return data[0]
|
|
398
|
+
|
|
399
|
+
def update_row(self, id: str, data: dict):
|
|
400
|
+
"""
|
|
401
|
+
Update a single row of data by id.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
id: `str`
|
|
405
|
+
The id of the row to update.
|
|
406
|
+
data: `dict`
|
|
407
|
+
A dictionary representing a single row of data.
|
|
408
|
+
The keys must match a subset of the columns in the data frame.
|
|
409
|
+
If a column is not present in the dictionary,
|
|
410
|
+
it will be set to an empty value.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
The updated row as a dictionary.
|
|
414
|
+
"""
|
|
415
|
+
data = json.dumps(data)
|
|
416
|
+
result = self.data_frame.update_row(id, data)
|
|
417
|
+
result = json.loads(result)
|
|
418
|
+
result = self._filter_keys_arr(result)
|
|
419
|
+
return result
|
|
420
|
+
|
|
421
|
+
def delete_row(self, id: str):
|
|
422
|
+
"""
|
|
423
|
+
Delete a single row of data by id.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
id: `str`
|
|
427
|
+
The id of the row to delete.
|
|
428
|
+
"""
|
|
429
|
+
return self.data_frame.delete_row(id)
|
|
430
|
+
|
|
431
|
+
def restore(self):
|
|
432
|
+
"""
|
|
433
|
+
Unstage any changes to the schema or contents of a data frame
|
|
434
|
+
"""
|
|
435
|
+
self.data_frame.restore()
|
|
436
|
+
|
|
437
|
+
def commit(self, message: str, branch: Optional[str] = None):
|
|
438
|
+
"""
|
|
439
|
+
Commit the current changes to the data frame.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
message: `str`
|
|
443
|
+
The message to commit the changes.
|
|
444
|
+
branch: `str`
|
|
445
|
+
The branch to commit the changes to. Defaults to the current branch.
|
|
446
|
+
"""
|
|
447
|
+
self._workspace.commit(message, branch)
|
|
448
|
+
|
|
449
|
+
def _filter_keys(self, data: dict):
|
|
450
|
+
"""
|
|
451
|
+
Filter out the keys that are not needed in the dataset.
|
|
452
|
+
"""
|
|
453
|
+
# TODO: why do we use periods vs underscores...?
|
|
454
|
+
# filter out .oxen.diff.hash and .oxen.diff.status and _oxen_row_id
|
|
455
|
+
# from each element in the list of dicts
|
|
456
|
+
return {k: v for k, v in data.items() if k not in self.filter_keys}
|
|
457
|
+
|
|
458
|
+
def _filter_keys_arr(self, data: List[dict]):
|
|
459
|
+
"""
|
|
460
|
+
Filter out the keys that are not needed in the dataset.
|
|
461
|
+
"""
|
|
462
|
+
return [self._filter_keys(d) for d in data]
|