pixie-prompts 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,373 @@
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ import json
4
+ from types import NoneType
5
+ from typing import Any, Generic, Protocol, Self, TypeVar, cast, overload
6
+ from uuid import uuid4
7
+
8
+ import jinja2
9
+ from jinja2 import StrictUndefined
10
+ from jsonsubschema import isSubschema
11
+ from pydantic import BaseModel
12
+
13
+
14
+ class Variables(BaseModel):
15
+ # TODO add validation to prevent fields using reserved names
16
+ pass
17
+
18
+
19
+ TPromptVar = TypeVar("TPromptVar", bound=Variables | None)
20
+
21
+
22
+ _prompt_registry: dict[str, "BasePrompt"] = {}
23
+ """Registry of all actualized prompts.
24
+
25
+ Purpose of the registry is to ensure there's single actualized prompt instance per prompt ID globally,
26
+ so that every compiled prompt can track back to one single instance of the prompt it was compiled from.
27
+ """
28
+
29
+
30
+ def get_prompt_by_id(prompt_id: str) -> "BasePrompt":
31
+ return _prompt_registry[prompt_id]
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class _CompiledPrompt:
36
+ value: str
37
+ prompt: "BasePrompt | OutdatedPrompt"
38
+ version_id: str
39
+ variables: Variables | None
40
+
41
+
42
+ _compiled_prompt_registry: dict[int, _CompiledPrompt] = {}
43
+ """Registry of all compiled prompts.
44
+
45
+ This is to keep track of every result string returned by BasePrompt.compile().
46
+ key is the id() of the compiled string."""
47
+
48
+
49
+ def _find_matching_prompt(obj):
50
+ if isinstance(obj, str):
51
+ for compiled in _compiled_prompt_registry.values():
52
+ if compiled.value == obj:
53
+ return compiled
54
+ return None
55
+ elif isinstance(obj, dict):
56
+ for value in obj.values():
57
+ result = _find_matching_prompt(value)
58
+ if result:
59
+ return result
60
+ return None
61
+ elif isinstance(obj, list):
62
+ for item in obj:
63
+ result = _find_matching_prompt(item)
64
+ if result:
65
+ return result
66
+ return None
67
+ else:
68
+ return None
69
+
70
+
71
+ def get_compiled_prompt(text: str) -> _CompiledPrompt | None:
72
+ """Find the compiled prompt metadata for a given compiled prompt string."""
73
+ if not _compiled_prompt_registry:
74
+ return None
75
+ direct_match = _compiled_prompt_registry.get(id(text))
76
+ if direct_match:
77
+ return direct_match
78
+ for compiled in _compiled_prompt_registry.values():
79
+ if compiled.value == text:
80
+ return compiled
81
+ try:
82
+ obj = json.loads(text)
83
+ return _find_matching_prompt(obj)
84
+
85
+ except json.JSONDecodeError:
86
+ return None
87
+
88
+
89
+ def _mark_compiled_prompts_outdated(
90
+ prompt_id: str, outdated_prompt: "OutdatedPrompt"
91
+ ) -> None:
92
+ for key in list(_compiled_prompt_registry.keys()):
93
+ compiled_prompt = _compiled_prompt_registry[key]
94
+ if compiled_prompt.prompt.id == prompt_id:
95
+ _compiled_prompt_registry[key] = _CompiledPrompt(
96
+ value=compiled_prompt.value,
97
+ version_id=compiled_prompt.version_id,
98
+ variables=compiled_prompt.variables,
99
+ prompt=outdated_prompt,
100
+ )
101
+
102
+
103
+ DEFAULT_VERSION_ID = "v0"
104
+
105
+
106
+ def _to_versions_dict(versions: str | dict[str, str]) -> dict[str, str]:
107
+ if isinstance(versions, str):
108
+ return {DEFAULT_VERSION_ID: versions}
109
+ return deepcopy(versions)
110
+
111
+
112
+ class _UnTypedPrompt(Protocol):
113
+ @property
114
+ def id(self) -> str: ...
115
+
116
+ def get_default_version_id(self) -> str: ...
117
+
118
+ def get_versions(self) -> dict[str, str]: ...
119
+
120
+ def get_variables_schema(self) -> dict[str, Any]: ...
121
+
122
+
123
+ EMPTY_VARIABLES_SCHEMA = {"type": "object", "properties": {}}
124
+
125
+
126
+ class BaseUntypedPrompt(_UnTypedPrompt):
127
+
128
+ def __init__(
129
+ self,
130
+ *,
131
+ versions: str | dict[str, str],
132
+ default_version_id: str | None = None,
133
+ id: str | None = None,
134
+ variables_schema: dict[str, Any] | None = None,
135
+ ) -> None:
136
+ if not id:
137
+ id = uuid4().hex
138
+ while id in _prompt_registry:
139
+ id = uuid4().hex
140
+
141
+ self._id = id
142
+ if not versions:
143
+ raise ValueError("No versions provided for the prompt.")
144
+ self._versions: dict[str, str]
145
+ self._versions = _to_versions_dict(versions)
146
+ self._default_version = default_version_id or next(iter(self._versions))
147
+ self._variables_schema = variables_schema or EMPTY_VARIABLES_SCHEMA
148
+
149
+ @property
150
+ def id(self) -> str:
151
+ return self._id
152
+
153
+ def get_default_version_id(self) -> str:
154
+ return self._default_version
155
+
156
+ def get_versions(self) -> dict[str, str]:
157
+ return deepcopy(self._versions)
158
+
159
+ def get_variables_schema(self) -> dict[str, Any]:
160
+ return deepcopy(self._variables_schema)
161
+
162
+
163
+ class Prompt(_UnTypedPrompt, Generic[TPromptVar]):
164
+ @property
165
+ def variables_definition(self) -> type[TPromptVar]: ...
166
+
167
+ @overload
168
+ def compile(
169
+ self: "BasePrompt[NoneType]", *, version_id: str | None = None
170
+ ) -> str: ...
171
+
172
+ @overload
173
+ def compile(
174
+ self, variables: TPromptVar, *, version_id: str | None = None
175
+ ) -> str: ...
176
+
177
+ def compile(
178
+ self,
179
+ variables: TPromptVar | None = None,
180
+ *,
181
+ version_id: str | None = None,
182
+ ) -> str: ...
183
+
184
+
185
+ def variables_definition_to_schema(definition: type[TPromptVar]) -> dict[str, Any]:
186
+ if definition is NoneType:
187
+ return EMPTY_VARIABLES_SCHEMA
188
+
189
+ return cast(type[Variables], definition).model_json_schema()
190
+
191
+
192
+ class BasePrompt(BaseUntypedPrompt, Generic[TPromptVar]):
193
+ @classmethod
194
+ def from_untyped(
195
+ cls,
196
+ untyped_prompt: "BaseUntypedPrompt",
197
+ variables_definition: type[TPromptVar] = NoneType,
198
+ ) -> "BasePrompt[TPromptVar]":
199
+ base_schema = untyped_prompt.get_variables_schema()
200
+ typed_schema = variables_definition_to_schema(variables_definition)
201
+ if not isSubschema(typed_schema, base_schema):
202
+ raise TypeError(
203
+ "The provided variables_definition is not compatible with the prompt's variables schema."
204
+ )
205
+ return cls(
206
+ variables_definition=variables_definition,
207
+ versions=untyped_prompt.get_versions(),
208
+ default_version_id=untyped_prompt.get_default_version_id(),
209
+ id=untyped_prompt.id,
210
+ )
211
+
212
+ def __init__(
213
+ self,
214
+ *,
215
+ versions: str | dict[str, str],
216
+ default_version_id: str | None = None,
217
+ variables_definition: type[TPromptVar] = NoneType,
218
+ id: str | None = None,
219
+ ) -> None:
220
+ super().__init__(
221
+ versions=versions,
222
+ default_version_id=default_version_id,
223
+ id=id,
224
+ variables_schema=variables_definition_to_schema(variables_definition),
225
+ )
226
+ self._variables_definition = variables_definition
227
+ _prompt_registry[self.id] = self
228
+
229
+ @property
230
+ def variables_definition(self) -> type[TPromptVar]:
231
+ return self._variables_definition
232
+
233
+ @overload
234
+ def compile(
235
+ self: "BasePrompt[NoneType]", *, version_id: str | None = None
236
+ ) -> str: ...
237
+
238
+ @overload
239
+ def compile(
240
+ self, variables: TPromptVar, *, version_id: str | None = None
241
+ ) -> str: ...
242
+
243
+ def compile(
244
+ self,
245
+ variables: TPromptVar | None = None,
246
+ *,
247
+ version_id: str | None = None,
248
+ ) -> str:
249
+ version_id = version_id or self._default_version
250
+ template_txt = self._versions[version_id]
251
+ if self._variables_definition is not NoneType:
252
+ if variables is None:
253
+ raise ValueError(
254
+ f"Variables[{self._variables_definition}] are required for this prompt."
255
+ )
256
+ template = jinja2.Template(template_txt, undefined=StrictUndefined)
257
+ ret = template.render(**variables.model_dump(mode="json"))
258
+ else:
259
+ ret = template_txt
260
+ _compiled_prompt_registry[id(ret)] = _CompiledPrompt(
261
+ value=ret,
262
+ version_id=version_id,
263
+ prompt=self,
264
+ variables=variables,
265
+ )
266
+ return ret
267
+
268
+ def _update(
269
+ self,
270
+ *,
271
+ versions: str | dict[str, str] | None = None,
272
+ default_version_id: str | None = None,
273
+ ) -> "tuple[Self, OutdatedPrompt[TPromptVar]]":
274
+ outdated_prompt = OutdatedPrompt.from_prompt(self)
275
+ if versions is not None:
276
+ self._versions = _to_versions_dict(versions)
277
+ if default_version_id is not None:
278
+ self._default_version = default_version_id
279
+ _mark_compiled_prompts_outdated(self.id, outdated_prompt)
280
+ return self, outdated_prompt
281
+
282
+ @staticmethod
283
+ def update_prompt_registry(
284
+ untyped_prompt: "BaseUntypedPrompt",
285
+ ) -> "BasePrompt":
286
+ """IMPORTANT: should only be called from storage on storage load!
287
+
288
+ Update the matching entry in type prompt registry in-place.
289
+ DO NOT call other than from initial storage load, to keep immutability of prompts in code.
290
+ """
291
+ existing = get_prompt_by_id(untyped_prompt.id)
292
+ outdated_prompt = OutdatedPrompt.from_prompt(existing)
293
+ _mark_compiled_prompts_outdated(existing.id, outdated_prompt)
294
+ existing._update(
295
+ versions=untyped_prompt.get_versions(),
296
+ default_version_id=untyped_prompt.get_default_version_id(),
297
+ )
298
+ return existing
299
+
300
+ def append_version(
301
+ self,
302
+ *,
303
+ version_id: str,
304
+ content: str,
305
+ set_as_default: bool = False,
306
+ ) -> None:
307
+ if version_id in self._versions:
308
+ raise ValueError(f"Version ID '{version_id}' already exists.")
309
+ self._update(
310
+ versions={version_id: content, **self._versions},
311
+ default_version_id=version_id if set_as_default else None,
312
+ )
313
+
314
+ def update_default_version_id(
315
+ self,
316
+ default_version_id: str,
317
+ ) -> None:
318
+ if default_version_id not in self._versions:
319
+ raise ValueError(f"Version ID '{default_version_id}' does not exist.")
320
+ if self._default_version == default_version_id:
321
+ return
322
+ self._update(
323
+ default_version_id=default_version_id,
324
+ )
325
+
326
+
327
+ class OutdatedPrompt(BasePrompt[TPromptVar]):
328
+
329
+ def __init__(
330
+ self,
331
+ *,
332
+ versions: str | dict[str, str],
333
+ default_version_id: str,
334
+ variables_definition: type[TPromptVar],
335
+ id: str,
336
+ ) -> None:
337
+ self._id = id
338
+ self._versions = _to_versions_dict(versions)
339
+ self._default_version = default_version_id
340
+ self._variables_definition = variables_definition
341
+
342
+ @classmethod
343
+ def from_prompt(
344
+ cls, prompt: BasePrompt[TPromptVar]
345
+ ) -> "OutdatedPrompt[TPromptVar]":
346
+ return cls(
347
+ variables_definition=prompt.variables_definition,
348
+ versions=prompt.get_versions(),
349
+ default_version_id=prompt.get_default_version_id(),
350
+ id=prompt.id,
351
+ )
352
+
353
+ def _update(
354
+ self,
355
+ *,
356
+ versions: str | dict[str, str] | None = None,
357
+ default_version_id: str | None = None,
358
+ ) -> "OutdatedPrompt[TPromptVar]":
359
+ raise ValueError("Cannot update an outdated prompt.")
360
+
361
+ def get_default_version_id(self) -> str:
362
+ return self._default_version
363
+
364
+ def get_versions(self) -> dict[str, str]:
365
+ return deepcopy(self._versions)
366
+
367
+ def compile(
368
+ self,
369
+ _variables: TPromptVar | None = None,
370
+ *,
371
+ _version_id: str | None = None,
372
+ ) -> str:
373
+ raise ValueError("This prompt is outdated and can no longer be used.")
@@ -0,0 +1,82 @@
1
+ from dataclasses import dataclass
2
+ import inspect
3
+ import logging
4
+ from types import NoneType
5
+
6
+ from pixie.prompts.prompt import TPromptVar
7
+ from pixie.prompts.storage import StorageBackedPrompt
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class StorageBackedPromptWithRegistration:
15
+ prompt: StorageBackedPrompt
16
+ description: str | None
17
+ module: str
18
+
19
+
20
+ _registry: dict[str, StorageBackedPromptWithRegistration] = {}
21
+ """Registry for StorageBackedPrompts created by `create_prompt`.
22
+
23
+ StorageBackedPrompt is different from BasePrompt because it can be imcomplete
24
+ (when record is not yet fetched from storage, or record doesn't exist at all).
25
+ Thus this registry could contain more entries than _prompt_registry in prompt.py."""
26
+
27
+
28
+ def list_prompts() -> list[StorageBackedPromptWithRegistration]:
29
+ """List all StorageBackedPrompts created via `create_prompt`."""
30
+ return list(_registry.values())
31
+
32
+
33
+ def get_prompt(id: str) -> StorageBackedPromptWithRegistration | None:
34
+ """Get a StorageBackedPrompt by id, if it was created via `create_prompt`."""
35
+ return _registry.get(id)
36
+
37
+
38
+ def _get_calling_module_name():
39
+ """Find the name of the module that called this function."""
40
+ # Get the current frame and the frame above it (the caller's frame)
41
+ try:
42
+ # inspect.stack()[2] gets the caller's frame record
43
+ # frame[0] or frame.frame is the actual frame object
44
+ caller_frame_record = inspect.stack()[2]
45
+ caller_frame = caller_frame_record.frame
46
+ except IndexError:
47
+ # Handle cases where the stack might be shallower, though unlikely for normal calls
48
+ return "__main__"
49
+
50
+ # Get the module object from the frame object
51
+ module = inspect.getmodule(caller_frame)
52
+
53
+ if module is not None:
54
+ return module.__name__
55
+ else:
56
+ # If getmodule returns None (e.g., if called from the interactive prompt or __main__),
57
+ # try to get the name from the frame's globals
58
+ return caller_frame.f_globals.get("__name__", "__main__")
59
+
60
+
61
+ def create_prompt(
62
+ id: str,
63
+ variables_definition: type[TPromptVar] = NoneType,
64
+ *,
65
+ description: str | None = None,
66
+ ) -> StorageBackedPrompt[TPromptVar]:
67
+ if id in _registry:
68
+ ret = _registry[id].prompt
69
+ if ret.variables_definition != variables_definition:
70
+ raise ValueError(
71
+ f"Prompt with id '{id}' already exists with a different variables definition."
72
+ )
73
+ return ret
74
+ ret = StorageBackedPrompt(id=id, variables_definition=variables_definition)
75
+ calling_module = _get_calling_module_name()
76
+ _registry[id] = StorageBackedPromptWithRegistration(
77
+ prompt=ret,
78
+ description=description,
79
+ module=calling_module,
80
+ )
81
+ logger.info(f"✅ Registered prompt: {id} ({calling_module})")
82
+ return ret
@@ -0,0 +1,231 @@
1
+ """FastAPI server for SDK."""
2
+
3
+ import argparse
4
+ import os
5
+ import colorlog
6
+ import logging
7
+ from urllib.parse import quote
8
+
9
+ import dotenv
10
+ from fastapi import FastAPI
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from strawberry.fastapi import GraphQLRouter
13
+ import uvicorn
14
+
15
+ from pixie.prompts.file_watcher import (
16
+ discover_and_load_modules,
17
+ init_prompt_storage,
18
+ )
19
+ from pixie.prompts.graphql import schema
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Global logging mode
25
+ _logging_mode: str = "default"
26
+
27
+
28
+ def setup_logging(mode: str = "default"):
29
+ """Configure logging for the entire application.
30
+
31
+ Sets up colored logging with consistent formatting for all loggers.
32
+
33
+ Args:
34
+ mode: Logging mode - "default", "verbose", or "debug"
35
+ - default: INFO for server events, WARNING+ for all modules
36
+ - verbose: INFO+ for all modules
37
+ - debug: DEBUG+ for all modules
38
+ """
39
+ global _logging_mode
40
+ _logging_mode = mode
41
+
42
+ # Determine log level based on mode
43
+ if mode == "debug":
44
+ level = logging.DEBUG
45
+ elif mode == "verbose":
46
+ level = logging.INFO
47
+ else: # default
48
+ level = logging.INFO
49
+
50
+ colorlog.basicConfig(
51
+ level=level,
52
+ format="[%(log_color)s%(levelname)-8s%(reset)s][%(asctime)s]\t%(message)s",
53
+ datefmt="%H:%M:%S",
54
+ log_colors={
55
+ "DEBUG": "cyan",
56
+ "INFO": "green",
57
+ "WARNING": "yellow",
58
+ "ERROR": "red",
59
+ "CRITICAL": "red,bg_white",
60
+ },
61
+ force=True,
62
+ )
63
+
64
+ # Configure uvicorn loggers to use the same format
65
+ for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error"]:
66
+ uvicorn_logger = logging.getLogger(logger_name)
67
+ uvicorn_logger.handlers = []
68
+ uvicorn_logger.propagate = True
69
+
70
+ # In default mode, set most loggers to WARNING+ except specific modules
71
+ if mode == "default":
72
+ # Set root logger to WARNING
73
+ logging.getLogger().setLevel(logging.WARNING)
74
+ # Allow INFO for pixie modules
75
+ logging.getLogger("pixie").setLevel(logging.INFO)
76
+ # Suppress uvicorn access logs in default mode
77
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
78
+
79
+
80
+ def create_app() -> FastAPI:
81
+ """Create and configure the FastAPI application.
82
+
83
+ Returns:
84
+ Configured FastAPI application instance with GraphQL router.
85
+ """
86
+ # Setup logging first (use global logging mode)
87
+ setup_logging(_logging_mode)
88
+
89
+ # Discover and load applications on every app creation (including reloads)
90
+ discover_and_load_modules()
91
+
92
+ dotenv.load_dotenv(os.getcwd() + "/.env")
93
+ lifespan = init_prompt_storage()
94
+
95
+ app = FastAPI(
96
+ title="Pixie Prompts Dev Server",
97
+ description="Server for managing prompts",
98
+ version="0.1.0",
99
+ lifespan=lifespan,
100
+ )
101
+ # Matches:
102
+ # 1. http://localhost followed by an optional port (:8080, :3000, etc.)
103
+ # 2. http://127.0.0.1 followed by an optional port
104
+ # 3. https://yourdomain.com (the production domain)
105
+ origins_regex = r"http://(localhost|127\.0\.0\.1)(:\d+)?|https://gopixie\.ai"
106
+ # Add CORS middleware
107
+ app.add_middleware(
108
+ CORSMiddleware,
109
+ allow_origin_regex=origins_regex,
110
+ allow_credentials=True,
111
+ allow_methods=["*"], # Allows all methods
112
+ allow_headers=["*"], # Allows all headers
113
+ )
114
+
115
+ # Add GraphQL router with GraphiQL enabled
116
+ graphql_app = GraphQLRouter(
117
+ schema,
118
+ graphiql=True,
119
+ )
120
+
121
+ app.include_router(graphql_app, prefix="/graphql")
122
+
123
+ @app.get("/")
124
+ async def root():
125
+ return {
126
+ "message": "Pixie Prompts Dev Server",
127
+ "graphiql": "/graphql",
128
+ "version": "0.1.0",
129
+ }
130
+
131
+ return app
132
+
133
+
134
+ def start_server(
135
+ host: str = "0.0.0.0",
136
+ port: int = 8000,
137
+ reload: bool = False,
138
+ log_mode: str = "default",
139
+ ) -> None:
140
+ """Start the SDK server.
141
+
142
+ Args:
143
+ host: Host to bind to
144
+ port: Port to bind to
145
+ reload: Enable auto-reload for development
146
+ log_mode: Logging mode - "default", "verbose", or "debug"
147
+ storage_directory: Directory to store prompt definitions
148
+ """
149
+ global _logging_mode
150
+ _logging_mode = log_mode
151
+
152
+ # Setup logging (will be called again in create_app for reload scenarios)
153
+ setup_logging(log_mode)
154
+
155
+ # Determine server URL
156
+ server_url = f"http://{host}:{port}"
157
+ if host == "0.0.0.0":
158
+ server_url = f"http://127.0.0.1:{port}"
159
+
160
+ # Log server start info
161
+ logger.info("Starting Pixie SDK Server")
162
+ logger.info("Server: %s", server_url)
163
+ logger.info("GraphQL: %s/graphql", server_url)
164
+
165
+ # Display gopixie.ai web link
166
+ encoded_url = quote(f"{server_url}/graphql", safe="")
167
+ pixie_web_url = f"https://gopixie.ai?url={encoded_url}"
168
+ logger.info("")
169
+ logger.info("=" * 60)
170
+ logger.info("")
171
+ logger.info("🎨 Open Pixie Web UI:")
172
+ logger.info("")
173
+ logger.info(" %s", pixie_web_url)
174
+ logger.info("")
175
+ logger.info("=" * 60)
176
+ logger.info("")
177
+
178
+ uvicorn.run(
179
+ "pixie.prompts.server:create_app",
180
+ host=host,
181
+ port=port,
182
+ loop="asyncio",
183
+ reload=reload,
184
+ factory=True,
185
+ log_config=None,
186
+ )
187
+
188
+
189
+ def main():
190
+ """Start the Pixie server.
191
+
192
+ Loads environment variables and starts the server with auto-reload enabled.
193
+ Supports --verbose and --debug flags for enhanced logging.
194
+ """
195
+ parser = argparse.ArgumentParser(description="Pixie Prompts development server")
196
+ parser.add_argument(
197
+ "--verbose",
198
+ "-v",
199
+ action="store_true",
200
+ help="Enable verbose logging (INFO+ for all modules)",
201
+ )
202
+ parser.add_argument(
203
+ "--debug",
204
+ "-d",
205
+ action="store_true",
206
+ help="Enable debug logging (DEBUG+ for all modules)",
207
+ )
208
+ parser.add_argument(
209
+ "--port",
210
+ "-p",
211
+ type=int,
212
+ default=None,
213
+ help="Port to run the server on (overrides PIXIE_SDK_PORT env var)",
214
+ )
215
+ args = parser.parse_args()
216
+
217
+ # Determine logging mode
218
+ log_mode = "default"
219
+ if args.debug:
220
+ log_mode = "debug"
221
+ elif args.verbose:
222
+ log_mode = "verbose"
223
+
224
+ dotenv.load_dotenv(os.getcwd() + "/.env")
225
+ port = args.port or int(os.getenv("PIXIE_SDK_PORT", "8000"))
226
+
227
+ start_server(port=port, reload=True, log_mode=log_mode)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()