@posthog/ai 2.4.0 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@posthog/ai",
3
- "version": "2.4.0",
3
+ "version": "3.1.0",
4
4
  "description": "PostHog Node.js AI integrations",
5
5
  "repository": {
6
6
  "type": "git",
@@ -12,13 +12,8 @@
12
12
  "types": "lib/index.d.ts",
13
13
  "license": "MIT",
14
14
  "devDependencies": {
15
- "@types/jest": "^28.1.5",
16
15
  "@types/node": "^18.0.0",
17
- "ai": "^4.0.0",
18
- "jest": "^29.0.0",
19
16
  "node-fetch": "^3.3.2",
20
- "openai": "^4.0.0",
21
- "ts-jest": "^29.0.0",
22
17
  "typescript": "^4.7.4"
23
18
  },
24
19
  "keywords": [
@@ -31,7 +26,9 @@
31
26
  ],
32
27
  "dependencies": {
33
28
  "@anthropic-ai/sdk": "^0.36.3",
29
+ "@langchain/core": "^0.3.37",
34
30
  "ai": "^4.0.0",
31
+ "langchain": "^0.3.15",
35
32
  "openai": "^4.0.0",
36
33
  "uuid": "^11.0.5",
37
34
  "zod": "^3.24.1"
@@ -72,7 +72,7 @@ export class WrappedMessages extends AnthropicOriginal.Messages {
72
72
  return parentPromise.then((value) => {
73
73
  const passThroughStream = new PassThrough({ objectMode: true })
74
74
  let accumulatedContent = ''
75
- let usage: { inputTokens: number; outputTokens: number } = {
75
+ const usage: { inputTokens: number; outputTokens: number } = {
76
76
  inputTokens: 0,
77
77
  outputTokens: 0,
78
78
  }
package/src/index.ts CHANGED
@@ -2,8 +2,10 @@ import PostHogOpenAI from './openai'
2
2
  import PostHogAzureOpenAI from './openai/azure'
3
3
  import { wrapVercelLanguageModel } from './vercel/middleware'
4
4
  import PostHogAnthropic from './anthropic'
5
+ import { LangChainCallbackHandler } from './langchain/callbacks'
5
6
 
6
7
  export { PostHogOpenAI as OpenAI }
7
8
  export { PostHogAzureOpenAI as AzureOpenAI }
8
9
  export { PostHogAnthropic as Anthropic }
9
10
  export { wrapVercelLanguageModel as withTracing }
11
+ export { LangChainCallbackHandler }
@@ -0,0 +1,584 @@
1
+ import { PostHog } from 'posthog-node'
2
+ import { withPrivacyMode, getModelParams } from '../utils'
3
+ import { BaseCallbackHandler } from '@langchain/core/callbacks/base'
4
+ import type { Serialized } from '@langchain/core/load/serializable'
5
+ import type { ChainValues } from '@langchain/core/utils/types'
6
+ import type { BaseMessage } from '@langchain/core/messages'
7
+ import type { LLMResult } from '@langchain/core/outputs'
8
+ import type { AgentAction, AgentFinish } from '@langchain/core/agents'
9
+ import type { DocumentInterface } from '@langchain/core/documents'
10
+
11
+ interface SpanMetadata {
12
+ /** Name of the trace/span (e.g. chain name) */
13
+ name: string
14
+ /** Timestamp (in ms) when the run started */
15
+ startTime: number
16
+ /** Timestamp (in ms) when the run ended (if already finished) */
17
+ endTime?: number
18
+ /** The input state */
19
+ input?: any
20
+ }
21
+
22
+ interface GenerationMetadata extends SpanMetadata {
23
+ /** Provider used (e.g. openai, anthropic) */
24
+ provider?: string
25
+ /** Model name used in the generation */
26
+ model?: string
27
+ /** The model parameters (temperature, max_tokens, etc.) */
28
+ modelParams?: Record<string, any>
29
+ /** The base URL—for example, the API base used */
30
+ baseUrl?: string
31
+ }
32
+
33
+ /** A run may either be a Span or a Generation */
34
+ type RunMetadata = SpanMetadata | GenerationMetadata
35
+
36
+ /** Storage for run metadata */
37
+ type RunMetadataStorage = { [runId: string]: RunMetadata }
38
+
39
+ export class LangChainCallbackHandler extends BaseCallbackHandler {
40
+ public name = 'PosthogCallbackHandler'
41
+ private client: PostHog
42
+ private distinctId?: string | number
43
+ private traceId?: string | number
44
+ private properties: Record<string, any>
45
+ private privacyMode: boolean
46
+ private groups: Record<string, any>
47
+ private debug: boolean
48
+
49
+ private runs: RunMetadataStorage = {}
50
+ private parentTree: { [runId: string]: string } = {}
51
+
52
+ constructor(options: {
53
+ client: PostHog
54
+ distinctId?: string | number
55
+ traceId?: string | number
56
+ properties?: Record<string, any>
57
+ privacyMode?: boolean
58
+ groups?: Record<string, any>
59
+ debug?: boolean
60
+ }) {
61
+ if (!options.client) {
62
+ throw new Error('PostHog client is required')
63
+ }
64
+ super()
65
+ this.client = options.client
66
+ this.distinctId = options.distinctId
67
+ this.traceId = options.traceId
68
+ this.properties = options.properties || {}
69
+ this.privacyMode = options.privacyMode || false
70
+ this.groups = options.groups || {}
71
+ this.debug = options.debug || false
72
+ }
73
+
74
+ // ===== CALLBACK METHODS =====
75
+
76
+ public handleChainStart(
77
+ chain: Serialized,
78
+ inputs: ChainValues,
79
+ runId: string,
80
+ parentRunId?: string,
81
+ tags?: string[],
82
+ metadata?: Record<string, unknown>,
83
+
84
+ runType?: string,
85
+ runName?: string
86
+ ): void {
87
+ this._logDebugEvent('on_chain_start', runId, parentRunId, { inputs, tags })
88
+ this._setParentOfRun(runId, parentRunId)
89
+ this._setTraceOrSpanMetadata(chain, inputs, runId, parentRunId, metadata, tags, runName)
90
+ }
91
+
92
+ public handleChainEnd(
93
+ outputs: ChainValues,
94
+ runId: string,
95
+ parentRunId?: string,
96
+ tags?: string[],
97
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
98
+ kwargs?: { inputs?: Record<string, unknown> }
99
+ ): void {
100
+ this._logDebugEvent('on_chain_end', runId, parentRunId, { outputs, tags })
101
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, outputs)
102
+ }
103
+
104
+ public handleChainError(
105
+ error: Error,
106
+ runId: string,
107
+ parentRunId?: string,
108
+ tags?: string[],
109
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
110
+ kwargs?: { inputs?: Record<string, unknown> }
111
+ ): void {
112
+ this._logDebugEvent('on_chain_error', runId, parentRunId, { error, tags })
113
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, error)
114
+ }
115
+
116
+ public handleChatModelStart(
117
+ serialized: Serialized,
118
+ messages: BaseMessage[][],
119
+ runId: string,
120
+ parentRunId?: string,
121
+ extraParams?: Record<string, unknown>,
122
+ tags?: string[],
123
+ metadata?: Record<string, unknown>,
124
+ runName?: string
125
+ ): void {
126
+ this._logDebugEvent('on_chat_model_start', runId, parentRunId, { messages, tags })
127
+ this._setParentOfRun(runId, parentRunId)
128
+ // Flatten the two-dimensional messages and convert each message to a plain object
129
+ const input = messages.flat().map((m) => this._convertMessageToDict(m))
130
+ this._setLLMMetadata(serialized, runId, input, metadata, extraParams, runName)
131
+ }
132
+
133
+ public handleLLMStart(
134
+ serialized: Serialized,
135
+ prompts: string[],
136
+ runId: string,
137
+ parentRunId?: string,
138
+ extraParams?: Record<string, unknown>,
139
+ tags?: string[],
140
+ metadata?: Record<string, unknown>,
141
+ runName?: string
142
+ ): void {
143
+ this._logDebugEvent('on_llm_start', runId, parentRunId, { prompts, tags })
144
+ this._setParentOfRun(runId, parentRunId)
145
+ this._setLLMMetadata(serialized, runId, prompts, metadata, extraParams, runName)
146
+ }
147
+
148
+ public handleLLMEnd(
149
+ output: LLMResult,
150
+ runId: string,
151
+ parentRunId?: string,
152
+ tags?: string[],
153
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
154
+ extraParams?: Record<string, unknown>
155
+ ): void {
156
+ this._logDebugEvent('on_llm_end', runId, parentRunId, { output, tags })
157
+ this._popRunAndCaptureGeneration(runId, parentRunId, output)
158
+ }
159
+
160
+ public handleLLMError(
161
+ err: Error,
162
+ runId: string,
163
+ parentRunId?: string,
164
+ tags?: string[],
165
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
166
+ extraParams?: Record<string, unknown>
167
+ ): void {
168
+ this._logDebugEvent('on_llm_error', runId, parentRunId, { err, tags })
169
+ this._popRunAndCaptureGeneration(runId, parentRunId, err)
170
+ }
171
+
172
+ public handleToolStart(
173
+ tool: Serialized,
174
+ input: string,
175
+ runId: string,
176
+ parentRunId?: string,
177
+ tags?: string[],
178
+ metadata?: Record<string, unknown>,
179
+ runName?: string
180
+ ): void {
181
+ this._logDebugEvent('on_tool_start', runId, parentRunId, { input, tags })
182
+ this._setParentOfRun(runId, parentRunId)
183
+ this._setTraceOrSpanMetadata(tool, input, runId, parentRunId, metadata, tags, runName)
184
+ }
185
+
186
+ public handleToolEnd(output: any, runId: string, parentRunId?: string, tags?: string[]): void {
187
+ this._logDebugEvent('on_tool_end', runId, parentRunId, { output, tags })
188
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, output)
189
+ }
190
+
191
+ public handleToolError(err: Error, runId: string, parentRunId?: string, tags?: string[]): void {
192
+ this._logDebugEvent('on_tool_error', runId, parentRunId, { err, tags })
193
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, err)
194
+ }
195
+
196
+ public handleRetrieverStart(
197
+ retriever: Serialized,
198
+ query: string,
199
+ runId: string,
200
+ parentRunId?: string,
201
+ tags?: string[],
202
+ metadata?: Record<string, unknown>,
203
+ name?: string
204
+ ): void {
205
+ this._logDebugEvent('on_retriever_start', runId, parentRunId, { query, tags })
206
+ this._setParentOfRun(runId, parentRunId)
207
+ this._setTraceOrSpanMetadata(retriever, query, runId, parentRunId, metadata, tags, name)
208
+ }
209
+
210
+ public handleRetrieverEnd(
211
+ documents: DocumentInterface[],
212
+ runId: string,
213
+ parentRunId?: string,
214
+ tags?: string[]
215
+ ): void {
216
+ this._logDebugEvent('on_retriever_end', runId, parentRunId, { documents, tags })
217
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, documents)
218
+ }
219
+
220
+ public handleRetrieverError(err: Error, runId: string, parentRunId?: string, tags?: string[]): void {
221
+ this._logDebugEvent('on_retriever_error', runId, parentRunId, { err, tags })
222
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, err)
223
+ }
224
+
225
+ public handleAgentAction(action: AgentAction, runId: string, parentRunId?: string, tags?: string[]): void {
226
+ this._logDebugEvent('on_agent_action', runId, parentRunId, { action, tags })
227
+ this._setParentOfRun(runId, parentRunId)
228
+ this._setTraceOrSpanMetadata(null, action, runId, parentRunId)
229
+ }
230
+
231
+ public handleAgentEnd(action: AgentFinish, runId: string, parentRunId?: string, tags?: string[]): void {
232
+ this._logDebugEvent('on_agent_finish', runId, parentRunId, { action, tags })
233
+ this._popRunAndCaptureTraceOrSpan(runId, parentRunId, action)
234
+ }
235
+
236
+ // ===== PRIVATE HELPERS =====
237
+
238
+ private _setParentOfRun(runId: string, parentRunId?: string): void {
239
+ if (parentRunId) {
240
+ this.parentTree[runId] = parentRunId
241
+ }
242
+ }
243
+
244
+ private _popParentOfRun(runId: string): void {
245
+ delete this.parentTree[runId]
246
+ }
247
+
248
+ private _findRootRun(runId: string): string {
249
+ let id = runId
250
+ while (this.parentTree[id]) {
251
+ id = this.parentTree[id]
252
+ }
253
+ return id
254
+ }
255
+
256
+ private _setTraceOrSpanMetadata(
257
+ serialized: any,
258
+ input: any,
259
+ runId: string,
260
+ parentRunId?: string,
261
+ ...args: any[]
262
+ ): void {
263
+ // Use default names if not provided: if this is a top-level run, we mark it as a trace, otherwise as a span.
264
+ const defaultName = parentRunId ? 'span' : 'trace'
265
+ const runName = this._getLangchainRunName(serialized, ...args) || defaultName
266
+ this.runs[runId] = {
267
+ name: runName,
268
+ input,
269
+ startTime: Date.now(),
270
+ } as SpanMetadata
271
+ }
272
+
273
+ private _setLLMMetadata(
274
+ serialized: Serialized | null,
275
+ runId: string,
276
+ messages: any,
277
+ metadata?: any,
278
+ extraParams?: any,
279
+ runName?: string
280
+ ): void {
281
+ const runNameFound = this._getLangchainRunName(serialized, { extraParams, runName }) || 'generation'
282
+ const generation: GenerationMetadata = {
283
+ name: runNameFound,
284
+ input: messages,
285
+ startTime: Date.now(),
286
+ }
287
+ if (extraParams) {
288
+ generation.modelParams = getModelParams(extraParams.invocation_params)
289
+ }
290
+ if (metadata) {
291
+ if (metadata.ls_model_name) {
292
+ generation.model = metadata.ls_model_name
293
+ }
294
+ if (metadata.ls_provider) {
295
+ generation.provider = metadata.ls_provider
296
+ }
297
+ }
298
+ if (serialized && 'kwargs' in serialized && serialized.kwargs.openai_api_base) {
299
+ generation.baseUrl = serialized.kwargs.openai_api_base
300
+ }
301
+ this.runs[runId] = generation
302
+ }
303
+
304
+ private _popRunMetadata(runId: string): RunMetadata | undefined {
305
+ const endTime = Date.now()
306
+ const run = this.runs[runId]
307
+ if (!run) {
308
+ console.warn(`No run metadata found for run ${runId}`)
309
+ return undefined
310
+ }
311
+ run.endTime = endTime
312
+ delete this.runs[runId]
313
+ return run
314
+ }
315
+
316
+ private _getTraceId(runId: string): string {
317
+ return this.traceId ? String(this.traceId) : this._findRootRun(runId)
318
+ }
319
+
320
+ private _getParentRunId(traceId: string, runId: string, parentRunId?: string): string | undefined {
321
+ // Replace the parent-run if not found in our stored parent tree.
322
+ if (parentRunId && !this.parentTree[parentRunId]) {
323
+ return traceId
324
+ }
325
+ return parentRunId
326
+ }
327
+
328
+ private _popRunAndCaptureTraceOrSpan(
329
+ runId: string,
330
+ parentRunId: string | undefined,
331
+ outputs: ChainValues | DocumentInterface[] | AgentFinish | Error | any
332
+ ): void {
333
+ const traceId = this._getTraceId(runId)
334
+ this._popParentOfRun(runId)
335
+ const run = this._popRunMetadata(runId)
336
+ if (!run) {
337
+ return
338
+ }
339
+ if ('modelParams' in run) {
340
+ console.warn(`Run ${runId} is a generation, but attempted to be captured as a trace/span.`)
341
+ return
342
+ }
343
+ const actualParentRunId = this._getParentRunId(traceId, runId, parentRunId)
344
+ this._captureTraceOrSpan(traceId, runId, run as SpanMetadata, outputs, actualParentRunId)
345
+ }
346
+
347
+ private _captureTraceOrSpan(
348
+ traceId: string,
349
+ runId: string,
350
+ run: SpanMetadata,
351
+ outputs: ChainValues | DocumentInterface[] | AgentFinish | Error | any,
352
+ parentRunId?: string
353
+ ): void {
354
+ const eventName = parentRunId ? '$ai_span' : '$ai_trace'
355
+ const latency = run.endTime ? (run.endTime - run.startTime) / 1000 : 0
356
+ const eventProperties: Record<string, any> = {
357
+ $ai_trace_id: traceId,
358
+ $ai_input_state: withPrivacyMode(this.client, this.privacyMode, run.input),
359
+ $ai_latency: latency,
360
+ $ai_span_name: run.name,
361
+ $ai_span_id: runId,
362
+ }
363
+ if (parentRunId) {
364
+ eventProperties['$ai_parent_id'] = parentRunId
365
+ }
366
+
367
+ Object.assign(eventProperties, this.properties)
368
+ if (!this.distinctId) {
369
+ eventProperties['$process_person_profile'] = false
370
+ }
371
+ if (outputs instanceof Error) {
372
+ eventProperties['$ai_error'] = outputs.toString()
373
+ eventProperties['$ai_is_error'] = true
374
+ } else if (outputs !== undefined) {
375
+ eventProperties['$ai_output_state'] = withPrivacyMode(this.client, this.privacyMode, outputs)
376
+ }
377
+ this.client.capture({
378
+ distinctId: this.distinctId ? this.distinctId.toString() : runId,
379
+ event: eventName,
380
+ properties: eventProperties,
381
+ groups: this.groups,
382
+ })
383
+ }
384
+
385
+ private _popRunAndCaptureGeneration(
386
+ runId: string,
387
+ parentRunId: string | undefined,
388
+ response: LLMResult | Error
389
+ ): void {
390
+ const traceId = this._getTraceId(runId)
391
+ this._popParentOfRun(runId)
392
+ const run = this._popRunMetadata(runId)
393
+ if (!run || typeof run !== 'object' || !('modelParams' in run)) {
394
+ console.warn(`Run ${runId} is not a generation, but attempted to be captured as such.`)
395
+ return
396
+ }
397
+ const actualParentRunId = this._getParentRunId(traceId, runId, parentRunId)
398
+ this._captureGeneration(traceId, runId, run as GenerationMetadata, response, actualParentRunId)
399
+ }
400
+
401
+ private _captureGeneration(
402
+ traceId: string,
403
+ runId: string,
404
+ run: GenerationMetadata,
405
+ output: LLMResult | Error,
406
+ parentRunId?: string
407
+ ): void {
408
+ const latency = run.endTime ? (run.endTime - run.startTime) / 1000 : 0
409
+ const eventProperties: Record<string, any> = {
410
+ $ai_trace_id: traceId,
411
+ $ai_span_id: runId,
412
+ $ai_span_name: run.name,
413
+ $ai_parent_id: parentRunId,
414
+ $ai_provider: run.provider,
415
+ $ai_model: run.model,
416
+ $ai_model_parameters: run.modelParams,
417
+ $ai_input: withPrivacyMode(this.client, this.privacyMode, run.input),
418
+ $ai_http_status: 200,
419
+ $ai_latency: latency,
420
+ $ai_base_url: run.baseUrl,
421
+ }
422
+
423
+ if (output instanceof Error) {
424
+ eventProperties['$ai_http_status'] = (output as any).status || 500
425
+ eventProperties['$ai_error'] = output.toString()
426
+ eventProperties['$ai_is_error'] = true
427
+ } else {
428
+ // Handle token usage
429
+ const [inputTokens, outputTokens] = this.parseUsage(output)
430
+ eventProperties['$ai_input_tokens'] = inputTokens
431
+ eventProperties['$ai_output_tokens'] = outputTokens
432
+
433
+ // Handle generations/completions
434
+ let completions
435
+ if (output.generations && Array.isArray(output.generations)) {
436
+ const lastGeneration = output.generations[output.generations.length - 1]
437
+ if (Array.isArray(lastGeneration)) {
438
+ completions = lastGeneration.map((gen) => {
439
+ return { role: 'assistant', content: gen.text }
440
+ })
441
+ }
442
+ }
443
+
444
+ if (completions) {
445
+ eventProperties['$ai_output_choices'] = withPrivacyMode(this.client, this.privacyMode, completions)
446
+ }
447
+ }
448
+
449
+ Object.assign(eventProperties, this.properties)
450
+ if (!this.distinctId) {
451
+ eventProperties['$process_person_profile'] = false
452
+ }
453
+
454
+ this.client.capture({
455
+ distinctId: this.distinctId ? this.distinctId.toString() : traceId,
456
+ event: '$ai_generation',
457
+ properties: eventProperties,
458
+ groups: this.groups,
459
+ })
460
+ }
461
+
462
+ private _logDebugEvent(eventName: string, runId: string, parentRunId: string | undefined, extra: any): void {
463
+ if (this.debug) {
464
+ console.log(`Event: ${eventName}, runId: ${runId}, parentRunId: ${parentRunId}, extra:`, extra)
465
+ }
466
+ }
467
+
468
+ private _getLangchainRunName(serialized: any, ...args: any[]): string | undefined {
469
+ if (args && args.length > 0) {
470
+ for (const arg of args) {
471
+ if (arg && typeof arg === 'object' && 'name' in arg) {
472
+ return arg.name
473
+ }
474
+ }
475
+ }
476
+ if (serialized && serialized.name) {
477
+ return serialized.name
478
+ }
479
+ if (serialized && serialized.id) {
480
+ return Array.isArray(serialized.id) ? serialized.id[serialized.id.length - 1] : serialized.id
481
+ }
482
+ return undefined
483
+ }
484
+
485
+ private _convertMessageToDict(message: any): Record<string, any> {
486
+ let messageDict: Record<string, any> = {}
487
+
488
+ // Check the _getType() method or type property instead of instanceof
489
+ const messageType = message._getType?.() || message.type
490
+
491
+ switch (messageType) {
492
+ case 'human':
493
+ messageDict = { role: 'user', content: message.content }
494
+ break
495
+ case 'ai':
496
+ messageDict = { role: 'assistant', content: message.content }
497
+ break
498
+ case 'system':
499
+ messageDict = { role: 'system', content: message.content }
500
+ break
501
+ case 'tool':
502
+ messageDict = { role: 'tool', content: message.content }
503
+ break
504
+ case 'function':
505
+ messageDict = { role: 'function', content: message.content }
506
+ break
507
+ default:
508
+ messageDict = { role: messageType || 'unknown', content: String(message.content) }
509
+ }
510
+
511
+ if (message.additional_kwargs) {
512
+ messageDict = { ...messageDict, ...message.additional_kwargs }
513
+ }
514
+ return messageDict
515
+ }
516
+
517
+ private _parseUsageModel(usage: any): [number, number] {
518
+ const conversionList: Array<[string, 'input' | 'output']> = [
519
+ ['promptTokens', 'input'],
520
+ ['completionTokens', 'output'],
521
+ ['input_tokens', 'input'],
522
+ ['output_tokens', 'output'],
523
+ ['prompt_token_count', 'input'],
524
+ ['candidates_token_count', 'output'],
525
+ ['inputTokenCount', 'input'],
526
+ ['outputTokenCount', 'output'],
527
+ ['input_token_count', 'input'],
528
+ ['generated_token_count', 'output'],
529
+ ]
530
+
531
+ const parsedUsage = conversionList.reduce(
532
+ (acc: { input: number; output: number }, [modelKey, typeKey]) => {
533
+ const value = usage[modelKey]
534
+ if (value != null) {
535
+ const finalCount = Array.isArray(value)
536
+ ? value.reduce((sum: number, tokenCount: number) => sum + tokenCount, 0)
537
+ : value
538
+ acc[typeKey] = finalCount
539
+ }
540
+ return acc
541
+ },
542
+ { input: 0, output: 0 }
543
+ )
544
+
545
+ return [parsedUsage.input, parsedUsage.output]
546
+ }
547
+
548
+ private parseUsage(response: LLMResult): [number, number] {
549
+ let llmUsage: [number, number] = [0, 0]
550
+ const llmUsageKeys = ['token_usage', 'usage', 'tokenUsage']
551
+
552
+ if (response.llmOutput != null) {
553
+ const key = llmUsageKeys.find((k) => response.llmOutput?.[k] != null)
554
+ if (key) {
555
+ llmUsage = this._parseUsageModel(response.llmOutput[key])
556
+ }
557
+ }
558
+
559
+ // If top-level usage info was not found, try checking the generations.
560
+ if (llmUsage[0] === 0 && llmUsage[1] === 0 && response.generations) {
561
+ for (const generation of response.generations) {
562
+ for (const genChunk of generation) {
563
+ if (genChunk.generationInfo?.usage_metadata) {
564
+ llmUsage = this._parseUsageModel(genChunk.generationInfo.usage_metadata)
565
+ return llmUsage
566
+ }
567
+
568
+ const messageChunk = genChunk.generationInfo ?? {}
569
+ const responseMetadata = messageChunk.response_metadata ?? {}
570
+ const chunkUsage =
571
+ responseMetadata['usage'] ??
572
+ responseMetadata['amazon-bedrock-invocationMetrics'] ??
573
+ messageChunk.usage_metadata
574
+ if (chunkUsage) {
575
+ llmUsage = this._parseUsageModel(chunkUsage)
576
+ return llmUsage
577
+ }
578
+ }
579
+ }
580
+ }
581
+
582
+ return llmUsage
583
+ }
584
+ }
package/src/utils.ts CHANGED
@@ -11,11 +11,22 @@ export interface MonitoringParams {
11
11
  posthogProperties?: Record<string, any>
12
12
  posthogPrivacyMode?: boolean
13
13
  posthogGroups?: Record<string, any>
14
+ posthogModelOverride?: string
15
+ posthogProviderOverride?: string
16
+ posthogCostOverride?: CostOverride
17
+ }
18
+
19
+ export interface CostOverride {
20
+ inputCost: number
21
+ outputCost: number
14
22
  }
