langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -120,25 +120,19 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
120
120
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
121
121
 
122
122
  from langfun.core.llms.vertexai import VertexAI
123
- from langfun.core.llms.vertexai import VertexAIRest
124
- from langfun.core.llms.vertexai import VertexAIRestGemini1_5
125
123
  from langfun.core.llms.vertexai import VertexAIGemini1_5
126
124
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
127
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
128
125
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
129
126
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
130
127
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
131
128
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
132
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_Latest
133
129
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
134
130
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
135
131
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
136
132
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
137
133
  from langfun.core.llms.vertexai import VertexAIGeminiPro1
138
134
  from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
139
- from langfun.core.llms.vertexai import VertexAIPalm2
140
- from langfun.core.llms.vertexai import VertexAIPalm2_32K
141
- from langfun.core.llms.vertexai import VertexAICustom
135
+ from langfun.core.llms.vertexai import VertexAIEndpoint
142
136
 
143
137
 
144
138
  # LLaMA C++ models.
@@ -13,34 +13,14 @@
13
13
  # limitations under the License.
14
14
  """Language models from OpenAI."""
15
15
 
16
- import collections
17
- import functools
18
16
  import os
19
17
  from typing import Annotated, Any
20
18
 
21
19
  import langfun.core as lf
22
20
  from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
23
22
  import pyglove as pg
24
23
 
25
- try:
26
- import openai # pylint: disable=g-import-not-at-top
27
-
28
- if hasattr(openai, 'error'):
29
- # For lower versions.
30
- ServiceUnavailableError = openai.error.ServiceUnavailableError
31
- RateLimitError = openai.error.RateLimitError
32
- APITimeoutError = (
33
- openai.error.APIError,
34
- '.*The server had an error processing your request'
35
- )
36
- else:
37
- # For higher versions.
38
- ServiceUnavailableError = getattr(openai, 'InternalServerError')
39
- RateLimitError = getattr(openai, 'RateLimitError')
40
- APITimeoutError = getattr(openai, 'APITimeoutError')
41
- except ImportError:
42
- openai = None
43
-
44
24
 
45
25
  # From https://platform.openai.com/settings/organization/limits
46
26
  _DEFAULT_TPM = 250000
@@ -289,7 +269,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
289
269
  rpm=_DEFAULT_RPM,
290
270
  tpm=_DEFAULT_TPM
291
271
  ),
292
- # GPT-3 instruction-tuned models
272
+ # GPT-3 instruction-tuned models (Deprecated)
293
273
  'text-curie-001': pg.Dict(
294
274
  in_service=False,
295
275
  rpm=_DEFAULT_RPM,
@@ -325,9 +305,9 @@ SUPPORTED_MODELS_AND_SETTINGS = {
325
305
  rpm=_DEFAULT_RPM,
326
306
  tpm=_DEFAULT_TPM
327
307
  ),
328
- # GPT-3 base models
308
+ # GPT-3 base models that are still in service.
329
309
  'babbage-002': pg.Dict(
330
- in_service=False,
310
+ in_service=True,
331
311
  rpm=_DEFAULT_RPM,
332
312
  tpm=_DEFAULT_TPM
333
313
  ),
@@ -340,7 +320,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
340
320
 
341
321
 
342
322
  @lf.use_init_args(['model'])
343
- class OpenAI(lf.LanguageModel):
323
+ class OpenAI(rest.REST):
344
324
  """OpenAI model."""
345
325
 
346
326
  model: pg.typing.Annotated[
@@ -348,7 +328,9 @@ class OpenAI(lf.LanguageModel):
348
328
  pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
349
329
  ),
350
330
  'The name of the model to use.',
351
- ] = 'gpt-3.5-turbo'
331
+ ]
332
+
333
+ api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
352
334
 
353
335
  multimodal: Annotated[
354
336
  bool,
@@ -372,27 +354,45 @@ class OpenAI(lf.LanguageModel):
372
354
  ),
373
355
  ] = None
374
356
 
357
+ project: Annotated[
358
+ str | None,
359
+ (
360
+ 'Project. If None, the key will be read from environment '
361
+ "variable 'OPENAI_PROJECT'. Based on the value, usages from "
362
+ "these API requests will count against the project's quota. "
363
+ ),
364
+ ] = None
365
+
375
366
  def _on_bound(self):
