pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (52) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +333 -99
  5. pixeltable/catalog/column.py +28 -26
  6. pixeltable/catalog/globals.py +12 -0
  7. pixeltable/catalog/insertable_table.py +8 -8
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +111 -116
  10. pixeltable/catalog/table_version.py +36 -50
  11. pixeltable/catalog/table_version_handle.py +4 -1
  12. pixeltable/catalog/table_version_path.py +28 -4
  13. pixeltable/catalog/view.py +10 -18
  14. pixeltable/config.py +4 -0
  15. pixeltable/dataframe.py +10 -9
  16. pixeltable/env.py +5 -11
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/exec_node.py +2 -0
  19. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  20. pixeltable/exec/sql_node.py +47 -30
  21. pixeltable/exprs/column_property_ref.py +2 -1
  22. pixeltable/exprs/column_ref.py +7 -6
  23. pixeltable/exprs/expr.py +4 -4
  24. pixeltable/func/__init__.py +1 -0
  25. pixeltable/func/mcp.py +74 -0
  26. pixeltable/func/query_template_function.py +4 -2
  27. pixeltable/func/tools.py +12 -2
  28. pixeltable/func/udf.py +2 -2
  29. pixeltable/functions/__init__.py +1 -0
  30. pixeltable/functions/groq.py +108 -0
  31. pixeltable/functions/huggingface.py +8 -6
  32. pixeltable/functions/mistralai.py +2 -13
  33. pixeltable/functions/openai.py +1 -6
  34. pixeltable/functions/replicate.py +2 -2
  35. pixeltable/functions/util.py +6 -1
  36. pixeltable/globals.py +0 -2
  37. pixeltable/io/external_store.py +2 -2
  38. pixeltable/io/label_studio.py +4 -4
  39. pixeltable/io/table_data_conduit.py +1 -1
  40. pixeltable/metadata/__init__.py +1 -1
  41. pixeltable/metadata/converters/convert_37.py +15 -0
  42. pixeltable/metadata/notes.py +1 -0
  43. pixeltable/metadata/schema.py +5 -0
  44. pixeltable/plan.py +37 -121
  45. pixeltable/share/packager.py +2 -2
  46. pixeltable/type_system.py +30 -0
  47. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/METADATA +1 -1
  48. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/RECORD +51 -49
  49. pixeltable/utils/sample.py +0 -25
  50. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/LICENSE +0 -0
  51. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/WHEEL +0 -0
  52. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/entry_points.txt +0 -0
@@ -64,8 +64,9 @@ class ColumnPropertyRef(Expr):
64
64
  # perform runtime checks and update state
65
65
  tv = self._col_ref.tbl_version.get()
66
66
  assert tv.is_validated
67
+ # we can assume at this point during query execution that the column exists
68
+ assert self._col_ref.col_id in tv.cols_by_id
67
69
  col = tv.cols_by_id[self._col_ref.col_id]
68
- # TODO: check for column being dropped
69
70
 
70
71
  # the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
71
72
  if (
@@ -239,7 +239,6 @@ class ColumnRef(Expr):
239
239
  return helper
240
240
 
241
241
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
242
- # return None if self.perform_validation else self.col.sa_col
243
242
  if self.perform_validation:
244
243
  return None
245
244
  # we need to reestablish that we have the correct Column instance, there could have been a metadata
@@ -248,13 +247,10 @@ class ColumnRef(Expr):
248
247
  # perform runtime checks and update state
249
248
  tv = self.tbl_version.get()
250
249
  assert tv.is_validated
250
+ # we can assume at this point during query execution that the column exists
251
+ assert self.col_id in tv.cols_by_id
251
252
  self.col = tv.cols_by_id[self.col_id]
252
253
  assert self.col.tbl is tv
253
- # TODO: check for column being dropped
254
- # print(
255
- # f'ColumnRef.sql_expr: tbl={tv.id}:{tv.effective_version} sa_tbl={id(self.col.tbl.store_tbl.sa_tbl):x} '
256
- # f'tv={id(tv):x}'
257
- # )
258
254
  return self.col.sa_col
259
255
 
260
256
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -315,6 +311,11 @@ class ColumnRef(Expr):
315
311
  'perform_validation': self.perform_validation,
316
312
  }
