oxenai 0.39.1__cp313-cp313-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oxenai might be problematic. Click here for more details.

oxen/__init__.py ADDED
@@ -0,0 +1,62 @@
1
+ """Core Oxen Functionality"""
2
+
3
+ # Rust wrappers
4
+ from .oxen import (
5
+ PyRepo,
6
+ PyStagedData,
7
+ PyCommit,
8
+ PyRemoteRepo,
9
+ PyDataset,
10
+ PyWorkspace,
11
+ PyWorkspaceDataFrame,
12
+ PyColumn,
13
+ __version__,
14
+ )
15
+ from .oxen import util
16
+ from .oxen import py_notebooks
17
+
18
+ # Python classes
19
+ from oxen.repo import Repo
20
+ from oxen.remote_repo import RemoteRepo
21
+ from oxen.workspace import Workspace
22
+ from oxen.data_frame import DataFrame
23
+ from oxen import auth
24
+ from oxen import datasets
25
+ from oxen.notebooks import start as start_notebook
26
+ from oxen.notebooks import stop as stop_notebook
27
+ from oxen.clone import clone
28
+ from oxen.diff.diff import diff
29
+ from oxen.init import init
30
+ from oxen.config import is_configured
31
+ from oxen.oxen_fs import OxenFS
32
+
33
+ # Names of public modules we want to expose
34
+ __all__ = [
35
+ "auth",
36
+ "DataFrame",
37
+ "Dataset",
38
+ "diff",
39
+ "init",
40
+ "is_configured",
41
+ "notebooks",
42
+ "start_notebook",
43
+ "stop_notebook",
44
+ "OxenFS",
45
+ "PyColumn",
46
+ "PyCommit",
47
+ "PyDataset",
48
+ "PyRemoteRepo",
49
+ "PyRepo",
50
+ "PyStagedData",
51
+ "PyWorkspace",
52
+ "PyWorkspaceDataFrame",
53
+ "RemoteRepo",
54
+ "Repo",
55
+ "util",
56
+ "py_notebooks",
57
+ "clone",
58
+ "datasets",
59
+ "Workspace",
60
+ ]
61
+
62
+ __version__ = __version__
oxen/auth.py ADDED
@@ -0,0 +1,40 @@
1
+ from .oxen import auth, util
2
+ from oxen.user import config_user
3
+ from typing import Optional
4
+ import os
5
+ import requests
6
+
7
+
8
+ def config_auth(token: str, host: str = "hub.oxen.ai", path: Optional[str] = None):
9
+ """
10
+ Configures authentication for a host.
11
+
12
+ Args:
13
+ token: `str`
14
+ The token to use for authentication.
15
+ host: `str`
16
+ The host to configure authentication for. Default: 'hub.oxen.ai'
17
+ path: `Optional[str]`
18
+ The path to save the authentication config to.
19
+ Defaults to $HOME/.config/oxen/auth_config.toml
20
+ """
21
+ if path is None:
22
+ path = os.path.join(util.get_oxen_config_dir(), "auth_config.toml")
23
+ if not path.endswith(".toml"):
24
+ raise ValueError("Path must end with .toml")
25
+ auth.config_auth(host, token, path)
26
+
27
+ # Only fetch user if the host is the hub
28
+ if "hub.oxen.ai" == host:
29
+ # Fetch the user from the hub and save it to the config
30
+ url = f"https://{host}/api/authorize"
31
+ # make request with token
32
+ headers = {"Authorization": f"Bearer {token}"}
33
+ r = requests.get(url, headers=headers)
34
+ if r.status_code != 200:
35
+ raise Exception(f"Failed to fetch user from {host}.")
36
+ user = r.json()["user"]
37
+ name = user["name"]
38
+ email = user["email"]
39
+ # save user to config
40
+ config_user(name, email)
oxen/clone.py ADDED
@@ -0,0 +1,58 @@
1
+ from typing import Optional
2
+ from oxen.repo import Repo
3
+
4
+
5
+ def clone(
6
+ repo_id: str,
7
+ path: Optional[str] = None,
8
+ host: str = "hub.oxen.ai",
9
+ branch: str = "main",
10
+ scheme: str = "https",
11
+ filters: Optional[str | list[str]] = None,
12
+ all=False,
13
+ ):
14
+ """
15
+ Clone a repository
16
+
17
+ Args:
18
+ repo_id: `str`
19
+ Name of the repository in the format 'namespace/repo_name'.
20
+ For example 'ox/chatbot'
21
+ path: `Optional[str]`
22
+ The path to clone the repo to. Defaults to the name of the repository.
23
+ host: `str`
24
+ The host to connect to. Defaults to 'hub.oxen.ai'
25
+ branch: `str`
26
+ The branch name id to clone. Defaults to 'main'
27
+ scheme: `str`
28
+ The scheme to use. Defaults to 'https'
29
+ all: `bool`
30
+ Whether to clone the full commit history or not. Default: False
31
+ filters: `str | list[str] | None`
32
+ Filter down the set of directories you want to clone. Useful if
33
+ you have a large repository and only want to make changes to a
34
+ specific subset of files. Default: None
35
+ Returns:
36
+ [Repo](/python-api/repo)
37
+ A Repo object that can be used to interact with the cloned repo.
38
+ """
39
+ # Get path from repo_name if not provided
40
+ # Get repo name from repo_id
41
+ repo_name = repo_id.split("/")[-1]
42
+ if path is None:
43
+ path = repo_name
44
+
45
+ if repo_id.startswith("http://") or repo_id.startswith("https://"):
46
+ # Clone repo
47
+ repo = Repo(path)
48
+ repo.clone(repo_id, branch=branch, all=all, filters=filters)
49
+ else:
50
+ # Verify repo_id format
51
+ if "/" not in repo_id:
52
+ raise ValueError(f"Invalid repo_id format: {repo_id}")
53
+ # Get repo url
54
+ repo_url = f"{scheme}://{host}/{repo_id}"
55
+ # Clone repo
56
+ repo = Repo(path)
57
+ repo.clone(repo_url, branch=branch, all=all, filters=filters)
58
+ return repo
oxen/config.py ADDED
@@ -0,0 +1,16 @@
1
+ from .oxen import util
2
+ import os
3
+
4
+
5
+ def is_configured():
6
+ """
7
+ Checks if the user and auth is configured.
8
+
9
+ Returns:
10
+ `bool`: True if the user and auth is configured, False otherwise.
11
+ """
12
+
13
+ auth_path = os.path.join(util.get_oxen_config_dir(), "auth_config.toml")
14
+ user_path = os.path.join(util.get_oxen_config_dir(), "user_config.toml")
15
+
16
+ return os.path.exists(auth_path) and os.path.exists(user_path)
oxen/data_frame.py ADDED
@@ -0,0 +1,462 @@
1
+ from oxen.workspace import Workspace
2
+ from oxen.remote_repo import RemoteRepo
3
+ from .oxen import PyWorkspaceDataFrame, PyColumn
4
+ import json
5
+ from typing import List, Union, Optional
6
+ import os
7
+ import polars as pl
8
+ from oxen import df_utils
9
+
10
+
11
+ class DataFrame:
12
+ """
13
+ The DataFrame class allows you to perform CRUD operations on a remote data frame.
14
+
15
+ If you pass in a [Workspace](/concepts/workspaces) or a [RemoteRepo](/concepts/remote-repos) the data is indexed into DuckDB on an oxen-server without downloading the data locally.
16
+
17
+ ## Examples
18
+
19
+ ### CRUD Operations
20
+
21
+ Index a data frame in a workspace.
22
+
23
+ ```python
24
+ from oxen import DataFrame
25
+
26
+ # Connect to and index the data frame
27
+ # Note: This must be an existing file committed to the repo
28
+ # indexing may take a while for large files
29
+ data_frame = DataFrame("datasets/SpamOrHam", "data.tsv")
30
+
31
+ # Add a row
32
+ row_id = data_frame.insert_row({"category": "spam", "message": "Hello, do I have an offer for you!"})
33
+
34
+ # Get a row by id
35
+ row = data_frame.get_row_by_id(row_id)
36
+ print(row)
37
+
38
+ # Update a row
39
+ row = data_frame.update_row(row_id, {"category": "ham"})
40
+ print(row)
41
+
42
+ # Delete a row
43
+ data_frame.delete_row(row_id)
44
+
45
+ # Get the current changes to the data frame
46
+ status = data_frame.diff()
47
+ print(status.added_files())
48
+
49
+ # Commit the changes
50
+ data_frame.commit("Updating data.csv")
51
+ ```
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ remote: Union[str, RemoteRepo, Workspace],
57
+ path: str,
58
+ host: str = "hub.oxen.ai",
59
+ branch: Optional[str] = None,
60
+ scheme: str = "https",
61
+ workspace_name: Optional[str] = None,
62
+ ):
63
+ """
64
+ Initialize the DataFrame class. Will index the data frame
65
+ into duckdb on init.
66
+
67
+ Will throw an error if the data frame does not exist.
68
+
69
+ Args:
70
+ remote: `str`, `RemoteRepo`, or `Workspace`
71
+ The workspace or remote repo the data frame is in.
72
+ path: `str`
73
+ The path of the data frame file in the repository.
74
+ host: `str`
75
+ The host of the oxen-server. Defaults to "hub.oxen.ai".
76
+ branch: `Optional[str]`
77
+ The branch of the remote repo. Defaults to None.
78
+ scheme: `str`
79
+ The scheme of the remote repo. Defaults to "https".
80
+ """
81
+ if isinstance(remote, str):
82
+ remote_repo = RemoteRepo(remote, host=host, scheme=scheme)
83
+ if branch is None:
84
+ branch = remote_repo.branch().name
85
+ self._workspace = Workspace(
86
+ remote_repo, branch, path=path, workspace_name=workspace_name
87
+ )
88
+ elif isinstance(remote, RemoteRepo):
89
+ if branch is None:
90
+ branch = remote.branch().name
91
+ self._workspace = Workspace(
92
+ remote, branch, path=path, workspace_name=workspace_name
93
+ )
94
+ elif isinstance(remote, Workspace):
95
+ self._workspace = remote
96
+ else:
97
+ raise ValueError(
98
+ "Invalid remote type. Must be a string, RemoteRepo, or Workspace"
99
+ )
100
+ self._path = path
101
+ # this will return an error if the data frame file does not exist
102
+ try:
103
+ self.data_frame = PyWorkspaceDataFrame(self._workspace._workspace, path)
104
+ except Exception as e:
105
+ print(e)
106
+ self.data_frame = None
107
+ self.filter_keys = ["_oxen_diff_hash", "_oxen_diff_status", "_oxen_row_id"]
108
+
109
+ def __repr__(self):
110
+ name = f"{self._workspace._repo.namespace}/{self._workspace._repo.name}"
111
+ return f"DataFrame(repo={name}, path={self._path})"
112
+
113
+ def workspace_url(self, host: str = "oxen.ai", scheme: str = "https") -> str:
114
+ """
115
+ Get the url of the data frame.
116
+ """
117
+ return f"{scheme}://{host}/{self._workspace._repo.namespace}/{self._workspace._repo.name}/workspaces/{self._workspace.id}/file/{self._path}"
118
+
119
+ def size(self) -> tuple[int, int]:
120
+ """
121
+ Get the size of the data frame. Returns a tuple of (rows, columns)
122
+ """
123
+ return self.data_frame.size()
124
+
125
+ def page_size(self) -> int:
126
+ """
127
+ Get the page size of the data frame for pagination in list() command.
128
+
129
+ Returns:
130
+ The page size of the data frame.
131
+ """
132
+ return self.data_frame.page_size()
133
+
134
+ def total_pages(self) -> int:
135
+ """
136
+ Get the total number of pages in the data frame for pagination in list() command.
137
+
138
+ Returns:
139
+ The total number of pages in the data frame.
140
+ """
141
+ return self.data_frame.total_pages()
142
+
143
+ def list_page(self, page_num: int = 1) -> List[dict]:
144
+ """
145
+ List the rows within the data frame.
146
+
147
+ Args:
148
+ page_num: `int`
149
+ The page number of the data frame to list. We default to page size of 100 for now.
150
+
151
+ Returns:
152
+ A list of rows from the data frame.
153
+ """
154
+ results = self.data_frame.list(page_num)
155
+ # convert string to dict
156
+ # this is not the most efficient but gets it working
157
+ data = json.loads(results)
158
+ data = self._filter_keys_arr(data)
159
+ return data
160
+
161
+ def insert_row(self, data: dict, workspace: Optional[Workspace] = None):
162
+ """
163
+ Insert a single row of data into the data frame.
164
+
165
+ Args:
166
+ data: `dict`
167
+ A dictionary representing a single row of data.
168
+ The keys must match a subset of the columns in the data frame.
169
+ If a column is not present in the dictionary,
170
+ it will be set to an empty value.
171
+
172
+ Returns:
173
+ The id of the row that was inserted.
174
+ """
175
+
176
+ repo = self._workspace.repo
177
+ if not repo.file_exists(self._path):
178
+ tmp_file_path = self._write_first_row(data)
179
+ # Add the file to the repo
180
+ dirname = os.path.dirname(self._path)
181
+ repo.add(tmp_file_path, dst=dirname)
182
+ repo.commit("Adding data frame at " + self._path)
183
+ # This is a temporary hack that allows us to reference the resulting workspace by the
184
+ # same name as the original workspace. Ideally, we should just be able to create a df
185
+ # inside a workspace without a commit
186
+ if workspace is None:
187
+ self._workspace = Workspace(
188
+ repo, self._workspace.branch, path=self._path
189
+ )
190
+ else:
191
+ if workspace.status().is_clean():
192
+ workspace.delete()
193
+ else:
194
+ workspace.commit("commit data to open new workspace")
195
+ self._workspace = Workspace(
196
+ repo,
197
+ workspace.branch,
198
+ path=self._path,
199
+ workspace_name=workspace.name,
200
+ )
201
+ self.data_frame = PyWorkspaceDataFrame(
202
+ self._workspace._workspace, self._path
203
+ )
204
+ results = self.data_frame.list(1)
205
+ results = json.loads(results)
206
+ print(results)
207
+ return results[0]["_oxen_id"]
208
+ else:
209
+ # convert dict to json string
210
+ # this is not the most efficient but gets it working
211
+ data = json.dumps(data)
212
+ return self.data_frame.insert_row(data)
213
+
214
+ def get_columns(self) -> List[PyColumn]:
215
+ """
216
+ Get the columns of the data frame.
217
+ """
218
+ # filter out the columns that are in the filter_keys list
219
+ columns = [
220
+ c for c in self.data_frame.get_columns() if c.name not in self.filter_keys
221
+ ]
222
+ return columns
223
+
224
+ def add_column(self, name: str, data_type: str):
225
+ """
226
+ Add a column to the data frame.
227
+ """
228
+ return self.data_frame.add_column(name, data_type)
229
+
230
+ def _write_first_row(self, data: dict):
231
+ """
232
+ Write the first row of the data frame to disk, based on the file extension and the input data.
233
+ """
234
+ # get the filename from the path logs/data_frame_name.csv -> data_frame_name.csv
235
+ basename = os.path.basename(self._path)
236
+ # write the data to a temp file that we will add to the repo
237
+ tmp_file_path = os.path.join("/tmp", basename)
238
+ # Create a polars data frame from the input data
239
+ df = pl.DataFrame(data)
240
+ # Save the data frame to disk
241
+ df_utils.save(df, tmp_file_path)
242
+ # Return the path to the file
243
+ return tmp_file_path
244
+
245
+ # TODO: Allow `where_from_str` to be passed in so user could write their own where clause
246
+ def where_sql_from_dict(self, attributes: dict, operator: str = "AND") -> str:
247
+ """
248
+ Generate the SQL from the attributes.
249
+ """
250
+ # df is the name of the data frame
251
+ sql = ""
252
+ i = 0
253
+ for key, value in attributes.items():
254
+ # only accept string and numeric values
255
+ if not isinstance(value, (str, int, float, bool)):
256
+ raise ValueError(f"Invalid value type for {key}: {type(value)}")
257
+
258
+ # if the value is a str put it in quotes
259
+ if isinstance(value, str):
260
+ value = f"'{value}'"
261
+ sql += f"{key} = {value}"
262
+ if i < len(attributes) - 1:
263
+ sql += f" {operator} "
264
+ i += 1
265
+ return sql
266
+
267
+ def select_sql_from_dict(
268
+ self, attributes: dict, columns: Optional[List[str]] = None
269
+ ) -> str:
270
+ """
271
+ Generate the SQL from the attributes.
272
+ """
273
+ # df is the name of the data frame
274
+ sql = "SELECT "
275
+ if columns is not None:
276
+ sql += ", ".join(columns)
277
+ else:
278
+ sql += "*"
279
+ sql += " FROM df WHERE "
280
+ sql += self.where_sql_from_dict(attributes)
281
+ return sql
282
+
283
+ def get_embeddings(
284
+ self, attributes: dict, column: str = "embedding"
285
+ ) -> List[float]:
286
+ """
287
+ Get the embedding from the data frame.
288
+ """
289
+ sql = self.select_sql_from_dict(attributes, columns=[column])
290
+ result = self.data_frame.sql_query(sql)
291
+ result = json.loads(result)
292
+ embeddings = [r[column] for r in result]
293
+ return embeddings
294
+
295
+ def is_nearest_neighbors_enabled(self, column="embedding"):
296
+ """
297
+ Check if the embeddings column is indexed in the data frame.
298
+ """
299
+ return self.data_frame.is_nearest_neighbors_enabled(column)
300
+
301
+ def enable_nearest_neighbors(self, column: str = "embedding"):
302
+ """
303
+ Index the embeddings in the data frame.
304
+ """
305
+ self.data_frame.enable_nearest_neighbors(column)
306
+
307
+ def query(
308
+ self,
309
+ sql: Optional[str] = None,
310
+ find_embedding_where: Optional[dict] = None,
311
+ embedding: Optional[list[float]] = None,
312
+ sort_by_similarity_to: Optional[str] = None,
313
+ page_num: int = 1,
314
+ page_size: int = 10,
315
+ ):
316
+ """
317
+ Sort the data frame by the embedding.
318
+ """
319
+
320
+ if sql is not None:
321
+ result = self.data_frame.sql_query(sql)
322
+ elif find_embedding_where is not None and sort_by_similarity_to is not None:
323
+ find_embedding_where = self.where_sql_from_dict(find_embedding_where)
324
+ result = self.data_frame.nearest_neighbors_search(
325
+ find_embedding_where, sort_by_similarity_to, page_num, page_size
326
+ )
327
+ elif embedding is not None and sort_by_similarity_to is not None:
328
+ result = self.data_frame.sort_by_embedding(
329
+ sort_by_similarity_to, embedding, page_num, page_size
330
+ )
331
+ else:
332
+ raise ValueError(
333
+ "Must provide either sql or find_embedding_where as well as sort_by_similarity_to"
334
+ )
335
+
336
+ return json.loads(result)
337
+
338
+ def nearest_neighbors_search(
339
+ self, find_embedding_where: dict, sort_by_similarity_to: str = "embedding"
340
+ ):
341
+ """
342
+ Get the nearest neighbors to the embedding.
343
+ """
344
+ result = self.data_frame.nearest_neighbors_search(
345
+ find_embedding_where, sort_by_similarity_to
346
+ )
347
+ result = json.loads(result)
348
+ return result
349
+
350
+ def get_by(self, attributes: dict):
351
+ """
352
+ Get a single row of data by attributes.
353
+ """
354
+ # Write the SQL from the attributes
355
+ sql = self.select_sql_from_dict(attributes)
356
+
357
+ # convert dict to json string
358
+ data = self.data_frame.sql_query(sql)
359
+ data = json.loads(data)
360
+ return data
361
+
362
+ def get_row(self, idx: int):
363
+ """
364
+ Get a single row of data by index.
365
+
366
+ Args:
367
+ idx: `int`
368
+ The index of the row to get.
369
+
370
+ Returns:
371
+ A dictionary representing the row.
372
+ """
373
+ result = self.data_frame.get_row_by_idx(idx)
374
+ result = json.loads(result)
375
+ return result
376
+
377
+ def get_row_by_id(self, id: str):
378
+ """
379
+ Get a single row of data by id.
380
+
381
+ Args:
382
+ id: `str`
383
+ The id of the row to get.
384
+
385
+ Returns:
386
+ A dictionary representing the row.
387
+ """
388
+ data = self.data_frame.get_row_by_id(id)
389
+ # convert string to dict
390
+ # this is not the most efficient but gets it working
391
+ data = json.loads(data)
392
+ # filter out .oxen.diff.hash and .oxen.diff.status and _oxen_row_id
393
+ data = self._filter_keys_arr(data)
394
+
395
+ if len(data) == 0:
396
+ return None
397
+ return data[0]
398
+
399
+ def update_row(self, id: str, data: dict):
400
+ """
401
+ Update a single row of data by id.
402
+
403
+ Args:
404
+ id: `str`
405
+ The id of the row to update.
406
+ data: `dict`
407
+ A dictionary representing a single row of data.
408
+ The keys must match a subset of the columns in the data frame.
409
+ If a column is not present in the dictionary,
410
+ it will be set to an empty value.
411
+
412
+ Returns:
413
+ The updated row as a dictionary.
414
+ """
415
+ data = json.dumps(data)
416
+ result = self.data_frame.update_row(id, data)
417
+ result = json.loads(result)
418
+ result = self._filter_keys_arr(result)
419
+ return result
420
+
421
+ def delete_row(self, id: str):
422
+ """
423
+ Delete a single row of data by id.
424
+
425
+ Args:
426
+ id: `str`
427
+ The id of the row to delete.
428
+ """
429
+ return self.data_frame.delete_row(id)
430
+
431
+ def restore(self):
432
+ """
433
+ Unstage any changes to the schema or contents of a data frame
434
+ """
435
+ self.data_frame.restore()
436
+
437
+ def commit(self, message: str, branch: Optional[str] = None):
438
+ """
439
+ Commit the current changes to the data frame.
440
+
441
+ Args:
442
+ message: `str`
443
+ The message to commit the changes.
444
+ branch: `str`
445
+ The branch to commit the changes to. Defaults to the current branch.
446
+ """
447
+ self._workspace.commit(message, branch)
448
+
449
+ def _filter_keys(self, data: dict):
450
+ """
451
+ Filter out the keys that are not needed in the dataset.
452
+ """
453
+ # TODO: why do we use periods vs underscores...?
454
+ # filter out .oxen.diff.hash and .oxen.diff.status and _oxen_row_id
455
+ # from each element in the list of dicts
456
+ return {k: v for k, v in data.items() if k not in self.filter_keys}
457
+
458
+ def _filter_keys_arr(self, data: List[dict]):
459
+ """
460
+ Filter out the keys that are not needed in the dataset.
461
+ """
462
+ return [self._filter_keys(d) for d in data]