376
367
  super()._on_bound()
377
- self.__dict__.pop('_api_initialized', None)
378
- if openai is None:
379
- raise RuntimeError(
380
- 'Please install "langfun[llm-openai]" to use OpenAI models.'
381
- )
368
+ self._api_key = None
369
+ self._organization = None
370
+ self._project = None
382
371
 
383
- @functools.cached_property
384
- def _api_initialized(self):
372
+ def _initialize(self):
385
373
  api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None)
386
374
  if not api_key:
387
375
  raise ValueError(
388
376
  'Please specify `api_key` during `__init__` or set environment '
389
377
  'variable `OPENAI_API_KEY` with your OpenAI API key.'
390
378
  )
391
- openai.api_key = api_key
392
- org = self.organization or os.environ.get('OPENAI_ORGANIZATION', None)
393
- if org:
394
- openai.organization = org
395
- return True
379
+ self._api_key = api_key
380
+ self._organization = self.organization or os.environ.get(
381
+ 'OPENAI_ORGANIZATION', None
382
+ )
383
+ self._project = self.project or os.environ.get('OPENAI_PROJECT', None)
384
+
385
+ @property
386
+ def headers(self) -> dict[str, Any]:
387
+ headers = {
388
+ 'Content-Type': 'application/json',
389
+ 'Authorization': f'Bearer {self._api_key}',
390
+ }
391
+ if self._organization:
392
+ headers['OpenAI-Organization'] = self._organization
393
+ if self._project:
394
+ headers['OpenAI-Project'] = self._project
395
+ return headers
396
396
 
397
397
  @property
398
398
  def model_id(self) -> str:
@@ -428,23 +428,16 @@ class OpenAI(lf.LanguageModel):
428
428
 
429
429
  @classmethod
430
430
  def dir(cls):
431
- assert openai is not None
432
- return openai.Model.list()
431
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
433
432
 
