pyspiral 0.6.9__cp312-abi3-macosx_11_0_arm64.whl → 0.7.12__cp312-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/METADATA +9 -8
  2. {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/RECORD +53 -45
  3. {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/entry_points.txt +1 -0
  4. spiral/__init__.py +20 -0
  5. spiral/_lib.abi3.so +0 -0
  6. spiral/api/__init__.py +1 -1
  7. spiral/api/client.py +1 -1
  8. spiral/api/types.py +1 -0
  9. spiral/cli/admin.py +2 -2
  10. spiral/cli/app.py +8 -4
  11. spiral/cli/fs.py +4 -4
  12. spiral/cli/iceberg.py +1 -1
  13. spiral/cli/key_spaces.py +15 -1
  14. spiral/cli/login.py +4 -3
  15. spiral/cli/orgs.py +8 -7
  16. spiral/cli/projects.py +4 -4
  17. spiral/cli/state.py +5 -3
  18. spiral/cli/tables.py +59 -36
  19. spiral/cli/telemetry.py +1 -1
  20. spiral/cli/types.py +2 -2
  21. spiral/cli/workloads.py +3 -3
  22. spiral/client.py +69 -22
  23. spiral/core/client/__init__.pyi +48 -13
  24. spiral/core/config/__init__.pyi +47 -0
  25. spiral/core/expr/__init__.pyi +15 -0
  26. spiral/core/expr/images/__init__.pyi +3 -0
  27. spiral/core/expr/list_/__init__.pyi +4 -0
  28. spiral/core/expr/refs/__init__.pyi +4 -0
  29. spiral/core/expr/str_/__init__.pyi +3 -0
  30. spiral/core/expr/struct_/__init__.pyi +6 -0
  31. spiral/core/expr/text/__init__.pyi +5 -0
  32. spiral/core/expr/udf/__init__.pyi +14 -0
  33. spiral/core/expr/video/__init__.pyi +3 -0
  34. spiral/core/table/__init__.pyi +37 -2
  35. spiral/core/table/spec/__init__.pyi +6 -4
  36. spiral/dataloader.py +52 -38
  37. spiral/dataset.py +10 -1
  38. spiral/enrichment.py +304 -0
  39. spiral/expressions/__init__.py +21 -23
  40. spiral/expressions/base.py +9 -4
  41. spiral/expressions/file.py +17 -0
  42. spiral/expressions/http.py +11 -80
  43. spiral/expressions/s3.py +16 -0
  44. spiral/expressions/tiff.py +2 -3
  45. spiral/expressions/udf.py +38 -24
  46. spiral/iceberg.py +3 -3
  47. spiral/project.py +34 -6
  48. spiral/scan.py +80 -33
  49. spiral/settings.py +19 -97
  50. spiral/streaming_/stream.py +1 -1
  51. spiral/table.py +40 -10
  52. spiral/transaction.py +99 -2
  53. spiral/expressions/io.py +0 -100
  54. spiral/expressions/mp4.py +0 -62
  55. spiral/expressions/png.py +0 -18
  56. spiral/expressions/qoi.py +0 -18
  57. spiral/expressions/refs.py +0 -58
  58. {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/WHEEL +0 -0
@@ -1,86 +1,17 @@
1
- import hishel
2
- import httpx
3
- import pyarrow as pa
4
-
1
+ from spiral import _lib
5
2
  from spiral.expressions.base import Expr, ExprLike
6
- from spiral.expressions.struct import pack
7
- from spiral.expressions.udf import UDF
8
- from spiral.settings import APP_DIR
9
-
10
-
11
- def get(url: ExprLike, headers: ExprLike = None, force_cache: bool = False) -> Expr:
12
- """Submit a GET request to either a scalar of vector of URLs."""
13
- to_pack = {"url": url}
14
- if headers is not None:
15
- to_pack["headers"] = headers
16
- return HttpGet(force_cache)(pack(to_pack))
17
-
18
-
19
- class HttpGet(UDF):
20
- RES_DTYPE: pa.DataType = pa.struct(
21
- [
22
- pa.field("bytes", pa.large_binary()),
23
- pa.field("status", pa.int32()),
24
- pa.field("headers", pa.map_(pa.string(), pa.string())),
25
- ]
26
- )
27
-
28
- def __init__(self, force_cache: bool = False):
29
- super().__init__("http.get")
30
- self._force_cache = force_cache
31
-
32
- def return_type(self, *input_types: pa.DataType) -> pa.DataType:
33
- return HttpGet.RES_DTYPE
34
-
35
- def invoke(self, *input_args: pa.Array) -> pa.Array:
36
- if len(input_args) != 1:
37
- raise ValueError(f"Expected 1 argument, got {len(input_args)}")
38
- result = _http_request(input_args[0], self._force_cache)
39
- if isinstance(result, pa.ChunkedArray):
40
- result = result.combine_chunks()
41
- return result
42
-
43
-
44
- def _http_request(arg: pa.Array, force_cache: bool) -> pa.Array:
45
- client = _HttpClient()
46
-
47
- if isinstance(arg, pa.StructArray):
48
- # We assume a vector of requests, but with potentially many arguments
49
- return pa.array(
50
- [
51
- _response_dict(
52
- client.request(
53
- req.get("method", "GET").upper(),
54
- req["url"],
55
- headers=req.get("headers", {}),
56
- extensions={"force_cache": force_cache},
57
- )
58
- )
59
- for req in arg.to_pylist()
60
- ],
61
- type=HttpGet.RES_DTYPE,
62
- )
63
-
64
- raise TypeError(f"Unsupported argument: {arg} ({type(arg)})")
65
-
66
3
 
67
- def _response_dict(response: httpx.Response) -> dict:
68
- if response.status_code != 200:
69
- raise ValueError(f"Request failed with status {response.status_code}")
70
- return {
71
- "bytes": response.read(),
72
- "status": response.status_code,
73
- "headers": dict(response.headers),
74
- }
75
4
 
5
+ def get(expr: ExprLike, abort_on_error: bool = False) -> Expr:
6
+ """Read data from the URL.
76
7
 
77
- class _HttpClient(hishel.CacheClient):
78
- _instance: "_HttpClient" = None
8
+ Args:
9
+ expr: URLs of the data that needs to be read.
10
+ abort_on_error: Should the expression abort on errors or just collect them.
11
+ """
12
+ from spiral import expressions as se
79
13
 
80
- def __new__(cls, *args, **kwargs):
81
- if not cls._instance:
82
- cls._instance = super().__new__(cls)
83
- return cls._instance
14
+ expr = se.lift(expr)
84
15
 
85
- def __init__(self):
86
- super().__init__(storage=hishel.FileStorage(base_path=APP_DIR / "http.cache", ttl=3600))
16
+ # This just works :)
17
+ return Expr(_lib.expr.s3.get(expr.__expr__, abort_on_error))
@@ -0,0 +1,16 @@
1
+ from spiral import _lib
2
+ from spiral.expressions.base import Expr, ExprLike
3
+
4
+
5
+ def get(expr: ExprLike, abort_on_error: bool = False) -> Expr:
6
+ """Read data from object storage by the s3:// URL.
7
+
8
+ Args:
9
+ expr: URLs of the data that needs to be read from object storage.
10
+ abort_on_error: Should the expression abort on errors or just collect them.
11
+ """
12
+ from spiral import expressions as se
13
+
14
+ expr = se.lift(expr)
15
+
16
+ return Expr(_lib.expr.s3.get(expr.__expr__, abort_on_error))
@@ -2,7 +2,6 @@ import numpy as np
2
2
  import pyarrow as pa
3
3
 
4
4
  from spiral.expressions.base import Expr, ExprLike
5
- from spiral.expressions.udf import RefUDF
6
5
 
7
6
  _TIFF_RES_DTYPE: pa.DataType = pa.struct(
8
7
  [
@@ -78,7 +77,7 @@ def select(
78
77
  return TiffSelectUDF()(expr, shape, indexes)
79
78
 
80
79
 
81
- class TiffReadUDF(RefUDF):
80
+ class TiffReadUDF:
82
81
  def __init__(self):
83
82
  super().__init__("tiff.read")
84
83
 
@@ -122,7 +121,7 @@ class TiffReadUDF(RefUDF):
122
121
  return _return_result(result, indexes)
123
122
 
124
123
 
125
- class TiffSelectUDF(RefUDF):
124
+ class TiffSelectUDF:
126
125
  def __init__(self):
127
126
  super().__init__("tiff.select")
128
127
 
spiral/expressions/udf.py CHANGED
@@ -3,44 +3,58 @@ import abc
3
3
  import pyarrow as pa
4
4
 
5
5
  from spiral import _lib
6
- from spiral.expressions.base import Expr
6
+ from spiral.expressions.base import Expr, ExprLike
7
7
 
8
8
 
9
- class BaseUDF:
10
- def __init__(self, udf):
11
- self._udf = udf
9
+ class UDF(abc.ABC):
10
+ """A User-Defined Function (UDF). This class should be subclassed to define custom UDFs.
12
11
 
13
- def __call__(self, *args) -> Expr:
14
- """Create an expression that calls this UDF with the given arguments."""
15
- from spiral import expressions as se
12
+ Example:
16
13
 
17
- args = [se.lift(arg).__expr__ for arg in args]
18
- return Expr(self._udf(args))
14
+ ```python
15
+ from spiral import expressions as se
16
+ import pyarrow as pa
19
17
 
20
- @abc.abstractmethod
21
- def return_type(self, *input_types: pa.DataType) -> pa.DataType: ...
18
+ class MyAdd(se.UDF):
19
+ def __init__(self):
20
+ super().__init__("my_add")
22
21
 
22
+ def return_type(self, scope: pa.DataType):
23
+ if not isinstance(scope, pa.StructType):
24
+ raise ValueError("Expected struct type as input")
25
+ return scope.field(0).type
23
26
 
24
- class UDF(BaseUDF):
25
- """A User-Defined Function (UDF)."""
27
+ def invoke(self, scope: pa.Array):
28
+ if not isinstance(scope, pa.StructArray):
29
+ raise ValueError("Expected struct array as input")
30
+ return pa.compute.add(scope.field(0), scope.field(1))
26
31
 
27
- def __init__(self, name: str):
28
- super().__init__(_lib.expr.udf.create(name, return_type=self.return_type, invoke=self.invoke))
32
+ my_add = MyAdd()
29
33
 
30
- @abc.abstractmethod
31
- def invoke(self, *input_args: pa.Array) -> pa.Array: ...
34
+ expr = my_add(table.select("first_arg", "second_arg"))
35
+ ```
36
+ """
32
37
 
38
+ def __init__(self, name: str):
39
+ self._udf = _lib.expr.udf.create(name, return_type=self.return_type, invoke=self.invoke)
33
40
 
34
- class RefUDF(BaseUDF):
35
- """A UDF over a single ref cell, and therefore can access the file object."""
41
+ def __call__(self, scope: ExprLike) -> Expr:
42
+ """Create an expression that calls this UDF with the given arguments."""
43
+ from spiral import expressions as se
36
44
 
37
- def __init__(self, name: str):
38
- super().__init__(_lib.expr.udf.create(name, return_type=self.return_type, invoke=self.invoke, scope="ref"))
45
+ return Expr(self._udf(se.lift(scope).__expr__))
39
46
 
40
47
  @abc.abstractmethod
41
- def invoke(self, fp, *input_args: pa.Array) -> pa.Array:
42
- """Invoke the UDF with the given arguments.
48
+ def return_type(self, scope: pa.DataType) -> pa.DataType:
49
+ """Must return the return type of the UDF given the input scope type.
43
50
 
44
- NOTE: The first argument is always the ref cell. All array input args will be sliced to the appropriate row.
51
+ IMPORTANT: All expressions in Spiral must return nullable (Arrow default) types,
52
+ including nested structs, meaning that all fields in structs must also be nullable,
53
+ and if those fields are structs, their fields must also be nullable, and so on.
45
54
  """
46
55
  ...
56
+
57
+ @abc.abstractmethod
58
+ def invoke(self, scope: pa.Array) -> pa.Array:
59
+ """Must implement the UDF logic given the input scope array."""
60
+ ...
spiral/iceberg.py CHANGED
@@ -15,7 +15,7 @@ class Iceberg:
15
15
 
16
16
  def __init__(self, spiral: "Spiral"):
17
17
  self._spiral = spiral
18
- self._api = self._spiral.config.api
18
+ self._api = self._spiral.api
19
19
 
20
20
  def catalog(self) -> "Catalog":
21
21
  """Open the Iceberg catalog."""
@@ -25,7 +25,7 @@ class Iceberg:
25
25
  "default",
26
26
  **{
27
27
  "type": "rest",
28
- "uri": self._spiral.config.spiraldb.uri + "/iceberg",
29
- "token": self._spiral.config.authn.token().expose_secret(),
28
+ "uri": self._spiral.config.server_url + "/iceberg",
29
+ "token": self._spiral.authn.token().expose_secret(),
30
30
  },
31
31
  )
spiral/project.py CHANGED
@@ -53,7 +53,7 @@ class Project:
53
53
  res = res[0]
54
54
 
55
55
  return Table(
56
- self._spiral, self._spiral._core.table(res.id), identifier=f"{res.project_id}.{res.dataset}.{res.table}"
56
+ self._spiral, self._spiral.core.table(res.id), identifier=f"{res.project_id}.{res.dataset}.{res.table}"
57
57
  )
58
58
 
59
59
  def create_table(
@@ -78,7 +78,7 @@ class Project:
78
78
  key_schema = pa.schema(key_schema)
79
79
  key_schema = Schema.from_arrow(key_schema)
80
80
 
81
- core_table = self._spiral._core.create_table(
81
+ core_table = self._spiral.core.create_table(
82
82
  project_id=self._id,
83
83
  dataset=dataset,
84
84
  table=table,
@@ -89,6 +89,34 @@ class Project:
89
89
 
90
90
  return Table(self._spiral, core_table, identifier=f"{self._id}.{dataset}.{table}")
91
91
 
92
+ def move_table(self, identifier: str, new_dataset: str):
93
+ """Move a table to a new dataset in the project.
94
+
95
+ Args:
96
+ identifier: The table identifier, in the form `dataset.table` or `table`.
97
+ new_dataset: The dataset into which to move this table.
98
+ """
99
+ table = self.table(identifier)
100
+
101
+ self._spiral.core.move_table(
102
+ table_id=table.table_id,
103
+ new_dataset=new_dataset,
104
+ )
105
+
106
+ def rename_table(self, identifier: str, new_table: str):
107
+ """Move a table to a new dataset in the project.
108
+
109
+ Args:
110
+ identifier: The table identifier, in the form `dataset.table` or `table`.
111
+ new_dataset: The dataset into which to move this table.
112
+ """
113
+ table = self.table(identifier)
114
+
115
+ self._spiral.core.rename_table(
116
+ table_id=table.table_id,
117
+ new_table=new_table,
118
+ )
119
+
92
120
  def _parse_table_identifier(self, identifier: str) -> tuple[str, str]:
93
121
  parts = identifier.split(".")
94
122
  if len(parts) == 1:
@@ -105,7 +133,7 @@ class Project:
105
133
  raise ValueError(f"Index not found: {name}")
106
134
  res = res[0]
107
135
 
108
- return TextIndex(self._spiral._core.text_index(res.id), name=name)
136
+ return TextIndex(self._spiral.core.text_index(res.id), name=name)
109
137
 
110
138
  def create_text_index(
111
139
  self,
@@ -135,7 +163,7 @@ class Project:
135
163
  if where is not None:
136
164
  where = se.lift(where)
137
165
 
138
- core_index = self._spiral._core.create_text_index(
166
+ core_index = self._spiral.core.create_text_index(
139
167
  project_id=self._id,
140
168
  name=name,
141
169
  projection=projection.__expr__,
@@ -154,7 +182,7 @@ class Project:
154
182
  raise ValueError(f"Index not found: {name}")
155
183
  res = res[0]
156
184
 
157
- return KeySpaceIndex(self._spiral._core.key_space_index(res.id), name=name)
185
+ return KeySpaceIndex(self._spiral.core.key_space_index(res.id), name=name)
158
186
 
159
187
  def create_key_space_index(
160
188
  self,
@@ -185,7 +213,7 @@ class Project:
185
213
  if where is not None:
186
214
  where = se.lift(where)
187
215
 
188
- core_index = self._spiral._core.create_key_space_index(
216
+ core_index = self._spiral.core.create_key_space_index(
189
217
  project_id=self._id,
190
218
  name=name,
191
219
  granularity=granularity,
spiral/scan.py CHANGED
@@ -1,8 +1,10 @@
1
+ from functools import partial
1
2
  from typing import TYPE_CHECKING, Any, Optional
2
3
 
3
4
  import pyarrow as pa
4
5
 
5
6
  from spiral.core.client import Shard, ShuffleConfig
7
+ from spiral.core.table import KeyRange
6
8
  from spiral.core.table import Scan as CoreScan
7
9
  from spiral.core.table.spec import Schema
8
10
  from spiral.settings import CI, DEV
@@ -15,13 +17,15 @@ if TYPE_CHECKING:
15
17
  import streaming # noqa
16
18
  import torch.utils.data as torchdata # noqa
17
19
 
20
+ from spiral.client import Spiral
18
21
  from spiral.dataloader import SpiralDataLoader, World # noqa
19
22
 
20
23
 
21
24
  class Scan:
22
25
  """Scan object."""
23
26
 
24
- def __init__(self, core: CoreScan):
27
+ def __init__(self, spiral: "Spiral", core: CoreScan):
28
+ self.spiral = spiral
25
29
  self.core = core
26
30
 
27
31
  @property
@@ -34,6 +38,11 @@ class Scan:
34
38
  """Returns the schema of the scan."""
35
39
  return self.core.schema()
36
40
 
41
+ @property
42
+ def key_schema(self) -> Schema:
43
+ """Returns the key schema of the scan."""
44
+ return self.core.key_schema()
45
+
37
46
  def is_empty(self) -> bool:
38
47
  """Check if the Spiral is empty for the given key range.
39
48
 
@@ -44,20 +53,30 @@ class Scan:
44
53
 
45
54
  def to_record_batches(
46
55
  self,
56
+ *,
57
+ key_range: KeyRange | None = None,
47
58
  key_table: pa.Table | pa.RecordBatchReader | None = None,
48
59
  batch_size: int | None = None,
49
60
  batch_readahead: int | None = None,
61
+ hide_progress_bar: bool = False,
50
62
  ) -> pa.RecordBatchReader:
51
63
  """Read as a stream of RecordBatches.
52
64
 
53
65
  Args:
66
+ key_range: Optional key range to filter the scan.
67
+ If provided, the scan will only return rows within the key range.
68
+ Only one of key_range or key_table can be provided.
54
69
  key_table: a table of keys to "take" (including aux columns for cell-push-down).
55
70
  If None, the scan will be executed without a key table.
56
71
  batch_size: the maximum number of rows per returned batch.
57
72
  IMPORTANT: This is currently only respected when the key_table is used. If key table is a
58
73
  RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
59
74
  batch_readahead: the number of batches to prefetch in the background.
75
+ hide_progress_bar: If True, disables the progress bar during reading.
60
76
  """
77
+ if key_range is not None and key_table is not None:
78
+ raise ValueError("Only one of key_range or key_table can be provided.")
79
+
61
80
  if isinstance(key_table, pa.RecordBatchReader):
62
81
  if batch_size is not None:
63
82
  raise ValueError(
@@ -66,46 +85,56 @@ class Scan:
66
85
  elif isinstance(key_table, pa.Table):
67
86
  key_table = key_table.to_reader(max_chunksize=batch_size)
68
87
 
69
- return self.core.to_record_batches(key_table=key_table, batch_readahead=batch_readahead)
88
+ return self.core.to_record_batches(
89
+ key_range=key_range, key_table=key_table, batch_readahead=batch_readahead, progress=(not hide_progress_bar)
90
+ )
70
91
 
71
92
  def to_table(
72
93
  self,
94
+ *,
95
+ key_range: KeyRange | None = None,
73
96
  key_table: pa.Table | pa.RecordBatchReader | None = None,
74
97
  ) -> pa.Table:
75
98
  """Read into a single PyArrow Table.
76
99
 
77
100
  Args:
101
+ key_range: Optional key range to filter the scan.
102
+ If provided, the scan will only return rows within the key range.
103
+ Only one of key_range or key_table can be provided.
78
104
  key_table: a table of keys to "take" (including aux columns for cell-push-down).
79
105
  If None, the scan will be executed without a key table.
80
106
  """
81
107
  # NOTE: Evaluates fully on Rust side which improved debuggability.
82
- if DEV and not CI and key_table is None:
108
+ if DEV and not CI and key_table is None and key_range is None:
83
109
  rb = self.core.to_record_batch()
84
110
  return pa.Table.from_batches([rb])
85
111
 
86
- return self.to_record_batches(key_table=key_table).read_all()
112
+ return self.to_record_batches(key_range=key_range, key_table=key_table).read_all()
87
113
 
88
114
  def to_dask(self) -> "dd.DataFrame":
89
115
  """Read into a Dask DataFrame.
90
116
 
91
117
  Requires the `dask` package to be installed.
118
+
119
+ IMPORTANT: Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
120
+ usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
121
+ encountering such issues, please reach out to the support for assistance.
92
122
  """
93
123
  import dask.dataframe as dd
94
- import pandas as pd
95
-
96
- def _read_shard(shard: Shard) -> pd.DataFrame:
97
- # TODO(ngates): we need a way to preserve the existing asofs?
98
- raise NotImplementedError()
99
124
 
100
- # Fetch a set of partition ranges
125
+ _read_shard = partial(
126
+ _read_shard_task,
127
+ settings_json=self.spiral.config.to_json(),
128
+ state_json=self.core.plan_state().to_json(),
129
+ )
101
130
  return dd.from_map(_read_shard, self.shards())
102
131
 
103
- def to_pandas(self) -> "pd.DataFrame":
132
+ def to_pandas(self, *, key_range: KeyRange | None = None) -> "pd.DataFrame":
104
133
  """Read into a Pandas DataFrame.
105
134
 
106
135
  Requires the `pandas` package to be installed.
107
136
  """
108
- return self.to_table().to_pandas()
137
+ return self.to_table(key_range=key_range).to_pandas()
109
138
 
110
139
  def to_polars(self) -> "pl.DataFrame":
111
140
  """Read into a Polars DataFrame.
@@ -160,16 +189,18 @@ class Scan:
160
189
 
161
190
  Returns:
162
191
  SpiralDataLoader with shards partitioned for this rank.
163
- """
164
- # Example usage:
165
- #
166
- # Auto-detect from PyTorch distributed:
167
- # loader: SpiralDataLoader = scan.to_distributed_data_loader(batch_size=32)
168
- #
169
- # Explicit world configuration:
170
- # world = World(rank=0, world_size=4)
171
- # loader: SpiralDataLoader = scan.to_distributed_data_loader(world=world, batch_size=32)
172
192
 
193
+ Auto-detect from PyTorch distributed:
194
+ ```python
195
+ loader: SpiralDataLoader = scan.to_distributed_data_loader(batch_size=32)
196
+ ```
197
+
198
+ Explicit world configuration:
199
+ ```python
200
+ world = World(rank=0, world_size=4)
201
+ loader: SpiralDataLoader = scan.to_distributed_data_loader(world=world, batch_size=32)
202
+ ```
203
+ """
173
204
  from spiral.dataloader import SpiralDataLoader, World
174
205
 
175
206
  if world is None:
@@ -203,19 +234,21 @@ class Scan:
203
234
 
204
235
  Returns:
205
236
  New SpiralDataLoader instance configured to resume from the checkpoint.
237
+
238
+ Save checkpoint during training:
239
+ ```python
240
+ loader = scan.to_distributed_data_loader(batch_size=32, seed=42)
241
+ checkpoint = loader.state_dict()
242
+ ```
243
+
244
+ Resume later - uses same shards from checkpoint:
245
+ ```python
246
+ resumed_loader = scan.resume_data_loader(
247
+ checkpoint,
248
+ batch_size=32,
249
+ transform_fn=my_transform,
250
+ )
206
251
  """
207
- # Example usage:
208
- #
209
- # Save checkpoint during training:
210
- # loader = scan.to_distributed_data_loader(batch_size=32, seed=42)
211
- # checkpoint = loader.state_dict()
212
- #
213
- # Resume later - uses same shards from checkpoint:
214
- # resumed_loader = scan.resume_data_loader(
215
- # checkpoint,
216
- # batch_size=32,
217
- # transform_fn=my_transform,
218
- # )
219
252
  from spiral.dataloader import SpiralDataLoader
220
253
 
221
254
  return SpiralDataLoader.from_state_dict(self, state, **kwargs)
@@ -283,3 +316,17 @@ class Scan:
283
316
  from spiral.debug.metrics import display_metrics
284
317
 
285
318
  display_metrics(self.metrics)
319
+
320
+
321
+ # NOTE(marko): This function must be picklable!
322
+ def _read_shard_task(shard: Shard, *, settings_json: str, state_json: str) -> "pd.DataFrame":
323
+ from spiral import Spiral
324
+ from spiral.core.table import ScanState
325
+ from spiral.settings import ClientSettings
326
+
327
+ settings = ClientSettings.from_json(settings_json)
328
+ sp = Spiral(config=settings)
329
+ state = ScanState.from_json(state_json)
330
+ task_scan = Scan(sp, sp.core.load_scan(state))
331
+
332
+ return task_scan.to_record_batches(key_range=shard.key_range, hide_progress_bar=True).read_all().to_pandas()
spiral/settings.py CHANGED
@@ -1,22 +1,16 @@
1
+ """Configuration module using Rust ClientSettings via PyO3.
2
+
3
+ This module provides a simple settings() function that returns a cached
4
+ ClientSettings instance loaded from ~/.spiral.toml and environment variables.
5
+ """
6
+
1
7
  import functools
2
8
  import os
3
9
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Annotated
5
10
 
6
11
  import typer
7
- from pydantic import Field, ValidatorFunctionWrapHandler, WrapValidator
8
- from pydantic_settings import (
9
- BaseSettings,
10
- InitSettingsSource,
11
- PydanticBaseSettingsSource,
12
- SettingsConfigDict,
13
- )
14
-
15
- from spiral.core.authn import Authn, DeviceCodeAuth, Token
16
- from spiral.core.client import Spiral
17
12
 
18
- if TYPE_CHECKING:
19
- from spiral.api import SpiralAPI
13
+ from spiral.core.config import ClientSettings
20
14
 
21
15
  DEV = "PYTEST_VERSION" in os.environ or bool(os.environ.get("SPIRAL_DEV", None))
22
16
  CI = "GITHUB_ACTIONS" in os.environ
@@ -27,88 +21,16 @@ LOG_DIR = APP_DIR / "logs"
27
21
  PACKAGE_NAME = "pyspiral"
28
22
 
29
23
 
30
- def validate_token(v, handler: ValidatorFunctionWrapHandler):
31
- if isinstance(v, str):
32
- return Token(v)
33
- else:
34
- raise ValueError("Token value must be a string")
35
-
36
-
37
- TokenType = Annotated[Token, WrapValidator(validate_token)]
38
-
39
-
40
- class SpiralDBSettings(BaseSettings):
41
- model_config = SettingsConfigDict(frozen=True)
42
-
43
- host: str = "localhost" if DEV else "api.spiraldb.com"
44
- port: int = 4279 if DEV else 443
45
- ssl: bool = not DEV
46
- token: TokenType | None = None
47
-
48
- @property
49
- def uri(self) -> str:
50
- return f"{'https' if self.ssl else 'http'}://{self.host}:{self.port}"
51
-
52
-
53
- class SpfsSettings(BaseSettings):
54
- model_config = SettingsConfigDict(frozen=True)
55
-
56
- host: str = "localhost" if DEV else "spfs.spiraldb.dev"
57
- port: int = 4295 if DEV else 443
58
- ssl: bool = not DEV
59
-
60
- @property
61
- def uri(self) -> str:
62
- return f"{'https' if self.ssl else 'http'}://{self.host}:{self.port}"
63
-
64
-
65
- class Settings(BaseSettings):
66
- model_config = SettingsConfigDict(
67
- env_nested_delimiter="__",
68
- env_prefix="SPIRAL__",
69
- frozen=True,
70
- )
71
-
72
- spiraldb: SpiralDBSettings = Field(default_factory=SpiralDBSettings)
73
- spfs: SpfsSettings = Field(default_factory=SpfsSettings)
74
- file_format: str = Field(default="vortex")
75
-
76
- @functools.cached_property
77
- def api(self) -> "SpiralAPI":
78
- from spiral.api import SpiralAPI
79
-
80
- return SpiralAPI(self.authn, base_url=self.spiraldb.uri)
81
-
82
- @functools.cached_property
83
- def core(self) -> Spiral:
84
- return Spiral(
85
- api_url=self.spiraldb.uri,
86
- spfs_url=self.spfs.uri,
87
- authn=self.authn,
88
- )
89
-
90
- @functools.cached_property
91
- def authn(self):
92
- if self.spiraldb.token:
93
- return Authn.from_token(self.spiraldb.token)
94
- return Authn.from_fallback(self.spiraldb.uri)
95
-
96
- @functools.cached_property
97
- def device_code_auth(self) -> DeviceCodeAuth:
98
- return DeviceCodeAuth.default()
99
-
100
- @classmethod
101
- def settings_customise_sources(
102
- cls,
103
- settings_cls: type[BaseSettings],
104
- env_settings: PydanticBaseSettingsSource,
105
- dotenv_settings: PydanticBaseSettingsSource,
106
- init_settings: InitSettingsSource,
107
- **kwargs,
108
- ) -> tuple[PydanticBaseSettingsSource, ...]:
109
- return env_settings, dotenv_settings, init_settings
110
-
111
-
112
24
  @functools.cache
113
- def settings() -> Settings:
114
- return Settings()
25
+ def settings() -> ClientSettings:
26
+ """Get the global ClientSettings instance.
27
+
28
+ Configuration is loaded with the following priority (highest to lowest):
29
+ 1. Environment variables (SPIRAL__*)
30
+ 2. Config file (~/.spiral.toml)
31
+ 3. Default values
32
+
33
+ Returns:
34
+ ClientSettings: The global configuration instance
35
+ """
36
+ return ClientSettings.load()
@@ -101,7 +101,7 @@ class SpiralStream:
101
101
  return 0
102
102
 
103
103
  # Prepare the shard, writing it to disk.
104
- self._sp._ops().prepare_shard(
104
+ self._sp.internal.prepare_shard(
105
105
  shard_path, self._scan.core, shard.shard, row_block_size=self._shard_row_block_size
106
106
  )
107
107