pytrilogy 0.0.3.108__py3-none-any.whl → 0.0.3.109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytrilogy might be problematic. Click here for more details.
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/METADATA +69 -1
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/RECORD +21 -8
- trilogy/__init__.py +1 -1
- trilogy/ai/__init__.py +19 -0
- trilogy/ai/constants.py +92 -0
- trilogy/ai/conversation.py +99 -0
- trilogy/ai/enums.py +7 -0
- trilogy/ai/execute.py +50 -0
- trilogy/ai/models.py +34 -0
- trilogy/ai/prompts.py +30 -0
- trilogy/ai/providers/__init__.py +0 -0
- trilogy/ai/providers/anthropic.py +105 -0
- trilogy/ai/providers/base.py +22 -0
- trilogy/ai/providers/google.py +142 -0
- trilogy/ai/providers/openai.py +88 -0
- trilogy/ai/providers/utils.py +68 -0
- trilogy/executor.py +35 -7
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/WHEEL +0 -0
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/entry_points.txt +0 -0
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/licenses/LICENSE.md +0 -0
- {pytrilogy-0.0.3.108.dist-info → pytrilogy-0.0.3.109.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytrilogy
|
|
3
|
-
Version: 0.0.3.
|
|
3
|
+
Version: 0.0.3.109
|
|
4
4
|
Summary: Declarative, typed query language that compiles to SQL.
|
|
5
5
|
Home-page:
|
|
6
6
|
Author:
|
|
@@ -28,6 +28,8 @@ Provides-Extra: bigquery
|
|
|
28
28
|
Requires-Dist: sqlalchemy-bigquery; extra == "bigquery"
|
|
29
29
|
Provides-Extra: snowflake
|
|
30
30
|
Requires-Dist: snowflake-sqlalchemy; extra == "snowflake"
|
|
31
|
+
Provides-Extra: ai
|
|
32
|
+
Requires-Dist: httpx; extra == "ai"
|
|
31
33
|
Dynamic: author-email
|
|
32
34
|
Dynamic: classifier
|
|
33
35
|
Dynamic: description
|
|
@@ -113,6 +115,31 @@ ORDER BY
|
|
|
113
115
|
LIMIT 10;
|
|
114
116
|
```
|
|
115
117
|
|
|
118
|
+
## Trilogy is Easy to Write
|
|
119
|
+
For humans *and* AI. Enjoy flexible, one-shot query generation without any DB access or security risks.
|
|
120
|
+
|
|
121
|
+
(full code in the python API section.)
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
query = text_to_query(
|
|
125
|
+
executor.environment,
|
|
126
|
+
"number of flights by month in 2005",
|
|
127
|
+
Provider.OPENAI,
|
|
128
|
+
"gpt-5-chat-latest",
|
|
129
|
+
api_key,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# get a ready to run query
|
|
133
|
+
print(query)
|
|
134
|
+
# typical output
|
|
135
|
+
'''where local.dep_time.year = 2020
|
|
136
|
+
select
|
|
137
|
+
local.dep_time.month,
|
|
138
|
+
count(local.id2) as number_of_flights
|
|
139
|
+
order by
|
|
140
|
+
local.dep_time.month asc;'''
|
|
141
|
+
```
|
|
142
|
+
|
|
116
143
|
## Goals
|
|
117
144
|
|
|
118
145
|
Versus SQL, Trilogy aims to:
|
|
@@ -264,6 +291,47 @@ for row in results:
|
|
|
264
291
|
print(x)
|
|
265
292
|
```
|
|
266
293
|
|
|
294
|
+
### LLM Usage
|
|
295
|
+
|
|
296
|
+
Connect to your favorite provider and generate queries with confidence and high accuracy.
|
|
297
|
+
|
|
298
|
+
```python
|
|
299
|
+
from trilogy import Environment, Dialects
|
|
300
|
+
from trilogy.ai import Provider, text_to_query
|
|
301
|
+
import os
|
|
302
|
+
|
|
303
|
+
executor = Dialects.DUCK_DB.default_executor(
|
|
304
|
+
environment=Environment(working_path=Path(__file__).parent)
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
api_key = os.environ.get(OPENAI_API_KEY)
|
|
308
|
+
if not api_key:
|
|
309
|
+
raise ValueError("OPENAI_API_KEY required for gpt generation")
|
|
310
|
+
# load a model
|
|
311
|
+
executor.parse_file("flight.preql")
|
|
312
|
+
# create tables in the DB if needed
|
|
313
|
+
executor.execute_file("setup.sql")
|
|
314
|
+
# generate a query
|
|
315
|
+
query = text_to_query(
|
|
316
|
+
executor.environment,
|
|
317
|
+
"number of flights by month in 2005",
|
|
318
|
+
Provider.OPENAI,
|
|
319
|
+
"gpt-5-chat-latest",
|
|
320
|
+
api_key,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# print the generated trilogy query
|
|
324
|
+
print(query)
|
|
325
|
+
# run it
|
|
326
|
+
results = executor.execute_text(query)[-1].fetchall()
|
|
327
|
+
assert len(results) == 12
|
|
328
|
+
|
|
329
|
+
for row in results:
|
|
330
|
+
# all monthly flights are between 5000 and 7000
|
|
331
|
+
assert row[1] > 5000 and row[1] < 7000, row
|
|
332
|
+
|
|
333
|
+
```
|
|
334
|
+
|
|
267
335
|
### CLI Usage
|
|
268
336
|
|
|
269
337
|
Trilogy can be run through a CLI tool, also named 'trilogy'.
|
|
@@ -1,12 +1,25 @@
|
|
|
1
|
-
pytrilogy-0.0.3.
|
|
2
|
-
trilogy/__init__.py,sha256=
|
|
1
|
+
pytrilogy-0.0.3.109.dist-info/licenses/LICENSE.md,sha256=5ZRvtTyCCFwz1THxDTjAu3Lidds9WjPvvzgVwPSYNDo,1042
|
|
2
|
+
trilogy/__init__.py,sha256=KT6UoNoE4ZPq_JfCkzB6yZ8a543YyAWsBDGNZxY8LEg,304
|
|
3
3
|
trilogy/constants.py,sha256=g_zkVCNjGop6coZ1kM8eXXAzCnUN22ldx3TYFz0E9sc,1747
|
|
4
4
|
trilogy/engine.py,sha256=3MiADf5MKcmxqiHBuRqiYdsXiLj7oitDfVvXvHrfjkA,2178
|
|
5
|
-
trilogy/executor.py,sha256
|
|
5
|
+
trilogy/executor.py,sha256=-VeOV0bTGmchHRHpRwFJDyl8FElUxDpwUTUix7hhIFM,17429
|
|
6
6
|
trilogy/parser.py,sha256=o4cfk3j3yhUFoiDKq9ZX_GjBF3dKhDjXEwb63rcBkBM,293
|
|
7
7
|
trilogy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
trilogy/render.py,sha256=qQWwduymauOlB517UtM-VGbVe8Cswa4UJub5aGbSO6c,1512
|
|
9
9
|
trilogy/utility.py,sha256=euQccZLKoYBz0LNg5tzLlvv2YHvXh9HArnYp1V3uXsM,763
|
|
10
|
+
trilogy/ai/__init__.py,sha256=H6gpzScruX2xgZNRDMjQ31Wy45irJbdebX1fU_gOwI8,581
|
|
11
|
+
trilogy/ai/constants.py,sha256=Aj-_mFqskcXqIlBjX_A9eqH0V9M8mqX3uJwUhr9puak,5064
|
|
12
|
+
trilogy/ai/conversation.py,sha256=I11xmUZikuKmh-W-jt38OvtyhpHwhpQ6Eeut6dkjI-c,3467
|
|
13
|
+
trilogy/ai/enums.py,sha256=vghPPx0W-DioQSgq4T0MGL-8ekFh6O6d52dHo7KsKtg,118
|
|
14
|
+
trilogy/ai/execute.py,sha256=DTARZxm_btCJq4Yd_jPRHJAcbsMLbjEsjR7KKyKBkTI,1335
|
|
15
|
+
trilogy/ai/models.py,sha256=Au4QnTIlv7e-p3XgTJYZqTSndPMGRIbOvCUWlekE81A,683
|
|
16
|
+
trilogy/ai/prompts.py,sha256=Uag0DJcKs7QWFGX7I3QFSm9o_4oYgASFyhNm4SJncVA,1788
|
|
17
|
+
trilogy/ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
+
trilogy/ai/providers/anthropic.py,sha256=Wj2kEmz2CPuCdgUAqC8mYYrqaICTknnVN4QukTkr8tQ,4036
|
|
19
|
+
trilogy/ai/providers/base.py,sha256=PgrD3y9-S42GAfyZUm8cNLgHQx7Wew_kCcVj9WKoImo,693
|
|
20
|
+
trilogy/ai/providers/google.py,sha256=WnAqD84pLPMs5iAgjEOX8BpxowRCzJNEbxPyc_c_AtE,5252
|
|
21
|
+
trilogy/ai/providers/openai.py,sha256=_lhY795q6XMKSyh4pAskQ8Ft2fMgD8tDWWhwlPh5FB0,3273
|
|
22
|
+
trilogy/ai/providers/utils.py,sha256=yttP6y2E_XzdytBCwhaKekfXfxM6gE6MRce4AtyLL60,2047
|
|
10
23
|
trilogy/authoring/__init__.py,sha256=TABMOETSMERrWuyDLR0nK4ISlqR0yaqeXrmuOdrSvAY,3060
|
|
11
24
|
trilogy/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
25
|
trilogy/core/constants.py,sha256=nizWYDCJQ1bigQMtkNIEMNTcN0NoEAXiIHLzpelxQ24,201
|
|
@@ -119,8 +132,8 @@ trilogy/std/money.preql,sha256=XWwvAV3WxBsHX9zfptoYRnBigcfYwrYtBHXTME0xJuQ,2082
|
|
|
119
132
|
trilogy/std/net.preql,sha256=WZCuvH87_rZntZiuGJMmBDMVKkdhTtxeHOkrXNwJ1EE,416
|
|
120
133
|
trilogy/std/ranking.preql,sha256=LDoZrYyz4g3xsII9XwXfmstZD-_92i1Eox1UqkBIfi8,83
|
|
121
134
|
trilogy/std/report.preql,sha256=LbV-XlHdfw0jgnQ8pV7acG95xrd1-p65fVpiIc-S7W4,202
|
|
122
|
-
pytrilogy-0.0.3.
|
|
123
|
-
pytrilogy-0.0.3.
|
|
124
|
-
pytrilogy-0.0.3.
|
|
125
|
-
pytrilogy-0.0.3.
|
|
126
|
-
pytrilogy-0.0.3.
|
|
135
|
+
pytrilogy-0.0.3.109.dist-info/METADATA,sha256=U_r100YWYUQoWKW48qdPDL2eZVxvDOia7fKkfOOiK3I,13460
|
|
136
|
+
pytrilogy-0.0.3.109.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
137
|
+
pytrilogy-0.0.3.109.dist-info/entry_points.txt,sha256=ewBPU2vLnVexZVnB-NrVj-p3E-4vukg83Zk8A55Wp2w,56
|
|
138
|
+
pytrilogy-0.0.3.109.dist-info/top_level.txt,sha256=cAy__NW_eMAa_yT9UnUNlZLFfxcg6eimUAZ184cdNiE,8
|
|
139
|
+
pytrilogy-0.0.3.109.dist-info/RECORD,,
|
trilogy/__init__.py
CHANGED
trilogy/ai/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from trilogy.ai.conversation import Conversation
|
|
2
|
+
from trilogy.ai.enums import Provider
|
|
3
|
+
from trilogy.ai.execute import text_to_query
|
|
4
|
+
from trilogy.ai.models import LLMMessage
|
|
5
|
+
from trilogy.ai.prompts import create_query_prompt
|
|
6
|
+
from trilogy.ai.providers.anthropic import AnthropicProvider
|
|
7
|
+
from trilogy.ai.providers.google import GoogleProvider
|
|
8
|
+
from trilogy.ai.providers.openai import OpenAIProvider
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Conversation",
|
|
12
|
+
"LLMMessage",
|
|
13
|
+
"OpenAIProvider",
|
|
14
|
+
"GoogleProvider",
|
|
15
|
+
"AnthropicProvider",
|
|
16
|
+
"create_query_prompt",
|
|
17
|
+
"text_to_query",
|
|
18
|
+
"Provider",
|
|
19
|
+
]
|
trilogy/ai/constants.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from trilogy.core.enums import FunctionClass, FunctionType
|
|
2
|
+
from trilogy.core.functions import FUNCTION_REGISTRY
|
|
3
|
+
|
|
4
|
+
RULE_PROMPT = """Trilogy statements define a semantic model or query. If a user is asking for data, they want a SELECT.
|
|
5
|
+
Semantic model statements:
|
|
6
|
+
- import <> imports a model to reuse. The output of imports will be visible in fields available to use.
|
|
7
|
+
- key|property|auto|metric defines fields locally. The output will also be visible in fields available to use, so you generally don't need to edit these unless requested.
|
|
8
|
+
- datasource statements define a datasource, which is a mapping of fields to a SQL database table. The left side is the SQL column name, the right side is the field name.
|
|
9
|
+
|
|
10
|
+
SELECT RULES:
|
|
11
|
+
- No FROM, JOIN, GROUP BY, SUB SELECTS, DISTINCT, UNION, or SELECT *.
|
|
12
|
+
- All fields exist in a global namespace; field paths look like `order.product.id`. Always use the full path. NEVER include a from clause.
|
|
13
|
+
- If a field has a grain defined, and that grain is not in the query output, aggregate it to get desired result.
|
|
14
|
+
- If a field has a 'alias_for' defined, it is shorthand for that calculation. Use the field name instead of the calculation in your query to be concise.
|
|
15
|
+
- Newly created fields at the output of the select must be aliased with as (e.g. `sum(births) as all_births`).
|
|
16
|
+
- Aliases cannot happen inside calculations or in the where/having/order clause. Never alias fields with existing names. 'sum(revenue) as total_revenue' is valid, but '(sum(births) as total_revenue) +1 as revenue_plus_one' is not.
|
|
17
|
+
- Implicit grouping: NEVER include a group by clause. Grouping is by non-aggregated fields in the SELECT clause.
|
|
18
|
+
- You can dynamically group inline to get groups at different grains - ex: `sum(metric) by dim1, dim2 as sum_by_dim1_dm2` for alternate grouping. If you are grouping a defined aggregate
|
|
19
|
+
- Count must specify a field (no `count(*)`) Counts are automatically deduplicated. Do not ever use DISTINCT.
|
|
20
|
+
- Since there are no underlying tables, sum/count of a constant should always specify a grain field (e.g. `sum(1) by x as count`).
|
|
21
|
+
- Aggregates in SELECT must be filtered via HAVING. Use WHERE for pre-aggregation filters.
|
|
22
|
+
- Use `field ? condition` for inline filters (e.g. `sum(x ? x > 0)`).
|
|
23
|
+
- Always use a reasonable `LIMIT` for final queries unless the request is for a time series or line chart.
|
|
24
|
+
- Window functions: `rank entity [optional over group] by field desc` (e.g. `rank name over state by sum(births) desc as top_name`) Do not use parentheses for over.
|
|
25
|
+
- Functions. All function names have parenthese (e.g. `sum(births)`, `date_part('year', dep_time)`). For no arguments, use empty parentheses (e.g. `current_date()`).
|
|
26
|
+
- For lag/lead, offset is first: lag/lead offset field order by expr asc/desc.
|
|
27
|
+
- For lag/lead with a window clause: lag/lead offset field by window_clause order by expr asc/desc.
|
|
28
|
+
- Use `::type` casting, e.g., `"2020-01-01"::date`.
|
|
29
|
+
- Date_parts have no quotes; use `date_part(order_date, year)` instead of `date_part(order_date, 'year')`.
|
|
30
|
+
- Comments use `#` only, per line.
|
|
31
|
+
- Two example queries: "where year between 1940 and 1950
|
|
32
|
+
select
|
|
33
|
+
name,
|
|
34
|
+
state,
|
|
35
|
+
sum(births) AS all_births,
|
|
36
|
+
sum(births ? state = 'VT') AS vermont_births,
|
|
37
|
+
rank name over state by all_births desc AS state_rank,
|
|
38
|
+
rank name by sum(births) by name desc AS all_rank
|
|
39
|
+
having
|
|
40
|
+
all_rank<11
|
|
41
|
+
and state = 'ID'
|
|
42
|
+
order by
|
|
43
|
+
all_rank asc
|
|
44
|
+
limit 5;", "where dep_time between '2002-01-01'::datetime and '2010-01-31'::datetime
|
|
45
|
+
select
|
|
46
|
+
carrier.name,
|
|
47
|
+
count(id2) AS total_flights,
|
|
48
|
+
total_flights / date_diff(min(dep_time.date), max(dep_time.date), DAY) AS average_daily_flights
|
|
49
|
+
order by
|
|
50
|
+
total_flights desc;"""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def render_function(function_type: FunctionType, example: str | None = None):
|
|
54
|
+
info = FUNCTION_REGISTRY[function_type]
|
|
55
|
+
|
|
56
|
+
if info.arg_count == -1:
|
|
57
|
+
# Infinite/variable number of arguments
|
|
58
|
+
base = f"{function_type.value}(<arg1>, <arg2>, ..., <argN>)"
|
|
59
|
+
elif info.arg_count == 0:
|
|
60
|
+
# No arguments
|
|
61
|
+
base = f"{function_type.value}()"
|
|
62
|
+
else:
|
|
63
|
+
# Fixed number of arguments
|
|
64
|
+
base = f"{function_type.value}({', '.join([f'<arg{p}>' for p in range(1, info.arg_count + 1)])})"
|
|
65
|
+
|
|
66
|
+
if example:
|
|
67
|
+
base += f" e.g. {example}"
|
|
68
|
+
return base
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
FUNCTION_EXAMPLES = {
|
|
72
|
+
FunctionType.DATE_ADD: "date_add('2020-01-01'::date, month, 1)",
|
|
73
|
+
FunctionType.DATE_DIFF: "date_diff('2020-01-01'::date, '2020-01-02'::date, day)",
|
|
74
|
+
FunctionType.DATE_PART: "date_part('2020-01-01'::date, year)",
|
|
75
|
+
FunctionType.DATE_SUB: "date_sub('2020-01-01'::date, day, 1)",
|
|
76
|
+
FunctionType.DATE_TRUNCATE: "date_trunc('2020-01-01'::date, month)",
|
|
77
|
+
FunctionType.CURRENT_TIMESTAMP: "now()",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
FUNCTIONS = "\n".join(
|
|
81
|
+
[
|
|
82
|
+
render_function(v, example=FUNCTION_EXAMPLES.get(v))
|
|
83
|
+
for x, v in FunctionType.__members__.items()
|
|
84
|
+
if v in FUNCTION_REGISTRY
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
AGGREGATE_FUNCTIONS = [
|
|
89
|
+
x
|
|
90
|
+
for x, info in FunctionType.__members__.items()
|
|
91
|
+
if x in FunctionClass.AGGREGATE_FUNCTIONS.value
|
|
92
|
+
]
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal, Union
|
|
3
|
+
|
|
4
|
+
from trilogy import Environment
|
|
5
|
+
from trilogy.ai.models import LLMMessage, LLMRequestOptions
|
|
6
|
+
from trilogy.ai.prompts import TRILOGY_LEAD_IN, create_query_prompt
|
|
7
|
+
from trilogy.ai.providers.base import LLMProvider
|
|
8
|
+
from trilogy.core.exceptions import (
|
|
9
|
+
InvalidSyntaxException,
|
|
10
|
+
NoDatasourceException,
|
|
11
|
+
UndefinedConceptException,
|
|
12
|
+
UnresolvableQueryException,
|
|
13
|
+
)
|
|
14
|
+
from trilogy.core.query_processor import process_query
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Conversation:
|
|
19
|
+
|
|
20
|
+
messages: list[LLMMessage]
|
|
21
|
+
provider: LLMProvider
|
|
22
|
+
id: str | None = None
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def create(
|
|
26
|
+
cls,
|
|
27
|
+
provider: LLMProvider,
|
|
28
|
+
model_prompt: str = TRILOGY_LEAD_IN,
|
|
29
|
+
id: str | None = None,
|
|
30
|
+
) -> "Conversation":
|
|
31
|
+
system_message = LLMMessage(role="system", content=model_prompt)
|
|
32
|
+
messages = [system_message]
|
|
33
|
+
return cls(id=id, messages=messages, provider=provider)
|
|
34
|
+
|
|
35
|
+
def add_message(
|
|
36
|
+
self,
|
|
37
|
+
message: Union[LLMMessage, str],
|
|
38
|
+
role: Literal["user", "assistant"] = "user",
|
|
39
|
+
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Add a message to the conversation.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
message: Either an LLMMessage object or a string content
|
|
45
|
+
role: The role for the message if a string is provided (default: 'user')
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(message, str):
|
|
48
|
+
message = LLMMessage(role=role, content=message)
|
|
49
|
+
self.messages.append(message)
|
|
50
|
+
|
|
51
|
+
def get_response(self) -> LLMMessage:
|
|
52
|
+
options = LLMRequestOptions()
|
|
53
|
+
response = self.provider.generate_completion(options, history=self.messages)
|
|
54
|
+
response_message = LLMMessage(role="assistant", content=response.text)
|
|
55
|
+
self.add_message(response_message)
|
|
56
|
+
return response_message
|
|
57
|
+
|
|
58
|
+
def extract_response(self, content: str) -> str:
|
|
59
|
+
# get contents in triple backticks
|
|
60
|
+
content = content.replace('"""', "```")
|
|
61
|
+
if "```" in content:
|
|
62
|
+
parts = content.split("```")
|
|
63
|
+
if len(parts) >= 3:
|
|
64
|
+
return parts[1].strip()
|
|
65
|
+
return content
|
|
66
|
+
|
|
67
|
+
def generate_query(
|
|
68
|
+
self, user_input: str, environment: Environment, attempts: int = 4
|
|
69
|
+
) -> str:
|
|
70
|
+
attempts = 0
|
|
71
|
+
self.add_message(create_query_prompt(user_input, environment), role="user")
|
|
72
|
+
e = None
|
|
73
|
+
while attempts < 4:
|
|
74
|
+
attempts += 1
|
|
75
|
+
|
|
76
|
+
response_message = self.get_response()
|
|
77
|
+
response = self.extract_response(response_message.content)
|
|
78
|
+
if not response.strip()[-1] == ";":
|
|
79
|
+
response += ";"
|
|
80
|
+
try:
|
|
81
|
+
env, raw = environment.parse(response)
|
|
82
|
+
process_query(statement=raw[-1], environment=environment)
|
|
83
|
+
return response
|
|
84
|
+
except (
|
|
85
|
+
InvalidSyntaxException,
|
|
86
|
+
NoDatasourceException,
|
|
87
|
+
UnresolvableQueryException,
|
|
88
|
+
UndefinedConceptException,
|
|
89
|
+
SyntaxError,
|
|
90
|
+
) as e2:
|
|
91
|
+
e = e2
|
|
92
|
+
self.add_message(
|
|
93
|
+
f"The previous response could not be parsed due to the error: {str(e)}. Please generate a new query with the issues fixed. Use the same response format.",
|
|
94
|
+
role="user",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
raise Exception(
|
|
98
|
+
f"Failed to generate a valid query after {attempts} attempts. Last error: {str(e)}. Full conversation: {self.messages}"
|
|
99
|
+
)
|
trilogy/ai/enums.py
ADDED
trilogy/ai/execute.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from trilogy import Environment
|
|
2
|
+
from trilogy.ai.conversation import Conversation
|
|
3
|
+
from trilogy.ai.enums import Provider
|
|
4
|
+
from trilogy.ai.providers.base import LLMProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def text_to_query(
|
|
8
|
+
environment: Environment,
|
|
9
|
+
user_input: str,
|
|
10
|
+
provider: Provider,
|
|
11
|
+
model: str,
|
|
12
|
+
secret: str | None = None,
|
|
13
|
+
) -> str:
|
|
14
|
+
llm_provider: LLMProvider
|
|
15
|
+
|
|
16
|
+
if provider == Provider.OPENAI:
|
|
17
|
+
from trilogy.ai.providers.openai import OpenAIProvider
|
|
18
|
+
|
|
19
|
+
llm_provider = OpenAIProvider(
|
|
20
|
+
name="openai",
|
|
21
|
+
api_key=secret,
|
|
22
|
+
model=model,
|
|
23
|
+
)
|
|
24
|
+
elif provider == Provider.ANTHROPIC:
|
|
25
|
+
from trilogy.ai.providers.anthropic import AnthropicProvider
|
|
26
|
+
|
|
27
|
+
llm_provider = AnthropicProvider(
|
|
28
|
+
name="anthropic",
|
|
29
|
+
api_key=secret,
|
|
30
|
+
model=model,
|
|
31
|
+
)
|
|
32
|
+
elif provider == Provider.GOOGLE:
|
|
33
|
+
from trilogy.ai.providers.google import GoogleProvider
|
|
34
|
+
|
|
35
|
+
llm_provider = GoogleProvider(
|
|
36
|
+
name="google",
|
|
37
|
+
api_key=secret,
|
|
38
|
+
model=model,
|
|
39
|
+
)
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
42
|
+
conversation = Conversation.create(
|
|
43
|
+
provider=llm_provider,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
response = conversation.generate_query(
|
|
47
|
+
user_input=user_input, environment=environment
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return response
|
trilogy/ai/models.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class UsageDict:
|
|
7
|
+
prompt_tokens: int
|
|
8
|
+
completion_tokens: int
|
|
9
|
+
total_tokens: int
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LLMResponse:
|
|
14
|
+
text: str
|
|
15
|
+
usage: UsageDict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class LLMRequestOptions:
|
|
20
|
+
max_tokens: Optional[int] = None
|
|
21
|
+
temperature: Optional[float] = None
|
|
22
|
+
top_p: Optional[float] = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class LLMMessage:
|
|
27
|
+
role: Literal["user", "assistant", "system"]
|
|
28
|
+
content: str
|
|
29
|
+
model_info: Optional[dict] = None
|
|
30
|
+
hidden: bool = False # Used to hide messages in the UI
|
|
31
|
+
|
|
32
|
+
def __post_init__(self):
|
|
33
|
+
if self.model_info is None:
|
|
34
|
+
self.model_info = {}
|
trilogy/ai/prompts.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from trilogy import Environment
|
|
2
|
+
from trilogy.ai.constants import AGGREGATE_FUNCTIONS, FUNCTIONS, RULE_PROMPT
|
|
3
|
+
from trilogy.authoring import Concept, DataType
|
|
4
|
+
|
|
5
|
+
TRILOGY_LEAD_IN = f'''You are a world-class expert in Trilogy, a SQL inspired language with similar syntax and a built in semantic layer. Use the following syntax description to help answer whatever questions they have. Often, they will be asking you to generate a query for them.
|
|
6
|
+
|
|
7
|
+
Key Trilogy Syntax Rules:
|
|
8
|
+
{RULE_PROMPT}
|
|
9
|
+
|
|
10
|
+
Aggregate Functions:
|
|
11
|
+
{AGGREGATE_FUNCTIONS}
|
|
12
|
+
|
|
13
|
+
Functions:
|
|
14
|
+
{FUNCTIONS}
|
|
15
|
+
|
|
16
|
+
Valid types:
|
|
17
|
+
{[x.value for x in DataType]}
|
|
18
|
+
|
|
19
|
+
For any response to the user, use this format -> put your actual response within triple double quotes with thinking and justification before it, in this format (replace placeholders with relevant content): Reasoning: {{reasoning}} """{{response}}"""
|
|
20
|
+
'''
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def concepts_to_fields_prompt(concepts: list[Concept]) -> str:
|
|
24
|
+
return ", ".join([f"[name: {c.address} | type: {c.datatype}" for c in concepts])
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_query_prompt(query: str, environment: Environment) -> str:
|
|
28
|
+
fields = concepts_to_fields_prompt(list(environment.concepts.values()))
|
|
29
|
+
return f'''
|
|
30
|
+
Using these base and aliased calculations, derivations thereof created with valid Trilogy, and any extra context you have: {fields}, create the best valid Trilogy query to answer the following user input: "{query}" Return the query within triple double quotes with your thinking and justification before it, so of this form as a jinja template: Reasoning: {{reasoning_placeholder}} """{{trilogy}}""". Example: Because the user asked for sales by year, and revenue is the best sales related field available, we can aggregate revenue by year: """SELECT order.year, sum(revenue) as year_revenue order by order.year asc;"""'''
|
|
File without changes
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from os import environ
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from trilogy.ai.enums import Provider
|
|
5
|
+
from trilogy.ai.models import LLMMessage, LLMResponse, UsageDict
|
|
6
|
+
|
|
7
|
+
from .base import LLMProvider, LLMRequestOptions
|
|
8
|
+
from .utils import RetryOptions, fetch_with_retry
|
|
9
|
+
|
|
10
|
+
DEFAULT_MAX_TOKENS = 10000
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AnthropicProvider(LLMProvider):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
name: str,
|
|
17
|
+
model: str,
|
|
18
|
+
api_key: str | None = None,
|
|
19
|
+
retry_options: Optional[RetryOptions] = None,
|
|
20
|
+
):
|
|
21
|
+
api_key = api_key or environ.get("ANTHROPIC_API_KEY")
|
|
22
|
+
if not api_key:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
"API key argument or environment variable ANTHROPIC_API_KEY is required"
|
|
25
|
+
)
|
|
26
|
+
super().__init__(name, api_key, model, Provider.ANTHROPIC)
|
|
27
|
+
self.base_completion_url = "https://api.anthropic.com/v1/messages"
|
|
28
|
+
self.base_model_url = "https://api.anthropic.com/v1/models"
|
|
29
|
+
self.models: List[str] = []
|
|
30
|
+
self.type = Provider.ANTHROPIC
|
|
31
|
+
self.retry_options = retry_options or RetryOptions(
|
|
32
|
+
max_retries=5,
|
|
33
|
+
initial_delay_ms=5000,
|
|
34
|
+
retry_status_codes=[429, 500, 502, 503, 504],
|
|
35
|
+
on_retry=lambda attempt, delay_ms, error: print(
|
|
36
|
+
f"Anthropic API retry attempt {attempt} after {delay_ms}ms delay due to error: {str(error)}"
|
|
37
|
+
),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def generate_completion(
|
|
41
|
+
self, options: LLMRequestOptions, history: List[LLMMessage]
|
|
42
|
+
) -> LLMResponse:
|
|
43
|
+
try:
|
|
44
|
+
import httpx
|
|
45
|
+
except ImportError:
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"Missing httpx. Install pytrilogy[ai] to use AnthropicProvider."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Separate system messages from user/assistant messages
|
|
51
|
+
system_messages = [msg.content for msg in history if msg.role == "system"]
|
|
52
|
+
conversation_messages = [
|
|
53
|
+
{"role": msg.role, "content": msg.content}
|
|
54
|
+
for msg in history
|
|
55
|
+
if msg.role != "system"
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
|
|
60
|
+
def make_request():
|
|
61
|
+
with httpx.Client(timeout=60) as client:
|
|
62
|
+
payload = {
|
|
63
|
+
"model": self.model,
|
|
64
|
+
"messages": conversation_messages,
|
|
65
|
+
"max_tokens": options.max_tokens or 10000,
|
|
66
|
+
# "temperature": options.temperature or 0.7,
|
|
67
|
+
# "top_p": options.top_p if hasattr(options, "top_p") else 1.0,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Add system parameter if there are system messages
|
|
71
|
+
if system_messages:
|
|
72
|
+
# Combine multiple system messages with newlines
|
|
73
|
+
payload["system"] = "\n\n".join(system_messages)
|
|
74
|
+
|
|
75
|
+
response = client.post(
|
|
76
|
+
url=self.base_completion_url,
|
|
77
|
+
headers={
|
|
78
|
+
"Content-Type": "application/json",
|
|
79
|
+
"x-api-key": self.api_key,
|
|
80
|
+
"anthropic-version": "2023-06-01",
|
|
81
|
+
},
|
|
82
|
+
json=payload,
|
|
83
|
+
)
|
|
84
|
+
response.raise_for_status()
|
|
85
|
+
return response.json()
|
|
86
|
+
|
|
87
|
+
data = fetch_with_retry(make_request, self.retry_options)
|
|
88
|
+
|
|
89
|
+
return LLMResponse(
|
|
90
|
+
text=data["content"][0]["text"],
|
|
91
|
+
usage=UsageDict(
|
|
92
|
+
prompt_tokens=data["usage"]["input_tokens"],
|
|
93
|
+
completion_tokens=data["usage"]["output_tokens"],
|
|
94
|
+
total_tokens=data["usage"]["input_tokens"]
|
|
95
|
+
+ data["usage"]["output_tokens"],
|
|
96
|
+
),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
except httpx.HTTPStatusError as error:
|
|
100
|
+
error_detail = error.response.text
|
|
101
|
+
raise Exception(
|
|
102
|
+
f"Anthropic API error ({error.response.status_code}): {error_detail}"
|
|
103
|
+
)
|
|
104
|
+
except Exception as error:
|
|
105
|
+
raise Exception(f"Anthropic API error: {str(error)}")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from trilogy.ai.enums import Provider
|
|
5
|
+
from trilogy.ai.models import LLMMessage, LLMRequestOptions, LLMResponse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LLMProvider(ABC):
|
|
9
|
+
def __init__(self, name: str, api_key: str, model: str, provider: Provider):
|
|
10
|
+
self.api_key = api_key
|
|
11
|
+
self.models: List[str] = []
|
|
12
|
+
self.name = name
|
|
13
|
+
self.model = model
|
|
14
|
+
self.type = provider
|
|
15
|
+
self.error: Optional[str] = None
|
|
16
|
+
|
|
17
|
+
# Abstract method to be implemented by specific providers
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def generate_completion(
|
|
20
|
+
self, options: LLMRequestOptions, history: List[LLMMessage]
|
|
21
|
+
) -> LLMResponse:
|
|
22
|
+
pass
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from os import environ
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from trilogy.ai.enums import Provider
|
|
5
|
+
from trilogy.ai.models import LLMMessage, LLMResponse, UsageDict
|
|
6
|
+
|
|
7
|
+
from .base import LLMProvider, LLMRequestOptions
|
|
8
|
+
from .utils import RetryOptions, fetch_with_retry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleProvider(LLMProvider):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
name: str,
|
|
15
|
+
model: str,
|
|
16
|
+
api_key: str | None = None,
|
|
17
|
+
retry_options: Optional[RetryOptions] = None,
|
|
18
|
+
):
|
|
19
|
+
api_key = api_key or environ.get("GOOGLE_API_KEY")
|
|
20
|
+
if not api_key:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
"API key argument or environment variable GOOGLE_API_KEY is required"
|
|
23
|
+
)
|
|
24
|
+
super().__init__(name, api_key, model, Provider.GOOGLE)
|
|
25
|
+
self.base_model_url = "https://generativelanguage.googleapis.com/v1/models"
|
|
26
|
+
self.base_completion_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
27
|
+
self.models: List[str] = []
|
|
28
|
+
self.type = Provider.GOOGLE
|
|
29
|
+
self.retry_options = retry_options or RetryOptions(
|
|
30
|
+
max_retries=3,
|
|
31
|
+
initial_delay_ms=30000, # 30s default for Google's 429 rate limits
|
|
32
|
+
retry_status_codes=[429, 500, 502, 503, 504],
|
|
33
|
+
on_retry=lambda attempt, delay_ms, error: print(
|
|
34
|
+
f"Google API retry attempt {attempt} after {delay_ms}ms delay due to error: {str(error)}"
|
|
35
|
+
),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def _convert_to_gemini_history(
|
|
39
|
+
self, messages: List[LLMMessage]
|
|
40
|
+
) -> List[Dict[str, Any]]:
|
|
41
|
+
"""Convert standard message format to Gemini format."""
|
|
42
|
+
return [
|
|
43
|
+
{
|
|
44
|
+
"role": "model" if msg.role == "assistant" else "user",
|
|
45
|
+
"parts": [{"text": msg.content}],
|
|
46
|
+
}
|
|
47
|
+
for msg in messages
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
def generate_completion(
|
|
51
|
+
self, options: LLMRequestOptions, history: List[LLMMessage]
|
|
52
|
+
) -> LLMResponse:
|
|
53
|
+
try:
|
|
54
|
+
import httpx
|
|
55
|
+
except ImportError:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
"Missing httpx. Install pytrilogy[ai] to use GoogleProvider."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Convert messages to Gemini format
|
|
61
|
+
gemini_history = self._convert_to_gemini_history(history)
|
|
62
|
+
|
|
63
|
+
# Separate system message if present
|
|
64
|
+
system_instruction = None
|
|
65
|
+
contents = gemini_history
|
|
66
|
+
|
|
67
|
+
# Check if first message is a system message
|
|
68
|
+
if history and history[0].role == "system":
|
|
69
|
+
system_instruction = {"parts": [{"text": history[0].content}]}
|
|
70
|
+
contents = gemini_history[1:] # Remove system message from history
|
|
71
|
+
|
|
72
|
+
# Build the request URL
|
|
73
|
+
url = f"{self.base_completion_url}/models/{self.model}:generateContent"
|
|
74
|
+
|
|
75
|
+
# Build request body
|
|
76
|
+
request_body: Dict[str, Any] = {"contents": contents, "generationConfig": {}}
|
|
77
|
+
|
|
78
|
+
# Add system instruction if present
|
|
79
|
+
if system_instruction:
|
|
80
|
+
request_body["systemInstruction"] = system_instruction
|
|
81
|
+
|
|
82
|
+
# Add generation config options
|
|
83
|
+
if options.temperature is not None:
|
|
84
|
+
request_body["generationConfig"]["temperature"] = options.temperature
|
|
85
|
+
|
|
86
|
+
if options.max_tokens is not None:
|
|
87
|
+
request_body["generationConfig"]["maxOutputTokens"] = options.max_tokens
|
|
88
|
+
|
|
89
|
+
if options.top_p is not None:
|
|
90
|
+
request_body["generationConfig"]["topP"] = options.top_p
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
# Make the API request with retry logic using a lambda
|
|
94
|
+
response = fetch_with_retry(
|
|
95
|
+
fetch_fn=lambda: httpx.post(
|
|
96
|
+
url,
|
|
97
|
+
headers={
|
|
98
|
+
"Content-Type": "application/json",
|
|
99
|
+
"x-goog-api-key": self.api_key,
|
|
100
|
+
},
|
|
101
|
+
json=request_body,
|
|
102
|
+
timeout=60.0,
|
|
103
|
+
),
|
|
104
|
+
options=self.retry_options,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
response.raise_for_status()
|
|
108
|
+
data = response.json()
|
|
109
|
+
|
|
110
|
+
# Extract text from response
|
|
111
|
+
candidates = data.get("candidates", [])
|
|
112
|
+
if not candidates:
|
|
113
|
+
raise Exception("No candidates returned from Google API")
|
|
114
|
+
|
|
115
|
+
content = candidates[0].get("content", {})
|
|
116
|
+
parts = content.get("parts", [])
|
|
117
|
+
|
|
118
|
+
if not parts:
|
|
119
|
+
raise Exception("No parts in response content")
|
|
120
|
+
|
|
121
|
+
text = parts[0].get("text", "")
|
|
122
|
+
|
|
123
|
+
# Extract usage metadata
|
|
124
|
+
usage_metadata = data.get("usageMetadata", {})
|
|
125
|
+
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
|
126
|
+
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
|
127
|
+
|
|
128
|
+
return LLMResponse(
|
|
129
|
+
text=text,
|
|
130
|
+
usage=UsageDict(
|
|
131
|
+
prompt_tokens=prompt_tokens,
|
|
132
|
+
completion_tokens=completion_tokens,
|
|
133
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
except httpx.HTTPStatusError as error:
|
|
137
|
+
error_detail = error.response.text
|
|
138
|
+
raise Exception(
|
|
139
|
+
f"Google API error ({error.response.status_code}): {error_detail}"
|
|
140
|
+
)
|
|
141
|
+
except Exception as error:
|
|
142
|
+
raise Exception(f"Google API error: {str(error)}")
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from os import environ
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from trilogy.ai.enums import Provider
|
|
5
|
+
from trilogy.ai.models import LLMMessage, LLMResponse, UsageDict
|
|
6
|
+
|
|
7
|
+
from .base import LLMProvider, LLMRequestOptions
|
|
8
|
+
from .utils import RetryOptions, fetch_with_retry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(LLMProvider):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
name: str,
|
|
15
|
+
model: str,
|
|
16
|
+
api_key: str | None = None,
|
|
17
|
+
retry_options: Optional[RetryOptions] = None,
|
|
18
|
+
):
|
|
19
|
+
api_key = api_key or environ.get("OPENAI_API_KEY")
|
|
20
|
+
if not api_key:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
"API key argument or environment variable OPENAI_API_KEY is required"
|
|
23
|
+
)
|
|
24
|
+
super().__init__(name, api_key, model, Provider.OPENAI)
|
|
25
|
+
self.base_completion_url = "https://api.openai.com/v1/chat/completions"
|
|
26
|
+
self.base_model_url = "https://api.openai.com/v1/models"
|
|
27
|
+
self.models: List[str] = []
|
|
28
|
+
self.type = Provider.OPENAI
|
|
29
|
+
|
|
30
|
+
self.retry_options = retry_options or RetryOptions(
|
|
31
|
+
max_retries=3,
|
|
32
|
+
initial_delay_ms=1000,
|
|
33
|
+
retry_status_codes=[429, 500, 502, 503, 504], # Add common API error codes
|
|
34
|
+
on_retry=lambda attempt, delay_ms, error: print(
|
|
35
|
+
f"Retry attempt {attempt} after {delay_ms}ms delay due to error: {str(error)}"
|
|
36
|
+
),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def generate_completion(
|
|
40
|
+
self, options: LLMRequestOptions, history: List[LLMMessage]
|
|
41
|
+
) -> LLMResponse:
|
|
42
|
+
try:
|
|
43
|
+
import httpx
|
|
44
|
+
except ImportError:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"Missing httpx. Install pytrilogy[ai] to use OpenAIProvider."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
messages: List[dict] = []
|
|
50
|
+
messages = [{"role": msg.role, "content": msg.content} for msg in history]
|
|
51
|
+
try:
|
|
52
|
+
|
|
53
|
+
def make_request():
|
|
54
|
+
with httpx.Client(timeout=30) as client:
|
|
55
|
+
payload = {
|
|
56
|
+
"model": self.model,
|
|
57
|
+
"messages": messages,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
response = client.post(
|
|
61
|
+
url=self.base_completion_url,
|
|
62
|
+
headers={
|
|
63
|
+
"Content-Type": "application/json",
|
|
64
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
65
|
+
},
|
|
66
|
+
json=payload,
|
|
67
|
+
)
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
return response.json()
|
|
70
|
+
|
|
71
|
+
data = fetch_with_retry(make_request, self.retry_options)
|
|
72
|
+
return LLMResponse(
|
|
73
|
+
text=data["choices"][0]["message"]["content"],
|
|
74
|
+
usage=UsageDict(
|
|
75
|
+
prompt_tokens=data["usage"]["prompt_tokens"],
|
|
76
|
+
completion_tokens=data["usage"]["completion_tokens"],
|
|
77
|
+
total_tokens=data["usage"]["total_tokens"],
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
except httpx.HTTPStatusError as error:
|
|
81
|
+
# Capture the response body text
|
|
82
|
+
error_detail = error.response.text
|
|
83
|
+
raise Exception(
|
|
84
|
+
f"OpenAI API error ({error.response.status_code}): {error_detail}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
except Exception as error:
|
|
88
|
+
raise Exception(f"OpenAI API error: {str(error)}")
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Callable, List, TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class RetryOptions:
|
|
10
|
+
max_retries: int = 3
|
|
11
|
+
initial_delay_ms: int = 1000
|
|
12
|
+
retry_status_codes: List[int] = field(
|
|
13
|
+
default_factory=lambda: [429, 500, 502, 503, 504, 525]
|
|
14
|
+
)
|
|
15
|
+
on_retry: Callable[[int, int, Exception], None] | None = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def fetch_with_retry(fetch_fn: Callable[[], T], options: RetryOptions) -> T:
|
|
19
|
+
from httpx import HTTPError
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Retry a fetch operation with exponential backoff.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
fetch_fn: Function that performs the fetch operation
|
|
26
|
+
options: Retry configuration options
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The result from the successful fetch operation
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
The last exception encountered if all retries fail
|
|
33
|
+
"""
|
|
34
|
+
from httpx import HTTPStatusError, TimeoutException
|
|
35
|
+
|
|
36
|
+
last_error = None
|
|
37
|
+
delay_ms = options.initial_delay_ms
|
|
38
|
+
|
|
39
|
+
for attempt in range(options.max_retries + 1):
|
|
40
|
+
try:
|
|
41
|
+
return fetch_fn()
|
|
42
|
+
except (HTTPError, TimeoutException) as error:
|
|
43
|
+
last_error = error
|
|
44
|
+
should_retry = False
|
|
45
|
+
|
|
46
|
+
if isinstance(error, HTTPStatusError):
|
|
47
|
+
if (
|
|
48
|
+
options.retry_status_codes
|
|
49
|
+
and error.response.status_code in options.retry_status_codes
|
|
50
|
+
):
|
|
51
|
+
should_retry = True
|
|
52
|
+
elif isinstance(error, TimeoutException):
|
|
53
|
+
should_retry = True
|
|
54
|
+
if not should_retry or attempt >= options.max_retries:
|
|
55
|
+
raise
|
|
56
|
+
|
|
57
|
+
# Call the retry callback if provided
|
|
58
|
+
if options.on_retry:
|
|
59
|
+
options.on_retry(attempt + 1, delay_ms, error)
|
|
60
|
+
|
|
61
|
+
# Wait before retrying with exponential backoff
|
|
62
|
+
time.sleep(delay_ms / 1000.0)
|
|
63
|
+
delay_ms *= 2 # Exponential backoff
|
|
64
|
+
|
|
65
|
+
# This should never be reached, but just in case
|
|
66
|
+
if last_error:
|
|
67
|
+
raise last_error
|
|
68
|
+
raise Exception("Retry logic failed unexpectedly")
|
trilogy/executor.py
CHANGED
|
@@ -6,7 +6,7 @@ from sqlalchemy import text
|
|
|
6
6
|
|
|
7
7
|
from trilogy.constants import MagicConstants, Rendering, logger
|
|
8
8
|
from trilogy.core.enums import FunctionType, Granularity, IOType, ValidationScope
|
|
9
|
-
from trilogy.core.models.author import Concept, Function
|
|
9
|
+
from trilogy.core.models.author import Comment, Concept, Function
|
|
10
10
|
from trilogy.core.models.build import BuildFunction
|
|
11
11
|
from trilogy.core.models.core import ListWrapper, MapWrapper
|
|
12
12
|
from trilogy.core.models.datasource import Datasource
|
|
@@ -86,6 +86,10 @@ class Executor(object):
|
|
|
86
86
|
def execute_query(self, query) -> ResultProtocol | None:
|
|
87
87
|
raise NotImplementedError("Cannot execute type {}".format(type(query)))
|
|
88
88
|
|
|
89
|
+
@execute_query.register
|
|
90
|
+
def _(self, query: Comment) -> ResultProtocol | None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
89
93
|
@execute_query.register
|
|
90
94
|
def _(self, query: ConceptDeclarationStatement) -> ResultProtocol | None:
|
|
91
95
|
return handle_concept_declaration(query)
|
|
@@ -266,9 +270,22 @@ class Executor(object):
|
|
|
266
270
|
None,
|
|
267
271
|
]:
|
|
268
272
|
file = Path(file)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
273
|
+
candidates = [file, self.environment.working_path / file]
|
|
274
|
+
err = None
|
|
275
|
+
for file in candidates:
|
|
276
|
+
try:
|
|
277
|
+
with open(file, "r") as f:
|
|
278
|
+
command = f.read()
|
|
279
|
+
return self.parse_text_generator(
|
|
280
|
+
command, persist=persist, root=file
|
|
281
|
+
)
|
|
282
|
+
except FileNotFoundError as e:
|
|
283
|
+
if not err:
|
|
284
|
+
err = e
|
|
285
|
+
continue
|
|
286
|
+
if err:
|
|
287
|
+
raise err
|
|
288
|
+
raise FileNotFoundError(f"File {file} not found")
|
|
272
289
|
|
|
273
290
|
def parse_text(
|
|
274
291
|
self, command: str, persist: bool = False, root: Path | None = None
|
|
@@ -440,9 +457,20 @@ class Executor(object):
|
|
|
440
457
|
self, file: str | Path, non_interactive: bool = False
|
|
441
458
|
) -> List[ResultProtocol]:
|
|
442
459
|
file = Path(file)
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
460
|
+
candidates = [file, self.environment.working_path / file]
|
|
461
|
+
err = None
|
|
462
|
+
for file in candidates:
|
|
463
|
+
if not file.exists():
|
|
464
|
+
continue
|
|
465
|
+
with open(file, "r") as f:
|
|
466
|
+
command = f.read()
|
|
467
|
+
if file.suffix == ".sql":
|
|
468
|
+
return [self.execute_raw_sql(command)]
|
|
469
|
+
else:
|
|
470
|
+
return self.execute_text(command, non_interactive=non_interactive)
|
|
471
|
+
if err:
|
|
472
|
+
raise err
|
|
473
|
+
raise FileNotFoundError(f"File {file} not found")
|
|
446
474
|
|
|
447
475
|
def validate_environment(
|
|
448
476
|
self,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|