pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +34 -6
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +590 -30
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +359 -45
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +116 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +195 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +34 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +256 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +122 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +418 -182
- pixeltable/tests/conftest.py +146 -88
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/functions/test_huggingface.py +158 -0
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +370 -0
- pixeltable/tests/test_dataframe.py +439 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +120 -0
- pixeltable/tests/test_exprs.py +592 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1195 -263
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +151 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +320 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/tool/create_test_video.py +81 -0
- pixeltable/type_system.py +445 -124
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +167 -0
- pixeltable/utils/pytorch.py +91 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.4.dist-info/LICENSE +18 -0
- pixeltable-0.2.4.dist-info/METADATA +127 -0
- pixeltable-0.2.4.dist-info/RECORD +132 -0
- {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_functions.py +0 -11
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.0.dist-info/METADATA +0 -34
- pixeltable-0.1.0.dist-info/RECORD +0 -36
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import importlib
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
import types
|
|
8
|
+
from typing import Optional, Dict, List, Tuple
|
|
9
|
+
from uuid import UUID
|
|
10
|
+
|
|
11
|
+
import cloudpickle
|
|
12
|
+
import sqlalchemy as sql
|
|
13
|
+
|
|
14
|
+
import pixeltable.env as env
|
|
15
|
+
import pixeltable.exceptions as excs
|
|
16
|
+
import pixeltable.type_system as ts
|
|
17
|
+
from pixeltable.metadata import schema
|
|
18
|
+
from .function import Function
|
|
19
|
+
from .globals import get_caller_module_path
|
|
20
|
+
|
|
21
|
+
_logger = logging.getLogger('pixeltable')
|
|
22
|
+
|
|
23
|
+
class FunctionRegistry:
|
|
24
|
+
"""
|
|
25
|
+
A central registry for all Functions. Handles interactions with the backing store.
|
|
26
|
+
Function are loaded from the store on demand.
|
|
27
|
+
"""
|
|
28
|
+
_instance: Optional[FunctionRegistry] = None
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def get(cls) -> FunctionRegistry:
|
|
32
|
+
if cls._instance is None:
|
|
33
|
+
cls._instance = FunctionRegistry()
|
|
34
|
+
return cls._instance
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.stored_fns_by_id: Dict[UUID, Function] = {}
|
|
38
|
+
self.module_fns: Dict[str, Function] = {} # fqn -> Function
|
|
39
|
+
|
|
40
|
+
def clear_cache(self) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Useful during testing
|
|
43
|
+
"""
|
|
44
|
+
self.stored_fns_by_id: Dict[UUID, Function] = {}
|
|
45
|
+
|
|
46
|
+
# def register_std_modules(self) -> None:
|
|
47
|
+
# """Register all submodules of pixeltable.functions"""
|
|
48
|
+
# root = sys.modules['pixeltable.functions']
|
|
49
|
+
# self.register_submodules(root)
|
|
50
|
+
#
|
|
51
|
+
# def register_submodules(self, mod: types.ModuleType) -> None:
|
|
52
|
+
# # TODO: this doesn't work
|
|
53
|
+
# for name, submod in mod.__dict__.items():
|
|
54
|
+
# if isinstance(submod, types.ModuleType):
|
|
55
|
+
# self.register_module(submod)
|
|
56
|
+
# self.register_submodules(submod)
|
|
57
|
+
#
|
|
58
|
+
# def register_module(self) -> None:
|
|
59
|
+
# """Register all Functions in the caller module"""
|
|
60
|
+
# caller_path = get_caller_module_path()
|
|
61
|
+
# mod = importlib.import_module(caller_path)
|
|
62
|
+
# for name in dir(mod):
|
|
63
|
+
# obj = getattr(mod, name)
|
|
64
|
+
# if isinstance(obj, Function):
|
|
65
|
+
# fn_path = f'{caller_path}.{name}' # fully-qualified name
|
|
66
|
+
# self.module_fns[fn_path] = obj
|
|
67
|
+
|
|
68
|
+
def register_function(self, fqn: str, fn: Function) -> None:
|
|
69
|
+
self.module_fns[fqn] = fn
|
|
70
|
+
|
|
71
|
+
def list_functions(self) -> List[Function]:
|
|
72
|
+
# retrieve Function.Metadata data for all existing stored functions from store directly
|
|
73
|
+
# (self.stored_fns_by_id isn't guaranteed to contain all functions)
|
|
74
|
+
# TODO: have the client do this, once the client takes over the Db functionality
|
|
75
|
+
# stmt = sql.select(
|
|
76
|
+
# schema.Function.name, schema.Function.md,
|
|
77
|
+
# schema.Db.name, schema.Dir.path, sql_func.length(schema.Function.init_obj))\
|
|
78
|
+
# .where(schema.Function.db_id == schema.Db.id)\
|
|
79
|
+
# .where(schema.Function.dir_id == schema.Dir.id)
|
|
80
|
+
# stored_fn_md: List[Function.Metadata] = []
|
|
81
|
+
# with Env.get().engine.begin() as conn:
|
|
82
|
+
# rows = conn.execute(stmt)
|
|
83
|
+
# for name, md_dict, db_name, dir_path, init_obj_len in rows:
|
|
84
|
+
# md = Function.Metadata.from_dict(md_dict)
|
|
85
|
+
# md.fqn = f'{db_name}{"." + dir_path if dir_path != "" else ""}.{name}'
|
|
86
|
+
# stored_fn_md.append(md)
|
|
87
|
+
return list(self.module_fns.values())
|
|
88
|
+
|
|
89
|
+
# def get_function(self, *, id: Optional[UUID] = None, fqn: Optional[str] = None) -> Function:
|
|
90
|
+
# assert (id is not None) != (fqn is not None)
|
|
91
|
+
# if id is not None:
|
|
92
|
+
# if id not in self.stored_fns_by_id:
|
|
93
|
+
# stmt = sql.select(
|
|
94
|
+
# schema.Function.md, schema.Function.eval_obj, schema.Function.init_obj,
|
|
95
|
+
# schema.Function.update_obj, schema.Function.value_obj) \
|
|
96
|
+
# .where(schema.Function.id == id)
|
|
97
|
+
# with env.Env.get().engine.begin() as conn:
|
|
98
|
+
# rows = conn.execute(stmt)
|
|
99
|
+
# row = next(rows)
|
|
100
|
+
# schema_md = schema.md_from_dict(schema.FunctionMd, row[0])
|
|
101
|
+
# name = schema_md.name
|
|
102
|
+
# md = FunctionMd.from_dict(schema_md.md)
|
|
103
|
+
# # md.fqn is set by caller
|
|
104
|
+
# eval_fn = cloudpickle.loads(row[1]) if row[1] is not None else None
|
|
105
|
+
# # TODO: are these checks needed?
|
|
106
|
+
# if row[1] is not None and eval_fn is None:
|
|
107
|
+
# raise excs.Error(f'Could not load eval_fn for function {name}')
|
|
108
|
+
# init_fn = cloudpickle.loads(row[2]) if row[2] is not None else None
|
|
109
|
+
# if row[2] is not None and init_fn is None:
|
|
110
|
+
# raise excs.Error(f'Could not load init_fn for aggregate function {name}')
|
|
111
|
+
# update_fn = cloudpickle.loads(row[3]) if row[3] is not None else None
|
|
112
|
+
# if row[3] is not None and update_fn is None:
|
|
113
|
+
# raise excs.Error(f'Could not load update_fn for aggregate function {name}')
|
|
114
|
+
# value_fn = cloudpickle.loads(row[4]) if row[4] is not None else None
|
|
115
|
+
# if row[4] is not None and value_fn is None:
|
|
116
|
+
# raise excs.Error(f'Could not load value_fn for aggregate function {name}')
|
|
117
|
+
#
|
|
118
|
+
# func = Function(
|
|
119
|
+
# md, id=id,
|
|
120
|
+
# eval_fn=eval_fn, init_fn=init_fn, update_fn=update_fn, value_fn=value_fn)
|
|
121
|
+
# _logger.info(f'Loaded function {name} from store')
|
|
122
|
+
# self.stored_fns_by_id[id] = func
|
|
123
|
+
# assert id in self.stored_fns_by_id
|
|
124
|
+
# return self.stored_fns_by_id[id]
|
|
125
|
+
# else:
|
|
126
|
+
# # this is an already-registered library function
|
|
127
|
+
# assert fqn in self.module_fns, f'{fqn} not found'
|
|
128
|
+
# return self.module_fns[fqn]
|
|
129
|
+
|
|
130
|
+
def get_type_methods(self, name: str, base_type: ts.ColumnType.Type) -> List[Function]:
|
|
131
|
+
return [
|
|
132
|
+
fn for fn in self.module_fns.values()
|
|
133
|
+
if fn.self_path is not None and fn.self_path.endswith('.' + name) \
|
|
134
|
+
and fn.signature.parameters_by_pos[0].col_type.type_enum == base_type
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
#def create_function(self, md: schema.FunctionMd, binary_obj: bytes, dir_id: Optional[UUID] = None) -> UUID:
|
|
138
|
+
def create_stored_function(self, pxt_fn: Function, dir_id: Optional[UUID] = None) -> UUID:
|
|
139
|
+
fn_md, binary_obj = pxt_fn.to_store()
|
|
140
|
+
md = schema.FunctionMd(name=pxt_fn.name, md=fn_md, py_version=sys.version, class_name=pxt_fn.__class__.__name__)
|
|
141
|
+
with env.Env.get().engine.begin() as conn:
|
|
142
|
+
res = conn.execute(
|
|
143
|
+
sql.insert(schema.Function.__table__)
|
|
144
|
+
.values(dir_id=dir_id, md=dataclasses.asdict(md), binary_obj=binary_obj))
|
|
145
|
+
id = res.inserted_primary_key[0]
|
|
146
|
+
_logger.info(f'Created function {pxt_fn.name} (id {id}) in store')
|
|
147
|
+
self.stored_fns_by_id[id] = pxt_fn
|
|
148
|
+
return id
|
|
149
|
+
|
|
150
|
+
def get_stored_function(self, id: UUID) -> Function:
|
|
151
|
+
if id in self.stored_fns_by_id:
|
|
152
|
+
return self.stored_fns_by_id[id]
|
|
153
|
+
stmt = sql.select(schema.Function.md, schema.Function.binary_obj, schema.Function.dir_id)\
|
|
154
|
+
.where(schema.Function.id == id)
|
|
155
|
+
with env.Env.get().engine.begin() as conn:
|
|
156
|
+
row = conn.execute(stmt).fetchone()
|
|
157
|
+
if row is None:
|
|
158
|
+
raise excs.Error(f'Function with id {id} not found')
|
|
159
|
+
# create instance of the referenced class
|
|
160
|
+
md = schema.md_from_dict(schema.FunctionMd, row[0])
|
|
161
|
+
func_module = importlib.import_module(self.__module__.rsplit('.', 1)[0])
|
|
162
|
+
func_class = getattr(func_module, md.class_name)
|
|
163
|
+
instance = func_class.from_store(md.name, md.md, row[1])
|
|
164
|
+
self.stored_fns_by_id[id] = instance
|
|
165
|
+
return instance
|
|
166
|
+
|
|
167
|
+
# def create_function(self, fn: Function, dir_id: Optional[UUID] = None, name: Optional[str] = None) -> None:
|
|
168
|
+
# with env.Env.get().engine.begin() as conn:
|
|
169
|
+
# _logger.debug(f'Pickling function {name}')
|
|
170
|
+
# eval_fn_str = cloudpickle.dumps(fn.eval_fn) if fn.eval_fn is not None else None
|
|
171
|
+
# init_fn_str = cloudpickle.dumps(fn.init_fn) if fn.init_fn is not None else None
|
|
172
|
+
# update_fn_str = cloudpickle.dumps(fn.update_fn) if fn.update_fn is not None else None
|
|
173
|
+
# value_fn_str = cloudpickle.dumps(fn.value_fn) if fn.value_fn is not None else None
|
|
174
|
+
# total_size = \
|
|
175
|
+
# (len(eval_fn_str) if eval_fn_str is not None else 0) + \
|
|
176
|
+
# (len(init_fn_str) if init_fn_str is not None else 0) + \
|
|
177
|
+
# (len(update_fn_str) if update_fn_str is not None else 0) + \
|
|
178
|
+
# (len(value_fn_str) if value_fn_str is not None else 0)
|
|
179
|
+
# _logger.debug(f'Pickled function {name} ({total_size} bytes)')
|
|
180
|
+
#
|
|
181
|
+
# schema_md = schema.FunctionMd(name=name, md=fn.md.as_dict())
|
|
182
|
+
# res = conn.execute(
|
|
183
|
+
# sql.insert(schema.Function.__table__)
|
|
184
|
+
# .values(
|
|
185
|
+
# dir_id=dir_id, md=dataclasses.asdict(schema_md),
|
|
186
|
+
# eval_obj=eval_fn_str, init_obj=init_fn_str, update_obj=update_fn_str, value_obj=value_fn_str))
|
|
187
|
+
# fn.id = res.inserted_primary_key[0]
|
|
188
|
+
# self.stored_fns_by_id[fn.id] = fn
|
|
189
|
+
# _logger.info(f'Created function {name} in store')
|
|
190
|
+
|
|
191
|
+
# def update_function(self, id: UUID, new_fn: Function) -> None:
|
|
192
|
+
# """
|
|
193
|
+
# Updates the callables for the Function with the given id in the store and in the cache, if present.
|
|
194
|
+
# """
|
|
195
|
+
# assert not new_fn.is_module_function
|
|
196
|
+
# with env.Env.get().engine.begin() as conn:
|
|
197
|
+
# updates = {}
|
|
198
|
+
# if new_fn.eval_fn is not None:
|
|
199
|
+
# updates[schema.Function.eval_obj] = cloudpickle.dumps(new_fn.eval_fn)
|
|
200
|
+
# if new_fn.init_fn is not None:
|
|
201
|
+
# updates[schema.Function.init_obj] = cloudpickle.dumps(new_fn.init_fn)
|
|
202
|
+
# if new_fn.update_fn is not None:
|
|
203
|
+
# updates[schema.Function.update_obj] = cloudpickle.dumps(new_fn.update_fn)
|
|
204
|
+
# if new_fn.value_fn is not None:
|
|
205
|
+
# updates[schema.Function.value_obj] = cloudpickle.dumps(new_fn.value_fn)
|
|
206
|
+
# conn.execute(
|
|
207
|
+
# sql.update(schema.Function.__table__)
|
|
208
|
+
# .values(updates)
|
|
209
|
+
# .where(schema.Function.id == id))
|
|
210
|
+
# _logger.info(f'Updated function {new_fn.md.fqn} (id={id}) in store')
|
|
211
|
+
# if id in self.stored_fns_by_id:
|
|
212
|
+
# if new_fn.eval_fn is not None:
|
|
213
|
+
# self.stored_fns_by_id[id].eval_fn = new_fn.eval_fn
|
|
214
|
+
# if new_fn.init_fn is not None:
|
|
215
|
+
# self.stored_fns_by_id[id].init_fn = new_fn.init_fn
|
|
216
|
+
# if new_fn.update_fn is not None:
|
|
217
|
+
# self.stored_fns_by_id[id].update_fn = new_fn.update_fn
|
|
218
|
+
# if new_fn.value_fn is not None:
|
|
219
|
+
# self.stored_fns_by_id[id].value_fn = new_fn.value_fn
|
|
220
|
+
|
|
221
|
+
def delete_function(self, id: UUID) -> None:
|
|
222
|
+
assert id is not None
|
|
223
|
+
with env.Env.get().engine.begin() as conn:
|
|
224
|
+
conn.execute(
|
|
225
|
+
sql.delete(schema.Function.__table__)
|
|
226
|
+
.where(schema.Function.id == id))
|
|
227
|
+
_logger.info(f'Deleted function with id {id} from store')
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from types import ModuleType
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def resolve_symbol(symbol_path: str) -> object:
|
|
8
|
+
path_elems = symbol_path.split('.')
|
|
9
|
+
module: Optional[ModuleType] = None
|
|
10
|
+
if path_elems[0:2] == ['pixeltable', 'functions'] and len(path_elems) > 2:
|
|
11
|
+
# if this is a pixeltable.functions submodule, it cannot be resolved via pixeltable.functions;
|
|
12
|
+
# try to import the submodule directly
|
|
13
|
+
submodule_path = '.'.join(path_elems[0:3])
|
|
14
|
+
try:
|
|
15
|
+
module = importlib.import_module(submodule_path)
|
|
16
|
+
path_elems = path_elems[3:]
|
|
17
|
+
except ModuleNotFoundError:
|
|
18
|
+
pass
|
|
19
|
+
if module is None:
|
|
20
|
+
module = importlib.import_module(path_elems[0])
|
|
21
|
+
path_elems = path_elems[1:]
|
|
22
|
+
obj = module
|
|
23
|
+
for el in path_elems:
|
|
24
|
+
obj = getattr(obj, el)
|
|
25
|
+
return obj
|
|
26
|
+
|
|
27
|
+
def get_caller_module_path() -> str:
|
|
28
|
+
"""Return the module path of our caller's caller"""
|
|
29
|
+
stack = inspect.stack()
|
|
30
|
+
try:
|
|
31
|
+
caller_frame = stack[2].frame
|
|
32
|
+
module_path = caller_frame.f_globals['__name__']
|
|
33
|
+
finally:
|
|
34
|
+
# remove references to stack frames to avoid reference cycles
|
|
35
|
+
del stack
|
|
36
|
+
return module_path
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from typing import Optional, Any, Dict, List, Tuple
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .signature import Signature, Parameter
|
|
9
|
+
from .batched_function import BatchedFunction
|
|
10
|
+
import pixeltable.env as env
|
|
11
|
+
import pixeltable.type_system as ts
|
|
12
|
+
import pixeltable.exceptions as excs
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_logger = logging.getLogger('pixeltable')
|
|
16
|
+
|
|
17
|
+
class NOSFunction(BatchedFunction):
|
|
18
|
+
def __init__(self, model_spec: 'nos.common.ModelSpec', self_path: str):
|
|
19
|
+
return_type, param_types = self._convert_nos_signature(model_spec.signature)
|
|
20
|
+
param_names = list(model_spec.signature.get_inputs_spec().keys())
|
|
21
|
+
params = [
|
|
22
|
+
Parameter(name, col_type, inspect.Parameter.POSITIONAL_OR_KEYWORD, is_batched=False)
|
|
23
|
+
for name, col_type in zip(param_names, param_types)
|
|
24
|
+
]
|
|
25
|
+
signature = Signature(return_type, params)
|
|
26
|
+
|
|
27
|
+
# construct inspect.Signature
|
|
28
|
+
py_params = [
|
|
29
|
+
inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
30
|
+
for name, col_type in zip(param_names, param_types)
|
|
31
|
+
]
|
|
32
|
+
py_signature = inspect.Signature(py_params)
|
|
33
|
+
super().__init__(signature, py_signature=py_signature, self_path=self_path)
|
|
34
|
+
|
|
35
|
+
self.model_spec = model_spec
|
|
36
|
+
self.nos_param_names = model_spec.signature.get_inputs_spec().keys()
|
|
37
|
+
self.scalar_nos_param_names = []
|
|
38
|
+
|
|
39
|
+
# for models on images
|
|
40
|
+
self.img_param_pos: Optional[int] = None # position of the image parameter in the function signature
|
|
41
|
+
# for multi-resolution image models
|
|
42
|
+
import nos
|
|
43
|
+
self.img_batch_params: List[nos.common.ObjectTypeInfo] = []
|
|
44
|
+
self.img_resolutions: List[int] = [] # for multi-resolution models
|
|
45
|
+
self.batch_size: Optional[int] = None
|
|
46
|
+
self.img_size: Optional[Tuple[int, int]] = None # W, H
|
|
47
|
+
|
|
48
|
+
# try to determine batch_size and img_size
|
|
49
|
+
batch_size = sys.maxsize
|
|
50
|
+
for pos, (param_name, type_info) in enumerate(model_spec.signature.get_inputs_spec().items()):
|
|
51
|
+
if isinstance(type_info, list):
|
|
52
|
+
assert isinstance(type_info[0].base_spec(), nos.common.ImageSpec)
|
|
53
|
+
# this is a multi-resolution image model
|
|
54
|
+
self.img_batch_params = type_info
|
|
55
|
+
self.img_param_pos = pos
|
|
56
|
+
self.img_resolutions = [
|
|
57
|
+
info.base_spec().shape[0] * info.base_spec().shape[1] for info in self.img_batch_params
|
|
58
|
+
]
|
|
59
|
+
else:
|
|
60
|
+
if not type_info.is_batched():
|
|
61
|
+
self.scalar_nos_param_names.append(param_name)
|
|
62
|
+
else:
|
|
63
|
+
batch_size = min(batch_size, type_info.batch_size())
|
|
64
|
+
|
|
65
|
+
if isinstance(type_info.base_spec(), nos.common.ImageSpec):
|
|
66
|
+
# this is a single-resolution image model
|
|
67
|
+
if type_info.base_spec().shape is not None:
|
|
68
|
+
self.img_size = (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
|
|
69
|
+
self.img_param_pos = pos
|
|
70
|
+
|
|
71
|
+
if batch_size != sys.maxsize:
|
|
72
|
+
self.batch_size = batch_size
|
|
73
|
+
|
|
74
|
+
def _convert_nos_type(
|
|
75
|
+
self, type_info: 'nos.common.spec.ObjectTypeInfo', ignore_shape: bool = False
|
|
76
|
+
) -> ts.ColumnType:
|
|
77
|
+
"""Convert ObjectTypeInfo to ColumnType"""
|
|
78
|
+
import nos
|
|
79
|
+
if type_info.base_spec() is None:
|
|
80
|
+
if type_info.base_type() == str:
|
|
81
|
+
return ts.StringType()
|
|
82
|
+
if type_info.base_type() == int:
|
|
83
|
+
return ts.IntType()
|
|
84
|
+
if type_info.base_type() == float:
|
|
85
|
+
return ts.FloatType()
|
|
86
|
+
if type_info.base_type() == bool:
|
|
87
|
+
return ts.BoolType()
|
|
88
|
+
else:
|
|
89
|
+
raise excs.Error(f'Cannot convert {type_info} to ColumnType')
|
|
90
|
+
elif isinstance(type_info.base_spec(), nos.common.ImageSpec):
|
|
91
|
+
size = None
|
|
92
|
+
if not ignore_shape and type_info.base_spec().shape is not None:
|
|
93
|
+
size = (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
|
|
94
|
+
# TODO: set mode
|
|
95
|
+
return ts.ImageType(size=size)
|
|
96
|
+
elif isinstance(type_info.base_spec(), nos.common.TensorSpec):
|
|
97
|
+
return ts.ArrayType(shape=type_info.base_spec().shape, dtype=ts.FloatType())
|
|
98
|
+
else:
|
|
99
|
+
raise excs.Error(f'Cannot convert {type_info} to ColumnType')
|
|
100
|
+
|
|
101
|
+
def _convert_nos_signature(
|
|
102
|
+
self, sig: 'nos.common.spec.FunctionSignature') -> Tuple[ts.ColumnType, List[ts.ColumnType]]:
|
|
103
|
+
if len(sig.get_outputs_spec()) > 1:
|
|
104
|
+
return_type = ts.JsonType()
|
|
105
|
+
else:
|
|
106
|
+
return_type = self._convert_nos_type(list(sig.get_outputs_spec().values())[0])
|
|
107
|
+
param_types: List[ts.ColumnType] = []
|
|
108
|
+
for _, type_info in sig.get_inputs_spec().items():
|
|
109
|
+
# if there are multiple input shapes we leave them out of the ColumnType and deal with them in FunctionCall
|
|
110
|
+
if isinstance(type_info, list):
|
|
111
|
+
param_types.append(self._convert_nos_type(type_info[0], ignore_shape=True))
|
|
112
|
+
else:
|
|
113
|
+
param_types.append(self._convert_nos_type(type_info, ignore_shape=False))
|
|
114
|
+
return return_type, param_types
|
|
115
|
+
|
|
116
|
+
def is_multi_res_model(self) -> bool:
|
|
117
|
+
return self.img_param_pos is not None and len(self.img_batch_params) > 0
|
|
118
|
+
|
|
119
|
+
def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
|
|
120
|
+
if self.batch_size is not None or len(self.img_batch_params) == 0 or len(args) == 0:
|
|
121
|
+
return self.batch_size
|
|
122
|
+
|
|
123
|
+
# return batch size appropriate for the given image size
|
|
124
|
+
img_arg = args[self.img_param_pos]
|
|
125
|
+
input_res = img_arg.size[0] * img_arg.size[1]
|
|
126
|
+
batch_size, _ = self._select_model_res(input_res)
|
|
127
|
+
return batch_size
|
|
128
|
+
|
|
129
|
+
def _select_model_res(self, input_res: int) -> Tuple[int, Tuple[int, int]]:
|
|
130
|
+
"""Select the model resolution that is closest to the input resolution
|
|
131
|
+
Returns: batch size, image size
|
|
132
|
+
"""
|
|
133
|
+
deltas = [abs(res - input_res) for res in self.img_resolutions]
|
|
134
|
+
idx = deltas.index(min(deltas))
|
|
135
|
+
type_info = self.img_batch_params[idx]
|
|
136
|
+
return type_info.batch_size(), (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
|
|
137
|
+
|
|
138
|
+
def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
|
|
139
|
+
# check that scalar args are constant
|
|
140
|
+
|
|
141
|
+
num_batch_rows = len(arg_batches[0])
|
|
142
|
+
# if we need to rescale image args, and we're doing object detection, we need to rescale the
|
|
143
|
+
# bounding boxes as well
|
|
144
|
+
scale_factors = np.ndarray((num_batch_rows, 2), dtype=np.float32)
|
|
145
|
+
|
|
146
|
+
target_res: Optional[Tuple[int, int]] = None
|
|
147
|
+
if self.img_param_pos is not None:
|
|
148
|
+
# for now, NOS will only receive RGB images
|
|
149
|
+
arg_batches[self.img_param_pos] = \
|
|
150
|
+
[img.convert('RGB') if img.mode != 'RGB' else img for img in arg_batches[self.img_param_pos ]]
|
|
151
|
+
if self.is_multi_res_model():
|
|
152
|
+
# we need to select the resolution that is closest to the input resolution
|
|
153
|
+
sample_img = arg_batches[self.img_param_pos][0]
|
|
154
|
+
_, target_res = self._select_model_res(sample_img.size[0] * sample_img.size[1])
|
|
155
|
+
else:
|
|
156
|
+
target_res = self.img_size
|
|
157
|
+
|
|
158
|
+
if target_res is not None:
|
|
159
|
+
# we need to record the scale factors and resize the images;
|
|
160
|
+
# keep in mind that every image could have a different resolution
|
|
161
|
+
scale_factors[:, 0] = \
|
|
162
|
+
[img.size[0] / target_res[0] for img in arg_batches[self.img_param_pos]]
|
|
163
|
+
scale_factors[:, 1] = \
|
|
164
|
+
[img.size[1] / target_res[1] for img in arg_batches[self.img_param_pos]]
|
|
165
|
+
arg_batches[self.img_param_pos] = [
|
|
166
|
+
# only resize if necessary
|
|
167
|
+
img.resize(target_res) if img.size != target_res else img
|
|
168
|
+
for img in arg_batches[self.img_param_pos]
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
kwargs = {param_name: args for param_name, args in zip(self.nos_param_names, arg_batches)}
|
|
172
|
+
# fix up scalar parameters
|
|
173
|
+
kwargs.update(
|
|
174
|
+
{param_name: kwargs[param_name][0] for param_name in self.scalar_nos_param_names})
|
|
175
|
+
_logger.debug(
|
|
176
|
+
f'Running NOS task {self.model_spec.task}: '
|
|
177
|
+
f'batch_size={num_batch_rows} target_res={target_res}')
|
|
178
|
+
result = env.Env.get().nos_client.Run(
|
|
179
|
+
task=self.model_spec.task, model_name=self.model_spec.name, **kwargs)
|
|
180
|
+
|
|
181
|
+
import nos
|
|
182
|
+
if self.model_spec.task == nos.common.TaskType.OBJECT_DETECTION_2D and target_res is not None:
|
|
183
|
+
# we need to rescale the bounding boxes
|
|
184
|
+
result_bboxes = [] # workaround: result['bboxes'][*] is immutable
|
|
185
|
+
for i, bboxes in enumerate(result['bboxes']):
|
|
186
|
+
bboxes = np.copy(bboxes)
|
|
187
|
+
nos_batch_row_idx = i
|
|
188
|
+
bboxes[:, 0] *= scale_factors[nos_batch_row_idx, 0]
|
|
189
|
+
bboxes[:, 1] *= scale_factors[nos_batch_row_idx, 1]
|
|
190
|
+
bboxes[:, 2] *= scale_factors[nos_batch_row_idx, 0]
|
|
191
|
+
bboxes[:, 3] *= scale_factors[nos_batch_row_idx, 1]
|
|
192
|
+
result_bboxes.append(bboxes)
|
|
193
|
+
result['bboxes'] = result_bboxes
|
|
194
|
+
|
|
195
|
+
if len(result) == 1:
|
|
196
|
+
key = list(result.keys())[0]
|
|
197
|
+
row_results = result[key]
|
|
198
|
+
else:
|
|
199
|
+
# we rearrange result into one dict per row
|
|
200
|
+
row_results = [{k: v[i].tolist() for k, v in result.items()} for i in range(num_batch_rows)]
|
|
201
|
+
return row_results
|
|
202
|
+
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import enum
|
|
5
|
+
import inspect
|
|
6
|
+
import logging
|
|
7
|
+
import typing
|
|
8
|
+
from typing import Optional, Callable, Dict, List, Any, Union, Tuple
|
|
9
|
+
|
|
10
|
+
import pixeltable.exceptions as excs
|
|
11
|
+
import pixeltable.type_system as ts
|
|
12
|
+
|
|
13
|
+
_logger = logging.getLogger('pixeltable')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class Parameter:
|
|
18
|
+
name: str
|
|
19
|
+
col_type: Optional[ts.ColumnType] # None for variable parameters
|
|
20
|
+
kind: enum.Enum # inspect.Parameter.kind; inspect._ParameterKind is private
|
|
21
|
+
is_batched: bool = False # True if the parameter is a batched parameter (eg, Batch[dict])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
T = typing.TypeVar('T')
|
|
25
|
+
Batch = typing.Annotated[list[T], 'pxt-batch']
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Signature:
|
|
29
|
+
"""
|
|
30
|
+
Represents the signature of a Pixeltable function.
|
|
31
|
+
|
|
32
|
+
Regarding return type:
|
|
33
|
+
- most functions will have a fixed return type, which is specified directly
|
|
34
|
+
- some functions will have a return type that depends on the argument values;
|
|
35
|
+
ex.: PIL.Image.Image.resize() returns an image with dimensions specified as a parameter
|
|
36
|
+
- in the latter case, the 'return_type' field is a function that takes the bound arguments and returns the
|
|
37
|
+
return type; if no bound arguments are specified, a generic return type is returned (eg, ImageType() without a
|
|
38
|
+
size)
|
|
39
|
+
- self.is_batched: return type is a Batch[...] type
|
|
40
|
+
"""
|
|
41
|
+
SPECIAL_PARAM_NAMES = ['group_by', 'order_by']
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
return_type: Union[ts.ColumnType, Callable[[Dict[str, Any]], ts.ColumnType]],
|
|
46
|
+
parameters: List[Parameter], is_batched: bool = False):
|
|
47
|
+
self.return_type = return_type
|
|
48
|
+
self.is_batched = is_batched
|
|
49
|
+
# we rely on the ordering guarantee of dicts in Python >=3.7
|
|
50
|
+
self.parameters = {p.name: p for p in parameters}
|
|
51
|
+
self.parameters_by_pos = parameters.copy()
|
|
52
|
+
self.constant_parameters = [p for p in parameters if not p.is_batched]
|
|
53
|
+
self.batched_parameters = [p for p in parameters if p.is_batched]
|
|
54
|
+
|
|
55
|
+
def get_return_type(self, bound_args: Optional[Dict[str, Any]] = None) -> ts.ColumnType:
|
|
56
|
+
if isinstance(self.return_type, ts.ColumnType):
|
|
57
|
+
return self.return_type
|
|
58
|
+
return self.return_type(bound_args)
|
|
59
|
+
|
|
60
|
+
def as_dict(self) -> Dict[str, Any]:
|
|
61
|
+
result = {
|
|
62
|
+
'return_type': self.get_return_type().as_dict(),
|
|
63
|
+
'parameters': [
|
|
64
|
+
[p.name, p.col_type.as_dict() if p.col_type is not None else None, p.kind, p.is_batched]
|
|
65
|
+
for p in self.parameters.values()
|
|
66
|
+
]
|
|
67
|
+
}
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, d: Dict[str, Any]) -> Signature:
|
|
72
|
+
parameters = [Parameter(p[0], ts.ColumnType.from_dict(p[1]), p[2], p[3]) for p in d['parameters']]
|
|
73
|
+
return cls(ts.ColumnType.from_dict(d['return_type']), parameters)
|
|
74
|
+
|
|
75
|
+
def __eq__(self, other: Signature) -> bool:
|
|
76
|
+
if self.get_return_type() != other.get_return_type():
|
|
77
|
+
return False
|
|
78
|
+
if len(self.parameters) != len(other.parameters):
|
|
79
|
+
return False
|
|
80
|
+
# ignore the parameter name
|
|
81
|
+
for param, other_param in zip(self.parameters.values(), other.parameters.values()):
|
|
82
|
+
if param.col_type != other_param.col_type or param.kind != other_param.kind:
|
|
83
|
+
return False
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
def __str__(self) -> str:
|
|
87
|
+
param_strs: List[str] = []
|
|
88
|
+
for p in self.parameters.values():
|
|
89
|
+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
90
|
+
param_strs.append(f'*{p.name}')
|
|
91
|
+
elif p.kind == inspect.Parameter.VAR_KEYWORD:
|
|
92
|
+
param_strs.append(f'**{p.name}')
|
|
93
|
+
else:
|
|
94
|
+
param_strs.append(f'{p.name}: {str(p.col_type)}')
|
|
95
|
+
return f'({", ".join(param_strs)}) -> {str(self.get_return_type())}'
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _infer_type(cls, annotation: Optional[type]) -> Tuple[Optional[ts.ColumnType], Optional[bool]]:
|
|
99
|
+
"""Returns: (column type, is_batched) or (None, ...) if the type cannot be inferred"""
|
|
100
|
+
if annotation is None:
|
|
101
|
+
return (None, None)
|
|
102
|
+
py_type: Optional[type] = None
|
|
103
|
+
is_batched = False
|
|
104
|
+
if typing.get_origin(annotation) == typing.Annotated:
|
|
105
|
+
type_args = typing.get_args(annotation)
|
|
106
|
+
if len(type_args) == 2 and type_args[1] == 'pxt-batch':
|
|
107
|
+
# this is our Batch
|
|
108
|
+
assert typing.get_origin(type_args[0]) == list
|
|
109
|
+
is_batched = True
|
|
110
|
+
py_type = typing.get_args(type_args[0])[0]
|
|
111
|
+
if py_type is None:
|
|
112
|
+
py_type = annotation
|
|
113
|
+
col_type = ts.ColumnType.from_python_type(py_type)
|
|
114
|
+
return (col_type, is_batched)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def create(
|
|
118
|
+
cls, c: Callable,
|
|
119
|
+
param_types: Optional[List[ts.ColumnType]] = None,
|
|
120
|
+
return_type: Optional[Union[ts.ColumnType, Callable]] = None
|
|
121
|
+
) -> Signature:
|
|
122
|
+
"""Create a signature for the given Callable.
|
|
123
|
+
Infer the parameter and return types, if none are specified.
|
|
124
|
+
Raises an exception if the types cannot be inferred.
|
|
125
|
+
"""
|
|
126
|
+
sig = inspect.signature(c)
|
|
127
|
+
py_parameters = list(sig.parameters.values())
|
|
128
|
+
|
|
129
|
+
# check non-var parameters for name collisions and default value compatibility
|
|
130
|
+
parameters: List[Parameter] = []
|
|
131
|
+
for idx, param in enumerate(py_parameters):
|
|
132
|
+
if param.name in cls.SPECIAL_PARAM_NAMES:
|
|
133
|
+
raise excs.Error(f"'{param.name}' is a reserved parameter name")
|
|
134
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
135
|
+
parameters.append(Parameter(param.name, None, param.kind, False))
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
if param_types is not None:
|
|
139
|
+
if idx >= len(param_types):
|
|
140
|
+
raise excs.Error(f'Missing type for parameter {param.name}')
|
|
141
|
+
param_type = param_types[idx]
|
|
142
|
+
is_batched = False
|
|
143
|
+
else:
|
|
144
|
+
param_type, is_batched = cls._infer_type(param.annotation)
|
|
145
|
+
if param_type is None:
|
|
146
|
+
raise excs.Error(f'Cannot infer pixeltable type for parameter {param.name}')
|
|
147
|
+
|
|
148
|
+
# check default value compatibility
|
|
149
|
+
default_val = sig.parameters[param.name].default
|
|
150
|
+
if default_val != inspect.Parameter.empty and default_val is not None:
|
|
151
|
+
try:
|
|
152
|
+
_ = param_type.create_literal(default_val)
|
|
153
|
+
except TypeError as e:
|
|
154
|
+
raise excs.Error(f'Default value for parameter {param.name}: {str(e)}')
|
|
155
|
+
|
|
156
|
+
parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
|
|
157
|
+
|
|
158
|
+
return_is_batched = False
|
|
159
|
+
if return_type is None:
|
|
160
|
+
return_type, return_is_batched = cls._infer_type(sig.return_annotation)
|
|
161
|
+
if return_type is None:
|
|
162
|
+
raise excs.Error('Cannot infer pixeltable return type')
|
|
163
|
+
else:
|
|
164
|
+
_, return_is_batched = cls._infer_type(sig.return_annotation)
|
|
165
|
+
|
|
166
|
+
return Signature(return_type, parameters, return_is_batched)
|