317
313
 
314
+ @classmethod
315
+ def get_column_id(cls, d: dict) -> catalog.QColumnId:
316
+ tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
317
+ return catalog.QColumnId(tbl_id, col_id)
318
+
318
319
  @classmethod
319
320
  def get_column(cls, d: dict) -> catalog.Column:
320
321
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
pixeltable/exprs/expr.py CHANGED
@@ -394,17 +394,17 @@ class Expr(abc.ABC):
394
394
  return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
395
395
 
396
396
  @classmethod
397
- def get_refd_columns(cls, expr_dict: dict[str, Any]) -> list[catalog.Column]:
397
+ def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
398
398
  """Return Columns referenced by expr_dict."""
399
- result: list[catalog.Column] = []
399
+ result: set[catalog.QColumnId] = set()
400
400
  assert '_classname' in expr_dict
401
401
  from .column_ref import ColumnRef
402
402
 
403
403
  if expr_dict['_classname'] == 'ColumnRef':
404
- result.append(ColumnRef.get_column(expr_dict))
404
+ result.add(ColumnRef.get_column_id(expr_dict))
405
405
  if 'components' in expr_dict:
406
406
  for component_dict in expr_dict['components']:
407
- result.extend(cls.get_refd_columns(component_dict))
407
+ result.update(cls.get_refd_column_ids(component_dict))
408
408
  return result
409
409
 
410
410
  def as_literal(self) -> Optional[Expr]:
@@ -5,6 +5,7 @@ from .callable_function import CallableFunction
5
5
  from .expr_template_function import ExprTemplateFunction
6
6
  from .function import Function, InvalidFunction
7
7
  from .function_registry import FunctionRegistry
8
+ from .mcp import mcp_udfs
8
9
  from .query_template_function import QueryTemplateFunction, query, retrieval_udf
9
10
  from .signature import Batch, Parameter, Signature
10
11
  from .tools import Tool, ToolChoice, Tools
