pixeltable 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (67) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +5 -0
  4. pixeltable/catalog/globals.py +8 -0
  5. pixeltable/catalog/insertable_table.py +2 -2
  6. pixeltable/catalog/table.py +27 -9
  7. pixeltable/catalog/table_version.py +41 -68
  8. pixeltable/catalog/view.py +3 -3
  9. pixeltable/dataframe.py +7 -6
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/expr_eval_node.py +8 -1
  12. pixeltable/exec/row_update_node.py +61 -0
  13. pixeltable/exec/{sql_scan_node.py → sql_node.py} +120 -56
  14. pixeltable/exprs/__init__.py +1 -2
  15. pixeltable/exprs/comparison.py +5 -5
  16. pixeltable/exprs/compound_predicate.py +12 -12
  17. pixeltable/exprs/expr.py +67 -22
  18. pixeltable/exprs/function_call.py +60 -29
  19. pixeltable/exprs/globals.py +2 -0
  20. pixeltable/exprs/in_predicate.py +3 -3
  21. pixeltable/exprs/inline_array.py +18 -11
  22. pixeltable/exprs/is_null.py +5 -5
  23. pixeltable/exprs/method_ref.py +63 -0
  24. pixeltable/ext/__init__.py +9 -0
  25. pixeltable/ext/functions/__init__.py +8 -0
  26. pixeltable/ext/functions/whisperx.py +45 -5
  27. pixeltable/ext/functions/yolox.py +60 -14
  28. pixeltable/func/aggregate_function.py +10 -4
  29. pixeltable/func/callable_function.py +16 -4
  30. pixeltable/func/expr_template_function.py +1 -1
  31. pixeltable/func/function.py +12 -2
  32. pixeltable/func/function_registry.py +26 -9
  33. pixeltable/func/udf.py +32 -4
  34. pixeltable/functions/__init__.py +1 -1
  35. pixeltable/functions/fireworks.py +33 -0
  36. pixeltable/functions/globals.py +36 -1
  37. pixeltable/functions/huggingface.py +155 -7
  38. pixeltable/functions/image.py +242 -40
  39. pixeltable/functions/openai.py +214 -0
  40. pixeltable/functions/string.py +600 -8
  41. pixeltable/functions/timestamp.py +210 -0
  42. pixeltable/functions/together.py +106 -0
  43. pixeltable/functions/video.py +28 -10
  44. pixeltable/functions/whisper.py +32 -0
  45. pixeltable/globals.py +3 -3
  46. pixeltable/io/__init__.py +1 -1
  47. pixeltable/io/globals.py +186 -5
  48. pixeltable/io/label_studio.py +42 -2
  49. pixeltable/io/pandas.py +70 -34
  50. pixeltable/metadata/__init__.py +1 -1
  51. pixeltable/metadata/converters/convert_18.py +39 -0
  52. pixeltable/metadata/notes.py +10 -0
  53. pixeltable/plan.py +82 -7
  54. pixeltable/tool/create_test_db_dump.py +4 -5
  55. pixeltable/tool/doc_plugins/griffe.py +81 -0
  56. pixeltable/tool/doc_plugins/mkdocstrings.py +6 -0
  57. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +135 -0
  58. pixeltable/type_system.py +15 -14
  59. pixeltable/utils/s3.py +1 -1
  60. pixeltable-0.2.14.dist-info/METADATA +206 -0
  61. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/RECORD +64 -56
  62. pixeltable-0.2.14.dist-info/entry_points.txt +3 -0
  63. pixeltable/exprs/image_member_access.py +0 -96
  64. pixeltable/exprs/predicate.py +0 -44
  65. pixeltable-0.2.12.dist-info/METADATA +0 -137
  66. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/LICENSE +0 -0
  67. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/WHEEL +0 -0
pixeltable/plan.py CHANGED
@@ -40,7 +40,7 @@ class Analyzer:
40
40
 
41
41
  def __init__(
42
42
  self, tbl: catalog.TableVersionPath, select_list: List[exprs.Expr],
43
- where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
43
+ where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
44
44
  order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None):
