@posthog/ai 2.4.0 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@posthog/ai",
3
- "version": "2.4.0",
3
+ "version": "3.0.0",
4
4
  "description": "PostHog Node.js AI integrations",
5
5
  "repository": {
6
6
  "type": "git",
@@ -12,13 +12,8 @@
12
12
  "types": "lib/index.d.ts",
13
13
  "license": "MIT",
14
14
  "devDependencies": {
15
- "@types/jest": "^28.1.5",
16
15
  "@types/node": "^18.0.0",
17
- "ai": "^4.0.0",
18
- "jest": "^29.0.0",
19
16
  "node-fetch": "^3.3.2",
20
- "openai": "^4.0.0",
21
- "ts-jest": "^29.0.0",
22
17
  "typescript": "^4.7.4"
23
18
  },
24
19
  "keywords": [
@@ -31,7 +26,9 @@
31
26
  ],
32
27
  "dependencies": {
33
28
  "@anthropic-ai/sdk": "^0.36.3",
29
+ "@langchain/core": "^0.3.37",
34
30
  "ai": "^4.0.0",
31
+ "langchain": "^0.3.15",
35
32
  "openai": "^4.0.0",
36
33
  "uuid": "^11.0.5",
37
34
  "zod": "^3.24.1"
@@ -72,7 +72,7 @@ export class WrappedMessages extends AnthropicOriginal.Messages {
72
72
  return parentPromise.then((value) => {
73
73
  const passThroughStream = new PassThrough({ objectMode: true })
74
74
  let accumulatedContent = ''
75
- let usage: { inputTokens: number; outputTokens: number } = {
75
+ const usage: { inputTokens: number; outputTokens: number } = {
76
76
  inputTokens: 0,
77
77
  outputTokens: 0,
78
78
  }
package/src/index.ts CHANGED
@@ -2,8 +2,10 @@ import PostHogOpenAI from './openai'
2
2
  import PostHogAzureOpenAI from './openai/azure'
3
3
  import { wrapVercelLanguageModel } from './vercel/middleware'
4
4
  import PostHogAnthropic from './anthropic'
5
+ import { LangChainCallbackHandler } from './langchain/callbacks'
5
6
 
6
7
  export { PostHogOpenAI as OpenAI }
7
8
  export { PostHogAzureOpenAI as AzureOpenAI }
8
9
  export { PostHogAnthropic as Anthropic }
9
10
  export { wrapVercelLanguageModel as withTracing }
11
+ export { LangChainCallbackHandler }
@@ -0,0 +1,578 @@
1
+ import { PostHog } from 'posthog-node'
2
+ import { withPrivacyMode, getModelParams } from '../utils'
3
+ import { BaseCallbackHandler } from '@langchain/core/callbacks/base'
4
+ import type { Serialized } from '@langchain/core/load/serializable'
5
+ import type { ChainValues } from '@langchain/core/utils/types'
6
+ import type { BaseMessage } from '@langchain/core/messages'
7
+ import type { LLMResult } from '@langchain/core/outputs'
8
+ import type { AgentAction, AgentFinish } from '@langchain/core/agents'
9
+ import type { DocumentInterface } from '@langchain/core/documents'
10
+
11
+ interface SpanMetadata {
12
+ /** Name of the trace/span (e.g. chain name) */
13
+ name: string
14
+ /** Timestamp (in ms) when the run started */
15
+ startTime: number
16
+ /** Timestamp (in ms) when the run ended (if already finished) */
17
+ endTime?: number
18
+ /** The input state */
19
+ input?: any
20
+ }
21
+
22
+ interface GenerationMetadata extends SpanMetadata {
23
+ /** Provider used (e.g. openai, anthropic) */
24
+ provider?: string
25
+ /** Model name used in the generation */
26
+ model?: string
27
+ /** The model parameters (temperature, max_tokens, etc.) */
28
+ modelParams?: Record<string, any>
29
+ /** The base URL—for example, the API base used */
30
+ baseUrl?: string
31
+ }
32
+
33
+ /** A run may either be a Span or a Generation */
34
+ type RunMetadata = SpanMetadata | GenerationMetadata
35
+
36
+ /** Storage for run metadata */
37
+ type RunMetadataStorage = { [runId: string]: RunMetadata }
38
+
39
+ export class LangChainCallbackHandler extends BaseCallbackHandler {
40
+ public name = 'PosthogCallbackHandler'
41
+ private client: PostHog
42
+ private distinctId?: string | number
43
+ private traceId?: string | number
44
+ private properties: Record<string, any>
45
+ private privacyMode: boolean
46
+ private groups: Record<string, any>
47
+ private debug: boolean
48
+
49
+ private runs: RunMetadataStorage = {}
50
+ private parentTree: { [runId: string]: string } = {}
51
+
52
+ constructor(options: {
53
+ client: PostHog
54
+ distinctId?: string | number
55
+ traceId?: string | number
56
+ properties?: Record<string, any>
57
+ privacyMode?: boolean
58
+ groups?: Record<string, any>
59
+ debug?: boolean
60
+ }) {
61
+ if (!options.client) {
62
+ throw new Error('PostHog client is required')
63
+ }
64
+ super()
65
+ this.client = options.client
66
+ this.distinctId = options.distinctId
67
+ this.traceId = options.traceId
68
+ this.properties = options.properties || {}
69
+ this.privacyMode = options.privacyMode || false
70
+ this.groups = options.groups || {}
71
+ this.debug = options.debug || false
72
+ }
73
+
74
+ // ===== CALLBACK METHODS =====
75
+
76
+ public handleChainStart(
77
+ chain: Serialized,
78
+ inputs: ChainValues,
79
+ runId: string,
80
+ parentRunId?: string,
81
+ tags?: string[],
82
+ metadata?: Record<string, unknown>,
83
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
84
+ runType?: string,
85
+ runName?: string
86
+ ): void {
87
+ this._logDebugEvent('on_chain_start', runId, parentRunId, { inputs, tags })
88
+ this._setParentOfRun(runId, parentRunId)
89
+ this._setTraceOrSpanMetadata(chain, inputs, runId, parentRunId, metadata, tags, runName)
90
+ }
91
+
92
+ public handleChainEnd(
93
+ outputs: ChainValues,
94
+ runId: string,
95
+ parentRunId?: string,
96
+ tags?: string[],
97
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
98
+ kwargs?: { inputs?: Record<string, unknown> }
99
+ ): void {
100
+ this._logDebugEvent('on_chain_end', runId, parentRunId, { outputs, tags })
101
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, outputs)
102
+ }
103
+
104
+ public handleChainError(
105
+ error: Error,
106
+ runId: string,
107
+ parentRunId?: string,
108
+ tags?: string[],
109
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
110
+ kwargs?: { inputs?: Record<string, unknown> }
111
+ ): void {
112
+ this._logDebugEvent('on_chain_error', runId, parentRunId, { error, tags })
113
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, error)
114
+ }
115
+
116
+ public handleChatModelStart(
117
+ serialized: Serialized,
118
+ messages: BaseMessage[][],
119
+ runId: string,
120
+ parentRunId?: string,
121
+ extraParams?: Record<string, unknown>,
122
+ tags?: string[],
123
+ metadata?: Record<string, unknown>,
124
+ runName?: string
125
+ ): void {
126
+ this._logDebugEvent('on_chat_model_start', runId, parentRunId, { messages, tags })
127
+ this._setParentOfRun(runId, parentRunId)
128
+ // Flatten the two-dimensional messages and convert each message to a plain object
129
+ const input = messages.flat().map((m) => this._convertMessageToDict(m))
130
+ this._setLLMMetadata(serialized, runId, input, metadata, extraParams, runName)
131
+ }
132
+
133
+ public handleLLMStart(
134
+ serialized: Serialized,
135
+ prompts: string[],
136
+ runId: string,
137
+ parentRunId?: string,
138
+ extraParams?: Record<string, unknown>,
139
+ tags?: string[],
140
+ metadata?: Record<string, unknown>,
141
+ runName?: string
142
+ ): void {
143
+ this._logDebugEvent('on_llm_start', runId, parentRunId, { prompts, tags })
144
+ this._setParentOfRun(runId, parentRunId)
145
+ this._setLLMMetadata(serialized, runId, prompts, metadata, extraParams, runName)
146
+ }
147
+
148
+ public handleLLMEnd(
149
+ output: LLMResult,
150
+ runId: string,
151
+ parentRunId?: string,
152
+ tags?: string[],
153
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
154
+ extraParams?: Record<string, unknown>
155
+ ): void {
156
+ this._logDebugEvent('on_llm_end', runId, parentRunId, { output, tags })
157
+ this._popRunAndCaptureGeneration(runId, parentRunId, output)
158
+ }
159
+
160
+ public handleLLMError(
161
+ err: Error,
162
+ runId: string,
163
+ parentRunId?: string,
164
+ tags?: string[],
165
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
166
+ extraParams?: Record<string, unknown>
167
+ ): void {
168
+ this._logDebugEvent('on_llm_error', runId, parentRunId, { err, tags })
169
+ this._popRunAndCaptureGeneration(runId, parentRunId, err)
170
+ }
171
+
172
+ public handleToolStart(
173
+ tool: Serialized,
174
+ input: string,
175
+ runId: string,
176
+ parentRunId?: string,
177
+ tags?: string[],
178
+ metadata?: Record<string, unknown>,
179
+ runName?: string
180
+ ): void {
181
+ this._logDebugEvent('on_tool_start', runId, parentRunId, { input, tags })
182
+ this._setParentOfRun(runId, parentRunId)
183
+ this._setTraceOrSpanMetadata(tool, input, runId, parentRunId, metadata, tags, runName)
184
+ }
185
+
186
+ public handleToolEnd(output: any, runId: string, parentRunId?: string, tags?: string[]): void {
187
+ this._logDebugEvent('on_tool_end', runId, parentRunId, { output, tags })
188
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, output)
189
+ }
190
+
191
+ public handleToolError(err: Error, runId: string, parentRunId?: string, tags?: string[]): void {
192
+ this._logDebugEvent('on_tool_error', runId, parentRunId, { err, tags })
193
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, err)
194
+ }
195
+
196
+ public handleRetrieverStart(
197
+ retriever: Serialized,
198
+ query: string,
199
+ runId: string,
200
+ parentRunId?: string,
201
+ tags?: string[],
202
+ metadata?: Record<string, unknown>,
203
+ name?: string
204
+ ): void {
205
+ this._logDebugEvent('on_retriever_start', runId, parentRunId, { query, tags })
206
+ this._setParentOfRun(runId, parentRunId)
207
+ this._setTraceOrSpanMetadata(retriever, query, runId, parentRunId, metadata, tags, name)
208
+ }
209
+
210
+ public handleRetrieverEnd(
211
+ documents: DocumentInterface[],
212
+ runId: string,
213
+ parentRunId?: string,
214
+ tags?: string[]
215
+ ): void {
216
+ this._logDebugEvent('on_retriever_end', runId, parentRunId, { documents, tags })
217
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, documents)
218
+ }
219
+
220
+ public handleRetrieverError(err: Error, runId: string, parentRunId?: string, tags?: string[]): void {
221
+ this._logDebugEvent('on_retriever_error', runId, parentRunId, { err, tags })
222
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, err)
223
+ }
224
+
225
+ public handleAgentAction(action: AgentAction, runId: string, parentRunId?: string, tags?: string[]): void {
226
+ this._logDebugEvent('on_agent_action', runId, parentRunId, { action, tags })
227
+ this._setParentOfRun(runId, parentRunId)
228
+ this._setTraceOrSpanMetadata(null, action, runId, parentRunId)
229
+ }
230
+
231
+ public handleAgentEnd(action: AgentFinish, runId: string, parentRunId?: string, tags?: string[]): void {
232
+ this._logDebugEvent('on_agent_finish', runId, parentRunId, { action, tags })
233
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, action)
234
+ }
235
+
236
+ // ===== PRIVATE HELPERS =====
237
+
238
+ private _setParentOfRun(runId: string, parentRunId?: string): void {
239
+ if (parentRunId) {
240
+ this.parentTree[runId] = parentRunId
241
+ }
242
+ }
243
+
244
+ private _popParentOfRun(runId: string): void {
245
+ delete this.parentTree[runId]
246
+ }
247
+
248
+ private _findRootRun(runId: string): string {
249
+ let id = runId
250
+ while (this.parentTree[id]) {
251
+ id = this.parentTree[id]
252
+ }
253
+ return id
254
+ }
255
+
256
+ private _setTraceOrSpanMetadata(
257
+ serialized: any,
258
+ input: any,
259
+ runId: string,
260
+ parentRunId?: string,
261
+ ...args: any[]
262
+ ): void {
263
+ // Use default names if not provided: if this is a top-level run, we mark it as a trace, otherwise as a span.
264
+ const defaultName = parentRunId ? 'span' : 'trace'
265
+ const runName = this._getLangchainRunName(serialized, ...args) || defaultName
266
+ this.runs[runId] = {
267
+ name: runName,
268
+ input,
269
+ startTime: Date.now(),
270
+ } as SpanMetadata
271
+ }
272
+
273
+ private _setLLMMetadata(
274
+ serialized: Serialized | null,
275
+ runId: string,
276
+ messages: any,
277
+ metadata?: any,
278
+ extraParams?: any,
279
+ runName?: string
280
+ ): void {
281
+ const runNameFound = this._getLangchainRunName(serialized, { extraParams, runName }) || 'generation'
282
+ const generation: GenerationMetadata = {
283
+ name: runNameFound,
284
+ input: messages,
285
+ startTime: Date.now(),
286
+ }
287
+ if (extraParams) {
288
+ generation.modelParams = getModelParams(extraParams.invocation_params)
289
+ }
290
+ if (metadata) {
291
+ if (metadata.ls_model_name) {
292
+ generation.model = metadata.ls_model_name
293
+ }
294
+ if (metadata.ls_provider) {
295
+ generation.provider = metadata.ls_provider
296
+ }
297
+ }
298
+ if (serialized && 'kwargs' in serialized && serialized.kwargs.openai_api_base) {
299
+ generation.baseUrl = serialized.kwargs.openai_api_base
300
+ }
301
+ this.runs[runId] = generation
302
+ }
303
+
304
+ private _popRunMetadata(runId: string): RunMetadata | undefined {
305
+ const endTime = Date.now()
306
+ const run = this.runs[runId]
307
+ if (!run) {
308
+ console.warn(`No run metadata found for run ${runId}`)
309
+ return undefined
310
+ }
311
+ run.endTime = endTime
312
+ delete this.runs[runId]
313
+ return run
314
+ }
315
+
316
+ private _getTraceId(runId: string): string {
317
+ return this.traceId ? String(this.traceId) : this._findRootRun(runId)
318
+ }
319
+
320
+ private _getParentRunId(traceId: string, runId: string, parentRunId?: string): string | undefined {
321
+ // Replace the parent-run if not found in our stored parent tree.
322
+ if (parentRunId && !this.parentTree[parentRunId]) {
323
+ return traceId
324
+ }
325
+ return parentRunId
326
+ }
327
+
328
+ private _popRunAndCaptureTraceOrSpan(
329
+ runId: string,
330
+ parentRunId: string | undefined,
331
+ outputs: ChainValues | DocumentInterface[] | AgentFinish | Error | any
332
+ ): void {
333
+ const traceId = this._getTraceId(runId)
334
+ this._popParentOfRun(runId)
335
+ const run = this._popRunMetadata(runId)
336
+ if (!run) return
337
+ if ('modelParams' in run) {
338
+ console.warn(`Run ${runId} is a generation, but attempted to be captured as a trace/span.`)
339
+ return
340
+ }
341
+ const actualParentRunId = this._getParentRunId(traceId, runId, parentRunId)
342
+ this._captureTraceOrSpan(traceId, runId, run as SpanMetadata, outputs, actualParentRunId)
343
+ }
344
+
345
+ private _captureTraceOrSpan(
346
+ traceId: string,
347
+ runId: string,
348
+ run: SpanMetadata,
349
+ outputs: ChainValues | DocumentInterface[] | AgentFinish | Error | any,
350
+ parentRunId?: string
351
+ ): void {
352
+ const eventName = parentRunId ? '$ai_span' : '$ai_trace'
353
+ const latency = run.endTime ? (run.endTime - run.startTime) / 1000 : 0
354
+ const eventProperties: Record<string, any> = {
355
+ $ai_trace_id: traceId,
356
+ $ai_input_state: withPrivacyMode(this.client, this.privacyMode, run.input),
357
+ $ai_latency: latency,
358
+ $ai_span_name: run.name,
359
+ $ai_span_id: runId,
360
+ }
361
+ if (parentRunId) {
362
+ eventProperties['$ai_parent_id'] = parentRunId
363
+ }
364
+
365
+ Object.assign(eventProperties, this.properties)
366
+ if (!this.distinctId) {
367
+ eventProperties['$process_person_profile'] = false
368
+ }
369
+ if (outputs instanceof Error) {
370
+ eventProperties['$ai_error'] = outputs.toString()
371
+ eventProperties['$ai_is_error'] = true
372
+ } else if (outputs !== undefined) {
373
+ eventProperties['$ai_output_state'] = withPrivacyMode(this.client, this.privacyMode, outputs)
374
+ }
375
+ this.client.capture({
376
+ distinctId: this.distinctId ? this.distinctId.toString() : runId,
377
+ event: eventName,
378
+ properties: eventProperties,
379
+ groups: this.groups,
380
+ })
381
+ }
382
+
383
+ private _popRunAndCaptureGeneration(
384
+ runId: string,
385
+ parentRunId: string | undefined,
386
+ response: LLMResult | Error
387
+ ): void {
388
+ const traceId = this._getTraceId(runId)
389
+ this._popParentOfRun(runId)
390
+ const run = this._popRunMetadata(runId)
391
+ if (!run || typeof run !== 'object' || !('modelParams' in run)) {
392
+ console.warn(`Run ${runId} is not a generation, but attempted to be captured as such.`)
393
+ return
394
+ }
395
+ const actualParentRunId = this._getParentRunId(traceId, runId, parentRunId)
396
+ this._captureGeneration(traceId, runId, run as GenerationMetadata, response, actualParentRunId)
397
+ }
398
+
399
+ private _captureGeneration(
400
+ traceId: string,
401
+ runId: string,
402
+ run: GenerationMetadata,
403
+ output: LLMResult | Error,
404
+ parentRunId?: string
405
+ ): void {
406
+ const latency = run.endTime ? (run.endTime - run.startTime) / 1000 : 0
407
+ const eventProperties: Record<string, any> = {
408
+ $ai_trace_id: traceId,
409
+ $ai_span_id: runId,
410
+ $ai_span_name: run.name,
411
+ $ai_parent_id: parentRunId,
412
+ $ai_provider: run.provider,
413
+ $ai_model: run.model,
414
+ $ai_model_parameters: run.modelParams,
415
+ $ai_input: withPrivacyMode(this.client, this.privacyMode, run.input),
416
+ $ai_http_status: 200,
417
+ $ai_latency: latency,
418
+ $ai_base_url: run.baseUrl,
419
+ }
420
+
421
+ if (output instanceof Error) {
422
+ eventProperties['$ai_http_status'] = (output as any).status || 500
423
+ eventProperties['$ai_error'] = output.toString()
424
+ eventProperties['$ai_is_error'] = true
425
+ } else {
426
+ // Handle token usage
427
+ const [inputTokens, outputTokens] = this.parseUsage(output)
428
+ eventProperties['$ai_input_tokens'] = inputTokens
429
+ eventProperties['$ai_output_tokens'] = outputTokens
430
+
431
+ // Handle generations/completions
432
+ let completions
433
+ if (output.generations && Array.isArray(output.generations)) {
434
+ const lastGeneration = output.generations[output.generations.length - 1]
435
+ if (Array.isArray(lastGeneration)) {
436
+ completions = lastGeneration.map((gen) => {
437
+ return { role: 'assistant', content: gen.text }
438
+ })
439
+ }
440
+ }
441
+
442
+ if (completions) {
443
+ eventProperties['$ai_output_choices'] = withPrivacyMode(this.client, this.privacyMode, completions)
444
+ }
445
+ }
446
+
447
+ Object.assign(eventProperties, this.properties)
448
+ if (!this.distinctId) {
449
+ eventProperties['$process_person_profile'] = false
450
+ }
451
+
452
+ this.client.capture({
453
+ distinctId: this.distinctId ? this.distinctId.toString() : traceId,
454
+ event: '$ai_generation',
455
+ properties: eventProperties,
456
+ groups: this.groups,
457
+ })
458
+ }
459
+
460
+ private _logDebugEvent(eventName: string, runId: string, parentRunId: string | undefined, extra: any): void {
461
+ if (this.debug) {
462
+ console.log(`Event: ${eventName}, runId: ${runId}, parentRunId: ${parentRunId}, extra:`, extra)
463
+ }
464
+ }
465
+
466
+ private _getLangchainRunName(serialized: any, ...args: any[]): string | undefined {
467
+ if (args && args.length > 0) {
468
+ for (const arg of args) {
469
+ if (arg && typeof arg === 'object' && 'name' in arg) return arg.name
470
+ }
471
+ }
472
+ if (serialized && serialized.name) return serialized.name
473
+ if (serialized && serialized.id) {
474
+ return Array.isArray(serialized.id) ? serialized.id[serialized.id.length - 1] : serialized.id
475
+ }
476
+ return undefined
477
+ }
478
+
479
+ private _convertMessageToDict(message: any): Record<string, any> {
480
+ let messageDict: Record<string, any> = {}
481
+
482
+ // Check the _getType() method or type property instead of instanceof
483
+ const messageType = message._getType?.() || message.type
484
+
485
+ switch (messageType) {
486
+ case 'human':
487
+ messageDict = { role: 'user', content: message.content }
488
+ break
489
+ case 'ai':
490
+ messageDict = { role: 'assistant', content: message.content }
491
+ break
492
+ case 'system':
493
+ messageDict = { role: 'system', content: message.content }
494
+ break
495
+ case 'tool':
496
+ messageDict = { role: 'tool', content: message.content }
497
+ break
498
+ case 'function':
499
+ messageDict = { role: 'function', content: message.content }
500
+ break
501
+ default:
502
+ messageDict = { role: messageType || 'unknown', content: String(message.content) }
503
+ }
504
+
505
+ if (message.additional_kwargs) {
506
+ messageDict = { ...messageDict, ...message.additional_kwargs }
507
+ }
508
+ return messageDict
509
+ }
510
+
511
+ private _parseUsageModel(usage: any): [number, number] {
512
+ const conversionList: Array<[string, 'input' | 'output']> = [
513
+ ['promptTokens', 'input'],
514
+ ['completionTokens', 'output'],
515
+ ['input_tokens', 'input'],
516
+ ['output_tokens', 'output'],
517
+ ['prompt_token_count', 'input'],
518
+ ['candidates_token_count', 'output'],
519
+ ['inputTokenCount', 'input'],
520
+ ['outputTokenCount', 'output'],
521
+ ['input_token_count', 'input'],
522
+ ['generated_token_count', 'output'],
523
+ ]
524
+
525
+ const parsedUsage = conversionList.reduce(
526
+ (acc: { input: number; output: number }, [modelKey, typeKey]) => {
527
+ const value = usage[modelKey]
528
+ if (value != null) {
529
+ const finalCount = Array.isArray(value)
530
+ ? value.reduce((sum: number, tokenCount: number) => sum + tokenCount, 0)
531
+ : value
532
+ acc[typeKey] = finalCount
533
+ }
534
+ return acc
535
+ },
536
+ { input: 0, output: 0 }
537
+ )
538
+
539
+ return [parsedUsage.input, parsedUsage.output]
540
+ }
541
+
542
+ private parseUsage(response: LLMResult): [number, number] {
543
+ let llmUsage: [number, number] = [0, 0]
544
+ const llmUsageKeys = ['token_usage', 'usage', 'tokenUsage']
545
+
546
+ if (response.llmOutput != null) {
547
+ const key = llmUsageKeys.find((k) => response.llmOutput?.[k] != null)
548
+ if (key) {
549
+ llmUsage = this._parseUsageModel(response.llmOutput[key])
550
+ }
551
+ }
552
+
553
+ // If top-level usage info was not found, try checking the generations.
554
+ if (llmUsage[0] === 0 && llmUsage[1] === 0 && response.generations) {
555
+ for (const generation of response.generations) {
556
+ for (const genChunk of generation) {
557
+ if (genChunk.generationInfo?.usage_metadata) {
558
+ llmUsage = this._parseUsageModel(genChunk.generationInfo.usage_metadata)
559
+ return llmUsage
560
+ }
561
+
562
+ const messageChunk = genChunk.generationInfo ?? {}
563
+ const responseMetadata = messageChunk.response_metadata ?? {}
564
+ const chunkUsage =
565
+ responseMetadata['usage'] ??
566
+ responseMetadata['amazon-bedrock-invocationMetrics'] ??
567
+ messageChunk.usage_metadata
568
+ if (chunkUsage) {
569
+ llmUsage = this._parseUsageModel(chunkUsage)
570
+ return llmUsage
571
+ }
572
+ }
573
+ }
574
+ }
575
+
576
+ return llmUsage
577
+ }
578
+ }
package/src/utils.ts CHANGED
@@ -14,8 +14,11 @@ export interface MonitoringParams {
14
14
  }
