pixie-prompts 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixie/prompts/__init__.py +19 -0
- pixie/prompts/file_watcher.py +327 -0
- pixie/prompts/graphql.py +212 -0
- pixie/prompts/prompt.py +373 -0
- pixie/prompts/prompt_management.py +82 -0
- pixie/prompts/server.py +231 -0
- pixie/prompts/storage.py +399 -0
- pixie_prompts-0.1.1.dist-info/METADATA +36 -0
- pixie_prompts-0.1.1.dist-info/RECORD +12 -0
- pixie_prompts-0.1.1.dist-info/WHEEL +4 -0
- pixie_prompts-0.1.1.dist-info/entry_points.txt +3 -0
- pixie_prompts-0.1.1.dist-info/licenses/LICENSE +21 -0
pixie/prompts/prompt.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import json
|
|
4
|
+
from types import NoneType
|
|
5
|
+
from typing import Any, Generic, Protocol, Self, TypeVar, cast, overload
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
import jinja2
|
|
9
|
+
from jinja2 import StrictUndefined
|
|
10
|
+
from jsonsubschema import isSubschema
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Variables(BaseModel):
|
|
15
|
+
# TODO add validation to prevent fields using reserved names
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
TPromptVar = TypeVar("TPromptVar", bound=Variables | None)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_prompt_registry: dict[str, "BasePrompt"] = {}
|
|
23
|
+
"""Registry of all actualized prompts.
|
|
24
|
+
|
|
25
|
+
Purpose of the registry is to ensure there's single actualized prompt instance per prompt ID globally,
|
|
26
|
+
so that every compiled prompt can track back to one single instance of the prompt it was compiled from.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_prompt_by_id(prompt_id: str) -> "BasePrompt":
|
|
31
|
+
return _prompt_registry[prompt_id]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class _CompiledPrompt:
|
|
36
|
+
value: str
|
|
37
|
+
prompt: "BasePrompt | OutdatedPrompt"
|
|
38
|
+
version_id: str
|
|
39
|
+
variables: Variables | None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_compiled_prompt_registry: dict[int, _CompiledPrompt] = {}
|
|
43
|
+
"""Registry of all compiled prompts.
|
|
44
|
+
|
|
45
|
+
This is to keep track of every result string returned by BasePrompt.compile().
|
|
46
|
+
key is the id() of the compiled string."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _find_matching_prompt(obj):
|
|
50
|
+
if isinstance(obj, str):
|
|
51
|
+
for compiled in _compiled_prompt_registry.values():
|
|
52
|
+
if compiled.value == obj:
|
|
53
|
+
return compiled
|
|
54
|
+
return None
|
|
55
|
+
elif isinstance(obj, dict):
|
|
56
|
+
for value in obj.values():
|
|
57
|
+
result = _find_matching_prompt(value)
|
|
58
|
+
if result:
|
|
59
|
+
return result
|
|
60
|
+
return None
|
|
61
|
+
elif isinstance(obj, list):
|
|
62
|
+
for item in obj:
|
|
63
|
+
result = _find_matching_prompt(item)
|
|
64
|
+
if result:
|
|
65
|
+
return result
|
|
66
|
+
return None
|
|
67
|
+
else:
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_compiled_prompt(text: str) -> _CompiledPrompt | None:
|
|
72
|
+
"""Find the compiled prompt metadata for a given compiled prompt string."""
|
|
73
|
+
if not _compiled_prompt_registry:
|
|
74
|
+
return None
|
|
75
|
+
direct_match = _compiled_prompt_registry.get(id(text))
|
|
76
|
+
if direct_match:
|
|
77
|
+
return direct_match
|
|
78
|
+
for compiled in _compiled_prompt_registry.values():
|
|
79
|
+
if compiled.value == text:
|
|
80
|
+
return compiled
|
|
81
|
+
try:
|
|
82
|
+
obj = json.loads(text)
|
|
83
|
+
return _find_matching_prompt(obj)
|
|
84
|
+
|
|
85
|
+
except json.JSONDecodeError:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _mark_compiled_prompts_outdated(
|
|
90
|
+
prompt_id: str, outdated_prompt: "OutdatedPrompt"
|
|
91
|
+
) -> None:
|
|
92
|
+
for key in list(_compiled_prompt_registry.keys()):
|
|
93
|
+
compiled_prompt = _compiled_prompt_registry[key]
|
|
94
|
+
if compiled_prompt.prompt.id == prompt_id:
|
|
95
|
+
_compiled_prompt_registry[key] = _CompiledPrompt(
|
|
96
|
+
value=compiled_prompt.value,
|
|
97
|
+
version_id=compiled_prompt.version_id,
|
|
98
|
+
variables=compiled_prompt.variables,
|
|
99
|
+
prompt=outdated_prompt,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
DEFAULT_VERSION_ID = "v0"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _to_versions_dict(versions: str | dict[str, str]) -> dict[str, str]:
|
|
107
|
+
if isinstance(versions, str):
|
|
108
|
+
return {DEFAULT_VERSION_ID: versions}
|
|
109
|
+
return deepcopy(versions)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class _UnTypedPrompt(Protocol):
|
|
113
|
+
@property
|
|
114
|
+
def id(self) -> str: ...
|
|
115
|
+
|
|
116
|
+
def get_default_version_id(self) -> str: ...
|
|
117
|
+
|
|
118
|
+
def get_versions(self) -> dict[str, str]: ...
|
|
119
|
+
|
|
120
|
+
def get_variables_schema(self) -> dict[str, Any]: ...
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
EMPTY_VARIABLES_SCHEMA = {"type": "object", "properties": {}}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class BaseUntypedPrompt(_UnTypedPrompt):
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
*,
|
|
131
|
+
versions: str | dict[str, str],
|
|
132
|
+
default_version_id: str | None = None,
|
|
133
|
+
id: str | None = None,
|
|
134
|
+
variables_schema: dict[str, Any] | None = None,
|
|
135
|
+
) -> None:
|
|
136
|
+
if not id:
|
|
137
|
+
id = uuid4().hex
|
|
138
|
+
while id in _prompt_registry:
|
|
139
|
+
id = uuid4().hex
|
|
140
|
+
|
|
141
|
+
self._id = id
|
|
142
|
+
if not versions:
|
|
143
|
+
raise ValueError("No versions provided for the prompt.")
|
|
144
|
+
self._versions: dict[str, str]
|
|
145
|
+
self._versions = _to_versions_dict(versions)
|
|
146
|
+
self._default_version = default_version_id or next(iter(self._versions))
|
|
147
|
+
self._variables_schema = variables_schema or EMPTY_VARIABLES_SCHEMA
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def id(self) -> str:
|
|
151
|
+
return self._id
|
|
152
|
+
|
|
153
|
+
def get_default_version_id(self) -> str:
|
|
154
|
+
return self._default_version
|
|
155
|
+
|
|
156
|
+
def get_versions(self) -> dict[str, str]:
|
|
157
|
+
return deepcopy(self._versions)
|
|
158
|
+
|
|
159
|
+
def get_variables_schema(self) -> dict[str, Any]:
|
|
160
|
+
return deepcopy(self._variables_schema)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Prompt(_UnTypedPrompt, Generic[TPromptVar]):
|
|
164
|
+
@property
|
|
165
|
+
def variables_definition(self) -> type[TPromptVar]: ...
|
|
166
|
+
|
|
167
|
+
@overload
|
|
168
|
+
def compile(
|
|
169
|
+
self: "BasePrompt[NoneType]", *, version_id: str | None = None
|
|
170
|
+
) -> str: ...
|
|
171
|
+
|
|
172
|
+
@overload
|
|
173
|
+
def compile(
|
|
174
|
+
self, variables: TPromptVar, *, version_id: str | None = None
|
|
175
|
+
) -> str: ...
|
|
176
|
+
|
|
177
|
+
def compile(
|
|
178
|
+
self,
|
|
179
|
+
variables: TPromptVar | None = None,
|
|
180
|
+
*,
|
|
181
|
+
version_id: str | None = None,
|
|
182
|
+
) -> str: ...
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def variables_definition_to_schema(definition: type[TPromptVar]) -> dict[str, Any]:
|
|
186
|
+
if definition is NoneType:
|
|
187
|
+
return EMPTY_VARIABLES_SCHEMA
|
|
188
|
+
|
|
189
|
+
return cast(type[Variables], definition).model_json_schema()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class BasePrompt(BaseUntypedPrompt, Generic[TPromptVar]):
|
|
193
|
+
@classmethod
|
|
194
|
+
def from_untyped(
|
|
195
|
+
cls,
|
|
196
|
+
untyped_prompt: "BaseUntypedPrompt",
|
|
197
|
+
variables_definition: type[TPromptVar] = NoneType,
|
|
198
|
+
) -> "BasePrompt[TPromptVar]":
|
|
199
|
+
base_schema = untyped_prompt.get_variables_schema()
|
|
200
|
+
typed_schema = variables_definition_to_schema(variables_definition)
|
|
201
|
+
if not isSubschema(typed_schema, base_schema):
|
|
202
|
+
raise TypeError(
|
|
203
|
+
"The provided variables_definition is not compatible with the prompt's variables schema."
|
|
204
|
+
)
|
|
205
|
+
return cls(
|
|
206
|
+
variables_definition=variables_definition,
|
|
207
|
+
versions=untyped_prompt.get_versions(),
|
|
208
|
+
default_version_id=untyped_prompt.get_default_version_id(),
|
|
209
|
+
id=untyped_prompt.id,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
*,
|
|
215
|
+
versions: str | dict[str, str],
|
|
216
|
+
default_version_id: str | None = None,
|
|
217
|
+
variables_definition: type[TPromptVar] = NoneType,
|
|
218
|
+
id: str | None = None,
|
|
219
|
+
) -> None:
|
|
220
|
+
super().__init__(
|
|
221
|
+
versions=versions,
|
|
222
|
+
default_version_id=default_version_id,
|
|
223
|
+
id=id,
|
|
224
|
+
variables_schema=variables_definition_to_schema(variables_definition),
|
|
225
|
+
)
|
|
226
|
+
self._variables_definition = variables_definition
|
|
227
|
+
_prompt_registry[self.id] = self
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def variables_definition(self) -> type[TPromptVar]:
|
|
231
|
+
return self._variables_definition
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
def compile(
|
|
235
|
+
self: "BasePrompt[NoneType]", *, version_id: str | None = None
|
|
236
|
+
) -> str: ...
|
|
237
|
+
|
|
238
|
+
@overload
|
|
239
|
+
def compile(
|
|
240
|
+
self, variables: TPromptVar, *, version_id: str | None = None
|
|
241
|
+
) -> str: ...
|
|
242
|
+
|
|
243
|
+
def compile(
|
|
244
|
+
self,
|
|
245
|
+
variables: TPromptVar | None = None,
|
|
246
|
+
*,
|
|
247
|
+
version_id: str | None = None,
|
|
248
|
+
) -> str:
|
|
249
|
+
version_id = version_id or self._default_version
|
|
250
|
+
template_txt = self._versions[version_id]
|
|
251
|
+
if self._variables_definition is not NoneType:
|
|
252
|
+
if variables is None:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Variables[{self._variables_definition}] are required for this prompt."
|
|
255
|
+
)
|
|
256
|
+
template = jinja2.Template(template_txt, undefined=StrictUndefined)
|
|
257
|
+
ret = template.render(**variables.model_dump(mode="json"))
|
|
258
|
+
else:
|
|
259
|
+
ret = template_txt
|
|
260
|
+
_compiled_prompt_registry[id(ret)] = _CompiledPrompt(
|
|
261
|
+
value=ret,
|
|
262
|
+
version_id=version_id,
|
|
263
|
+
prompt=self,
|
|
264
|
+
variables=variables,
|
|
265
|
+
)
|
|
266
|
+
return ret
|
|
267
|
+
|
|
268
|
+
def _update(
|
|
269
|
+
self,
|
|
270
|
+
*,
|
|
271
|
+
versions: str | dict[str, str] | None = None,
|
|
272
|
+
default_version_id: str | None = None,
|
|
273
|
+
) -> "tuple[Self, OutdatedPrompt[TPromptVar]]":
|
|
274
|
+
outdated_prompt = OutdatedPrompt.from_prompt(self)
|
|
275
|
+
if versions is not None:
|
|
276
|
+
self._versions = _to_versions_dict(versions)
|
|
277
|
+
if default_version_id is not None:
|
|
278
|
+
self._default_version = default_version_id
|
|
279
|
+
_mark_compiled_prompts_outdated(self.id, outdated_prompt)
|
|
280
|
+
return self, outdated_prompt
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def update_prompt_registry(
|
|
284
|
+
untyped_prompt: "BaseUntypedPrompt",
|
|
285
|
+
) -> "BasePrompt":
|
|
286
|
+
"""IMPORTANT: should only be called from storage on storage load!
|
|
287
|
+
|
|
288
|
+
Update the matching entry in type prompt registry in-place.
|
|
289
|
+
DO NOT call other than from initial storage load, to keep immutability of prompts in code.
|
|
290
|
+
"""
|
|
291
|
+
existing = get_prompt_by_id(untyped_prompt.id)
|
|
292
|
+
outdated_prompt = OutdatedPrompt.from_prompt(existing)
|
|
293
|
+
_mark_compiled_prompts_outdated(existing.id, outdated_prompt)
|
|
294
|
+
existing._update(
|
|
295
|
+
versions=untyped_prompt.get_versions(),
|
|
296
|
+
default_version_id=untyped_prompt.get_default_version_id(),
|
|
297
|
+
)
|
|
298
|
+
return existing
|
|
299
|
+
|
|
300
|
+
def append_version(
|
|
301
|
+
self,
|
|
302
|
+
*,
|
|
303
|
+
version_id: str,
|
|
304
|
+
content: str,
|
|
305
|
+
set_as_default: bool = False,
|
|
306
|
+
) -> None:
|
|
307
|
+
if version_id in self._versions:
|
|
308
|
+
raise ValueError(f"Version ID '{version_id}' already exists.")
|
|
309
|
+
self._update(
|
|
310
|
+
versions={version_id: content, **self._versions},
|
|
311
|
+
default_version_id=version_id if set_as_default else None,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def update_default_version_id(
|
|
315
|
+
self,
|
|
316
|
+
default_version_id: str,
|
|
317
|
+
) -> None:
|
|
318
|
+
if default_version_id not in self._versions:
|
|
319
|
+
raise ValueError(f"Version ID '{default_version_id}' does not exist.")
|
|
320
|
+
if self._default_version == default_version_id:
|
|
321
|
+
return
|
|
322
|
+
self._update(
|
|
323
|
+
default_version_id=default_version_id,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class OutdatedPrompt(BasePrompt[TPromptVar]):
|
|
328
|
+
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
*,
|
|
332
|
+
versions: str | dict[str, str],
|
|
333
|
+
default_version_id: str,
|
|
334
|
+
variables_definition: type[TPromptVar],
|
|
335
|
+
id: str,
|
|
336
|
+
) -> None:
|
|
337
|
+
self._id = id
|
|
338
|
+
self._versions = _to_versions_dict(versions)
|
|
339
|
+
self._default_version = default_version_id
|
|
340
|
+
self._variables_definition = variables_definition
|
|
341
|
+
|
|
342
|
+
@classmethod
|
|
343
|
+
def from_prompt(
|
|
344
|
+
cls, prompt: BasePrompt[TPromptVar]
|
|
345
|
+
) -> "OutdatedPrompt[TPromptVar]":
|
|
346
|
+
return cls(
|
|
347
|
+
variables_definition=prompt.variables_definition,
|
|
348
|
+
versions=prompt.get_versions(),
|
|
349
|
+
default_version_id=prompt.get_default_version_id(),
|
|
350
|
+
id=prompt.id,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def _update(
|
|
354
|
+
self,
|
|
355
|
+
*,
|
|
356
|
+
versions: str | dict[str, str] | None = None,
|
|
357
|
+
default_version_id: str | None = None,
|
|
358
|
+
) -> "OutdatedPrompt[TPromptVar]":
|
|
359
|
+
raise ValueError("Cannot update an outdated prompt.")
|
|
360
|
+
|
|
361
|
+
def get_default_version_id(self) -> str:
|
|
362
|
+
return self._default_version
|
|
363
|
+
|
|
364
|
+
def get_versions(self) -> dict[str, str]:
|
|
365
|
+
return deepcopy(self._versions)
|
|
366
|
+
|
|
367
|
+
def compile(
|
|
368
|
+
self,
|
|
369
|
+
_variables: TPromptVar | None = None,
|
|
370
|
+
*,
|
|
371
|
+
_version_id: str | None = None,
|
|
372
|
+
) -> str:
|
|
373
|
+
raise ValueError("This prompt is outdated and can no longer be used.")
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
from types import NoneType
|
|
5
|
+
|
|
6
|
+
from pixie.prompts.prompt import TPromptVar
|
|
7
|
+
from pixie.prompts.storage import StorageBackedPrompt
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class StorageBackedPromptWithRegistration:
|
|
15
|
+
prompt: StorageBackedPrompt
|
|
16
|
+
description: str | None
|
|
17
|
+
module: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_registry: dict[str, StorageBackedPromptWithRegistration] = {}
|
|
21
|
+
"""Registry for StorageBackedPrompts created by `create_prompt`.
|
|
22
|
+
|
|
23
|
+
StorageBackedPrompt is different from BasePrompt because it can be imcomplete
|
|
24
|
+
(when record is not yet fetched from storage, or record doesn't exist at all).
|
|
25
|
+
Thus this registry could contain more entries than _prompt_registry in prompt.py."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def list_prompts() -> list[StorageBackedPromptWithRegistration]:
|
|
29
|
+
"""List all StorageBackedPrompts created via `create_prompt`."""
|
|
30
|
+
return list(_registry.values())
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_prompt(id: str) -> StorageBackedPromptWithRegistration | None:
|
|
34
|
+
"""Get a StorageBackedPrompt by id, if it was created via `create_prompt`."""
|
|
35
|
+
return _registry.get(id)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_calling_module_name():
|
|
39
|
+
"""Find the name of the module that called this function."""
|
|
40
|
+
# Get the current frame and the frame above it (the caller's frame)
|
|
41
|
+
try:
|
|
42
|
+
# inspect.stack()[2] gets the caller's frame record
|
|
43
|
+
# frame[0] or frame.frame is the actual frame object
|
|
44
|
+
caller_frame_record = inspect.stack()[2]
|
|
45
|
+
caller_frame = caller_frame_record.frame
|
|
46
|
+
except IndexError:
|
|
47
|
+
# Handle cases where the stack might be shallower, though unlikely for normal calls
|
|
48
|
+
return "__main__"
|
|
49
|
+
|
|
50
|
+
# Get the module object from the frame object
|
|
51
|
+
module = inspect.getmodule(caller_frame)
|
|
52
|
+
|
|
53
|
+
if module is not None:
|
|
54
|
+
return module.__name__
|
|
55
|
+
else:
|
|
56
|
+
# If getmodule returns None (e.g., if called from the interactive prompt or __main__),
|
|
57
|
+
# try to get the name from the frame's globals
|
|
58
|
+
return caller_frame.f_globals.get("__name__", "__main__")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_prompt(
|
|
62
|
+
id: str,
|
|
63
|
+
variables_definition: type[TPromptVar] = NoneType,
|
|
64
|
+
*,
|
|
65
|
+
description: str | None = None,
|
|
66
|
+
) -> StorageBackedPrompt[TPromptVar]:
|
|
67
|
+
if id in _registry:
|
|
68
|
+
ret = _registry[id].prompt
|
|
69
|
+
if ret.variables_definition != variables_definition:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Prompt with id '{id}' already exists with a different variables definition."
|
|
72
|
+
)
|
|
73
|
+
return ret
|
|
74
|
+
ret = StorageBackedPrompt(id=id, variables_definition=variables_definition)
|
|
75
|
+
calling_module = _get_calling_module_name()
|
|
76
|
+
_registry[id] = StorageBackedPromptWithRegistration(
|
|
77
|
+
prompt=ret,
|
|
78
|
+
description=description,
|
|
79
|
+
module=calling_module,
|
|
80
|
+
)
|
|
81
|
+
logger.info(f"✅ Registered prompt: {id} ({calling_module})")
|
|
82
|
+
return ret
|
pixie/prompts/server.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""FastAPI server for SDK."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import os
|
|
5
|
+
import colorlog
|
|
6
|
+
import logging
|
|
7
|
+
from urllib.parse import quote
|
|
8
|
+
|
|
9
|
+
import dotenv
|
|
10
|
+
from fastapi import FastAPI
|
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
12
|
+
from strawberry.fastapi import GraphQLRouter
|
|
13
|
+
import uvicorn
|
|
14
|
+
|
|
15
|
+
from pixie.prompts.file_watcher import (
|
|
16
|
+
discover_and_load_modules,
|
|
17
|
+
init_prompt_storage,
|
|
18
|
+
)
|
|
19
|
+
from pixie.prompts.graphql import schema
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Global logging mode
|
|
25
|
+
_logging_mode: str = "default"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def setup_logging(mode: str = "default"):
|
|
29
|
+
"""Configure logging for the entire application.
|
|
30
|
+
|
|
31
|
+
Sets up colored logging with consistent formatting for all loggers.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
mode: Logging mode - "default", "verbose", or "debug"
|
|
35
|
+
- default: INFO for server events, WARNING+ for all modules
|
|
36
|
+
- verbose: INFO+ for all modules
|
|
37
|
+
- debug: DEBUG+ for all modules
|
|
38
|
+
"""
|
|
39
|
+
global _logging_mode
|
|
40
|
+
_logging_mode = mode
|
|
41
|
+
|
|
42
|
+
# Determine log level based on mode
|
|
43
|
+
if mode == "debug":
|
|
44
|
+
level = logging.DEBUG
|
|
45
|
+
elif mode == "verbose":
|
|
46
|
+
level = logging.INFO
|
|
47
|
+
else: # default
|
|
48
|
+
level = logging.INFO
|
|
49
|
+
|
|
50
|
+
colorlog.basicConfig(
|
|
51
|
+
level=level,
|
|
52
|
+
format="[%(log_color)s%(levelname)-8s%(reset)s][%(asctime)s]\t%(message)s",
|
|
53
|
+
datefmt="%H:%M:%S",
|
|
54
|
+
log_colors={
|
|
55
|
+
"DEBUG": "cyan",
|
|
56
|
+
"INFO": "green",
|
|
57
|
+
"WARNING": "yellow",
|
|
58
|
+
"ERROR": "red",
|
|
59
|
+
"CRITICAL": "red,bg_white",
|
|
60
|
+
},
|
|
61
|
+
force=True,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Configure uvicorn loggers to use the same format
|
|
65
|
+
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error"]:
|
|
66
|
+
uvicorn_logger = logging.getLogger(logger_name)
|
|
67
|
+
uvicorn_logger.handlers = []
|
|
68
|
+
uvicorn_logger.propagate = True
|
|
69
|
+
|
|
70
|
+
# In default mode, set most loggers to WARNING+ except specific modules
|
|
71
|
+
if mode == "default":
|
|
72
|
+
# Set root logger to WARNING
|
|
73
|
+
logging.getLogger().setLevel(logging.WARNING)
|
|
74
|
+
# Allow INFO for pixie modules
|
|
75
|
+
logging.getLogger("pixie").setLevel(logging.INFO)
|
|
76
|
+
# Suppress uvicorn access logs in default mode
|
|
77
|
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def create_app() -> FastAPI:
|
|
81
|
+
"""Create and configure the FastAPI application.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Configured FastAPI application instance with GraphQL router.
|
|
85
|
+
"""
|
|
86
|
+
# Setup logging first (use global logging mode)
|
|
87
|
+
setup_logging(_logging_mode)
|
|
88
|
+
|
|
89
|
+
# Discover and load applications on every app creation (including reloads)
|
|
90
|
+
discover_and_load_modules()
|
|
91
|
+
|
|
92
|
+
dotenv.load_dotenv(os.getcwd() + "/.env")
|
|
93
|
+
lifespan = init_prompt_storage()
|
|
94
|
+
|
|
95
|
+
app = FastAPI(
|
|
96
|
+
title="Pixie Prompts Dev Server",
|
|
97
|
+
description="Server for managing prompts",
|
|
98
|
+
version="0.1.0",
|
|
99
|
+
lifespan=lifespan,
|
|
100
|
+
)
|
|
101
|
+
# Matches:
|
|
102
|
+
# 1. http://localhost followed by an optional port (:8080, :3000, etc.)
|
|
103
|
+
# 2. http://127.0.0.1 followed by an optional port
|
|
104
|
+
# 3. https://yourdomain.com (the production domain)
|
|
105
|
+
origins_regex = r"http://(localhost|127\.0\.0\.1)(:\d+)?|https://gopixie\.ai"
|
|
106
|
+
# Add CORS middleware
|
|
107
|
+
app.add_middleware(
|
|
108
|
+
CORSMiddleware,
|
|
109
|
+
allow_origin_regex=origins_regex,
|
|
110
|
+
allow_credentials=True,
|
|
111
|
+
allow_methods=["*"], # Allows all methods
|
|
112
|
+
allow_headers=["*"], # Allows all headers
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Add GraphQL router with GraphiQL enabled
|
|
116
|
+
graphql_app = GraphQLRouter(
|
|
117
|
+
schema,
|
|
118
|
+
graphiql=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
app.include_router(graphql_app, prefix="/graphql")
|
|
122
|
+
|
|
123
|
+
@app.get("/")
|
|
124
|
+
async def root():
|
|
125
|
+
return {
|
|
126
|
+
"message": "Pixie Prompts Dev Server",
|
|
127
|
+
"graphiql": "/graphql",
|
|
128
|
+
"version": "0.1.0",
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return app
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def start_server(
|
|
135
|
+
host: str = "0.0.0.0",
|
|
136
|
+
port: int = 8000,
|
|
137
|
+
reload: bool = False,
|
|
138
|
+
log_mode: str = "default",
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Start the SDK server.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
host: Host to bind to
|
|
144
|
+
port: Port to bind to
|
|
145
|
+
reload: Enable auto-reload for development
|
|
146
|
+
log_mode: Logging mode - "default", "verbose", or "debug"
|
|
147
|
+
storage_directory: Directory to store prompt definitions
|
|
148
|
+
"""
|
|
149
|
+
global _logging_mode
|
|
150
|
+
_logging_mode = log_mode
|
|
151
|
+
|
|
152
|
+
# Setup logging (will be called again in create_app for reload scenarios)
|
|
153
|
+
setup_logging(log_mode)
|
|
154
|
+
|
|
155
|
+
# Determine server URL
|
|
156
|
+
server_url = f"http://{host}:{port}"
|
|
157
|
+
if host == "0.0.0.0":
|
|
158
|
+
server_url = f"http://127.0.0.1:{port}"
|
|
159
|
+
|
|
160
|
+
# Log server start info
|
|
161
|
+
logger.info("Starting Pixie SDK Server")
|
|
162
|
+
logger.info("Server: %s", server_url)
|
|
163
|
+
logger.info("GraphQL: %s/graphql", server_url)
|
|
164
|
+
|
|
165
|
+
# Display gopixie.ai web link
|
|
166
|
+
encoded_url = quote(f"{server_url}/graphql", safe="")
|
|
167
|
+
pixie_web_url = f"https://gopixie.ai?url={encoded_url}"
|
|
168
|
+
logger.info("")
|
|
169
|
+
logger.info("=" * 60)
|
|
170
|
+
logger.info("")
|
|
171
|
+
logger.info("🎨 Open Pixie Web UI:")
|
|
172
|
+
logger.info("")
|
|
173
|
+
logger.info(" %s", pixie_web_url)
|
|
174
|
+
logger.info("")
|
|
175
|
+
logger.info("=" * 60)
|
|
176
|
+
logger.info("")
|
|
177
|
+
|
|
178
|
+
uvicorn.run(
|
|
179
|
+
"pixie.prompts.server:create_app",
|
|
180
|
+
host=host,
|
|
181
|
+
port=port,
|
|
182
|
+
loop="asyncio",
|
|
183
|
+
reload=reload,
|
|
184
|
+
factory=True,
|
|
185
|
+
log_config=None,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def main():
|
|
190
|
+
"""Start the Pixie server.
|
|
191
|
+
|
|
192
|
+
Loads environment variables and starts the server with auto-reload enabled.
|
|
193
|
+
Supports --verbose and --debug flags for enhanced logging.
|
|
194
|
+
"""
|
|
195
|
+
parser = argparse.ArgumentParser(description="Pixie Prompts development server")
|
|
196
|
+
parser.add_argument(
|
|
197
|
+
"--verbose",
|
|
198
|
+
"-v",
|
|
199
|
+
action="store_true",
|
|
200
|
+
help="Enable verbose logging (INFO+ for all modules)",
|
|
201
|
+
)
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--debug",
|
|
204
|
+
"-d",
|
|
205
|
+
action="store_true",
|
|
206
|
+
help="Enable debug logging (DEBUG+ for all modules)",
|
|
207
|
+
)
|
|
208
|
+
parser.add_argument(
|
|
209
|
+
"--port",
|
|
210
|
+
"-p",
|
|
211
|
+
type=int,
|
|
212
|
+
default=None,
|
|
213
|
+
help="Port to run the server on (overrides PIXIE_SDK_PORT env var)",
|
|
214
|
+
)
|
|
215
|
+
args = parser.parse_args()
|
|
216
|
+
|
|
217
|
+
# Determine logging mode
|
|
218
|
+
log_mode = "default"
|
|
219
|
+
if args.debug:
|
|
220
|
+
log_mode = "debug"
|
|
221
|
+
elif args.verbose:
|
|
222
|
+
log_mode = "verbose"
|
|
223
|
+
|
|
224
|
+
dotenv.load_dotenv(os.getcwd() + "/.env")
|
|
225
|
+
port = args.port or int(os.getenv("PIXIE_SDK_PORT", "8000"))
|
|
226
|
+
|
|
227
|
+
start_server(port=port, reload=True, log_mode=log_mode)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
if __name__ == "__main__":
|
|
231
|
+
main()
|