pixeltable 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (140) hide show
  1. pixeltable/__init__.py +21 -4
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -31
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -48
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -86
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1086 -258
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -133
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.0.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.0.dist-info/METADATA +117 -0
  124. pixeltable-0.2.0.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.2.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.2.dist-info/LICENSE +0 -201
  139. pixeltable-0.1.2.dist-info/METADATA +0 -89
  140. pixeltable-0.1.2.dist-info/RECORD +0 -37
@@ -0,0 +1,227 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import importlib
5
+ import logging
6
+ import sys
7
+ import types
8
+ from typing import Optional, Dict, List, Tuple
9
+ from uuid import UUID
10
+
11
+ import cloudpickle
12
+ import sqlalchemy as sql
13
+
14
+ import pixeltable.env as env
15
+ import pixeltable.exceptions as excs
16
+ import pixeltable.type_system as ts
17
+ from pixeltable.metadata import schema
18
+ from .function import Function
19
+ from .globals import get_caller_module_path
20
+
21
+ _logger = logging.getLogger('pixeltable')
22
+
23
+ class FunctionRegistry:
24
+ """
25
+ A central registry for all Functions. Handles interactions with the backing store.
26
+ Function are loaded from the store on demand.
27
+ """
28
+ _instance: Optional[FunctionRegistry] = None
29
+
30
+ @classmethod
31
+ def get(cls) -> FunctionRegistry:
32
+ if cls._instance is None:
33
+ cls._instance = FunctionRegistry()
34
+ return cls._instance
35
+
36
+ def __init__(self):
37
+ self.stored_fns_by_id: Dict[UUID, Function] = {}
38
+ self.module_fns: Dict[str, Function] = {} # fqn -> Function
39
+
40
+ def clear_cache(self) -> None:
41
+ """
42
+ Useful during testing
43
+ """
44
+ self.stored_fns_by_id: Dict[UUID, Function] = {}
45
+
46
+ # def register_std_modules(self) -> None:
47
+ # """Register all submodules of pixeltable.functions"""
48
+ # root = sys.modules['pixeltable.functions']
49
+ # self.register_submodules(root)
50
+ #
51
+ # def register_submodules(self, mod: types.ModuleType) -> None:
52
+ # # TODO: this doesn't work
53
+ # for name, submod in mod.__dict__.items():
54
+ # if isinstance(submod, types.ModuleType):
55
+ # self.register_module(submod)
56
+ # self.register_submodules(submod)
57
+ #
58
+ # def register_module(self) -> None:
59
+ # """Register all Functions in the caller module"""
60
+ # caller_path = get_caller_module_path()
61
+ # mod = importlib.import_module(caller_path)
62
+ # for name in dir(mod):
63
+ # obj = getattr(mod, name)
64
+ # if isinstance(obj, Function):
65
+ # fn_path = f'{caller_path}.{name}' # fully-qualified name
66
+ # self.module_fns[fn_path] = obj
67
+
68
+ def register_function(self, fqn: str, fn: Function) -> None:
69
+ self.module_fns[fqn] = fn
70
+
71
+ def list_functions(self) -> List[Function]:
72
+ # retrieve Function.Metadata data for all existing stored functions from store directly
73
+ # (self.stored_fns_by_id isn't guaranteed to contain all functions)
74
+ # TODO: have the client do this, once the client takes over the Db functionality
75
+ # stmt = sql.select(
76
+ # schema.Function.name, schema.Function.md,
77
+ # schema.Db.name, schema.Dir.path, sql_func.length(schema.Function.init_obj))\
78
+ # .where(schema.Function.db_id == schema.Db.id)\
79
+ # .where(schema.Function.dir_id == schema.Dir.id)
80
+ # stored_fn_md: List[Function.Metadata] = []
81
+ # with Env.get().engine.begin() as conn:
82
+ # rows = conn.execute(stmt)
83
+ # for name, md_dict, db_name, dir_path, init_obj_len in rows:
84
+ # md = Function.Metadata.from_dict(md_dict)
85
+ # md.fqn = f'{db_name}{"." + dir_path if dir_path != "" else ""}.{name}'
86
+ # stored_fn_md.append(md)
87
+ return list(self.module_fns.values())
88
+
89
+ # def get_function(self, *, id: Optional[UUID] = None, fqn: Optional[str] = None) -> Function:
90
+ # assert (id is not None) != (fqn is not None)
91
+ # if id is not None:
92
+ # if id not in self.stored_fns_by_id:
93
+ # stmt = sql.select(
94
+ # schema.Function.md, schema.Function.eval_obj, schema.Function.init_obj,
95
+ # schema.Function.update_obj, schema.Function.value_obj) \
96
+ # .where(schema.Function.id == id)
97
+ # with env.Env.get().engine.begin() as conn:
98
+ # rows = conn.execute(stmt)
99
+ # row = next(rows)
100
+ # schema_md = schema.md_from_dict(schema.FunctionMd, row[0])
101
+ # name = schema_md.name
102
+ # md = FunctionMd.from_dict(schema_md.md)
103
+ # # md.fqn is set by caller
104
+ # eval_fn = cloudpickle.loads(row[1]) if row[1] is not None else None
105
+ # # TODO: are these checks needed?
106
+ # if row[1] is not None and eval_fn is None:
107
+ # raise excs.Error(f'Could not load eval_fn for function {name}')
108
+ # init_fn = cloudpickle.loads(row[2]) if row[2] is not None else None
109
+ # if row[2] is not None and init_fn is None:
110
+ # raise excs.Error(f'Could not load init_fn for aggregate function {name}')
111
+ # update_fn = cloudpickle.loads(row[3]) if row[3] is not None else None
112
+ # if row[3] is not None and update_fn is None:
113
+ # raise excs.Error(f'Could not load update_fn for aggregate function {name}')
114
+ # value_fn = cloudpickle.loads(row[4]) if row[4] is not None else None
115
+ # if row[4] is not None and value_fn is None:
116
+ # raise excs.Error(f'Could not load value_fn for aggregate function {name}')
117
+ #
118
+ # func = Function(
119
+ # md, id=id,
120
+ # eval_fn=eval_fn, init_fn=init_fn, update_fn=update_fn, value_fn=value_fn)
121
+ # _logger.info(f'Loaded function {name} from store')
122
+ # self.stored_fns_by_id[id] = func
123
+ # assert id in self.stored_fns_by_id
124
+ # return self.stored_fns_by_id[id]
125
+ # else:
126
+ # # this is an already-registered library function
127
+ # assert fqn in self.module_fns, f'{fqn} not found'
128
+ # return self.module_fns[fqn]
129
+
130
+ def get_type_methods(self, name: str, base_type: ts.ColumnType.Type) -> List[Function]:
131
+ return [
132
+ fn for fn in self.module_fns.values()
133
+ if fn.self_path is not None and fn.self_path.endswith('.' + name) \
134
+ and fn.signature.parameters_by_pos[0].col_type.type_enum == base_type
135
+ ]
136
+
137
+ #def create_function(self, md: schema.FunctionMd, binary_obj: bytes, dir_id: Optional[UUID] = None) -> UUID:
138
+ def create_stored_function(self, pxt_fn: Function, dir_id: Optional[UUID] = None) -> UUID:
139
+ fn_md, binary_obj = pxt_fn.to_store()
140
+ md = schema.FunctionMd(name=pxt_fn.name, md=fn_md, py_version=sys.version, class_name=pxt_fn.__class__.__name__)
141
+ with env.Env.get().engine.begin() as conn:
142
+ res = conn.execute(
143
+ sql.insert(schema.Function.__table__)
144
+ .values(dir_id=dir_id, md=dataclasses.asdict(md), binary_obj=binary_obj))
145
+ id = res.inserted_primary_key[0]
146
+ _logger.info(f'Created function {pxt_fn.name} (id {id}) in store')
147
+ self.stored_fns_by_id[id] = pxt_fn
148
+ return id
149
+
150
+ def get_stored_function(self, id: UUID) -> Function:
151
+ if id in self.stored_fns_by_id:
152
+ return self.stored_fns_by_id[id]
153
+ stmt = sql.select(schema.Function.md, schema.Function.binary_obj, schema.Function.dir_id)\
154
+ .where(schema.Function.id == id)
155
+ with env.Env.get().engine.begin() as conn:
156
+ row = conn.execute(stmt).fetchone()
157
+ if row is None:
158
+ raise excs.Error(f'Function with id {id} not found')
159
+ # create instance of the referenced class
160
+ md = schema.md_from_dict(schema.FunctionMd, row[0])
161
+ func_module = importlib.import_module(self.__module__.rsplit('.', 1)[0])
162
+ func_class = getattr(func_module, md.class_name)
163
+ instance = func_class.from_store(md.name, md.md, row[1])
164
+ self.stored_fns_by_id[id] = instance
165
+ return instance
166
+
167
+ # def create_function(self, fn: Function, dir_id: Optional[UUID] = None, name: Optional[str] = None) -> None:
168
+ # with env.Env.get().engine.begin() as conn:
169
+ # _logger.debug(f'Pickling function {name}')
170
+ # eval_fn_str = cloudpickle.dumps(fn.eval_fn) if fn.eval_fn is not None else None
171
+ # init_fn_str = cloudpickle.dumps(fn.init_fn) if fn.init_fn is not None else None
172
+ # update_fn_str = cloudpickle.dumps(fn.update_fn) if fn.update_fn is not None else None
173
+ # value_fn_str = cloudpickle.dumps(fn.value_fn) if fn.value_fn is not None else None
174
+ # total_size = \
175
+ # (len(eval_fn_str) if eval_fn_str is not None else 0) + \
176
+ # (len(init_fn_str) if init_fn_str is not None else 0) + \
177
+ # (len(update_fn_str) if update_fn_str is not None else 0) + \
178
+ # (len(value_fn_str) if value_fn_str is not None else 0)
179
+ # _logger.debug(f'Pickled function {name} ({total_size} bytes)')
180
+ #
181
+ # schema_md = schema.FunctionMd(name=name, md=fn.md.as_dict())
182
+ # res = conn.execute(
183
+ # sql.insert(schema.Function.__table__)
184
+ # .values(
185
+ # dir_id=dir_id, md=dataclasses.asdict(schema_md),
186
+ # eval_obj=eval_fn_str, init_obj=init_fn_str, update_obj=update_fn_str, value_obj=value_fn_str))
187
+ # fn.id = res.inserted_primary_key[0]
188
+ # self.stored_fns_by_id[fn.id] = fn
189
+ # _logger.info(f'Created function {name} in store')
190
+
191
+ # def update_function(self, id: UUID, new_fn: Function) -> None:
192
+ # """
193
+ # Updates the callables for the Function with the given id in the store and in the cache, if present.
194
+ # """
195
+ # assert not new_fn.is_module_function
196
+ # with env.Env.get().engine.begin() as conn:
197
+ # updates = {}
198
+ # if new_fn.eval_fn is not None:
199
+ # updates[schema.Function.eval_obj] = cloudpickle.dumps(new_fn.eval_fn)
200
+ # if new_fn.init_fn is not None:
201
+ # updates[schema.Function.init_obj] = cloudpickle.dumps(new_fn.init_fn)
202
+ # if new_fn.update_fn is not None:
203
+ # updates[schema.Function.update_obj] = cloudpickle.dumps(new_fn.update_fn)
204
+ # if new_fn.value_fn is not None:
205
+ # updates[schema.Function.value_obj] = cloudpickle.dumps(new_fn.value_fn)
206
+ # conn.execute(
207
+ # sql.update(schema.Function.__table__)
208
+ # .values(updates)
209
+ # .where(schema.Function.id == id))
210
+ # _logger.info(f'Updated function {new_fn.md.fqn} (id={id}) in store')
211
+ # if id in self.stored_fns_by_id:
212
+ # if new_fn.eval_fn is not None:
213
+ # self.stored_fns_by_id[id].eval_fn = new_fn.eval_fn
214
+ # if new_fn.init_fn is not None:
215
+ # self.stored_fns_by_id[id].init_fn = new_fn.init_fn
216
+ # if new_fn.update_fn is not None:
217
+ # self.stored_fns_by_id[id].update_fn = new_fn.update_fn
218
+ # if new_fn.value_fn is not None:
219
+ # self.stored_fns_by_id[id].value_fn = new_fn.value_fn
220
+
221
+ def delete_function(self, id: UUID) -> None:
222
+ assert id is not None
223
+ with env.Env.get().engine.begin() as conn:
224
+ conn.execute(
225
+ sql.delete(schema.Function.__table__)
226
+ .where(schema.Function.id == id))
227
+ _logger.info(f'Deleted function with id {id} from store')
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+ from types import ModuleType
3
+ import importlib
4
+ import inspect
5
+
6
+
7
+ def resolve_symbol(symbol_path: str) -> object:
8
+ path_elems = symbol_path.split('.')
9
+ module: Optional[ModuleType] = None
10
+ if path_elems[0:2] == ['pixeltable', 'functions'] and len(path_elems) > 2:
11
+ # if this is a pixeltable.functions submodule, it cannot be resolved via pixeltable.functions;
12
+ # try to import the submodule directly
13
+ submodule_path = '.'.join(path_elems[0:3])
14
+ try:
15
+ module = importlib.import_module(submodule_path)
16
+ path_elems = path_elems[3:]
17
+ except ModuleNotFoundError:
18
+ pass
19
+ if module is None:
20
+ module = importlib.import_module(path_elems[0])
21
+ path_elems = path_elems[1:]
22
+ obj = module
23
+ for el in path_elems:
24
+ obj = getattr(obj, el)
25
+ return obj
26
+
27
+ def get_caller_module_path() -> str:
28
+ """Return the module path of our caller's caller"""
29
+ stack = inspect.stack()
30
+ try:
31
+ caller_frame = stack[2].frame
32
+ module_path = caller_frame.f_globals['__name__']
33
+ finally:
34
+ # remove references to stack frames to avoid reference cycles
35
+ del stack
36
+ return module_path
@@ -0,0 +1,202 @@
1
+ from typing import Optional, Any, Dict, List, Tuple
2
+ import inspect
3
+ import logging
4
+ import sys
5
+
6
+ import numpy as np
7
+
8
+ from .signature import Signature, Parameter
9
+ from .batched_function import BatchedFunction
10
+ import pixeltable.env as env
11
+ import pixeltable.type_system as ts
12
+ import pixeltable.exceptions as excs
13
+
14
+
15
+ _logger = logging.getLogger('pixeltable')
16
+
17
+ class NOSFunction(BatchedFunction):
18
+ def __init__(self, model_spec: 'nos.common.ModelSpec', self_path: str):
19
+ return_type, param_types = self._convert_nos_signature(model_spec.signature)
20
+ param_names = list(model_spec.signature.get_inputs_spec().keys())
21
+ params = [
22
+ Parameter(name, col_type, inspect.Parameter.POSITIONAL_OR_KEYWORD, is_batched=False)
23
+ for name, col_type in zip(param_names, param_types)
24
+ ]
25
+ signature = Signature(return_type, params)
26
+
27
+ # construct inspect.Signature
28
+ py_params = [
29
+ inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
30
+ for name, col_type in zip(param_names, param_types)
31
+ ]
32
+ py_signature = inspect.Signature(py_params)
33
+ super().__init__(signature, py_signature=py_signature, self_path=self_path)
34
+
35
+ self.model_spec = model_spec
36
+ self.nos_param_names = model_spec.signature.get_inputs_spec().keys()
37
+ self.scalar_nos_param_names = []
38
+
39
+ # for models on images
40
+ self.img_param_pos: Optional[int] = None # position of the image parameter in the function signature
41
+ # for multi-resolution image models
42
+ import nos
43
+ self.img_batch_params: List[nos.common.ObjectTypeInfo] = []
44
+ self.img_resolutions: List[int] = [] # for multi-resolution models
45
+ self.batch_size: Optional[int] = None
46
+ self.img_size: Optional[Tuple[int, int]] = None # W, H
47
+
48
+ # try to determine batch_size and img_size
49
+ batch_size = sys.maxsize
50
+ for pos, (param_name, type_info) in enumerate(model_spec.signature.get_inputs_spec().items()):
51
+ if isinstance(type_info, list):
52
+ assert isinstance(type_info[0].base_spec(), nos.common.ImageSpec)
53
+ # this is a multi-resolution image model
54
+ self.img_batch_params = type_info
55
+ self.img_param_pos = pos
56
+ self.img_resolutions = [
57
+ info.base_spec().shape[0] * info.base_spec().shape[1] for info in self.img_batch_params
58
+ ]
59
+ else:
60
+ if not type_info.is_batched():
61
+ self.scalar_nos_param_names.append(param_name)
62
+ else:
63
+ batch_size = min(batch_size, type_info.batch_size())
64
+
65
+ if isinstance(type_info.base_spec(), nos.common.ImageSpec):
66
+ # this is a single-resolution image model
67
+ if type_info.base_spec().shape is not None:
68
+ self.img_size = (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
69
+ self.img_param_pos = pos
70
+
71
+ if batch_size != sys.maxsize:
72
+ self.batch_size = batch_size
73
+
74
+ def _convert_nos_type(
75
+ self, type_info: 'nos.common.spec.ObjectTypeInfo', ignore_shape: bool = False
76
+ ) -> ts.ColumnType:
77
+ """Convert ObjectTypeInfo to ColumnType"""
78
+ import nos
79
+ if type_info.base_spec() is None:
80
+ if type_info.base_type() == str:
81
+ return ts.StringType()
82
+ if type_info.base_type() == int:
83
+ return ts.IntType()
84
+ if type_info.base_type() == float:
85
+ return ts.FloatType()
86
+ if type_info.base_type() == bool:
87
+ return ts.BoolType()
88
+ else:
89
+ raise excs.Error(f'Cannot convert {type_info} to ColumnType')
90
+ elif isinstance(type_info.base_spec(), nos.common.ImageSpec):
91
+ size = None
92
+ if not ignore_shape and type_info.base_spec().shape is not None:
93
+ size = (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
94
+ # TODO: set mode
95
+ return ts.ImageType(size=size)
96
+ elif isinstance(type_info.base_spec(), nos.common.TensorSpec):
97
+ return ts.ArrayType(shape=type_info.base_spec().shape, dtype=ts.FloatType())
98
+ else:
99
+ raise excs.Error(f'Cannot convert {type_info} to ColumnType')
100
+
101
+ def _convert_nos_signature(
102
+ self, sig: 'nos.common.spec.FunctionSignature') -> Tuple[ts.ColumnType, List[ts.ColumnType]]:
103
+ if len(sig.get_outputs_spec()) > 1:
104
+ return_type = ts.JsonType()
105
+ else:
106
+ return_type = self._convert_nos_type(list(sig.get_outputs_spec().values())[0])
107
+ param_types: List[ts.ColumnType] = []
108
+ for _, type_info in sig.get_inputs_spec().items():
109
+ # if there are multiple input shapes we leave them out of the ColumnType and deal with them in FunctionCall
110
+ if isinstance(type_info, list):
111
+ param_types.append(self._convert_nos_type(type_info[0], ignore_shape=True))
112
+ else:
113
+ param_types.append(self._convert_nos_type(type_info, ignore_shape=False))
114
+ return return_type, param_types
115
+
116
+ def is_multi_res_model(self) -> bool:
117
+ return self.img_param_pos is not None and len(self.img_batch_params) > 0
118
+
119
+ def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
120
+ if self.batch_size is not None or len(self.img_batch_params) == 0 or len(args) == 0:
121
+ return self.batch_size
122
+
123
+ # return batch size appropriate for the given image size
124
+ img_arg = args[self.img_param_pos]
125
+ input_res = img_arg.size[0] * img_arg.size[1]
126
+ batch_size, _ = self._select_model_res(input_res)
127
+ return batch_size
128
+
129
+ def _select_model_res(self, input_res: int) -> Tuple[int, Tuple[int, int]]:
130
+ """Select the model resolution that is closest to the input resolution
131
+ Returns: batch size, image size
132
+ """
133
+ deltas = [abs(res - input_res) for res in self.img_resolutions]
134
+ idx = deltas.index(min(deltas))
135
+ type_info = self.img_batch_params[idx]
136
+ return type_info.batch_size(), (type_info.base_spec().shape[1], type_info.base_spec().shape[0])
137
+
138
+ def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
139
+ # check that scalar args are constant
140
+
141
+ num_batch_rows = len(arg_batches[0])
142
+ # if we need to rescale image args, and we're doing object detection, we need to rescale the
143
+ # bounding boxes as well
144
+ scale_factors = np.ndarray((num_batch_rows, 2), dtype=np.float32)
145
+
146
+ target_res: Optional[Tuple[int, int]] = None
147
+ if self.img_param_pos is not None:
148
+ # for now, NOS will only receive RGB images
149
+ arg_batches[self.img_param_pos] = \
150
+ [img.convert('RGB') if img.mode != 'RGB' else img for img in arg_batches[self.img_param_pos ]]
151
+ if self.is_multi_res_model():
152
+ # we need to select the resolution that is closest to the input resolution
153
+ sample_img = arg_batches[self.img_param_pos][0]
154
+ _, target_res = self._select_model_res(sample_img.size[0] * sample_img.size[1])
155
+ else:
156
+ target_res = self.img_size
157
+
158
+ if target_res is not None:
159
+ # we need to record the scale factors and resize the images;
160
+ # keep in mind that every image could have a different resolution
161
+ scale_factors[:, 0] = \
162
+ [img.size[0] / target_res[0] for img in arg_batches[self.img_param_pos]]
163
+ scale_factors[:, 1] = \
164
+ [img.size[1] / target_res[1] for img in arg_batches[self.img_param_pos]]
165
+ arg_batches[self.img_param_pos] = [
166
+ # only resize if necessary
167
+ img.resize(target_res) if img.size != target_res else img
168
+ for img in arg_batches[self.img_param_pos]
169
+ ]
170
+
171
+ kwargs = {param_name: args for param_name, args in zip(self.nos_param_names, arg_batches)}
172
+ # fix up scalar parameters
173
+ kwargs.update(
174
+ {param_name: kwargs[param_name][0] for param_name in self.scalar_nos_param_names})
175
+ _logger.debug(
176
+ f'Running NOS task {self.model_spec.task}: '
177
+ f'batch_size={num_batch_rows} target_res={target_res}')
178
+ result = env.Env.get().nos_client.Run(
179
+ task=self.model_spec.task, model_name=self.model_spec.name, **kwargs)
180
+
181
+ import nos
182
+ if self.model_spec.task == nos.common.TaskType.OBJECT_DETECTION_2D and target_res is not None:
183
+ # we need to rescale the bounding boxes
184
+ result_bboxes = [] # workaround: result['bboxes'][*] is immutable
185
+ for i, bboxes in enumerate(result['bboxes']):
186
+ bboxes = np.copy(bboxes)
187
+ nos_batch_row_idx = i
188
+ bboxes[:, 0] *= scale_factors[nos_batch_row_idx, 0]
189
+ bboxes[:, 1] *= scale_factors[nos_batch_row_idx, 1]
190
+ bboxes[:, 2] *= scale_factors[nos_batch_row_idx, 0]
191
+ bboxes[:, 3] *= scale_factors[nos_batch_row_idx, 1]
192
+ result_bboxes.append(bboxes)
193
+ result['bboxes'] = result_bboxes
194
+
195
+ if len(result) == 1:
196
+ key = list(result.keys())[0]
197
+ row_results = result[key]
198
+ else:
199
+ # we rearrange result into one dict per row
200
+ row_results = [{k: v[i].tolist() for k, v in result.items()} for i in range(num_batch_rows)]
201
+ return row_results
202
+
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import enum
5
+ import inspect
6
+ import logging
7
+ import typing
8
+ from typing import Optional, Callable, Dict, List, Any, Union, Tuple
9
+
10
+ import pixeltable.exceptions as excs
11
+ import pixeltable.type_system as ts
12
+
13
+ _logger = logging.getLogger('pixeltable')
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Parameter:
18
+ name: str
19
+ col_type: Optional[ts.ColumnType] # None for variable parameters
20
+ kind: enum.Enum # inspect.Parameter.kind; inspect._ParameterKind is private
21
+ is_batched: bool = False # True if the parameter is a batched parameter (eg, Batch[dict])
22
+
23
+
24
+ T = typing.TypeVar('T')
25
+ Batch = typing.Annotated[list[T], 'pxt-batch']
26
+
27
+
28
+ class Signature:
29
+ """
30
+ Represents the signature of a Pixeltable function.
31
+
32
+ Regarding return type:
33
+ - most functions will have a fixed return type, which is specified directly
34
+ - some functions will have a return type that depends on the argument values;
35
+ ex.: PIL.Image.Image.resize() returns an image with dimensions specified as a parameter
36
+ - in the latter case, the 'return_type' field is a function that takes the bound arguments and returns the
37
+ return type; if no bound arguments are specified, a generic return type is returned (eg, ImageType() without a
38
+ size)
39
+ - self.is_batched: return type is a Batch[...] type
40
+ """
41
+ SPECIAL_PARAM_NAMES = ['group_by', 'order_by']
42
+
43
+ def __init__(
44
+ self,
45
+ return_type: Union[ts.ColumnType, Callable[[Dict[str, Any]], ts.ColumnType]],
46
+ parameters: List[Parameter], is_batched: bool = False):
47
+ self.return_type = return_type
48
+ self.is_batched = is_batched
49
+ # we rely on the ordering guarantee of dicts in Python >=3.7
50
+ self.parameters = {p.name: p for p in parameters}
51
+ self.parameters_by_pos = parameters.copy()
52
+ self.constant_parameters = [p for p in parameters if not p.is_batched]
53
+ self.batched_parameters = [p for p in parameters if p.is_batched]
54
+
55
+ def get_return_type(self, bound_args: Optional[Dict[str, Any]] = None) -> ts.ColumnType:
56
+ if isinstance(self.return_type, ts.ColumnType):
57
+ return self.return_type
58
+ return self.return_type(bound_args)
59
+
60
+ def as_dict(self) -> Dict[str, Any]:
61
+ result = {
62
+ 'return_type': self.get_return_type().as_dict(),
63
+ 'parameters': [
64
+ [p.name, p.col_type.as_dict() if p.col_type is not None else None, p.kind, p.is_batched]
65
+ for p in self.parameters.values()
66
+ ]
67
+ }
68
+ return result
69
+
70
+ @classmethod
71
+ def from_dict(cls, d: Dict[str, Any]) -> Signature:
72
+ parameters = [Parameter(p[0], ts.ColumnType.from_dict(p[1]), p[2], p[3]) for p in d['parameters']]
73
+ return cls(ts.ColumnType.from_dict(d['return_type']), parameters)
74
+
75
+ def __eq__(self, other: Signature) -> bool:
76
+ if self.get_return_type() != other.get_return_type():
77
+ return False
78
+ if len(self.parameters) != len(other.parameters):
79
+ return False
80
+ # ignore the parameter name
81
+ for param, other_param in zip(self.parameters.values(), other.parameters.values()):
82
+ if param.col_type != other_param.col_type or param.kind != other_param.kind:
83
+ return False
84
+ return True
85
+
86
+ def __str__(self) -> str:
87
+ param_strs: List[str] = []
88
+ for p in self.parameters.values():
89
+ if p.kind == inspect.Parameter.VAR_POSITIONAL:
90
+ param_strs.append(f'*{p.name}')
91
+ elif p.kind == inspect.Parameter.VAR_KEYWORD:
92
+ param_strs.append(f'**{p.name}')
93
+ else:
94
+ param_strs.append(f'{p.name}: {str(p.col_type)}')
95
+ return f'({", ".join(param_strs)}) -> {str(self.get_return_type())}'
96
+
97
+ @classmethod
98
+ def _infer_type(cls, annotation: Optional[type]) -> Tuple[Optional[ts.ColumnType], Optional[bool]]:
99
+ """Returns: (column type, is_batched) or (None, ...) if the type cannot be inferred"""
100
+ if annotation is None:
101
+ return (None, None)
102
+ py_type: Optional[type] = None
103
+ is_batched = False
104
+ if typing.get_origin(annotation) == typing.Annotated:
105
+ type_args = typing.get_args(annotation)
106
+ if len(type_args) == 2 and type_args[1] == 'pxt-batch':
107
+ # this is our Batch
108
+ assert typing.get_origin(type_args[0]) == list
109
+ is_batched = True
110
+ py_type = typing.get_args(type_args[0])[0]
111
+ if py_type is None:
112
+ py_type = annotation
113
+ col_type = ts.ColumnType.from_python_type(py_type)
114
+ return (col_type, is_batched)
115
+
116
+ @classmethod
117
+ def create(
118
+ cls, c: Callable,
119
+ param_types: Optional[List[ts.ColumnType]] = None,
120
+ return_type: Optional[Union[ts.ColumnType, Callable]] = None
121
+ ) -> Signature:
122
+ """Create a signature for the given Callable.
123
+ Infer the parameter and return types, if none are specified.
124
+ Raises an exception if the types cannot be inferred.
125
+ """
126
+ sig = inspect.signature(c)
127
+ py_parameters = list(sig.parameters.values())
128
+
129
+ # check non-var parameters for name collisions and default value compatibility
130
+ parameters: List[Parameter] = []
131
+ for idx, param in enumerate(py_parameters):
132
+ if param.name in cls.SPECIAL_PARAM_NAMES:
133
+ raise excs.Error(f"'{param.name}' is a reserved parameter name")
134
+ if param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD:
135
+ parameters.append(Parameter(param.name, None, param.kind, False))
136
+ continue
137
+
138
+ if param_types is not None:
139
+ if idx >= len(param_types):
140
+ raise excs.Error(f'Missing type for parameter {param.name}')
141
+ param_type = param_types[idx]
142
+ is_batched = False
143
+ else:
144
+ param_type, is_batched = cls._infer_type(param.annotation)
145
+ if param_type is None:
146
+ raise excs.Error(f'Cannot infer pixeltable type for parameter {param.name}')
147
+
148
+ # check default value compatibility
149
+ default_val = sig.parameters[param.name].default
150
+ if default_val != inspect.Parameter.empty and default_val is not None:
151
+ try:
152
+ _ = param_type.create_literal(default_val)
153
+ except TypeError as e:
154
+ raise excs.Error(f'Default value for parameter {param.name}: {str(e)}')
155
+
156
+ parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
157
+
158
+ return_is_batched = False
159
+ if return_type is None:
160
+ return_type, return_is_batched = cls._infer_type(sig.return_annotation)
161
+ if return_type is None:
162
+ raise excs.Error('Cannot infer pixeltable return type')
163
+ else:
164
+ _, return_is_batched = cls._infer_type(sig.return_annotation)
165
+
166
+ return Signature(return_type, parameters, return_is_batched)