15
15
 
16
16
  export const getModelParams = (
17
- params: (ChatCompletionCreateParamsBase | MessageCreateParams) & MonitoringParams
17
+ params: ((ChatCompletionCreateParamsBase | MessageCreateParams) & MonitoringParams) | null
18
18
  ): Record<string, any> => {
19
+ if (!params) {
20
+ return {}
21
+ }
19
22
  const modelParams: Record<string, any> = {}
20
23
  const paramKeys = [
21
24
  'temperature',
@@ -1,10 +1,5 @@
1
- import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai'
2
- import type {
3
- LanguageModelV1,
4
- Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware,
5
- LanguageModelV1Prompt,
6
- LanguageModelV1StreamPart,
7
- } from 'ai'
1
+ import { wrapLanguageModel } from 'ai'
2
+ import type { LanguageModelV1, LanguageModelV1Middleware, LanguageModelV1Prompt, LanguageModelV1StreamPart } from 'ai'
8
3
  import { v4 as uuidv4 } from 'uuid'
9
4
  import { PostHog } from 'posthog-node'
10
5
  import { sendEventToPosthog } from '../utils'
@@ -17,6 +12,12 @@ interface ClientOptions {
17
12
  posthogGroups?: Record<string, any>
18
13
  posthogModelOverride?: string
19
14
  posthogProviderOverride?: string
15
+ posthogCostOverride?: CostOverride
16
+ }
17
+
18
+ interface CostOverride {
19
+ inputTokens: number
20
+ outputTokens: number
20
21
  }
21
22
 
22
23
  interface CreateInstrumentationMiddlewareOptions {
@@ -27,6 +28,7 @@ interface CreateInstrumentationMiddlewareOptions {
27
28
  posthogGroups?: Record<string, any>
28
29
  posthogModelOverride?: string
29
30
  posthogProviderOverride?: string
31
+ posthogCostOverride?: CostOverride
30
32
  }
31
33
 
32
34
  interface PostHogInput {
@@ -83,7 +85,7 @@ export const createInstrumentationMiddleware = (
83
85
  const middleware: LanguageModelV1Middleware = {
84
86
  wrapGenerate: async ({ doGenerate, params }) => {
85
87
  const startTime = Date.now()
86
- let mergedParams = {
88
+ const mergedParams = {
87
89
  ...options,
88
90
  ...mapVercelParams(params),
89
91
  }
@@ -143,7 +145,7 @@ export const createInstrumentationMiddleware = (
143
145
  const startTime = Date.now()
144
146
  let generatedText = ''
145
147
  let usage: { inputTokens?: number; outputTokens?: number } = {}
146
- let mergedParams = {
148
+ const mergedParams = {
147
149
  ...options,
148
150
  ...mapVercelParams(params),
149
151
  }