45
45
  if group_by_clause is None:
46
46
  group_by_clause = []
@@ -58,7 +58,7 @@ class Analyzer:
58
58
  # Where clause of the Select stmt of the SQL scan
59
59
  self.sql_where_clause: Optional[exprs.Expr] = None
60
60
  # filter predicate applied to output rows of the SQL scan
61
- self.filter: Optional[exprs.Predicate] = None
61
+ self.filter: Optional[exprs.Expr] = None
62
62
  # not executable
63
63
  #self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
64
64
  if where_clause is not None:
@@ -107,7 +107,7 @@ class Analyzer:
107
107
  for e in self.group_by_clause:
108
108
  if e.sql_expr() is None:
109
109
  raise excs.Error(f'Invalid grouping expression, needs to be expressible in SQL: {e}')
110
- if e.contains(filter=lambda e: _is_agg_fn_call(e)):
110
+ if e._contains(filter=lambda e: _is_agg_fn_call(e)):
111
111
  raise excs.Error(f'Grouping expression contains aggregate function: {e}')
112
112
 
113
113
  # check that agg fn calls don't have contradicting ordering requirements
@@ -183,7 +183,7 @@ class Planner:
183
183
  # TODO: create an exec.CountNode and change this to create_count_plan()
184
184
  @classmethod
185
185
  def create_count_stmt(
186
- cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Predicate] = None
186
+ cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None
187
187
  ) -> sql.Select:
188
188
  stmt = sql.select(sql.func.count('*'))
189
189
  refd_tbl_ids: Set[UUID] = set()
@@ -239,7 +239,7 @@ class Planner:
239
239
  cls, tbl: catalog.TableVersionPath,
240
240
  update_targets: dict[catalog.Column, exprs.Expr],
241
241
  recompute_targets: List[catalog.Column],
242
- where_clause: Optional[exprs.Predicate], cascade: bool
242
+ where_clause: Optional[exprs.Expr], cascade: bool
243
243
  ) -> Tuple[exec.ExecNode, List[str], List[catalog.Column]]:
