kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,20 +1,8 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
3
1
|
from enum import Enum
|
|
4
|
-
from
|
|
5
|
-
from typing import Any, Dict, List, NoReturn
|
|
2
|
+
from typing import Dict, List
|
|
6
3
|
|
|
7
|
-
import httpx
|
|
8
|
-
import requests
|
|
9
|
-
from langchain_aws import ChatBedrockConverse
|
|
10
|
-
from langchain_core.language_models.chat_models import BaseChatModel
|
|
11
|
-
from langchain_groq import ChatGroq
|
|
12
|
-
from langchain_ollama import ChatOllama
|
|
13
|
-
from langchain_openai import ChatOpenAI
|
|
14
4
|
from pydantic import BaseModel
|
|
15
5
|
|
|
16
|
-
from ..utils.config import Config
|
|
17
|
-
|
|
18
6
|
"""
|
|
19
7
|
Provides model configuration and management for various LLM providers and models.
|
|
20
8
|
This module handles the integration with different AI model providers and their respective models,
|
|
@@ -32,6 +20,9 @@ class ModelProviderName(str, Enum):
|
|
|
32
20
|
amazon_bedrock = "amazon_bedrock"
|
|
33
21
|
ollama = "ollama"
|
|
34
22
|
openrouter = "openrouter"
|
|
23
|
+
fireworks_ai = "fireworks_ai"
|
|
24
|
+
kiln_fine_tune = "kiln_fine_tune"
|
|
25
|
+
kiln_custom_registry = "kiln_custom_registry"
|
|
35
26
|
|
|
36
27
|
|
|
37
28
|
class ModelFamily(str, Enum):
|
|
@@ -46,6 +37,8 @@ class ModelFamily(str, Enum):
|
|
|
46
37
|
gemma = "gemma"
|
|
47
38
|
gemini = "gemini"
|
|
48
39
|
claude = "claude"
|
|
40
|
+
mixtral = "mixtral"
|
|
41
|
+
qwen = "qwen"
|
|
49
42
|
|
|
50
43
|
|
|
51
44
|
# Where models have instruct and raw versions, instruct is default and raw is specified
|
|
@@ -58,9 +51,11 @@ class ModelName(str, Enum):
|
|
|
58
51
|
llama_3_1_8b = "llama_3_1_8b"
|
|
59
52
|
llama_3_1_70b = "llama_3_1_70b"
|
|
60
53
|
llama_3_1_405b = "llama_3_1_405b"
|
|
54
|
+
llama_3_2_1b = "llama_3_2_1b"
|
|
61
55
|
llama_3_2_3b = "llama_3_2_3b"
|
|
62
56
|
llama_3_2_11b = "llama_3_2_11b"
|
|
63
57
|
llama_3_2_90b = "llama_3_2_90b"
|
|
58
|
+
llama_3_3_70b = "llama_3_3_70b"
|
|
64
59
|
gpt_4o_mini = "gpt_4o_mini"
|
|
65
60
|
gpt_4o = "gpt_4o"
|
|
66
61
|
phi_3_5 = "phi_3_5"
|
|
@@ -75,6 +70,9 @@ class ModelName(str, Enum):
|
|
|
75
70
|
gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
|
|
76
71
|
gemini_1_5_pro = "gemini_1_5_pro"
|
|
77
72
|
nemotron_70b = "nemotron_70b"
|
|
73
|
+
mixtral_8x7b = "mixtral_8x7b"
|
|
74
|
+
qwen_2p5_7b = "qwen_2p5_7b"
|
|
75
|
+
qwen_2p5_72b = "qwen_2p5_72b"
|
|
78
76
|
|
|
79
77
|
|
|
80
78
|
class KilnModelProvider(BaseModel):
|
|
@@ -84,13 +82,20 @@ class KilnModelProvider(BaseModel):
|
|
|
84
82
|
Attributes:
|
|
85
83
|
name: The provider's identifier
|
|
86
84
|
supports_structured_output: Whether the provider supports structured output formats
|
|
85
|
+
supports_data_gen: Whether the provider supports data generation
|
|
86
|
+
untested_model: Whether the model is untested (typically user added). The supports_ fields are not applicable.
|
|
87
|
+
provider_finetune_id: The finetune ID for the provider, if applicable
|
|
87
88
|
provider_options: Additional provider-specific configuration options
|
|
89
|
+
adapter_options: Additional options specific to the adapter. Top level key should be adapter ID.
|
|
88
90
|
"""
|
|
89
91
|
|
|
90
92
|
name: ModelProviderName
|
|
91
93
|
supports_structured_output: bool = True
|
|
92
94
|
supports_data_gen: bool = True
|
|
95
|
+
untested_model: bool = False
|
|
96
|
+
provider_finetune_id: str | None = None
|
|
93
97
|
provider_options: Dict = {}
|
|
98
|
+
adapter_options: Dict = {}
|
|
94
99
|
|
|
95
100
|
|
|
96
101
|
class KilnModel(BaseModel):
|
|
@@ -122,6 +127,7 @@ built_in_models: List[KilnModel] = [
|
|
|
122
127
|
KilnModelProvider(
|
|
123
128
|
name=ModelProviderName.openai,
|
|
124
129
|
provider_options={"model": "gpt-4o-mini"},
|
|
130
|
+
provider_finetune_id="gpt-4o-mini-2024-07-18",
|
|
125
131
|
),
|
|
126
132
|
KilnModelProvider(
|
|
127
133
|
name=ModelProviderName.openrouter,
|
|
@@ -138,6 +144,7 @@ built_in_models: List[KilnModel] = [
|
|
|
138
144
|
KilnModelProvider(
|
|
139
145
|
name=ModelProviderName.openai,
|
|
140
146
|
provider_options={"model": "gpt-4o"},
|
|
147
|
+
provider_finetune_id="gpt-4o-2024-08-06",
|
|
141
148
|
),
|
|
142
149
|
KilnModelProvider(
|
|
143
150
|
name=ModelProviderName.openrouter,
|
|
@@ -257,6 +264,15 @@ built_in_models: List[KilnModel] = [
|
|
|
257
264
|
supports_data_gen=False,
|
|
258
265
|
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
|
|
259
266
|
),
|
|
267
|
+
KilnModelProvider(
|
|
268
|
+
name=ModelProviderName.fireworks_ai,
|
|
269
|
+
supports_structured_output=False,
|
|
270
|
+
supports_data_gen=False,
|
|
271
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
|
272
|
+
provider_options={
|
|
273
|
+
"model": "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
|
274
|
+
},
|
|
275
|
+
),
|
|
260
276
|
],
|
|
261
277
|
),
|
|
262
278
|
# Llama 3.1 70b
|
|
@@ -271,7 +287,7 @@ built_in_models: List[KilnModel] = [
|
|
|
271
287
|
),
|
|
272
288
|
KilnModelProvider(
|
|
273
289
|
name=ModelProviderName.amazon_bedrock,
|
|
274
|
-
#
|
|
290
|
+
# AWS 70b not working as well as the others.
|
|
275
291
|
supports_structured_output=False,
|
|
276
292
|
supports_data_gen=False,
|
|
277
293
|
provider_options={
|
|
@@ -287,6 +303,13 @@ built_in_models: List[KilnModel] = [
|
|
|
287
303
|
name=ModelProviderName.ollama,
|
|
288
304
|
provider_options={"model": "llama3.1:70b"},
|
|
289
305
|
),
|
|
306
|
+
KilnModelProvider(
|
|
307
|
+
name=ModelProviderName.fireworks_ai,
|
|
308
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
|
|
309
|
+
provider_options={
|
|
310
|
+
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct"
|
|
311
|
+
},
|
|
312
|
+
),
|
|
290
313
|
],
|
|
291
314
|
),
|
|
292
315
|
# Llama 3.1 405b
|
|
@@ -311,6 +334,13 @@ built_in_models: List[KilnModel] = [
|
|
|
311
334
|
name=ModelProviderName.openrouter,
|
|
312
335
|
provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
|
|
313
336
|
),
|
|
337
|
+
KilnModelProvider(
|
|
338
|
+
name=ModelProviderName.fireworks_ai,
|
|
339
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
340
|
+
provider_options={
|
|
341
|
+
"model": "accounts/fireworks/models/llama-v3p1-405b-instruct"
|
|
342
|
+
},
|
|
343
|
+
),
|
|
314
344
|
],
|
|
315
345
|
),
|
|
316
346
|
# Mistral Nemo
|
|
@@ -348,6 +378,35 @@ built_in_models: List[KilnModel] = [
|
|
|
348
378
|
),
|
|
349
379
|
],
|
|
350
380
|
),
|
|
381
|
+
# Llama 3.2 1B
|
|
382
|
+
KilnModel(
|
|
383
|
+
family=ModelFamily.llama,
|
|
384
|
+
name=ModelName.llama_3_2_1b,
|
|
385
|
+
friendly_name="Llama 3.2 1B",
|
|
386
|
+
providers=[
|
|
387
|
+
KilnModelProvider(
|
|
388
|
+
name=ModelProviderName.openrouter,
|
|
389
|
+
supports_structured_output=False,
|
|
390
|
+
supports_data_gen=False,
|
|
391
|
+
provider_options={"model": "meta-llama/llama-3.2-1b-instruct"},
|
|
392
|
+
),
|
|
393
|
+
KilnModelProvider(
|
|
394
|
+
name=ModelProviderName.ollama,
|
|
395
|
+
supports_structured_output=False,
|
|
396
|
+
supports_data_gen=False,
|
|
397
|
+
provider_options={"model": "llama3.2:1b"},
|
|
398
|
+
),
|
|
399
|
+
KilnModelProvider(
|
|
400
|
+
name=ModelProviderName.fireworks_ai,
|
|
401
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p2-1b-instruct",
|
|
402
|
+
supports_structured_output=False,
|
|
403
|
+
supports_data_gen=False,
|
|
404
|
+
provider_options={
|
|
405
|
+
"model": "accounts/fireworks/models/llama-v3p2-1b-instruct"
|
|
406
|
+
},
|
|
407
|
+
),
|
|
408
|
+
],
|
|
409
|
+
),
|
|
351
410
|
# Llama 3.2 3B
|
|
352
411
|
KilnModel(
|
|
353
412
|
family=ModelFamily.llama,
|
|
@@ -366,6 +425,15 @@ built_in_models: List[KilnModel] = [
|
|
|
366
425
|
supports_data_gen=False,
|
|
367
426
|
provider_options={"model": "llama3.2"},
|
|
368
427
|
),
|
|
428
|
+
KilnModelProvider(
|
|
429
|
+
name=ModelProviderName.fireworks_ai,
|
|
430
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
|
|
431
|
+
supports_structured_output=False,
|
|
432
|
+
supports_data_gen=False,
|
|
433
|
+
provider_options={
|
|
434
|
+
"model": "accounts/fireworks/models/llama-v3p2-3b-instruct"
|
|
435
|
+
},
|
|
436
|
+
),
|
|
369
437
|
],
|
|
370
438
|
),
|
|
371
439
|
# Llama 3.2 11B
|
|
@@ -376,9 +444,12 @@ built_in_models: List[KilnModel] = [
|
|
|
376
444
|
providers=[
|
|
377
445
|
KilnModelProvider(
|
|
378
446
|
name=ModelProviderName.openrouter,
|
|
379
|
-
supports_structured_output=False,
|
|
380
|
-
supports_data_gen=False,
|
|
381
447
|
provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
|
|
448
|
+
adapter_options={
|
|
449
|
+
"langchain": {
|
|
450
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
451
|
+
}
|
|
452
|
+
},
|
|
382
453
|
),
|
|
383
454
|
KilnModelProvider(
|
|
384
455
|
name=ModelProviderName.ollama,
|
|
@@ -386,6 +457,18 @@ built_in_models: List[KilnModel] = [
|
|
|
386
457
|
supports_data_gen=False,
|
|
387
458
|
provider_options={"model": "llama3.2-vision"},
|
|
388
459
|
),
|
|
460
|
+
KilnModelProvider(
|
|
461
|
+
name=ModelProviderName.fireworks_ai,
|
|
462
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
463
|
+
provider_options={
|
|
464
|
+
"model": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct"
|
|
465
|
+
},
|
|
466
|
+
adapter_options={
|
|
467
|
+
"langchain": {
|
|
468
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
469
|
+
}
|
|
470
|
+
},
|
|
471
|
+
),
|
|
389
472
|
],
|
|
390
473
|
),
|
|
391
474
|
# Llama 3.2 90B
|
|
@@ -396,15 +479,60 @@ built_in_models: List[KilnModel] = [
|
|
|
396
479
|
providers=[
|
|
397
480
|
KilnModelProvider(
|
|
398
481
|
name=ModelProviderName.openrouter,
|
|
399
|
-
supports_structured_output=False,
|
|
400
|
-
supports_data_gen=False,
|
|
401
482
|
provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
|
|
483
|
+
adapter_options={
|
|
484
|
+
"langchain": {
|
|
485
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
486
|
+
}
|
|
487
|
+
},
|
|
402
488
|
),
|
|
403
489
|
KilnModelProvider(
|
|
404
490
|
name=ModelProviderName.ollama,
|
|
491
|
+
provider_options={"model": "llama3.2-vision:90b"},
|
|
492
|
+
),
|
|
493
|
+
KilnModelProvider(
|
|
494
|
+
name=ModelProviderName.fireworks_ai,
|
|
495
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
496
|
+
provider_options={
|
|
497
|
+
"model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
|
|
498
|
+
},
|
|
499
|
+
adapter_options={
|
|
500
|
+
"langchain": {
|
|
501
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
502
|
+
}
|
|
503
|
+
},
|
|
504
|
+
),
|
|
505
|
+
],
|
|
506
|
+
),
|
|
507
|
+
# Llama 3.3 70B
|
|
508
|
+
KilnModel(
|
|
509
|
+
family=ModelFamily.llama,
|
|
510
|
+
name=ModelName.llama_3_3_70b,
|
|
511
|
+
friendly_name="Llama 3.3 70B",
|
|
512
|
+
providers=[
|
|
513
|
+
KilnModelProvider(
|
|
514
|
+
name=ModelProviderName.openrouter,
|
|
515
|
+
provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
|
|
516
|
+
# Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
|
|
405
517
|
supports_structured_output=False,
|
|
406
518
|
supports_data_gen=False,
|
|
407
|
-
|
|
519
|
+
adapter_options={
|
|
520
|
+
"langchain": {
|
|
521
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
522
|
+
}
|
|
523
|
+
},
|
|
524
|
+
),
|
|
525
|
+
KilnModelProvider(
|
|
526
|
+
name=ModelProviderName.ollama,
|
|
527
|
+
provider_options={"model": "llama3.3"},
|
|
528
|
+
),
|
|
529
|
+
KilnModelProvider(
|
|
530
|
+
name=ModelProviderName.fireworks_ai,
|
|
531
|
+
# Finetuning not live yet
|
|
532
|
+
# provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
|
|
533
|
+
provider_options={
|
|
534
|
+
"model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
|
535
|
+
},
|
|
408
536
|
),
|
|
409
537
|
],
|
|
410
538
|
),
|
|
@@ -427,6 +555,15 @@ built_in_models: List[KilnModel] = [
|
|
|
427
555
|
supports_data_gen=False,
|
|
428
556
|
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
|
|
429
557
|
),
|
|
558
|
+
KilnModelProvider(
|
|
559
|
+
name=ModelProviderName.fireworks_ai,
|
|
560
|
+
supports_structured_output=False,
|
|
561
|
+
supports_data_gen=False,
|
|
562
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
563
|
+
provider_options={
|
|
564
|
+
"model": "accounts/fireworks/models/phi-3-vision-128k-instruct"
|
|
565
|
+
},
|
|
566
|
+
),
|
|
430
567
|
],
|
|
431
568
|
),
|
|
432
569
|
# Gemma 2 2.6b
|
|
@@ -465,6 +602,7 @@ built_in_models: List[KilnModel] = [
|
|
|
465
602
|
supports_data_gen=False,
|
|
466
603
|
provider_options={"model": "google/gemma-2-9b-it"},
|
|
467
604
|
),
|
|
605
|
+
# fireworks AI errors - not allowing system role. Exclude until resolved.
|
|
468
606
|
],
|
|
469
607
|
),
|
|
470
608
|
# Gemma 2 27b
|
|
@@ -488,290 +626,88 @@ built_in_models: List[KilnModel] = [
|
|
|
488
626
|
),
|
|
489
627
|
],
|
|
490
628
|
),
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
Returns:
|
|
515
|
-
The human-readable name of the provider
|
|
516
|
-
|
|
517
|
-
Raises:
|
|
518
|
-
ValueError: If the provider ID is invalid or unhandled
|
|
519
|
-
"""
|
|
520
|
-
if id in ModelProviderName.__members__:
|
|
521
|
-
enum_id = ModelProviderName(id)
|
|
522
|
-
match enum_id:
|
|
523
|
-
case ModelProviderName.amazon_bedrock:
|
|
524
|
-
return "Amazon Bedrock"
|
|
525
|
-
case ModelProviderName.openrouter:
|
|
526
|
-
return "OpenRouter"
|
|
527
|
-
case ModelProviderName.groq:
|
|
528
|
-
return "Groq"
|
|
529
|
-
case ModelProviderName.ollama:
|
|
530
|
-
return "Ollama"
|
|
531
|
-
case ModelProviderName.openai:
|
|
532
|
-
return "OpenAI"
|
|
533
|
-
case _:
|
|
534
|
-
# triggers pyright warning if I miss a case
|
|
535
|
-
raise_exhaustive_error(enum_id)
|
|
536
|
-
|
|
537
|
-
return "Unknown provider: " + id
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
541
|
-
raise ValueError(f"Unhandled enum value: {value}")
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
@dataclass
|
|
545
|
-
class ModelProviderWarning:
|
|
546
|
-
required_config_keys: List[str]
|
|
547
|
-
message: str
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
551
|
-
ModelProviderName.amazon_bedrock: ModelProviderWarning(
|
|
552
|
-
required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
|
|
553
|
-
message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
|
|
554
|
-
),
|
|
555
|
-
ModelProviderName.openrouter: ModelProviderWarning(
|
|
556
|
-
required_config_keys=["open_router_api_key"],
|
|
557
|
-
message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
|
|
629
|
+
# Mixtral 8x7B
|
|
630
|
+
KilnModel(
|
|
631
|
+
family=ModelFamily.mixtral,
|
|
632
|
+
name=ModelName.mixtral_8x7b,
|
|
633
|
+
friendly_name="Mixtral 8x7B",
|
|
634
|
+
providers=[
|
|
635
|
+
KilnModelProvider(
|
|
636
|
+
name=ModelProviderName.openrouter,
|
|
637
|
+
provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
|
|
638
|
+
adapter_options={
|
|
639
|
+
"langchain": {
|
|
640
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
641
|
+
}
|
|
642
|
+
},
|
|
643
|
+
),
|
|
644
|
+
KilnModelProvider(
|
|
645
|
+
name=ModelProviderName.ollama,
|
|
646
|
+
supports_structured_output=False,
|
|
647
|
+
supports_data_gen=False,
|
|
648
|
+
provider_options={"model": "mixtral"},
|
|
649
|
+
),
|
|
650
|
+
],
|
|
558
651
|
),
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
652
|
+
# Qwen 2.5 7B
|
|
653
|
+
KilnModel(
|
|
654
|
+
family=ModelFamily.qwen,
|
|
655
|
+
name=ModelName.qwen_2p5_7b,
|
|
656
|
+
friendly_name="Qwen 2.5 7B",
|
|
657
|
+
providers=[
|
|
658
|
+
KilnModelProvider(
|
|
659
|
+
name=ModelProviderName.openrouter,
|
|
660
|
+
provider_options={"model": "qwen/qwen-2.5-7b-instruct"},
|
|
661
|
+
# Tool calls not supported. JSON doesn't error, but fails.
|
|
662
|
+
supports_structured_output=False,
|
|
663
|
+
supports_data_gen=False,
|
|
664
|
+
adapter_options={
|
|
665
|
+
"langchain": {
|
|
666
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
667
|
+
}
|
|
668
|
+
},
|
|
669
|
+
),
|
|
670
|
+
KilnModelProvider(
|
|
671
|
+
name=ModelProviderName.ollama,
|
|
672
|
+
provider_options={"model": "qwen2.5"},
|
|
673
|
+
),
|
|
674
|
+
],
|
|
562
675
|
),
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
676
|
+
# Qwen 2.5 72B
|
|
677
|
+
KilnModel(
|
|
678
|
+
family=ModelFamily.qwen,
|
|
679
|
+
name=ModelName.qwen_2p5_72b,
|
|
680
|
+
friendly_name="Qwen 2.5 72B",
|
|
681
|
+
providers=[
|
|
682
|
+
KilnModelProvider(
|
|
683
|
+
name=ModelProviderName.openrouter,
|
|
684
|
+
provider_options={"model": "qwen/qwen-2.5-72b-instruct"},
|
|
685
|
+
# Not consistent with structure data. Works sometimes but not often
|
|
686
|
+
supports_structured_output=False,
|
|
687
|
+
supports_data_gen=False,
|
|
688
|
+
adapter_options={
|
|
689
|
+
"langchain": {
|
|
690
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
691
|
+
}
|
|
692
|
+
},
|
|
693
|
+
),
|
|
694
|
+
KilnModelProvider(
|
|
695
|
+
name=ModelProviderName.ollama,
|
|
696
|
+
provider_options={"model": "qwen2.5:72b"},
|
|
697
|
+
),
|
|
698
|
+
KilnModelProvider(
|
|
699
|
+
name=ModelProviderName.fireworks_ai,
|
|
700
|
+
provider_options={
|
|
701
|
+
"model": "accounts/fireworks/models/qwen2p5-72b-instruct"
|
|
702
|
+
},
|
|
703
|
+
# Fireworks will start tuning, but it never finishes.
|
|
704
|
+
# provider_finetune_id="accounts/fireworks/models/qwen2p5-72b-instruct",
|
|
705
|
+
adapter_options={
|
|
706
|
+
"langchain": {
|
|
707
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
708
|
+
}
|
|
709
|
+
},
|
|
710
|
+
),
|
|
711
|
+
],
|
|
566
712
|
),
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
def get_config_value(key: str):
|
|
571
|
-
try:
|
|
572
|
-
return Config.shared().__getattr__(key)
|
|
573
|
-
except AttributeError:
|
|
574
|
-
return None
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
def check_provider_warnings(provider_name: ModelProviderName):
|
|
578
|
-
"""
|
|
579
|
-
Validates that required configuration is present for a given provider.
|
|
580
|
-
|
|
581
|
-
Args:
|
|
582
|
-
provider_name: The provider to check
|
|
583
|
-
|
|
584
|
-
Raises:
|
|
585
|
-
ValueError: If required configuration keys are missing
|
|
586
|
-
"""
|
|
587
|
-
warning_check = provider_warnings.get(provider_name)
|
|
588
|
-
if warning_check is None:
|
|
589
|
-
return
|
|
590
|
-
for key in warning_check.required_config_keys:
|
|
591
|
-
if get_config_value(key) is None:
|
|
592
|
-
raise ValueError(warning_check.message)
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
async def langchain_model_from(
|
|
596
|
-
name: str, provider_name: str | None = None
|
|
597
|
-
) -> BaseChatModel:
|
|
598
|
-
"""
|
|
599
|
-
Creates a LangChain chat model instance for the specified model and provider.
|
|
600
|
-
|
|
601
|
-
Args:
|
|
602
|
-
name: The name of the model to instantiate
|
|
603
|
-
provider_name: Optional specific provider to use (defaults to first available)
|
|
604
|
-
|
|
605
|
-
Returns:
|
|
606
|
-
A configured LangChain chat model instance
|
|
607
|
-
|
|
608
|
-
Raises:
|
|
609
|
-
ValueError: If the model/provider combination is invalid or misconfigured
|
|
610
|
-
"""
|
|
611
|
-
if name not in ModelName.__members__:
|
|
612
|
-
raise ValueError(f"Invalid name: {name}")
|
|
613
|
-
|
|
614
|
-
# Select the model from built_in_models using the name
|
|
615
|
-
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
616
|
-
if model is None:
|
|
617
|
-
raise ValueError(f"Model {name} not found")
|
|
618
|
-
|
|
619
|
-
# If a provider is provided, select the provider from the model's provider_config
|
|
620
|
-
provider: KilnModelProvider | None = None
|
|
621
|
-
if model.providers is None or len(model.providers) == 0:
|
|
622
|
-
raise ValueError(f"Model {name} has no providers")
|
|
623
|
-
elif provider_name is None:
|
|
624
|
-
# TODO: priority order
|
|
625
|
-
provider = model.providers[0]
|
|
626
|
-
else:
|
|
627
|
-
provider = next(
|
|
628
|
-
filter(lambda p: p.name == provider_name, model.providers), None
|
|
629
|
-
)
|
|
630
|
-
if provider is None:
|
|
631
|
-
raise ValueError(f"Provider {provider_name} not found for model {name}")
|
|
632
|
-
|
|
633
|
-
check_provider_warnings(provider.name)
|
|
634
|
-
|
|
635
|
-
if provider.name == ModelProviderName.openai:
|
|
636
|
-
api_key = Config.shared().open_ai_api_key
|
|
637
|
-
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
638
|
-
elif provider.name == ModelProviderName.groq:
|
|
639
|
-
api_key = Config.shared().groq_api_key
|
|
640
|
-
if api_key is None:
|
|
641
|
-
raise ValueError(
|
|
642
|
-
"Attempted to use Groq without an API key set. "
|
|
643
|
-
"Get your API key from https://console.groq.com/keys"
|
|
644
|
-
)
|
|
645
|
-
return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
|
|
646
|
-
elif provider.name == ModelProviderName.amazon_bedrock:
|
|
647
|
-
api_key = Config.shared().bedrock_access_key
|
|
648
|
-
secret_key = Config.shared().bedrock_secret_key
|
|
649
|
-
# langchain doesn't allow passing these, so ugly hack to set env vars
|
|
650
|
-
os.environ["AWS_ACCESS_KEY_ID"] = api_key
|
|
651
|
-
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
|
|
652
|
-
return ChatBedrockConverse(
|
|
653
|
-
**provider.provider_options,
|
|
654
|
-
)
|
|
655
|
-
elif provider.name == ModelProviderName.ollama:
|
|
656
|
-
# Ollama model naming is pretty flexible. We try a few versions of the model name
|
|
657
|
-
potential_model_names = []
|
|
658
|
-
if "model" in provider.provider_options:
|
|
659
|
-
potential_model_names.append(provider.provider_options["model"])
|
|
660
|
-
if "model_aliases" in provider.provider_options:
|
|
661
|
-
potential_model_names.extend(provider.provider_options["model_aliases"])
|
|
662
|
-
|
|
663
|
-
# Get the list of models Ollama supports
|
|
664
|
-
ollama_connection = await get_ollama_connection()
|
|
665
|
-
if ollama_connection is None:
|
|
666
|
-
raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
|
|
667
|
-
|
|
668
|
-
for model_name in potential_model_names:
|
|
669
|
-
if ollama_model_supported(ollama_connection, model_name):
|
|
670
|
-
return ChatOllama(model=model_name, base_url=ollama_base_url())
|
|
671
|
-
|
|
672
|
-
raise ValueError(f"Model {name} not installed on Ollama")
|
|
673
|
-
elif provider.name == ModelProviderName.openrouter:
|
|
674
|
-
api_key = Config.shared().open_router_api_key
|
|
675
|
-
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
676
|
-
return ChatOpenAI(
|
|
677
|
-
**provider.provider_options,
|
|
678
|
-
openai_api_key=api_key, # type: ignore[arg-type]
|
|
679
|
-
openai_api_base=base_url, # type: ignore[arg-type]
|
|
680
|
-
default_headers={
|
|
681
|
-
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
682
|
-
"X-Title": "KilnAI",
|
|
683
|
-
},
|
|
684
|
-
)
|
|
685
|
-
else:
|
|
686
|
-
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
def ollama_base_url() -> str:
|
|
690
|
-
"""
|
|
691
|
-
Gets the base URL for Ollama API connections.
|
|
692
|
-
|
|
693
|
-
Returns:
|
|
694
|
-
The base URL to use for Ollama API calls, using environment variable if set
|
|
695
|
-
or falling back to localhost default
|
|
696
|
-
"""
|
|
697
|
-
env_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
698
|
-
if env_base_url is not None:
|
|
699
|
-
return env_base_url
|
|
700
|
-
return "http://localhost:11434"
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
async def ollama_online() -> bool:
|
|
704
|
-
"""
|
|
705
|
-
Checks if the Ollama service is available and responding.
|
|
706
|
-
|
|
707
|
-
Returns:
|
|
708
|
-
True if Ollama is available and responding, False otherwise
|
|
709
|
-
"""
|
|
710
|
-
try:
|
|
711
|
-
httpx.get(ollama_base_url() + "/api/tags")
|
|
712
|
-
except httpx.RequestError:
|
|
713
|
-
return False
|
|
714
|
-
return True
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
class OllamaConnection(BaseModel):
|
|
718
|
-
message: str
|
|
719
|
-
models: List[str]
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
# Parse the Ollama /api/tags response
|
|
723
|
-
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
724
|
-
# Build a list of models we support for Ollama from the built-in model list
|
|
725
|
-
supported_ollama_models = [
|
|
726
|
-
provider.provider_options["model"]
|
|
727
|
-
for model in built_in_models
|
|
728
|
-
for provider in model.providers
|
|
729
|
-
if provider.name == ModelProviderName.ollama
|
|
730
|
-
]
|
|
731
|
-
# Append model_aliases to supported_ollama_models
|
|
732
|
-
supported_ollama_models.extend(
|
|
733
|
-
[
|
|
734
|
-
alias
|
|
735
|
-
for model in built_in_models
|
|
736
|
-
for provider in model.providers
|
|
737
|
-
for alias in provider.provider_options.get("model_aliases", [])
|
|
738
|
-
]
|
|
739
|
-
)
|
|
740
|
-
|
|
741
|
-
if "models" in tags:
|
|
742
|
-
models = tags["models"]
|
|
743
|
-
if isinstance(models, list):
|
|
744
|
-
model_names = [model["model"] for model in models]
|
|
745
|
-
available_supported_models = [
|
|
746
|
-
model
|
|
747
|
-
for model in model_names
|
|
748
|
-
if model in supported_ollama_models
|
|
749
|
-
or model in [f"{m}:latest" for m in supported_ollama_models]
|
|
750
|
-
]
|
|
751
|
-
if available_supported_models:
|
|
752
|
-
return OllamaConnection(
|
|
753
|
-
message="Ollama connected",
|
|
754
|
-
models=available_supported_models,
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
return OllamaConnection(
|
|
758
|
-
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
759
|
-
models=[],
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
async def get_ollama_connection() -> OllamaConnection | None:
|
|
764
|
-
"""
|
|
765
|
-
Gets the connection status for Ollama.
|
|
766
|
-
"""
|
|
767
|
-
try:
|
|
768
|
-
tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
|
|
769
|
-
|
|
770
|
-
except Exception:
|
|
771
|
-
return None
|
|
772
|
-
|
|
773
|
-
return parse_ollama_tags(tags)
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
|
|
777
|
-
return model_name in conn.models or f"{model_name}:latest" in conn.models
|
|
713
|
+
]
|