pixeltable/func/mcp.py ADDED
@@ -0,0 +1,74 @@
1
+ import asyncio
2
+ import inspect
3
+ from typing import TYPE_CHECKING, Any, Optional
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable import exceptions as excs, type_system as ts
7
+ from pixeltable.func.signature import Parameter
8
+
9
+ if TYPE_CHECKING:
10
+ import mcp
11
+
12
+
13
+ def mcp_udfs(url: str) -> list['pxt.func.Function']:
14
+ return asyncio.run(mcp_udfs_async(url))
15
+
16
+
17
+ async def mcp_udfs_async(url: str) -> list['pxt.func.Function']:
18
+ import mcp
19
+ from mcp.client.streamable_http import streamablehttp_client
20
+
21
+ list_tools_result: Optional[mcp.types.ListToolsResult] = None
22
+ async with (
23
+ streamablehttp_client(url) as (read_stream, write_stream, _),
24
+ mcp.ClientSession(read_stream, write_stream) as session,
25
+ ):
26
+ await session.initialize()
27
+ list_tools_result = await session.list_tools()
28
+ assert list_tools_result is not None
29
+
30
+ return [mcp_tool_to_udf(url, tool) for tool in list_tools_result.tools]
31
+
32
+
33
+ def mcp_tool_to_udf(url: str, mcp_tool: 'mcp.types.Tool') -> 'pxt.func.Function':
34
+ import mcp
35
+ from mcp.client.streamable_http import streamablehttp_client
36
+
37
+ async def invoke(**kwargs: Any) -> str:
38
+ # TODO: Cache session objects rather than creating a new one each time?
39
+ async with (
40
+ streamablehttp_client(url) as (read_stream, write_stream, _),
41
+ mcp.ClientSession(read_stream, write_stream) as session,
42
+ ):
43
+ await session.initialize()
44
+ res = await session.call_tool(name=mcp_tool.name, arguments=kwargs)
45
+ # TODO Handle image/audio responses?
46
+ return res.content[0].text # type: ignore[union-attr]
47
+
48
+ if mcp_tool.description is not None:
49
+ invoke.__doc__ = mcp_tool.description
50
+
51
+ input_schema = mcp_tool.inputSchema
52
+ params = {
53
+ name: __mcp_param_to_pxt_type(mcp_tool.name, name, param) for name, param in input_schema['properties'].items()
54
+ }
55
+ required = input_schema.get('required', [])
56
+
57
+ # Ensure that any params not appearing in `required` are nullable.
58
+ # (A required param might or might not be nullable, since its type might be an 'anyOf' containing a null.)
59
+ for name in params.keys() - required:
60
+ params[name] = params[name].copy(nullable=True)
61
+
62
+ signature = pxt.func.Signature(
63
+ return_type=ts.StringType(), # Return type is always string
64
+ parameters=[Parameter(name, col_type, inspect.Parameter.KEYWORD_ONLY) for name, col_type in params.items()],
65
+ )
66
+
67
+ return pxt.func.CallableFunction(signatures=[signature], py_fns=[invoke], self_name=mcp_tool.name)
68
+
69
+
70
+ def __mcp_param_to_pxt_type(tool_name: str, name: str, param: dict[str, Any]) -> ts.ColumnType:
71
+ pxt_type = ts.ColumnType.from_json_schema(param)
72
+ if pxt_type is None:
73
+ raise excs.Error(f'Unknown type schema for MCP parameter {name!r} of tool {tool_name!r}: {param}')
74
+ return pxt_type
@@ -157,11 +157,13 @@ def retrieval_udf(
157
157
  """
158
158
  # Argument validation
159
159
  col_refs: list[exprs.ColumnRef]
160
+ # TODO: get rid of references to ColumnRef internals and replace instead with a public interface
161
+ col_names = table.columns()
160
162
  if parameters is None:
161
- col_refs = [table[col_name] for col_name in table.columns if not table[col_name].col.is_computed]
163
+ col_refs = [table[col_name] for col_name in col_names if not table[col_name].col.is_computed]
162
164
  else:
163
165
  for param in parameters:
164
- if isinstance(param, str) and param not in table.columns:
166
+ if isinstance(param, str) and param not in col_names:
165
167
  raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path()!r}')
166
168
  col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
167
169
 
pixeltable/func/tools.py CHANGED
@@ -1,8 +1,9 @@
1
+ import json
1
2
  from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
2
3
 
3
4
  import pydantic
4
5
 
5
- from pixeltable import exceptions as excs
6
+ from pixeltable import exceptions as excs, type_system as ts
6
7
 
7
8
  from .function import Function
8
9
  from .signature import Parameter
@@ -69,7 +70,9 @@ class Tool(pydantic.BaseModel):
69
70
  return _extract_float_tool_arg(kwargs, param_name=param.name)
70
71
  if param.col_type.is_bool_type():
71
72
  return _extract_bool_tool_arg(kwargs, param_name=param.name)
72
- raise AssertionError()
73
+ if param.col_type.is_json_type():
74
+ return _extract_json_tool_arg(kwargs, param_name=param.name)
75
+ raise AssertionError(param.col_type)
73
76
 
74
77
 
75
78
  class ToolChoice(pydantic.BaseModel):
@@ -137,6 +140,13 @@ def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[
137
140
  return _extract_arg(bool, kwargs, param_name)
138
141
 
139
142
 
143
+ @udf
144
+ def _extract_json_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[ts.Json]:
145
+ if param_name in kwargs:
146
+ return json.loads(kwargs[param_name])
147
+ return None
148
+
149
+
140
150
  T = TypeVar('T')
141
151
 
142
152
 
pixeltable/func/udf.py CHANGED
@@ -262,7 +262,7 @@ def from_table(
262
262
  """
263
263
  from pixeltable import exprs
264
264
 
265
- ancestors = [tbl, *tbl._base_tables]
265
+ ancestors = [tbl, *tbl._get_base_tables()]
266
266
  ancestors.reverse() # We must traverse the ancestors in order from base to derived
267
267
 
268
268
  subst: dict[exprs.Expr, exprs.Expr] = {}
@@ -297,7 +297,7 @@ def from_table(
297
297
 
298
298
  if description is None:
299
299
  # Default description is the table comment
300
- description = tbl._comment
300
+ description = tbl._get_comment()
301
301
  if len(description) == 0:
302
302
  description = f"UDF for table '{tbl._name}'"
303
303
 
@@ -10,6 +10,7 @@ from . import (
10
10
  deepseek,
11
11
  fireworks,
12
12
  gemini,
13
+ groq,
13
14
  huggingface,
14
15
  image,
15
16
  json,
@@ -0,0 +1,108 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Groq API. In order to use them, you must
4
+ first `pip install groq` and configure your Groq credentials, as described in
5
+ the [Working with Groq](https://pixeltable.readme.io/docs/working-with-groq) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Optional
9
+
10
+ import pixeltable as pxt
11
+ from pixeltable import exprs
12
+ from pixeltable.env import Env, register_client
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ from .openai import _openai_response_to_pxt_tool_calls
16
+
17
+ if TYPE_CHECKING:
18
+ import groq
19
+
20
+
21
+ @register_client('groq')
22
+ def _(api_key: str) -> 'groq.AsyncGroq':
23
+ import groq
24
+
25
+ return groq.AsyncGroq(api_key=api_key)
26
+
27
+
28
+ def _groq_client() -> 'groq.AsyncGroq':
29
+ return Env.get().get_client('groq')
30
+
31
+
32
+ @pxt.udf(resource_pool='request-rate:groq')
33
+ async def chat_completions(
34
+ messages: list[dict[str, str]],
35
+ *,
36
+ model: str,
37
+ model_kwargs: Optional[dict[str, Any]] = None,
38
+ tools: Optional[list[dict[str, Any]]] = None,
39
+ tool_choice: Optional[dict[str, Any]] = None,
40
+ ) -> dict:
41
+ """
42
+ Chat Completion API.
43
+
44
+ Equivalent to the Groq `chat/completions` API endpoint.
45
+ For additional details, see: <https://console.groq.com/docs/api-reference#chat-create>
46
+
47
+ Request throttling:
48
+ Applies the rate limit set in the config (section `groq`, key `rate_limit`). If no rate
49
+ limit is configured, uses a default of 600 RPM.
50
+
51
+ __Requirements:__
52
+
53
+ - `pip install groq`
54
+
55
+ Args:
56
+ messages: A list of messages comprising the conversation so far.
57
+ model: ID of the model to use. (See overview here: <https://console.groq.com/docs/models>)
58
+ model_kwargs: Additional keyword args for the Groq `chat/completions` API.
59
+ For details on the available parameters, see: <https://console.groq.com/docs/api-reference#chat-create>
60
+
61
+ Returns:
62
+ A dictionary containing the response and other metadata.
63
+
64
+ Examples:
65
+ Add a computed column that applies the model `llama3-8b-8192`
66
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
67
+
68
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
69
+ ... tbl.add_computed_column(response=chat_completions(messages, model='llama3-8b-8192'))
70
+ """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
74
+ Env.get().require_package('groq')
75
+
76
+ if tools is not None:
77
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
78
+
79
+ if tool_choice is not None:
80
+ if tool_choice['auto']:
81
+ model_kwargs['tool_choice'] = 'auto'
82
+ elif tool_choice['required']:
83
+ model_kwargs['tool_choice'] = 'required'
84
+ else:
85
+ assert tool_choice['tool'] is not None
86
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
87
+
88
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
89
+ model_kwargs['parallel_tool_calls'] = False
90
+
91
+ result = await _groq_client().chat.completions.create(
92
+ messages=messages, # type: ignore[arg-type]
93
+ model=model,
94
+ **model_kwargs,
95
+ )
96
+ return result.model_dump()
97
+
98
+
99
+ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDict:
100
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
101
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
102
+
103
+
104
+ __all__ = local_public_names(__name__)
105
+
106
+
107
+ def __dir__() -> list[str]:
108
+ return __all__
@@ -51,7 +51,7 @@ def sentence_transformer(
51
51
  """
52
52
  env.Env.get().require_package('sentence_transformers')
53
53
  device = resolve_torch_device('auto')
54
- from sentence_transformers import SentenceTransformer # type: ignore
54
+ from sentence_transformers import SentenceTransformer
55
55
 
56
56
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
57
57
  model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
@@ -170,7 +170,7 @@ def clip(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Flo
170
170
  env.Env.get().require_package('transformers')
171
171
  device = resolve_torch_device('auto')
172
172
  import torch
173
- from transformers import CLIPModel, CLIPProcessor # type: ignore
173
+ from transformers import CLIPModel, CLIPProcessor
174
174
 
175
175
  model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
176
176
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
@@ -395,19 +395,21 @@ def speech2text_for_conditional_generation(audio: pxt.Audio, *, model_id: str, l
395
395
  device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
396
396
  import torch
397
397
  import torchaudio # type: ignore[import-untyped]
398
- from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
398
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor, Speech2TextTokenizer
399
399
 
400
400
  model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
401
401
  processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
402
+ tokenizer = processor.tokenizer
402
403
  assert isinstance(processor, Speech2TextProcessor)
404
+ assert isinstance(tokenizer, Speech2TextTokenizer)
403
405
 
404
- if language is not None and language not in processor.tokenizer.lang_code_to_id:
406
+ if language is not None and language not in tokenizer.lang_code_to_id:
405
407
  raise excs.Error(
406
408
  f"Language code '{language}' is not supported by the model '{model_id}'. "
407
- f'Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}'
409
+ f'Supported languages are: {list(tokenizer.lang_code_to_id.keys())}'
408
410
  )
409
411
 
410
- forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
412
+ forced_bos_token_id: Optional[int] = None if language is None else tokenizer.lang_code_to_id[language]
411
413
 
412
414
  # Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
413
415
  model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
@@ -5,7 +5,7 @@ first `pip install mistralai` and configure your Mistral AI credentials, as desc
5
5
  the [Working with Mistral AI](https://pixeltable.readme.io/docs/working-with-mistralai) tutorial.
6
6
  """
7
7
 
8
- from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
8
+ from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
10
  import numpy as np
11
11
 
@@ -16,7 +16,7 @@ from pixeltable.func.signature import Batch
16
16
  from pixeltable.utils.code import local_public_names
17
17
 
18
18
  if TYPE_CHECKING:
19
- import mistralai.types.basemodel
19
+ import mistralai
20
20
 
21
21
 
22
22
  @register_client('mistral')
@@ -54,8 +54,6 @@ async def chat_completions(
54
54
  model_kwargs: Additional keyword args for the Mistral `chat/completions` API.
55
55
  For details on the available parameters, see: <https://docs.mistral.ai/api/#tag/chat>
56
56
 
57
- For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/chat>
58
-
59
57
  Returns:
60
58
  A dictionary containing the response and other metadata.
61
59
 
@@ -156,15 +154,6 @@ def _(model: str) -> ts.ArrayType:
156
154
  return ts.ArrayType((dimensions,), dtype=ts.FloatType())
157
155
 
158
156
 
159
- _T = TypeVar('_T')
160
-
161
-
162
- def _opt(arg: Optional[_T]) -> Union[_T, 'mistralai.types.basemodel.Unset']:
163
- from mistralai.types import UNSET
164
-
165
- return arg if arg is not None else UNSET
166
-
167
-
168
157
  __all__ = local_public_names(__name__)
169
158
 
170
159
 
@@ -205,12 +205,7 @@ async def speech(input: str, *, model: str, voice: str, model_kwargs: Optional[d
205
205
  if model_kwargs is None:
206
206
  model_kwargs = {}
207
207
 
208
- content = await _openai_client().audio.speech.create(
209
- input=input,
210
- model=model,
211
- voice=voice, # type: ignore
212
- **model_kwargs,
213
- )
208
+ content = await _openai_client().audio.speech.create(input=input, model=model, voice=voice, **model_kwargs)
214
209
  ext = model_kwargs.get('response_format', 'mp3')
215
210
  output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
216
211
  content.write_to_file(output_filename)
@@ -12,7 +12,7 @@ from pixeltable.env import Env, register_client
12
12
  from pixeltable.utils.code import local_public_names
13
13
 
14
14
  if TYPE_CHECKING:
15
- import replicate # type: ignore[import-untyped]
15
+ import replicate
16
16
 
17
17
 
18
18
  @register_client('replicate')
@@ -27,7 +27,7 @@ def _replicate_client() -> 'replicate.Client':
27
27
 
28
28
 
29
29
  @pxt.udf(resource_pool='request-rate:replicate')
30
- async def run(input: dict[str, Any], *, ref: str) -> dict[str, Any]:
30
+ async def run(input: dict[str, Any], *, ref: str) -> pxt.Json:
31
31
  """
32
32
  Run a model on Replicate.
33
33
 
@@ -1,5 +1,6 @@
1
1
  import PIL.Image
2
2
 
3
+ from pixeltable.config import Config
3
4
  from pixeltable.env import Env
4
5
 
5
6
 
@@ -7,10 +8,14 @@ def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
7
8
  Env.get().require_package('torch')
8
9
  import torch
9
10
 
11
+ mps_enabled = Config.get().get_bool_value('enable_mps')
12
+ if mps_enabled is None:
13
+ mps_enabled = True # Default to True if not set in config
14
+
10
15
  if device == 'auto':
11
16
  if torch.cuda.is_available():
12
17
  return 'cuda'
13
- if allow_mps and torch.backends.mps.is_available():
18
+ if mps_enabled and allow_mps and torch.backends.mps.is_available():
14
19
  return 'mps'
15
20
  return 'cpu'
16
21
  return device
pixeltable/globals.py CHANGED
@@ -428,8 +428,6 @@ def get_table(path: str) -> catalog.Table:
428
428
  """
429
429
  path_obj = catalog.Path(path)
430
430
  tbl = Catalog.get().get_table(path_obj)
431
- tv = tbl._tbl_version.get()
432
- _logger.debug(f'get_table(): tbl={tv.id}:{tv.effective_version} sa_tbl={id(tv.store_tbl.sa_tbl):x} tv={id(tv):x}')
433
431
  return tbl
434
432
 
435
433
 
@@ -202,7 +202,7 @@ class Project(ExternalStore, abc.ABC):
202
202
  resolved_col_mapping: dict[Column, str] = {}
203
203
 
204
204
  # Validate names
205
- t_cols = set(table._schema.keys())
205
+ t_cols = set(table._get_schema().keys())
206
206
  for t_col, ext_col in col_mapping.items():
207
207
  if t_col not in t_cols:
208
208
  if is_user_specified_col_mapping:
@@ -225,7 +225,7 @@ class Project(ExternalStore, abc.ABC):
225
225
  assert isinstance(col_ref, exprs.ColumnRef)
226
226
  resolved_col_mapping[col_ref.col] = ext_col
227
227
  # Validate column specs
228
- t_col_types = table._schema
228
+ t_col_types = table._get_schema()
229
229
  for t_col, ext_col in col_mapping.items():
230
230
  t_col_type = t_col_types[t_col]
231
231
  if ext_col in export_cols:
@@ -412,8 +412,8 @@ class LabelStudioProject(Project):
412
412
  # TODO(aaron-siegel): Simplify this once propagation is properly implemented in batch_update
413
413
  ancestor = t
414
414
  while local_annotations_col not in ancestor._tbl_version.get().cols:
415
- assert ancestor._base_table is not None
416
- ancestor = ancestor._base_table
415
+ assert ancestor._get_base_table is not None
416
+ ancestor = ancestor._get_base_table()
417
417
  update_status = ancestor.batch_update(updates)
418
418
  env.Env.get().console_logger.info(f'Updated annotation(s) from {len(updates)} task(s) in {self}.')
419
419
  return SyncStatus(pxt_rows_updated=update_status.num_rows, num_excs=update_status.num_excs)
@@ -560,7 +560,7 @@ class LabelStudioProject(Project):
560
560
 
561
561
  if name is None:
562
562
  # Create a default name that's unique to the table
563
- all_stores = t.external_stores
563
+ all_stores = t.external_stores()
564
564
  n = 0
565
565
  while f'ls_project_{n}' in all_stores:
566
566
  n += 1
@@ -576,7 +576,7 @@ class LabelStudioProject(Project):
576
576
  local_annotations_column = ANNOTATIONS_COLUMN
577
577
  else:
578
578
  local_annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
579
- if local_annotations_column not in t._schema:
579
+ if local_annotations_column not in t._get_schema():
580
580
  t.add_columns({local_annotations_column: ts.Json})
581
581
 
582
582
  resolved_col_mapping = cls.validate_columns(
@@ -101,7 +101,7 @@ class TableDataConduit:
101
101
  def add_table_info(self, table: pxt.Table) -> None:
102
102
  """Add information about the table into which we are inserting data"""
103
103
  assert isinstance(table, pxt.Table)
104
- self.pxt_schema = table._schema
104
+ self.pxt_schema = table._get_schema()
105
105
  self.pxt_pk = table._tbl_version.get().primary_key
106
106
  for col in table._tbl_version_path.columns():
107
107
  if col.is_required_for_insert:
@@ -18,7 +18,7 @@ _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
18
18
  _logger = logging.getLogger('pixeltable')
19
19
 
20
20
  # current version of the metadata; this is incremented whenever the metadata schema changes
21
- VERSION = 37
21
+ VERSION = 38
22
22
 
23
23
 
24
24
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -0,0 +1,15 @@
1
+ from uuid import UUID
2
+
3
+ import sqlalchemy as sql
4
+
5
+ from pixeltable.metadata import register_converter
6
+ from pixeltable.metadata.converters.util import convert_table_md
7
+
8
+
9
+ @register_converter(version=37)
10
+ def _(engine: sql.engine.Engine) -> None:
11
+ convert_table_md(engine, table_md_updater=__update_table_md)
12
+
13
+
14
+ def __update_table_md(table_md: dict, _: UUID) -> None:
15
+ table_md['view_sn'] = 0
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 38: 'Added TableMd.view_sn',
5
6
  37: 'Add support for the sample() method on DataFrames',
6
7
  36: 'Added Table.lock_dummy',
7
8
  35: 'Track reference_tbl in ColumnRef',
@@ -177,6 +177,11 @@ class TableMd:
177
177
  # - every row is assigned a unique and immutable rowid on insertion
178
178
  next_row_id: int
179
179
 
180
+ # sequence number to track changes in the set of mutable views of this table (ie, this table = the view base)
181
+ # - incremented for each add/drop of a mutable view
182
+ # - only maintained for mutable tables
183
+ view_sn: int
184
+
180
185
  # Metadata format for external stores:
181
186
  # {'class': 'pixeltable.io.label_studio.LabelStudioProject', 'md': {'project_id': 3}}
182
187
  external_stores: list[dict[str, Any]]