244
244
  """Creates a plan to materialize updated rows.
245
245
  The plan:
@@ -288,6 +288,81 @@ class Planner:
288
288
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
289
289
  return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
290
290
 
291
+ @classmethod
292
+ def create_batch_update_plan(
293
+ cls, tbl: catalog.TableVersionPath,
294
+ batch: list[dict[catalog.Column, exprs.Expr]], rowids: list[tuple[int, ...]],
295
+ cascade: bool
296
+ ) -> Tuple[exec.ExecNode, exec.RowUpdateNode, sql.ClauseElement, List[catalog.Column], List[catalog.Column]]:
297
+ """
298
+ Returns:
299
+ - root node of the plan to produce the updated rows
300
+ - RowUpdateNode of plan
301
+ - Where clause for deleting the current versions of updated rows
302
+ - list of columns that are getting updated
303
+ - list of user-visible columns that are being recomputed
304
+ """
305
+ assert isinstance(tbl, catalog.TableVersionPath)
306
+ target = tbl.tbl_version # the one we need to update
307
+ sa_key_cols: list[sql.Column] = []
308
+ key_vals: list[tuple] = []
309
+ if len(rowids) > 0:
310
+ sa_key_cols = target.store_tbl.rowid_columns()
311
+ key_vals = rowids
312
+ else:
313
+ pk_cols = target.primary_key_columns()
314
+ sa_key_cols = [c.sa_col for c in pk_cols]
315
+ key_vals = [tuple(row[col].val for col in pk_cols) for row in batch]
316
+
317
+ # retrieve all stored cols and all target exprs
318
+ updated_cols = batch[0].keys() - target.primary_key_columns()
319
+ recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
320
+ # regardless of cascade, we need to update all indices on any updated column
321
+ idx_val_cols = target.get_idx_val_columns(updated_cols)
322
+ recomputed_cols.update(idx_val_cols)
323
+ # we only need to recompute stored columns (unstored ones are substituted away)
324
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
325
+ recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
326
+ copied_cols = [
327
+ col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
328
+ ]
329
+ select_list = [exprs.ColumnRef(col) for col in copied_cols]
330
+ select_list.extend([exprs.ColumnRef(col) for col in updated_cols])
331
+
332
+ recomputed_exprs = \
333
+ [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
334
+ # the RowUpdateNode updates columns in-place, ie, in the original ColumnRef; no further sustitution is needed
335
+ select_list.extend(recomputed_exprs)
336
+
337
+ # ExecNode tree (from bottom to top):
338
+ # - SqlLookupNode to retrieve the existing rows
339
+ # - RowUpdateNode to update the retrieved rows
340
+ # - ExprEvalNode to evaluate the remaining output exprs
341
+ analyzer = Analyzer(tbl, select_list)
342
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], analyzer.sql_exprs)
343
+ analyzer.finalize(row_builder)
344
+ plan = exec.SqlLookupNode(tbl, row_builder, analyzer.sql_exprs, sa_key_cols, key_vals)
345
+ delete_where_clause = plan.where_clause
346
+ col_vals = [{col: row[col].val for col in updated_cols} for row in batch]
347
+ plan = row_update_node = exec.RowUpdateNode(tbl, key_vals, len(rowids) > 0, col_vals, row_builder, plan)
348
+ if not cls._is_contained_in(analyzer.select_list, analyzer.sql_exprs):
349
+ # we need an ExprEvalNode to evaluate the remaining output exprs
350
+ plan = exec.ExprEvalNode(row_builder, analyzer.select_list, analyzer.sql_exprs, input=plan)
351
+ # update row builder with column information
352
+ all_base_cols = copied_cols + list(updated_cols) + list(recomputed_base_cols) # same order as select_list
353
+ row_builder.substitute_exprs(select_list, remove_duplicates=False)
354
+ for i, col in enumerate(all_base_cols):
355
+ plan.row_builder.add_table_column(col, select_list[i].slot_idx)
356
+
357
+ ctx = exec.ExecContext(row_builder)
358
+ # we're returning everything to the user, so we might as well do it in a single batch
359
+ ctx.batch_size = 0
360
+ plan.set_ctx(ctx)
361
+ recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
362
+ return (
363
+ plan, row_update_node, delete_where_clause, list(updated_cols) + recomputed_user_cols, recomputed_user_cols
364
+ )
365
+
291
366
  @classmethod
292
367
  def create_view_update_plan(
293
368
  cls, view: catalog.TableVersionPath, recompute_targets: List[catalog.Column]
@@ -505,7 +580,7 @@ class Planner:
505
580
  @classmethod
506
581
  def create_query_plan(
507
582
  cls, tbl: catalog.TableVersionPath, select_list: Optional[List[exprs.Expr]] = None,
508
- where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
583
+ where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
509
584
  order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, limit: Optional[int] = None,
510
585
  with_pk: bool = False, ignore_errors: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
511
586
  ) -> exec.ExecNode:
@@ -597,7 +672,7 @@ class Planner:
597
672
  return plan
598
673
 
599
674
  @classmethod
600
- def analyze(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Predicate) -> Analyzer:
675
+ def analyze(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Expr) -> Analyzer:
601
676
  return Analyzer(tbl, [], where_clause=where_clause)
602
677
 
603
678
  @classmethod
@@ -61,7 +61,7 @@ class Dumper:
61
61
  info_dict = {'pixeltable-dump': {
62
62
  'metadata-version': md_version,
63
63
  'git-sha': git_sha,
64
- 'datetime': datetime.datetime.utcnow(),
64
+ 'datetime': datetime.datetime.now(tz=datetime.timezone.utc),
65
65
  'user': user
66
66
  }}
67
67
  with open(info_file, 'w') as info:
@@ -179,7 +179,7 @@ class Dumper:
179
179
  def __add_expr_columns(self, t: pxt.Table, col_prefix: str, include_expensive_functions=False) -> None:
180
180
  def add_column(col_name: str, col_expr: Any) -> None:
181
181
  t.add_column(**{f'{col_prefix}_{col_name}': col_expr})
182
-
182
+
183
183
  # arithmetic_expr
184
184
  add_column('plus', t.c2 + 6)
185
185
  add_column('minus', t.c2 - 5)
@@ -208,7 +208,7 @@ class Dumper:
208
208
  add_column('not', ~(t.c2 > 20))
209
209
 
210
210
  # function_call
211
- add_column('function_call', pxt.functions.string.str_format('{0} {key}', t.c1, key=t.c1)) # library function
211
+ add_column('function_call', pxt.functions.string.format('{0} {key}', t.c1, key=t.c1)) # library function
212
212
  add_column('test_udf', test_udf_stored(t.c2)) # stored udf
213
213
  add_column('test_udf_batched', test_udf_stored_batched(t.c1, upper=False)) # batched stored udf
214
214
  if include_expensive_functions:
@@ -242,8 +242,7 @@ class Dumper:
242
242
  add_column('str_const', 'str')
243
243
  add_column('int_const', 5)
244
244
  add_column('float_const', 5.0)
245
- add_column('timestamp_const_1', datetime.datetime.utcnow())
246
- add_column('timestamp_const_2', datetime.date.today())
245
+ add_column('timestamp_const_1', datetime.datetime.now(tz=datetime.timezone.utc))
247
246
 
248
247
  # type_cast
249
248
  add_column('astype', t.c2.astype(FloatType()))
@@ -0,0 +1,81 @@
1
+ import ast
2
+ from typing import Optional, Union
3
+ import warnings
4
+
5
+ import griffe
6
+ import griffe.expressions
7
+ from griffe import Extension, Object, ObjectNode
8
+
9
+ import pixeltable as pxt
10
+
11
+ logger = griffe.get_logger(__name__)
12
+
13
+ class PxtGriffeExtension(Extension):
14
+ """Implementation of a Pixeltable custom griffe extension."""
15
+
16
+ def on_instance(self, node: Union[ast.AST, ObjectNode], obj: Object) -> None:
17
+ if obj.docstring is None:
18
+ # Skip over entities without a docstring
19
+ return
20
+
21
+ if isinstance(obj, griffe.Function):
22
+ # See if the (Python) function has a @pxt.udf decorator
23
+ if any(
24
+ isinstance(dec.value, griffe.expressions.Expr) and dec.value.canonical_path in ['pixeltable.func.udf', 'pixeltable.udf']
25
+ for dec in obj.decorators
26
+ ):
27
+ # Update the template
28
+ self.__modify_pxt_udf(obj)
29
+
30
+ def __modify_pxt_udf(self, func: griffe.Function) -> None:
31
+ """
32
+ Instructs the doc snippet for `func` to use the custom Pixeltable UDF jinja template, and
33
+ converts all type hints to Pixeltable column type references, in accordance with the @udf
34
+ decorator behavior.
35
+ """
36
+ func.extra['mkdocstrings']['template'] = 'udf.html.jinja'
37
+ # Dynamically load the UDF reference so we can inspect the Pixeltable signature directly
38
+ warnings.simplefilter("ignore")
39
+ udf = griffe.dynamic_import(func.path)
40
+ assert isinstance(udf, pxt.Function)
41
+ # Convert the return type to a Pixeltable type reference
42
+ func.returns = self.__column_type_to_display_str(udf.signature.get_return_type())
43
+ # Convert the parameter types to Pixeltable type references
44
+ for griffe_param in func.parameters:
45
+ assert isinstance(griffe_param.annotation, griffe.expressions.Expr)
46
+ if griffe_param.name not in udf.signature.parameters:
47
+ logger.warning(f'Parameter `{griffe_param.name}` not found in signature for UDF: {udf.display_name}')
48
+ continue
49
+ pxt_param = udf.signature.parameters[griffe_param.name]
50
+ griffe_param.annotation = self.__column_type_to_display_str(pxt_param.col_type)
51
+
52
+ def __column_type_to_display_str(self, column_type: Optional[pxt.ColumnType]) -> str:
53
+ # TODO: When we enhance the Pixeltable type system, we may want to refactor some of this logic out.
54
+ # I'm putting it here for now though.
55
+ if column_type is None:
56
+ return 'None'
57
+ if column_type.is_string_type():
58
+ base = 'str'
59
+ elif column_type.is_int_type():
60
+ base = 'int'
61
+ elif column_type.is_float_type():
62
+ base = 'float'
63
+ elif column_type.is_bool_type():
64
+ base = 'bool'
65
+ elif column_type.is_timestamp_type():
66
+ base = 'datetime'
67
+ elif column_type.is_array_type():
68
+ base = 'ArrayT'
69
+ elif column_type.is_json_type():
70
+ base = 'JsonT'
71
+ elif column_type.is_image_type():
72
+ base = 'ImageT'
73
+ elif column_type.is_video_type():
74
+ base = 'VideoT'
75
+ elif column_type.is_audio_type():
76
+ base = 'AudioT'
77
+ elif column_type.is_document_type():
78
+ base = 'DocumentT'
79
+ else:
80
+ assert False
81
+ return f'Optional[{base}]' if column_type.nullable else base
@@ -0,0 +1,6 @@
1
+ from pathlib import Path
2
+
3
+
4
+ def get_templates_path() -> Path:
5
+ """Implementation of the 'mkdocstrings.python.templates' plugin for custom jinja templates."""
6
+ return Path(__file__).parent / "templates"
@@ -0,0 +1,135 @@
1
+ {#- Template for Pixeltable UDFs. Cargo-culted (with modification) from _base/function.html.jinja. -#}
2
+
3
+ {% block logs scoped %}
4
+ {#- Logging block.
5
+
6
+ This block can be used to log debug messages, deprecation messages, warnings, etc.
7
+ -#}
8
+ {{ log.debug("Rendering " + function.path) }}
9
+ {% endblock logs %}
10
+
11
+ {% import "language"|get_template as lang with context %}
12
+ {#- Language module providing the `t` translation method. -#}
13
+
14
+ <div class="doc doc-object doc-function">
15
+ {% with obj = function, html_id = function.path %}
16
+
17
+ {% if root %}
18
+ {% set show_full_path = config.show_root_full_path %}
19
+ {% set root_members = True %}
20
+ {% elif root_members %}
21
+ {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %}
22
+ {% set root_members = False %}
23
+ {% else %}
24
+ {% set show_full_path = config.show_object_full_path %}
25
+ {% endif %}
26
+
27
+ {% set function_name = function.path if show_full_path else function.name %}
28
+ {#- Brief or full function name depending on configuration. -#}
29
+ {% set symbol_type = "udf" %}
30
+ {#- Symbol type: method when parent is a class, function otherwise. -#}
31
+
32
+ {% if not root or config.show_root_heading %}
33
+ {% filter heading(
34
+ heading_level,
35
+ role="function",
36
+ id=html_id,
37
+ class="doc doc-heading",
38
+ toc_label=(('<code class="doc-symbol doc-symbol-toc doc-symbol-' + symbol_type + '"></code>&nbsp;')|safe if config.show_symbol_type_toc else '') + function.name,
39
+ ) %}
40
+
41
+ {% block heading scoped %}
42
+ {#- Heading block.
43
+
44
+ This block renders the heading for the function.
45
+ -#}
46
+ {% if config.show_symbol_type_heading %}<code class="doc-symbol doc-symbol-heading doc-symbol-{{ symbol_type }}"></code>{% endif %}
47
+ {% if config.separate_signature %}
48
+ <span class="doc doc-object-name doc-function-name">{{ function_name }}</span>
49
+ {% else %}
50
+ {%+ filter highlight(language="python", inline=True) %}
51
+ {{ function_name }}{% include "signature"|get_template with context %}
52
+ {% endfilter %}
53
+ {% endif %}
54
+ {% endblock heading %}
55
+
56
+ {% block labels scoped %}
57
+ {#- Labels block.
58
+
59
+ This block renders the labels for the function.
60
+ -#}
61
+ {% with labels = function.labels %}
62
+ {% include "labels"|get_template with context %}
63
+ {% endwith %}
64
+ {% endblock labels %}
65
+
66
+ {% endfilter %}
67
+
68
+ {% block signature scoped %}
69
+ {#- Signature block.
70
+
71
+ This block renders the signature for the function.
72
+ -#}
73
+ {% if config.separate_signature %}
74
+ {% filter format_signature(function, config.line_length, crossrefs=config.signature_crossrefs) %}
75
+ {{ function.name }}
76
+ {% endfilter %}
77
+ {% endif %}
78
+ {% endblock signature %}
79
+
80
+ {% else %}
81
+
82
+ {% if config.show_root_toc_entry %}
83
+ {% filter heading(
84
+ heading_level,
85
+ role="function",
86
+ id=html_id,
87
+ toc_label=(('<code class="doc-symbol doc-symbol-toc doc-symbol-' + symbol_type + '"></code>&nbsp;')|safe if config.show_symbol_type_toc else '') + function.name,
88
+ hidden=True,
89
+ ) %}
90
+ {% endfilter %}
91
+ {% endif %}
92
+ {% set heading_level = heading_level - 1 %}
93
+ {% endif %}
94
+
95
+ <div class="doc doc-contents {% if root %}first{% endif %}">
96
+ {% block contents scoped %}
97
+ {#- Contents block.
98
+
99
+ This block renders the contents of the function.
100
+ It contains other blocks that users can override.
101
+ Overriding the contents block allows to rearrange the order of the blocks.
102
+ -#}
103
+ {% block docstring scoped %}
104
+ {#- Docstring block.
105
+
106
+ This block renders the docstring for the function.
107
+ -#}
108
+ {% with docstring_sections = function.docstring.parsed %}
109
+ {% include "docstring"|get_template with context %}
110
+ {% endwith %}
111
+ {% endblock docstring %}
112
+
113
+ {% block source scoped %}
114
+ {#- Source block.
115
+
116
+ This block renders the source code for the function.
117
+ -#}
118
+ {% if config.show_source and function.source %}
119
+ <details class="quote">
120
+ <summary>{{ lang.t("Source code in") }} <code>
121
+ {%- if function.relative_filepath.is_absolute() -%}
122
+ {{ function.relative_package_filepath }}
123
+ {%- else -%}
124
+ {{ function.relative_filepath }}
125
+ {%- endif -%}
126
+ </code></summary>
127
+ {{ function.source|highlight(language="python", linestart=function.lineno, linenums=True) }}
128
+ </details>
129
+ {% endif %}
130
+ {% endblock source %}
131
+ {% endblock contents %}
132
+ </div>
133
+
134
+ {% endwith %}
135
+ </div>
pixeltable/type_system.py CHANGED
@@ -183,7 +183,7 @@ class ColumnType:
183
183
  """Two types match if they're equal, aside from nullability"""
