pixeltable 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (60) hide show
  1. pixeltable/__init__.py +1 -0
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +9 -2
  4. pixeltable/catalog/column.py +1 -1
  5. pixeltable/catalog/dir.py +1 -1
  6. pixeltable/catalog/table.py +1 -1
  7. pixeltable/catalog/table_version.py +12 -2
  8. pixeltable/catalog/table_version_path.py +2 -2
  9. pixeltable/catalog/view.py +64 -20
  10. pixeltable/dataframe.py +14 -14
  11. pixeltable/env.py +20 -3
  12. pixeltable/exec/component_iteration_node.py +1 -2
  13. pixeltable/exec/expr_eval/evaluators.py +4 -2
  14. pixeltable/exec/expr_eval/expr_eval_node.py +4 -1
  15. pixeltable/exprs/comparison.py +8 -4
  16. pixeltable/exprs/data_row.py +5 -3
  17. pixeltable/exprs/expr.py +9 -2
  18. pixeltable/exprs/function_call.py +155 -313
  19. pixeltable/func/aggregate_function.py +29 -15
  20. pixeltable/func/callable_function.py +11 -8
  21. pixeltable/func/expr_template_function.py +3 -9
  22. pixeltable/func/function.py +148 -74
  23. pixeltable/func/signature.py +65 -30
  24. pixeltable/func/udf.py +1 -1
  25. pixeltable/functions/__init__.py +1 -0
  26. pixeltable/functions/deepseek.py +121 -0
  27. pixeltable/functions/image.py +7 -7
  28. pixeltable/functions/openai.py +49 -10
  29. pixeltable/functions/video.py +14 -7
  30. pixeltable/globals.py +14 -3
  31. pixeltable/index/embedding_index.py +4 -13
  32. pixeltable/io/globals.py +88 -77
  33. pixeltable/io/hf_datasets.py +34 -34
  34. pixeltable/io/pandas.py +75 -87
  35. pixeltable/io/parquet.py +19 -27
  36. pixeltable/io/utils.py +115 -0
  37. pixeltable/iterators/audio.py +2 -1
  38. pixeltable/iterators/video.py +1 -1
  39. pixeltable/metadata/__init__.py +2 -1
  40. pixeltable/metadata/converters/convert_15.py +18 -8
  41. pixeltable/metadata/converters/convert_27.py +31 -0
  42. pixeltable/metadata/converters/convert_28.py +15 -0
  43. pixeltable/metadata/converters/convert_29.py +111 -0
  44. pixeltable/metadata/converters/util.py +12 -1
  45. pixeltable/metadata/notes.py +3 -0
  46. pixeltable/metadata/schema.py +8 -0
  47. pixeltable/share/__init__.py +1 -0
  48. pixeltable/share/packager.py +246 -0
  49. pixeltable/share/publish.py +97 -0
  50. pixeltable/type_system.py +87 -42
  51. pixeltable/utils/__init__.py +41 -0
  52. pixeltable/utils/arrow.py +45 -12
  53. pixeltable/utils/formatter.py +1 -1
  54. pixeltable/utils/iceberg.py +14 -0
  55. pixeltable/utils/media_store.py +1 -1
  56. {pixeltable-0.3.3.dist-info → pixeltable-0.3.5.dist-info}/METADATA +37 -50
  57. {pixeltable-0.3.3.dist-info → pixeltable-0.3.5.dist-info}/RECORD +60 -51
  58. {pixeltable-0.3.3.dist-info → pixeltable-0.3.5.dist-info}/WHEEL +1 -1
  59. {pixeltable-0.3.3.dist-info → pixeltable-0.3.5.dist-info}/LICENSE +0 -0
  60. {pixeltable-0.3.3.dist-info → pixeltable-0.3.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,121 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast
3
+
4
+ import httpx
5
+
6
+ import pixeltable as pxt
7
+ from pixeltable import env
8
+ from pixeltable.utils.code import local_public_names
9
+
10
+ from .openai import _opt
11
+
12
+ if TYPE_CHECKING:
13
+ import openai
14
+
15
+
16
+ @env.register_client('deepseek')
17
+ def _(api_key: str) -> 'openai.AsyncOpenAI':
18
+ import openai
19
+
20
+ return openai.AsyncOpenAI(
21
+ api_key=api_key,
22
+ base_url='https://api.deepseek.com',
23
+ http_client=httpx.AsyncClient(limits=httpx.Limits(max_keepalive_connections=100, max_connections=500)),
24
+ )
25
+
26
+
27
+ def _deepseek_client() -> 'openai.AsyncOpenAI':
28
+ return env.Env.get().get_client('deepseek')
29
+
30
+
31
+ @pxt.udf
32
+ async def chat_completions(
33
+ messages: list,
34
+ *,
35
+ model: str,
36
+ frequency_penalty: Optional[float] = None,
37
+ logprobs: Optional[bool] = None,
38
+ top_logprobs: Optional[int] = None,
39
+ max_tokens: Optional[int] = None,
40
+ presence_penalty: Optional[float] = None,
41
+ response_format: Optional[dict] = None,
42
+ stop: Optional[list[str]] = None,
43
+ temperature: Optional[float] = None,
44
+ tools: Optional[list[dict]] = None,
45
+ tool_choice: Optional[dict] = None,
46
+ top_p: Optional[float] = None,
47
+ ) -> dict:
48
+ """
49
+ Creates a model response for the given chat conversation.
50
+
51
+ Equivalent to the Deepseek `chat/completions` API endpoint.
52
+ For additional details, see: <https://api-docs.deepseek.com/api/create-chat-completion>
53
+
54
+ Deepseek uses the OpenAI SDK, so you will need to install the `openai` package to use this UDF.
55
+
56
+ __Requirements:__
57
+
58
+ - `pip install openai`
59
+
60
+ Args:
61
+ messages: A list of messages to use for chat completion, as described in the Deepseek API documentation.
62
+ model: The model to use for chat completion.
63
+
64
+ For details on the other parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
65
+
66
+ Returns:
67
+ A dictionary containing the response and other metadata.
68
+
69
+ Examples:
70
+ Add a computed column that applies the model `deepseek-chat` to an existing Pixeltable column `tbl.prompt`
71
+ of the table `tbl`:
72
+
73
+ >>> messages = [
74
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
75
+ {'role': 'user', 'content': tbl.prompt}
76
+ ]
77
+ tbl.add_computed_column(response=chat_completions(messages, model='deepseek-chat'))
78
+ """
79
+ if tools is not None:
80
+ tools = [{'type': 'function', 'function': tool} for tool in tools]
81
+
82
+ tool_choice_: Union[str, dict, None] = None
83
+ if tool_choice is not None:
84
+ if tool_choice['auto']:
85
+ tool_choice_ = 'auto'
86
+ elif tool_choice['required']:
87
+ tool_choice_ = 'required'
88
+ else:
89
+ assert tool_choice['tool'] is not None
90
+ tool_choice_ = {'type': 'function', 'function': {'name': tool_choice['tool']}}
91
+
92
+ extra_body: Optional[dict[str, Any]] = None
93
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
94
+ extra_body = {'parallel_tool_calls': False}
95
+
96
+ # cast(Any, ...): avoid mypy errors
97
+ result = await _deepseek_client().chat.completions.with_raw_response.create(
98
+ messages=messages,
99
+ model=model,
100
+ frequency_penalty=_opt(frequency_penalty),
101
+ logprobs=_opt(logprobs),
102
+ top_logprobs=_opt(top_logprobs),
103
+ max_tokens=_opt(max_tokens),
104
+ presence_penalty=_opt(presence_penalty),
105
+ response_format=_opt(cast(Any, response_format)),
106
+ stop=_opt(stop),
107
+ temperature=_opt(temperature),
108
+ tools=_opt(cast(Any, tools)),
109
+ tool_choice=_opt(cast(Any, tool_choice_)),
110
+ top_p=_opt(top_p),
111
+ extra_body=extra_body,
112
+ )
113
+
114
+ return json.loads(result.text)
115
+
116
+
117
+ __all__ = local_public_names(__name__)
118
+
119
+
120
+ def __dir__():
121
+ return __all__
@@ -131,6 +131,13 @@ def getchannel(self: PIL.Image.Image, channel: int) -> PIL.Image.Image:
131
131
  pass
132
132
 
133
133
 
134
+ @getchannel.conditional_return_type
135
+ def _(self: Expr) -> pxt.ColumnType:
136
+ input_type = self.col_type
137
+ assert isinstance(input_type, pxt.ImageType)
138
+ return pxt.ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
139
+
140
+
134
141
  @pxt.udf(is_method=True)
135
142
  def get_metadata(self: PIL.Image.Image) -> dict:
136
143
  """
@@ -146,13 +153,6 @@ def get_metadata(self: PIL.Image.Image) -> dict:
146
153
  }
147
154
 
148
155
 
149
- @getchannel.conditional_return_type
150
- def _(self: Expr) -> pxt.ColumnType:
151
- input_type = self.col_type
152
- assert isinstance(input_type, pxt.ImageType)
153
- return pxt.ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
154
-
155
-
156
156
  # Image.point()
157
157
  @pxt.udf(is_method=True)
158
158
  def point(self: PIL.Image.Image, lut: list[int], mode: Optional[str] = None) -> PIL.Image.Image:
@@ -14,7 +14,7 @@ import math
14
14
  import pathlib
15
15
  import re
16
16
  import uuid
17
- from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union, cast
17
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar, Union, cast
18
18
 
19
19
  import httpx
20
20
  import numpy as np
@@ -324,10 +324,37 @@ async def translations(
324
324
  # Chat Endpoints
325
325
 
326
326
 
327
+ def _default_max_tokens(model: str) -> int:
328
+ if (
329
+ _is_model_family(model, 'gpt-4o-realtime')
330
+ or _is_model_family(model, 'gpt-4o-mini-realtime')
331
+ or _is_model_family(model, 'gpt-4-turbo')
332
+ or _is_model_family(model, 'gpt-3.5-turbo')
333
+ ):
334
+ return 4096
335
+ if _is_model_family(model, 'gpt-4'):
336
+ return 8192 # All other gpt-4 models (will not match on gpt-4o models)
337
+ if _is_model_family(model, 'gpt-4o') or _is_model_family(model, 'gpt-4.5-preview'):
338
+ return 16384 # All other gpt-4o / gpt-4.5 models
339
+ if _is_model_family(model, 'o1-preview'):
340
+ return 32768
341
+ if _is_model_family(model, 'o1-mini'):
342
+ return 65536
343
+ if _is_model_family(model, 'o1') or _is_model_family(model, 'o3'):
344
+ return 100000 # All other o1 / o3 models
345
+ return 100000 # global default
346
+
347
+
348
+ def _is_model_family(model: str, family: str) -> bool:
349
+ # `model.startswith(family)` would be a simpler match, but increases the risk of false positives.
350
+ # We use a slightly more complicated criterion to make things a little less error prone.
351
+ return model == family or model.startswith(f'{family}-')
352
+
353
+
327
354
  def _chat_completions_get_request_resources(
328
- messages: list, max_tokens: Optional[int], n: Optional[int]
355
+ messages: list, model: str, max_completion_tokens: Optional[int], max_tokens: Optional[int], n: Optional[int]
329
356
  ) -> dict[str, int]:
330
- completion_tokens = n * max_tokens
357
+ completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
331
358
 
332
359
  num_tokens = 0.0
333
360
  for message in messages:
@@ -349,16 +376,18 @@ async def chat_completions(
349
376
  logit_bias: Optional[dict[str, int]] = None,
350
377
  logprobs: Optional[bool] = None,
351
378
  top_logprobs: Optional[int] = None,
352
- max_tokens: Optional[int] = 1024,
353
- n: Optional[int] = 1,
379
+ max_completion_tokens: Optional[int] = None,
380
+ max_tokens: Optional[int] = None,
381
+ n: Optional[int] = None,
354
382
  presence_penalty: Optional[float] = None,
383
+ reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None,
355
384
  response_format: Optional[dict] = None,
356
385
  seed: Optional[int] = None,
357
386
  stop: Optional[list[str]] = None,
358
387
  temperature: Optional[float] = None,
359
- top_p: Optional[float] = None,
360
388
  tools: Optional[list[dict]] = None,
361
389
  tool_choice: Optional[dict] = None,
390
+ top_p: Optional[float] = None,
362
391
  user: Optional[str] = None,
363
392
  timeout: Optional[float] = None,
364
393
  ) -> dict:
@@ -426,16 +455,18 @@ async def chat_completions(
426
455
  logit_bias=_opt(logit_bias),
427
456
  logprobs=_opt(logprobs),
428
457
  top_logprobs=_opt(top_logprobs),
458
+ max_completion_tokens=_opt(max_completion_tokens),
429
459
  max_tokens=_opt(max_tokens),
430
460
  n=_opt(n),
431
461
  presence_penalty=_opt(presence_penalty),
462
+ reasoning_effort=_opt(reasoning_effort),
432
463
  response_format=_opt(cast(Any, response_format)),
433
464
  seed=_opt(seed),
434
465
  stop=_opt(stop),
435
466
  temperature=_opt(temperature),
436
- top_p=_opt(top_p),
437
467
  tools=_opt(cast(Any, tools)),
438
468
  tool_choice=_opt(cast(Any, tool_choice_)),
469
+ top_p=_opt(top_p),
439
470
  user=_opt(user),
440
471
  timeout=_opt(timeout),
441
472
  extra_body=extra_body,
@@ -448,9 +479,14 @@ async def chat_completions(
448
479
 
449
480
 
450
481
  def _vision_get_request_resources(
451
- prompt: str, image: PIL.Image.Image, max_tokens: Optional[int], n: Optional[int]
482
+ prompt: str,
483
+ image: PIL.Image.Image,
484
+ model: str,
485
+ max_completion_tokens: Optional[int],
486
+ max_tokens: Optional[int],
487
+ n: Optional[int],
452
488
  ) -> dict[str, int]:
453
- completion_tokens = n * max_tokens
489
+ completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
454
490
  prompt_tokens = len(prompt) / 4
455
491
 
456
492
  # calculate image tokens based on
@@ -482,7 +518,8 @@ async def vision(
482
518
  image: PIL.Image.Image,
483
519
  *,
484
520
  model: str,
485
- max_tokens: Optional[int] = 1024,
521
+ max_completion_tokens: Optional[int] = None,
522
+ max_tokens: Optional[int] = None,
486
523
  n: Optional[int] = 1,
487
524
  timeout: Optional[float] = None,
488
525
  ) -> str:
@@ -534,9 +571,11 @@ async def vision(
534
571
  rate_limits_info = env.Env.get().get_resource_pool_info(
535
572
  resource_pool, lambda: OpenAIRateLimitsInfo(_vision_get_request_resources)
536
573
  )
574
+
537
575
  result = await _openai_client().chat.completions.with_raw_response.create(
538
576
  messages=messages, # type: ignore
539
577
  model=model,
578
+ max_completion_tokens=_opt(max_completion_tokens),
540
579
  max_tokens=_opt(max_tokens),
541
580
  n=_opt(n),
542
581
  timeout=_opt(timeout),
@@ -14,9 +14,9 @@ t.select(pxt_video.extract_audio(t.video_col)).collect()
14
14
  import tempfile
15
15
  import uuid
16
16
  from pathlib import Path
17
- from typing import Optional
17
+ from typing import Any, Optional
18
18
 
19
- import av # type: ignore[import-untyped]
19
+ import av
20
20
  import numpy as np
21
21
  import PIL.Image
22
22
 
@@ -53,10 +53,14 @@ class make_video(pxt.Aggregator):
53
53
  Aggregator that creates a video from a sequence of images.
54
54
  """
55
55
 
56
+ container: Optional[av.container.OutputContainer]
57
+ stream: Optional[av.video.stream.VideoStream]
58
+ fps: int
59
+
56
60
  def __init__(self, fps: int = 25):
57
61
  """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
58
- self.container: Optional[av.container.OutputContainer] = None
59
- self.stream: Optional[av.stream.Stream] = None
62
+ self.container = None
63
+ self.stream = None
60
64
  self.fps = fps
61
65
 
62
66
  def update(self, frame: PIL.Image.Image) -> None:
@@ -107,9 +111,10 @@ def extract_audio(
107
111
 
108
112
  with av.open(output_filename, 'w', format=format) as output_container:
109
113
  output_stream = output_container.add_stream(codec or default_codec)
114
+ assert isinstance(output_stream, av.audio.stream.AudioStream)
110
115
  for packet in container.demux(audio_stream):
111
116
  for frame in packet.decode():
112
- output_container.mux(output_stream.encode(frame))
117
+ output_container.mux(output_stream.encode(frame)) # type: ignore[arg-type]
113
118
 
114
119
  return output_filename
115
120
 
@@ -141,7 +146,7 @@ def __get_stream_metadata(stream: av.stream.Stream) -> dict:
141
146
  return {'type': stream.type} # Currently unsupported
142
147
 
143
148
  codec_context = stream.codec_context
144
- codec_context_md = {
149
+ codec_context_md: dict[str, Any] = {
145
150
  'name': codec_context.name,
146
151
  'codec_tag': codec_context.codec_tag.encode('unicode-escape').decode('utf-8'),
147
152
  'profile': codec_context.profile,
@@ -160,9 +165,11 @@ def __get_stream_metadata(stream: av.stream.Stream) -> dict:
160
165
 
161
166
  if stream.type == 'audio':
162
167
  # Additional metadata for audio
163
- codec_context_md['channels'] = int(codec_context.channels) if codec_context.channels is not None else None
168
+ channels = getattr(stream.codec_context, 'channels', None)
169
+ codec_context_md['channels'] = int(channels) if channels is not None else None
164
170
  else:
165
171
  assert stream.type == 'video'
172
+ assert isinstance(stream, av.video.stream.VideoStream)
166
173
  # Additional metadata for video
167
174
  codec_context_md['pix_fmt'] = getattr(stream.codec_context, 'pix_fmt', None)
168
175
  metadata.update(
pixeltable/globals.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import dataclasses
2
2
  import logging
3
+ import urllib.parse
3
4
  from typing import Any, Iterable, Literal, Optional, Union
4
5
  from uuid import UUID
5
6
 
@@ -10,7 +11,7 @@ from sqlalchemy.util.preloaded import orm
10
11
 
11
12
  import pixeltable.exceptions as excs
12
13
  import pixeltable.exprs as exprs
13
- from pixeltable import DataFrame, catalog, func
14
+ from pixeltable import DataFrame, catalog, func, share
14
15
  from pixeltable.catalog import Catalog
15
16
  from pixeltable.dataframe import DataFrameResultSet
16
17
  from pixeltable.env import Env
@@ -279,14 +280,16 @@ def create_view(
279
280
  ... view = pxt.create_view('my_view', tbl.where(tbl.col1 > 100), if_exists='replace_force')
280
281
  """
281
282
  where: Optional[exprs.Expr] = None
283
+ select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None
282
284
  if isinstance(base, catalog.Table):
283
285
  tbl_version_path = base._tbl_version_path
284
286
  elif isinstance(base, DataFrame):
285
- base._validate_mutable('create_view')
287
+ base._validate_mutable('create_view', allow_select=True)
286
288
  if len(base._from_clause.tbls) > 1:
287
289
  raise excs.Error('Cannot create a view of a join')
288
290
  tbl_version_path = base._from_clause.tbls[0]
289
291
  where = base.where_clause
292
+ select_list = base.select_list
290
293
  else:
291
294
  raise excs.Error('`base` must be an instance of `Table` or `DataFrame`')
292
295
  assert isinstance(base, catalog.Table) or isinstance(base, DataFrame)
@@ -322,6 +325,7 @@ def create_view(
322
325
  dir._id,
323
326
  path.name,
324
327
  base=tbl_version_path,
328
+ select_list=select_list,
325
329
  additional_columns=additional_columns,
326
330
  predicate=where,
327
331
  is_snapshot=is_snapshot,
@@ -630,7 +634,7 @@ def create_dir(
630
634
  parent = cat.paths[path.parent]
631
635
  assert parent is not None
632
636
  with orm.Session(Env.get().engine, future=True) as session:
633
- dir_md = schema.DirMd(name=path.name)
637
+ dir_md = schema.DirMd(name=path.name, user=None, additional_md={})
634
638
  dir_record = schema.Dir(parent_id=parent._id, md=dataclasses.asdict(dir_md))
635
639
  session.add(dir_record)
636
640
  session.flush()
@@ -723,6 +727,13 @@ def drop_dir(path_str: str, force: bool = False, if_not_exists: Literal['error',
723
727
  _logger.info(f'Removed directory `{path_str}`.')
724
728
 
725
729
 
730
+ def publish_snapshot(dest_uri: str, table: catalog.Table) -> None:
731
+ parsed_uri = urllib.parse.urlparse(dest_uri)
732
+ if parsed_uri.scheme != 'pxt':
733
+ raise excs.Error(f'Invalid Pixeltable URI (does not start with pxt://): {dest_uri}')
734
+ share.publish_snapshot(dest_uri, table)
735
+
736
+
726
737
  def list_dirs(path_str: str = '', recursive: bool = True) -> list[str]:
727
738
  """List the directories in a directory.
728
739
 
@@ -99,10 +99,10 @@ class EmbeddingIndex(IndexBase):
99
99
  # Now validate the return types of the embedding functions.
100
100
 
101
101
  if self.string_embed is not None:
102
- self._validate_embedding_fn(self.string_embed, ts.ColumnType.Type.STRING)
102
+ self._validate_embedding_fn(self.string_embed)
103
103
 
104
104
  if self.image_embed is not None:
105
- self._validate_embedding_fn(self.image_embed, ts.ColumnType.Type.IMAGE)
105
+ self._validate_embedding_fn(self.image_embed)
106
106
 
107
107
  if c.col_type.is_string_type() and self.string_embed is None:
108
108
  raise excs.Error(f"Text embedding function is required for column {c.name} (parameter 'string_embed')")
@@ -206,21 +206,12 @@ class EmbeddingIndex(IndexBase):
206
206
  return None
207
207
 
208
208
  @classmethod
209
- def _validate_embedding_fn(cls, embed_fn: func.Function, expected_type: ts.ColumnType.Type) -> None:
209
+ def _validate_embedding_fn(cls, embed_fn: func.Function) -> None:
210
210
  """Validate the given embedding function."""
211
211
  assert not embed_fn.is_polymorphic
212
- sig = embed_fn.signature
213
212
 
214
- # validate return type
215
- param_name = sig.parameters_by_pos[0].name
216
- if expected_type == ts.ColumnType.Type.STRING:
217
- return_type = embed_fn.call_return_type([], {param_name: 'dummy'})
218
- else:
219
- assert expected_type == ts.ColumnType.Type.IMAGE
220
- img = PIL.Image.new('RGB', (512, 512))
221
- return_type = embed_fn.call_return_type([], {param_name: img})
213
+ return_type = embed_fn.signature.return_type
222
214
 
223
- assert return_type is not None
224
215
  if not isinstance(return_type, ts.ArrayType):
225
216
  raise excs.Error(
226
217
  f'The function `{embed_fn.name}` is not a valid embedding: '
pixeltable/io/globals.py CHANGED
@@ -1,3 +1,7 @@
1
+ import json
2
+ import urllib.parse
3
+ import urllib.request
4
+ from pathlib import Path
1
5
  from typing import TYPE_CHECKING, Any, Literal, Optional, Union
2
6
 
3
7
  import pixeltable as pxt
@@ -5,11 +9,61 @@ import pixeltable.exceptions as excs
5
9
  from pixeltable import Table, exprs
6
10
  from pixeltable.env import Env
7
11
  from pixeltable.io.external_store import SyncStatus
12
+ from pixeltable.utils import parse_local_file_path
8
13
 
9
14
  if TYPE_CHECKING:
10
15
  import fiftyone as fo # type: ignore[import-untyped]
11
16
 
12
17
 
18
+ from .utils import find_or_create_table, normalize_import_parameters, normalize_schema_names
19
+
20
+
21
+ def _infer_schema_from_rows(
22
+ rows: list[dict[str, Any]], schema_overrides: dict[str, Any], primary_key: list[str]
23
+ ) -> dict[str, pxt.ColumnType]:
24
+ schema: dict[str, pxt.ColumnType] = {}
25
+ cols_with_nones: set[str] = set()
26
+
27
+ for n, row in enumerate(rows):
28
+ for col_name, value in row.items():
29
+ if col_name in schema_overrides:
30
+ # We do the insertion here; this will ensure that the column order matches the order
31
+ # in which the column names are encountered in the input data, even if `schema_overrides`
32
+ # is specified.
33
+ if col_name not in schema:
34
+ schema[col_name] = schema_overrides[col_name]
35
+ elif value is not None:
36
+ # If `key` is not in `schema_overrides`, then we infer its type from the data.
37
+ # The column type will always be nullable by default.
38
+ col_type = pxt.ColumnType.infer_literal_type(value, nullable=col_name not in primary_key)
39
+ if col_type is None:
40
+ raise excs.Error(
41
+ f'Could not infer type for column `{col_name}`; the value in row {n} has an unsupported type: {type(value)}'
42
+ )
43
+ if col_name not in schema:
44
+ schema[col_name] = col_type
45
+ else:
46
+ supertype = schema[col_name].supertype(col_type)
47
+ if supertype is None:
48
+ raise excs.Error(
49
+ f'Could not infer type of column `{col_name}`; the value in row {n} does not match preceding type {schema[col_name]}: {value!r}\n'
50
+ 'Consider specifying the type explicitly in `schema_overrides`.'
51
+ )
52
+ schema[col_name] = supertype
53
+ else:
54
+ cols_with_nones.add(col_name)
55
+
56
+ entirely_none_cols = cols_with_nones - schema.keys()
57
+ if len(entirely_none_cols) > 0:
58
+ # A column can only end up in `entirely_none_cols` if it was not in `schema_overrides` and
59
+ # was not encountered in any row with a non-None value.
60
+ raise excs.Error(
61
+ f'The following columns have no non-null values: {", ".join(entirely_none_cols)}\n'
62
+ 'Consider specifying the type(s) explicitly in `schema_overrides`.'
63
+ )
64
+ return schema
65
+
66
+
13
67
  def create_label_studio_project(
14
68
  t: Table,
15
69
  label_config: str,
@@ -140,7 +194,7 @@ def import_rows(
140
194
  tbl_path: str,
141
195
  rows: list[dict[str, Any]],
142
196
  *,
143
- schema_overrides: Optional[dict[str, pxt.ColumnType]] = None,
197
+ schema_overrides: Optional[dict[str, Any]] = None,
144
198
  primary_key: Optional[Union[str, list[str]]] = None,
145
199
  num_retained_versions: int = 10,
146
200
  comment: str = '',
@@ -169,67 +223,22 @@ def import_rows(
169
223
  Returns:
170
224
  A handle to the newly created [`Table`][pixeltable.Table].
171
225
  """
172
- if schema_overrides is None:
173
- schema_overrides = {}
174
- schema: dict[str, pxt.ColumnType] = {}
175
- cols_with_nones: set[str] = set()
226
+ schema_overrides, primary_key = normalize_import_parameters(schema_overrides, primary_key)
227
+ row_schema = _infer_schema_from_rows(rows, schema_overrides, primary_key)
228
+ schema, pxt_pk, _ = normalize_schema_names(row_schema, primary_key, schema_overrides, True)
176
229
 
177
- for n, row in enumerate(rows):
178
- for col_name, value in row.items():
179
- if col_name in schema_overrides:
180
- # We do the insertion here; this will ensure that the column order matches the order
181
- # in which the column names are encountered in the input data, even if `schema_overrides`
182
- # is specified.
183
- if col_name not in schema:
184
- schema[col_name] = schema_overrides[col_name]
185
- elif value is not None:
186
- # If `key` is not in `schema_overrides`, then we infer its type from the data.
187
- # The column type will always be nullable by default.
188
- col_type = pxt.ColumnType.infer_literal_type(value, nullable=True)
189
- if col_type is None:
190
- raise excs.Error(
191
- f'Could not infer type for column `{col_name}`; the value in row {n} has an unsupported type: {type(value)}'
192
- )
193
- if col_name not in schema:
194
- schema[col_name] = col_type
195
- else:
196
- supertype = schema[col_name].supertype(col_type)
197
- if supertype is None:
198
- raise excs.Error(
199
- f'Could not infer type of column `{col_name}`; the value in row {n} does not match preceding type {schema[col_name]}: {value!r}\n'
200
- 'Consider specifying the type explicitly in `schema_overrides`.'
201
- )
202
- schema[col_name] = supertype
203
- else:
204
- cols_with_nones.add(col_name)
205
-
206
- extraneous_keys = schema_overrides.keys() - schema.keys()
207
- if len(extraneous_keys) > 0:
208
- raise excs.Error(
209
- f'The following columns specified in `schema_overrides` are not present in the data: {", ".join(extraneous_keys)}'
210
- )
211
-
212
- entirely_none_cols = cols_with_nones - schema.keys()
213
- if len(entirely_none_cols) > 0:
214
- # A column can only end up in `entirely_null_cols` if it was not in `schema_overrides` and
215
- # was not encountered in any row with a non-None value.
216
- raise excs.Error(
217
- f'The following columns have no non-null values: {", ".join(entirely_none_cols)}\n'
218
- 'Consider specifying the type(s) explicitly in `schema_overrides`.'
219
- )
220
-
221
- t = pxt.create_table(
222
- tbl_path, schema, primary_key=primary_key, num_retained_versions=num_retained_versions, comment=comment
230
+ table = find_or_create_table(
231
+ tbl_path, schema, primary_key=pxt_pk, num_retained_versions=num_retained_versions, comment=comment
223
232
  )
224
- t.insert(rows)
225
- return t
233
+ table.insert(rows)
234
+ return table
226
235
 
227
236
 
228
237
  def import_json(
229
238
  tbl_path: str,
230
239
  filepath_or_url: str,
231
240
  *,
232
- schema_overrides: Optional[dict[str, pxt.ColumnType]] = None,
241
+ schema_overrides: Optional[dict[str, Any]] = None,
233
242
  primary_key: Optional[Union[str, list[str]]] = None,
234
243
  num_retained_versions: int = 10,
235
244
  comment: str = '',
@@ -253,33 +262,35 @@ def import_json(
253
262
  Returns:
254
263
  A handle to the newly created [`Table`][pixeltable.Table].
255
264
  """
256
- import json
257
- import urllib.parse
258
- import urllib.request
259
-
260
- # TODO Consolidate this logic with other places where files/URLs are parsed
261
- parsed = urllib.parse.urlparse(filepath_or_url)
262
- if len(parsed.scheme) <= 1 or parsed.scheme == 'file':
263
- # local file path
264
- if len(parsed.scheme) <= 1:
265
- filepath = filepath_or_url
266
- else:
267
- filepath = urllib.parse.unquote(urllib.request.url2pathname(parsed.path))
268
- with open(filepath) as fp:
265
+ path = parse_local_file_path(filepath_or_url)
266
+ if path is None: # it's a URL
267
+ # TODO: This should read from S3 as well.
268
+ contents = urllib.request.urlopen(filepath_or_url).read()
269
+ else:
270
+ with open(path) as fp:
269
271
  contents = fp.read()
272
+
273
+ rows = json.loads(contents, **kwargs)
274
+
275
+ schema_overrides, primary_key = normalize_import_parameters(schema_overrides, primary_key)
276
+ row_schema = _infer_schema_from_rows(rows, schema_overrides, primary_key)
277
+ schema, pxt_pk, col_mapping = normalize_schema_names(row_schema, primary_key, schema_overrides, False)
278
+
279
+ # Convert all rows to insertable format - not needed, misnamed columns and types are errors in the incoming row format
280
+ if col_mapping is not None:
281
+ tbl_rows = [
282
+ {field if col_mapping is None else col_mapping[field]: val for field, val in row.items()} for row in rows
283
+ ]
270
284
  else:
271
- # URL
272
- contents = urllib.request.urlopen(filepath_or_url).read()
273
- data = json.loads(contents, **kwargs)
274
- return import_rows(
275
- tbl_path,
276
- data,
277
- schema_overrides=schema_overrides,
278
- primary_key=primary_key,
279
- num_retained_versions=num_retained_versions,
280
- comment=comment,
285
+ tbl_rows = rows
286
+
287
+ table = find_or_create_table(
288
+ tbl_path, schema, primary_key=pxt_pk, num_retained_versions=num_retained_versions, comment=comment
281
289
  )
282
290
 
291
+ table.insert(tbl_rows)
292
+ return table
293
+
283
294
 
284
295
  def export_images_as_fo_dataset(
285
296
  tbl: pxt.Table,