15
23
 
16
24
  export const getModelParams = (
17
- params: (ChatCompletionCreateParamsBase | MessageCreateParams) & MonitoringParams
25
+ params: ((ChatCompletionCreateParamsBase | MessageCreateParams) & MonitoringParams) | null
18
26
  ): Record<string, any> => {
27
+ if (!params) {
28
+ return {}
29
+ }
19
30
  const modelParams: Record<string, any> = {}
20
31
  const paramKeys = [
21
32
  'temperature',
@@ -137,12 +148,23 @@ export const sendEventToPosthog = ({
137
148
  $ai_error: error,
138
149
  }
139
150
  }
151
+ let costOverrideData = {}
152
+ if (params.posthogCostOverride) {
153
+ const inputCostUSD = (params.posthogCostOverride.inputCost ?? 0) * (usage.inputTokens ?? 0)
154
+ const outputCostUSD = (params.posthogCostOverride.outputCost ?? 0) * (usage.outputTokens ?? 0)
155
+ costOverrideData = {
156
+ $ai_input_cost_usd: inputCostUSD,
157
+ $ai_output_cost_usd: outputCostUSD,
158
+ $ai_total_cost_usd: inputCostUSD + outputCostUSD,
159
+ }
160
+ }
161
+
140
162
  client.capture({
141
163
  distinctId: distinctId ?? traceId,
142
164
  event: '$ai_generation',
143
165
  properties: {
144
- $ai_provider: provider,
145
- $ai_model: model,
166
+ $ai_provider: params.posthogProviderOverride ?? provider,
167
+ $ai_model: params.posthogModelOverride ?? model,
146
168
  $ai_model_parameters: getModelParams(params),
147
169
  $ai_input: withPrivacyMode(client, params.posthogPrivacyMode ?? false, input),
148
170
  $ai_output_choices: withPrivacyMode(client, params.posthogPrivacyMode ?? false, output),
@@ -155,6 +177,7 @@ export const sendEventToPosthog = ({
155
177
  ...params.posthogProperties,
156
178
  ...(distinctId ? {} : { $process_person_profile: false }),
157
179
  ...errorData,
180
+ ...costOverrideData,
158
181
  },
159
182
  groups: params.posthogGroups,
160
183
  })