184
184
  if not isinstance(other, ColumnType):
185
185
  pass
186
- assert isinstance(other, ColumnType)
186
+ assert isinstance(other, ColumnType), type(other)
187
187
  if type(self) != type(other):
188
188
  return False
189
189
  for member_var in vars(self).keys():
@@ -206,7 +206,7 @@ class ColumnType:
206
206
  if type1.is_scalar_type() and type2.is_scalar_type():
207
207
  t = cls.Type.supertype(type1._type, type2._type, cls.common_supertypes)
208
208
  if t is not None:
209
- return cls.make_type(t)
209
+ return cls.make_type(t).copy(nullable=(type1.nullable or type2.nullable))
210
210
  return None
211
211
 
212
212
  if type1._type == type2._type:
@@ -227,22 +227,23 @@ class ColumnType:
227
227
  def infer_literal_type(cls, val: Any) -> Optional[ColumnType]:
228
228
  if isinstance(val, str):
229
229
  return StringType()
230
+ if isinstance(val, bool):
231
+ # We have to check bool before int, because isinstance(b, int) is True if b is a Python bool
232
+ return BoolType()
230
233
  if isinstance(val, int):
231
234
  return IntType()
232
235
  if isinstance(val, float):
233
236
  return FloatType()
234
- if isinstance(val, bool):
235
- return BoolType()
236
- if isinstance(val, datetime.datetime) or isinstance(val, datetime.date):
237
+ if isinstance(val, datetime.datetime):
237
238
  return TimestampType()
