langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +1 -7
- langfun/core/llms/openai.py +142 -207
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +23 -422
- langfun/core/llms/vertexai_test.py +21 -335
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/RECORD +10 -10
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/top_level.txt +0 -0
langfun/core/llms/__init__.py
CHANGED
@@ -120,25 +120,19 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
|
|
120
120
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
121
121
|
|
122
122
|
from langfun.core.llms.vertexai import VertexAI
|
123
|
-
from langfun.core.llms.vertexai import VertexAIRest
|
124
|
-
from langfun.core.llms.vertexai import VertexAIRestGemini1_5
|
125
123
|
from langfun.core.llms.vertexai import VertexAIGemini1_5
|
126
124
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
127
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
|
128
125
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
129
126
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
130
127
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
|
131
128
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
|
132
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_Latest
|
133
129
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
134
130
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
135
131
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
136
132
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
|
137
133
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
138
134
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
139
|
-
from langfun.core.llms.vertexai import
|
140
|
-
from langfun.core.llms.vertexai import VertexAIPalm2_32K
|
141
|
-
from langfun.core.llms.vertexai import VertexAICustom
|
135
|
+
from langfun.core.llms.vertexai import VertexAIEndpoint
|
142
136
|
|
143
137
|
|
144
138
|
# LLaMA C++ models.
|
langfun/core/llms/openai.py
CHANGED
@@ -13,34 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from OpenAI."""
|
15
15
|
|
16
|
-
import collections
|
17
|
-
import functools
|
18
16
|
import os
|
19
17
|
from typing import Annotated, Any
|
20
18
|
|
21
19
|
import langfun.core as lf
|
22
20
|
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
23
22
|
import pyglove as pg
|
24
23
|
|
25
|
-
try:
|
26
|
-
import openai # pylint: disable=g-import-not-at-top
|
27
|
-
|
28
|
-
if hasattr(openai, 'error'):
|
29
|
-
# For lower versions.
|
30
|
-
ServiceUnavailableError = openai.error.ServiceUnavailableError
|
31
|
-
RateLimitError = openai.error.RateLimitError
|
32
|
-
APITimeoutError = (
|
33
|
-
openai.error.APIError,
|
34
|
-
'.*The server had an error processing your request'
|
35
|
-
)
|
36
|
-
else:
|
37
|
-
# For higher versions.
|
38
|
-
ServiceUnavailableError = getattr(openai, 'InternalServerError')
|
39
|
-
RateLimitError = getattr(openai, 'RateLimitError')
|
40
|
-
APITimeoutError = getattr(openai, 'APITimeoutError')
|
41
|
-
except ImportError:
|
42
|
-
openai = None
|
43
|
-
|
44
24
|
|
45
25
|
# From https://platform.openai.com/settings/organization/limits
|
46
26
|
_DEFAULT_TPM = 250000
|
@@ -289,7 +269,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
289
269
|
rpm=_DEFAULT_RPM,
|
290
270
|
tpm=_DEFAULT_TPM
|
291
271
|
),
|
292
|
-
# GPT-3 instruction-tuned models
|
272
|
+
# GPT-3 instruction-tuned models (Deprecated)
|
293
273
|
'text-curie-001': pg.Dict(
|
294
274
|
in_service=False,
|
295
275
|
rpm=_DEFAULT_RPM,
|
@@ -325,9 +305,9 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
325
305
|
rpm=_DEFAULT_RPM,
|
326
306
|
tpm=_DEFAULT_TPM
|
327
307
|
),
|
328
|
-
# GPT-3 base models
|
308
|
+
# GPT-3 base models that are still in service.
|
329
309
|
'babbage-002': pg.Dict(
|
330
|
-
in_service=
|
310
|
+
in_service=True,
|
331
311
|
rpm=_DEFAULT_RPM,
|
332
312
|
tpm=_DEFAULT_TPM
|
333
313
|
),
|
@@ -340,7 +320,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
340
320
|
|
341
321
|
|
342
322
|
@lf.use_init_args(['model'])
|
343
|
-
class OpenAI(
|
323
|
+
class OpenAI(rest.REST):
|
344
324
|
"""OpenAI model."""
|
345
325
|
|
346
326
|
model: pg.typing.Annotated[
|
@@ -348,7 +328,9 @@ class OpenAI(lf.LanguageModel):
|
|
348
328
|
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
349
329
|
),
|
350
330
|
'The name of the model to use.',
|
351
|
-
]
|
331
|
+
]
|
332
|
+
|
333
|
+
api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
|
352
334
|
|
353
335
|
multimodal: Annotated[
|
354
336
|
bool,
|
@@ -372,27 +354,45 @@ class OpenAI(lf.LanguageModel):
|
|
372
354
|
),
|
373
355
|
] = None
|
374
356
|
|
357
|
+
project: Annotated[
|
358
|
+
str | None,
|
359
|
+
(
|
360
|
+
'Project. If None, the key will be read from environment '
|
361
|
+
"variable 'OPENAI_PROJECT'. Based on the value, usages from "
|
362
|
+
"these API requests will count against the project's quota. "
|
363
|
+
),
|
364
|
+
] = None
|
365
|
+
|
375
366
|
def _on_bound(self):
|
376
367
|
super()._on_bound()
|
377
|
-
self.
|
378
|
-
|
379
|
-
|
380
|
-
'Please install "langfun[llm-openai]" to use OpenAI models.'
|
381
|
-
)
|
368
|
+
self._api_key = None
|
369
|
+
self._organization = None
|
370
|
+
self._project = None
|
382
371
|
|
383
|
-
|
384
|
-
def _api_initialized(self):
|
372
|
+
def _initialize(self):
|
385
373
|
api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None)
|
386
374
|
if not api_key:
|
387
375
|
raise ValueError(
|
388
376
|
'Please specify `api_key` during `__init__` or set environment '
|
389
377
|
'variable `OPENAI_API_KEY` with your OpenAI API key.'
|
390
378
|
)
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
379
|
+
self._api_key = api_key
|
380
|
+
self._organization = self.organization or os.environ.get(
|
381
|
+
'OPENAI_ORGANIZATION', None
|
382
|
+
)
|
383
|
+
self._project = self.project or os.environ.get('OPENAI_PROJECT', None)
|
384
|
+
|
385
|
+
@property
|
386
|
+
def headers(self) -> dict[str, Any]:
|
387
|
+
headers = {
|
388
|
+
'Content-Type': 'application/json',
|
389
|
+
'Authorization': f'Bearer {self._api_key}',
|
390
|
+
}
|
391
|
+
if self._organization:
|
392
|
+
headers['OpenAI-Organization'] = self._organization
|
393
|
+
if self._project:
|
394
|
+
headers['OpenAI-Project'] = self._project
|
395
|
+
return headers
|
396
396
|
|
397
397
|
@property
|
398
398
|
def model_id(self) -> str:
|
@@ -428,23 +428,16 @@ class OpenAI(lf.LanguageModel):
|
|
428
428
|
|
429
429
|
@classmethod
|
430
430
|
def dir(cls):
|
431
|
-
|
432
|
-
return openai.Model.list()
|
431
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
433
432
|
|
434
|
-
|
435
|
-
def is_chat_model(self):
|
436
|
-
"""Returns True if the model is a chat model."""
|
437
|
-
return self.model.startswith(('o1', 'gpt-4', 'gpt-3.5-turbo'))
|
438
|
-
|
439
|
-
def _get_request_args(
|
433
|
+
def _request_args(
|
440
434
|
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
441
435
|
# Reference:
|
442
436
|
# https://platform.openai.com/docs/api-reference/completions/create
|
443
437
|
# NOTE(daiyip): options.top_k is not applicable.
|
444
438
|
args = dict(
|
439
|
+
model=self.model,
|
445
440
|
n=options.n,
|
446
|
-
stream=False,
|
447
|
-
timeout=self.timeout,
|
448
441
|
top_logprobs=options.top_logprobs,
|
449
442
|
)
|
450
443
|
if options.logprobs:
|
@@ -453,13 +446,10 @@ class OpenAI(lf.LanguageModel):
|
|
453
446
|
raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
|
454
447
|
args['logprobs'] = options.logprobs
|
455
448
|
|
456
|
-
# Completion and ChatCompletion uses different parameter name for model.
|
457
|
-
args['model' if self.is_chat_model else 'engine'] = self.model
|
458
|
-
|
459
449
|
if options.temperature is not None:
|
460
450
|
args['temperature'] = options.temperature
|
461
451
|
if options.max_tokens is not None:
|
462
|
-
args['
|
452
|
+
args['max_completion_tokens'] = options.max_tokens
|
463
453
|
if options.top_p is not None:
|
464
454
|
args['top_p'] = options.top_p
|
465
455
|
if options.stop:
|
@@ -468,168 +458,113 @@ class OpenAI(lf.LanguageModel):
|
|
468
458
|
args['seed'] = options.random_seed
|
469
459
|
return args
|
470
460
|
|
471
|
-
def
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
**self._get_request_args(self.sampling_options),
|
487
|
-
)
|
488
|
-
# Parse response.
|
489
|
-
samples_by_index = collections.defaultdict(list)
|
490
|
-
for choice in response.choices:
|
491
|
-
samples_by_index[choice.index].append(
|
492
|
-
lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
|
493
|
-
)
|
494
|
-
|
495
|
-
n = len(samples_by_index)
|
496
|
-
estimated_cost = self.estimate_cost(
|
497
|
-
num_input_tokens=response.usage.prompt_tokens,
|
498
|
-
num_output_tokens=response.usage.completion_tokens,
|
499
|
-
)
|
500
|
-
usage = lf.LMSamplingUsage(
|
501
|
-
prompt_tokens=response.usage.prompt_tokens // n,
|
502
|
-
completion_tokens=response.usage.completion_tokens // n,
|
503
|
-
total_tokens=response.usage.total_tokens // n,
|
504
|
-
estimated_cost=(
|
505
|
-
None if estimated_cost is None else (estimated_cost // n)
|
506
|
-
)
|
507
|
-
)
|
508
|
-
return [
|
509
|
-
lf.LMSamplingResult(samples_by_index[index], usage=usage)
|
510
|
-
for index in sorted(samples_by_index.keys())
|
511
|
-
]
|
512
|
-
|
513
|
-
return self._parallel_execute_with_currency_control(
|
514
|
-
_open_ai_completion,
|
515
|
-
[prompts],
|
516
|
-
retry_on_errors=(
|
517
|
-
ServiceUnavailableError,
|
518
|
-
RateLimitError,
|
519
|
-
APITimeoutError,
|
520
|
-
),
|
521
|
-
)[0]
|
522
|
-
|
523
|
-
def _chat_complete_batch(
|
524
|
-
self, prompts: list[lf.Message]
|
525
|
-
) -> list[lf.LMSamplingResult]:
|
526
|
-
def _content_from_message(message: lf.Message):
|
527
|
-
if self.multimodal:
|
528
|
-
content = []
|
529
|
-
for chunk in message.chunk():
|
530
|
-
if isinstance(chunk, str):
|
531
|
-
item = dict(type='text', text=chunk)
|
532
|
-
elif isinstance(chunk, lf_modalities.Image):
|
533
|
-
if chunk.uri and chunk.uri.lower().startswith(
|
534
|
-
('http:', 'https:', 'ftp:')
|
535
|
-
):
|
536
|
-
uri = chunk.uri
|
537
|
-
else:
|
538
|
-
uri = chunk.content_uri
|
539
|
-
item = dict(type='image_url', image_url=dict(url=uri))
|
540
|
-
else:
|
541
|
-
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
542
|
-
content.append(item)
|
461
|
+
def _content_from_message(self, message: lf.Message):
|
462
|
+
"""Returns a OpenAI content object from a Langfun message."""
|
463
|
+
def _uri_from(chunk: lf.Modality) -> str:
|
464
|
+
if chunk.uri and chunk.uri.lower().startswith(
|
465
|
+
('http:', 'https:', 'ftp:')
|
466
|
+
):
|
467
|
+
return chunk.uri
|
468
|
+
return chunk.content_uri
|
469
|
+
|
470
|
+
content = []
|
471
|
+
for chunk in message.chunk():
|
472
|
+
if isinstance(chunk, str):
|
473
|
+
item = dict(type='text', text=chunk)
|
474
|
+
elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
|
475
|
+
item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
|
543
476
|
else:
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
def _open_ai_chat_completion(prompt: lf.Message):
|
548
|
-
request_args = self._get_request_args(self.sampling_options)
|
549
|
-
# Users could use `metadata_json_schema` to pass additional
|
550
|
-
# request arguments.
|
551
|
-
json_schema = prompt.metadata.get('json_schema')
|
552
|
-
if json_schema is not None:
|
553
|
-
if not isinstance(json_schema, dict):
|
554
|
-
raise ValueError(
|
555
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
556
|
-
)
|
557
|
-
if 'title' not in json_schema:
|
558
|
-
raise ValueError(
|
559
|
-
f'The root of `json_schema` must have a `title` field, '
|
560
|
-
f'got {json_schema!r}.'
|
561
|
-
)
|
562
|
-
request_args.update(
|
563
|
-
response_format=dict(
|
564
|
-
type='json_schema',
|
565
|
-
json_schema=dict(
|
566
|
-
schema=json_schema,
|
567
|
-
name=json_schema['title'],
|
568
|
-
strict=True,
|
569
|
-
)
|
570
|
-
)
|
571
|
-
)
|
572
|
-
prompt.metadata.formatted_text = (
|
573
|
-
prompt.text
|
574
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
575
|
-
+ pg.to_json_str(request_args['response_format'], json_indent=2)
|
576
|
-
)
|
477
|
+
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
478
|
+
content.append(item)
|
479
|
+
return content
|
577
480
|
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
481
|
+
def request(
|
482
|
+
self,
|
483
|
+
prompt: lf.Message,
|
484
|
+
sampling_options: lf.LMSamplingOptions
|
485
|
+
) -> dict[str, Any]:
|
486
|
+
"""Returns the JSON input for a message."""
|
487
|
+
request_args = self._request_args(sampling_options)
|
488
|
+
|
489
|
+
# Users could use `metadata_json_schema` to pass additional
|
490
|
+
# request arguments.
|
491
|
+
json_schema = prompt.metadata.get('json_schema')
|
492
|
+
if json_schema is not None:
|
493
|
+
if not isinstance(json_schema, dict):
|
494
|
+
raise ValueError(
|
495
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
586
496
|
)
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
samples = []
|
593
|
-
for choice in response.choices:
|
594
|
-
logprobs = None
|
595
|
-
choice_logprobs = getattr(choice, 'logprobs', None)
|
596
|
-
if choice_logprobs:
|
597
|
-
logprobs = [
|
598
|
-
(
|
599
|
-
t.token,
|
600
|
-
t.logprob,
|
601
|
-
[(tt.token, tt.logprob) for tt in t.top_logprobs],
|
602
|
-
)
|
603
|
-
for t in choice_logprobs.content
|
604
|
-
]
|
605
|
-
samples.append(
|
606
|
-
lf.LMSample(
|
607
|
-
choice.message.content,
|
608
|
-
score=0.0,
|
609
|
-
logprobs=logprobs,
|
610
|
-
)
|
497
|
+
if 'title' not in json_schema:
|
498
|
+
raise ValueError(
|
499
|
+
f'The root of `json_schema` must have a `title` field, '
|
500
|
+
f'got {json_schema!r}.'
|
611
501
|
)
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
estimated_cost=self.estimate_cost(
|
620
|
-
num_input_tokens=response.usage.prompt_tokens,
|
621
|
-
num_output_tokens=response.usage.completion_tokens,
|
502
|
+
request_args.update(
|
503
|
+
response_format=dict(
|
504
|
+
type='json_schema',
|
505
|
+
json_schema=dict(
|
506
|
+
schema=json_schema,
|
507
|
+
name=json_schema['title'],
|
508
|
+
strict=True,
|
622
509
|
)
|
623
|
-
)
|
510
|
+
)
|
511
|
+
)
|
512
|
+
prompt.metadata.formatted_text = (
|
513
|
+
prompt.text
|
514
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
515
|
+
+ pg.to_json_str(request_args['response_format'], json_indent=2)
|
516
|
+
)
|
517
|
+
|
518
|
+
# Prepare messages.
|
519
|
+
messages = []
|
520
|
+
# Users could use `metadata_system_message` to pass system message.
|
521
|
+
system_message = prompt.metadata.get('system_message')
|
522
|
+
if system_message:
|
523
|
+
system_message = lf.SystemMessage.from_value(system_message)
|
524
|
+
messages.append(
|
525
|
+
dict(role='system',
|
526
|
+
content=self._content_from_message(system_message))
|
624
527
|
)
|
528
|
+
messages.append(
|
529
|
+
dict(role='user', content=self._content_from_message(prompt))
|
530
|
+
)
|
531
|
+
request = dict()
|
532
|
+
request.update(request_args)
|
533
|
+
request['messages'] = messages
|
534
|
+
return request
|
625
535
|
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
536
|
+
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
537
|
+
# Reference:
|
538
|
+
# https://platform.openai.com/docs/api-reference/chat/object
|
539
|
+
logprobs = None
|
540
|
+
choice_logprobs = choice.get('logprobs')
|
541
|
+
if choice_logprobs:
|
542
|
+
logprobs = [
|
543
|
+
(
|
544
|
+
t['token'],
|
545
|
+
t['logprob'],
|
546
|
+
[(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
|
547
|
+
)
|
548
|
+
for t in choice_logprobs['content']
|
549
|
+
]
|
550
|
+
return lf.LMSample(
|
551
|
+
choice['message']['content'],
|
552
|
+
score=0.0,
|
553
|
+
logprobs=logprobs,
|
554
|
+
)
|
555
|
+
|
556
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
557
|
+
usage = json['usage']
|
558
|
+
return lf.LMSamplingResult(
|
559
|
+
samples=[self._parse_choice(choice) for choice in json['choices']],
|
560
|
+
usage=lf.LMSamplingUsage(
|
561
|
+
prompt_tokens=usage['prompt_tokens'],
|
562
|
+
completion_tokens=usage['completion_tokens'],
|
563
|
+
total_tokens=usage['total_tokens'],
|
564
|
+
estimated_cost=self.estimate_cost(
|
565
|
+
num_input_tokens=usage['prompt_tokens'],
|
566
|
+
num_output_tokens=usage['completion_tokens'],
|
567
|
+
)
|
633
568
|
),
|
634
569
|
)
|
635
570
|
|