434
- @property
435
- def is_chat_model(self):
436
- """Returns True if the model is a chat model."""
437
- return self.model.startswith(('o1', 'gpt-4', 'gpt-3.5-turbo'))
438
-
439
- def _get_request_args(
433
+ def _request_args(
440
434
  self, options: lf.LMSamplingOptions) -> dict[str, Any]:
441
435
  # Reference:
442
436
  # https://platform.openai.com/docs/api-reference/completions/create
443
437
  # NOTE(daiyip): options.top_k is not applicable.
444
438
  args = dict(
439
+ model=self.model,
445
440
  n=options.n,
446
- stream=False,
447
- timeout=self.timeout,
448
441
  top_logprobs=options.top_logprobs,
449
442
  )
450
443
  if options.logprobs:
@@ -453,13 +446,10 @@ class OpenAI(lf.LanguageModel):
453
446
  raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
454
447
  args['logprobs'] = options.logprobs
455
448
 
456
- # Completion and ChatCompletion uses different parameter name for model.
457
- args['model' if self.is_chat_model else 'engine'] = self.model
458
-
459
449
  if options.temperature is not None:
460
450
  args['temperature'] = options.temperature
461
451
  if options.max_tokens is not None:
462
- args['max_tokens'] = options.max_tokens
452
+ args['max_completion_tokens'] = options.max_tokens
463
453
  if options.top_p is not None:
464
454
  args['top_p'] = options.top_p
465
455
  if options.stop:
@@ -468,168 +458,113 @@ class OpenAI(lf.LanguageModel):
468
458
  args['seed'] = options.random_seed
469
459
  return args
470
460
 
471
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
472
- assert self._api_initialized
473
- if self.is_chat_model:
474
- return self._chat_complete_batch(prompts)
475
- else:
476
- return self._complete_batch(prompts)
477
-
478
- def _complete_batch(
479
- self, prompts: list[lf.Message]
480
- ) -> list[lf.LMSamplingResult]:
481
-
482
- def _open_ai_completion(prompts):
483
- assert openai is not None
484
- response = openai.Completion.create(
485
- prompt=[p.text for p in prompts],
486
- **self._get_request_args(self.sampling_options),
487
- )
488
- # Parse response.
489
- samples_by_index = collections.defaultdict(list)
490
- for choice in response.choices:
491
- samples_by_index[choice.index].append(
492
- lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
493
- )
494
-
495
- n = len(samples_by_index)
496
- estimated_cost = self.estimate_cost(
497
- num_input_tokens=response.usage.prompt_tokens,
498
- num_output_tokens=response.usage.completion_tokens,
499
- )
500
- usage = lf.LMSamplingUsage(
501
- prompt_tokens=response.usage.prompt_tokens // n,
502
- completion_tokens=response.usage.completion_tokens // n,
503
- total_tokens=response.usage.total_tokens // n,
504
- estimated_cost=(
505
- None if estimated_cost is None else (estimated_cost // n)
506
- )
507
- )
508
- return [
509
- lf.LMSamplingResult(samples_by_index[index], usage=usage)
510
- for index in sorted(samples_by_index.keys())
511
- ]
512
-
513
- return self._parallel_execute_with_currency_control(
514
- _open_ai_completion,
515
- [prompts],
516
- retry_on_errors=(
517
- ServiceUnavailableError,
518
- RateLimitError,
519
- APITimeoutError,
520
- ),
521
- )[0]
522
-
523
- def _chat_complete_batch(
524
- self, prompts: list[lf.Message]
525
- ) -> list[lf.LMSamplingResult]:
526
- def _content_from_message(message: lf.Message):
527
- if self.multimodal:
528
- content = []
529
- for chunk in message.chunk():
530
- if isinstance(chunk, str):
531
- item = dict(type='text', text=chunk)
532
- elif isinstance(chunk, lf_modalities.Image):
533
- if chunk.uri and chunk.uri.lower().startswith(
534
- ('http:', 'https:', 'ftp:')
535
- ):
536
- uri = chunk.uri
537
- else:
538
- uri = chunk.content_uri
539
- item = dict(type='image_url', image_url=dict(url=uri))
540
- else:
541
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
542
- content.append(item)
461
+ def _content_from_message(self, message: lf.Message):
462
+ """Returns a OpenAI content object from a Langfun message."""
463
+ def _uri_from(chunk: lf.Modality) -> str:
464
+ if chunk.uri and chunk.uri.lower().startswith(
465
+ ('http:', 'https:', 'ftp:')
466
+ ):
467
+ return chunk.uri
468
+ return chunk.content_uri
469
+
470
+ content = []
471
+ for chunk in message.chunk():
472
+ if isinstance(chunk, str):
473
+ item = dict(type='text', text=chunk)
474
+ elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
475
+ item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
543
476
  else:
544
- content = message.text
545
- return content
546
-
547
- def _open_ai_chat_completion(prompt: lf.Message):
548
- request_args = self._get_request_args(self.sampling_options)
549
- # Users could use `metadata_json_schema` to pass additional
550
- # request arguments.
551
- json_schema = prompt.metadata.get('json_schema')
552
- if json_schema is not None:
553
- if not isinstance(json_schema, dict):
554
- raise ValueError(
555
- f'`json_schema` must be a dict, got {json_schema!r}.'
556
- )
557
- if 'title' not in json_schema:
558
- raise ValueError(
559
- f'The root of `json_schema` must have a `title` field, '
560
- f'got {json_schema!r}.'
561
- )
562
- request_args.update(
563
- response_format=dict(
564
- type='json_schema',
565
- json_schema=dict(
566
- schema=json_schema,
567
- name=json_schema['title'],
568
- strict=True,
569
- )
570
- )
571
- )
572
- prompt.metadata.formatted_text = (
573
- prompt.text
574
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
575
- + pg.to_json_str(request_args['response_format'], json_indent=2)
576
- )
477
+ raise ValueError(f'Unsupported modality: {chunk!r}.')
478
+ content.append(item)
479
+ return content
577
480
 
578
- # Prepare messages.
579
- messages = []
580
- # Users could use `metadata_system_message` to pass system message.
581
- system_message = prompt.metadata.get('system_message')
582
- if system_message:
583
- system_message = lf.SystemMessage.from_value(system_message)
584
- messages.append(
585
- dict(role='system', content=_content_from_message(system_message))
481
+ def request(
482
+ self,
483
+ prompt: lf.Message,
484
+ sampling_options: lf.LMSamplingOptions
485
+ ) -> dict[str, Any]:
486
+ """Returns the JSON input for a message."""
487
+ request_args = self._request_args(sampling_options)
488
+
489
+ # Users could use `metadata_json_schema` to pass additional
490
+ # request arguments.
491
+ json_schema = prompt.metadata.get('json_schema')
492
+ if json_schema is not None:
493
+ if not isinstance(json_schema, dict):
494
+ raise ValueError(
495
+ f'`json_schema` must be a dict, got {json_schema!r}.'
586
496
  )
587
- messages.append(dict(role='user', content=_content_from_message(prompt)))
588
-
589
- assert openai is not None
590
- response = openai.ChatCompletion.create(messages=messages, **request_args)
591
-
592
- samples = []
593
- for choice in response.choices:
594
- logprobs = None
595
- choice_logprobs = getattr(choice, 'logprobs', None)
596
- if choice_logprobs:
597
- logprobs = [
598
- (
599
- t.token,
600
- t.logprob,
601
- [(tt.token, tt.logprob) for tt in t.top_logprobs],
602
- )
603
- for t in choice_logprobs.content
604
- ]
605
- samples.append(
606
- lf.LMSample(
607
- choice.message.content,
608
- score=0.0,
609
- logprobs=logprobs,
610
- )
497
+ if 'title' not in json_schema:
498
+ raise ValueError(
499
+ f'The root of `json_schema` must have a `title` field, '
500
+ f'got {json_schema!r}.'
611
501
  )
612
-
613
- return lf.LMSamplingResult(
614
- samples=samples,
615
- usage=lf.LMSamplingUsage(
616
- prompt_tokens=response.usage.prompt_tokens,
617
- completion_tokens=response.usage.completion_tokens,
618
- total_tokens=response.usage.total_tokens,
619
- estimated_cost=self.estimate_cost(
620
- num_input_tokens=response.usage.prompt_tokens,
621
- num_output_tokens=response.usage.completion_tokens,
502
+ request_args.update(
503
+ response_format=dict(
504
+ type='json_schema',
505
+ json_schema=dict(
506
+ schema=json_schema,
507
+ name=json_schema['title'],
508
+ strict=True,
622
509
  )
623
- ),
510
+ )
511
+ )
512
+ prompt.metadata.formatted_text = (
513
+ prompt.text
514
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
515
+ + pg.to_json_str(request_args['response_format'], json_indent=2)
516
+ )
517
+
518
+ # Prepare messages.
519
+ messages = []
520
+ # Users could use `metadata_system_message` to pass system message.
521
+ system_message = prompt.metadata.get('system_message')
522
+ if system_message:
523
+ system_message = lf.SystemMessage.from_value(system_message)
524
+ messages.append(
525
+ dict(role='system',
526
+ content=self._content_from_message(system_message))
624
527
  )
528
+ messages.append(
529
+ dict(role='user', content=self._content_from_message(prompt))
530
+ )
531
+ request = dict()
532
+ request.update(request_args)
533
+ request['messages'] = messages
534
+ return request
625
535
 
626
- return self._parallel_execute_with_currency_control(
627
- _open_ai_chat_completion,
628
- prompts,
629
- retry_on_errors=(
630
- ServiceUnavailableError,
631
- RateLimitError,
632
- APITimeoutError
536
+ def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
537
+ # Reference:
538
+ # https://platform.openai.com/docs/api-reference/chat/object
539
+ logprobs = None
540
+ choice_logprobs = choice.get('logprobs')
541
+ if choice_logprobs:
542
+ logprobs = [
543
+ (
544
+ t['token'],
545
+ t['logprob'],
546
+ [(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
547
+ )
548
+ for t in choice_logprobs['content']
549
+ ]
550
+ return lf.LMSample(
551
+ choice['message']['content'],
552
+ score=0.0,
553
+ logprobs=logprobs,
554
+ )
555
+
556
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
557
+ usage = json['usage']
558
+ return lf.LMSamplingResult(
559
+ samples=[self._parse_choice(choice) for choice in json['choices']],
560
+ usage=lf.LMSamplingUsage(
561
+ prompt_tokens=usage['prompt_tokens'],
562
+ completion_tokens=usage['completion_tokens'],
563
+ total_tokens=usage['total_tokens'],
564
+ estimated_cost=self.estimate_cost(
565
+ num_input_tokens=usage['prompt_tokens'],
566
+ num_output_tokens=usage['completion_tokens'],
567
+ )
633
568
  ),
634
569
  )
635
570