238
239
  if isinstance(val, PIL.Image.Image):
239
- return ImageType(width=val.width, height=val.height)
240
+ return ImageType(width=val.width, height=val.height, mode=val.mode)
240
241
  if isinstance(val, np.ndarray):
241
242
  col_type = ArrayType.from_literal(val)
242
243
  if col_type is not None:
243
244
  return col_type
244
245
  # this could still be json-serializable
245
- if isinstance(val, dict) or isinstance(val, np.ndarray):
246
+ if isinstance(val, dict) or isinstance(val, list) or isinstance(val, np.ndarray):
246
247
  try:
247
248
  JsonType().validate_literal(val)
248
249
  return JsonType()
@@ -276,7 +277,7 @@ class ColumnType:
276
277
  return FloatType()
277
278
  if base is bool:
278
279
  return BoolType()
279
- if base is datetime.date or base is datetime.datetime:
280
+ if base is datetime.datetime:
280
281
  return TimestampType()
281
282
  if issubclass(base, Sequence) or issubclass(base, Mapping):
282
283
  return JsonType()
@@ -425,7 +426,7 @@ class StringType(ColumnType):
425
426
  def conversion_fn(self, target: ColumnType) -> Optional[Callable[[Any], Any]]:
426
427
  if not target.is_timestamp_type():
