pixie-prompts 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,205 @@
1
+ """GraphQL schema for SDK server."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ from graphql import GraphQLError
7
+ import strawberry
8
+ from strawberry.scalars import JSON
9
+
10
+ from pixie.prompts.prompt import variables_definition_to_schema
11
+ from pixie.prompts.prompt_management import get_prompt, list_prompts
12
+ from importlib.metadata import PackageNotFoundError, version
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @strawberry.type
18
+ class PromptMetadata:
19
+ """Metadata for a registered prompt via create_prompt."""
20
+
21
+ id: strawberry.ID
22
+ variables_schema: JSON
23
+ version_count: int
24
+ description: Optional[str] = None
25
+ module: Optional[str] = None
26
+
27
+
28
+ @strawberry.input
29
+ class IKeyValue:
30
+ """Key-value attribute."""
31
+
32
+ key: str
33
+ value: str
34
+
35
+
36
+ @strawberry.type
37
+ class TKeyValue:
38
+ """Key-value attribute."""
39
+
40
+ key: str
41
+ value: str
42
+
43
+
44
+ @strawberry.type
45
+ class Prompt:
46
+ """Full prompt information including versions."""
47
+
48
+ id: strawberry.ID
49
+ variables_schema: JSON
50
+ versions: list[TKeyValue]
51
+ default_version_id: str | None
52
+ """default version id can only be None if versions is empty"""
53
+ description: Optional[str] = None
54
+ module: Optional[str] = None
55
+
56
+
57
+ @strawberry.type
58
+ class Query:
59
+ """GraphQL queries."""
60
+
61
+ @strawberry.field
62
+ async def health_check(self) -> str:
63
+ """Health check endpoint."""
64
+ logger.debug("Health check endpoint called")
65
+ try:
66
+ version_str = version("pixie-sdk")
67
+ logger.debug("Pixie SDK version: %s", version_str)
68
+ return version_str
69
+ except PackageNotFoundError as e:
70
+ logger.warning("Failed to get Pixie SDK version: %s", str(e))
71
+ return "0.0.0"
72
+
73
+ @strawberry.field
74
+ def list_prompts(self) -> list[PromptMetadata]:
75
+ """List all registered prompt templates.
76
+
77
+ Returns:
78
+ A list of PromptMetadata objects containing id, variables_schema, version_count,
79
+ description, and module for each registered prompt.
80
+ """
81
+
82
+ return [
83
+ PromptMetadata(
84
+ id=strawberry.ID(p.prompt.id),
85
+ variables_schema=JSON(
86
+ # NOTE: avoid p.get_variables_schema() to prevent potential fetching from storage
87
+ # this in theory could be different from the stored schema but in practice should not be
88
+ variables_definition_to_schema(p.prompt.variables_definition)
89
+ ),
90
+ version_count=p.prompt.get_version_count(),
91
+ description=p.description,
92
+ module=p.module,
93
+ )
94
+ for p in list_prompts()
95
+ ]
96
+
97
+ @strawberry.field
98
+ async def get_prompt(self, id: strawberry.ID) -> Prompt:
99
+ """Get full prompt information including versions.
100
+
101
+ Args:
102
+ id: The unique identifier of the prompt.
103
+ Returns:
104
+ Prompt object containing id, variables_schema, versions,
105
+ and default_version_id.
106
+ Raises:
107
+ GraphQLError: If prompt with given id is not found.
108
+ """
109
+ prompt_with_registration = get_prompt((str(id)))
110
+ if prompt_with_registration is None:
111
+ raise GraphQLError(f"Prompt with id '{id}' not found.")
112
+ prompt = prompt_with_registration.prompt
113
+ if not prompt.exists_in_storage():
114
+ return Prompt(
115
+ id=id,
116
+ variables_schema=JSON(
117
+ # NOTE: avoid prompt.get_variables_schema() to prevent potential fetching from storage
118
+ variables_definition_to_schema(prompt.variables_definition)
119
+ ),
120
+ versions=[],
121
+ default_version_id=None,
122
+ description=prompt_with_registration.description,
123
+ module=prompt_with_registration.module,
124
+ )
125
+ versions_dict = prompt.get_versions()
126
+ versions = [TKeyValue(key=k, value=v) for k, v in versions_dict.items()]
127
+ default_version_id: str = prompt.get_default_version_id()
128
+ variables_schema = prompt.get_variables_schema()
129
+ return Prompt(
130
+ id=id,
131
+ variables_schema=JSON(variables_schema),
132
+ versions=versions,
133
+ default_version_id=default_version_id,
134
+ description=prompt_with_registration.description,
135
+ module=prompt_with_registration.module,
136
+ )
137
+
138
+
139
+ @strawberry.type
140
+ class Mutation:
141
+ """GraphQL mutations."""
142
+
143
+ @strawberry.mutation
144
+ async def add_prompt_version(
145
+ self,
146
+ prompt_id: strawberry.ID,
147
+ version_id: str,
148
+ content: str,
149
+ set_as_default: bool = False,
150
+ ) -> str:
151
+ """Add a new version to an existing prompt.
152
+
153
+ Args:
154
+ prompt_id: The unique identifier of the prompt.
155
+ version_id: The identifier for the new version.
156
+ content: The content of the new prompt version.
157
+ set_as_default: Whether to set this version as the default.
158
+
159
+ Returns:
160
+ The updated BasePrompt object.
161
+ """
162
+ prompt_with_registration = get_prompt((str(prompt_id)))
163
+ if prompt_with_registration is None:
164
+ raise GraphQLError(f"Prompt with id '{prompt_id}' not found.")
165
+ prompt = prompt_with_registration.prompt
166
+ try:
167
+ prompt.append_version(
168
+ version_id=version_id,
169
+ content=content,
170
+ set_as_default=set_as_default,
171
+ )
172
+ except Exception as e:
173
+ raise GraphQLError(f"Failed to add prompt version: {str(e)}") from e
174
+ return "OK"
175
+
176
+ @strawberry.mutation
177
+ async def update_default_prompt_version(
178
+ self,
179
+ prompt_id: strawberry.ID,
180
+ default_version_id: str,
181
+ ) -> str:
182
+ """Update the default version of an existing prompt.
183
+
184
+ Args:
185
+ prompt_id: The unique identifier of the prompt.
186
+ default_version_id: The identifier of the version to set as default.
187
+
188
+ Returns:
189
+ True if the update was successful.
190
+ """
191
+ prompt_with_registration = get_prompt((str(prompt_id)))
192
+ if prompt_with_registration is None:
193
+ raise GraphQLError(f"Prompt with id '{prompt_id}' not found.")
194
+ prompt = prompt_with_registration.prompt
195
+ try:
196
+ prompt.update_default_version_id(default_version_id)
197
+ except Exception as e:
198
+ raise GraphQLError(
199
+ f"Failed to update default prompt version: {str(e)}"
200
+ ) from e
201
+ return "OK"
202
+
203
+
204
+ # Create the schema
205
+ schema = strawberry.Schema(query=Query, mutation=Mutation)
@@ -0,0 +1,373 @@
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ import json
4
+ from types import NoneType
5
+ from typing import Any, Generic, Protocol, Self, TypeVar, cast, overload
6
+ from uuid import uuid4
7
+
8
+ import jinja2
9
+ from jinja2 import StrictUndefined
10
+ from jsonsubschema import isSubschema
11
+ from pydantic import BaseModel
12
+
13
+
14
+ class PromptVariables(BaseModel):
15
+ # TODO add validation to prevent fields using reserved names
16
+ pass
17
+
18
+
19
+ TPromptVar = TypeVar("TPromptVar", bound=PromptVariables | None)
20
+
21
+
22
+ _prompt_registry: dict[str, "BasePrompt"] = {}
23
+ """Registry of all actualized prompts.
24
+
25
+ Purpose of the registry is to ensure there's single actualized prompt instance per prompt ID globally,
26
+ so that every compiled prompt can track back to one single instance of the prompt it was compiled from.
27
+ """
28
+
29
+
30
+ def get_prompt_by_id(prompt_id: str) -> "BasePrompt":
31
+ return _prompt_registry[prompt_id]
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class _CompiledPrompt:
36
+ value: str
37
+ prompt: "BasePrompt | OutdatedPrompt"
38
+ version_id: str
39
+ variables: PromptVariables | None
40
+
41
+
42
+ _compiled_prompt_registry: dict[int, _CompiledPrompt] = {}
43
+ """Registry of all compiled prompts.
44
+
45
+ This is to keep track of every result string returned by BasePrompt.compile().
46
+ key is the id() of the compiled string."""
47
+
48
+
49
+ def _find_matching_prompt(obj):
50
+ if isinstance(obj, str):
51
+ for compiled in _compiled_prompt_registry.values():
52
+ if compiled.value == obj:
53
+ return compiled
54
+ return None
55
+ elif isinstance(obj, dict):
56
+ for value in obj.values():
57
+ result = _find_matching_prompt(value)
58
+ if result:
59
+ return result
60
+ return None
61
+ elif isinstance(obj, list):
62
+ for item in obj:
63
+ result = _find_matching_prompt(item)
64
+ if result:
65
+ return result
66
+ return None
67
+ else:
68
+ return None
69
+
70
+
71
+ def get_compiled_prompt(text: str) -> _CompiledPrompt | None:
72
+ """Find the compiled prompt metadata for a given compiled prompt string."""
73
+ if not _compiled_prompt_registry:
74
+ return None
75
+ direct_match = _compiled_prompt_registry.get(id(text))
76
+ if direct_match:
77
+ return direct_match
78
+ for compiled in _compiled_prompt_registry.values():
79
+ if compiled.value == text:
80
+ return compiled
81
+ try:
82
+ obj = json.loads(text)
83
+ return _find_matching_prompt(obj)
84
+
85
+ except json.JSONDecodeError:
86
+ return None
87
+
88
+
89
+ def _mark_compiled_prompts_outdated(
90
+ prompt_id: str, outdated_prompt: "OutdatedPrompt"
91
+ ) -> None:
92
+ for key in list(_compiled_prompt_registry.keys()):
93
+ compiled_prompt = _compiled_prompt_registry[key]
94
+ if compiled_prompt.prompt.id == prompt_id:
95
+ _compiled_prompt_registry[key] = _CompiledPrompt(
96
+ value=compiled_prompt.value,
97
+ version_id=compiled_prompt.version_id,
98
+ variables=compiled_prompt.variables,
99
+ prompt=outdated_prompt,
100
+ )
101
+
102
+
103
+ DEFAULT_VERSION_ID = "v0"
104
+
105
+
106
+ def _to_versions_dict(versions: str | dict[str, str]) -> dict[str, str]:
107
+ if isinstance(versions, str):
108
+ return {DEFAULT_VERSION_ID: versions}
109
+ return deepcopy(versions)
110
+
111
+
112
+ class _UnTypedPrompt(Protocol):
113
+ @property
114
+ def id(self) -> str: ...
115
+
116
+ def get_default_version_id(self) -> str: ...
117
+
118
+ def get_versions(self) -> dict[str, str]: ...
119
+
120
+ def get_variables_schema(self) -> dict[str, Any]: ...
121
+
122
+
123
+ EMPTY_VARIABLES_SCHEMA = {"type": "object", "properties": {}}
124
+
125
+
126
+ class BaseUntypedPrompt(_UnTypedPrompt):
127
+
128
+ def __init__(
129
+ self,
130
+ *,
131
+ versions: str | dict[str, str],
132
+ default_version_id: str | None = None,
133
+ id: str | None = None,
134
+ variables_schema: dict[str, Any] | None = None,
135
+ ) -> None:
136
+ if not id:
137
+ id = uuid4().hex
138
+ while id in _prompt_registry:
139
+ id = uuid4().hex
140
+
141
+ self._id = id
142
+ if not versions:
143
+ raise ValueError("No versions provided for the prompt.")
144
+ self._versions: dict[str, str]
145
+ self._versions = _to_versions_dict(versions)
146
+ self._default_version = default_version_id or next(iter(self._versions))
147
+ self._variables_schema = variables_schema or EMPTY_VARIABLES_SCHEMA
148
+
149
+ @property
150
+ def id(self) -> str:
151
+ return self._id
152
+
153
+ def get_default_version_id(self) -> str:
154
+ return self._default_version
155
+
156
+ def get_versions(self) -> dict[str, str]:
157
+ return deepcopy(self._versions)
158
+
159
+ def get_variables_schema(self) -> dict[str, Any]:
160
+ return deepcopy(self._variables_schema)
161
+
162
+
163
+ class Prompt(_UnTypedPrompt, Generic[TPromptVar]):
164
+ @property
165
+ def variables_definition(self) -> type[TPromptVar]: ...
166
+
167
+ @overload
168
+ def compile(
169
+ self: "BasePrompt[NoneType]", *, version_id: str | None = None
170
+ ) -> str: ...
171
+
172
+ @overload
173
+ def compile(
174
+ self, variables: TPromptVar, *, version_id: str | None = None
175
+ ) -> str: ...
176
+
177
+ def compile(
178
+ self,
179
+ variables: TPromptVar | None = None,
180
+ *,
181
+ version_id: str | None = None,
182
+ ) -> str: ...
183
+
184
+
185
+ def variables_definition_to_schema(definition: type[TPromptVar]) -> dict[str, Any]:
186
+ if definition is NoneType:
187
+ return EMPTY_VARIABLES_SCHEMA
188
+
189
+ return cast(type[PromptVariables], definition).model_json_schema()
190
+
191
+
192
+ class BasePrompt(BaseUntypedPrompt, Generic[TPromptVar]):
193
+ @classmethod
194
+ def from_untyped(
195
+ cls,
196
+ untyped_prompt: "BaseUntypedPrompt",
197
+ variables_definition: type[TPromptVar] = NoneType,
198
+ ) -> "BasePrompt[TPromptVar]":
199
+ base_schema = untyped_prompt.get_variables_schema()
200
+ typed_schema = variables_definition_to_schema(variables_definition)
201
+ if not isSubschema(typed_schema, base_schema):
202
+ raise TypeError(
203
+ "The provided variables_definition is not compatible with the prompt's variables schema."
204
+ )
205
+ return cls(
206
+ variables_definition=variables_definition,
207
+ versions=untyped_prompt.get_versions(),
208
+ default_version_id=untyped_prompt.get_default_version_id(),
209
+ id=untyped_prompt.id,
210
+ )
211
+
212
+ def __init__(
213
+ self,
214
+ *,
215
+ versions: str | dict[str, str],
216
+ default_version_id: str | None = None,
217
+ variables_definition: type[TPromptVar] = NoneType,
218
+ id: str | None = None,
219
+ ) -> None:
220
+ super().__init__(
221
+ versions=versions,
222
+ default_version_id=default_version_id,
223
+ id=id,
224
+ variables_schema=variables_definition_to_schema(variables_definition),
225
+ )
226
+ self._variables_definition = variables_definition
227
+ _prompt_registry[self.id] = self
228
+
229
+ @property
230
+ def variables_definition(self) -> type[TPromptVar]:
231
+ return self._variables_definition
232
+
233
+ @overload
234
+ def compile(
235
+ self: "BasePrompt[NoneType]", *, version_id: str | None = None
236
+ ) -> str: ...
237
+
238
+ @overload
239
+ def compile(
240
+ self, variables: TPromptVar, *, version_id: str | None = None
241
+ ) -> str: ...
242
+
243
+ def compile(
244
+ self,
245
+ variables: TPromptVar | None = None,
246
+ *,
247
+ version_id: str | None = None,
248
+ ) -> str:
249
+ version_id = version_id or self._default_version
250
+ template_txt = self._versions[version_id]
251
+ if self._variables_definition is not NoneType:
252
+ if variables is None:
253
+ raise ValueError(
254
+ f"Variables[{self._variables_definition}] are required for this prompt."
255
+ )
256
+ template = jinja2.Template(template_txt, undefined=StrictUndefined)
257
+ ret = template.render(**variables.model_dump(mode="json"))
258
+ else:
259
+ ret = template_txt
260
+ _compiled_prompt_registry[id(ret)] = _CompiledPrompt(
261
+ value=ret,
262
+ version_id=version_id,
263
+ prompt=self,
264
+ variables=variables,
265
+ )
266
+ return ret
267
+
268
+ def _update(
269
+ self,
270
+ *,
271
+ versions: str | dict[str, str] | None = None,
272
+ default_version_id: str | None = None,
273
+ ) -> "tuple[Self, OutdatedPrompt[TPromptVar]]":
274
+ outdated_prompt = OutdatedPrompt.from_prompt(self)
275
+ if versions is not None:
276
+ self._versions = _to_versions_dict(versions)
277
+ if default_version_id is not None:
278
+ self._default_version = default_version_id
279
+ _mark_compiled_prompts_outdated(self.id, outdated_prompt)
280
+ return self, outdated_prompt
281
+
282
+ @staticmethod
283
+ def update_prompt_registry(
284
+ untyped_prompt: "BaseUntypedPrompt",
285
+ ) -> "BasePrompt":
286
+ """IMPORTANT: should only be called from storage on storage load!
287
+
288
+ Update the matching entry in type prompt registry in-place.
289
+ DO NOT call other than from initial storage load, to keep immutability of prompts in code.
290
+ """
291
+ existing = get_prompt_by_id(untyped_prompt.id)
292
+ outdated_prompt = OutdatedPrompt.from_prompt(existing)
293
+ _mark_compiled_prompts_outdated(existing.id, outdated_prompt)
294
+ existing._update(
295
+ versions=untyped_prompt.get_versions(),
296
+ default_version_id=untyped_prompt.get_default_version_id(),
297
+ )
298
+ return existing
299
+
300
+ def append_version(
301
+ self,
302
+ *,
303
+ version_id: str,
304
+ content: str,
305
+ set_as_default: bool = False,
306
+ ) -> None:
307
+ if version_id in self._versions:
308
+ raise ValueError(f"Version ID '{version_id}' already exists.")
309
+ self._update(
310
+ versions={version_id: content, **self._versions},
311
+ default_version_id=version_id if set_as_default else None,
312
+ )
313
+
314
+ def update_default_version_id(
315
+ self,
316
+ default_version_id: str,
317
+ ) -> None:
318
+ if default_version_id not in self._versions:
319
+ raise ValueError(f"Version ID '{default_version_id}' does not exist.")
320
+ if self._default_version == default_version_id:
321
+ return
322
+ self._update(
323
+ default_version_id=default_version_id,
324
+ )
325
+
326
+
327
+ class OutdatedPrompt(BasePrompt[TPromptVar]):
328
+
329
+ def __init__(
330
+ self,
331
+ *,
332
+ versions: str | dict[str, str],
333
+ default_version_id: str,
334
+ variables_definition: type[TPromptVar],
335
+ id: str,
336
+ ) -> None:
337
+ self._id = id
338
+ self._versions = _to_versions_dict(versions)
339
+ self._default_version = default_version_id
340
+ self._variables_definition = variables_definition
341
+
342
+ @classmethod
343
+ def from_prompt(
344
+ cls, prompt: BasePrompt[TPromptVar]
345
+ ) -> "OutdatedPrompt[TPromptVar]":
346
+ return cls(
347
+ variables_definition=prompt.variables_definition,
348
+ versions=prompt.get_versions(),
349
+ default_version_id=prompt.get_default_version_id(),
350
+ id=prompt.id,
351
+ )
352
+
353
+ def _update(
354
+ self,
355
+ *,
356
+ versions: str | dict[str, str] | None = None,
357
+ default_version_id: str | None = None,
358
+ ) -> "OutdatedPrompt[TPromptVar]":
359
+ raise ValueError("Cannot update an outdated prompt.")
360
+
361
+ def get_default_version_id(self) -> str:
362
+ return self._default_version
363
+
364
+ def get_versions(self) -> dict[str, str]:
365
+ return deepcopy(self._versions)
366
+
367
+ def compile(
368
+ self,
369
+ _variables: TPromptVar | None = None,
370
+ *,
371
+ _version_id: str | None = None,
372
+ ) -> str:
373
+ raise ValueError("This prompt is outdated and can no longer be used.")
@@ -0,0 +1,82 @@
1
+ from dataclasses import dataclass
2
+ import inspect
3
+ import logging
4
+ from types import NoneType
5
+
6
+ from pixie.prompts.prompt import TPromptVar
7
+ from pixie.prompts.storage import StorageBackedPrompt
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class StorageBackedPromptWithRegistration:
15
+ prompt: StorageBackedPrompt
16
+ description: str | None
17
+ module: str
18
+
19
+
20
+ _registry: dict[str, StorageBackedPromptWithRegistration] = {}
21
+ """Registry for StorageBackedPrompts created by `create_prompt`.
22
+
23
+ StorageBackedPrompt is different from BasePrompt because it can be imcomplete
24
+ (when record is not yet fetched from storage, or record doesn't exist at all).
25
+ Thus this registry could contain more entries than _prompt_registry in prompt.py."""
26
+
27
+
28
+ def list_prompts() -> list[StorageBackedPromptWithRegistration]:
29
+ """List all StorageBackedPrompts created via `create_prompt`."""
30
+ return list(_registry.values())
31
+
32
+
33
+ def get_prompt(id: str) -> StorageBackedPromptWithRegistration | None:
34
+ """Get a StorageBackedPrompt by id, if it was created via `create_prompt`."""
35
+ return _registry.get(id)
36
+
37
+
38
+ def _get_calling_module_name():
39
+ """Find the name of the module that called this function."""
40
+ # Get the current frame and the frame above it (the caller's frame)
41
+ try:
42
+ # inspect.stack()[2] gets the caller's frame record
43
+ # frame[0] or frame.frame is the actual frame object
44
+ caller_frame_record = inspect.stack()[2]
45
+ caller_frame = caller_frame_record.frame
46
+ except IndexError:
47
+ # Handle cases where the stack might be shallower, though unlikely for normal calls
48
+ return "__main__"
49
+
50
+ # Get the module object from the frame object
51
+ module = inspect.getmodule(caller_frame)
52
+
53
+ if module is not None:
54
+ return module.__name__
55
+ else:
56
+ # If getmodule returns None (e.g., if called from the interactive prompt or __main__),
57
+ # try to get the name from the frame's globals
58
+ return caller_frame.f_globals.get("__name__", "__main__")
59
+
60
+
61
+ def create_prompt(
62
+ id: str,
63
+ variables_definition: type[TPromptVar] = NoneType,
64
+ *,
65
+ description: str | None = None,
66
+ ) -> StorageBackedPrompt[TPromptVar]:
67
+ if id in _registry:
68
+ ret = _registry[id].prompt
69
+ if ret.variables_definition != variables_definition:
70
+ raise ValueError(
71
+ f"Prompt with id '{id}' already exists with a different variables definition."
72
+ )
73
+ return ret
74
+ ret = StorageBackedPrompt(id=id, variables_definition=variables_definition)
75
+ calling_module = _get_calling_module_name()
76
+ _registry[id] = StorageBackedPromptWithRegistration(
77
+ prompt=ret,
78
+ description=description,
79
+ module=calling_module,
80
+ )
81
+ logger.info(f"✅ Registered prompt: {id} ({calling_module})")
82
+ return ret