kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,20 +1,8 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
3
1
|
from enum import Enum
|
|
4
|
-
from
|
|
5
|
-
from typing import Any, Dict, List, NoReturn
|
|
2
|
+
from typing import Dict, List
|
|
6
3
|
|
|
7
|
-
import httpx
|
|
8
|
-
import requests
|
|
9
|
-
from langchain_aws import ChatBedrockConverse
|
|
10
|
-
from langchain_core.language_models.chat_models import BaseChatModel
|
|
11
|
-
from langchain_groq import ChatGroq
|
|
12
|
-
from langchain_ollama import ChatOllama
|
|
13
|
-
from langchain_openai import ChatOpenAI
|
|
14
4
|
from pydantic import BaseModel
|
|
15
5
|
|
|
16
|
-
from ..utils.config import Config
|
|
17
|
-
|
|
18
6
|
"""
|
|
19
7
|
Provides model configuration and management for various LLM providers and models.
|
|
20
8
|
This module handles the integration with different AI model providers and their respective models,
|
|
@@ -32,6 +20,8 @@ class ModelProviderName(str, Enum):
|
|
|
32
20
|
amazon_bedrock = "amazon_bedrock"
|
|
33
21
|
ollama = "ollama"
|
|
34
22
|
openrouter = "openrouter"
|
|
23
|
+
fireworks_ai = "fireworks_ai"
|
|
24
|
+
kiln_fine_tune = "kiln_fine_tune"
|
|
35
25
|
|
|
36
26
|
|
|
37
27
|
class ModelFamily(str, Enum):
|
|
@@ -46,6 +36,8 @@ class ModelFamily(str, Enum):
|
|
|
46
36
|
gemma = "gemma"
|
|
47
37
|
gemini = "gemini"
|
|
48
38
|
claude = "claude"
|
|
39
|
+
mixtral = "mixtral"
|
|
40
|
+
qwen = "qwen"
|
|
49
41
|
|
|
50
42
|
|
|
51
43
|
# Where models have instruct and raw versions, instruct is default and raw is specified
|
|
@@ -58,6 +50,7 @@ class ModelName(str, Enum):
|
|
|
58
50
|
llama_3_1_8b = "llama_3_1_8b"
|
|
59
51
|
llama_3_1_70b = "llama_3_1_70b"
|
|
60
52
|
llama_3_1_405b = "llama_3_1_405b"
|
|
53
|
+
llama_3_2_1b = "llama_3_2_1b"
|
|
61
54
|
llama_3_2_3b = "llama_3_2_3b"
|
|
62
55
|
llama_3_2_11b = "llama_3_2_11b"
|
|
63
56
|
llama_3_2_90b = "llama_3_2_90b"
|
|
@@ -75,6 +68,9 @@ class ModelName(str, Enum):
|
|
|
75
68
|
gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
|
|
76
69
|
gemini_1_5_pro = "gemini_1_5_pro"
|
|
77
70
|
nemotron_70b = "nemotron_70b"
|
|
71
|
+
mixtral_8x7b = "mixtral_8x7b"
|
|
72
|
+
qwen_2p5_7b = "qwen_2p5_7b"
|
|
73
|
+
qwen_2p5_72b = "qwen_2p5_72b"
|
|
78
74
|
|
|
79
75
|
|
|
80
76
|
class KilnModelProvider(BaseModel):
|
|
@@ -84,13 +80,20 @@ class KilnModelProvider(BaseModel):
|
|
|
84
80
|
Attributes:
|
|
85
81
|
name: The provider's identifier
|
|
86
82
|
supports_structured_output: Whether the provider supports structured output formats
|
|
83
|
+
supports_data_gen: Whether the provider supports data generation
|
|
84
|
+
untested_model: Whether the model is untested (typically user added). The supports_ fields are not applicable.
|
|
85
|
+
provider_finetune_id: The finetune ID for the provider, if applicable
|
|
87
86
|
provider_options: Additional provider-specific configuration options
|
|
87
|
+
adapter_options: Additional options specific to the adapter. Top level key should be adapter ID.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
name: ModelProviderName
|
|
91
91
|
supports_structured_output: bool = True
|
|
92
92
|
supports_data_gen: bool = True
|
|
93
|
+
untested_model: bool = False
|
|
94
|
+
provider_finetune_id: str | None = None
|
|
93
95
|
provider_options: Dict = {}
|
|
96
|
+
adapter_options: Dict = {}
|
|
94
97
|
|
|
95
98
|
|
|
96
99
|
class KilnModel(BaseModel):
|
|
@@ -122,6 +125,7 @@ built_in_models: List[KilnModel] = [
|
|
|
122
125
|
KilnModelProvider(
|
|
123
126
|
name=ModelProviderName.openai,
|
|
124
127
|
provider_options={"model": "gpt-4o-mini"},
|
|
128
|
+
provider_finetune_id="gpt-4o-mini-2024-07-18",
|
|
125
129
|
),
|
|
126
130
|
KilnModelProvider(
|
|
127
131
|
name=ModelProviderName.openrouter,
|
|
@@ -138,6 +142,7 @@ built_in_models: List[KilnModel] = [
|
|
|
138
142
|
KilnModelProvider(
|
|
139
143
|
name=ModelProviderName.openai,
|
|
140
144
|
provider_options={"model": "gpt-4o"},
|
|
145
|
+
provider_finetune_id="gpt-4o-2024-08-06",
|
|
141
146
|
),
|
|
142
147
|
KilnModelProvider(
|
|
143
148
|
name=ModelProviderName.openrouter,
|
|
@@ -257,6 +262,15 @@ built_in_models: List[KilnModel] = [
|
|
|
257
262
|
supports_data_gen=False,
|
|
258
263
|
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
|
|
259
264
|
),
|
|
265
|
+
KilnModelProvider(
|
|
266
|
+
name=ModelProviderName.fireworks_ai,
|
|
267
|
+
supports_structured_output=False,
|
|
268
|
+
supports_data_gen=False,
|
|
269
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
|
270
|
+
provider_options={
|
|
271
|
+
"model": "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
|
272
|
+
},
|
|
273
|
+
),
|
|
260
274
|
],
|
|
261
275
|
),
|
|
262
276
|
# Llama 3.1 70b
|
|
@@ -271,7 +285,7 @@ built_in_models: List[KilnModel] = [
|
|
|
271
285
|
),
|
|
272
286
|
KilnModelProvider(
|
|
273
287
|
name=ModelProviderName.amazon_bedrock,
|
|
274
|
-
#
|
|
288
|
+
# AWS 70b not working as well as the others.
|
|
275
289
|
supports_structured_output=False,
|
|
276
290
|
supports_data_gen=False,
|
|
277
291
|
provider_options={
|
|
@@ -287,6 +301,13 @@ built_in_models: List[KilnModel] = [
|
|
|
287
301
|
name=ModelProviderName.ollama,
|
|
288
302
|
provider_options={"model": "llama3.1:70b"},
|
|
289
303
|
),
|
|
304
|
+
KilnModelProvider(
|
|
305
|
+
name=ModelProviderName.fireworks_ai,
|
|
306
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
|
|
307
|
+
provider_options={
|
|
308
|
+
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct"
|
|
309
|
+
},
|
|
310
|
+
),
|
|
290
311
|
],
|
|
291
312
|
),
|
|
292
313
|
# Llama 3.1 405b
|
|
@@ -311,6 +332,13 @@ built_in_models: List[KilnModel] = [
|
|
|
311
332
|
name=ModelProviderName.openrouter,
|
|
312
333
|
provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
|
|
313
334
|
),
|
|
335
|
+
KilnModelProvider(
|
|
336
|
+
name=ModelProviderName.fireworks_ai,
|
|
337
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
338
|
+
provider_options={
|
|
339
|
+
"model": "accounts/fireworks/models/llama-v3p1-405b-instruct"
|
|
340
|
+
},
|
|
341
|
+
),
|
|
314
342
|
],
|
|
315
343
|
),
|
|
316
344
|
# Mistral Nemo
|
|
@@ -348,6 +376,35 @@ built_in_models: List[KilnModel] = [
|
|
|
348
376
|
),
|
|
349
377
|
],
|
|
350
378
|
),
|
|
379
|
+
# Llama 3.2 1B
|
|
380
|
+
KilnModel(
|
|
381
|
+
family=ModelFamily.llama,
|
|
382
|
+
name=ModelName.llama_3_2_1b,
|
|
383
|
+
friendly_name="Llama 3.2 1B",
|
|
384
|
+
providers=[
|
|
385
|
+
KilnModelProvider(
|
|
386
|
+
name=ModelProviderName.openrouter,
|
|
387
|
+
supports_structured_output=False,
|
|
388
|
+
supports_data_gen=False,
|
|
389
|
+
provider_options={"model": "meta-llama/llama-3.2-1b-instruct"},
|
|
390
|
+
),
|
|
391
|
+
KilnModelProvider(
|
|
392
|
+
name=ModelProviderName.ollama,
|
|
393
|
+
supports_structured_output=False,
|
|
394
|
+
supports_data_gen=False,
|
|
395
|
+
provider_options={"model": "llama3.2:1b"},
|
|
396
|
+
),
|
|
397
|
+
KilnModelProvider(
|
|
398
|
+
name=ModelProviderName.fireworks_ai,
|
|
399
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p2-1b-instruct",
|
|
400
|
+
supports_structured_output=False,
|
|
401
|
+
supports_data_gen=False,
|
|
402
|
+
provider_options={
|
|
403
|
+
"model": "accounts/fireworks/models/llama-v3p2-1b-instruct"
|
|
404
|
+
},
|
|
405
|
+
),
|
|
406
|
+
],
|
|
407
|
+
),
|
|
351
408
|
# Llama 3.2 3B
|
|
352
409
|
KilnModel(
|
|
353
410
|
family=ModelFamily.llama,
|
|
@@ -366,6 +423,15 @@ built_in_models: List[KilnModel] = [
|
|
|
366
423
|
supports_data_gen=False,
|
|
367
424
|
provider_options={"model": "llama3.2"},
|
|
368
425
|
),
|
|
426
|
+
KilnModelProvider(
|
|
427
|
+
name=ModelProviderName.fireworks_ai,
|
|
428
|
+
provider_finetune_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
|
|
429
|
+
supports_structured_output=False,
|
|
430
|
+
supports_data_gen=False,
|
|
431
|
+
provider_options={
|
|
432
|
+
"model": "accounts/fireworks/models/llama-v3p2-3b-instruct"
|
|
433
|
+
},
|
|
434
|
+
),
|
|
369
435
|
],
|
|
370
436
|
),
|
|
371
437
|
# Llama 3.2 11B
|
|
@@ -376,9 +442,12 @@ built_in_models: List[KilnModel] = [
|
|
|
376
442
|
providers=[
|
|
377
443
|
KilnModelProvider(
|
|
378
444
|
name=ModelProviderName.openrouter,
|
|
379
|
-
supports_structured_output=False,
|
|
380
|
-
supports_data_gen=False,
|
|
381
445
|
provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
|
|
446
|
+
adapter_options={
|
|
447
|
+
"langchain": {
|
|
448
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
449
|
+
}
|
|
450
|
+
},
|
|
382
451
|
),
|
|
383
452
|
KilnModelProvider(
|
|
384
453
|
name=ModelProviderName.ollama,
|
|
@@ -386,6 +455,18 @@ built_in_models: List[KilnModel] = [
|
|
|
386
455
|
supports_data_gen=False,
|
|
387
456
|
provider_options={"model": "llama3.2-vision"},
|
|
388
457
|
),
|
|
458
|
+
KilnModelProvider(
|
|
459
|
+
name=ModelProviderName.fireworks_ai,
|
|
460
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
461
|
+
provider_options={
|
|
462
|
+
"model": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct"
|
|
463
|
+
},
|
|
464
|
+
adapter_options={
|
|
465
|
+
"langchain": {
|
|
466
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
467
|
+
}
|
|
468
|
+
},
|
|
469
|
+
),
|
|
389
470
|
],
|
|
390
471
|
),
|
|
391
472
|
# Llama 3.2 90B
|
|
@@ -396,16 +477,29 @@ built_in_models: List[KilnModel] = [
|
|
|
396
477
|
providers=[
|
|
397
478
|
KilnModelProvider(
|
|
398
479
|
name=ModelProviderName.openrouter,
|
|
399
|
-
supports_structured_output=False,
|
|
400
|
-
supports_data_gen=False,
|
|
401
480
|
provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
|
|
481
|
+
adapter_options={
|
|
482
|
+
"langchain": {
|
|
483
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
484
|
+
}
|
|
485
|
+
},
|
|
402
486
|
),
|
|
403
487
|
KilnModelProvider(
|
|
404
488
|
name=ModelProviderName.ollama,
|
|
405
|
-
supports_structured_output=False,
|
|
406
|
-
supports_data_gen=False,
|
|
407
489
|
provider_options={"model": "llama3.2-vision:90b"},
|
|
408
490
|
),
|
|
491
|
+
KilnModelProvider(
|
|
492
|
+
name=ModelProviderName.fireworks_ai,
|
|
493
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
494
|
+
provider_options={
|
|
495
|
+
"model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
|
|
496
|
+
},
|
|
497
|
+
adapter_options={
|
|
498
|
+
"langchain": {
|
|
499
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
500
|
+
}
|
|
501
|
+
},
|
|
502
|
+
),
|
|
409
503
|
],
|
|
410
504
|
),
|
|
411
505
|
# Phi 3.5
|
|
@@ -427,6 +521,15 @@ built_in_models: List[KilnModel] = [
|
|
|
427
521
|
supports_data_gen=False,
|
|
428
522
|
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
|
|
429
523
|
),
|
|
524
|
+
KilnModelProvider(
|
|
525
|
+
name=ModelProviderName.fireworks_ai,
|
|
526
|
+
supports_structured_output=False,
|
|
527
|
+
supports_data_gen=False,
|
|
528
|
+
# No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
|
|
529
|
+
provider_options={
|
|
530
|
+
"model": "accounts/fireworks/models/phi-3-vision-128k-instruct"
|
|
531
|
+
},
|
|
532
|
+
),
|
|
430
533
|
],
|
|
431
534
|
),
|
|
432
535
|
# Gemma 2 2.6b
|
|
@@ -465,6 +568,7 @@ built_in_models: List[KilnModel] = [
|
|
|
465
568
|
supports_data_gen=False,
|
|
466
569
|
provider_options={"model": "google/gemma-2-9b-it"},
|
|
467
570
|
),
|
|
571
|
+
# fireworks AI errors - not allowing system role. Exclude until resolved.
|
|
468
572
|
],
|
|
469
573
|
),
|
|
470
574
|
# Gemma 2 27b
|
|
@@ -488,290 +592,100 @@ built_in_models: List[KilnModel] = [
|
|
|
488
592
|
),
|
|
489
593
|
],
|
|
490
594
|
),
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
case ModelProviderName.openrouter:
|
|
526
|
-
return "OpenRouter"
|
|
527
|
-
case ModelProviderName.groq:
|
|
528
|
-
return "Groq"
|
|
529
|
-
case ModelProviderName.ollama:
|
|
530
|
-
return "Ollama"
|
|
531
|
-
case ModelProviderName.openai:
|
|
532
|
-
return "OpenAI"
|
|
533
|
-
case _:
|
|
534
|
-
# triggers pyright warning if I miss a case
|
|
535
|
-
raise_exhaustive_error(enum_id)
|
|
536
|
-
|
|
537
|
-
return "Unknown provider: " + id
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
541
|
-
raise ValueError(f"Unhandled enum value: {value}")
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
@dataclass
|
|
545
|
-
class ModelProviderWarning:
|
|
546
|
-
required_config_keys: List[str]
|
|
547
|
-
message: str
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
551
|
-
ModelProviderName.amazon_bedrock: ModelProviderWarning(
|
|
552
|
-
required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
|
|
553
|
-
message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
|
|
554
|
-
),
|
|
555
|
-
ModelProviderName.openrouter: ModelProviderWarning(
|
|
556
|
-
required_config_keys=["open_router_api_key"],
|
|
557
|
-
message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
|
|
595
|
+
# Mixtral 8x7B
|
|
596
|
+
KilnModel(
|
|
597
|
+
family=ModelFamily.mixtral,
|
|
598
|
+
name=ModelName.mixtral_8x7b,
|
|
599
|
+
friendly_name="Mixtral 8x7B",
|
|
600
|
+
providers=[
|
|
601
|
+
KilnModelProvider(
|
|
602
|
+
name=ModelProviderName.fireworks_ai,
|
|
603
|
+
provider_options={
|
|
604
|
+
"model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
605
|
+
},
|
|
606
|
+
provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
607
|
+
adapter_options={
|
|
608
|
+
"langchain": {
|
|
609
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
610
|
+
}
|
|
611
|
+
},
|
|
612
|
+
),
|
|
613
|
+
KilnModelProvider(
|
|
614
|
+
name=ModelProviderName.openrouter,
|
|
615
|
+
provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
|
|
616
|
+
adapter_options={
|
|
617
|
+
"langchain": {
|
|
618
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
619
|
+
}
|
|
620
|
+
},
|
|
621
|
+
),
|
|
622
|
+
KilnModelProvider(
|
|
623
|
+
name=ModelProviderName.ollama,
|
|
624
|
+
supports_structured_output=False,
|
|
625
|
+
supports_data_gen=False,
|
|
626
|
+
provider_options={"model": "mixtral"},
|
|
627
|
+
),
|
|
628
|
+
],
|
|
558
629
|
),
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
630
|
+
# Qwen 2.5 7B
|
|
631
|
+
KilnModel(
|
|
632
|
+
family=ModelFamily.qwen,
|
|
633
|
+
name=ModelName.qwen_2p5_7b,
|
|
634
|
+
friendly_name="Qwen 2.5 7B",
|
|
635
|
+
providers=[
|
|
636
|
+
KilnModelProvider(
|
|
637
|
+
name=ModelProviderName.openrouter,
|
|
638
|
+
provider_options={"model": "qwen/qwen-2.5-7b-instruct"},
|
|
639
|
+
# Tool calls not supported. JSON doesn't error, but fails.
|
|
640
|
+
supports_structured_output=False,
|
|
641
|
+
supports_data_gen=False,
|
|
642
|
+
adapter_options={
|
|
643
|
+
"langchain": {
|
|
644
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
645
|
+
}
|
|
646
|
+
},
|
|
647
|
+
),
|
|
648
|
+
KilnModelProvider(
|
|
649
|
+
name=ModelProviderName.ollama,
|
|
650
|
+
provider_options={"model": "qwen2.5"},
|
|
651
|
+
),
|
|
652
|
+
],
|
|
562
653
|
),
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
654
|
+
# Qwen 2.5 72B
|
|
655
|
+
KilnModel(
|
|
656
|
+
family=ModelFamily.qwen,
|
|
657
|
+
name=ModelName.qwen_2p5_72b,
|
|
658
|
+
friendly_name="Qwen 2.5 72B",
|
|
659
|
+
providers=[
|
|
660
|
+
KilnModelProvider(
|
|
661
|
+
name=ModelProviderName.openrouter,
|
|
662
|
+
provider_options={"model": "qwen/qwen-2.5-72b-instruct"},
|
|
663
|
+
# Not consistent with structure data. Works sometimes but not often
|
|
664
|
+
supports_structured_output=False,
|
|
665
|
+
supports_data_gen=False,
|
|
666
|
+
adapter_options={
|
|
667
|
+
"langchain": {
|
|
668
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
669
|
+
}
|
|
670
|
+
},
|
|
671
|
+
),
|
|
672
|
+
KilnModelProvider(
|
|
673
|
+
name=ModelProviderName.ollama,
|
|
674
|
+
provider_options={"model": "qwen2.5:72b"},
|
|
675
|
+
),
|
|
676
|
+
KilnModelProvider(
|
|
677
|
+
name=ModelProviderName.fireworks_ai,
|
|
678
|
+
provider_options={
|
|
679
|
+
"model": "accounts/fireworks/models/qwen2p5-72b-instruct"
|
|
680
|
+
},
|
|
681
|
+
# Fireworks will start tuning, but it never finishes.
|
|
682
|
+
# provider_finetune_id="accounts/fireworks/models/qwen2p5-72b-instruct",
|
|
683
|
+
adapter_options={
|
|
684
|
+
"langchain": {
|
|
685
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
686
|
+
}
|
|
687
|
+
},
|
|
688
|
+
),
|
|
689
|
+
],
|
|
566
690
|
),
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
def get_config_value(key: str):
|
|
571
|
-
try:
|
|
572
|
-
return Config.shared().__getattr__(key)
|
|
573
|
-
except AttributeError:
|
|
574
|
-
return None
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
def check_provider_warnings(provider_name: ModelProviderName):
|
|
578
|
-
"""
|
|
579
|
-
Validates that required configuration is present for a given provider.
|
|
580
|
-
|
|
581
|
-
Args:
|
|
582
|
-
provider_name: The provider to check
|
|
583
|
-
|
|
584
|
-
Raises:
|
|
585
|
-
ValueError: If required configuration keys are missing
|
|
586
|
-
"""
|
|
587
|
-
warning_check = provider_warnings.get(provider_name)
|
|
588
|
-
if warning_check is None:
|
|
589
|
-
return
|
|
590
|
-
for key in warning_check.required_config_keys:
|
|
591
|
-
if get_config_value(key) is None:
|
|
592
|
-
raise ValueError(warning_check.message)
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
async def langchain_model_from(
|
|
596
|
-
name: str, provider_name: str | None = None
|
|
597
|
-
) -> BaseChatModel:
|
|
598
|
-
"""
|
|
599
|
-
Creates a LangChain chat model instance for the specified model and provider.
|
|
600
|
-
|
|
601
|
-
Args:
|
|
602
|
-
name: The name of the model to instantiate
|
|
603
|
-
provider_name: Optional specific provider to use (defaults to first available)
|
|
604
|
-
|
|
605
|
-
Returns:
|
|
606
|
-
A configured LangChain chat model instance
|
|
607
|
-
|
|
608
|
-
Raises:
|
|
609
|
-
ValueError: If the model/provider combination is invalid or misconfigured
|
|
610
|
-
"""
|
|
611
|
-
if name not in ModelName.__members__:
|
|
612
|
-
raise ValueError(f"Invalid name: {name}")
|
|
613
|
-
|
|
614
|
-
# Select the model from built_in_models using the name
|
|
615
|
-
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
616
|
-
if model is None:
|
|
617
|
-
raise ValueError(f"Model {name} not found")
|
|
618
|
-
|
|
619
|
-
# If a provider is provided, select the provider from the model's provider_config
|
|
620
|
-
provider: KilnModelProvider | None = None
|
|
621
|
-
if model.providers is None or len(model.providers) == 0:
|
|
622
|
-
raise ValueError(f"Model {name} has no providers")
|
|
623
|
-
elif provider_name is None:
|
|
624
|
-
# TODO: priority order
|
|
625
|
-
provider = model.providers[0]
|
|
626
|
-
else:
|
|
627
|
-
provider = next(
|
|
628
|
-
filter(lambda p: p.name == provider_name, model.providers), None
|
|
629
|
-
)
|
|
630
|
-
if provider is None:
|
|
631
|
-
raise ValueError(f"Provider {provider_name} not found for model {name}")
|
|
632
|
-
|
|
633
|
-
check_provider_warnings(provider.name)
|
|
634
|
-
|
|
635
|
-
if provider.name == ModelProviderName.openai:
|
|
636
|
-
api_key = Config.shared().open_ai_api_key
|
|
637
|
-
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
638
|
-
elif provider.name == ModelProviderName.groq:
|
|
639
|
-
api_key = Config.shared().groq_api_key
|
|
640
|
-
if api_key is None:
|
|
641
|
-
raise ValueError(
|
|
642
|
-
"Attempted to use Groq without an API key set. "
|
|
643
|
-
"Get your API key from https://console.groq.com/keys"
|
|
644
|
-
)
|
|
645
|
-
return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
|
|
646
|
-
elif provider.name == ModelProviderName.amazon_bedrock:
|
|
647
|
-
api_key = Config.shared().bedrock_access_key
|
|
648
|
-
secret_key = Config.shared().bedrock_secret_key
|
|
649
|
-
# langchain doesn't allow passing these, so ugly hack to set env vars
|
|
650
|
-
os.environ["AWS_ACCESS_KEY_ID"] = api_key
|
|
651
|
-
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
|
|
652
|
-
return ChatBedrockConverse(
|
|
653
|
-
**provider.provider_options,
|
|
654
|
-
)
|
|
655
|
-
elif provider.name == ModelProviderName.ollama:
|
|
656
|
-
# Ollama model naming is pretty flexible. We try a few versions of the model name
|
|
657
|
-
potential_model_names = []
|
|
658
|
-
if "model" in provider.provider_options:
|
|
659
|
-
potential_model_names.append(provider.provider_options["model"])
|
|
660
|
-
if "model_aliases" in provider.provider_options:
|
|
661
|
-
potential_model_names.extend(provider.provider_options["model_aliases"])
|
|
662
|
-
|
|
663
|
-
# Get the list of models Ollama supports
|
|
664
|
-
ollama_connection = await get_ollama_connection()
|
|
665
|
-
if ollama_connection is None:
|
|
666
|
-
raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
|
|
667
|
-
|
|
668
|
-
for model_name in potential_model_names:
|
|
669
|
-
if ollama_model_supported(ollama_connection, model_name):
|
|
670
|
-
return ChatOllama(model=model_name, base_url=ollama_base_url())
|
|
671
|
-
|
|
672
|
-
raise ValueError(f"Model {name} not installed on Ollama")
|
|
673
|
-
elif provider.name == ModelProviderName.openrouter:
|
|
674
|
-
api_key = Config.shared().open_router_api_key
|
|
675
|
-
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
676
|
-
return ChatOpenAI(
|
|
677
|
-
**provider.provider_options,
|
|
678
|
-
openai_api_key=api_key, # type: ignore[arg-type]
|
|
679
|
-
openai_api_base=base_url, # type: ignore[arg-type]
|
|
680
|
-
default_headers={
|
|
681
|
-
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
682
|
-
"X-Title": "KilnAI",
|
|
683
|
-
},
|
|
684
|
-
)
|
|
685
|
-
else:
|
|
686
|
-
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
def ollama_base_url() -> str:
|
|
690
|
-
"""
|
|
691
|
-
Gets the base URL for Ollama API connections.
|
|
692
|
-
|
|
693
|
-
Returns:
|
|
694
|
-
The base URL to use for Ollama API calls, using environment variable if set
|
|
695
|
-
or falling back to localhost default
|
|
696
|
-
"""
|
|
697
|
-
env_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
698
|
-
if env_base_url is not None:
|
|
699
|
-
return env_base_url
|
|
700
|
-
return "http://localhost:11434"
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
async def ollama_online() -> bool:
|
|
704
|
-
"""
|
|
705
|
-
Checks if the Ollama service is available and responding.
|
|
706
|
-
|
|
707
|
-
Returns:
|
|
708
|
-
True if Ollama is available and responding, False otherwise
|
|
709
|
-
"""
|
|
710
|
-
try:
|
|
711
|
-
httpx.get(ollama_base_url() + "/api/tags")
|
|
712
|
-
except httpx.RequestError:
|
|
713
|
-
return False
|
|
714
|
-
return True
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
class OllamaConnection(BaseModel):
|
|
718
|
-
message: str
|
|
719
|
-
models: List[str]
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
# Parse the Ollama /api/tags response
|
|
723
|
-
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
724
|
-
# Build a list of models we support for Ollama from the built-in model list
|
|
725
|
-
supported_ollama_models = [
|
|
726
|
-
provider.provider_options["model"]
|
|
727
|
-
for model in built_in_models
|
|
728
|
-
for provider in model.providers
|
|
729
|
-
if provider.name == ModelProviderName.ollama
|
|
730
|
-
]
|
|
731
|
-
# Append model_aliases to supported_ollama_models
|
|
732
|
-
supported_ollama_models.extend(
|
|
733
|
-
[
|
|
734
|
-
alias
|
|
735
|
-
for model in built_in_models
|
|
736
|
-
for provider in model.providers
|
|
737
|
-
for alias in provider.provider_options.get("model_aliases", [])
|
|
738
|
-
]
|
|
739
|
-
)
|
|
740
|
-
|
|
741
|
-
if "models" in tags:
|
|
742
|
-
models = tags["models"]
|
|
743
|
-
if isinstance(models, list):
|
|
744
|
-
model_names = [model["model"] for model in models]
|
|
745
|
-
available_supported_models = [
|
|
746
|
-
model
|
|
747
|
-
for model in model_names
|
|
748
|
-
if model in supported_ollama_models
|
|
749
|
-
or model in [f"{m}:latest" for m in supported_ollama_models]
|
|
750
|
-
]
|
|
751
|
-
if available_supported_models:
|
|
752
|
-
return OllamaConnection(
|
|
753
|
-
message="Ollama connected",
|
|
754
|
-
models=available_supported_models,
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
return OllamaConnection(
|
|
758
|
-
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
759
|
-
models=[],
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
async def get_ollama_connection() -> OllamaConnection | None:
|
|
764
|
-
"""
|
|
765
|
-
Gets the connection status for Ollama.
|
|
766
|
-
"""
|
|
767
|
-
try:
|
|
768
|
-
tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
|
|
769
|
-
|
|
770
|
-
except Exception:
|
|
771
|
-
return None
|
|
772
|
-
|
|
773
|
-
return parse_ollama_tags(tags)
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
|
|
777
|
-
return model_name in conn.models or f"{model_name}:latest" in conn.models
|
|
691
|
+
]
|