427
428
  return None
428
- def convert(val: str) -> Optional[datetime]:
429
+ def convert(val: str) -> Optional[datetime.datetime]:
429
430
  try:
430
431
  dt = datetime.datetime.fromisoformat(val)
431
432
  return dt
@@ -506,8 +507,8 @@ class TimestampType(ColumnType):
506
507
  return sql.TIMESTAMP()
507
508
 
508
509
  def _validate_literal(self, val: Any) -> None:
509
- if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
510
- raise TypeError(f'Expected datetime.datetime or datetime.date, got {val.__class__.__name__}')
510
+ if not isinstance(val, datetime.datetime):
511
+ raise TypeError(f'Expected datetime.datetime, got {val.__class__.__name__}')
511
512
 
512
513
  def _create_literal(self, val: Any) -> Any:
513
514
  if isinstance(val, str):
@@ -577,7 +578,7 @@ class ArrayType(ColumnType):
577
578
  if base_type is None:
578
579
  return None
579
580
  shape = [n1 if n1 == n2 else None for n1, n2 in zip(type1.shape, type2.shape)]
580
- return ArrayType(tuple(shape), base_type)
581
+ return ArrayType(tuple(shape), base_type, nullable=(type1.nullable or type2.nullable))
581
582
 
582
583
  def _as_dict(self) -> Dict:
583
584
  result = super()._as_dict()
@@ -609,7 +610,7 @@ class ArrayType(ColumnType):
609
610
  dtype = StringType()
610
611
  else:
611
612
  return None
612
- return cls(val.shape, dtype=dtype, nullable=True)
613
+ return cls(val.shape, dtype=dtype)
613
614
 
614
615
  def is_valid_literal(self, val: np.ndarray) -> bool:
615
616
  if not isinstance(val, np.ndarray):
@@ -695,7 +696,7 @@ class ImageType(ColumnType):
695
696
  return f'{self._type.name.lower()}{params_str}'
696
697
 
697
698
  def _is_supertype_of(self, other: ImageType) -> bool:
698
- if self.mode != other.mode:
699
+ if self.mode is not None and self.mode != other.mode:
699
700
  return False
700
701
  if self.width is None and self.height is None:
701
702
  return True
pixeltable/utils/s3.py CHANGED
@@ -10,4 +10,4 @@ def get_client() -> Any:
10
10
  except AttributeError:
11
11
  # No credentials available, use unsigned mode
12
12
  config = botocore.config.Config(signature_version=botocore.UNSIGNED)
13
- return boto3.client('s3', config=config)
13
+ return boto3.client('s3', config=config)